Commit b7a7f421 authored by Christian Bergmiller's avatar Christian Bergmiller

test refactoring wip

parent fc60e03c
...@@ -523,8 +523,8 @@ class Client(object): ...@@ -523,8 +523,8 @@ class Client(object):
uries = await self.get_namespace_array() uries = await self.get_namespace_array()
return uries.index(uri) return uries.index(uri)
def delete_nodes(self, nodes, recursive=False): async def delete_nodes(self, nodes, recursive=False):
return delete_nodes(self.uaclient, nodes, recursive) return await delete_nodes(self.uaclient, nodes, recursive)
def import_xml(self, path=None, xmlstring=None): def import_xml(self, path=None, xmlstring=None):
""" """
......
...@@ -103,7 +103,7 @@ class UASocketProtocol(asyncio.Protocol): ...@@ -103,7 +103,7 @@ class UASocketProtocol(asyncio.Protocol):
self.transport.write(msg) self.transport.write(msg)
return future return future
async def send_request(self, request, callback=None, timeout=1000, message_type=ua.MessageType.SecureMessage): async def send_request(self, request, callback=None, timeout=10, message_type=ua.MessageType.SecureMessage):
""" """
Send a request to the server. Send a request to the server.
Timeout is the timeout written in ua header. Timeout is the timeout written in ua header.
......
...@@ -14,7 +14,7 @@ async def copy_node(parent, node, nodeid=None, recursive=True): ...@@ -14,7 +14,7 @@ async def copy_node(parent, node, nodeid=None, recursive=True):
rdesc = await _rdesc_from_node(parent, node) rdesc = await _rdesc_from_node(parent, node)
if nodeid is None: if nodeid is None:
nodeid = ua.NodeId(namespaceidx=node.nodeid.NamespaceIndex) nodeid = ua.NodeId(namespaceidx=node.nodeid.NamespaceIndex)
added_nodeids = _copy_node(parent.server, parent.nodeid, rdesc, nodeid, recursive) added_nodeids = await _copy_node(parent.server, parent.nodeid, rdesc, nodeid, recursive)
return [Node(parent.server, nid) for nid in added_nodeids] return [Node(parent.server, nid) for nid in added_nodeids]
...@@ -29,12 +29,12 @@ async def _copy_node(server, parent_nodeid, rdesc, nodeid, recursive): ...@@ -29,12 +29,12 @@ async def _copy_node(server, parent_nodeid, rdesc, nodeid, recursive):
node_to_copy = Node(server, rdesc.NodeId) node_to_copy = Node(server, rdesc.NodeId)
attr_obj = getattr(ua, rdesc.NodeClass.name + "Attributes") attr_obj = getattr(ua, rdesc.NodeClass.name + "Attributes")
await _read_and_copy_attrs(node_to_copy, attr_obj(), addnode) await _read_and_copy_attrs(node_to_copy, attr_obj(), addnode)
res = await server.add_nodes([addnode])[0] res = (await server.add_nodes([addnode]))[0]
added_nodes = [res.AddedNodeId] added_nodes = [res.AddedNodeId]
if recursive: if recursive:
descs = await node_to_copy.get_children_descriptions() descs = await node_to_copy.get_children_descriptions()
for desc in descs: for desc in descs:
nodes = _copy_node(server, res.AddedNodeId, desc, nodeid=ua.NodeId(namespaceidx=desc.NodeId.NamespaceIndex), recursive=True) nodes = await _copy_node(server, res.AddedNodeId, desc, nodeid=ua.NodeId(namespaceidx=desc.NodeId.NamespaceIndex), recursive=True)
added_nodes.extend(nodes) added_nodes.extend(nodes)
return added_nodes return added_nodes
......
...@@ -406,14 +406,14 @@ def _guess_datatype(variant): ...@@ -406,14 +406,14 @@ def _guess_datatype(variant):
return ua.NodeId(getattr(ua.ObjectIds, variant.VariantType.name)) return ua.NodeId(getattr(ua.ObjectIds, variant.VariantType.name))
def delete_nodes(server, nodes, recursive=False, delete_target_references=True): async def delete_nodes(server, nodes, recursive=False, delete_target_references=True):
""" """
Delete specified nodes. Optionally delete recursively all nodes with a Delete specified nodes. Optionally delete recursively all nodes with a
downward hierachic references to the node downward hierachic references to the node
""" """
nodestodelete = [] nodestodelete = []
if recursive: if recursive:
nodes += _add_childs(nodes) nodes += await _add_childs(nodes)
for mynode in nodes: for mynode in nodes:
it = ua.DeleteNodesItem() it = ua.DeleteNodesItem()
it.NodeId = mynode.nodeid it.NodeId = mynode.nodeid
...@@ -421,11 +421,11 @@ def delete_nodes(server, nodes, recursive=False, delete_target_references=True): ...@@ -421,11 +421,11 @@ def delete_nodes(server, nodes, recursive=False, delete_target_references=True):
nodestodelete.append(it) nodestodelete.append(it)
params = ua.DeleteNodesParameters() params = ua.DeleteNodesParameters()
params.NodesToDelete = nodestodelete params.NodesToDelete = nodestodelete
return server.delete_nodes(params) return await server.delete_nodes(params)
def _add_childs(nodes): async def _add_childs(nodes):
results = [] results = []
for mynode in nodes[:]: for mynode in nodes[:]:
results += mynode.get_children() results += await mynode.get_children()
return results return results
...@@ -6,7 +6,7 @@ from opcua import ua ...@@ -6,7 +6,7 @@ from opcua import ua
from opcua.common import node from opcua.common import node
def call_method(parent, methodid, *args): async def call_method(parent, methodid, *args):
""" """
Call an OPC-UA method. methodid is browse name of child method or the Call an OPC-UA method. methodid is browse name of child method or the
nodeid of method as a NodeId object nodeid of method as a NodeId object
...@@ -14,7 +14,7 @@ def call_method(parent, methodid, *args): ...@@ -14,7 +14,7 @@ def call_method(parent, methodid, *args):
which may be of different types which may be of different types
returns a list of values or a single value depending on the output of the method returns a list of values or a single value depending on the output of the method
""" """
result = call_method_full(parent, methodid, *args) result = await call_method_full(parent, methodid, *args)
if len(result.OutputArguments) == 0: if len(result.OutputArguments) == 0:
return None return None
...@@ -24,7 +24,7 @@ def call_method(parent, methodid, *args): ...@@ -24,7 +24,7 @@ def call_method(parent, methodid, *args):
return result.OutputArguments return result.OutputArguments
def call_method_full(parent, methodid, *args): async def call_method_full(parent, methodid, *args):
""" """
Call an OPC-UA method. methodid is browse name of child method or the Call an OPC-UA method. methodid is browse name of child method or the
nodeid of method as a NodeId object nodeid of method as a NodeId object
...@@ -33,7 +33,7 @@ def call_method_full(parent, methodid, *args): ...@@ -33,7 +33,7 @@ def call_method_full(parent, methodid, *args):
returns a CallMethodResult object with converted OutputArguments returns a CallMethodResult object with converted OutputArguments
""" """
if isinstance(methodid, (str, ua.uatypes.QualifiedName)): if isinstance(methodid, (str, ua.uatypes.QualifiedName)):
methodid = parent.get_child(methodid).nodeid methodid = (await parent.get_child(methodid)).nodeid
elif isinstance(methodid, node.Node): elif isinstance(methodid, node.Node):
methodid = methodid.nodeid methodid = methodid.nodeid
......
...@@ -405,15 +405,11 @@ class Node: ...@@ -405,15 +405,11 @@ class Node:
Since address space may have circular references, a max length is specified Since address space may have circular references, a max length is specified
""" """
path = [] path = await self._get_path(max_length)
for ref in await self._get_path(max_length): path = [Node(self.server, ref.NodeId) for ref in path]
path.append(Node(self.server, ref.NodeId))
path.append(self) path.append(self)
if as_string: if as_string:
str_path = [] path = [(await el.get_browse_name()).to_string() for el in path]
for el in path:
name = await el.get_browse_name()
str_path.append(name.to_string())
return path return path
async def _get_path(self, max_length=20): async def _get_path(self, max_length=20):
...@@ -598,7 +594,7 @@ class Node: ...@@ -598,7 +594,7 @@ class Node:
else: else:
raise ua.UaStatusCodeError(ua.StatusCodes.BadNotFound) raise ua.UaStatusCodeError(ua.StatusCodes.BadNotFound)
ditem = self._fill_delete_reference_item(rdesc, bidirectional) ditem = self._fill_delete_reference_item(rdesc, bidirectional)
await self.server.delete_references([ditem])[0].check() (await self.server.delete_references([ditem]))[0].check()
async def add_reference(self, target, reftype, forward=True, bidirectional=True): async def add_reference(self, target, reftype, forward=True, bidirectional=True):
""" """
...@@ -660,6 +656,7 @@ class Node: ...@@ -660,6 +656,7 @@ class Node:
return opcua.common.manage_nodes.create_method(self, *args) return opcua.common.manage_nodes.create_method(self, *args)
def add_reference_type(self, nodeid, bname, symmetric=True, inversename=None): def add_reference_type(self, nodeid, bname, symmetric=True, inversename=None):
"""COROUTINE"""
return opcua.common.manage_nodes.create_reference_type(self, nodeid, bname, symmetric, inversename) return opcua.common.manage_nodes.create_reference_type(self, nodeid, bname, symmetric, inversename)
def call_method(self, methodid, *args): def call_method(self, methodid, *args):
......
...@@ -141,7 +141,7 @@ class InternalServer(object): ...@@ -141,7 +141,7 @@ class InternalServer(object):
def add_endpoint(self, endpoint): def add_endpoint(self, endpoint):
self.endpoints.append(endpoint) self.endpoints.append(endpoint)
def get_endpoints(self, params=None, sockname=None): async def get_endpoints(self, params=None, sockname=None):
self.logger.info('get endpoint') self.logger.info('get endpoint')
if sockname: if sockname:
# return to client the ip address it has access to # return to client the ip address it has access to
...@@ -261,10 +261,10 @@ class InternalSession(object): ...@@ -261,10 +261,10 @@ class InternalSession(object):
return 'InternalSession(name:{0}, user:{1}, id:{2}, auth_token:{3})'.format( return 'InternalSession(name:{0}, user:{1}, id:{2}, auth_token:{3})'.format(
self.name, self.user, self.session_id, self.authentication_token) self.name, self.user, self.session_id, self.authentication_token)
def get_endpoints(self, params=None, sockname=None): async def get_endpoints(self, params=None, sockname=None):
return self.iserver.get_endpoints(params, sockname) return await self.iserver.get_endpoints(params, sockname)
def create_session(self, params, sockname=None): async def create_session(self, params, sockname=None):
self.logger.info('Create session request') self.logger.info('Create session request')
result = ua.CreateSessionResult() result = ua.CreateSessionResult()
...@@ -274,7 +274,7 @@ class InternalSession(object): ...@@ -274,7 +274,7 @@ class InternalSession(object):
result.MaxRequestMessageSize = 65536 result.MaxRequestMessageSize = 65536
self.nonce = utils.create_nonce(32) self.nonce = utils.create_nonce(32)
result.ServerNonce = self.nonce result.ServerNonce = self.nonce
result.ServerEndpoints = self.get_endpoints(sockname=sockname) result.ServerEndpoints = await self.get_endpoints(sockname=sockname)
return result return result
......
...@@ -23,13 +23,9 @@ from opcua.common.structures import load_type_definitions ...@@ -23,13 +23,9 @@ from opcua.common.structures import load_type_definitions
from opcua.common.xmlexporter import XmlExporter from opcua.common.xmlexporter import XmlExporter
from opcua.common.xmlimporter import XmlImporter from opcua.common.xmlimporter import XmlImporter
from opcua.common.ua_utils import get_nodes_of_namespace from opcua.common.ua_utils import get_nodes_of_namespace
from opcua.crypto import uacrypto
use_crypto = True _logger = logging.getLogger(__name__)
try:
from opcua.crypto import uacrypto
except ImportError:
logging.getLogger(__name__).warning("cryptography is not installed, use of crypto disabled")
use_crypto = False
class Server: class Server:
...@@ -159,7 +155,7 @@ class Server: ...@@ -159,7 +155,7 @@ class Server:
uries.append(uri) uries.append(uri)
await ns_node.set_value(uries) await ns_node.set_value(uries)
def find_servers(self, uris=None): async def find_servers(self, uris=None):
""" """
find_servers. mainly implemented for symmetry with client find_servers. mainly implemented for symmetry with client
""" """
...@@ -211,8 +207,8 @@ class Server: ...@@ -211,8 +207,8 @@ class Server:
def set_endpoint(self, url): def set_endpoint(self, url):
self.endpoint = urlparse(url) self.endpoint = urlparse(url)
def get_endpoints(self): async def get_endpoints(self):
return self.iserver.get_endpoints() return await self.iserver.get_endpoints()
def set_security_policy(self, security_policy): def set_security_policy(self, security_policy):
""" """
...@@ -545,8 +541,8 @@ class Server: ...@@ -545,8 +541,8 @@ class Server:
nodes = await get_nodes_of_namespace(self, namespaces) nodes = await get_nodes_of_namespace(self, namespaces)
self.export_xml(nodes, path) self.export_xml(nodes, path)
def delete_nodes(self, nodes, recursive=False): async def delete_nodes(self, nodes, recursive=False):
return delete_nodes(self.iserver.isession, nodes, recursive) return await delete_nodes(self.iserver.isession, nodes, recursive)
async def historize_node_data_change(self, node, period=timedelta(days=7), count=0): async def historize_node_data_change(self, node, period=timedelta(days=7), count=0):
""" """
......
...@@ -4,13 +4,14 @@ from threading import RLock, Lock ...@@ -4,13 +4,14 @@ from threading import RLock, Lock
import time import time
from opcua import ua from opcua import ua
from opcua.server.internal_server import InternalServer, InternalSession
from opcua.ua.ua_binary import nodeid_from_binary, struct_from_binary from opcua.ua.ua_binary import nodeid_from_binary, struct_from_binary
from opcua.ua.ua_binary import struct_to_binary, uatcp_to_binary from opcua.ua.ua_binary import struct_to_binary, uatcp_to_binary
from opcua.common import utils from opcua.common import utils
from opcua.common.connection import SecureConnection from opcua.common.connection import SecureConnection
class PublishRequestData(object): class PublishRequestData:
def __init__(self): def __init__(self):
self.requesthdr = None self.requesthdr = None
...@@ -19,14 +20,14 @@ class PublishRequestData(object): ...@@ -19,14 +20,14 @@ class PublishRequestData(object):
self.timestamp = time.time() self.timestamp = time.time()
class UaProcessor(object): class UaProcessor:
def __init__(self, internal_server, socket): def __init__(self, internal_server: InternalServer, socket):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.iserver = internal_server self.iserver: InternalServer = internal_server
self.name = socket.get_extra_info('peername') self.name = socket.get_extra_info('peername')
self.sockname = socket.get_extra_info('sockname') self.sockname = socket.get_extra_info('sockname')
self.session = None self.session: InternalSession = None
self.socket = socket self.socket = socket
self._datalock = RLock() self._datalock = RLock()
self._publishdata_queue = [] self._publishdata_queue = []
...@@ -119,7 +120,7 @@ class UaProcessor(object): ...@@ -119,7 +120,7 @@ class UaProcessor(object):
# create the session on server # create the session on server
self.session = self.iserver.create_session(self.name, external=True) self.session = self.iserver.create_session(self.name, external=True)
# get a session creation result to send back # get a session creation result to send back
sessiondata = self.session.create_session(params, sockname=self.sockname) sessiondata = await self.session.create_session(params, sockname=self.sockname)
response = ua.CreateSessionResponse() response = ua.CreateSessionResponse()
response.Parameters = sessiondata response.Parameters = sessiondata
response.Parameters.ServerCertificate = self._connection.security_policy.client_certificate response.Parameters.ServerCertificate = self._connection.security_policy.client_certificate
...@@ -188,7 +189,7 @@ class UaProcessor(object): ...@@ -188,7 +189,7 @@ class UaProcessor(object):
elif typeid == ua.NodeId(ua.ObjectIds.GetEndpointsRequest_Encoding_DefaultBinary): elif typeid == ua.NodeId(ua.ObjectIds.GetEndpointsRequest_Encoding_DefaultBinary):
self.logger.info("get endpoints request") self.logger.info("get endpoints request")
params = struct_from_binary(ua.GetEndpointsParameters, body) params = struct_from_binary(ua.GetEndpointsParameters, body)
endpoints = self.iserver.get_endpoints(params, sockname=self.sockname) endpoints = await self.iserver.get_endpoints(params, sockname=self.sockname)
response = ua.GetEndpointsResponse() response = ua.GetEndpointsResponse()
response.Endpoints = endpoints response.Endpoints = endpoints
self.logger.info("sending get endpoints response") self.logger.info("sending get endpoints response")
......
...@@ -22,7 +22,7 @@ async def opc(request): ...@@ -22,7 +22,7 @@ async def opc(request):
await srv.start() await srv.start()
# start client # start client
# long timeout since travis (automated testing) can be really slow # long timeout since travis (automated testing) can be really slow
clt = Client(f'opc.tcp://127.0.0.1:{port_num}', timeout=10) clt = Client(f'opc.tcp://admin@127.0.0.1:{port_num}', timeout=10)
await clt.connect() await clt.connect()
yield clt yield clt
await clt.disconnect() await clt.disconnect()
......
...@@ -7,7 +7,7 @@ from opcua import Server ...@@ -7,7 +7,7 @@ from opcua import Server
from opcua import ua from opcua import ua
from .test_common import add_server_methods from .test_common import add_server_methods
from .tests_enum_struct import add_server_custom_enum_struct from .util_enum_struct import add_server_custom_enum_struct
port_num1 = 48510 port_num1 = 48510
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
......
This diff is collapsed.
import unittest import pytest
from opcua import Client from opcua import Client
from opcua import Server from opcua import Server
...@@ -13,10 +13,110 @@ except ImportError: ...@@ -13,10 +13,110 @@ except ImportError:
else: else:
disable_crypto_tests = False disable_crypto_tests = False
pytestmark = pytest.mark.asyncio
port_num1 = 48515 port_num1 = 48515
port_num2 = 48512 port_num2 = 48512
uri_crypto = 'opc.tcp://127.0.0.1:{0:d}'.format(port_num1)
uri_no_crypto = 'opc.tcp://127.0.0.1:{0:d}'.format(port_num2)
@pytest.fixture()
async def srv_crypto():
# start our own server
srv = Server()
await srv.init()
srv.set_endpoint(uri_crypto)
await srv.load_certificate("examples/certificate-example.der")
await srv.load_private_key("examples/private-key-example.pem")
await srv.start()
yield srv
# stop the server
await srv.stop()
@pytest.fixture()
async def srv_no_crypto():
# start our own server
srv = Server()
await srv.init()
srv.set_endpoint(uri_no_crypto)
await srv.start()
yield srv
# stop the server
await srv.stop()
async def test_nocrypto(srv_no_crypto):
clt = Client(uri_no_crypto)
async with clt:
await clt.get_objects_node().get_children()
async def test_nocrypto_fail():
clt = Client(uri_no_crypto)
with pytest.raises(ua.UaError):
await clt.set_security_string("Basic256,Sign,examples/certificate-example.der,examples/private-key-example.pem")
async def test_basic256(srv_crypto):
clt = Client(uri_crypto)
await clt.set_security_string("Basic256,Sign,examples/certificate-example.der,examples/private-key-example.pem")
async with clt:
assert await clt.get_objects_node().get_children()
async def test_basic256_encrypt():
clt = Client(uri_crypto)
await clt.set_security_string(
"Basic256,SignAndEncrypt,examples/certificate-example.der,examples/private-key-example.pem")
async with clt:
assert await clt.get_objects_node().get_children()
async def test_basic128Rsa15():
clt = Client(uri_crypto)
await clt.set_security_string("Basic128Rsa15,Sign,examples/certificate-example.der,examples/private-key-example.pem")
async with clt:
assert await clt.get_objects_node().get_children()
async def test_basic128Rsa15_encrypt():
clt = Client(uri_crypto)
await clt.set_security_string(
"Basic128Rsa15,SignAndEncrypt,examples/certificate-example.der,examples/private-key-example.pem"
)
async with clt:
assert await clt.get_objects_node().get_children()
async def test_basic256_encrypt_success():
clt = Client(uri_crypto)
await clt.set_security(
security_policies.SecurityPolicyBasic256,
'examples/certificate-example.der',
'examples/private-key-example.pem',
None,
ua.MessageSecurityMode.SignAndEncrypt
)
async with clt:
assert await clt.get_objects_node().get_children()
async def test_basic256_encrypt_feil():
# FIXME: how to make it feil???
clt = Client(uri_crypto)
with pytest.raises(ua.UaError):
await clt.set_security(
security_policies.SecurityPolicyBasic256,
'examples/certificate-example.der',
'examples/private-key-example.pem',
None,
ua.MessageSecurityMode.None_
)
"""
@unittest.skipIf(disable_crypto_tests, "crypto not available") @unittest.skipIf(disable_crypto_tests, "crypto not available")
class TestCryptoConnect(unittest.TestCase): class TestCryptoConnect(unittest.TestCase):
...@@ -48,27 +148,6 @@ class TestCryptoConnect(unittest.TestCase): ...@@ -48,27 +148,6 @@ class TestCryptoConnect(unittest.TestCase):
cls.srv_no_crypto.stop() cls.srv_no_crypto.stop()
cls.srv_crypto.stop() cls.srv_crypto.stop()
def test_nocrypto(self):
clt = Client(self.uri_no_crypto)
clt.connect()
try:
clt.get_objects_node().get_children()
finally:
clt.disconnect()
def test_nocrypto_feil(self):
clt = Client(self.uri_no_crypto)
with self.assertRaises(ua.UaError):
clt.set_security_string("Basic256,Sign,examples/certificate-example.der,examples/private-key-example.pem")
def test_basic256(self):
clt = Client(self.uri_crypto)
try:
clt.set_security_string("Basic256,Sign,examples/certificate-example.der,examples/private-key-example.pem")
clt.connect()
self.assertTrue(clt.get_objects_node().get_children())
finally:
clt.disconnect()
def test_basic256_encrypt(self): def test_basic256_encrypt(self):
clt = Client(self.uri_crypto) clt = Client(self.uri_crypto)
...@@ -121,3 +200,4 @@ class TestCryptoConnect(unittest.TestCase): ...@@ -121,3 +200,4 @@ class TestCryptoConnect(unittest.TestCase):
None, None,
ua.MessageSecurityMode.None_ ua.MessageSecurityMode.None_
) )
"""
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment