Commit 2dbd3eae authored by Shane Hathaway's avatar Shane Hathaway

Synced with Zope-2_7-branch.

Removed support for Python 2.1, fixed yield test, and added a test for
bad names set by exception handlers.
parent c6e35709
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
############################################################################## ##############################################################################
from __future__ import nested_scopes from __future__ import nested_scopes
__version__='$Revision: 1.11 $'[11:-2] __version__='$Revision: 1.12 $'[11:-2]
import exceptions import exceptions
import new import new
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
Python standard library. Python standard library.
""" """
__version__='$Revision: 1.4 $'[11:-2] __version__='$Revision: 1.5 $'[11:-2]
from compiler import ast, parse, misc, syntax from compiler import ast, parse, misc, syntax
......
...@@ -15,7 +15,7 @@ RestrictionMutator modifies a tree produced by ...@@ -15,7 +15,7 @@ RestrictionMutator modifies a tree produced by
compiler.transformer.Transformer, restricting and enhancing the compiler.transformer.Transformer, restricting and enhancing the
code in various ways before sending it to pycodegen. code in various ways before sending it to pycodegen.
''' '''
__version__='$Revision: 1.11 $'[11:-2] __version__='$Revision: 1.12 $'[11:-2]
from SelectCompiler import ast, parse, OP_ASSIGN, OP_DELETE, OP_APPLY from SelectCompiler import ast, parse, OP_ASSIGN, OP_DELETE, OP_APPLY
...@@ -40,12 +40,11 @@ def exprNode(txt): ...@@ -40,12 +40,11 @@ def exprNode(txt):
'''Make a "clean" expression node''' '''Make a "clean" expression node'''
return stmtNode(txt).expr return stmtNode(txt).expr
# There should be up to four objects in the global namespace. # There should be up to four objects in the global namespace. If a
# If a wrapper function or print target is needed in a particular # wrapper function or print target is needed in a particular module or
# module or function, it is obtained from one of these objects. # function, it is obtained from one of these objects. There is a
# It is stored in a variable with the same name as the global # local and a global binding for each object: the global name has a
# object, but without a single trailing underscore. This variable is # trailing underscore, while the local name does not.
# local, and therefore efficient to access, in function scopes.
_print_target_name = ast.Name('_print') _print_target_name = ast.Name('_print')
_getattr_name = ast.Name('_getattr') _getattr_name = ast.Name('_getattr')
_getattr_name_expr = ast.Name('_getattr_') _getattr_name_expr = ast.Name('_getattr_')
...@@ -90,6 +89,8 @@ class RestrictionMutator: ...@@ -90,6 +89,8 @@ class RestrictionMutator:
self.used_names = {} self.used_names = {}
def error(self, node, info): def error(self, node, info):
"""Records a security error discovered during compilation.
"""
lineno = getattr(node, 'lineno', None) lineno = getattr(node, 'lineno', None)
if lineno is not None and lineno > 0: if lineno is not None and lineno > 0:
self.errors.append('Line %d: %s' % (lineno, info)) self.errors.append('Line %d: %s' % (lineno, info))
...@@ -97,6 +98,20 @@ class RestrictionMutator: ...@@ -97,6 +98,20 @@ class RestrictionMutator:
self.errors.append(info) self.errors.append(info)
def checkName(self, node, name): def checkName(self, node, name):
"""Verifies that a name being assigned is safe.
This is to prevent people from doing things like:
__metatype__ = mytype (opens up metaclasses, a big unknown
in terms of security)
__path__ = foo (could this confuse the import machinery?)
_getattr = somefunc (not very useful, but could open a hole)
Note that assigning a variable is not the only way to assign
a name. def _badname, class _badname, import foo as _badname,
and perhaps other statements assign names. Special case:
'_' is allowed.
"""
if len(name) > 1 and name[0] == '_': if len(name) > 1 and name[0] == '_':
# Note: "_" *is* allowed. # Note: "_" *is* allowed.
self.error(node, '"%s" is an invalid variable name because' self.error(node, '"%s" is an invalid variable name because'
...@@ -105,9 +120,12 @@ class RestrictionMutator: ...@@ -105,9 +120,12 @@ class RestrictionMutator:
self.error(node, '"printed" is a reserved name.') self.error(node, '"printed" is a reserved name.')
def checkAttrName(self, node): def checkAttrName(self, node):
# This prevents access to protected attributes of guards """Verifies that an attribute name does not start with _.
# and is thus essential regardless of the security policy,
# unless some other solution is devised. As long as guards (security proxies) have underscored names,
this underscore protection is important regardless of the
security policy. Special case: '_' is allowed.
"""
name = node.attrname name = node.attrname
if len(name) > 1 and name[0] == '_': if len(name) > 1 and name[0] == '_':
# Note: "_" *is* allowed. # Note: "_" *is* allowed.
...@@ -115,7 +133,15 @@ class RestrictionMutator: ...@@ -115,7 +133,15 @@ class RestrictionMutator:
'because it starts with "_".' % name) 'because it starts with "_".' % name)
def prepBody(self, body): def prepBody(self, body):
"""Appends prep code to the beginning of a code suite. """Prepends preparation code to a code suite.
For example, if a code suite uses getattr operations,
this places the following code at the beginning of the suite:
global _getattr_
_getattr = _getattr_
Similarly for _getitem_, _print_, and _write_.
""" """
info = self.funcinfo info = self.funcinfo
if info._print_used or info._printed_used: if info._print_used or info._printed_used:
...@@ -135,6 +161,12 @@ class RestrictionMutator: ...@@ -135,6 +161,12 @@ class RestrictionMutator:
body[0:0] = _prep_code['write'] body[0:0] = _prep_code['write']
def visitFunction(self, node, walker): def visitFunction(self, node, walker):
"""Checks and mutates a function definition.
Checks the name of the function and the argument names using
checkName(). It also calls prepBody() to prepend code to the
beginning of the code suite.
"""
self.checkName(node, node.name) self.checkName(node, node.name)
for argname in node.argnames: for argname in node.argnames:
self.checkName(node, argname) self.checkName(node, argname)
...@@ -149,11 +181,30 @@ class RestrictionMutator: ...@@ -149,11 +181,30 @@ class RestrictionMutator:
return node return node
def visitLambda(self, node, walker): def visitLambda(self, node, walker):
"""Checks and mutates an anonymous function definition.
Checks the argument names using checkName(). It also calls
prepBody() to prepend code to the beginning of the code suite.
"""
for argname in node.argnames: for argname in node.argnames:
self.checkName(node, argname) self.checkName(node, argname)
return walker.defaultVisitNode(node) return walker.defaultVisitNode(node)
def visitPrint(self, node, walker): def visitPrint(self, node, walker):
"""Checks and mutates a print statement.
Adds a target to all print statements. 'print foo' becomes
'print >> _print, foo', where _print is the default print
target defined for this scope.
Alternatively, if the untrusted code provides its own target,
we have to check the 'write' method of the target.
'print >> ob, foo' becomes
'print >> (_getattr(ob, 'write') and ob), foo'.
Otherwise, it would be possible to call the write method of
templates and scripts; 'write' happens to be the name of the
method that changes them.
"""
node = walker.defaultVisitNode(node) node = walker.defaultVisitNode(node)
self.funcinfo._print_used = 1 self.funcinfo._print_used = 1
if node.dest is None: if node.dest is None:
...@@ -171,6 +222,10 @@ class RestrictionMutator: ...@@ -171,6 +222,10 @@ class RestrictionMutator:
visitPrintnl = visitPrint visitPrintnl = visitPrint
def visitName(self, node, walker): def visitName(self, node, walker):
"""Prevents access to protected names as defined by checkName().
Also converts use of the name 'printed' to an expression.
"""
if node.name == 'printed': if node.name == 'printed':
# Replace name lookup with an expression. # Replace name lookup with an expression.
self.funcinfo._printed_used = 1 self.funcinfo._printed_used = 1
...@@ -180,10 +235,19 @@ class RestrictionMutator: ...@@ -180,10 +235,19 @@ class RestrictionMutator:
return node return node
def visitAssName(self, node, walker): def visitAssName(self, node, walker):
"""Checks a name assignment using checkName().
"""
self.checkName(node, node.name) self.checkName(node, node.name)
return node return node
def visitGetattr(self, node, walker): def visitGetattr(self, node, walker):
"""Converts attribute access to a function call.
'foo.bar' becomes '_getattr(foo, "bar")'.
Also prevents augmented assignment of attributes, which would
be difficult to support correctly.
"""
self.checkAttrName(node) self.checkAttrName(node)
node = walker.defaultVisitNode(node) node = walker.defaultVisitNode(node)
if getattr(node, 'in_aug_assign', 0): if getattr(node, 'in_aug_assign', 0):
...@@ -195,15 +259,30 @@ class RestrictionMutator: ...@@ -195,15 +259,30 @@ class RestrictionMutator:
#self.funcinfo._write_used = 1 #self.funcinfo._write_used = 1
self.funcinfo._getattr_used = 1 self.funcinfo._getattr_used = 1
if self.funcinfo._is_suite: if self.funcinfo._is_suite:
# Use the local function _getattr().
ga = _getattr_name ga = _getattr_name
else: else:
# Use the global function _getattr_().
ga = _getattr_name_expr ga = _getattr_name_expr
return ast.CallFunc(ga, [node.expr, ast.Const(node.attrname)]) return ast.CallFunc(ga, [node.expr, ast.Const(node.attrname)])
def visitSubscript(self, node, walker): def visitSubscript(self, node, walker):
"""Checks all kinds of subscripts.
'foo[bar] += baz' is disallowed.
'a = foo[bar, baz]' becomes 'a = _getitem(foo, (bar, baz))'.
'a = foo[bar]' becomes 'a = _getitem(foo, bar)'.
'a = foo[bar:baz]' becomes 'a = _getitem(foo, slice(bar, baz))'.
'a = foo[:baz]' becomes 'a = _getitem(foo, slice(None, baz))'.
'a = foo[bar:]' becomes 'a = _getitem(foo, slice(bar, None))'.
'del foo[bar]' becomes 'del _write(foo)[bar]'.
'foo[bar] = a' becomes '_write(foo)[bar] = a'.
The _write function returns a security proxy.
"""
node = walker.defaultVisitNode(node) node = walker.defaultVisitNode(node)
if node.flags == OP_APPLY: if node.flags == OP_APPLY:
# get subscript or slice # Set 'subs' to the node that represents the subscript or slice.
if getattr(node, 'in_aug_assign', 0): if getattr(node, 'in_aug_assign', 0):
# We're in an augmented assignment # We're in an augmented assignment
# We might support this later... # We might support this later...
...@@ -245,6 +324,11 @@ class RestrictionMutator: ...@@ -245,6 +324,11 @@ class RestrictionMutator:
visitSlice = visitSubscript visitSlice = visitSubscript
def visitAssAttr(self, node, walker): def visitAssAttr(self, node, walker):
"""Checks and mutates attribute assignment.
'a.b = c' becomes '_write(a).b = c'.
The _write function returns a security proxy.
"""
self.checkAttrName(node) self.checkAttrName(node)
node = walker.defaultVisitNode(node) node = walker.defaultVisitNode(node)
node.expr = ast.CallFunc(_write_guard_name, [node.expr]) node.expr = ast.CallFunc(_write_guard_name, [node.expr])
...@@ -258,23 +342,45 @@ class RestrictionMutator: ...@@ -258,23 +342,45 @@ class RestrictionMutator:
self.error(node, 'Yield statements are not allowed.') self.error(node, 'Yield statements are not allowed.')
def visitClass(self, node, walker): def visitClass(self, node, walker):
# Should classes be allowed at all?? """Checks the name of a class using checkName().
Should classes be allowed at all? They don't cause security
issues, but they aren't very useful either since untrusted
code can't assign instance attributes.
"""
self.checkName(node, node.name) self.checkName(node, node.name)
return walker.defaultVisitNode(node) return walker.defaultVisitNode(node)
def visitModule(self, node, walker): def visitModule(self, node, walker):
"""Adds prep code at module scope.
Zope doesn't make use of this. The body of Python scripts is
always at function scope.
"""
self.funcinfo._is_suite = 1 self.funcinfo._is_suite = 1
node = walker.defaultVisitNode(node) node = walker.defaultVisitNode(node)
self.prepBody(node.node.nodes) self.prepBody(node.node.nodes)
return node return node
def visitAugAssign(self, node, walker): def visitAugAssign(self, node, walker):
"""Makes a note that augmented assignment is in use.
Note that although augmented assignment of attributes and
subscripts is disallowed, augmented assignment of names (such
as 'n += 1') is allowed.
This could be a problem if untrusted code got access to a
mutable database object that supports augmented assignment.
"""
node.node.in_aug_assign = 1 node.node.in_aug_assign = 1
return walker.defaultVisitNode(node) return walker.defaultVisitNode(node)
def visitImport(self, node, walker): def visitImport(self, node, walker):
"""Checks names imported using checkName().
"""
for name, asname in node.names: for name, asname in node.names:
self.checkName(node, name) self.checkName(node, name)
if asname: if asname:
self.checkName(node, asname) self.checkName(node, asname)
return node return node
...@@ -12,30 +12,18 @@ ...@@ -12,30 +12,18 @@
############################################################################## ##############################################################################
''' '''
Compiler selector. Compiler selector.
$Id: SelectCompiler.py,v 1.4 2002/08/14 21:44:31 mj Exp $ $Id: SelectCompiler.py,v 1.5 2003/11/06 17:11:49 shane Exp $
''' '''
import sys import sys
if sys.version_info[1] < 2: # Use the compiler from the standard library.
# Use the compiler_2_1 package. import compiler
from compiler_2_1 import ast from compiler import ast
from compiler_2_1.transformer import parse from compiler.transformer import parse
from compiler_2_1.consts import OP_ASSIGN, OP_DELETE, OP_APPLY from compiler.consts import OP_ASSIGN, OP_DELETE, OP_APPLY
from RCompile_2_1 import \ from RCompile import \
compile_restricted, \
compile_restricted_function, \
compile_restricted_exec, \
compile_restricted_eval
else:
# Use the compiler from the standard library.
import compiler
from compiler import ast
from compiler.transformer import parse
from compiler.consts import OP_ASSIGN, OP_DELETE, OP_APPLY
from RCompile import \
compile_restricted, \ compile_restricted, \
compile_restricted_function, \ compile_restricted_function, \
compile_restricted_exec, \ compile_restricted_exec, \
......
"""Package for parsing and compiling Python source code
There are several functions defined at the top level that are imported
from modules contained in the package.
parse(buf) -> AST
Converts a string containing Python source code to an abstract
syntax tree (AST). The AST is defined in compiler.ast.
parseFile(path) -> AST
The same as parse(open(path))
walk(ast, visitor, verbose=None)
Does a pre-order walk over the ast using the visitor instance.
See compiler.visitor for details.
compile(filename)
Generates a .pyc file by compilining filename.
"""
from transformer import parse, parseFile
from visitor import walk
from pycodegen import compile
"""Python abstract syntax node definitions
This file is automatically generated.
"""
from types import TupleType, ListType
from consts import CO_VARARGS, CO_VARKEYWORDS
def flatten(list):
l = []
for elt in list:
t = type(elt)
if t is TupleType or t is ListType:
for elt2 in flatten(elt):
l.append(elt2)
else:
l.append(elt)
return l
def asList(nodes):
l = []
for item in nodes:
if hasattr(item, "asList"):
l.append(item.asList())
else:
t = type(item)
if t is TupleType or t is ListType:
l.append(tuple(asList(item)))
else:
l.append(item)
return l
nodes = {}
class Node:
lineno = None
def getType(self):
pass
def getChildren(self):
# XXX It would be better to generate flat values to begin with
return flatten(self._getChildren())
def asList(self):
return tuple(asList(self.getChildren()))
class EmptyNode(Node):
def __init__(self):
self.lineno = None
class If(Node):
nodes["if"] = "If"
def __init__(self, tests, else_):
self.tests = tests
self.else_ = else_
def _getChildren(self):
return self.tests, self.else_
def __repr__(self):
return "If(%s, %s)" % (repr(self.tests), repr(self.else_))
class ListComp(Node):
nodes["listcomp"] = "ListComp"
def __init__(self, expr, quals):
self.expr = expr
self.quals = quals
def _getChildren(self):
return self.expr, self.quals
def __repr__(self):
return "ListComp(%s, %s)" % (repr(self.expr), repr(self.quals))
class Bitor(Node):
nodes["bitor"] = "Bitor"
def __init__(self, nodes):
self.nodes = nodes
def _getChildren(self):
return self.nodes,
def __repr__(self):
return "Bitor(%s)" % (repr(self.nodes),)
class Pass(Node):
nodes["pass"] = "Pass"
def __init__(self, ):
pass
def _getChildren(self):
return ()
def __repr__(self):
return "Pass()"
class Module(Node):
nodes["module"] = "Module"
def __init__(self, doc, node):
self.doc = doc
self.node = node
def _getChildren(self):
return self.doc, self.node
def __repr__(self):
return "Module(%s, %s)" % (repr(self.doc), repr(self.node))
class Global(Node):
nodes["global"] = "Global"
def __init__(self, names):
self.names = names
def _getChildren(self):
return self.names,
def __repr__(self):
return "Global(%s)" % (repr(self.names),)
class CallFunc(Node):
nodes["callfunc"] = "CallFunc"
def __init__(self, node, args, star_args = None, dstar_args = None):
self.node = node
self.args = args
self.star_args = star_args
self.dstar_args = dstar_args
def _getChildren(self):
return self.node, self.args, self.star_args, self.dstar_args
def __repr__(self):
return "CallFunc(%s, %s, %s, %s)" % (repr(self.node), repr(self.args), repr(self.star_args), repr(self.dstar_args))
class Printnl(Node):
nodes["printnl"] = "Printnl"
def __init__(self, nodes, dest):
self.nodes = nodes
self.dest = dest
def _getChildren(self):
return self.nodes, self.dest
def __repr__(self):
return "Printnl(%s, %s)" % (repr(self.nodes), repr(self.dest))
class Tuple(Node):
nodes["tuple"] = "Tuple"
def __init__(self, nodes):
self.nodes = nodes
def _getChildren(self):
return self.nodes,
def __repr__(self):
return "Tuple(%s)" % (repr(self.nodes),)
class Compare(Node):
nodes["compare"] = "Compare"
def __init__(self, expr, ops):
self.expr = expr
self.ops = ops
def _getChildren(self):
return self.expr, self.ops
def __repr__(self):
return "Compare(%s, %s)" % (repr(self.expr), repr(self.ops))
class And(Node):
nodes["and"] = "And"
def __init__(self, nodes):
self.nodes = nodes
def _getChildren(self):
return self.nodes,
def __repr__(self):
return "And(%s)" % (repr(self.nodes),)
class Lambda(Node):
nodes["lambda"] = "Lambda"
def __init__(self, argnames, defaults, flags, code):
self.argnames = argnames
self.defaults = defaults
self.flags = flags
self.code = code
self.varargs = self.kwargs = None
if flags & CO_VARARGS:
self.varargs = 1
if flags & CO_VARKEYWORDS:
self.kwargs = 1
def _getChildren(self):
return self.argnames, self.defaults, self.flags, self.code
def __repr__(self):
return "Lambda(%s, %s, %s, %s)" % (repr(self.argnames), repr(self.defaults), repr(self.flags), repr(self.code))
class Assign(Node):
nodes["assign"] = "Assign"
def __init__(self, nodes, expr):
self.nodes = nodes
self.expr = expr
def _getChildren(self):
return self.nodes, self.expr
def __repr__(self):
return "Assign(%s, %s)" % (repr(self.nodes), repr(self.expr))
class Sub(Node):
nodes["sub"] = "Sub"
def __init__(self, (left, right)):
self.left = left
self.right = right
def _getChildren(self):
return self.left, self.right
def __repr__(self):
return "Sub(%s, %s)" % (repr(self.left), repr(self.right))
class ListCompIf(Node):
nodes["listcompif"] = "ListCompIf"
def __init__(self, test):
self.test = test
def _getChildren(self):
return self.test,
def __repr__(self):
return "ListCompIf(%s)" % (repr(self.test),)
class Div(Node):
nodes["div"] = "Div"
def __init__(self, (left, right)):
self.left = left
self.right = right
def _getChildren(self):
return self.left, self.right
def __repr__(self):
return "Div(%s, %s)" % (repr(self.left), repr(self.right))
class Discard(Node):
nodes["discard"] = "Discard"
def __init__(self, expr):
self.expr = expr
def _getChildren(self):
return self.expr,
def __repr__(self):
return "Discard(%s)" % (repr(self.expr),)
class Backquote(Node):
nodes["backquote"] = "Backquote"
def __init__(self, expr):
self.expr = expr
def _getChildren(self):
return self.expr,
def __repr__(self):
return "Backquote(%s)" % (repr(self.expr),)
class RightShift(Node):
nodes["rightshift"] = "RightShift"
def __init__(self, (left, right)):
self.left = left
self.right = right
def _getChildren(self):
return self.left, self.right
def __repr__(self):
return "RightShift(%s, %s)" % (repr(self.left), repr(self.right))
class Continue(Node):
nodes["continue"] = "Continue"
def __init__(self, ):
pass
def _getChildren(self):
return ()
def __repr__(self):
return "Continue()"
class While(Node):
nodes["while"] = "While"
def __init__(self, test, body, else_):
self.test = test
self.body = body
self.else_ = else_
def _getChildren(self):
return self.test, self.body, self.else_
def __repr__(self):
return "While(%s, %s, %s)" % (repr(self.test), repr(self.body), repr(self.else_))
class AssName(Node):
nodes["assname"] = "AssName"
def __init__(self, name, flags):
self.name = name
self.flags = flags
def _getChildren(self):
return self.name, self.flags
def __repr__(self):
return "AssName(%s, %s)" % (repr(self.name), repr(self.flags))
class LeftShift(Node):
nodes["leftshift"] = "LeftShift"
def __init__(self, (left, right)):
self.left = left
self.right = right
def _getChildren(self):
return self.left, self.right
def __repr__(self):
return "LeftShift(%s, %s)" % (repr(self.left), repr(self.right))
class Mul(Node):
nodes["mul"] = "Mul"
def __init__(self, (left, right)):
self.left = left
self.right = right
def _getChildren(self):
return self.left, self.right
def __repr__(self):
return "Mul(%s, %s)" % (repr(self.left), repr(self.right))
class List(Node):
nodes["list"] = "List"
def __init__(self, nodes):
self.nodes = nodes
def _getChildren(self):
return self.nodes,
def __repr__(self):
return "List(%s)" % (repr(self.nodes),)
class AugAssign(Node):
nodes["augassign"] = "AugAssign"
def __init__(self, node, op, expr):
self.node = node
self.op = op
self.expr = expr
def _getChildren(self):
return self.node, self.op, self.expr
def __repr__(self):
return "AugAssign(%s, %s, %s)" % (repr(self.node), repr(self.op), repr(self.expr))
class Or(Node):
nodes["or"] = "Or"
def __init__(self, nodes):
self.nodes = nodes
def _getChildren(self):
return self.nodes,
def __repr__(self):
return "Or(%s)" % (repr(self.nodes),)
class Keyword(Node):
nodes["keyword"] = "Keyword"
def __init__(self, name, expr):
self.name = name
self.expr = expr
def _getChildren(self):
return self.name, self.expr
def __repr__(self):
return "Keyword(%s, %s)" % (repr(self.name), repr(self.expr))
class AssAttr(Node):
nodes["assattr"] = "AssAttr"
def __init__(self, expr, attrname, flags):
self.expr = expr
self.attrname = attrname
self.flags = flags
def _getChildren(self):
return self.expr, self.attrname, self.flags
def __repr__(self):
return "AssAttr(%s, %s, %s)" % (repr(self.expr), repr(self.attrname), repr(self.flags))
class Const(Node):
nodes["const"] = "Const"
def __init__(self, value):
self.value = value
def _getChildren(self):
return self.value,
def __repr__(self):
return "Const(%s)" % (repr(self.value),)
class Mod(Node):
nodes["mod"] = "Mod"
def __init__(self, (left, right)):
self.left = left
self.right = right
def _getChildren(self):
return self.left, self.right
def __repr__(self):
return "Mod(%s, %s)" % (repr(self.left), repr(self.right))
class Class(Node):
nodes["class"] = "Class"
def __init__(self, name, bases, doc, code):
self.name = name
self.bases = bases
self.doc = doc
self.code = code
def _getChildren(self):
return self.name, self.bases, self.doc, self.code
def __repr__(self):
return "Class(%s, %s, %s, %s)" % (repr(self.name), repr(self.bases), repr(self.doc), repr(self.code))
class Not(Node):
nodes["not"] = "Not"
def __init__(self, expr):
self.expr = expr
def _getChildren(self):
return self.expr,
def __repr__(self):
return "Not(%s)" % (repr(self.expr),)
class Bitxor(Node):
nodes["bitxor"] = "Bitxor"
def __init__(self, nodes):
self.nodes = nodes
def _getChildren(self):
return self.nodes,
def __repr__(self):
return "Bitxor(%s)" % (repr(self.nodes),)
class TryFinally(Node):
nodes["tryfinally"] = "TryFinally"
def __init__(self, body, final):
self.body = body
self.final = final
def _getChildren(self):
return self.body, self.final
def __repr__(self):
return "TryFinally(%s, %s)" % (repr(self.body), repr(self.final))
class Bitand(Node):
nodes["bitand"] = "Bitand"
def __init__(self, nodes):
self.nodes = nodes
def _getChildren(self):
return self.nodes,
def __repr__(self):
return "Bitand(%s)" % (repr(self.nodes),)
class Break(Node):
nodes["break"] = "Break"
def __init__(self, ):
pass
def _getChildren(self):
return ()
def __repr__(self):
return "Break()"
class Stmt(Node):
nodes["stmt"] = "Stmt"
def __init__(self, nodes):
self.nodes = nodes
def _getChildren(self):
return self.nodes,
def __repr__(self):
return "Stmt(%s)" % (repr(self.nodes),)
class Assert(Node):
nodes["assert"] = "Assert"
def __init__(self, test, fail):
self.test = test
self.fail = fail
def _getChildren(self):
return self.test, self.fail
def __repr__(self):
return "Assert(%s, %s)" % (repr(self.test), repr(self.fail))
class Exec(Node):
nodes["exec"] = "Exec"
def __init__(self, expr, locals, globals):
self.expr = expr
self.locals = locals
self.globals = globals
def _getChildren(self):
return self.expr, self.locals, self.globals
def __repr__(self):
return "Exec(%s, %s, %s)" % (repr(self.expr), repr(self.locals), repr(self.globals))
class Power(Node):
nodes["power"] = "Power"
def __init__(self, (left, right)):
self.left = left
self.right = right
def _getChildren(self):
return self.left, self.right
def __repr__(self):
return "Power(%s, %s)" % (repr(self.left), repr(self.right))
class Import(Node):
nodes["import"] = "Import"
def __init__(self, names):
self.names = names
def _getChildren(self):
return self.names,
def __repr__(self):
return "Import(%s)" % (repr(self.names),)
class Return(Node):
nodes["return"] = "Return"
def __init__(self, value):
self.value = value
def _getChildren(self):
return self.value,
def __repr__(self):
return "Return(%s)" % (repr(self.value),)
class Add(Node):
nodes["add"] = "Add"
def __init__(self, (left, right)):
self.left = left
self.right = right
def _getChildren(self):
return self.left, self.right
def __repr__(self):
return "Add(%s, %s)" % (repr(self.left), repr(self.right))
class Function(Node):
nodes["function"] = "Function"
def __init__(self, name, argnames, defaults, flags, doc, code):
self.name = name
self.argnames = argnames
self.defaults = defaults
self.flags = flags
self.doc = doc
self.code = code
self.varargs = self.kwargs = None
if flags & CO_VARARGS:
self.varargs = 1
if flags & CO_VARKEYWORDS:
self.kwargs = 1
def _getChildren(self):
return self.name, self.argnames, self.defaults, self.flags, self.doc, self.code
def __repr__(self):
return "Function(%s, %s, %s, %s, %s, %s)" % (repr(self.name), repr(self.argnames), repr(self.defaults), repr(self.flags), repr(self.doc), repr(self.code))
class TryExcept(Node):
nodes["tryexcept"] = "TryExcept"
def __init__(self, body, handlers, else_):
self.body = body
self.handlers = handlers
self.else_ = else_
def _getChildren(self):
return self.body, self.handlers, self.else_
def __repr__(self):
return "TryExcept(%s, %s, %s)" % (repr(self.body), repr(self.handlers), repr(self.else_))
class Subscript(Node):
nodes["subscript"] = "Subscript"
def __init__(self, expr, flags, subs):
self.expr = expr
self.flags = flags
self.subs = subs
def _getChildren(self):
return self.expr, self.flags, self.subs
def __repr__(self):
return "Subscript(%s, %s, %s)" % (repr(self.expr), repr(self.flags), repr(self.subs))
class Ellipsis(Node):
nodes["ellipsis"] = "Ellipsis"
def __init__(self, ):
pass
def _getChildren(self):
return ()
def __repr__(self):
return "Ellipsis()"
class Print(Node):
nodes["print"] = "Print"
def __init__(self, nodes, dest):
self.nodes = nodes
self.dest = dest
def _getChildren(self):
return self.nodes, self.dest
def __repr__(self):
return "Print(%s, %s)" % (repr(self.nodes), repr(self.dest))
class UnaryAdd(Node):
nodes["unaryadd"] = "UnaryAdd"
def __init__(self, expr):
self.expr = expr
def _getChildren(self):
return self.expr,
def __repr__(self):
return "UnaryAdd(%s)" % (repr(self.expr),)
class ListCompFor(Node):
nodes["listcompfor"] = "ListCompFor"
def __init__(self, assign, list, ifs):
self.assign = assign
self.list = list
self.ifs = ifs
def _getChildren(self):
return self.assign, self.list, self.ifs
def __repr__(self):
return "ListCompFor(%s, %s, %s)" % (repr(self.assign), repr(self.list), repr(self.ifs))
class Dict(Node):
nodes["dict"] = "Dict"
def __init__(self, items):
self.items = items
def _getChildren(self):
return self.items,
def __repr__(self):
return "Dict(%s)" % (repr(self.items),)
class Getattr(Node):
nodes["getattr"] = "Getattr"
def __init__(self, expr, attrname):
self.expr = expr
self.attrname = attrname
def _getChildren(self):
return self.expr, self.attrname
def __repr__(self):
return "Getattr(%s, %s)" % (repr(self.expr), repr(self.attrname))
class AssList(Node):
nodes["asslist"] = "AssList"
def __init__(self, nodes):
self.nodes = nodes
def _getChildren(self):
return self.nodes,
def __repr__(self):
return "AssList(%s)" % (repr(self.nodes),)
class UnarySub(Node):
nodes["unarysub"] = "UnarySub"
def __init__(self, expr):
self.expr = expr
def _getChildren(self):
return self.expr,
def __repr__(self):
return "UnarySub(%s)" % (repr(self.expr),)
class Sliceobj(Node):
nodes["sliceobj"] = "Sliceobj"
def __init__(self, nodes):
self.nodes = nodes
def _getChildren(self):
return self.nodes,
def __repr__(self):
return "Sliceobj(%s)" % (repr(self.nodes),)
class Invert(Node):
nodes["invert"] = "Invert"
def __init__(self, expr):
self.expr = expr
def _getChildren(self):
return self.expr,
def __repr__(self):
return "Invert(%s)" % (repr(self.expr),)
class Name(Node):
nodes["name"] = "Name"
def __init__(self, name):
self.name = name
def _getChildren(self):
return self.name,
def __repr__(self):
return "Name(%s)" % (repr(self.name),)
class AssTuple(Node):
nodes["asstuple"] = "AssTuple"
def __init__(self, nodes):
self.nodes = nodes
def _getChildren(self):
return self.nodes,
def __repr__(self):
return "AssTuple(%s)" % (repr(self.nodes),)
class For(Node):
nodes["for"] = "For"
def __init__(self, assign, list, body, else_):
self.assign = assign
self.list = list
self.body = body
self.else_ = else_
def _getChildren(self):
return self.assign, self.list, self.body, self.else_
def __repr__(self):
return "For(%s, %s, %s, %s)" % (repr(self.assign), repr(self.list), repr(self.body), repr(self.else_))
class Raise(Node):
nodes["raise"] = "Raise"
def __init__(self, expr1, expr2, expr3):
self.expr1 = expr1
self.expr2 = expr2
self.expr3 = expr3
def _getChildren(self):
return self.expr1, self.expr2, self.expr3
def __repr__(self):
return "Raise(%s, %s, %s)" % (repr(self.expr1), repr(self.expr2), repr(self.expr3))
class From(Node):
nodes["from"] = "From"
def __init__(self, modname, names):
self.modname = modname
self.names = names
def _getChildren(self):
return self.modname, self.names
def __repr__(self):
return "From(%s, %s)" % (repr(self.modname), repr(self.names))
class Slice(Node):
nodes["slice"] = "Slice"
def __init__(self, expr, flags, lower, upper):
self.expr = expr
self.flags = flags
self.lower = lower
self.upper = upper
def _getChildren(self):
return self.expr, self.flags, self.lower, self.upper
def __repr__(self):
return "Slice(%s, %s, %s, %s)" % (repr(self.expr), repr(self.flags), repr(self.lower), repr(self.upper))
klasses = globals()
for k in nodes.keys():
nodes[k] = klasses[nodes[k]]
Module: doc, node
Stmt: nodes
Function: name, argnames, defaults, flags, doc, code
Lambda: argnames, defaults, flags, code
Class: name, bases, doc, code
Pass:
Break:
Continue:
For: assign, list, body, else_
While: test, body, else_
If: tests, else_
Exec: expr, locals, globals
From: modname, names
Import: names
Raise: expr1, expr2, expr3
TryFinally: body, final
TryExcept: body, handlers, else_
Return: value
Const: value
Print: nodes, dest
Printnl: nodes, dest
Discard: expr
AugAssign: node, op, expr
Assign: nodes, expr
AssTuple: nodes
AssList: nodes
AssName: name, flags
AssAttr: expr, attrname, flags
ListComp: expr, quals
ListCompFor: assign, list, ifs
ListCompIf: test
List: nodes
Dict: items
Not: expr
Compare: expr, ops
Name: name
Global: names
Backquote: expr
Getattr: expr, attrname
CallFunc: node, args, star_args = None, dstar_args = None
Keyword: name, expr
Subscript: expr, flags, subs
Ellipsis:
Sliceobj: nodes
Slice: expr, flags, lower, upper
Assert: test, fail
Tuple: nodes
Or: nodes
And: nodes
Bitor: nodes
Bitxor: nodes
Bitand: nodes
LeftShift: (left, right)
RightShift: (left, right)
Add: (left, right)
Sub: (left, right)
Mul: (left, right)
Div: (left, right)
Mod: (left, right)
Power: (left, right)
UnaryAdd: expr
UnarySub: expr
Invert: expr
init(Function):
self.varargs = self.kwargs = None
if flags & CO_VARARGS:
self.varargs = 1
if flags & CO_VARKEYWORDS:
self.kwargs = 1
init(Lambda):
self.varargs = self.kwargs = None
if flags & CO_VARARGS:
self.varargs = 1
if flags & CO_VARKEYWORDS:
self.kwargs = 1
"""Generate ast module from specification"""
import fileinput
import getopt
import re
import sys
from StringIO import StringIO
SPEC = "ast.txt"
COMMA = ", "
def load_boilerplate(file):
f = open(file)
buf = f.read()
f.close()
i = buf.find('### ''PROLOGUE')
j = buf.find('### ''EPILOGUE')
pro = buf[i+12:j].strip()
epi = buf[j+12:].strip()
return pro, epi
def strip_default(arg):
"""Return the argname from an 'arg = default' string"""
i = arg.find('=')
if i == -1:
return arg
return arg[:i].strip()
class NodeInfo:
"""Each instance describes a specific AST node"""
def __init__(self, name, args):
self.name = name
self.args = args.strip()
self.argnames = self.get_argnames()
self.nargs = len(self.argnames)
self.children = COMMA.join(["self.%s" % c
for c in self.argnames])
self.init = []
def get_argnames(self):
if '(' in self.args:
i = self.args.find('(')
j = self.args.rfind(')')
args = self.args[i+1:j]
else:
args = self.args
return [strip_default(arg.strip())
for arg in args.split(',') if arg]
def gen_source(self):
buf = StringIO()
print >> buf, "class %s(Node):" % self.name
print >> buf, ' nodes["%s"] = "%s"' % (self.name.lower(), self.name)
self._gen_init(buf)
self._gen_getChildren(buf)
self._gen_repr(buf)
buf.seek(0, 0)
return buf.read()
def _gen_init(self, buf):
print >> buf, " def __init__(self, %s):" % self.args
if self.argnames:
for name in self.argnames:
print >> buf, " self.%s = %s" % (name, name)
else:
print >> buf, " pass"
if self.init:
print >> buf, "".join([" " + line for line in self.init])
def _gen_getChildren(self, buf):
print >> buf, " def _getChildren(self):"
if self.argnames:
if self.nargs == 1:
print >> buf, " return %s," % self.children
else:
print >> buf, " return %s" % self.children
else:
print >> buf, " return ()"
def _gen_repr(self, buf):
print >> buf, " def __repr__(self):"
if self.argnames:
fmt = COMMA.join(["%s"] * self.nargs)
vals = ["repr(self.%s)" % name for name in self.argnames]
vals = COMMA.join(vals)
if self.nargs == 1:
vals = vals + ","
print >> buf, ' return "%s(%s)" %% (%s)' % \
(self.name, fmt, vals)
else:
print >> buf, ' return "%s()"' % self.name
rx_init = re.compile('init\((.*)\):')
def parse_spec(file):
classes = {}
cur = None
for line in fileinput.input(file):
mo = rx_init.search(line)
if mo is None:
if cur is None:
# a normal entry
try:
name, args = line.split(':')
except ValueError:
continue
classes[name] = NodeInfo(name, args)
cur = None
else:
# some code for the __init__ method
cur.init.append(line)
else:
# some extra code for a Node's __init__ method
name = mo.group(1)
cur = classes[name]
return classes.values()
def main():
prologue, epilogue = load_boilerplate(sys.argv[-1])
print prologue
print
classes = parse_spec(SPEC)
for info in classes:
print info.gen_source()
print epilogue
if __name__ == "__main__":
main()
sys.exit(0)
### PROLOGUE
"""Python abstract syntax node definitions
This file is automatically generated.
"""
from types import TupleType, ListType
from consts import CO_VARARGS, CO_VARKEYWORDS
def flatten(list):
l = []
for elt in list:
t = type(elt)
if t is TupleType or t is ListType:
for elt2 in flatten(elt):
l.append(elt2)
else:
l.append(elt)
return l
def asList(nodes):
l = []
for item in nodes:
if hasattr(item, "asList"):
l.append(item.asList())
else:
t = type(item)
if t is TupleType or t is ListType:
l.append(tuple(asList(item)))
else:
l.append(item)
return l
nodes = {}
class Node:
lineno = None
def getType(self):
pass
def getChildren(self):
# XXX It would be better to generate flat values to begin with
return flatten(self._getChildren())
def asList(self):
return tuple(asList(self.getChildren()))
class EmptyNode(Node):
def __init__(self):
self.lineno = None
### EPILOGUE
klasses = globals()
for k in nodes.keys():
nodes[k] = klasses[nodes[k]]
# operation flags
OP_ASSIGN = 'OP_ASSIGN'
OP_DELETE = 'OP_DELETE'
OP_APPLY = 'OP_APPLY'
SC_LOCAL = 1
SC_GLOBAL = 2
SC_FREE = 3
SC_CELL = 4
SC_UNKNOWN = 5
CO_OPTIMIZED = 0x0001
CO_NEWLOCALS = 0x0002
CO_VARARGS = 0x0004
CO_VARKEYWORDS = 0x0008
CO_NESTED = 0x0010
"""Parser for future statements
"""
import ast
from visitor import walk
def is_future(stmt):
"""Return true if statement is a well-formed future statement"""
if not isinstance(stmt, ast.From):
return 0
if stmt.modname == "__future__":
return 1
else:
return 0
class FutureParser:
features = ("nested_scopes",)
def __init__(self):
self.found = {} # set
def visitModule(self, node):
if node.doc is None:
off = 0
else:
off = 1
stmt = node.node
for s in stmt.nodes[off:]:
if not self.check_stmt(s):
break
def check_stmt(self, stmt):
if is_future(stmt):
for name, asname in stmt.names:
if name in self.features:
self.found[name] = 1
else:
raise SyntaxError, \
"future feature %s is not defined" % name
stmt.valid_future = 1
return 1
return 0
def get_features(self):
"""Return list of features enabled by future statements"""
return self.found.keys()
class BadFutureParser:
"""Check for invalid future statements"""
def visitFrom(self, node):
if hasattr(node, 'valid_future'):
return
if node.modname != "__future__":
return
raise SyntaxError, "invalid future statement"
def find_futures(node):
p1 = FutureParser()
p2 = BadFutureParser()
walk(node, p1)
walk(node, p2)
return p1.get_features()
if __name__ == "__main__":
import sys
from transformer import parseFile
for file in sys.argv[1:]:
print file
tree = parseFile(file)
v = FutureParser()
walk(tree, v)
print v.found
print
import types
def flatten(tup):
elts = []
for elt in tup:
if type(elt) == types.TupleType:
elts = elts + flatten(elt)
else:
elts.append(elt)
return elts
class Set:
def __init__(self):
self.elts = {}
def __len__(self):
return len(self.elts)
def __contains__(self, elt):
return self.elts.has_key(elt)
def add(self, elt):
self.elts[elt] = elt
def elements(self):
return self.elts.keys()
def has_elt(self, elt):
return self.elts.has_key(elt)
def remove(self, elt):
del self.elts[elt]
def copy(self):
c = Set()
c.elts.update(self.elts)
return c
class Stack:
def __init__(self):
self.stack = []
self.pop = self.stack.pop
def __len__(self):
return len(self.stack)
def push(self, elt):
self.stack.append(elt)
def top(self):
return self.stack[-1]
def __getitem__(self, index): # needed by visitContinue()
return self.stack[index]
MANGLE_LEN = 256 # magic constant from compile.c
def mangle(name, klass):
if not name.startswith('__'):
return name
if len(name) + 2 >= MANGLE_LEN:
return name
if name.endswith('__'):
return name
try:
i = 0
while klass[i] == '_':
i = i + 1
except IndexError:
return name
klass = klass[i:]
tlen = len(klass) + len(name)
if tlen > MANGLE_LEN:
klass = klass[:MANGLE_LEN-tlen]
return "_%s%s" % (klass, name)
"""A flow graph representation for Python bytecode"""
from __future__ import nested_scopes
import dis
import new
import string
import sys
import types
import misc
from consts import CO_OPTIMIZED, CO_NEWLOCALS, CO_VARARGS, CO_VARKEYWORDS
def xxx_sort(l):
l = l[:]
def sorter(a, b):
return cmp(a.bid, b.bid)
l.sort(sorter)
return l
class FlowGraph:
def __init__(self):
self.current = self.entry = Block()
self.exit = Block("exit")
self.blocks = misc.Set()
self.blocks.add(self.entry)
self.blocks.add(self.exit)
def startBlock(self, block):
if self._debug:
if self.current:
print "end", repr(self.current)
print " next", self.current.next
print " ", self.current.get_children()
print repr(block)
self.current = block
def nextBlock(self, block=None):
# XXX think we need to specify when there is implicit transfer
# from one block to the next. might be better to represent this
# with explicit JUMP_ABSOLUTE instructions that are optimized
# out when they are unnecessary.
#
# I think this strategy works: each block has a child
# designated as "next" which is returned as the last of the
# children. because the nodes in a graph are emitted in
# reverse post order, the "next" block will always be emitted
# immediately after its parent.
# Worry: maintaining this invariant could be tricky
if block is None:
block = self.newBlock()
# Note: If the current block ends with an unconditional
# control transfer, then it is incorrect to add an implicit
# transfer to the block graph. The current code requires
# these edges to get the blocks emitted in the right order,
# however. :-( If a client needs to remove these edges, call
# pruneEdges().
self.current.addNext(block)
self.startBlock(block)
def newBlock(self):
b = Block()
self.blocks.add(b)
return b
def startExitBlock(self):
self.startBlock(self.exit)
_debug = 0
def _enable_debug(self):
self._debug = 1
def _disable_debug(self):
self._debug = 0
def emit(self, *inst):
if self._debug:
print "\t", inst
if inst[0] == 'RETURN_VALUE':
self.current.addOutEdge(self.exit)
if len(inst) == 2 and isinstance(inst[1], Block):
self.current.addOutEdge(inst[1])
self.current.emit(inst)
def getBlocksInOrder(self):
"""Return the blocks in reverse postorder
i.e. each node appears before all of its successors
"""
# XXX make sure every node that doesn't have an explicit next
# is set so that next points to exit
for b in self.blocks.elements():
if b is self.exit:
continue
if not b.next:
b.addNext(self.exit)
order = dfs_postorder(self.entry, {})
order.reverse()
self.fixupOrder(order, self.exit)
# hack alert
if not self.exit in order:
order.append(self.exit)
return order
def fixupOrder(self, blocks, default_next):
"""Fixup bad order introduced by DFS."""
# XXX This is a total mess. There must be a better way to get
# the code blocks in the right order.
self.fixupOrderHonorNext(blocks, default_next)
self.fixupOrderForward(blocks, default_next)
def fixupOrderHonorNext(self, blocks, default_next):
"""Fix one problem with DFS.
The DFS uses child block, but doesn't know about the special
"next" block. As a result, the DFS can order blocks so that a
block isn't next to the right block for implicit control
transfers.
"""
index = {}
for i in range(len(blocks)):
index[blocks[i]] = i
for i in range(0, len(blocks) - 1):
b = blocks[i]
n = blocks[i + 1]
if not b.next or b.next[0] == default_next or b.next[0] == n:
continue
# The blocks are in the wrong order. Find the chain of
# blocks to insert where they belong.
cur = b
chain = []
elt = cur
while elt.next and elt.next[0] != default_next:
chain.append(elt.next[0])
elt = elt.next[0]
# Now remove the blocks in the chain from the current
# block list, so that they can be re-inserted.
l = []
for b in chain:
assert index[b] > i
l.append((index[b], b))
l.sort()
l.reverse()
for j, b in l:
del blocks[index[b]]
# Insert the chain in the proper location
blocks[i:i + 1] = [cur] + chain
# Finally, re-compute the block indexes
for i in range(len(blocks)):
index[blocks[i]] = i
def fixupOrderForward(self, blocks, default_next):
"""Make sure all JUMP_FORWARDs jump forward"""
index = {}
chains = []
cur = []
for b in blocks:
index[b] = len(chains)
cur.append(b)
if b.next and b.next[0] == default_next:
chains.append(cur)
cur = []
chains.append(cur)
while 1:
constraints = []
for i in range(len(chains)):
l = chains[i]
for b in l:
for c in b.get_children():
if index[c] < i:
forward_p = 0
for inst in b.insts:
if inst[0] == 'JUMP_FORWARD':
if inst[1] == c:
forward_p = 1
if not forward_p:
continue
constraints.append((index[c], i))
if not constraints:
break
# XXX just do one for now
# do swaps to get things in the right order
goes_before, a_chain = constraints[0]
assert a_chain > goes_before
c = chains[a_chain]
chains.remove(c)
chains.insert(goes_before, c)
del blocks[:]
for c in chains:
for b in c:
blocks.append(b)
def getBlocks(self):
return self.blocks.elements()
def getRoot(self):
"""Return nodes appropriate for use with dominator"""
return self.entry
def getContainedGraphs(self):
l = []
for b in self.getBlocks():
l.extend(b.getContainedGraphs())
return l
def dfs_postorder(b, seen):
"""Depth-first search of tree rooted at b, return in postorder"""
order = []
seen[b] = b
for c in b.get_children():
if seen.has_key(c):
continue
order = order + dfs_postorder(c, seen)
order.append(b)
return order
class Block:
_count = 0
def __init__(self, label=''):
self.insts = []
self.inEdges = misc.Set()
self.outEdges = misc.Set()
self.label = label
self.bid = Block._count
self.next = []
Block._count = Block._count + 1
def __repr__(self):
if self.label:
return "<block %s id=%d>" % (self.label, self.bid)
else:
return "<block id=%d>" % (self.bid)
def __str__(self):
insts = map(str, self.insts)
return "<block %s %d:\n%s>" % (self.label, self.bid,
string.join(insts, '\n'))
def emit(self, inst):
op = inst[0]
if op[:4] == 'JUMP':
self.outEdges.add(inst[1])
self.insts.append(inst)
def getInstructions(self):
return self.insts
def addInEdge(self, block):
self.inEdges.add(block)
def addOutEdge(self, block):
self.outEdges.add(block)
def addNext(self, block):
self.next.append(block)
assert len(self.next) == 1, map(str, self.next)
_uncond_transfer = ('RETURN_VALUE', 'RAISE_VARARGS',
'JUMP_ABSOLUTE', 'JUMP_FORWARD', 'CONTINUE_LOOP')
def pruneNext(self):
"""Remove bogus edge for unconditional transfers
Each block has a next edge that accounts for implicit control
transfers, e.g. from a JUMP_IF_FALSE to the block that will be
executed if the test is true.
These edges must remain for the current assembler code to
work. If they are removed, the dfs_postorder gets things in
weird orders. However, they shouldn't be there for other
purposes, e.g. conversion to SSA form. This method will
remove the next edge when it follows an unconditional control
transfer.
"""
try:
op, arg = self.insts[-1]
except (IndexError, ValueError):
return
if op in self._uncond_transfer:
self.next = []
def get_children(self):
if self.next and self.next[0] in self.outEdges:
self.outEdges.remove(self.next[0])
return self.outEdges.elements() + self.next
def getContainedGraphs(self):
"""Return all graphs contained within this block.
For example, a MAKE_FUNCTION block will contain a reference to
the graph for the function body.
"""
contained = []
for inst in self.insts:
if len(inst) == 1:
continue
op = inst[1]
if hasattr(op, 'graph'):
contained.append(op.graph)
return contained
# flags for code objects
# the FlowGraph is transformed in place; it exists in one of these states
RAW = "RAW"
FLAT = "FLAT"
CONV = "CONV"
DONE = "DONE"
class PyFlowGraph(FlowGraph):
super_init = FlowGraph.__init__
def __init__(self, name, filename, args=(), optimized=0, klass=None):
self.super_init()
self.name = name
assert isinstance(filename, types.StringType)
self.filename = filename
self.docstring = None
self.args = args # XXX
self.argcount = getArgCount(args)
self.klass = klass
if optimized:
self.flags = CO_OPTIMIZED | CO_NEWLOCALS
else:
self.flags = 0
self.consts = []
self.names = []
# Free variables found by the symbol table scan, including
# variables used only in nested scopes, are included here.
self.freevars = []
self.cellvars = []
# The closure list is used to track the order of cell
# variables and free variables in the resulting code object.
# The offsets used by LOAD_CLOSURE/LOAD_DEREF refer to both
# kinds of variables.
self.closure = []
self.varnames = list(args) or []
for i in range(len(self.varnames)):
var = self.varnames[i]
if isinstance(var, TupleArg):
self.varnames[i] = var.getName()
self.stage = RAW
def setDocstring(self, doc):
self.docstring = doc
def setFlag(self, flag):
self.flags = self.flags | flag
if flag == CO_VARARGS:
self.argcount = self.argcount - 1
def checkFlag(self, flag):
if self.flags & flag:
return 1
def setFreeVars(self, names):
self.freevars = list(names)
def setCellVars(self, names):
self.cellvars = names
def getCode(self):
"""Get a Python code object"""
if self.stage == RAW:
self.computeStackDepth()
self.flattenGraph()
if self.stage == FLAT:
self.convertArgs()
if self.stage == CONV:
self.makeByteCode()
if self.stage == DONE:
return self.newCodeObject()
raise RuntimeError, "inconsistent PyFlowGraph state"
def dump(self, io=None):
if io:
save = sys.stdout
sys.stdout = io
pc = 0
for t in self.insts:
opname = t[0]
if opname == "SET_LINENO":
print
if len(t) == 1:
print "\t", "%3d" % pc, opname
pc = pc + 1
else:
print "\t", "%3d" % pc, opname, t[1]
pc = pc + 3
if io:
sys.stdout = save
def computeStackDepth(self):
"""Compute the max stack depth.
Approach is to compute the stack effect of each basic block.
Then find the path through the code with the largest total
effect.
"""
depth = {}
exit = None
for b in self.getBlocks():
depth[b] = findDepth(b.getInstructions())
seen = {}
def max_depth(b, d):
if seen.has_key(b):
return d
seen[b] = 1
d = d + depth[b]
children = b.get_children()
if children:
return max([max_depth(c, d) for c in children])
else:
if not b.label == "exit":
return max_depth(self.exit, d)
else:
return d
self.stacksize = max_depth(self.entry, 0)
def flattenGraph(self):
"""Arrange the blocks in order and resolve jumps"""
assert self.stage == RAW
self.insts = insts = []
pc = 0
begin = {}
end = {}
for b in self.getBlocksInOrder():
begin[b] = pc
for inst in b.getInstructions():
insts.append(inst)
if len(inst) == 1:
pc = pc + 1
else:
# arg takes 2 bytes
pc = pc + 3
end[b] = pc
pc = 0
for i in range(len(insts)):
inst = insts[i]
if len(inst) == 1:
pc = pc + 1
else:
pc = pc + 3
opname = inst[0]
if self.hasjrel.has_elt(opname):
oparg = inst[1]
offset = begin[oparg] - pc
insts[i] = opname, offset
elif self.hasjabs.has_elt(opname):
insts[i] = opname, begin[inst[1]]
self.stage = FLAT
hasjrel = misc.Set()
for i in dis.hasjrel:
hasjrel.add(dis.opname[i])
hasjabs = misc.Set()
for i in dis.hasjabs:
hasjabs.add(dis.opname[i])
def convertArgs(self):
"""Convert arguments from symbolic to concrete form"""
assert self.stage == FLAT
self.consts.insert(0, self.docstring)
self.sort_cellvars()
for i in range(len(self.insts)):
t = self.insts[i]
if len(t) == 2:
opname, oparg = t
conv = self._converters.get(opname, None)
if conv:
self.insts[i] = opname, conv(self, oparg)
self.stage = CONV
def sort_cellvars(self):
"""Sort cellvars in the order of varnames and prune from freevars.
"""
cells = {}
for name in self.cellvars:
cells[name] = 1
self.cellvars = [name for name in self.varnames
if cells.has_key(name)]
for name in self.cellvars:
del cells[name]
self.cellvars = self.cellvars + cells.keys()
self.closure = self.cellvars + self.freevars
def _lookupName(self, name, list):
"""Return index of name in list, appending if necessary
This routine uses a list instead of a dictionary, because a
dictionary can't store two different keys if the keys have the
same value but different types, e.g. 2 and 2L. The compiler
must treat these two separately, so it does an explicit type
comparison before comparing the values.
"""
t = type(name)
for i in range(len(list)):
if t == type(list[i]) and list[i] == name:
return i
end = len(list)
list.append(name)
return end
_converters = {}
def _convert_LOAD_CONST(self, arg):
if hasattr(arg, 'getCode'):
arg = arg.getCode()
return self._lookupName(arg, self.consts)
def _convert_LOAD_FAST(self, arg):
self._lookupName(arg, self.names)
return self._lookupName(arg, self.varnames)
_convert_STORE_FAST = _convert_LOAD_FAST
_convert_DELETE_FAST = _convert_LOAD_FAST
def _convert_LOAD_NAME(self, arg):
if self.klass is None:
self._lookupName(arg, self.varnames)
return self._lookupName(arg, self.names)
def _convert_NAME(self, arg):
if self.klass is None:
self._lookupName(arg, self.varnames)
return self._lookupName(arg, self.names)
_convert_STORE_NAME = _convert_NAME
_convert_DELETE_NAME = _convert_NAME
_convert_IMPORT_NAME = _convert_NAME
_convert_IMPORT_FROM = _convert_NAME
_convert_STORE_ATTR = _convert_NAME
_convert_LOAD_ATTR = _convert_NAME
_convert_DELETE_ATTR = _convert_NAME
_convert_LOAD_GLOBAL = _convert_NAME
_convert_STORE_GLOBAL = _convert_NAME
_convert_DELETE_GLOBAL = _convert_NAME
def _convert_DEREF(self, arg):
self._lookupName(arg, self.names)
self._lookupName(arg, self.varnames)
return self._lookupName(arg, self.closure)
_convert_LOAD_DEREF = _convert_DEREF
_convert_STORE_DEREF = _convert_DEREF
def _convert_LOAD_CLOSURE(self, arg):
self._lookupName(arg, self.varnames)
return self._lookupName(arg, self.closure)
_cmp = list(dis.cmp_op)
def _convert_COMPARE_OP(self, arg):
return self._cmp.index(arg)
# similarly for other opcodes...
for name, obj in locals().items():
if name[:9] == "_convert_":
opname = name[9:]
_converters[opname] = obj
del name, obj, opname
def makeByteCode(self):
assert self.stage == CONV
self.lnotab = lnotab = LineAddrTable()
for t in self.insts:
opname = t[0]
if len(t) == 1:
lnotab.addCode(self.opnum[opname])
else:
oparg = t[1]
if opname == "SET_LINENO":
lnotab.nextLine(oparg)
hi, lo = twobyte(oparg)
try:
lnotab.addCode(self.opnum[opname], lo, hi)
except ValueError:
print opname, oparg
print self.opnum[opname], lo, hi
raise
self.stage = DONE
opnum = {}
for num in range(len(dis.opname)):
opnum[dis.opname[num]] = num
del num
def newCodeObject(self):
assert self.stage == DONE
if (self.flags & CO_NEWLOCALS) == 0:
nlocals = 0
else:
nlocals = len(self.varnames)
argcount = self.argcount
if self.flags & CO_VARKEYWORDS:
argcount = argcount - 1
return new.code(argcount, nlocals, self.stacksize, self.flags,
self.lnotab.getCode(), self.getConsts(),
tuple(self.names), tuple(self.varnames),
self.filename, self.name, self.lnotab.firstline,
self.lnotab.getTable(), tuple(self.freevars),
tuple(self.cellvars))
def getConsts(self):
"""Return a tuple for the const slot of the code object
Must convert references to code (MAKE_FUNCTION) to code
objects recursively.
"""
l = []
for elt in self.consts:
if isinstance(elt, PyFlowGraph):
elt = elt.getCode()
l.append(elt)
return tuple(l)
def isJump(opname):
if opname[:4] == 'JUMP':
return 1
class TupleArg:
"""Helper for marking func defs with nested tuples in arglist"""
def __init__(self, count, names):
self.count = count
self.names = names
def __repr__(self):
return "TupleArg(%s, %s)" % (self.count, self.names)
def getName(self):
return ".%d" % self.count
def getArgCount(args):
argcount = len(args)
if args:
for arg in args:
if isinstance(arg, TupleArg):
numNames = len(misc.flatten(arg.names))
argcount = argcount - numNames
return argcount
def twobyte(val):
"""Convert an int argument into high and low bytes"""
assert type(val) == types.IntType
return divmod(val, 256)
class LineAddrTable:
"""lnotab
This class builds the lnotab, which is documented in compile.c.
Here's a brief recap:
For each SET_LINENO instruction after the first one, two bytes are
added to lnotab. (In some cases, multiple two-byte entries are
added.) The first byte is the distance in bytes between the
instruction for the last SET_LINENO and the current SET_LINENO.
The second byte is offset in line numbers. If either offset is
greater than 255, multiple two-byte entries are added -- see
compile.c for the delicate details.
"""
def __init__(self):
self.code = []
self.codeOffset = 0
self.firstline = 0
self.lastline = 0
self.lastoff = 0
self.lnotab = []
def addCode(self, *args):
for arg in args:
self.code.append(chr(arg))
self.codeOffset = self.codeOffset + len(args)
def nextLine(self, lineno):
if self.firstline == 0:
self.firstline = lineno
self.lastline = lineno
else:
# compute deltas
addr = self.codeOffset - self.lastoff
line = lineno - self.lastline
# Python assumes that lineno always increases with
# increasing bytecode address (lnotab is unsigned char).
# Depending on when SET_LINENO instructions are emitted
# this is not always true. Consider the code:
# a = (1,
# b)
# In the bytecode stream, the assignment to "a" occurs
# after the loading of "b". This works with the C Python
# compiler because it only generates a SET_LINENO instruction
# for the assignment.
if line > 0:
push = self.lnotab.append
while addr > 255:
push(255); push(0)
addr -= 255
while line > 255:
push(addr); push(255)
line -= 255
addr = 0
if addr > 0 or line > 0:
push(addr); push(line)
self.lastline = lineno
self.lastoff = self.codeOffset
def getCode(self):
return string.join(self.code, '')
def getTable(self):
return string.join(map(chr, self.lnotab), '')
class StackDepthTracker:
# XXX 1. need to keep track of stack depth on jumps
# XXX 2. at least partly as a result, this code is broken
def findDepth(self, insts, debug=0):
depth = 0
maxDepth = 0
for i in insts:
opname = i[0]
if debug:
print i,
delta = self.effect.get(opname, None)
if delta is not None:
depth = depth + delta
else:
# now check patterns
for pat, pat_delta in self.patterns:
if opname[:len(pat)] == pat:
delta = pat_delta
depth = depth + delta
break
# if we still haven't found a match
if delta is None:
meth = getattr(self, opname, None)
if meth is not None:
depth = depth + meth(i[1])
if depth > maxDepth:
maxDepth = depth
if debug:
print depth, maxDepth
return maxDepth
effect = {
'POP_TOP': -1,
'DUP_TOP': 1,
'SLICE+1': -1,
'SLICE+2': -1,
'SLICE+3': -2,
'STORE_SLICE+0': -1,
'STORE_SLICE+1': -2,
'STORE_SLICE+2': -2,
'STORE_SLICE+3': -3,
'DELETE_SLICE+0': -1,
'DELETE_SLICE+1': -2,
'DELETE_SLICE+2': -2,
'DELETE_SLICE+3': -3,
'STORE_SUBSCR': -3,
'DELETE_SUBSCR': -2,
# PRINT_EXPR?
'PRINT_ITEM': -1,
'RETURN_VALUE': -1,
'EXEC_STMT': -3,
'BUILD_CLASS': -2,
'STORE_NAME': -1,
'STORE_ATTR': -2,
'DELETE_ATTR': -1,
'STORE_GLOBAL': -1,
'BUILD_MAP': 1,
'COMPARE_OP': -1,
'STORE_FAST': -1,
'IMPORT_STAR': -1,
'IMPORT_NAME': 0,
'IMPORT_FROM': 1,
'LOAD_ATTR': 0, # unlike other loads
# close enough...
'SETUP_EXCEPT': 3,
'SETUP_FINALLY': 3,
'FOR_LOOP': 1,
}
# use pattern match
patterns = [
('BINARY_', -1),
('LOAD_', 1),
]
def UNPACK_SEQUENCE(self, count):
return count-1
def BUILD_TUPLE(self, count):
return -count+1
def BUILD_LIST(self, count):
return -count+1
def CALL_FUNCTION(self, argc):
hi, lo = divmod(argc, 256)
return -(lo + hi * 2)
def CALL_FUNCTION_VAR(self, argc):
return self.CALL_FUNCTION(argc)-1
def CALL_FUNCTION_KW(self, argc):
return self.CALL_FUNCTION(argc)-1
def CALL_FUNCTION_VAR_KW(self, argc):
return self.CALL_FUNCTION(argc)-2
def MAKE_FUNCTION(self, argc):
return -argc
def MAKE_CLOSURE(self, argc):
# XXX need to account for free variables too!
return -argc
def BUILD_SLICE(self, argc):
if argc == 2:
return -1
elif argc == 3:
return -2
def DUP_TOPX(self, argc):
return argc
findDepth = StackDepthTracker().findDepth
import imp
import os
import marshal
import stat
import string
import struct
import sys
import types
from cStringIO import StringIO
import ast
from transformer import parse
from visitor import walk
import pyassem, misc, future, symbols
from consts import SC_LOCAL, SC_GLOBAL, SC_FREE, SC_CELL
from consts import CO_VARARGS, CO_VARKEYWORDS, CO_NEWLOCALS, CO_NESTED
from pyassem import TupleArg
# Do we have Python 1.x or Python 2.x?
try:
VERSION = sys.version_info[0]
except AttributeError:
VERSION = 1
callfunc_opcode_info = {
# (Have *args, Have **args) : opcode
(0,0) : "CALL_FUNCTION",
(1,0) : "CALL_FUNCTION_VAR",
(0,1) : "CALL_FUNCTION_KW",
(1,1) : "CALL_FUNCTION_VAR_KW",
}
LOOP = 1
EXCEPT = 2
TRY_FINALLY = 3
END_FINALLY = 4
def compile(filename, display=0):
f = open(filename)
buf = f.read()
f.close()
mod = Module(buf, filename)
try:
mod.compile(display)
except SyntaxError:
raise
else:
f = open(filename + "c", "wb")
mod.dump(f)
f.close()
class Module:
def __init__(self, source, filename):
self.filename = filename
self.source = source
self.code = None
def compile(self, display=0):
tree = parse(self.source)
root, filename = os.path.split(self.filename)
if "nested_scopes" in future.find_futures(tree):
gen = NestedScopeModuleCodeGenerator(filename)
else:
gen = ModuleCodeGenerator(filename)
walk(tree, gen, 1)
if display:
import pprint
print pprint.pprint(tree)
self.code = gen.getCode()
def dump(self, f):
f.write(self.getPycHeader())
marshal.dump(self.code, f)
MAGIC = imp.get_magic()
def getPycHeader(self):
# compile.c uses marshal to write a long directly, with
# calling the interface that would also generate a 1-byte code
# to indicate the type of the value. simplest way to get the
# same effect is to call marshal and then skip the code.
mtime = os.stat(self.filename)[stat.ST_MTIME]
mtime = struct.pack('i', mtime)
return self.MAGIC + mtime
class LocalNameFinder:
"""Find local names in scope"""
def __init__(self, names=()):
self.names = misc.Set()
self.globals = misc.Set()
for name in names:
self.names.add(name)
# XXX list comprehensions and for loops
def getLocals(self):
for elt in self.globals.elements():
if self.names.has_elt(elt):
self.names.remove(elt)
return self.names
def visitDict(self, node):
pass
def visitGlobal(self, node):
for name in node.names:
self.globals.add(name)
def visitFunction(self, node):
self.names.add(node.name)
def visitLambda(self, node):
pass
def visitImport(self, node):
for name, alias in node.names:
self.names.add(alias or name)
def visitFrom(self, node):
for name, alias in node.names:
self.names.add(alias or name)
def visitClass(self, node):
self.names.add(node.name)
def visitAssName(self, node):
self.names.add(node.name)
def is_constant_false(node):
if isinstance(node, ast.Const):
if not node.value:
return 1
return 0
class CodeGenerator:
"""Defines basic code generator for Python bytecode
This class is an abstract base class. Concrete subclasses must
define an __init__() that defines self.graph and then calls the
__init__() defined in this class.
The concrete class must also define the class attributes
NameFinder, FunctionGen, and ClassGen. These attributes can be
defined in the initClass() method, which is a hook for
initializing these methods after all the classes have been
defined.
"""
optimized = 0 # is namespace access optimized?
__initialized = None
class_name = None # provide default for instance variable
def __init__(self, filename):
if self.__initialized is None:
self.initClass()
self.__class__.__initialized = 1
self.checkClass()
self.filename = filename
self.locals = misc.Stack()
self.setups = misc.Stack()
self.curStack = 0
self.maxStack = 0
self.last_lineno = None
self._setupGraphDelegation()
def initClass(self):
"""This method is called once for each class"""
def checkClass(self):
"""Verify that class is constructed correctly"""
try:
assert hasattr(self, 'graph')
assert getattr(self, 'NameFinder')
assert getattr(self, 'FunctionGen')
assert getattr(self, 'ClassGen')
except AssertionError, msg:
intro = "Bad class construction for %s" % self.__class__.__name__
raise AssertionError, intro
def _setupGraphDelegation(self):
self.emit = self.graph.emit
self.newBlock = self.graph.newBlock
self.startBlock = self.graph.startBlock
self.nextBlock = self.graph.nextBlock
self.setDocstring = self.graph.setDocstring
def getCode(self):
"""Return a code object"""
return self.graph.getCode()
def mangle(self, name):
if self.class_name is not None:
return misc.mangle(name, self.class_name)
else:
return name
def parseSymbols(self, tree):
s = symbols.SymbolVisitor()
walk(tree, s)
return s.scopes
# Next five methods handle name access
def isLocalName(self, name):
return self.locals.top().has_elt(name)
def storeName(self, name):
self._nameOp('STORE', name)
def loadName(self, name):
self._nameOp('LOAD', name)
def delName(self, name):
self._nameOp('DELETE', name)
def _nameOp(self, prefix, name):
name = self.mangle(name)
scope = self.scope.check_name(name)
if scope == SC_LOCAL:
if not self.optimized:
self.emit(prefix + '_NAME', name)
else:
self.emit(prefix + '_FAST', name)
elif scope == SC_GLOBAL:
if not self.optimized:
self.emit(prefix + '_NAME', name)
else:
self.emit(prefix + '_GLOBAL', name)
elif scope == SC_FREE or scope == SC_CELL:
self.emit(prefix + '_DEREF', name)
else:
raise RuntimeError, "unsupported scope for var %s: %d" % \
(name, scope)
def _implicitNameOp(self, prefix, name):
"""Emit name ops for names generated implicitly by for loops
The interpreter generates names that start with a period or
dollar sign. The symbol table ignores these names because
they aren't present in the program text.
"""
if self.optimized:
self.emit(prefix + '_FAST', name)
else:
self.emit(prefix + '_NAME', name)
def set_lineno(self, node, force=0):
"""Emit SET_LINENO if node has lineno attribute and it is
different than the last lineno emitted.
Returns true if SET_LINENO was emitted.
There are no rules for when an AST node should have a lineno
attribute. The transformer and AST code need to be reviewed
and a consistent policy implemented and documented. Until
then, this method works around missing line numbers.
"""
lineno = getattr(node, 'lineno', None)
if lineno is not None and (lineno != self.last_lineno
or force):
self.emit('SET_LINENO', lineno)
self.last_lineno = lineno
return 1
return 0
# The first few visitor methods handle nodes that generator new
# code objects. They use class attributes to determine what
# specialized code generators to use.
NameFinder = LocalNameFinder
FunctionGen = None
ClassGen = None
def visitModule(self, node):
self.scopes = self.parseSymbols(node)
self.scope = self.scopes[node]
self.emit('SET_LINENO', 0)
if node.doc:
self.emit('LOAD_CONST', node.doc)
self.storeName('__doc__')
lnf = walk(node.node, self.NameFinder(), verbose=0)
self.locals.push(lnf.getLocals())
self.visit(node.node)
self.emit('LOAD_CONST', None)
self.emit('RETURN_VALUE')
def visitFunction(self, node):
self._visitFuncOrLambda(node, isLambda=0)
if node.doc:
self.setDocstring(node.doc)
self.storeName(node.name)
def visitLambda(self, node):
self._visitFuncOrLambda(node, isLambda=1)
def _visitFuncOrLambda(self, node, isLambda=0):
gen = self.FunctionGen(node, self.filename, self.scopes, isLambda,
self.class_name)
walk(node.code, gen)
gen.finish()
self.set_lineno(node)
for default in node.defaults:
self.visit(default)
self.emit('LOAD_CONST', gen)
self.emit('MAKE_FUNCTION', len(node.defaults))
def visitClass(self, node):
gen = self.ClassGen(node, self.scopes, self.filename)
walk(node.code, gen)
gen.finish()
self.set_lineno(node)
self.emit('LOAD_CONST', node.name)
for base in node.bases:
self.visit(base)
self.emit('BUILD_TUPLE', len(node.bases))
self.emit('LOAD_CONST', gen)
self.emit('MAKE_FUNCTION', 0)
self.emit('CALL_FUNCTION', 0)
self.emit('BUILD_CLASS')
self.storeName(node.name)
# The rest are standard visitor methods
# The next few implement control-flow statements
def visitIf(self, node):
end = self.newBlock()
numtests = len(node.tests)
for i in range(numtests):
test, suite = node.tests[i]
if is_constant_false(test):
continue
self.set_lineno(test)
self.visit(test)
nextTest = self.newBlock()
self.emit('JUMP_IF_FALSE', nextTest)
self.nextBlock()
self.emit('POP_TOP')
self.visit(suite)
self.emit('JUMP_FORWARD', end)
self.startBlock(nextTest)
self.emit('POP_TOP')
if node.else_:
self.visit(node.else_)
self.nextBlock(end)
def visitWhile(self, node):
self.set_lineno(node)
loop = self.newBlock()
else_ = self.newBlock()
after = self.newBlock()
self.emit('SETUP_LOOP', after)
self.nextBlock(loop)
self.setups.push((LOOP, loop))
self.set_lineno(node, force=1)
self.visit(node.test)
self.emit('JUMP_IF_FALSE', else_ or after)
self.nextBlock()
self.emit('POP_TOP')
self.visit(node.body)
self.emit('JUMP_ABSOLUTE', loop)
self.startBlock(else_) # or just the POPs if not else clause
self.emit('POP_TOP')
self.emit('POP_BLOCK')
self.setups.pop()
if node.else_:
self.visit(node.else_)
self.nextBlock(after)
def visitFor(self, node):
start = self.newBlock()
anchor = self.newBlock()
after = self.newBlock()
self.setups.push((LOOP, start))
self.set_lineno(node)
self.emit('SETUP_LOOP', after)
self.visit(node.list)
self.visit(ast.Const(0))
self.nextBlock(start)
self.set_lineno(node, force=1)
self.emit('FOR_LOOP', anchor)
## self.nextBlock()
self.visit(node.assign)
self.visit(node.body)
self.emit('JUMP_ABSOLUTE', start)
## self.startBlock(anchor)
self.nextBlock(anchor)
self.emit('POP_BLOCK')
self.setups.pop()
if node.else_:
self.visit(node.else_)
self.nextBlock(after)
def visitBreak(self, node):
if not self.setups:
raise SyntaxError, "'break' outside loop (%s, %d)" % \
(self.filename, node.lineno)
self.set_lineno(node)
self.emit('BREAK_LOOP')
def visitContinue(self, node):
# XXX test_grammar.py, line 351
if not self.setups:
raise SyntaxError, "'continue' outside loop (%s, %d)" % \
(self.filename, node.lineno)
kind, block = self.setups.top()
if kind == LOOP:
self.set_lineno(node)
self.emit('JUMP_ABSOLUTE', block)
self.nextBlock()
elif kind == EXCEPT or kind == TRY_FINALLY:
self.set_lineno(node)
# find the block that starts the loop
top = len(self.setups)
while top > 0:
top = top - 1
kind, loop_block = self.setups[top]
if kind == LOOP:
break
if kind != LOOP:
raise SyntaxError, "'continue' outside loop (%s, %d)" % \
(self.filename, node.lineno)
self.emit('CONTINUE_LOOP', loop_block)
self.nextBlock()
elif kind == END_FINALLY:
msg = "'continue' not allowed inside 'finally' clause (%s, %d)"
raise SyntaxError, msg % (self.filename, node.lineno)
def visitTest(self, node, jump):
end = self.newBlock()
for child in node.nodes[:-1]:
self.visit(child)
self.emit(jump, end)
self.nextBlock()
self.emit('POP_TOP')
self.visit(node.nodes[-1])
self.nextBlock(end)
def visitAnd(self, node):
self.visitTest(node, 'JUMP_IF_FALSE')
def visitOr(self, node):
self.visitTest(node, 'JUMP_IF_TRUE')
def visitCompare(self, node):
self.visit(node.expr)
cleanup = self.newBlock()
for op, code in node.ops[:-1]:
self.visit(code)
self.emit('DUP_TOP')
self.emit('ROT_THREE')
self.emit('COMPARE_OP', op)
self.emit('JUMP_IF_FALSE', cleanup)
self.nextBlock()
self.emit('POP_TOP')
# now do the last comparison
if node.ops:
op, code = node.ops[-1]
self.visit(code)
self.emit('COMPARE_OP', op)
if len(node.ops) > 1:
end = self.newBlock()
self.emit('JUMP_FORWARD', end)
self.startBlock(cleanup)
self.emit('ROT_TWO')
self.emit('POP_TOP')
self.nextBlock(end)
# list comprehensions
__list_count = 0
def visitListComp(self, node):
self.set_lineno(node)
# setup list
append = "$append%d" % self.__list_count
self.__list_count = self.__list_count + 1
self.emit('BUILD_LIST', 0)
self.emit('DUP_TOP')
self.emit('LOAD_ATTR', 'append')
self._implicitNameOp('STORE', append)
stack = []
for i, for_ in zip(range(len(node.quals)), node.quals):
start, anchor = self.visit(for_)
cont = None
for if_ in for_.ifs:
if cont is None:
cont = self.newBlock()
self.visit(if_, cont)
stack.insert(0, (start, cont, anchor))
self._implicitNameOp('LOAD', append)
self.visit(node.expr)
self.emit('CALL_FUNCTION', 1)
self.emit('POP_TOP')
for start, cont, anchor in stack:
if cont:
skip_one = self.newBlock()
self.emit('JUMP_FORWARD', skip_one)
self.startBlock(cont)
self.emit('POP_TOP')
self.nextBlock(skip_one)
self.emit('JUMP_ABSOLUTE', start)
self.startBlock(anchor)
self._implicitNameOp('DELETE', append)
self.__list_count = self.__list_count - 1
def visitListCompFor(self, node):
start = self.newBlock()
anchor = self.newBlock()
self.visit(node.list)
self.visit(ast.Const(0))
self.nextBlock(start)
self.emit('SET_LINENO', node.lineno)
self.emit('FOR_LOOP', anchor)
self.nextBlock()
self.visit(node.assign)
return start, anchor
def visitListCompIf(self, node, branch):
self.set_lineno(node, force=1)
self.visit(node.test)
self.emit('JUMP_IF_FALSE', branch)
self.newBlock()
self.emit('POP_TOP')
# exception related
def visitAssert(self, node):
# XXX would be interesting to implement this via a
# transformation of the AST before this stage
end = self.newBlock()
self.set_lineno(node)
# XXX __debug__ and AssertionError appear to be special cases
# -- they are always loaded as globals even if there are local
# names. I guess this is a sort of renaming op.
self.emit('LOAD_GLOBAL', '__debug__')
self.emit('JUMP_IF_FALSE', end)
self.nextBlock()
self.emit('POP_TOP')
self.visit(node.test)
self.emit('JUMP_IF_TRUE', end)
self.nextBlock()
self.emit('POP_TOP')
self.emit('LOAD_GLOBAL', 'AssertionError')
if node.fail:
self.visit(node.fail)
self.emit('RAISE_VARARGS', 2)
else:
self.emit('RAISE_VARARGS', 1)
self.nextBlock(end)
self.emit('POP_TOP')
def visitRaise(self, node):
self.set_lineno(node)
n = 0
if node.expr1:
self.visit(node.expr1)
n = n + 1
if node.expr2:
self.visit(node.expr2)
n = n + 1
if node.expr3:
self.visit(node.expr3)
n = n + 1
self.emit('RAISE_VARARGS', n)
def visitTryExcept(self, node):
body = self.newBlock()
handlers = self.newBlock()
end = self.newBlock()
if node.else_:
lElse = self.newBlock()
else:
lElse = end
self.set_lineno(node)
self.emit('SETUP_EXCEPT', handlers)
self.nextBlock(body)
self.setups.push((EXCEPT, body))
self.visit(node.body)
self.emit('POP_BLOCK')
self.setups.pop()
self.emit('JUMP_FORWARD', lElse)
self.startBlock(handlers)
last = len(node.handlers) - 1
for i in range(len(node.handlers)):
expr, target, body = node.handlers[i]
self.set_lineno(expr)
if expr:
self.emit('DUP_TOP')
self.visit(expr)
self.emit('COMPARE_OP', 'exception match')
next = self.newBlock()
self.emit('JUMP_IF_FALSE', next)
self.nextBlock()
self.emit('POP_TOP')
self.emit('POP_TOP')
if target:
self.visit(target)
else:
self.emit('POP_TOP')
self.emit('POP_TOP')
self.visit(body)
self.emit('JUMP_FORWARD', end)
if expr:
self.nextBlock(next)
else:
self.nextBlock()
self.emit('POP_TOP')
self.emit('END_FINALLY')
if node.else_:
self.nextBlock(lElse)
self.visit(node.else_)
self.nextBlock(end)
def visitTryFinally(self, node):
body = self.newBlock()
final = self.newBlock()
self.set_lineno(node)
self.emit('SETUP_FINALLY', final)
self.nextBlock(body)
self.setups.push((TRY_FINALLY, body))
self.visit(node.body)
self.emit('POP_BLOCK')
self.setups.pop()
self.emit('LOAD_CONST', None)
self.nextBlock(final)
self.setups.push((END_FINALLY, final))
self.visit(node.final)
self.emit('END_FINALLY')
self.setups.pop()
# misc
def visitDiscard(self, node):
self.set_lineno(node)
self.visit(node.expr)
self.emit('POP_TOP')
def visitConst(self, node):
self.emit('LOAD_CONST', node.value)
def visitKeyword(self, node):
self.emit('LOAD_CONST', node.name)
self.visit(node.expr)
def visitGlobal(self, node):
# no code to generate
pass
def visitName(self, node):
self.set_lineno(node)
self.loadName(node.name)
def visitPass(self, node):
self.set_lineno(node)
def visitImport(self, node):
self.set_lineno(node)
for name, alias in node.names:
if VERSION > 1:
self.emit('LOAD_CONST', None)
self.emit('IMPORT_NAME', name)
mod = string.split(name, ".")[0]
self.storeName(alias or mod)
def visitFrom(self, node):
self.set_lineno(node)
fromlist = map(lambda (name, alias): name, node.names)
if VERSION > 1:
self.emit('LOAD_CONST', tuple(fromlist))
self.emit('IMPORT_NAME', node.modname)
for name, alias in node.names:
if VERSION > 1:
if name == '*':
self.namespace = 0
self.emit('IMPORT_STAR')
# There can only be one name w/ from ... import *
assert len(node.names) == 1
return
else:
self.emit('IMPORT_FROM', name)
self._resolveDots(name)
self.storeName(alias or name)
else:
self.emit('IMPORT_FROM', name)
self.emit('POP_TOP')
def _resolveDots(self, name):
elts = string.split(name, ".")
if len(elts) == 1:
return
for elt in elts[1:]:
self.emit('LOAD_ATTR', elt)
def visitGetattr(self, node):
self.visit(node.expr)
self.emit('LOAD_ATTR', self.mangle(node.attrname))
# next five implement assignments
def visitAssign(self, node):
self.set_lineno(node)
self.visit(node.expr)
dups = len(node.nodes) - 1
for i in range(len(node.nodes)):
elt = node.nodes[i]
if i < dups:
self.emit('DUP_TOP')
if isinstance(elt, ast.Node):
self.visit(elt)
def visitAssName(self, node):
if node.flags == 'OP_ASSIGN':
self.storeName(node.name)
elif node.flags == 'OP_DELETE':
self.set_lineno(node)
self.delName(node.name)
else:
print "oops", node.flags
def visitAssAttr(self, node):
self.visit(node.expr)
if node.flags == 'OP_ASSIGN':
self.emit('STORE_ATTR', self.mangle(node.attrname))
elif node.flags == 'OP_DELETE':
self.emit('DELETE_ATTR', self.mangle(node.attrname))
else:
print "warning: unexpected flags:", node.flags
print node
def _visitAssSequence(self, node, op='UNPACK_SEQUENCE'):
if findOp(node) != 'OP_DELETE':
self.emit(op, len(node.nodes))
for child in node.nodes:
self.visit(child)
if VERSION > 1:
visitAssTuple = _visitAssSequence
visitAssList = _visitAssSequence
else:
def visitAssTuple(self, node):
self._visitAssSequence(node, 'UNPACK_TUPLE')
def visitAssList(self, node):
self._visitAssSequence(node, 'UNPACK_LIST')
# augmented assignment
def visitAugAssign(self, node):
self.set_lineno(node)
aug_node = wrap_aug(node.node)
self.visit(aug_node, "load")
self.visit(node.expr)
self.emit(self._augmented_opcode[node.op])
self.visit(aug_node, "store")
_augmented_opcode = {
'+=' : 'INPLACE_ADD',
'-=' : 'INPLACE_SUBTRACT',
'*=' : 'INPLACE_MULTIPLY',
'/=' : 'INPLACE_DIVIDE',
'%=' : 'INPLACE_MODULO',
'**=': 'INPLACE_POWER',
'>>=': 'INPLACE_RSHIFT',
'<<=': 'INPLACE_LSHIFT',
'&=' : 'INPLACE_AND',
'^=' : 'INPLACE_XOR',
'|=' : 'INPLACE_OR',
}
def visitAugName(self, node, mode):
if mode == "load":
self.loadName(node.name)
elif mode == "store":
self.storeName(node.name)
def visitAugGetattr(self, node, mode):
if mode == "load":
self.visit(node.expr)
self.emit('DUP_TOP')
self.emit('LOAD_ATTR', self.mangle(node.attrname))
elif mode == "store":
self.emit('ROT_TWO')
self.emit('STORE_ATTR', self.mangle(node.attrname))
def visitAugSlice(self, node, mode):
if mode == "load":
self.visitSlice(node, 1)
elif mode == "store":
slice = 0
if node.lower:
slice = slice | 1
if node.upper:
slice = slice | 2
if slice == 0:
self.emit('ROT_TWO')
elif slice == 3:
self.emit('ROT_FOUR')
else:
self.emit('ROT_THREE')
self.emit('STORE_SLICE+%d' % slice)
def visitAugSubscript(self, node, mode):
if len(node.subs) > 1:
raise SyntaxError, "augmented assignment to tuple is not possible"
if mode == "load":
self.visitSubscript(node, 1)
elif mode == "store":
self.emit('ROT_THREE')
self.emit('STORE_SUBSCR')
def visitExec(self, node):
self.visit(node.expr)
if node.locals is None:
self.emit('LOAD_CONST', None)
else:
self.visit(node.locals)
if node.globals is None:
self.emit('DUP_TOP')
else:
self.visit(node.globals)
self.emit('EXEC_STMT')
def visitCallFunc(self, node):
pos = 0
kw = 0
self.set_lineno(node)
self.visit(node.node)
for arg in node.args:
self.visit(arg)
if isinstance(arg, ast.Keyword):
kw = kw + 1
else:
pos = pos + 1
if node.star_args is not None:
self.visit(node.star_args)
if node.dstar_args is not None:
self.visit(node.dstar_args)
have_star = node.star_args is not None
have_dstar = node.dstar_args is not None
opcode = callfunc_opcode_info[have_star, have_dstar]
self.emit(opcode, kw << 8 | pos)
def visitPrint(self, node, newline=0):
self.set_lineno(node)
if node.dest:
self.visit(node.dest)
for child in node.nodes:
if node.dest:
self.emit('DUP_TOP')
self.visit(child)
if node.dest:
self.emit('ROT_TWO')
self.emit('PRINT_ITEM_TO')
else:
self.emit('PRINT_ITEM')
if node.dest and not newline:
self.emit('POP_TOP')
def visitPrintnl(self, node):
self.visitPrint(node, newline=1)
if node.dest:
self.emit('PRINT_NEWLINE_TO')
else:
self.emit('PRINT_NEWLINE')
def visitReturn(self, node):
self.set_lineno(node)
self.visit(node.value)
self.emit('RETURN_VALUE')
# slice and subscript stuff
def visitSlice(self, node, aug_flag=None):
# aug_flag is used by visitAugSlice
self.visit(node.expr)
slice = 0
if node.lower:
self.visit(node.lower)
slice = slice | 1
if node.upper:
self.visit(node.upper)
slice = slice | 2
if aug_flag:
if slice == 0:
self.emit('DUP_TOP')
elif slice == 3:
self.emit('DUP_TOPX', 3)
else:
self.emit('DUP_TOPX', 2)
if node.flags == 'OP_APPLY':
self.emit('SLICE+%d' % slice)
elif node.flags == 'OP_ASSIGN':
self.emit('STORE_SLICE+%d' % slice)
elif node.flags == 'OP_DELETE':
self.emit('DELETE_SLICE+%d' % slice)
else:
print "weird slice", node.flags
raise
def visitSubscript(self, node, aug_flag=None):
self.visit(node.expr)
for sub in node.subs:
self.visit(sub)
if aug_flag:
self.emit('DUP_TOPX', 2)
if len(node.subs) > 1:
self.emit('BUILD_TUPLE', len(node.subs))
if node.flags == 'OP_APPLY':
self.emit('BINARY_SUBSCR')
elif node.flags == 'OP_ASSIGN':
self.emit('STORE_SUBSCR')
elif node.flags == 'OP_DELETE':
self.emit('DELETE_SUBSCR')
# binary ops
def binaryOp(self, node, op):
self.visit(node.left)
self.visit(node.right)
self.emit(op)
def visitAdd(self, node):
return self.binaryOp(node, 'BINARY_ADD')
def visitSub(self, node):
return self.binaryOp(node, 'BINARY_SUBTRACT')
def visitMul(self, node):
return self.binaryOp(node, 'BINARY_MULTIPLY')
def visitDiv(self, node):
return self.binaryOp(node, 'BINARY_DIVIDE')
def visitMod(self, node):
return self.binaryOp(node, 'BINARY_MODULO')
def visitPower(self, node):
return self.binaryOp(node, 'BINARY_POWER')
def visitLeftShift(self, node):
return self.binaryOp(node, 'BINARY_LSHIFT')
def visitRightShift(self, node):
return self.binaryOp(node, 'BINARY_RSHIFT')
# unary ops
def unaryOp(self, node, op):
self.visit(node.expr)
self.emit(op)
def visitInvert(self, node):
return self.unaryOp(node, 'UNARY_INVERT')
def visitUnarySub(self, node):
return self.unaryOp(node, 'UNARY_NEGATIVE')
def visitUnaryAdd(self, node):
return self.unaryOp(node, 'UNARY_POSITIVE')
def visitUnaryInvert(self, node):
return self.unaryOp(node, 'UNARY_INVERT')
def visitNot(self, node):
return self.unaryOp(node, 'UNARY_NOT')
def visitBackquote(self, node):
return self.unaryOp(node, 'UNARY_CONVERT')
# bit ops
def bitOp(self, nodes, op):
self.visit(nodes[0])
for node in nodes[1:]:
self.visit(node)
self.emit(op)
def visitBitand(self, node):
return self.bitOp(node.nodes, 'BINARY_AND')
def visitBitor(self, node):
return self.bitOp(node.nodes, 'BINARY_OR')
def visitBitxor(self, node):
return self.bitOp(node.nodes, 'BINARY_XOR')
# object constructors
def visitEllipsis(self, node):
self.emit('LOAD_CONST', Ellipsis)
def visitTuple(self, node):
self.set_lineno(node)
for elt in node.nodes:
self.visit(elt)
self.emit('BUILD_TUPLE', len(node.nodes))
def visitList(self, node):
self.set_lineno(node)
for elt in node.nodes:
self.visit(elt)
self.emit('BUILD_LIST', len(node.nodes))
def visitSliceobj(self, node):
for child in node.nodes:
self.visit(child)
self.emit('BUILD_SLICE', len(node.nodes))
def visitDict(self, node):
lineno = getattr(node, 'lineno', None)
if lineno:
self.emit('SET_LINENO', lineno)
self.emit('BUILD_MAP', 0)
for k, v in node.items:
lineno2 = getattr(node, 'lineno', None)
if lineno2 is not None and lineno != lineno2:
self.emit('SET_LINENO', lineno2)
lineno = lineno2
self.emit('DUP_TOP')
self.visit(v)
self.emit('ROT_TWO')
self.visit(k)
self.emit('STORE_SUBSCR')
class NestedScopeCodeGenerator(CodeGenerator):
__super_visitModule = CodeGenerator.visitModule
__super_visitClass = CodeGenerator.visitClass
__super__visitFuncOrLambda = CodeGenerator._visitFuncOrLambda
def parseSymbols(self, tree):
s = symbols.SymbolVisitor()
walk(tree, s)
return s.scopes
def visitModule(self, node):
self.scopes = self.parseSymbols(node)
self.scope = self.scopes[node]
self.__super_visitModule(node)
def _nameOp(self, prefix, name):
name = self.mangle(name)
scope = self.scope.check_name(name)
if scope == SC_LOCAL:
if not self.optimized:
self.emit(prefix + '_NAME', name)
else:
self.emit(prefix + '_FAST', name)
elif scope == SC_GLOBAL:
self.emit(prefix + '_GLOBAL', name)
elif scope == SC_FREE or scope == SC_CELL:
self.emit(prefix + '_DEREF', name)
else:
raise RuntimeError, "unsupported scope for var %s: %d" % \
(name, scope)
def _visitFuncOrLambda(self, node, isLambda=0):
gen = self.FunctionGen(node, self.filename, self.scopes, isLambda,
self.class_name)
walk(node.code, gen)
gen.finish()
self.set_lineno(node)
for default in node.defaults:
self.visit(default)
frees = gen.scope.get_free_vars()
if frees:
for name in frees:
self.emit('LOAD_CLOSURE', name)
self.emit('LOAD_CONST', gen)
self.emit('MAKE_CLOSURE', len(node.defaults))
else:
self.emit('LOAD_CONST', gen)
self.emit('MAKE_FUNCTION', len(node.defaults))
def visitClass(self, node):
gen = self.ClassGen(node, self.scopes, self.filename)
walk(node.code, gen)
gen.finish()
self.set_lineno(node)
self.emit('LOAD_CONST', node.name)
for base in node.bases:
self.visit(base)
self.emit('BUILD_TUPLE', len(node.bases))
frees = gen.scope.get_free_vars()
for name in frees:
self.emit('LOAD_CLOSURE', name)
self.emit('LOAD_CONST', gen)
if frees:
self.emit('MAKE_CLOSURE', 0)
else:
self.emit('MAKE_FUNCTION', 0)
self.emit('CALL_FUNCTION', 0)
self.emit('BUILD_CLASS')
self.storeName(node.name)
class LGBScopeMixin:
"""Defines initClass() for Python 2.1-compatible scoping"""
def initClass(self):
self.__class__.NameFinder = LocalNameFinder
self.__class__.FunctionGen = FunctionCodeGenerator
self.__class__.ClassGen = ClassCodeGenerator
class NestedScopeMixin:
"""Defines initClass() for nested scoping (Python 2.2-compatible)"""
def initClass(self):
self.__class__.NameFinder = LocalNameFinder
self.__class__.FunctionGen = NestedFunctionCodeGenerator
self.__class__.ClassGen = NestedClassCodeGenerator
class ModuleCodeGenerator(LGBScopeMixin, CodeGenerator):
__super_init = CodeGenerator.__init__
scopes = None
def __init__(self, filename):
self.graph = pyassem.PyFlowGraph("<module>", filename)
self.__super_init(filename)
class NestedScopeModuleCodeGenerator(NestedScopeMixin,
NestedScopeCodeGenerator):
__super_init = CodeGenerator.__init__
def __init__(self, filename):
self.graph = pyassem.PyFlowGraph("<module>", filename)
self.__super_init(filename)
self.graph.setFlag(CO_NESTED)
class AbstractFunctionCode:
optimized = 1
lambdaCount = 0
def __init__(self, func, filename, scopes, isLambda, class_name):
self.scopes = scopes
self.scope = scopes[func]
self.class_name = class_name
if isLambda:
klass = FunctionCodeGenerator
name = "<lambda.%d>" % klass.lambdaCount
klass.lambdaCount = klass.lambdaCount + 1
else:
name = func.name
args, hasTupleArg = generateArgList(func.argnames)
self.graph = pyassem.PyFlowGraph(name, filename, args,
optimized=1)
self.isLambda = isLambda
self.super_init(filename)
if not isLambda and func.doc:
self.setDocstring(func.doc)
lnf = walk(func.code, self.NameFinder(args), verbose=0)
self.locals.push(lnf.getLocals())
if func.varargs:
self.graph.setFlag(CO_VARARGS)
if func.kwargs:
self.graph.setFlag(CO_VARKEYWORDS)
self.set_lineno(func)
if hasTupleArg:
self.generateArgUnpack(func.argnames)
def finish(self):
self.graph.startExitBlock()
if not self.isLambda:
self.emit('LOAD_CONST', None)
self.emit('RETURN_VALUE')
def generateArgUnpack(self, args):
for i in range(len(args)):
arg = args[i]
if type(arg) == types.TupleType:
self.emit('LOAD_FAST', '.%d' % (i * 2))
self.unpackSequence(arg)
def unpackSequence(self, tup):
if VERSION > 1:
self.emit('UNPACK_SEQUENCE', len(tup))
else:
self.emit('UNPACK_TUPLE', len(tup))
for elt in tup:
if type(elt) == types.TupleType:
self.unpackSequence(elt)
else:
self._nameOp('STORE', elt)
unpackTuple = unpackSequence
class FunctionCodeGenerator(LGBScopeMixin, AbstractFunctionCode,
CodeGenerator):
super_init = CodeGenerator.__init__ # call be other init
scopes = None
class NestedFunctionCodeGenerator(AbstractFunctionCode,
NestedScopeMixin,
NestedScopeCodeGenerator):
super_init = NestedScopeCodeGenerator.__init__ # call be other init
__super_init = AbstractFunctionCode.__init__
def __init__(self, func, filename, scopes, isLambda, class_name):
self.scopes = scopes
self.scope = scopes[func]
self.__super_init(func, filename, scopes, isLambda, class_name)
self.graph.setFreeVars(self.scope.get_free_vars())
self.graph.setCellVars(self.scope.get_cell_vars())
self.graph.setFlag(CO_NESTED)
class AbstractClassCode:
def __init__(self, klass, scopes, filename):
assert isinstance(filename, types.StringType)
assert isinstance(scopes, types.DictType)
self.graph = pyassem.PyFlowGraph(klass.name, filename,
optimized=0)
self.super_init(filename)
lnf = walk(klass.code, self.NameFinder(), 0)
self.locals.push(lnf.getLocals())
self.graph.setFlag(CO_NEWLOCALS)
if klass.doc:
self.setDocstring(klass.doc)
def _nameOp(self, prefix, name):
# Class namespaces are always unoptimized
self.emit(prefix + '_NAME', name)
def finish(self):
self.graph.startExitBlock()
self.emit('LOAD_LOCALS')
self.emit('RETURN_VALUE')
class ClassCodeGenerator(LGBScopeMixin, AbstractClassCode, CodeGenerator):
super_init = CodeGenerator.__init__
scopes = None
__super_init = AbstractClassCode.__init__
def __init__(self, klass, scopes, filename):
self.scopes = scopes
self.scope = scopes[klass]
self.__super_init(klass, scopes, filename)
self.graph.setFreeVars(self.scope.get_free_vars())
self.graph.setCellVars(self.scope.get_cell_vars())
class NestedClassCodeGenerator(AbstractClassCode,
NestedScopeMixin,
NestedScopeCodeGenerator):
super_init = NestedScopeCodeGenerator.__init__ # call be other init
__super_init = AbstractClassCode.__init__
def __init__(self, klass, scopes, filename):
assert isinstance(filename, types.StringType)
self.scopes = scopes
self.scope = scopes[klass]
self.__super_init(klass, scopes, filename)
self.graph.setFreeVars(self.scope.get_free_vars())
self.graph.setCellVars(self.scope.get_cell_vars())
self.graph.setFlag(CO_NESTED)
def generateArgList(arglist):
"""Generate an arg list marking TupleArgs"""
args = []
extra = []
count = 0
for i in range(len(arglist)):
elt = arglist[i]
if type(elt) == types.StringType:
args.append(elt)
elif type(elt) == types.TupleType:
args.append(TupleArg(i * 2, elt))
extra.extend(misc.flatten(elt))
count = count + 1
else:
raise ValueError, "unexpect argument type:", elt
return args + extra, count
def findOp(node):
"""Find the op (DELETE, LOAD, STORE) in an AssTuple tree"""
v = OpFinder()
walk(node, v, verbose=0)
return v.op
class OpFinder:
def __init__(self):
self.op = None
def visitAssName(self, node):
if self.op is None:
self.op = node.flags
elif self.op != node.flags:
raise ValueError, "mixed ops in stmt"
visitAssAttr = visitAssName
visitSubscript = visitAssName
class Delegator:
"""Base class to support delegation for augmented assignment nodes
To generator code for augmented assignments, we use the following
wrapper classes. In visitAugAssign, the left-hand expression node
is visited twice. The first time the visit uses the normal method
for that node . The second time the visit uses a different method
that generates the appropriate code to perform the assignment.
These delegator classes wrap the original AST nodes in order to
support the variant visit methods.
"""
def __init__(self, obj):
self.obj = obj
def __getattr__(self, attr):
return getattr(self.obj, attr)
class AugGetattr(Delegator):
pass
class AugName(Delegator):
pass
class AugSlice(Delegator):
pass
class AugSubscript(Delegator):
pass
wrapper = {
ast.Getattr: AugGetattr,
ast.Name: AugName,
ast.Slice: AugSlice,
ast.Subscript: AugSubscript,
}
def wrap_aug(node):
return wrapper[node.__class__](node)
if __name__ == "__main__":
import sys
for file in sys.argv[1:]:
compile(file)
"""Module symbol-table generator"""
import ast
from consts import SC_LOCAL, SC_GLOBAL, SC_FREE, SC_CELL, SC_UNKNOWN
from misc import mangle
import types
import sys
MANGLE_LEN = 256
class Scope:
# XXX how much information do I need about each name?
def __init__(self, name, module, klass=None):
self.name = name
self.module = module
self.defs = {}
self.uses = {}
self.globals = {}
self.params = {}
self.frees = {}
self.cells = {}
self.children = []
# nested is true if the class could contain free variables,
# i.e. if it is nested within another function.
self.nested = None
self.klass = None
if klass is not None:
for i in range(len(klass)):
if klass[i] != '_':
self.klass = klass[i:]
break
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self.name)
def mangle(self, name):
if self.klass is None:
return name
return mangle(name, self.klass)
def add_def(self, name):
self.defs[self.mangle(name)] = 1
def add_use(self, name):
self.uses[self.mangle(name)] = 1
def add_global(self, name):
name = self.mangle(name)
if self.uses.has_key(name) or self.defs.has_key(name):
pass # XXX warn about global following def/use
if self.params.has_key(name):
raise SyntaxError, "%s in %s is global and parameter" % \
(name, self.name)
self.globals[name] = 1
self.module.add_def(name)
def add_param(self, name):
name = self.mangle(name)
self.defs[name] = 1
self.params[name] = 1
def get_names(self):
d = {}
d.update(self.defs)
d.update(self.uses)
d.update(self.globals)
return d.keys()
def add_child(self, child):
self.children.append(child)
def get_children(self):
return self.children
def DEBUG(self):
return
print >> sys.stderr, self.name, self.nested and "nested" or ""
print >> sys.stderr, "\tglobals: ", self.globals
print >> sys.stderr, "\tcells: ", self.cells
print >> sys.stderr, "\tdefs: ", self.defs
print >> sys.stderr, "\tuses: ", self.uses
print >> sys.stderr, "\tfrees:", self.frees
def check_name(self, name):
"""Return scope of name.
The scope of a name could be LOCAL, GLOBAL, FREE, or CELL.
"""
if self.globals.has_key(name):
return SC_GLOBAL
if self.cells.has_key(name):
return SC_CELL
if self.defs.has_key(name):
return SC_LOCAL
if self.nested and (self.frees.has_key(name) or
self.uses.has_key(name)):
return SC_FREE
if self.nested:
return SC_UNKNOWN
else:
return SC_GLOBAL
def get_free_vars(self):
if not self.nested:
return ()
free = {}
free.update(self.frees)
for name in self.uses.keys():
if not (self.defs.has_key(name) or
self.globals.has_key(name)):
free[name] = 1
return free.keys()
def handle_children(self):
for child in self.children:
frees = child.get_free_vars()
globals = self.add_frees(frees)
for name in globals:
child.force_global(name)
def force_global(self, name):
"""Force name to be global in scope.
Some child of the current node had a free reference to name.
When the child was processed, it was labelled a free
variable. Now that all its enclosing scope have been
processed, the name is known to be a global or builtin. So
walk back down the child chain and set the name to be global
rather than free.
Be careful to stop if a child does not think the name is
free.
"""
self.globals[name] = 1
if self.frees.has_key(name):
del self.frees[name]
for child in self.children:
if child.check_name(name) == SC_FREE:
child.force_global(name)
def add_frees(self, names):
"""Process list of free vars from nested scope.
Returns a list of names that are either 1) declared global in the
parent or 2) undefined in a top-level parent. In either case,
the nested scope should treat them as globals.
"""
child_globals = []
for name in names:
sc = self.check_name(name)
if self.nested:
if sc == SC_UNKNOWN or sc == SC_FREE \
or isinstance(self, ClassScope):
self.frees[name] = 1
elif sc == SC_GLOBAL:
child_globals.append(name)
elif isinstance(self, FunctionScope) and sc == SC_LOCAL:
self.cells[name] = 1
elif sc != SC_CELL:
child_globals.append(name)
else:
if sc == SC_LOCAL:
self.cells[name] = 1
elif sc != SC_CELL:
child_globals.append(name)
return child_globals
def get_cell_vars(self):
return self.cells.keys()
class ModuleScope(Scope):
__super_init = Scope.__init__
def __init__(self):
self.__super_init("global", self)
class FunctionScope(Scope):
pass
class LambdaScope(FunctionScope):
__super_init = Scope.__init__
__counter = 1
def __init__(self, module, klass=None):
i = self.__counter
self.__counter += 1
self.__super_init("lambda.%d" % i, module, klass)
class ClassScope(Scope):
__super_init = Scope.__init__
def __init__(self, name, module):
self.__super_init(name, module, name)
class SymbolVisitor:
def __init__(self):
self.scopes = {}
self.klass = None
# node that define new scopes
def visitModule(self, node):
scope = self.module = self.scopes[node] = ModuleScope()
self.visit(node.node, scope)
def visitFunction(self, node, parent):
parent.add_def(node.name)
for n in node.defaults:
self.visit(n, parent)
scope = FunctionScope(node.name, self.module, self.klass)
if parent.nested or isinstance(parent, FunctionScope):
scope.nested = 1
self.scopes[node] = scope
self._do_args(scope, node.argnames)
self.visit(node.code, scope)
self.handle_free_vars(scope, parent)
scope.DEBUG()
def visitLambda(self, node, parent):
for n in node.defaults:
self.visit(n, parent)
scope = LambdaScope(self.module, self.klass)
if parent.nested or isinstance(parent, FunctionScope):
scope.nested = 1
self.scopes[node] = scope
self._do_args(scope, node.argnames)
self.visit(node.code, scope)
self.handle_free_vars(scope, parent)
def _do_args(self, scope, args):
for name in args:
if type(name) == types.TupleType:
self._do_args(scope, name)
else:
scope.add_param(name)
def handle_free_vars(self, scope, parent):
parent.add_child(scope)
if scope.children:
scope.DEBUG()
scope.handle_children()
def visitClass(self, node, parent):
parent.add_def(node.name)
for n in node.bases:
self.visit(n, parent)
scope = ClassScope(node.name, self.module)
if parent.nested or isinstance(parent, FunctionScope):
scope.nested = 1
self.scopes[node] = scope
prev = self.klass
self.klass = node.name
self.visit(node.code, scope)
self.klass = prev
self.handle_free_vars(scope, parent)
# name can be a def or a use
# XXX a few calls and nodes expect a third "assign" arg that is
# true if the name is being used as an assignment. only
# expressions contained within statements may have the assign arg.
def visitName(self, node, scope, assign=0):
if assign:
scope.add_def(node.name)
else:
scope.add_use(node.name)
# operations that bind new names
def visitFor(self, node, scope):
self.visit(node.assign, scope, 1)
self.visit(node.list, scope)
self.visit(node.body, scope)
if node.else_:
self.visit(node.else_, scope)
def visitFrom(self, node, scope):
for name, asname in node.names:
if name == "*":
continue
scope.add_def(asname or name)
def visitImport(self, node, scope):
for name, asname in node.names:
i = name.find(".")
if i > -1:
name = name[:i]
scope.add_def(asname or name)
def visitGlobal(self, node, scope):
for name in node.names:
scope.add_global(name)
def visitAssign(self, node, scope):
"""Propagate assignment flag down to child nodes.
The Assign node doesn't itself contains the variables being
assigned to. Instead, the children in node.nodes are visited
with the assign flag set to true. When the names occur in
those nodes, they are marked as defs.
Some names that occur in an assignment target are not bound by
the assignment, e.g. a name occurring inside a slice. The
visitor handles these nodes specially; they do not propagate
the assign flag to their children.
"""
for n in node.nodes:
self.visit(n, scope, 1)
self.visit(node.expr, scope)
def visitAssName(self, node, scope, assign=1):
scope.add_def(node.name)
def visitAssAttr(self, node, scope, assign=0):
self.visit(node.expr, scope, 0)
def visitSubscript(self, node, scope, assign=0):
self.visit(node.expr, scope, 0)
for n in node.subs:
self.visit(n, scope, 0)
def visitSlice(self, node, scope, assign=0):
self.visit(node.expr, scope, 0)
if node.lower:
self.visit(node.lower, scope, 0)
if node.upper:
self.visit(node.upper, scope, 0)
def visitAugAssign(self, node, scope):
# If the LHS is a name, then this counts as assignment.
# Otherwise, it's just use.
self.visit(node.node, scope)
if isinstance(node.node, ast.Name):
self.visit(node.node, scope, 1) # XXX worry about this
self.visit(node.expr, scope)
# prune if statements if tests are false
_const_types = types.StringType, types.IntType, types.FloatType
def visitIf(self, node, scope):
for test, body in node.tests:
if isinstance(test, ast.Const):
if type(test.value) in self._const_types:
if not test.value:
continue
self.visit(test, scope)
self.visit(body, scope)
if node.else_:
self.visit(node.else_, scope)
def sort(l):
l = l[:]
l.sort()
return l
def list_eq(l1, l2):
return sort(l1) == sort(l2)
if __name__ == "__main__":
import sys
from transformer import parseFile
from visitor import walk
import symtable
def get_names(syms):
return [s for s in [s.get_name() for s in syms.get_symbols()]
if not (s.startswith('_[') or s.startswith('.'))]
for file in sys.argv[1:]:
print file
f = open(file)
buf = f.read()
f.close()
syms = symtable.symtable(buf, file, "exec")
mod_names = get_names(syms)
tree = parseFile(file)
s = SymbolVisitor()
walk(tree, s)
# compare module-level symbols
names2 = s.scopes[tree].get_names()
if not list_eq(mod_names, names2):
print
print "oops", file
print sort(mod_names)
print sort(names2)
sys.exit(-1)
d = {}
d.update(s.scopes)
del d[tree]
scopes = d.values()
del d
for s in syms.get_symbols():
if s.is_namespace():
l = [sc for sc in scopes
if sc.name == s.get_name()]
if len(l) > 1:
print "skipping", s.get_name()
else:
if not list_eq(get_names(s.get_namespace()),
l[0].get_names()):
print s.get_name()
print sort(get_names(s.get_namespace()))
print sort(l[0].get_names())
sys.exit(-1)
"""Parse tree transformation module.
Transforms Python source code into an abstract syntax tree (AST)
defined in the ast module.
The simplest ways to invoke this module are via parse and parseFile.
parse(buf) -> AST
parseFile(path) -> AST
"""
# Original version written by Greg Stein (gstein@lyra.org)
# and Bill Tutt (rassilon@lima.mudlib.org)
# February 1997.
#
# Modifications and improvements for Python 2.0 by Jeremy Hylton and
# Mark Hammond
# Portions of this file are:
# Copyright (C) 1997-1998 Greg Stein. All Rights Reserved.
#
# This module is provided under a BSD-ish license. See
# http://www.opensource.org/licenses/bsd-license.html
# and replace OWNER, ORGANIZATION, and YEAR as appropriate.
from ast import *
import parser
# Care must be taken to use only symbols and tokens defined in Python
# 1.5.2 for code branches executed in 1.5.2
import symbol
import token
import string
import sys
error = 'walker.error'
from consts import CO_VARARGS, CO_VARKEYWORDS
from consts import OP_ASSIGN, OP_DELETE, OP_APPLY
def parseFile(path):
f = open(path)
src = f.read()
f.close()
return parse(src)
def parse(buf):
return Transformer().parsesuite(buf)
def asList(nodes):
l = []
for item in nodes:
if hasattr(item, "asList"):
l.append(item.asList())
else:
if type(item) is type( (None, None) ):
l.append(tuple(asList(item)))
elif type(item) is type( [] ):
l.append(asList(item))
else:
l.append(item)
return l
def Node(*args):
kind = args[0]
if nodes.has_key(kind):
try:
return apply(nodes[kind], args[1:])
except TypeError:
print nodes[kind], len(args), args
raise
else:
raise error, "Can't find appropriate Node type: %s" % str(args)
#return apply(ast.Node, args)
class Transformer:
"""Utility object for transforming Python parse trees.
Exposes the following methods:
tree = transform(ast_tree)
tree = parsesuite(text)
tree = parseexpr(text)
tree = parsefile(fileob | filename)
"""
def __init__(self):
self._dispatch = {}
for value, name in symbol.sym_name.items():
if hasattr(self, name):
self._dispatch[value] = getattr(self, name)
self._dispatch[token.NEWLINE] = self.com_NEWLINE
self._atom_dispatch = {token.LPAR: self.atom_lpar,
token.LSQB: self.atom_lsqb,
token.LBRACE: self.atom_lbrace,
token.BACKQUOTE: self.atom_backquote,
token.NUMBER: self.atom_number,
token.STRING: self.atom_string,
token.NAME: self.atom_name,
}
def transform(self, tree):
"""Transform an AST into a modified parse tree."""
if type(tree) != type(()) and type(tree) != type([]):
tree = parser.ast2tuple(tree,1)
return self.compile_node(tree)
def parsesuite(self, text):
"""Return a modified parse tree for the given suite text."""
# Hack for handling non-native line endings on non-DOS like OSs.
text = string.replace(text, '\x0d', '')
return self.transform(parser.suite(text))
def parseexpr(self, text):
"""Return a modified parse tree for the given expression text."""
return self.transform(parser.expr(text))
def parsefile(self, file):
"""Return a modified parse tree for the contents of the given file."""
if type(file) == type(''):
file = open(file)
return self.parsesuite(file.read())
# --------------------------------------------------------------
#
# PRIVATE METHODS
#
def compile_node(self, node):
### emit a line-number node?
n = node[0]
if n == symbol.single_input:
return self.single_input(node[1:])
if n == symbol.file_input:
return self.file_input(node[1:])
if n == symbol.eval_input:
return self.eval_input(node[1:])
if n == symbol.lambdef:
return self.lambdef(node[1:])
if n == symbol.funcdef:
return self.funcdef(node[1:])
if n == symbol.classdef:
return self.classdef(node[1:])
raise error, ('unexpected node type', n)
def single_input(self, node):
### do we want to do anything about being "interactive" ?
# NEWLINE | simple_stmt | compound_stmt NEWLINE
n = node[0][0]
if n != token.NEWLINE:
return self.com_stmt(node[0])
return Pass()
def file_input(self, nodelist):
doc = self.get_docstring(nodelist, symbol.file_input)
stmts = []
for node in nodelist:
if node[0] != token.ENDMARKER and node[0] != token.NEWLINE:
self.com_append_stmt(stmts, node)
return Module(doc, Stmt(stmts))
def eval_input(self, nodelist):
# from the built-in function input()
### is this sufficient?
return self.com_node(nodelist[0])
def funcdef(self, nodelist):
# funcdef: 'def' NAME parameters ':' suite
# parameters: '(' [varargslist] ')'
lineno = nodelist[1][2]
name = nodelist[1][1]
args = nodelist[2][2]
if args[0] == symbol.varargslist:
names, defaults, flags = self.com_arglist(args[1:])
else:
names = defaults = ()
flags = 0
doc = self.get_docstring(nodelist[4])
# code for function
code = self.com_node(nodelist[4])
if doc is not None:
assert isinstance(code, Stmt)
assert isinstance(code.nodes[0], Discard)
del code.nodes[0]
n = Function(name, names, defaults, flags, doc, code)
n.lineno = lineno
return n
def lambdef(self, nodelist):
# lambdef: 'lambda' [varargslist] ':' test
if nodelist[2][0] == symbol.varargslist:
names, defaults, flags = self.com_arglist(nodelist[2][1:])
else:
names = defaults = ()
flags = 0
# code for lambda
code = self.com_node(nodelist[-1])
n = Lambda(names, defaults, flags, code)
n.lineno = nodelist[1][2]
return n
def classdef(self, nodelist):
# classdef: 'class' NAME ['(' testlist ')'] ':' suite
name = nodelist[1][1]
doc = self.get_docstring(nodelist[-1])
if nodelist[2][0] == token.COLON:
bases = []
else:
bases = self.com_bases(nodelist[3])
# code for class
code = self.com_node(nodelist[-1])
n = Class(name, bases, doc, code)
n.lineno = nodelist[1][2]
return n
def stmt(self, nodelist):
return self.com_stmt(nodelist[0])
small_stmt = stmt
flow_stmt = stmt
compound_stmt = stmt
def simple_stmt(self, nodelist):
# small_stmt (';' small_stmt)* [';'] NEWLINE
stmts = []
for i in range(0, len(nodelist), 2):
self.com_append_stmt(stmts, nodelist[i])
return Stmt(stmts)
def parameters(self, nodelist):
raise error
def varargslist(self, nodelist):
raise error
def fpdef(self, nodelist):
raise error
def fplist(self, nodelist):
raise error
def dotted_name(self, nodelist):
raise error
def comp_op(self, nodelist):
raise error
def trailer(self, nodelist):
raise error
def sliceop(self, nodelist):
raise error
def argument(self, nodelist):
raise error
# --------------------------------------------------------------
#
# STATEMENT NODES (invoked by com_node())
#
def expr_stmt(self, nodelist):
# augassign testlist | testlist ('=' testlist)*
exprNode = self.com_node(nodelist[-1])
if len(nodelist) == 1:
n = Discard(exprNode)
n.lineno = exprNode.lineno
return n
if nodelist[1][0] == token.EQUAL:
nodes = []
for i in range(0, len(nodelist) - 2, 2):
nodes.append(self.com_assign(nodelist[i], OP_ASSIGN))
n = Assign(nodes, exprNode)
n.lineno = nodelist[1][2]
else:
lval = self.com_augassign(nodelist[0])
op = self.com_augassign_op(nodelist[1])
n = AugAssign(lval, op[1], exprNode)
n.lineno = op[2]
return n
def print_stmt(self, nodelist):
# print ([ test (',' test)* [','] ] | '>>' test [ (',' test)+ [','] ])
items = []
if len(nodelist) == 1:
start = 1
dest = None
elif nodelist[1][0] == token.RIGHTSHIFT:
assert len(nodelist) == 3 \
or nodelist[3][0] == token.COMMA
dest = self.com_node(nodelist[2])
start = 4
else:
dest = None
start = 1
for i in range(start, len(nodelist), 2):
items.append(self.com_node(nodelist[i]))
if nodelist[-1][0] == token.COMMA:
n = Print(items, dest)
n.lineno = nodelist[0][2]
return n
n = Printnl(items, dest)
n.lineno = nodelist[0][2]
return n
def del_stmt(self, nodelist):
return self.com_assign(nodelist[1], OP_DELETE)
def pass_stmt(self, nodelist):
n = Pass()
n.lineno = nodelist[0][2]
return n
def break_stmt(self, nodelist):
n = Break()
n.lineno = nodelist[0][2]
return n
def continue_stmt(self, nodelist):
n = Continue()
n.lineno = nodelist[0][2]
return n
def return_stmt(self, nodelist):
# return: [testlist]
if len(nodelist) < 2:
n = Return(Const(None))
n.lineno = nodelist[0][2]
return n
n = Return(self.com_node(nodelist[1]))
n.lineno = nodelist[0][2]
return n
def raise_stmt(self, nodelist):
# raise: [test [',' test [',' test]]]
if len(nodelist) > 5:
expr3 = self.com_node(nodelist[5])
else:
expr3 = None
if len(nodelist) > 3:
expr2 = self.com_node(nodelist[3])
else:
expr2 = None
if len(nodelist) > 1:
expr1 = self.com_node(nodelist[1])
else:
expr1 = None
n = Raise(expr1, expr2, expr3)
n.lineno = nodelist[0][2]
return n
def import_stmt(self, nodelist):
# import_stmt: 'import' dotted_as_name (',' dotted_as_name)* |
# from: 'from' dotted_name 'import'
# ('*' | import_as_name (',' import_as_name)*)
if nodelist[0][1] == 'from':
names = []
if nodelist[3][0] == token.NAME:
for i in range(3, len(nodelist), 2):
names.append((nodelist[i][1], None))
else:
for i in range(3, len(nodelist), 2):
names.append(self.com_import_as_name(nodelist[i]))
n = From(self.com_dotted_name(nodelist[1]), names)
n.lineno = nodelist[0][2]
return n
if nodelist[1][0] == symbol.dotted_name:
names = [(self.com_dotted_name(nodelist[1][1:]), None)]
else:
names = []
for i in range(1, len(nodelist), 2):
names.append(self.com_dotted_as_name(nodelist[i]))
n = Import(names)
n.lineno = nodelist[0][2]
return n
def global_stmt(self, nodelist):
# global: NAME (',' NAME)*
names = []
for i in range(1, len(nodelist), 2):
names.append(nodelist[i][1])
n = Global(names)
n.lineno = nodelist[0][2]
return n
def exec_stmt(self, nodelist):
# exec_stmt: 'exec' expr ['in' expr [',' expr]]
expr1 = self.com_node(nodelist[1])
if len(nodelist) >= 4:
expr2 = self.com_node(nodelist[3])
if len(nodelist) >= 6:
expr3 = self.com_node(nodelist[5])
else:
expr3 = None
else:
expr2 = expr3 = None
n = Exec(expr1, expr2, expr3)
n.lineno = nodelist[0][2]
return n
def assert_stmt(self, nodelist):
# 'assert': test, [',' test]
expr1 = self.com_node(nodelist[1])
if (len(nodelist) == 4):
expr2 = self.com_node(nodelist[3])
else:
expr2 = None
n = Assert(expr1, expr2)
n.lineno = nodelist[0][2]
return n
def if_stmt(self, nodelist):
# if: test ':' suite ('elif' test ':' suite)* ['else' ':' suite]
tests = []
for i in range(0, len(nodelist) - 3, 4):
testNode = self.com_node(nodelist[i + 1])
suiteNode = self.com_node(nodelist[i + 3])
tests.append((testNode, suiteNode))
if len(nodelist) % 4 == 3:
elseNode = self.com_node(nodelist[-1])
## elseNode.lineno = nodelist[-1][1][2]
else:
elseNode = None
n = If(tests, elseNode)
n.lineno = nodelist[0][2]
return n
def while_stmt(self, nodelist):
# 'while' test ':' suite ['else' ':' suite]
testNode = self.com_node(nodelist[1])
bodyNode = self.com_node(nodelist[3])
if len(nodelist) > 4:
elseNode = self.com_node(nodelist[6])
else:
elseNode = None
n = While(testNode, bodyNode, elseNode)
n.lineno = nodelist[0][2]
return n
def for_stmt(self, nodelist):
# 'for' exprlist 'in' exprlist ':' suite ['else' ':' suite]
assignNode = self.com_assign(nodelist[1], OP_ASSIGN)
listNode = self.com_node(nodelist[3])
bodyNode = self.com_node(nodelist[5])
if len(nodelist) > 8:
elseNode = self.com_node(nodelist[8])
else:
elseNode = None
n = For(assignNode, listNode, bodyNode, elseNode)
n.lineno = nodelist[0][2]
return n
def try_stmt(self, nodelist):
# 'try' ':' suite (except_clause ':' suite)+ ['else' ':' suite]
# | 'try' ':' suite 'finally' ':' suite
if nodelist[3][0] != symbol.except_clause:
return self.com_try_finally(nodelist)
return self.com_try_except(nodelist)
def suite(self, nodelist):
# simple_stmt | NEWLINE INDENT NEWLINE* (stmt NEWLINE*)+ DEDENT
if len(nodelist) == 1:
return self.com_stmt(nodelist[0])
stmts = []
for node in nodelist:
if node[0] == symbol.stmt:
self.com_append_stmt(stmts, node)
return Stmt(stmts)
# --------------------------------------------------------------
#
# EXPRESSION NODES (invoked by com_node())
#
def testlist(self, nodelist):
# testlist: expr (',' expr)* [',']
# exprlist: expr (',' expr)* [',']
return self.com_binary(Tuple, nodelist)
exprlist = testlist
def test(self, nodelist):
# and_test ('or' and_test)* | lambdef
if len(nodelist) == 1 and nodelist[0][0] == symbol.lambdef:
return self.lambdef(nodelist[0])
return self.com_binary(Or, nodelist)
def and_test(self, nodelist):
# not_test ('and' not_test)*
return self.com_binary(And, nodelist)
def not_test(self, nodelist):
# 'not' not_test | comparison
result = self.com_node(nodelist[-1])
if len(nodelist) == 2:
n = Not(result)
n.lineno = nodelist[0][2]
return n
return result
def comparison(self, nodelist):
# comparison: expr (comp_op expr)*
node = self.com_node(nodelist[0])
if len(nodelist) == 1:
return node
results = []
for i in range(2, len(nodelist), 2):
nl = nodelist[i-1]
# comp_op: '<' | '>' | '=' | '>=' | '<=' | '<>' | '!=' | '=='
# | 'in' | 'not' 'in' | 'is' | 'is' 'not'
n = nl[1]
if n[0] == token.NAME:
type = n[1]
if len(nl) == 3:
if type == 'not':
type = 'not in'
else:
type = 'is not'
else:
type = _cmp_types[n[0]]
lineno = nl[1][2]
results.append((type, self.com_node(nodelist[i])))
# we need a special "compare" node so that we can distinguish
# 3 < x < 5 from (3 < x) < 5
# the two have very different semantics and results (note that the
# latter form is always true)
n = Compare(node, results)
n.lineno = lineno
return n
def expr(self, nodelist):
# xor_expr ('|' xor_expr)*
return self.com_binary(Bitor, nodelist)
def xor_expr(self, nodelist):
# xor_expr ('^' xor_expr)*
return self.com_binary(Bitxor, nodelist)
def and_expr(self, nodelist):
# xor_expr ('&' xor_expr)*
return self.com_binary(Bitand, nodelist)
def shift_expr(self, nodelist):
# shift_expr ('<<'|'>>' shift_expr)*
node = self.com_node(nodelist[0])
for i in range(2, len(nodelist), 2):
right = self.com_node(nodelist[i])
if nodelist[i-1][0] == token.LEFTSHIFT:
node = LeftShift([node, right])
node.lineno = nodelist[1][2]
else:
node = RightShift([node, right])
node.lineno = nodelist[1][2]
return node
def arith_expr(self, nodelist):
node = self.com_node(nodelist[0])
for i in range(2, len(nodelist), 2):
right = self.com_node(nodelist[i])
if nodelist[i-1][0] == token.PLUS:
node = Add([node, right])
node.lineno = nodelist[1][2]
else:
node = Sub([node, right])
node.lineno = nodelist[1][2]
return node
def term(self, nodelist):
node = self.com_node(nodelist[0])
for i in range(2, len(nodelist), 2):
right = self.com_node(nodelist[i])
t = nodelist[i-1][0]
if t == token.STAR:
node = Mul([node, right])
elif t == token.SLASH:
node = Div([node, right])
else:
node = Mod([node, right])
node.lineno = nodelist[1][2]
return node
def factor(self, nodelist):
elt = nodelist[0]
t = elt[0]
node = self.com_node(nodelist[-1])
if t == token.PLUS:
node = UnaryAdd(node)
node.lineno = elt[2]
elif t == token.MINUS:
node = UnarySub(node)
node.lineno = elt[2]
elif t == token.TILDE:
node = Invert(node)
node.lineno = elt[2]
return node
def power(self, nodelist):
# power: atom trailer* ('**' factor)*
node = self.com_node(nodelist[0])
for i in range(1, len(nodelist)):
elt = nodelist[i]
if elt[0] == token.DOUBLESTAR:
n = Power([node, self.com_node(nodelist[i+1])])
n.lineno = elt[2]
return n
node = self.com_apply_trailer(node, elt)
return node
def atom(self, nodelist):
return self._atom_dispatch[nodelist[0][0]](nodelist)
def atom_lpar(self, nodelist):
if nodelist[1][0] == token.RPAR:
n = Tuple(())
n.lineno = nodelist[0][2]
return n
return self.com_node(nodelist[1])
def atom_lsqb(self, nodelist):
if nodelist[1][0] == token.RSQB:
n = List(())
n.lineno = nodelist[0][2]
return n
return self.com_list_constructor(nodelist[1])
def atom_lbrace(self, nodelist):
if nodelist[1][0] == token.RBRACE:
return Dict(())
return self.com_dictmaker(nodelist[1])
def atom_backquote(self, nodelist):
n = Backquote(self.com_node(nodelist[1]))
n.lineno = nodelist[0][2]
return n
def atom_number(self, nodelist):
### need to verify this matches compile.c
k = eval(nodelist[0][1])
n = Const(k)
n.lineno = nodelist[0][2]
return n
def atom_string(self, nodelist):
### need to verify this matches compile.c
k = ''
for node in nodelist:
k = k + eval(node[1])
n = Const(k)
n.lineno = nodelist[0][2]
return n
def atom_name(self, nodelist):
### any processing to do?
n = Name(nodelist[0][1])
n.lineno = nodelist[0][2]
return n
# --------------------------------------------------------------
#
# INTERNAL PARSING UTILITIES
#
def com_node(self, node):
# Note: compile.c has handling in com_node for del_stmt, pass_stmt,
# break_stmt, stmt, small_stmt, flow_stmt, simple_stmt,
# and compound_stmt.
# We'll just dispatch them.
return self._dispatch[node[0]](node[1:])
def com_NEWLINE(self, *args):
# A ';' at the end of a line can make a NEWLINE token appear
# here, Render it harmless. (genc discards ('discard',
# ('const', xxxx)) Nodes)
return Discard(Const(None))
def com_arglist(self, nodelist):
# varargslist:
# (fpdef ['=' test] ',')* ('*' NAME [',' ('**'|'*' '*') NAME]
# | fpdef ['=' test] (',' fpdef ['=' test])* [',']
# | ('**'|'*' '*') NAME)
# fpdef: NAME | '(' fplist ')'
# fplist: fpdef (',' fpdef)* [',']
names = []
defaults = []
flags = 0
i = 0
while i < len(nodelist):
node = nodelist[i]
if node[0] == token.STAR or node[0] == token.DOUBLESTAR:
if node[0] == token.STAR:
node = nodelist[i+1]
if node[0] == token.NAME:
names.append(node[1])
flags = flags | CO_VARARGS
i = i + 3
if i < len(nodelist):
# should be DOUBLESTAR or STAR STAR
if nodelist[i][0] == token.DOUBLESTAR:
node = nodelist[i+1]
else:
node = nodelist[i+2]
names.append(node[1])
flags = flags | CO_VARKEYWORDS
break
# fpdef: NAME | '(' fplist ')'
names.append(self.com_fpdef(node))
i = i + 1
if i >= len(nodelist):
break
if nodelist[i][0] == token.EQUAL:
defaults.append(self.com_node(nodelist[i + 1]))
i = i + 2
elif len(defaults):
# Treat "(a=1, b)" as "(a=1, b=None)"
defaults.append(Const(None))
i = i + 1
return names, defaults, flags
def com_fpdef(self, node):
# fpdef: NAME | '(' fplist ')'
if node[1][0] == token.LPAR:
return self.com_fplist(node[2])
return node[1][1]
def com_fplist(self, node):
# fplist: fpdef (',' fpdef)* [',']
if len(node) == 2:
return self.com_fpdef(node[1])
list = []
for i in range(1, len(node), 2):
list.append(self.com_fpdef(node[i]))
return tuple(list)
def com_dotted_name(self, node):
# String together the dotted names and return the string
name = ""
for n in node:
if type(n) == type(()) and n[0] == 1:
name = name + n[1] + '.'
return name[:-1]
def com_dotted_as_name(self, node):
dot = self.com_dotted_name(node[1])
if len(node) <= 2:
return dot, None
if node[0] == symbol.dotted_name:
pass
else:
assert node[2][1] == 'as'
assert node[3][0] == token.NAME
return dot, node[3][1]
def com_import_as_name(self, node):
if node[0] == token.STAR:
return '*', None
assert node[0] == symbol.import_as_name
node = node[1:]
if len(node) == 1:
assert node[0][0] == token.NAME
return node[0][1], None
assert node[1][1] == 'as', node
assert node[2][0] == token.NAME
return node[0][1], node[2][1]
def com_bases(self, node):
bases = []
for i in range(1, len(node), 2):
bases.append(self.com_node(node[i]))
return bases
def com_try_finally(self, nodelist):
# try_fin_stmt: "try" ":" suite "finally" ":" suite
n = TryFinally(self.com_node(nodelist[2]),
self.com_node(nodelist[5]))
n.lineno = nodelist[0][2]
return n
def com_try_except(self, nodelist):
# try_except: 'try' ':' suite (except_clause ':' suite)* ['else' suite]
#tryexcept: [TryNode, [except_clauses], elseNode)]
stmt = self.com_node(nodelist[2])
clauses = []
elseNode = None
for i in range(3, len(nodelist), 3):
node = nodelist[i]
if node[0] == symbol.except_clause:
# except_clause: 'except' [expr [',' expr]] */
if len(node) > 2:
expr1 = self.com_node(node[2])
if len(node) > 4:
expr2 = self.com_assign(node[4], OP_ASSIGN)
else:
expr2 = None
else:
expr1 = expr2 = None
clauses.append((expr1, expr2, self.com_node(nodelist[i+2])))
if node[0] == token.NAME:
elseNode = self.com_node(nodelist[i+2])
n = TryExcept(self.com_node(nodelist[2]), clauses, elseNode)
n.lineno = nodelist[0][2]
return n
def com_augassign_op(self, node):
assert node[0] == symbol.augassign
return node[1]
def com_augassign(self, node):
"""Return node suitable for lvalue of augmented assignment
Names, slices, and attributes are the only allowable nodes.
"""
l = self.com_node(node)
if l.__class__ in (Name, Slice, Subscript, Getattr):
return l
raise SyntaxError, "can't assign to %s" % l.__class__.__name__
def com_assign(self, node, assigning):
# return a node suitable for use as an "lvalue"
# loop to avoid trivial recursion
while 1:
t = node[0]
if t == symbol.exprlist or t == symbol.testlist:
if len(node) > 2:
return self.com_assign_tuple(node, assigning)
node = node[1]
elif t in _assign_types:
if len(node) > 2:
raise SyntaxError, "can't assign to operator"
node = node[1]
elif t == symbol.power:
if node[1][0] != symbol.atom:
raise SyntaxError, "can't assign to operator"
if len(node) > 2:
primary = self.com_node(node[1])
for i in range(2, len(node)-1):
ch = node[i]
if ch[0] == token.DOUBLESTAR:
raise SyntaxError, "can't assign to operator"
primary = self.com_apply_trailer(primary, ch)
return self.com_assign_trailer(primary, node[-1],
assigning)
node = node[1]
elif t == symbol.atom:
t = node[1][0]
if t == token.LPAR:
node = node[2]
if node[0] == token.RPAR:
raise SyntaxError, "can't assign to ()"
elif t == token.LSQB:
node = node[2]
if node[0] == token.RSQB:
raise SyntaxError, "can't assign to []"
return self.com_assign_list(node, assigning)
elif t == token.NAME:
return self.com_assign_name(node[1], assigning)
else:
raise SyntaxError, "can't assign to literal"
else:
raise SyntaxError, "bad assignment"
def com_assign_tuple(self, node, assigning):
assigns = []
for i in range(1, len(node), 2):
assigns.append(self.com_assign(node[i], assigning))
return AssTuple(assigns)
def com_assign_list(self, node, assigning):
assigns = []
for i in range(1, len(node), 2):
assigns.append(self.com_assign(node[i], assigning))
return AssList(assigns)
def com_assign_name(self, node, assigning):
n = AssName(node[1], assigning)
n.lineno = node[2]
return n
def com_assign_trailer(self, primary, node, assigning):
t = node[1][0]
if t == token.DOT:
return self.com_assign_attr(primary, node[2], assigning)
if t == token.LSQB:
return self.com_subscriptlist(primary, node[2], assigning)
if t == token.LPAR:
raise SyntaxError, "can't assign to function call"
raise SyntaxError, "unknown trailer type: %s" % t
def com_assign_attr(self, primary, node, assigning):
return AssAttr(primary, node[1], assigning)
def com_binary(self, constructor, nodelist):
"Compile 'NODE (OP NODE)*' into (type, [ node1, ..., nodeN ])."
l = len(nodelist)
if l == 1:
return self.com_node(nodelist[0])
items = []
for i in range(0, l, 2):
items.append(self.com_node(nodelist[i]))
return constructor(items)
def com_stmt(self, node):
result = self.com_node(node)
assert result is not None
if isinstance(result, Stmt):
return result
return Stmt([result])
def com_append_stmt(self, stmts, node):
result = self.com_node(node)
assert result is not None
if isinstance(result, Stmt):
stmts.extend(result.nodes)
else:
stmts.append(result)
if hasattr(symbol, 'list_for'):
def com_list_constructor(self, nodelist):
# listmaker: test ( list_for | (',' test)* [','] )
values = []
for i in range(1, len(nodelist)):
if nodelist[i][0] == symbol.list_for:
assert len(nodelist[i:]) == 1
return self.com_list_comprehension(values[0],
nodelist[i])
elif nodelist[i][0] == token.COMMA:
continue
values.append(self.com_node(nodelist[i]))
return List(values)
def com_list_comprehension(self, expr, node):
# list_iter: list_for | list_if
# list_for: 'for' exprlist 'in' testlist [list_iter]
# list_if: 'if' test [list_iter]
# XXX should raise SyntaxError for assignment
lineno = node[1][2]
fors = []
while node:
t = node[1][1]
if t == 'for':
assignNode = self.com_assign(node[2], OP_ASSIGN)
listNode = self.com_node(node[4])
newfor = ListCompFor(assignNode, listNode, [])
newfor.lineno = node[1][2]
fors.append(newfor)
if len(node) == 5:
node = None
else:
node = self.com_list_iter(node[5])
elif t == 'if':
test = self.com_node(node[2])
newif = ListCompIf(test)
newif.lineno = node[1][2]
newfor.ifs.append(newif)
if len(node) == 3:
node = None
else:
node = self.com_list_iter(node[3])
else:
raise SyntaxError, \
("unexpected list comprehension element: %s %d"
% (node, lineno))
n = ListComp(expr, fors)
n.lineno = lineno
return n
def com_list_iter(self, node):
assert node[0] == symbol.list_iter
return node[1]
else:
def com_list_constructor(self, nodelist):
values = []
for i in range(1, len(nodelist), 2):
values.append(self.com_node(nodelist[i]))
return List(values)
def com_dictmaker(self, nodelist):
# dictmaker: test ':' test (',' test ':' value)* [',']
items = []
for i in range(1, len(nodelist), 4):
items.append((self.com_node(nodelist[i]),
self.com_node(nodelist[i+2])))
return Dict(items)
def com_apply_trailer(self, primaryNode, nodelist):
t = nodelist[1][0]
if t == token.LPAR:
return self.com_call_function(primaryNode, nodelist[2])
if t == token.DOT:
return self.com_select_member(primaryNode, nodelist[2])
if t == token.LSQB:
return self.com_subscriptlist(primaryNode, nodelist[2], OP_APPLY)
raise SyntaxError, 'unknown node type: %s' % t
def com_select_member(self, primaryNode, nodelist):
if nodelist[0] != token.NAME:
raise SyntaxError, "member must be a name"
n = Getattr(primaryNode, nodelist[1])
n.lineno = nodelist[2]
return n
def com_call_function(self, primaryNode, nodelist):
if nodelist[0] == token.RPAR:
return CallFunc(primaryNode, [])
args = []
kw = 0
len_nodelist = len(nodelist)
for i in range(1, len_nodelist, 2):
node = nodelist[i]
if node[0] == token.STAR or node[0] == token.DOUBLESTAR:
break
kw, result = self.com_argument(node, kw)
args.append(result)
else:
# No broken by star arg, so skip the last one we processed.
i = i + 1
if i < len_nodelist and nodelist[i][0] == token.COMMA:
# need to accept an application that looks like "f(a, b,)"
i = i + 1
star_node = dstar_node = None
while i < len_nodelist:
tok = nodelist[i]
ch = nodelist[i+1]
i = i + 3
if tok[0]==token.STAR:
if star_node is not None:
raise SyntaxError, 'already have the varargs indentifier'
star_node = self.com_node(ch)
elif tok[0]==token.DOUBLESTAR:
if dstar_node is not None:
raise SyntaxError, 'already have the kwargs indentifier'
dstar_node = self.com_node(ch)
else:
raise SyntaxError, 'unknown node type: %s' % tok
return CallFunc(primaryNode, args, star_node, dstar_node)
def com_argument(self, nodelist, kw):
if len(nodelist) == 2:
if kw:
raise SyntaxError, "non-keyword arg after keyword arg"
return 0, self.com_node(nodelist[1])
result = self.com_node(nodelist[3])
n = nodelist[1]
while len(n) == 2 and n[0] != token.NAME:
n = n[1]
if n[0] != token.NAME:
raise SyntaxError, "keyword can't be an expression (%s)"%n[0]
node = Keyword(n[1], result)
node.lineno = n[2]
return 1, node
def com_subscriptlist(self, primary, nodelist, assigning):
# slicing: simple_slicing | extended_slicing
# simple_slicing: primary "[" short_slice "]"
# extended_slicing: primary "[" slice_list "]"
# slice_list: slice_item ("," slice_item)* [","]
# backwards compat slice for '[i:j]'
if len(nodelist) == 2:
sub = nodelist[1]
if (sub[1][0] == token.COLON or \
(len(sub) > 2 and sub[2][0] == token.COLON)) and \
sub[-1][0] != symbol.sliceop:
return self.com_slice(primary, sub, assigning)
subscripts = []
for i in range(1, len(nodelist), 2):
subscripts.append(self.com_subscript(nodelist[i]))
return Subscript(primary, assigning, subscripts)
def com_subscript(self, node):
# slice_item: expression | proper_slice | ellipsis
ch = node[1]
t = ch[0]
if t == token.DOT and node[2][0] == token.DOT:
return Ellipsis()
if t == token.COLON or len(node) > 2:
return self.com_sliceobj(node)
return self.com_node(ch)
def com_sliceobj(self, node):
# proper_slice: short_slice | long_slice
# short_slice: [lower_bound] ":" [upper_bound]
# long_slice: short_slice ":" [stride]
# lower_bound: expression
# upper_bound: expression
# stride: expression
#
# Note: a stride may be further slicing...
items = []
if node[1][0] == token.COLON:
items.append(Const(None))
i = 2
else:
items.append(self.com_node(node[1]))
# i == 2 is a COLON
i = 3
if i < len(node) and node[i][0] == symbol.test:
items.append(self.com_node(node[i]))
i = i + 1
else:
items.append(Const(None))
# a short_slice has been built. look for long_slice now by looking
# for strides...
for j in range(i, len(node)):
ch = node[j]
if len(ch) == 2:
items.append(Const(None))
else:
items.append(self.com_node(ch[2]))
return Sliceobj(items)
def com_slice(self, primary, node, assigning):
# short_slice: [lower_bound] ":" [upper_bound]
lower = upper = None
if len(node) == 3:
if node[1][0] == token.COLON:
upper = self.com_node(node[2])
else:
lower = self.com_node(node[1])
elif len(node) == 4:
lower = self.com_node(node[1])
upper = self.com_node(node[3])
return Slice(primary, assigning, lower, upper)
def get_docstring(self, node, n=None):
if n is None:
n = node[0]
node = node[1:]
if n == symbol.suite:
if len(node) == 1:
return self.get_docstring(node[0])
for sub in node:
if sub[0] == symbol.stmt:
return self.get_docstring(sub)
return None
if n == symbol.file_input:
for sub in node:
if sub[0] == symbol.stmt:
return self.get_docstring(sub)
return None
if n == symbol.atom:
if node[0][0] == token.STRING:
s = ''
for t in node:
s = s + eval(t[1])
return s
return None
if n == symbol.stmt or n == symbol.simple_stmt \
or n == symbol.small_stmt:
return self.get_docstring(node[0])
if n in _doc_nodes and len(node) == 1:
return self.get_docstring(node[0])
return None
_doc_nodes = [
symbol.expr_stmt,
symbol.testlist,
symbol.test,
symbol.and_test,
symbol.not_test,
symbol.comparison,
symbol.expr,
symbol.xor_expr,
symbol.and_expr,
symbol.shift_expr,
symbol.arith_expr,
symbol.term,
symbol.factor,
symbol.power,
]
# comp_op: '<' | '>' | '=' | '>=' | '<=' | '<>' | '!=' | '=='
# | 'in' | 'not' 'in' | 'is' | 'is' 'not'
_cmp_types = {
token.LESS : '<',
token.GREATER : '>',
token.EQEQUAL : '==',
token.EQUAL : '==',
token.LESSEQUAL : '<=',
token.GREATEREQUAL : '>=',
token.NOTEQUAL : '!=',
}
_legal_node_types = [
symbol.funcdef,
symbol.classdef,
symbol.stmt,
symbol.small_stmt,
symbol.flow_stmt,
symbol.simple_stmt,
symbol.compound_stmt,
symbol.expr_stmt,
symbol.print_stmt,
symbol.del_stmt,
symbol.pass_stmt,
symbol.break_stmt,
symbol.continue_stmt,
symbol.return_stmt,
symbol.raise_stmt,
symbol.import_stmt,
symbol.global_stmt,
symbol.exec_stmt,
symbol.assert_stmt,
symbol.if_stmt,
symbol.while_stmt,
symbol.for_stmt,
symbol.try_stmt,
symbol.suite,
symbol.testlist,
symbol.test,
symbol.and_test,
symbol.not_test,
symbol.comparison,
symbol.exprlist,
symbol.expr,
symbol.xor_expr,
symbol.and_expr,
symbol.shift_expr,
symbol.arith_expr,
symbol.term,
symbol.factor,
symbol.power,
symbol.atom,
]
_assign_types = [
symbol.test,
symbol.and_test,
symbol.not_test,
symbol.comparison,
symbol.expr,
symbol.xor_expr,
symbol.and_expr,
symbol.shift_expr,
symbol.arith_expr,
symbol.term,
symbol.factor,
]
import types
_names = {}
for k, v in symbol.sym_name.items():
_names[k] = v
for k, v in token.tok_name.items():
_names[k] = v
def debug_tree(tree):
l = []
for elt in tree:
if type(elt) == types.IntType:
l.append(_names.get(elt, elt))
elif type(elt) == types.StringType:
l.append(elt)
else:
l.append(debug_tree(elt))
return l
import sys
import ast
class ASTVisitor:
"""Performs a depth-first walk of the AST
The ASTVisitor will walk the AST, performing either a preorder or
postorder traversal depending on which method is called.
methods:
preorder(tree, visitor)
postorder(tree, visitor)
tree: an instance of ast.Node
visitor: an instance with visitXXX methods
The ASTVisitor is responsible for walking over the tree in the
correct order. For each node, it checks the visitor argument for
a method named 'visitNodeType' where NodeType is the name of the
node's class, e.g. Class. If the method exists, it is called
with the node as its sole argument.
The visitor method for a particular node type can control how
child nodes are visited during a preorder walk. (It can't control
the order during a postorder walk, because it is called _after_
the walk has occurred.) The ASTVisitor modifies the visitor
argument by adding a visit method to the visitor; this method can
be used to visit a particular child node. If the visitor method
returns a true value, the ASTVisitor will not traverse the child
nodes.
XXX The interface for controlling the preorder walk needs to be
re-considered. The current interface is convenient for visitors
that mostly let the ASTVisitor do everything. For something like
a code generator, where you want to walk to occur in a specific
order, it's a pain to add "return 1" to the end of each method.
"""
VERBOSE = 0
def __init__(self):
self.node = None
self._cache = {}
def default(self, node, *args):
for child in node.getChildren():
if isinstance(child, ast.Node):
apply(self._preorder, (child,) + args)
def dispatch(self, node, *args):
self.node = node
klass = node.__class__
meth = self._cache.get(klass, None)
if meth is None:
className = klass.__name__
meth = getattr(self.visitor, 'visit' + className, self.default)
self._cache[klass] = meth
if self.VERBOSE > 0:
className = klass.__name__
if self.VERBOSE == 1:
if meth == 0:
print "dispatch", className
else:
print "dispatch", className, (meth and meth.__name__ or '')
return meth(node, *args)
def preorder(self, tree, visitor, *args):
"""Do preorder walk of tree using visitor"""
self.visitor = visitor
visitor.visit = self._preorder
self._preorder(tree, *args) # XXX *args make sense?
_preorder = dispatch
class ExampleASTVisitor(ASTVisitor):
"""Prints examples of the nodes that aren't visited
This visitor-driver is only useful for development, when it's
helpful to develop a visitor incremently, and get feedback on what
you still have to do.
"""
examples = {}
def dispatch(self, node, *args):
self.node = node
meth = self._cache.get(node.__class__, None)
className = node.__class__.__name__
if meth is None:
meth = getattr(self.visitor, 'visit' + className, 0)
self._cache[node.__class__] = meth
if self.VERBOSE > 1:
print "dispatch", className, (meth and meth.__name__ or '')
if meth:
return apply(meth, (node,) + args)
elif self.VERBOSE > 0:
klass = node.__class__
if not self.examples.has_key(klass):
self.examples[klass] = klass
print
print self.visitor
print klass
for attr in dir(node):
if attr[0] != '_':
print "\t", "%-12.12s" % attr, getattr(node, attr)
print
return apply(self.default, (node,) + args)
_walker = ASTVisitor
def walk(tree, visitor, verbose=None):
w = _walker()
if verbose is not None:
w.VERBOSE = verbose
w.preorder(tree, visitor)
return w.visitor
def dumpNode(node):
print node.__class__
for attr in dir(node):
if attr[0] != '_':
print "\t", "%-10.10s" % attr, getattr(node, attr)
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# compile_restricted() but not when using compile(). # compile_restricted() but not when using compile().
# Each function in this module is compiled using compile_restricted(). # Each function in this module is compiled using compile_restricted().
from __future__ import generators
def overrideGuardWithFunction(): def overrideGuardWithFunction():
def _getattr(o): return o def _getattr(o): return o
...@@ -40,3 +42,12 @@ def check_getattr_in_lambda(arg=lambda _getattr=(lambda ob, name: name): ...@@ -40,3 +42,12 @@ def check_getattr_in_lambda(arg=lambda _getattr=(lambda ob, name: name):
def import_as_bad_name(): def import_as_bad_name():
import os as _leading_underscore import os as _leading_underscore
def except_using_bad_name():
try:
foo
except NameError, _leading_underscore:
# The name of choice (say, _write) is now assigned to an exception
# object. Hard to exploit, but conceivable.
pass
...@@ -127,6 +127,12 @@ def guarded_getitem(ob, index): ...@@ -127,6 +127,12 @@ def guarded_getitem(ob, index):
raise AccessDenied raise AccessDenied
return v return v
def minimal_import(name, _globals, _locals, names):
if name != "__future__":
raise ValueError, "Only future imports are allowed"
import __future__
return __future__
class TestGuard: class TestGuard:
'''A guard class''' '''A guard class'''
...@@ -152,7 +158,6 @@ class TestGuard: ...@@ -152,7 +158,6 @@ class TestGuard:
_ob = self.__dict__['_ob'] _ob = self.__dict__['_ob']
_ob[lo:hi] = value _ob[lo:hi] = value
## attribute_of_anything = 98.6
class RestrictionTests(unittest.TestCase): class RestrictionTests(unittest.TestCase):
def execFunc(self, name, *args, **kw): def execFunc(self, name, *args, **kw):
...@@ -222,11 +227,12 @@ class RestrictionTests(unittest.TestCase): ...@@ -222,11 +227,12 @@ class RestrictionTests(unittest.TestCase):
f.close() f.close()
# Unrestricted compile. # Unrestricted compile.
code = compile(source, fn, 'exec') code = compile(source, fn, 'exec')
m = {'__builtins__':None} m = {'__builtins__': {'__import__':minimal_import}}
exec code in m exec code in m
for k, v in m.items(): for k, v in m.items():
if hasattr(v, 'func_code'): if hasattr(v, 'func_code'):
filename, source = find_source(fn, v.func_code) filename, source = find_source(fn, v.func_code)
source = "from __future__ import generators\n\n" + source
# Now compile it with restrictions # Now compile it with restrictions
try: try:
code = compile_restricted(source, filename, 'exec') code = compile_restricted(source, filename, 'exec')
...@@ -236,10 +242,6 @@ class RestrictionTests(unittest.TestCase): ...@@ -236,10 +242,6 @@ class RestrictionTests(unittest.TestCase):
else: else:
raise AssertionError, '%s should not have compiled' % k raise AssertionError, '%s should not have compiled' % k
## def checkStrangeAttribute(self):
## res = self.execFunc('strange_attribute')
## assert res == 98.6, res
def checkOrderOfOperations(self): def checkOrderOfOperations(self):
res = self.execFunc('order_of_operations') res = self.execFunc('order_of_operations')
assert (res == 0), res assert (res == 0), res
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment