Commit e6a6db32 authored by ORD's avatar ORD

Merge pull request #94 from alkor/client-cleanup

Clean up client's security policy API
parents 49fbc066 234008be
......@@ -212,13 +212,16 @@ class BinaryClient(object):
uaprotocol_auto.py and uaprotocol_hand.py
"""
def __init__(self, timeout=1, security_policy=ua.SecurityPolicy()):
def __init__(self, timeout=1):
self.logger = logging.getLogger(__name__)
self._publishcallbacks = {}
self._lock = Lock()
self._timeout = timeout
self._uasocket = None
self._security_policy = security_policy
self._security_policy = ua.SecurityPolicy()
def set_security(self, policy):
self._security_policy = policy
def connect_socket(self, host, port):
"""
......
......@@ -10,6 +10,7 @@ except ImportError: # support for python2
from opcua import uaprotocol as ua
from opcua import BinaryClient, Node, Subscription
from opcua import utils
from opcua import security_policies
use_crypto = True
try:
from opcua import uacrypto
......@@ -67,7 +68,7 @@ class Client(object):
which offers a raw OPC-UA interface.
"""
def __init__(self, url, timeout=1, security_policy=ua.SecurityPolicy()):
def __init__(self, url, timeout=1):
"""
used url argument to connect to server.
if you are unsure of url, write at least hostname and port
......@@ -82,27 +83,79 @@ class Client(object):
self.description = self.name
self.application_uri = "urn:freeopcua:client"
self.product_uri = "urn:freeopcua.github.no:client"
self.security_policy = security_policy
self.security_policy = ua.SecurityPolicy()
self.secure_channel_id = None
self.default_timeout = 3600000
self.secure_channel_timeout = self.default_timeout
self.session_timeout = self.default_timeout
self._policy_ids = []
self.bclient = BinaryClient(timeout, security_policy=security_policy)
self.server_certificate = None
self.client_certificate = None
self.private_key = None
self.bclient = BinaryClient(timeout)
self.user_certificate = None
self.user_private_key = None
self._session_counter = 1
self.keepalive = None
@staticmethod
def find_endpoint(endpoints, security_mode, policy_uri):
"""
Find endpoint with required security mode and policy URI
"""
for ep in endpoints:
if (ep.EndpointUrl.startswith(ua.OPC_TCP_SCHEME) and
ep.SecurityMode == security_mode and
ep.SecurityPolicyUri == policy_uri):
return ep
raise ValueError("No matching endpoints: {}, {}".format(
security_mode, policy_uri))
def set_security_string(self, string):
"""
Set SecureConnection mode. String format:
Policy,Mode,certificate,private_key[,server_private_key]
where Policy is Basic128Rsa15 or Basic256,
Mode is Sign or SignAndEncrypt
certificate, private_key and server_private_key are
paths to .pem or .der files
Call this before connect()
"""
if not string:
return
parts = string.split(',')
if len(parts) < 4:
raise Exception('Wrong format: `{}`, expected at least 4 '
'comma-separated values'.format(string))
policy_class = getattr(security_policies, 'SecurityPolicy' + parts[0])
mode = getattr(ua.MessageSecurityMode, parts[1])
return self.set_security(policy_class, parts[2], parts[3],
parts[4] if len(parts) >= 5 else None, mode)
def set_security(self, policy, certificate_path, private_key_path,
server_certificate_path=None,
mode=ua.MessageSecurityMode.SignAndEncrypt):
"""
Set SecureConnection mode.
Call this before connect()
"""
if server_certificate_path is None:
# load certificate from server's list of endpoints
endpoints = self.connect_and_get_server_endpoints()
endpoint = Client.find_endpoint(endpoints, mode, policy.URI)
server_cert = uacrypto.x509_from_der(endpoint.ServerCertificate)
else:
server_cert = uacrypto.load_certificate(server_certificate_path)
cert = uacrypto.load_certificate(certificate_path)
pk = uacrypto.load_private_key(private_key_path)
self.security_policy = policy(server_cert, cert, pk, mode)
self.bclient.set_security(self.security_policy)
def load_client_certificate(self, path):
"""
load our certificate from file, either pem or der
"""
self.client_certificate = uacrypto.load_certificate(path)
self.user_certificate = uacrypto.load_certificate(path)
def load_private_key(self, path):
self.private_key = uacrypto.load_private_key(path)
self.user_private_key = uacrypto.load_private_key(path)
def connect_and_get_server_endpoints(self):
"""
......@@ -149,7 +202,7 @@ class Client(object):
self.send_hello()
self.open_secure_channel()
self.create_session()
self.activate_session(username=self.server_url.username, password=self.server_url.password, certificate=self.client_certificate)
self.activate_session(username=self.server_url.username, password=self.server_url.password, certificate=self.user_certificate)
def disconnect(self):
"""
......@@ -257,11 +310,13 @@ class Client(object):
response = self.bclient.create_session(params)
self.security_policy.asymmetric_cryptography.verify(self.security_policy.client_certificate + nonce, response.ServerSignature.Signature)
self._server_nonce = response.ServerNonce
self.server_certificate = response.ServerCertificate
for ep in response.ServerEndpoints:
if urlparse(ep.EndpointUrl).scheme == self.server_url.scheme and ep.SecurityMode == self.security_policy.Mode:
# remember PolicyId's: we will use them in activate_session()
self._policy_ids = ep.UserIdentityTokens
if not self.security_policy.server_certificate:
self.security_policy.server_certificate = response.ServerCertificate
elif self.security_policy.server_certificate != response.ServerCertificate:
raise Exception("Server certificate mismatch")
# remember PolicyId's: we will use them in activate_session()
ep = Client.find_endpoint(response.ServerEndpoints, self.security_policy.Mode, self.security_policy.URI)
self._policy_ids = ep.UserIdentityTokens
self.session_timeout = response.RevisedSessionTimeout
self.keepalive = KeepAlive(self, min(self.session_timeout, self.secure_channel_timeout) * 0.7) # 0.7 is from spec
self.keepalive.start()
......@@ -292,8 +347,8 @@ class Client(object):
elif certificate:
params.UserIdentityToken = ua.X509IdentityToken()
params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.Certificate, b"certificate_basic256")
params.UserIdentityToken.CertificateData = uacrypto.der_from_x509(self.client_certificate)
sig = uacrypto.sign_sha1(self.private_key, certificate)
params.UserIdentityToken.CertificateData = uacrypto.der_from_x509(certificate)
sig = uacrypto.sign_sha1(self.user_private_key, certificate)
params.UserTokenSignature = ua.SignatureData()
params.UserTokenSignature.Algorithm = b"http://www.w3.org/2000/09/xmldsig#rsa-sha1"
params.UserTokenSignature.Signature = sig
......@@ -301,7 +356,7 @@ class Client(object):
params.UserIdentityToken = ua.UserNameIdentityToken()
params.UserIdentityToken.UserName = username
if self.server_url.password:
pubkey = self.server_certificate.publick_key()
pubkey = uacrypto.x509_from_der(self.security_policy.server_certificate).public_key()
data = uacrypto.encrypt_basic256(pubkey, bytes(password, "utf8"))
params.UserIdentityToken.Password = data
params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.UserName, b"username_basic256")
......
This diff is collapsed.
......@@ -64,26 +64,6 @@ def add_common_args(parser, default_node='i=84'):
default='')
def client_security(security, url, timeout):
parts = security.split(',')
if len(parts) < 4:
raise Exception('Wrong format: `{}`, expected at least 4 comma-separated values'.format(security))
policy_class = getattr(security_policies, 'SecurityPolicy' + parts[0])
mode = getattr(ua.MessageSecurityMode, parts[1])
cert = open(parts[2], 'rb').read()
pk = open(parts[3], 'rb').read()
server_cert = None
if len(parts) == 5:
server_cert = open(parts[4], 'rb').read()
else:
# we need server's certificate too. Let's get it from the list of endpoints
client = Client(url, timeout=timeout)
for ep in client.connect_and_get_server_endpoints():
if ep.EndpointUrl.startswith(ua.OPC_TCP_SCHEME) and ep.SecurityMode == mode and ep.SecurityPolicyUri == policy_class.URI:
server_cert = ep.ServerCertificate
return policy_class(server_cert, cert, pk, mode)
def _require_nodeid(parser, args):
# check that a nodeid has been given explicitly, a bit hackish...
if args.nodeid == "i=84" and args.path == "":
......@@ -98,11 +78,6 @@ def parse_args(parser, requirenodeid=False):
if args.url and '://' not in args.url:
logging.info("Adding default scheme %s to URL %s", ua.OPC_TCP_SCHEME, args.url)
args.url = ua.OPC_TCP_SCHEME + '://' + args.url
if hasattr(args, 'security'):
if args.security:
args.security = client_security(args.security, args.url, args.timeout)
else:
args.security = ua.SecurityPolicy()
if requirenodeid:
_require_nodeid(parser, args)
return args
......@@ -133,7 +108,8 @@ def uaread():
args = parse_args(parser, requirenodeid=True)
client = Client(args.url, timeout=args.timeout, security_policy=args.security)
client = Client(args.url, timeout=args.timeout)
client.set_security_string(args.security)
client.connect()
try:
node = get_node(client, args)
......@@ -269,7 +245,8 @@ def uawrite():
metavar="VALUE")
args = parse_args(parser, requirenodeid=True)
client = Client(args.url, timeout=args.timeout, security_policy=args.security)
client = Client(args.url, timeout=args.timeout)
client.set_security_string(args.security)
client.connect()
try:
node = get_node(client, args)
......@@ -300,7 +277,8 @@ def uals():
if args.long_format is None:
args.long_format = 1
client = Client(args.url, timeout=args.timeout, security_policy=args.security)
client = Client(args.url, timeout=args.timeout)
client.set_security_string(args.security)
client.connect()
try:
node = get_node(client, args)
......@@ -391,7 +369,8 @@ def uasubscribe():
if args.nodeid == "i=84" and args.path == "":
args.nodeid = "i=2253"
client = Client(args.url, timeout=args.timeout, security_policy=args.security)
client = Client(args.url, timeout=args.timeout)
client.set_security_string(args.security)
client.connect()
try:
node = get_node(client, args)
......@@ -472,7 +451,8 @@ def uaclient():
help="set client private key")
args = parse_args(parser)
client = Client(args.url, timeout=args.timeout, security_policy=args.security)
client = Client(args.url, timeout=args.timeout)
client.set_security_string(args.security)
client.connect()
if args.certificate:
client.load_client_certificate(args.certificate)
......@@ -633,7 +613,8 @@ def uahistoryread():
args = parse_args(parser, requirenodeid=True)
client = Client(args.url, timeout=args.timeout, security_policy=args.security)
client = Client(args.url, timeout=args.timeout)
client.set_security_string(args.security)
client.connect()
try:
node = get_node(client, args)
......
import os
from cryptography import x509
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import hmac
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers import algorithms
from cryptography.hazmat.primitives.ciphers import modes
def load_certificate(path):
......@@ -22,9 +27,19 @@ def x509_from_der(data):
return x509.load_der_x509_certificate(data, default_backend())
def x509_to_der(cert):
if not data:
return b''
return cert.public_bytes(serialization.Encoding.DER)
def load_private_key(path):
_, ext = os.path.splitext(path)
with open(path, "br") as f:
return serialization.load_pem_private_key(f.read(), password=None, backend=default_backend())
if ext == ".pem":
return serialization.load_pem_private_key(f.read(), password=None, backend=default_backend())
else:
return serialization.load_der_private_key(f.read(), password=None, backend=default_backend())
def der_from_x509(certificate):
......@@ -42,6 +57,15 @@ def sign_sha1(private_key, data):
return signer.finalize()
def verify_sha1(certificate, data, signature):
verifier = certificate.public_key().verifier(
signature,
padding.PKCS1v15(),
hashes.SHA1())
verifier.update(data)
verifier.verify()
def encrypt_basic256(public_key, data):
ciphertext = public_key.encrypt(
data,
......@@ -53,6 +77,90 @@ def encrypt_basic256(public_key, data):
return ciphertext
def encrypt_rsa_oaep(public_key, data):
ciphertext = public_key.encrypt(
data,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA1()),
algorithm=hashes.SHA1(),
label=None)
)
return ciphertext
def encrypt_rsa15(public_key, data):
ciphertext = public_key.encrypt(
data,
padding.PKCS1v15()
)
return ciphertext
def decrypt_rsa_oaep(private_key, data):
text = private_key.decrypt(
data,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA1()),
algorithm=hashes.SHA1(),
label=None)
)
return text
def decrypt_rsa15(private_key, data):
text = private_key.decrypt(
data,
padding.PKCS1v15()
)
return text
def cipher_aes_cbc(key, init_vec):
return Cipher(algorithms.AES(key), modes.CBC(init_vec), default_backend())
def cipher_encrypt(cipher, data):
encryptor = cipher.encryptor()
return encryptor.update(data) + encryptor.finalize()
def cipher_decrypt(cipher, data):
decryptor = cipher.decryptor()
return decryptor.update(data) + decryptor.finalize()
def hash_hmac(key, message):
hasher = hmac.HMAC(key, hashes.SHA1(), backend=default_backend())
hasher.update(message)
return hasher.finalize()
def sha1_size():
return hashes.SHA1.digest_size
def p_sha1(key, body, sizes=()):
"""
Derive one or more keys from key and body.
Lengths of keys will match sizes argument
"""
full_size = 0
for size in sizes:
full_size += size
result = b''
accum = body
while len(result) < full_size:
accum = hash_hmac(key, accum)
result += hash_hmac(key, accum + body)
parts = []
for size in sizes:
parts.append(result[:size])
result = result[size:]
return tuple(parts)
if __name__ == "__main__":
# Convert from PEM to DER
cert = load_certificate("../examples/server_cert.pem")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment