Commit 8ec6ae1d authored by David Wilson's avatar David Wilson

importer: module whitelist/blacklist support

Hoped to avoid it, but it's the obvious solution for Ansible.
parent 43ba1c76
...@@ -394,7 +394,7 @@ class Importer(object): ...@@ -394,7 +394,7 @@ class Importer(object):
:param context: Context to communicate via. :param context: Context to communicate via.
""" """
def __init__(self, router, context, core_src): def __init__(self, router, context, core_src, whitelist=(), blacklist=()):
self._context = context self._context = context
self._present = {'mitogen': [ self._present = {'mitogen': [
'mitogen.compat', 'mitogen.compat',
...@@ -407,6 +407,15 @@ class Importer(object): ...@@ -407,6 +407,15 @@ class Importer(object):
'mitogen.utils', 'mitogen.utils',
]} ]}
self._lock = threading.Lock() self._lock = threading.Lock()
self.whitelist = whitelist or ['']
self.blacklist = list(blacklist) + [
# 2.x generates needless imports for 'builtins', while 3.x does the
# same for '__builtin__'. The correct one is built-in, the other
# always a negative round-trip.
'builtins',
'__builtin__',
]
# Presence of an entry in this map indicates in-flight GET_MODULE. # Presence of an entry in this map indicates in-flight GET_MODULE.
self._callbacks = {} self._callbacks = {}
router.add_handler(self._on_load_module, LOAD_MODULE) router.add_handler(self._on_load_module, LOAD_MODULE)
...@@ -451,12 +460,9 @@ class Importer(object): ...@@ -451,12 +460,9 @@ class Importer(object):
finally: finally:
del _tls.running del _tls.running
def _load_module_hacks(self, fullname): def _refuse_imports(self, fullname):
if fullname in ('builtins', '__builtin__'): if ((not any(fullname.startswith(s) for s in self.whitelist)) or
# Python 2.x will generate needless imports for 'builtins', while (any(fullname.startswith(s) for s in self.blacklist))):
# Python 3.x will generate needless imports for '__builtin__'. The
# correct one is already present in sys.modules, the other is
# always a negative round-trip.
raise ImportError('Refused') raise ImportError('Refused')
f = sys._getframe(2) f = sys._getframe(2)
...@@ -515,7 +521,7 @@ class Importer(object): ...@@ -515,7 +521,7 @@ class Importer(object):
def load_module(self, fullname): def load_module(self, fullname):
_v and LOG.debug('Importer.load_module(%r)', fullname) _v and LOG.debug('Importer.load_module(%r)', fullname)
self._load_module_hacks(fullname) self._refuse_imports(fullname)
event = threading.Event() event = threading.Event()
self._request_module(fullname, event.set) self._request_module(fullname, event.set)
...@@ -1260,7 +1266,7 @@ class ExternalContext(object): ...@@ -1260,7 +1266,7 @@ class ExternalContext(object):
if debug: if debug:
enable_debug_logging() enable_debug_logging()
def _setup_importer(self, core_src_fd): def _setup_importer(self, core_src_fd, whitelist, blacklist):
if core_src_fd: if core_src_fd:
with os.fdopen(101, 'r', 1) as fp: with os.fdopen(101, 'r', 1) as fp:
core_size = int(fp.readline()) core_size = int(fp.readline())
...@@ -1271,7 +1277,9 @@ class ExternalContext(object): ...@@ -1271,7 +1277,9 @@ class ExternalContext(object):
else: else:
core_src = None core_src = None
self.importer = Importer(self.router, self.parent, core_src) self.importer = Importer(self.router, self.parent, core_src,
whitelist, blacklist)
self.router.importer = self.importer
sys.meta_path.append(self.importer) sys.meta_path.append(self.importer)
def _setup_package(self, context_id, parent_ids): def _setup_package(self, context_id, parent_ids):
...@@ -1328,12 +1336,13 @@ class ExternalContext(object): ...@@ -1328,12 +1336,13 @@ class ExternalContext(object):
self.dispatch_stopped = True self.dispatch_stopped = True
def main(self, parent_ids, context_id, debug, profiling, log_level, def main(self, parent_ids, context_id, debug, profiling, log_level,
in_fd=100, out_fd=1, core_src_fd=101, setup_stdio=True): in_fd=100, out_fd=1, core_src_fd=101, setup_stdio=True,
whitelist=(), blacklist=()):
self._setup_master(profiling, parent_ids[0], context_id, in_fd, out_fd) self._setup_master(profiling, parent_ids[0], context_id, in_fd, out_fd)
try: try:
try: try:
self._setup_logging(debug, log_level) self._setup_logging(debug, log_level)
self._setup_importer(core_src_fd) self._setup_importer(core_src_fd, whitelist, blacklist)
self._setup_package(context_id, parent_ids) self._setup_package(context_id, parent_ids)
if setup_stdio: if setup_stdio:
self._setup_stdio() self._setup_stdio()
......
...@@ -341,17 +341,17 @@ def run(dest, router, args, deadline=None, econtext=None): ...@@ -341,17 +341,17 @@ def run(dest, router, args, deadline=None, econtext=None):
fp.write('#!%s\n' % (sys.executable,)) fp.write('#!%s\n' % (sys.executable,))
fp.write(inspect.getsource(mitogen.core)) fp.write(inspect.getsource(mitogen.core))
fp.write('\n') fp.write('\n')
fp.write('ExternalContext().main%r\n' % (( fp.write('ExternalContext().main(**%r)\n' % ({
parent_ids, # parent_ids 'parent_ids': parent_ids,
context_id, # context_id 'context_id': context_id,
router.debug, # debug 'debug': router.debug,
router.profiling, # profiling 'profiling': router.profiling,
logging.getLogger().level, # log_level 'log_level': mitogen.parent.get_log_level(),
sock2.fileno(), # in_fd 'in_fd': sock2.fileno(),
sock2.fileno(), # out_fd 'out_fd': sock2.fileno(),
None, # core_src_fd 'core_src_fd': None,
False, # setup_stdio 'setup_stdio': False,
),)) },))
finally: finally:
fp.close() fp.close()
......
...@@ -441,6 +441,8 @@ class ModuleResponder(object): ...@@ -441,6 +441,8 @@ class ModuleResponder(object):
self._router = router self._router = router
self._finder = ModuleFinder() self._finder = ModuleFinder()
self._cache = {} # fullname -> pickled self._cache = {} # fullname -> pickled
self.blacklist = []
self.whitelist = []
router.add_handler(self._on_get_module, mitogen.core.GET_MODULE) router.add_handler(self._on_get_module, mitogen.core.GET_MODULE)
def __repr__(self): def __repr__(self):
...@@ -448,6 +450,12 @@ class ModuleResponder(object): ...@@ -448,6 +450,12 @@ class ModuleResponder(object):
MAIN_RE = re.compile(r'^if\s+__name__\s*==\s*.__main__.\s*:', re.M) MAIN_RE = re.compile(r'^if\s+__name__\s*==\s*.__main__.\s*:', re.M)
def whitelist_prefix(self, fullname):
self.whitelist.append(fullname)
def blacklist_prefix(self, fullname):
self.blacklist.append(fullname)
def neutralize_main(self, src): def neutralize_main(self, src):
"""Given the source for the __main__ module, try to find where it """Given the source for the __main__ module, try to find where it
begins conditional execution based on a "if __name__ == '__main__'" begins conditional execution based on a "if __name__ == '__main__'"
...@@ -458,6 +466,9 @@ class ModuleResponder(object): ...@@ -458,6 +466,9 @@ class ModuleResponder(object):
return src return src
def _build_tuple(self, fullname): def _build_tuple(self, fullname):
if fullname in self._blacklist:
raise ImportError('blacklisted')
if fullname in self._cache: if fullname in self._cache:
return self._cache[fullname] return self._cache[fullname]
......
...@@ -63,6 +63,10 @@ class Argv(object): ...@@ -63,6 +63,10 @@ class Argv(object):
return ' '.join(map(self.escape, self.argv)) return ' '.join(map(self.escape, self.argv))
def get_log_level():
return (LOG.level or logging.getLogger().level or logging.INFO)
def minimize_source(source): def minimize_source(source):
subber = lambda match: '""' + ('\n' * match.group(0).count('\n')) subber = lambda match: '""' + ('\n' * match.group(0).count('\n'))
source = DOCSTRING_RE.sub(subber, source) source = DOCSTRING_RE.sub(subber, source)
...@@ -336,14 +340,17 @@ class Stream(mitogen.core.Stream): ...@@ -336,14 +340,17 @@ class Stream(mitogen.core.Stream):
def get_preamble(self): def get_preamble(self):
parent_ids = mitogen.parent_ids[:] parent_ids = mitogen.parent_ids[:]
parent_ids.insert(0, mitogen.context_id) parent_ids.insert(0, mitogen.context_id)
source = inspect.getsource(mitogen.core) source = inspect.getsource(mitogen.core)
source += '\nExternalContext().main%r\n' % (( source += '\nExternalContext().main(**%r)\n' % ({
parent_ids, # parent_ids 'parent_ids': parent_ids,
self.remote_id, # context_id 'context_id': self.remote_id,
self.debug, 'debug': self.debug,
self.profiling, 'profiling': self.profiling,
LOG.level or logging.getLogger().level or logging.INFO, 'log_level': get_log_level(),
),) 'whitelist': self._router.get_module_whitelist(),
'blacklist': self._router.get_module_blacklist(),
},)
compressed = zlib.compress(minimize_source(source)) compressed = zlib.compress(minimize_source(source))
return str(len(compressed)) + '\n' + compressed return str(len(compressed)) + '\n' + compressed
...@@ -385,6 +392,16 @@ class ChildIdAllocator(object): ...@@ -385,6 +392,16 @@ class ChildIdAllocator(object):
class Router(mitogen.core.Router): class Router(mitogen.core.Router):
context_class = mitogen.core.Context context_class = mitogen.core.Context
def get_module_blacklist(self):
if mitogen.context_id == 0:
return self.responder.blacklist
return self.importer.blacklist
def get_module_whitelist(self):
if mitogen.context_id == 0:
return self.responder.whitelist
return self.importer.whitelist
def allocate_id(self): def allocate_id(self):
return self.id_allocator.allocate() return self.id_allocator.allocate()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment