Commit ff6969d2 authored by Christoph Ziebuhr's avatar Christoph Ziebuhr Committed by oroulet

Respect EndpointUrl request parameter in GetEndpoints, FindServers and CreateSession

This is especially handy when the server runs behind NAT or inside docker container,
but a static EndpointUrl cannot be used because the host is multihomed.
parent b39672b9
......@@ -72,6 +72,7 @@ class InternalServer:
self.current_time_node = Node(self.isession, ua.NodeId(ua.ObjectIds.Server_ServerStatus_CurrentTime))
self.time_task = None
self._time_task_stop = False
self.match_discovery_endpoint_url: bool = True
self.match_discovery_source_ip: bool = True
self.supported_tokens = []
......@@ -200,32 +201,52 @@ class InternalServer:
def add_endpoint(self, endpoint):
self.endpoints.append(endpoint)
def _mangle_endpoint_url(self, ep_url, params_ep_url=None, sockname=None):
url = urlparse(ep_url)
if self.match_discovery_endpoint_url and params_ep_url:
try:
netloc = urlparse(params_ep_url).netloc
except ValueError:
netloc = ''
if netloc:
return url._replace(netloc=netloc).geturl()
if self.match_discovery_source_ip and sockname:
return url._replace(netloc=sockname[0] + ':' + str(sockname[1])).geturl()
return url.geturl()
async def get_endpoints(self, params=None, sockname=None):
self.logger.info('get endpoint')
if sockname:
# return to client the ip address it has access to
edps = []
for edp in self.endpoints:
edp1 = copy(edp)
url = urlparse(edp1.EndpointUrl)
if self.match_discovery_source_ip:
url = url._replace(netloc=sockname[0] + ':' + str(sockname[1]))
edp1.EndpointUrl = url.geturl()
edps.append(edp1)
return edps
return self.endpoints[:]
def find_servers(self, params):
if not params.ServerUris:
return [desc.Server for desc in self._known_servers.values()]
edps = []
params_ep_url = params.EndpointUrl if params else None
for edp in self.endpoints:
edp = copy(edp)
edp.EndpointUrl = self._mangle_endpoint_url(edp.EndpointUrl, params_ep_url=params_ep_url, sockname=sockname)
edp.Server = copy(edp.Server)
edp.Server.DiscoveryUrls = [
self._mangle_endpoint_url(url, params_ep_url=params_ep_url, sockname=sockname)
for url in edp.Server.DiscoveryUrls
]
edps.append(edp)
return edps
def find_servers(self, params, sockname=None):
servers = []
for serv in self._known_servers.values():
serv_uri = serv.Server.ApplicationUri.split(':')
for uri in params.ServerUris:
uri = uri.split(':')
if serv_uri[: len(uri)] == uri:
servers.append(serv.Server)
break
params_server_uris = [uri.split(':') for uri in params.ServerUris]
our_application_uris = [edp.Server.ApplicationUri for edp in self.endpoints]
for desc in self._known_servers.values():
if params_server_uris:
serv_uri = desc.Server.ApplicationUri.split(':')
if not any(serv_uri[: len(uri)] == uri for uri in params_server_uris):
continue
if desc.Server.ApplicationUri in our_application_uris:
serv = copy(desc.Server)
serv.DiscoveryUrls = [
self._mangle_endpoint_url(url, params_ep_url=params.EndpointUrl, sockname=sockname)
for url in serv.DiscoveryUrls
]
else:
serv = desc.Server
servers.append(serv)
return servers
def register_server(self, server, conf=None):
......
......@@ -64,7 +64,10 @@ class InternalSession(AbstractSession):
result.MaxRequestMessageSize = 65536
self.nonce = create_nonce(32)
result.ServerNonce = self.nonce
result.ServerEndpoints = await self.get_endpoints(sockname=sockname)
ep_params = ua.GetEndpointsParameters()
ep_params.EndpointUrl = params.EndpointUrl
result.ServerEndpoints = await self.get_endpoints(params=ep_params, sockname=sockname)
return result
......
......@@ -128,6 +128,15 @@ class Server:
await self.set_build_info(self.product_uri, self.manufacturer_name, self.name, "1.0pre", "0", datetime.now())
def set_match_discovery_endpoint_url(self, match_discovery_endpoint_url: bool):
"""
Enables or disables the matching of the EndpointUrl request parameter during discovery.
When True (default), the host/port of endpoints sent during the discovery is modified to the host/port
which is specified in the EndpointUrl request parameter.
"""
self.iserver.match_discovery_endpoint_url = match_discovery_endpoint_url
def set_match_discovery_client_ip(self, match_discovery_client_ip: bool):
"""
Enables or disables the matching of an endpoint IP to a client IP during discovery.
......
......@@ -233,7 +233,7 @@ class UaProcessor:
elif typeid == ua.NodeId(ua.ObjectIds.FindServersRequest_Encoding_DefaultBinary):
_logger.info("find servers request (%s)", user)
params = struct_from_binary(ua.FindServersParameters, body)
servers = self.iserver.find_servers(params)
servers = self.iserver.find_servers(params, sockname=self.sockname)
response = ua.FindServersResponse()
response.Servers = servers
# _logger.info("sending find servers response")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment