Commit 05c9f108 authored by Jérome Perrin's avatar Jérome Perrin

util: use safe variant or xml_marshaller

Also change places where xml_marshaller were used directly to always use
the wrapper from utils (except in tests for simplicity)
parent 6cd47aed
Pipeline #10064 failed with stage
in 0 seconds
...@@ -49,7 +49,7 @@ try: ...@@ -49,7 +49,7 @@ try:
ComputerPartition as SlapComputerPartition, ComputerPartition as SlapComputerPartition,
SoftwareInstance, SoftwareInstance,
SoftwareRelease) SoftwareRelease)
from slapos.util import dict2xml, xml2dict, calculate_dict_hash from slapos.util import dict2xml, xml2dict, calculate_dict_hash, loads, dumps
except ImportError: except ImportError:
# Do no prevent instance from starting # Do no prevent instance from starting
# if libs are not installed # if libs are not installed
...@@ -71,9 +71,12 @@ except ImportError: ...@@ -71,9 +71,12 @@ except ImportError:
raise ImportError raise ImportError
def calculate_dict_hash(dictionary): def calculate_dict_hash(dictionary):
raise ImportError raise ImportError
def loads(*args):
raise ImportError
def dumps(*args):
raise ImportError
from zLOG import LOG, INFO from zLOG import LOG, INFO
import xml_marshaller
import StringIO import StringIO
import pkg_resources import pkg_resources
from Products.Vifib.Conduit import VifibConduit from Products.Vifib.Conduit import VifibConduit
...@@ -185,7 +188,7 @@ class SlapTool(BaseTool): ...@@ -185,7 +188,7 @@ class SlapTool(BaseTool):
portal_type="Computer Partition"): portal_type="Computer Partition"):
slap_computer._computer_partition_list.append( slap_computer._computer_partition_list.append(
self._getSlapPartitionByPackingList(_assertACI(computer_partition.getObject()))) self._getSlapPartitionByPackingList(_assertACI(computer_partition.getObject())))
return xml_marshaller.xml_marshaller.dumps(slap_computer) return dumps(slap_computer)
def _fillComputerInformationCache(self, computer_id, user): def _fillComputerInformationCache(self, computer_id, user):
key = '%s_%s' % (computer_id, user) key = '%s_%s' % (computer_id, user)
...@@ -278,7 +281,7 @@ class SlapTool(BaseTool): ...@@ -278,7 +281,7 @@ class SlapTool(BaseTool):
for computer_partition in computer_partition_list: for computer_partition in computer_partition_list:
slap_computer._computer_partition_list.append( slap_computer._computer_partition_list.append(
self._getSlapPartitionByPackingList(_assertACI(computer_partition.getObject()))) self._getSlapPartitionByPackingList(_assertACI(computer_partition.getObject())))
return xml_marshaller.xml_marshaller.dumps(slap_computer) return dumps(slap_computer)
@UnrestrictedMethod @UnrestrictedMethod
def _getHostingSubscriptionIpList(self, computer_id, computer_partition_id): def _getHostingSubscriptionIpList(self, computer_id, computer_partition_id):
...@@ -287,7 +290,7 @@ class SlapTool(BaseTool): ...@@ -287,7 +290,7 @@ class SlapTool(BaseTool):
if software_instance is None or \ if software_instance is None or \
software_instance.getSlapState() == 'destroy_requested': software_instance.getSlapState() == 'destroy_requested':
return xml_marshaller.xml_marshaller.dumps([]) return dumps([])
# Search hosting subscription # Search hosting subscription
hosting = software_instance.getSpecialiseValue() hosting = software_instance.getSpecialiseValue()
while hosting and hosting.getPortalType() != "Hosting Subscription": while hosting and hosting.getPortalType() != "Hosting Subscription":
...@@ -305,7 +308,7 @@ class SlapTool(BaseTool): ...@@ -305,7 +308,7 @@ class SlapTool(BaseTool):
internet_protocol_address.getIpAddress().decode("UTF-8")) internet_protocol_address.getIpAddress().decode("UTF-8"))
) )
return xml_marshaller.xml_marshaller.dumps(ip_address_list) return dumps(ip_address_list)
security.declareProtected(Permissions.AccessContentsInformation, security.declareProtected(Permissions.AccessContentsInformation,
'getFullComputerInformation') 'getFullComputerInformation')
...@@ -366,7 +369,7 @@ class SlapTool(BaseTool): ...@@ -366,7 +369,7 @@ class SlapTool(BaseTool):
key=software_instance.getSslKey(), key=software_instance.getSslKey(),
certificate=software_instance.getSslCertificate() certificate=software_instance.getSslCertificate()
) )
result = xml_marshaller.xml_marshaller.dumps(certificate_dict) result = dumps(certificate_dict)
# Cache with revalidation # Cache with revalidation
self.REQUEST.response.setStatus(200) self.REQUEST.response.setStatus(200)
self.REQUEST.response.setHeader('Cache-Control', self.REQUEST.response.setHeader('Cache-Control',
...@@ -458,7 +461,7 @@ class SlapTool(BaseTool): ...@@ -458,7 +461,7 @@ class SlapTool(BaseTool):
reference=software_product_reference, reference=software_product_reference,
validation_state='published') validation_state='published')
if len(software_product_list) is 0: if len(software_product_list) is 0:
return xml_marshaller.xml_marshaller.dumps([]) return dumps([])
if len(software_product_list) > 1: if len(software_product_list) > 1:
raise NotImplementedError('Several Software Product with the same title.') raise NotImplementedError('Several Software Product with the same title.')
software_release_list = \ software_release_list = \
...@@ -477,7 +480,7 @@ class SlapTool(BaseTool): ...@@ -477,7 +480,7 @@ class SlapTool(BaseTool):
key=sortkey, key=sortkey,
reverse=True, reverse=True,
) )
return xml_marshaller.xml_marshaller.dumps( return dumps(
[software_release.getUrlString() [software_release.getUrlString()
for software_release in software_release_list for software_release in software_release_list
if software_release.getValidationState() in \ if software_release.getValidationState() in \
...@@ -555,7 +558,7 @@ class SlapTool(BaseTool): ...@@ -555,7 +558,7 @@ class SlapTool(BaseTool):
person = portal.portal_membership.getAuthenticatedMember().getUserValue() person = portal.portal_membership.getAuthenticatedMember().getUserValue()
person.requestComputer(computer_title=computer_title) person.requestComputer(computer_title=computer_title)
computer = Computer(self.REQUEST.get('computer_reference').decode("UTF-8")) computer = Computer(self.REQUEST.get('computer_reference').decode("UTF-8"))
return xml_marshaller.xml_marshaller.dumps(computer) return dumps(computer)
security.declareProtected(Permissions.AccessContentsInformation, security.declareProtected(Permissions.AccessContentsInformation,
'requestComputer') 'requestComputer')
...@@ -724,7 +727,7 @@ class SlapTool(BaseTool): ...@@ -724,7 +727,7 @@ class SlapTool(BaseTool):
'loadComputerConfigurationFromXML') 'loadComputerConfigurationFromXML')
def loadComputerConfigurationFromXML(self, xml): def loadComputerConfigurationFromXML(self, xml):
"Load the given xml as configuration for the computer object" "Load the given xml as configuration for the computer object"
computer_dict = xml_marshaller.xml_marshaller.loads(xml) computer_dict = loads(xml)
computer = self._getComputerDocument(computer_dict['reference']) computer = self._getComputerDocument(computer_dict['reference'])
computer.Computer_updateFromDict(computer_dict) computer.Computer_updateFromDict(computer_dict)
return 'Content properly posted.' return 'Content properly posted.'
...@@ -750,7 +753,7 @@ class SlapTool(BaseTool): ...@@ -750,7 +753,7 @@ class SlapTool(BaseTool):
'certificate': self.REQUEST.get('computer_certificate').decode("UTF-8"), 'certificate': self.REQUEST.get('computer_certificate').decode("UTF-8"),
'key': self.REQUEST.get('computer_key').decode("UTF-8") 'key': self.REQUEST.get('computer_key').decode("UTF-8")
} }
return xml_marshaller.xml_marshaller.dumps(result) return dumps(result)
security.declareProtected(Permissions.AccessContentsInformation, security.declareProtected(Permissions.AccessContentsInformation,
'generateComputerCertificate') 'generateComputerCertificate')
...@@ -842,7 +845,7 @@ class SlapTool(BaseTool): ...@@ -842,7 +845,7 @@ class SlapTool(BaseTool):
slave_instance_dict.pop("xml"))) slave_instance_dict.pop("xml")))
slap_partition._parameter_dict.update(parameter_dict) slap_partition._parameter_dict.update(parameter_dict)
result = xml_marshaller.xml_marshaller.dumps(slap_partition) result = dumps(slap_partition)
# Keep in cache server for 7 days # Keep in cache server for 7 days
self.REQUEST.response.setStatus(200) self.REQUEST.response.setStatus(200)
...@@ -1145,7 +1148,7 @@ class SlapTool(BaseTool): ...@@ -1145,7 +1148,7 @@ class SlapTool(BaseTool):
'REMOTE_USER') 'REMOTE_USER')
self.REQUEST.response.setHeader('Last-Modified', last_modified) self.REQUEST.response.setHeader('Last-Modified', last_modified)
self.REQUEST.response.setHeader('Content-Type', 'text/xml; charset=utf-8') self.REQUEST.response.setHeader('Content-Type', 'text/xml; charset=utf-8')
self.REQUEST.response.setBody(xml_marshaller.xml_marshaller.dumps(d)) self.REQUEST.response.setBody(dumps(d))
return self.REQUEST.response return self.REQUEST.response
@convertToREST @convertToREST
...@@ -1218,8 +1221,7 @@ class SlapTool(BaseTool): ...@@ -1218,8 +1221,7 @@ class SlapTool(BaseTool):
computer_id, computer_id,
computer_partition_id, computer_partition_id,
slave_reference) slave_reference)
connection_xml = dict2xml(xml_marshaller.xml_marshaller.loads( connection_xml = dict2xml(loads(connection_xml))
connection_xml))
reference = software_instance.getReference() reference = software_instance.getReference()
if self._getLastData(reference) != connection_xml: if self._getLastData(reference) != connection_xml:
software_instance.updateConnection( software_instance.updateConnection(
...@@ -1244,20 +1246,19 @@ class SlapTool(BaseTool): ...@@ -1244,20 +1246,19 @@ class SlapTool(BaseTool):
In any other case returns not important data and HTTP code is 403 Forbidden In any other case returns not important data and HTTP code is 403 Forbidden
""" """
if state: if state:
state = xml_marshaller.xml_marshaller.loads(state) state = loads(state)
if state is None: if state is None:
state = 'started' state = 'started'
if shared_xml is not _MARKER: if shared_xml is not _MARKER:
shared = xml_marshaller.xml_marshaller.loads(shared_xml) shared = loads(shared_xml)
else: else:
shared = False shared = False
if partition_parameter_xml: if partition_parameter_xml:
partition_parameter_kw = xml_marshaller.xml_marshaller.loads( partition_parameter_kw = loads(partition_parameter_xml)
partition_parameter_xml)
else: else:
partition_parameter_kw = dict() partition_parameter_kw = dict()
if filter_xml: if filter_xml:
filter_kw = xml_marshaller.xml_marshaller.loads(filter_xml) filter_kw = loads(filter_xml)
if software_type == 'pull-backup' and not 'retention_delay' in filter_kw: if software_type == 'pull-backup' and not 'retention_delay' in filter_kw:
filter_kw['retention_delay'] = 7.0 filter_kw['retention_delay'] = 7.0
else: else:
...@@ -1363,7 +1364,7 @@ class SlapTool(BaseTool): ...@@ -1363,7 +1364,7 @@ class SlapTool(BaseTool):
software_instance._filter_dict = filter_xml software_instance._filter_dict = filter_xml
software_instance._requested_state = state software_instance._requested_state = state
software_instance._instance_guid = instance_guid software_instance._instance_guid = instance_guid
return xml_marshaller.xml_marshaller.dumps(software_instance) return dumps(software_instance)
@UnrestrictedMethod @UnrestrictedMethod
def _updateComputerPartitionRelatedInstanceList(self, computer_id, def _updateComputerPartitionRelatedInstanceList(self, computer_id,
...@@ -1383,8 +1384,7 @@ class SlapTool(BaseTool): ...@@ -1383,8 +1384,7 @@ class SlapTool(BaseTool):
cache_reference = '%s-PREDLIST' % software_instance_document.getReference() cache_reference = '%s-PREDLIST' % software_instance_document.getReference()
if self._getLastData(cache_reference) != instance_reference_xml: if self._getLastData(cache_reference) != instance_reference_xml:
instance_reference_list = xml_marshaller.xml_marshaller.loads( instance_reference_list = loads(instance_reference_xml)
instance_reference_xml)
current_predecessor_list = software_instance_document.getPredecessorValueList( current_predecessor_list = software_instance_document.getPredecessorValueList(
portal_type=['Software Instance', 'Slave Instance']) portal_type=['Software Instance', 'Slave Instance'])
......
...@@ -51,7 +51,6 @@ try: ...@@ -51,7 +51,6 @@ try:
except ImportError: # XXX to be removed once we depend on typing except ImportError: # XXX to be removed once we depend on typing
pass pass
import xml_marshaller
import zope.interface import zope.interface
import psutil import psutil
...@@ -61,6 +60,8 @@ from .interface.slap import IRequester ...@@ -61,6 +60,8 @@ from .interface.slap import IRequester
from ..grid.slapgrid import SLAPGRID_PROMISE_FAIL from ..grid.slapgrid import SLAPGRID_PROMISE_FAIL
from .slap import slap from .slap import slap
from ..util import dumps
from ..grid.svcbackend import getSupervisorRPC from ..grid.svcbackend import getSupervisorRPC
...@@ -534,7 +535,7 @@ class StandaloneSlapOS(object): ...@@ -534,7 +535,7 @@ class StandaloneSlapOS(object):
**locals())) **locals()))
self.computer.updateConfiguration( self.computer.updateConfiguration(
xml_marshaller.xml_marshaller.dumps({ dumps({
'address': ipv4_address, 'address': ipv4_address,
'netmask': '255.255.255.255', 'netmask': '255.255.255.255',
'partition_list': partition_list, 'partition_list': partition_list,
......
...@@ -217,6 +217,20 @@ class TestUtil(unittest.TestCase): ...@@ -217,6 +217,20 @@ class TestUtil(unittest.TestCase):
slapos.util.dict2xml(self.xml2dict1_dict) slapos.util.dict2xml(self.xml2dict1_dict)
) )
def test_dumps_loads(self):
simple_object = {"ok": [True]}
self.assertEqual(simple_object, slapos.util.loads(slapos.util.dumps(simple_object)))
self.assertRaises(
Exception,
slapos.util.loads,
b'<marshal><object id="i2" module="nasty" class="klass">'
b'<tuple></tuple><dictionary id="i3"/></object></marshal>')
class Nasty(object):
pass
self.assertRaises(Exception, slapos.util.dumps, Nasty())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -33,7 +33,7 @@ import socket ...@@ -33,7 +33,7 @@ import socket
import struct import struct
import subprocess import subprocess
import sqlite3 import sqlite3
from xml_marshaller.xml_marshaller import dumps, loads from xml_marshaller.xml_marshaller import Marshaller, Unmarshaller
from lxml import etree from lxml import etree
import six import six
from six.moves.urllib import parse from six.moves.urllib import parse
...@@ -49,6 +49,36 @@ except NameError: # make pylint happy on python2... ...@@ -49,6 +49,36 @@ except NameError: # make pylint happy on python2...
_ALLOWED_CLASS_SET = frozenset((
('slapos.slap.slap', 'Computer'),
('slapos.slap.slap', 'ComputerPartition'),
('slapos.slap.slap', 'SoftwareRelease'),
('slapos.slap.slap', 'SoftwareInstance'),
))
class SafeXMLMarshaller(Marshaller):
def m_instance(self, value, kw):
cls = value.__class__
if (cls.__module__, cls.__name__) in _ALLOWED_CLASS_SET:
return super(SafeXMLMarshaller, self).m_instance(value, kw)
raise RuntimeError("Refusing to marshall {}.{}".format(
cls.__module__, cls.__name__))
dumps = SafeXMLMarshaller().dumps
class SafeXMLUnmrshaller(Unmarshaller, object):
def find_class(self, module, name):
if (module, name) in _ALLOWED_CLASS_SET:
return super(SafeXMLUnmrshaller, self).find_class(module, name)
raise RuntimeError("Refusing to unmarshall {}.{}".format(module, name))
loads = SafeXMLUnmrshaller().loads
def mkdir_p(path, mode=0o700): def mkdir_p(path, mode=0o700):
"""\ """\
Creates a directory and its parents, if needed. Creates a directory and its parents, if needed.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment