diff --git a/lib/mako/_ast_util.py b/lib/mako/_ast_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..728eef881971f2fe27f3a43f4470abd06c106200
--- /dev/null
+++ b/lib/mako/_ast_util.py
@@ -0,0 +1,833 @@
+# -*- coding: utf-8 -*-
+"""
+    ast
+    ~~~
+
+    The `ast` module helps Python applications to process trees of the Python
+    abstract syntax grammar.  The abstract syntax itself might change with
+    each Python release; this module helps to find out programmatically what
+    the current grammar looks like and allows modifications of it.
+
+    An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
+    a flag to the `compile()` builtin function or by using the `parse()`
+    function from this module.  The result will be a tree of objects whose
+    classes all inherit from `ast.AST`.
+
+    A modified abstract syntax tree can be compiled into a Python code object
+    using the built-in `compile()` function.
+
+    Additionally various helper functions are provided that make working with
+    the trees simpler.  The main intention of the helper functions and this
+    module in general is to provide an easy to use interface for libraries
+    that work tightly with the python syntax (template engines for example).
+
+
+    :copyright: Copyright 2008 by Armin Ronacher.
+    :license: Python License.
+"""
+from _ast import *
+
+
+BOOLOP_SYMBOLS = {
+    And:        'and',
+    Or:         'or'
+}
+
+BINOP_SYMBOLS = {
+    Add:        '+',
+    Sub:        '-',
+    Mult:       '*',
+    Div:        '/',
+    FloorDiv:   '//',
+    Mod:        '%',
+    LShift:     '<<',
+    RShift:     '>>',
+    BitOr:      '|',
+    BitAnd:     '&',
+    BitXor:     '^'
+}
+
+CMPOP_SYMBOLS = {
+    Eq:         '==',
+    Gt:         '>',
+    GtE:        '>=',
+    In:         'in',
+    Is:         'is',
+    IsNot:      'is not',
+    Lt:         '<',
+    LtE:        '<=',
+    NotEq:      '!=',
+    NotIn:      'not in'
+}
+
+UNARYOP_SYMBOLS = {
+    Invert:     '~',
+    Not:        'not',
+    UAdd:       '+',
+    USub:       '-'
+}
+
+ALL_SYMBOLS = {}
+ALL_SYMBOLS.update(BOOLOP_SYMBOLS)
+ALL_SYMBOLS.update(BINOP_SYMBOLS)
+ALL_SYMBOLS.update(CMPOP_SYMBOLS)
+ALL_SYMBOLS.update(UNARYOP_SYMBOLS)
+
+
+def parse(expr, filename='<unknown>', mode='exec'):
+    """Parse an expression into an AST node."""
+    return compile(expr, filename, mode, PyCF_ONLY_AST)
+
+
+def to_source(node, indent_with=' ' * 4):
+    """
+    This function can convert a node tree back into python sourcecode.  This
+    is useful for debugging purposes, especially if you're dealing with custom
+    asts not generated by python itself.
+
+    It could be that the sourcecode is evaluable when the AST itself is not
+    compilable / evaluable.  The reason for this is that the AST contains some
+    more data than regular sourcecode does, which is dropped during
+    conversion.
+
+    Each level of indentation is replaced with `indent_with`.  Per default this
+    parameter is equal to four spaces as suggested by PEP 8, but it might be
+    adjusted to match the application's styleguide.
+    """
+    generator = SourceGenerator(indent_with)
+    generator.visit(node)
+    return ''.join(generator.result)
+
+
+def dump(node):
+    """
+    A very verbose representation of the node passed.  This is useful for
+    debugging purposes.
+    """
+    def _format(node):
+        if isinstance(node, AST):
+            return '%s(%s)' % (node.__class__.__name__,
+                               ', '.join('%s=%s' % (a, _format(b))
+                                         for a, b in iter_fields(node)))
+        elif isinstance(node, list):
+            return '[%s]' % ', '.join(_format(x) for x in node)
+        return repr(node)
+    if not isinstance(node, AST):
+        raise TypeError('expected AST, got %r' % node.__class__.__name__)
+    return _format(node)
+
+
+def copy_location(new_node, old_node):
+    """
+    Copy the source location hint (`lineno` and `col_offset`) from the
+    old to the new node if possible and return the new one.
+    """
+    for attr in 'lineno', 'col_offset':
+        if attr in old_node._attributes and attr in new_node._attributes \
+           and hasattr(old_node, attr):
+            setattr(new_node, attr, getattr(old_node, attr))
+    return new_node
+
+
+def fix_missing_locations(node):
+    """
+    Some nodes require a line number and the column offset.  Without that
+    information the compiler will abort the compilation.  Because it can be
+    a dull task to add appropriate line numbers and column offsets when
+    adding new nodes this function can help.  It copies the line number and
+    column offset of the parent node to the child nodes without this
+    information.
+
+    Unlike `copy_location` this works recursive and won't touch nodes that
+    already have a location information.
+    """
+    def _fix(node, lineno, col_offset):
+        if 'lineno' in node._attributes:
+            if not hasattr(node, 'lineno'):
+                node.lineno = lineno
+            else:
+                lineno = node.lineno
+        if 'col_offset' in node._attributes:
+            if not hasattr(node, 'col_offset'):
+                node.col_offset = col_offset
+            else:
+                col_offset = node.col_offset
+        for child in iter_child_nodes(node):
+            _fix(child, lineno, col_offset)
+    _fix(node, 1, 0)
+    return node
+
+
+def increment_lineno(node, n=1):
+    """
+    Increment the line numbers of all nodes by `n` if they have line number
+    attributes.  This is useful to "move code" to a different location in a
+    file.
+    """
+    for node in zip((node,), walk(node)):
+        if 'lineno' in node._attributes:
+            node.lineno = getattr(node, 'lineno', 0) + n
+
+
+def iter_fields(node):
+    """Iterate over all fields of a node, only yielding existing fields."""
+    if not hasattr(node, '_fields') or not node._fields:
+        return
+    for field in node._fields:
+        try:
+            yield field, getattr(node, field)
+        except AttributeError:
+            pass
+
+
+def get_fields(node):
+    """Like `iter_fiels` but returns a dict."""
+    return dict(iter_fields(node))
+
+
+def iter_child_nodes(node):
+    """Iterate over all child nodes or a node."""
+    for name, field in iter_fields(node):
+        if isinstance(field, AST):
+            yield field
+        elif isinstance(field, list):
+            for item in field:
+                if isinstance(item, AST):
+                    yield item
+
+
+def get_child_nodes(node):
+    """Like `iter_child_nodes` but returns a list."""
+    return list(iter_child_nodes(node))
+
+
+def get_compile_mode(node):
+    """
+    Get the mode for `compile` of a given node.  If the node is not a `mod`
+    node (`Expression`, `Module` etc.) a `TypeError` is thrown.
+    """
+    if not isinstance(node, mod):
+        raise TypeError('expected mod node, got %r' % node.__class__.__name__)
+    return {
+        Expression:     'eval',
+        Interactive:    'single'
+    }.get(node.__class__, 'expr')
+
+
+def get_docstring(node):
+    """
+    Return the docstring for the given node or `None` if no docstring can be
+    found.  If the node provided does not accept docstrings a `TypeError`
+    will be raised.
+    """
+    if not isinstance(node, (FunctionDef, ClassDef, Module)):
+        raise TypeError("%r can't have docstrings" % node.__class__.__name__)
+    if node.body and isinstance(node.body[0], Str):
+        return node.body[0].s
+
+
+def walk(node):
+    """
+    Iterate over all nodes.  This is useful if you only want to modify nodes in
+    place and don't care about the context or the order the nodes are returned.
+    """
+    from collections import deque
+    todo = deque([node])
+    while todo:
+        node = todo.popleft()
+        todo.extend(iter_child_nodes(node))
+        yield node
+
+
+class NodeVisitor(object):
+    """
+    Walks the abstract syntax tree and call visitor functions for every node
+    found.  The visitor functions may return values which will be forwarded
+    by the `visit` method.
+
+    Per default the visitor functions for the nodes are ``'visit_'`` +
+    class name of the node.  So a `TryFinally` node visit function would
+    be `visit_TryFinally`.  This behavior can be changed by overriding
+    the `get_visitor` function.  If no visitor function exists for a node
+    (return value `None`) the `generic_visit` visitor is used instead.
+
+    Don't use the `NodeVisitor` if you want to apply changes to nodes during
+    traversing.  For this a special visitor exists (`NodeTransformer`) that
+    allows modifications.
+    """
+
+    def get_visitor(self, node):
+        """
+        Return the visitor function for this node or `None` if no visitor
+        exists for this node.  In that case the generic visit function is
+        used instead.
+        """
+        method = 'visit_' + node.__class__.__name__
+        return getattr(self, method, None)
+
+    def visit(self, node):
+        """Visit a node."""
+        f = self.get_visitor(node)
+        if f is not None:
+            return f(node)
+        return self.generic_visit(node)
+
+    def generic_visit(self, node):
+        """Called if no explicit visitor function exists for a node."""
+        for field, value in iter_fields(node):
+            if isinstance(value, list):
+                for item in value:
+                    if isinstance(item, AST):
+                        self.visit(item)
+            elif isinstance(value, AST):
+                self.visit(value)
+
+
+class NodeTransformer(NodeVisitor):
+    """
+    Walks the abstract syntax tree and allows modifications of nodes.
+
+    The `NodeTransformer` will walk the AST and use the return value of the
+    visitor functions to replace or remove the old node.  If the return
+    value of the visitor function is `None` the node will be removed
+    from the previous location otherwise it's replaced with the return
+    value.  The return value may be the original node in which case no
+    replacement takes place.
+
+    Here an example transformer that rewrites all `foo` to `data['foo']`::
+
+        class RewriteName(NodeTransformer):
+
+            def visit_Name(self, node):
+                return copy_location(Subscript(
+                    value=Name(id='data', ctx=Load()),
+                    slice=Index(value=Str(s=node.id)),
+                    ctx=node.ctx
+                ), node)
+
+    Keep in mind that if the node you're operating on has child nodes
+    you must either transform the child nodes yourself or call the generic
+    visit function for the node first.
+
+    Nodes that were part of a collection of statements (that applies to
+    all statement nodes) may also return a list of nodes rather than just
+    a single node.
+
+    Usually you use the transformer like this::
+
+        node = YourTransformer().visit(node)
+    """
+
+    def generic_visit(self, node):
+        for field, old_value in iter_fields(node):
+            old_value = getattr(node, field, None)
+            if isinstance(old_value, list):
+                new_values = []
+                for value in old_value:
+                    if isinstance(value, AST):
+                        value = self.visit(value)
+                        if value is None:
+                            continue
+                        elif not isinstance(value, AST):
+                            new_values.extend(value)
+                            continue
+                    new_values.append(value)
+                old_value[:] = new_values
+            elif isinstance(old_value, AST):
+                new_node = self.visit(old_value)
+                if new_node is None:
+                    delattr(node, field)
+                else:
+                    setattr(node, field, new_node)
+        return node
+
+
+class SourceGenerator(NodeVisitor):
+    """
+    This visitor is able to transform a well formed syntax tree into python
+    sourcecode.  For more details have a look at the docstring of the
+    `node_to_source` function.
+    """
+
+    def __init__(self, indent_with):
+        self.result = []
+        self.indent_with = indent_with
+        self.indentation = 0
+        self.new_lines = 0
+
+    def write(self, x):
+        if self.new_lines:
+            if self.result:
+                self.result.append('\n' * self.new_lines)
+            self.result.append(self.indent_with * self.indentation)
+            self.new_lines = 0
+        self.result.append(x)
+
+    def newline(self, n=1):
+        self.new_lines = max(self.new_lines, n)
+
+    def body(self, statements):
+        self.new_line = True
+        self.indentation += 1
+        for stmt in statements:
+            self.visit(stmt)
+        self.indentation -= 1
+
+    def body_or_else(self, node):
+        self.body(node.body)
+        if node.orelse:
+            self.newline()
+            self.write('else:')
+            self.body(node.orelse)
+
+    def signature(self, node):
+        want_comma = []
+        def write_comma():
+            if want_comma:
+                self.write(', ')
+            else:
+                want_comma.append(True)
+
+        padding = [None] * (len(node.args) - len(node.defaults))
+        for arg, default in zip(node.args, padding + node.defaults):
+            write_comma()
+            self.visit(arg)
+            if default is not None:
+                self.write('=')
+                self.visit(default)
+        if node.vararg is not None:
+            write_comma()
+            self.write('*' + node.vararg)
+        if node.kwarg is not None:
+            write_comma()
+            self.write('**' + node.kwarg)
+
+    def decorators(self, node):
+        for decorator in node.decorator_list:
+            self.newline()
+            self.write('@')
+            self.visit(decorator)
+
+    # Statements
+
+    def visit_Assign(self, node):
+        self.newline()
+        for idx, target in enumerate(node.targets):
+            if idx:
+                self.write(', ')
+            self.visit(target)
+        self.write(' = ')
+        self.visit(node.value)
+
+    def visit_AugAssign(self, node):
+        self.newline()
+        self.visit(node.target)
+        self.write(BINOP_SYMBOLS[type(node.op)] + '=')
+        self.visit(node.value)
+
+    def visit_ImportFrom(self, node):
+        self.newline()
+        self.write('from %s%s import ' % ('.' * node.level, node.module))
+        for idx, item in enumerate(node.names):
+            if idx:
+                self.write(', ')
+            self.write(item)
+
+    def visit_Import(self, node):
+        self.newline()
+        for item in node.names:
+            self.write('import ')
+            self.visit(item)
+
+    def visit_Expr(self, node):
+        self.newline()
+        self.generic_visit(node)
+
+    def visit_FunctionDef(self, node):
+        self.newline(n=2)
+        self.decorators(node)
+        self.newline()
+        self.write('def %s(' % node.name)
+        self.signature(node.args)
+        self.write('):')
+        self.body(node.body)
+
+    def visit_ClassDef(self, node):
+        have_args = []
+        def paren_or_comma():
+            if have_args:
+                self.write(', ')
+            else:
+                have_args.append(True)
+                self.write('(')
+
+        self.newline(n=3)
+        self.decorators(node)
+        self.newline()
+        self.write('class %s' % node.name)
+        for base in node.bases:
+            paren_or_comma()
+            self.visit(base)
+        # XXX: the if here is used to keep this module compatible
+        #      with python 2.6.
+        if hasattr(node, 'keywords'):
+            for keyword in node.keywords:
+                paren_or_comma()
+                self.write(keyword.arg + '=')
+                self.visit(keyword.value)
+            if node.starargs is not None:
+                paren_or_comma()
+                self.write('*')
+                self.visit(node.starargs)
+            if node.kwargs is not None:
+                paren_or_comma()
+                self.write('**')
+                self.visit(node.kwargs)
+        self.write(have_args and '):' or ':')
+        self.body(node.body)
+
+    def visit_If(self, node):
+        self.newline()
+        self.write('if ')
+        self.visit(node.test)
+        self.write(':')
+        self.body(node.body)
+        while True:
+            else_ = node.orelse
+            if len(else_) == 1 and isinstance(else_[0], If):
+                node = else_[0]
+                self.newline()
+                self.write('elif ')
+                self.visit(node.test)
+                self.write(':')
+                self.body(node.body)
+            else:
+                self.newline()
+                self.write('else:')
+                self.body(else_)
+                break
+
+    def visit_For(self, node):
+        self.newline()
+        self.write('for ')
+        self.visit(node.target)
+        self.write(' in ')
+        self.visit(node.iter)
+        self.write(':')
+        self.body_or_else(node)
+
+    def visit_While(self, node):
+        self.newline()
+        self.write('while ')
+        self.visit(node.test)
+        self.write(':')
+        self.body_or_else(node)
+
+    def visit_With(self, node):
+        self.newline()
+        self.write('with ')
+        self.visit(node.context_expr)
+        if node.optional_vars is not None:
+            self.write(' as ')
+            self.visit(node.optional_vars)
+        self.write(':')
+        self.body(node.body)
+
+    def visit_Pass(self, node):
+        self.newline()
+        self.write('pass')
+
+    def visit_Print(self, node):
+        # XXX: python 2.6 only
+        self.newline()
+        self.write('print ')
+        want_comma = False
+        if node.dest is not None:
+            self.write(' >> ')
+            self.visit(node.dest)
+            want_comma = True
+        for value in node.values:
+            if want_comma:
+                self.write(', ')
+            self.visit(value)
+            want_comma = True
+        if not node.nl:
+            self.write(',')
+
+    def visit_Delete(self, node):
+        self.newline()
+        self.write('del ')
+        for idx, target in enumerate(node):
+            if idx:
+                self.write(', ')
+            self.visit(target)
+
+    def visit_TryExcept(self, node):
+        self.newline()
+        self.write('try:')
+        self.body(node.body)
+        for handler in node.handlers:
+            self.visit(handler)
+
+    def visit_TryFinally(self, node):
+        self.newline()
+        self.write('try:')
+        self.body(node.body)
+        self.newline()
+        self.write('finally:')
+        self.body(node.finalbody)
+
+    def visit_Global(self, node):
+        self.newline()
+        self.write('global ' + ', '.join(node.names))
+
+    def visit_Nonlocal(self, node):
+        self.newline()
+        self.write('nonlocal ' + ', '.join(node.names))
+
+    def visit_Return(self, node):
+        self.newline()
+        self.write('return ')
+        self.visit(node.value)
+
+    def visit_Break(self, node):
+        self.newline()
+        self.write('break')
+
+    def visit_Continue(self, node):
+        self.newline()
+        self.write('continue')
+
+    def visit_Raise(self, node):
+        # XXX: Python 2.6 / 3.0 compatibility
+        self.newline()
+        self.write('raise')
+        if hasattr(node, 'exc') and node.exc is not None:
+            self.write(' ')
+            self.visit(node.exc)
+            if node.cause is not None:
+                self.write(' from ')
+                self.visit(node.cause)
+        elif hasattr(node, 'type') and node.type is not None:
+            self.visit(node.type)
+            if node.inst is not None:
+                self.write(', ')
+                self.visit(node.inst)
+            if node.tback is not None:
+                self.write(', ')
+                self.visit(node.tback)
+
+    # Expressions
+
+    def visit_Attribute(self, node):
+        self.visit(node.value)
+        self.write('.' + node.attr)
+
+    def visit_Call(self, node):
+        want_comma = []
+        def write_comma():
+            if want_comma:
+                self.write(', ')
+            else:
+                want_comma.append(True)
+
+        self.visit(node.func)
+        self.write('(')
+        for arg in node.args:
+            write_comma()
+            self.visit(arg)
+        for keyword in node.keywords:
+            write_comma()
+            self.write(keyword.arg + '=')
+            self.visit(keyword.value)
+        if node.starargs is not None:
+            write_comma()
+            self.write('*')
+            self.visit(node.starargs)
+        if node.kwargs is not None:
+            write_comma()
+            self.write('**')
+            self.visit(node.kwargs)
+        self.write(')')
+
+    def visit_Name(self, node):
+        self.write(node.id)
+
+    def visit_Str(self, node):
+        self.write(repr(node.s))
+
+    def visit_Bytes(self, node):
+        self.write(repr(node.s))
+
+    def visit_Num(self, node):
+        self.write(repr(node.n))
+
+    def visit_Tuple(self, node):
+        self.write('(')
+        idx = -1
+        for idx, item in enumerate(node.elts):
+            if idx:
+                self.write(', ')
+            self.visit(item)
+        self.write(idx and ')' or ',)')
+
+    def sequence_visit(left, right):
+        def visit(self, node):
+            self.write(left)
+            for idx, item in enumerate(node.elts):
+                if idx:
+                    self.write(', ')
+                self.visit(item)
+            self.write(right)
+        return visit
+
+    visit_List = sequence_visit('[', ']')
+    visit_Set = sequence_visit('{', '}')
+    del sequence_visit
+
+    def visit_Dict(self, node):
+        self.write('{')
+        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
+            if idx:
+                self.write(', ')
+            self.visit(key)
+            self.write(': ')
+            self.visit(value)
+        self.write('}')
+
+    def visit_BinOp(self, node):
+        self.write('(')
+        self.visit(node.left)
+        self.write(' %s ' % BINOP_SYMBOLS[type(node.op)])
+        self.visit(node.right)
+        self.write(')')
+
+    def visit_BoolOp(self, node):
+        self.write('(')
+        for idx, value in enumerate(node.values):
+            if idx:
+                self.write(' %s ' % BOOLOP_SYMBOLS[type(node.op)])
+            self.visit(value)
+        self.write(')')
+
+    def visit_Compare(self, node):
+        self.write('(')
+        self.visit(node.left)
+        for op, right in zip(node.ops, node.comparators):
+            self.write(' %s ' % CMPOP_SYMBOLS[type(op)])
+            self.visit(right)
+        self.write(')')
+
+    def visit_UnaryOp(self, node):
+        self.write('(')
+        op = UNARYOP_SYMBOLS[type(node.op)]
+        self.write(op)
+        if op == 'not':
+            self.write(' ')
+        self.visit(node.operand)
+        self.write(')')
+
+    def visit_Subscript(self, node):
+        self.visit(node.value)
+        self.write('[')
+        self.visit(node.slice)
+        self.write(']')
+
+    def visit_Slice(self, node):
+        if node.lower is not None:
+            self.visit(node.lower)
+        self.write(':')
+        if node.upper is not None:
+            self.visit(node.upper)
+        if node.step is not None:
+            self.write(':')
+            if not (isinstance(node.step, Name) and node.step.id == 'None'):
+                self.visit(node.step)
+
+    def visit_ExtSlice(self, node):
+        for idx, item in node.dims:
+            if idx:
+                self.write(', ')
+            self.visit(item)
+
+    def visit_Yield(self, node):
+        self.write('yield ')
+        self.visit(node.value)
+
+    def visit_Lambda(self, node):
+        self.write('lambda ')
+        self.signature(node.args)
+        self.write(': ')
+        self.visit(node.body)
+
+    def visit_Ellipsis(self, node):
+        self.write('Ellipsis')
+
+    def generator_visit(left, right):
+        def visit(self, node):
+            self.write(left)
+            self.visit(node.elt)
+            for comprehension in node.generators:
+                self.visit(comprehension)
+            self.write(right)
+        return visit
+
+    visit_ListComp = generator_visit('[', ']')
+    visit_GeneratorExp = generator_visit('(', ')')
+    visit_SetComp = generator_visit('{', '}')
+    del generator_visit
+
+    def visit_DictComp(self, node):
+        self.write('{')
+        self.visit(node.key)
+        self.write(': ')
+        self.visit(node.value)
+        for comprehension in node.generators:
+            self.visit(comprehension)
+        self.write('}')
+
+    def visit_IfExp(self, node):
+        self.visit(node.body)
+        self.write(' if ')
+        self.visit(node.test)
+        self.write(' else ')
+        self.visit(node.orelse)
+
+    def visit_Starred(self, node):
+        self.write('*')
+        self.visit(node.value)
+
+    def visit_Repr(self, node):
+        # XXX: python 2.6 only
+        self.write('`')
+        self.visit(node.value)
+        self.write('`')
+
+    # Helper Nodes
+
+    def visit_alias(self, node):
+        self.write(node.name)
+        if node.asname is not None:
+            self.write(' as ' + node.asname)
+
+    def visit_comprehension(self, node):
+        self.write(' for ')
+        self.visit(node.target)
+        self.write(' in ')
+        self.visit(node.iter)
+        if node.ifs:
+            for if_ in node.ifs:
+                self.write(' if ')
+                self.visit(if_)
+
+    def visit_excepthandler(self, node):
+        self.newline()
+        self.write('except')
+        if node.type is not None:
+            self.write(' ')
+            self.visit(node.type)
+            if node.name is not None:
+                self.write(' as ')
+                self.visit(node.name)
+        self.write(':')
+        self.body(node.body)
diff --git a/lib/mako/ast.py b/lib/mako/ast.py
index 3ed500f95ae0823a94ef42c6848233366a65f155..6d4ef029871f2f9a7bbfc06f0b43caa34617c639 100644
--- a/lib/mako/ast.py
+++ b/lib/mako/ast.py
@@ -6,18 +6,9 @@
 
 """utilities for analyzing expressions and blocks of Python code, as well as generating Python from AST nodes"""
 
-from compiler import ast, visitor
-from compiler import parse as compiler_parse
-from mako import util, exceptions
-from StringIO import StringIO
+from mako import exceptions, pyparser, util
 import re
 
-def parse(code, mode, **exception_kwargs):
-    try:
-        return compiler_parse(code, mode)
-    except SyntaxError, e:
-        raise exceptions.SyntaxException("(%s) %s (%s)" % (e.__class__.__name__, str(e), repr(code[0:50])), **exception_kwargs)
-    
 class PythonCode(object):
     """represents information about a string containing Python code"""
     def __init__(self, code, **exception_kwargs):
@@ -36,71 +27,12 @@ class PythonCode(object):
         # - AST is less likely to break with version changes (for example, the behavior of co_names changed a little bit
         # in python version 2.5)
         if isinstance(code, basestring):
-            expr = parse(code.lstrip(), "exec", **exception_kwargs)
+            expr = pyparser.parse(code.lstrip(), "exec", **exception_kwargs)
         else:
             expr = code
-        
-        class FindIdentifiers(object):
-            def __init__(self):
-                self.in_function = False
-                self.local_ident_stack = {}
-            def _add_declared(s, name):
-                if not s.in_function:
-                    self.declared_identifiers.add(name)
-            def visitClass(self, node, *args):
-                self._add_declared(node.name)
-            def visitAssName(self, node, *args):
-                self._add_declared(node.name)
-            def visitAssign(self, node, *args):
-                # flip around the visiting of Assign so the expression gets evaluated first, 
-                # in the case of a clause like "x=x+5" (x is undeclared)
-                self.visit(node.expr, *args)
-                for n in node.nodes:
-                    self.visit(n, *args)
-            def visitFunction(self,node, *args):
-                self._add_declared(node.name)
-                # push function state onto stack.  dont log any
-                # more identifiers as "declared" until outside of the function,
-                # but keep logging identifiers as "undeclared".
-                # track argument names in each function header so they arent counted as "undeclared"
-                saved = {}
-                inf = self.in_function
-                self.in_function = True
-                for arg in node.argnames:
-                    if arg in self.local_ident_stack:
-                        saved[arg] = True
-                    else:
-                        self.local_ident_stack[arg] = True
-                for n in node.getChildNodes():
-                    self.visit(n, *args)
-                self.in_function = inf
-                for arg in node.argnames:
-                    if arg not in saved:
-                        del self.local_ident_stack[arg]
-            def visitFor(self, node, *args):
-                # flip around visit
-                self.visit(node.list, *args)
-                self.visit(node.assign, *args)
-                self.visit(node.body, *args)
-            def visitName(s, node, *args):
-                if node.name not in __builtins__ and node.name not in self.declared_identifiers and node.name not in s.local_ident_stack:
-                    self.undeclared_identifiers.add(node.name)
-            def visitImport(self, node, *args):
-                for (mod, alias) in node.names:
-                    if alias is not None:
-                        self._add_declared(alias)
-                    else:
-                        self._add_declared(mod.split('.')[0])
-            def visitFrom(self, node, *args):
-                for (mod, alias) in node.names:
-                    if alias is not None:
-                        self._add_declared(alias)
-                    else:
-                        if mod == '*':
-                            raise exceptions.CompileException("'import *' is not supported, since all identifier names must be explicitly declared.  Please use the form 'from <modulename> import <name1>, <name2>, ...' instead.", **exception_kwargs)
-                        self._add_declared(mod)
-        f = FindIdentifiers()
-        visitor.walk(expr, f) #, walker=walker())
+
+        f = pyparser.FindIdentifiers(self, **exception_kwargs)
+        f.visit(expr)
 
 class ArgumentList(object):
     """parses a fragment of code as a comma-separated list of expressions"""
@@ -109,25 +41,17 @@ class ArgumentList(object):
         self.args = []
         self.declared_identifiers = util.Set()
         self.undeclared_identifiers = util.Set()
-        class FindTuple(object):
-            def visitTuple(s, node, *args):
-                for n in node.nodes:
-                    p = PythonCode(n, **exception_kwargs)
-                    self.codeargs.append(p)
-                    self.args.append(ExpressionGenerator(n).value())
-                    self.declared_identifiers = self.declared_identifiers.union(p.declared_identifiers)
-                    self.undeclared_identifiers = self.undeclared_identifiers.union(p.undeclared_identifiers)
         if isinstance(code, basestring):
             if re.match(r"\S", code) and not re.match(r",\s*$", code):
                 # if theres text and no trailing comma, insure its parsed
                 # as a tuple by adding a trailing comma
                 code  += ","
-            expr = parse(code, "exec", **exception_kwargs)
+            expr = pyparser.parse(code, "exec", **exception_kwargs)
         else:
             expr = code
 
-        f = FindTuple()
-        visitor.walk(expr, f)
+        f = pyparser.FindTuple(self, PythonCode, **exception_kwargs)
+        f.visit(expr)
         
 class PythonFragment(PythonCode):
     """extends PythonCode to provide identifier lookups in partial control statements
@@ -157,27 +81,15 @@ class PythonFragment(PythonCode):
             raise exceptions.CompileException("Unsupported control keyword: '%s'" % keyword, **exception_kwargs)
         super(PythonFragment, self).__init__(code, **exception_kwargs)
         
-class walker(visitor.ASTVisitor):
-    def dispatch(self, node, *args):
-        print "Node:", str(node)
-        #print "dir:", dir(node)
-        return visitor.ASTVisitor.dispatch(self, node, *args)
         
 class FunctionDecl(object):
     """function declaration"""
     def __init__(self, code, allow_kwargs=True, **exception_kwargs):
         self.code = code
-        expr = parse(code, "exec", **exception_kwargs)
-        class ParseFunc(object):
-            def visitFunction(s, node, *args):
-                self.funcname = node.name
-                self.argnames = node.argnames
-                self.defaults = node.defaults
-                self.varargs = node.varargs
-                self.kwargs = node.kwargs
+        expr = pyparser.parse(code, "exec", **exception_kwargs)
                 
-        f = ParseFunc()
-        visitor.walk(expr, f)
+        f = pyparser.ParseFunc(self, **exception_kwargs)
+        f.visit(expr)
         if not hasattr(self, 'funcname'):
             raise exceptions.CompileException("Code '%s' is not a function declaration" % code, **exception_kwargs)
         if not allow_kwargs and self.kwargs:
@@ -202,7 +114,7 @@ class FunctionDecl(object):
             else:
                 default = len(defaults) and defaults.pop() or None
             if include_defaults and default:
-                namedecls.insert(0, "%s=%s" % (arg, ExpressionGenerator(default).value()))
+                namedecls.insert(0, "%s=%s" % (arg, pyparser.ExpressionGenerator(default).value()))
             else:
                 namedecls.insert(0, arg)
         return namedecls
@@ -211,135 +123,3 @@ class FunctionArgs(FunctionDecl):
     """the argument portion of a function declaration"""
     def __init__(self, code, **kwargs):
         super(FunctionArgs, self).__init__("def ANON(%s):pass" % code, **kwargs)
-        
-            
-class ExpressionGenerator(object):
-    """given an AST node, generates an equivalent literal Python expression."""
-    def __init__(self, astnode):
-        self.buf = StringIO()
-        visitor.walk(astnode, self) #, walker=walker())
-    def value(self):
-        return self.buf.getvalue()        
-    def operator(self, op, node, *args):
-        self.buf.write("(")
-        self.visit(node.left, *args)
-        self.buf.write(" %s " % op)
-        self.visit(node.right, *args)
-        self.buf.write(")")
-    def booleanop(self, op, node, *args):
-        self.visit(node.nodes[0])
-        for n in node.nodes[1:]:
-            self.buf.write(" " + op + " ")
-            self.visit(n, *args)
-    def visitConst(self, node, *args):
-        self.buf.write(repr(node.value))
-    def visitAssName(self, node, *args):
-        # TODO: figure out OP_ASSIGN, other OP_s
-        self.buf.write(node.name)
-    def visitName(self, node, *args):
-        self.buf.write(node.name)
-    def visitMul(self, node, *args):
-        self.operator("*", node, *args)
-    def visitAnd(self, node, *args):
-        self.booleanop("and", node, *args)
-    def visitOr(self, node, *args):
-        self.booleanop("or", node, *args)
-    def visitBitand(self, node, *args):
-        self.booleanop("&", node, *args)
-    def visitBitor(self, node, *args):
-        self.booleanop("|", node, *args)
-    def visitBitxor(self, node, *args):
-        self.booleanop("^", node, *args)
-    def visitAdd(self, node, *args):
-        self.operator("+", node, *args)
-    def visitGetattr(self, node, *args):
-        self.visit(node.expr, *args)
-        self.buf.write(".%s" % node.attrname)
-    def visitSub(self, node, *args):
-        self.operator("-", node, *args)
-    def visitNot(self, node, *args):
-        self.buf.write("not ")
-        self.visit(node.expr)
-    def visitDiv(self, node, *args):
-        self.operator("/", node, *args)
-    def visitFloorDiv(self, node, *args):
-        self.operator("//", node, *args)
-    def visitSubscript(self, node, *args):
-        self.visit(node.expr)
-        self.buf.write("[")
-        [self.visit(x) for x in node.subs]
-        self.buf.write("]")
-    def visitUnarySub(self, node, *args):
-        self.buf.write("-")
-        self.visit(node.expr)
-    def visitUnaryAdd(self, node, *args):
-        self.buf.write("-")
-        self.visit(node.expr)
-    def visitSlice(self, node, *args):
-        self.visit(node.expr)
-        self.buf.write("[")
-        if node.lower is not None:
-            self.visit(node.lower)
-        self.buf.write(":")
-        if node.upper is not None:
-            self.visit(node.upper)
-        self.buf.write("]")
-    def visitDict(self, node):
-        self.buf.write("{")
-        c = node.getChildren()
-        for i in range(0, len(c), 2):
-            self.visit(c[i])
-            self.buf.write(": ")
-            self.visit(c[i+1])
-            if i<len(c) -2:
-                self.buf.write(", ")
-        self.buf.write("}")
-    def visitTuple(self, node):
-        self.buf.write("(")
-        c = node.getChildren()
-        for i in range(0, len(c)):
-            self.visit(c[i])
-            if i<len(c) - 1:
-                self.buf.write(", ")
-        self.buf.write(")")
-    def visitList(self, node):
-        self.buf.write("[")
-        c = node.getChildren()
-        for i in range(0, len(c)):
-            self.visit(c[i])
-            if i<len(c) - 1:
-                self.buf.write(", ")
-        self.buf.write("]")
-    def visitListComp(self, node):
-        self.buf.write("[")
-        self.visit(node.expr)
-        self.buf.write(" ")
-        for n in node.quals:
-            self.visit(n)
-        self.buf.write("]")
-    def visitListCompFor(self, node):
-        self.buf.write(" for ")
-        self.visit(node.assign)
-        self.buf.write(" in ")
-        self.visit(node.list)
-        for n in node.ifs:
-            self.visit(n)
-    def visitListCompIf(self, node):
-        self.buf.write(" if ")
-        self.visit(node.test)
-    def visitCompare(self, node):
-        self.visit(node.expr)
-        for tup in node.ops:
-            self.buf.write(tup[0])
-            self.visit(tup[1])
-    def visitCallFunc(self, node, *args):
-        self.visit(node.node)
-        self.buf.write("(")
-        if len(node.args):
-            self.visit(node.args[0])
-            for a in node.args[1:]:
-                self.buf.write(", ")
-                self.visit(a)
-        self.buf.write(")")
-        
-        
\ No newline at end of file
diff --git a/lib/mako/pyparser.py b/lib/mako/pyparser.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd20c6786414affa6bb47c3c8808cee7a4a899de
--- /dev/null
+++ b/lib/mako/pyparser.py
@@ -0,0 +1,369 @@
+# ast.py
+# Copyright (C) Mako developers
+#
+# This module is part of Mako and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""Handles parsing of Python code.
+
+Parsing to AST is done via _ast on Python > 2.5, otherwise the compiler
+module is used.
+"""
+import sys
+from StringIO import StringIO
+from mako import exceptions, util
+
+new_ast = sys.version_info > (2, 5)
+
+if new_ast:
+    import _ast
+    util.restore__ast(_ast)
+    import _ast_util
+else:
+    from compiler import parse as compiler_parse
+    from compiler import visitor
+
+
+def parse(code, mode='exec', **exception_kwargs):
+    """Parse an expression into AST"""
+    try:
+        if new_ast:
+            return _ast_util.parse(code, '<unknown>', mode)
+        else:
+            return compiler_parse(code, mode)
+    except Exception, e:
+        raise exceptions.SyntaxException("(%s) %s (%s)" % (e.__class__.__name__, str(e), repr(code[0:50])), **exception_kwargs)
+
+
+if new_ast:
+    class FindIdentifiers(_ast_util.NodeVisitor):
+        def __init__(self, listener, **exception_kwargs):
+            self.in_function = False
+            self.in_assign_targets = False
+            self.local_ident_stack = {}
+            self.listener = listener
+            self.exception_kwargs = exception_kwargs
+        def _add_declared(self, name):
+            if not self.in_function:
+                self.listener.declared_identifiers.add(name)
+        def visit_ClassDef(self, node):
+            self._add_declared(node.name)
+        def visit_Assign(self, node):
+            # flip around the visiting of Assign so the expression gets evaluated first, 
+            # in the case of a clause like "x=x+5" (x is undeclared)
+            self.visit(node.value)
+            in_a = self.in_assign_targets
+            self.in_assign_targets = True
+            for n in node.targets:
+                self.visit(n)
+            self.in_assign_targets = in_a
+        def visit_FunctionDef(self, node):
+            self._add_declared(node.name)
+            # push function state onto stack.  dont log any
+            # more identifiers as "declared" until outside of the function,
+            # but keep logging identifiers as "undeclared".
+            # track argument names in each function header so they arent counted as "undeclared"
+            saved = {}
+            inf = self.in_function
+            self.in_function = True
+            for arg in node.args.args:
+                if arg.id in self.local_ident_stack:
+                    saved[arg.id] = True
+                else:
+                    self.local_ident_stack[arg.id] = True
+            for n in node.body:
+                self.visit(n)
+            self.in_function = inf
+            for arg in node.args.args:
+                if arg.id not in saved:
+                    del self.local_ident_stack[arg.id]
+        def visit_For(self, node):
+            # flip around visit
+            self.visit(node.iter)
+            self.visit(node.target)
+            for statement in node.body:
+                self.visit(statement)
+            for statement in node.orelse:
+                self.visit(statement)
+        def visit_Name(self, node):
+            if isinstance(node.ctx, _ast.Store):
+                self._add_declared(node.id)
+            if node.id not in __builtins__ and node.id not in self.listener.declared_identifiers and node.id not in self.local_ident_stack:
+                self.listener.undeclared_identifiers.add(node.id)
+        def visit_Import(self, node):
+            for name in node.names:
+                if name.asname is not None:
+                    self._add_declared(name.asname)
+                else:
+                    self._add_declared(name.name.split('.')[0])
+        def visit_ImportFrom(self, node):
+            for name in node.names:
+                if name.asname is not None:
+                    self._add_declared(name.asname)
+                else:
+                    if name.name == '*':
+                        raise exceptions.CompileException("'import *' is not supported, since all identifier names must be explicitly declared.  Please use the form 'from <modulename> import <name1>, <name2>, ...' instead.", **self.exception_kwargs)
+                    self._add_declared(name.name)
+
+    class FindTuple(_ast_util.NodeVisitor):
+        def __init__(self, listener, code_factory, **exception_kwargs):
+            self.listener = listener
+            self.exception_kwargs = exception_kwargs
+            self.code_factory = code_factory
+        def visit_Tuple(self, node):
+            for n in node.elts:
+                p = self.code_factory(n, **self.exception_kwargs)
+                self.listener.codeargs.append(p)
+                self.listener.args.append(ExpressionGenerator(n).value())
+                self.listener.declared_identifiers = self.listener.declared_identifiers.union(p.declared_identifiers)
+                self.listener.undeclared_identifiers = self.listener.undeclared_identifiers.union(p.undeclared_identifiers)
+
+    class ParseFunc(_ast_util.NodeVisitor):
+        def __init__(self, listener, **exception_kwargs):
+            self.listener = listener
+            self.exception_kwargs = exception_kwargs
+        def visit_FunctionDef(self, node):
+            self.listener.funcname = node.name
+            argnames = [arg.id for arg in node.args.args]
+            if node.args.vararg:
+                argnames.append(node.args.vararg)
+            if node.args.kwarg:
+                argnames.append(node.args.kwarg)
+            self.listener.argnames = argnames
+            self.listener.defaults = node.args.defaults # ast
+            self.listener.varargs = node.args.vararg
+            self.listener.kwargs = node.args.kwarg
+
+    class ExpressionGenerator(object):
+        def __init__(self, astnode):
+            self.generator = _ast_util.SourceGenerator(' ' * 4)
+            self.generator.visit(astnode)
+        def value(self):
+            return ''.join(self.generator.result)
+else:
+    class FindIdentifiers(object):
+        def __init__(self, listener, **exception_kwargs):
+            self.in_function = False
+            self.local_ident_stack = {}
+            self.listener = listener
+            self.exception_kwargs = exception_kwargs
+        def _add_declared(self, name):
+            if not self.in_function:
+                self.listener.declared_identifiers.add(name)
+        def visitClass(self, node, *args):
+            self._add_declared(node.name)
+        def visitAssName(self, node, *args):
+            self._add_declared(node.name)
+        def visitAssign(self, node, *args):
+            # flip around the visiting of Assign so the expression gets evaluated first, 
+            # in the case of a clause like "x=x+5" (x is undeclared)
+            self.visit(node.expr, *args)
+            for n in node.nodes:
+                self.visit(n, *args)
+        def visitFunction(self,node, *args):
+            self._add_declared(node.name)
+            # push function state onto stack.  dont log any
+            # more identifiers as "declared" until outside of the function,
+            # but keep logging identifiers as "undeclared".
+            # track argument names in each function header so they arent counted as "undeclared"
+            saved = {}
+            inf = self.in_function
+            self.in_function = True
+            for arg in node.argnames:
+                if arg in self.local_ident_stack:
+                    saved[arg] = True
+                else:
+                    self.local_ident_stack[arg] = True
+            for n in node.getChildNodes():
+                self.visit(n, *args)
+            self.in_function = inf
+            for arg in node.argnames:
+                if arg not in saved:
+                    del self.local_ident_stack[arg]
+        def visitFor(self, node, *args):
+            # flip around visit
+            self.visit(node.list, *args)
+            self.visit(node.assign, *args)
+            self.visit(node.body, *args)
+        def visitName(self, node, *args):
+            if node.name not in __builtins__ and node.name not in self.listener.declared_identifiers and node.name not in self.local_ident_stack:
+                self.listener.undeclared_identifiers.add(node.name)
+        def visitImport(self, node, *args):
+            for (mod, alias) in node.names:
+                if alias is not None:
+                    self._add_declared(alias)
+                else:
+                    self._add_declared(mod.split('.')[0])
+        def visitFrom(self, node, *args):
+            for (mod, alias) in node.names:
+                if alias is not None:
+                    self._add_declared(alias)
+                else:
+                    if mod == '*':
+                        raise exceptions.CompileException("'import *' is not supported, since all identifier names must be explicitly declared.  Please use the form 'from <modulename> import <name1>, <name2>, ...' instead.", **self.exception_kwargs)
+                    self._add_declared(mod)
+        def visit(self, expr):
+            visitor.walk(expr, self) #, walker=walker())
+
+    class FindTuple(object):
+        def __init__(self, listener, code_factory, **exception_kwargs):
+            self.listener = listener
+            self.exception_kwargs = exception_kwargs
+            self.code_factory = code_factory
+        def visitTuple(self, node, *args):
+            for n in node.nodes:
+                p = self.code_factory(n, **self.exception_kwargs)
+                self.listener.codeargs.append(p)
+                self.listener.args.append(ExpressionGenerator(n).value())
+                self.listener.declared_identifiers = self.listener.declared_identifiers.union(p.declared_identifiers)
+                self.listener.undeclared_identifiers = self.listener.undeclared_identifiers.union(p.undeclared_identifiers)
+        def visit(self, expr):
+            visitor.walk(expr, self) #, walker=walker())
+
+    class ParseFunc(object):
+        def __init__(self, listener, **exception_kwargs):
+            self.listener = listener
+            self.exception_kwargs = exception_kwargs
+        def visitFunction(self, node, *args):
+            self.listener.funcname = node.name
+            self.listener.argnames = node.argnames
+            self.listener.defaults = node.defaults
+            self.listener.varargs = node.varargs
+            self.listener.kwargs = node.kwargs
+        def visit(self, expr):
+            visitor.walk(expr, self)
+
+    class ExpressionGenerator(object):
+        """given an AST node, generates an equivalent literal Python expression."""
+        def __init__(self, astnode):
+            self.buf = StringIO()
+            visitor.walk(astnode, self) #, walker=walker())
+        def value(self):
+            return self.buf.getvalue()        
+        def operator(self, op, node, *args):
+            self.buf.write("(")
+            self.visit(node.left, *args)
+            self.buf.write(" %s " % op)
+            self.visit(node.right, *args)
+            self.buf.write(")")
+        def booleanop(self, op, node, *args):
+            self.visit(node.nodes[0])
+            for n in node.nodes[1:]:
+                self.buf.write(" " + op + " ")
+                self.visit(n, *args)
+        def visitConst(self, node, *args):
+            self.buf.write(repr(node.value))
+        def visitAssName(self, node, *args):
+            # TODO: figure out OP_ASSIGN, other OP_s
+            self.buf.write(node.name)
+        def visitName(self, node, *args):
+            self.buf.write(node.name)
+        def visitMul(self, node, *args):
+            self.operator("*", node, *args)
+        def visitAnd(self, node, *args):
+            self.booleanop("and", node, *args)
+        def visitOr(self, node, *args):
+            self.booleanop("or", node, *args)
+        def visitBitand(self, node, *args):
+            self.booleanop("&", node, *args)
+        def visitBitor(self, node, *args):
+            self.booleanop("|", node, *args)
+        def visitBitxor(self, node, *args):
+            self.booleanop("^", node, *args)
+        def visitAdd(self, node, *args):
+            self.operator("+", node, *args)
+        def visitGetattr(self, node, *args):
+            self.visit(node.expr, *args)
+            self.buf.write(".%s" % node.attrname)
+        def visitSub(self, node, *args):
+            self.operator("-", node, *args)
+        def visitNot(self, node, *args):
+            self.buf.write("not ")
+            self.visit(node.expr)
+        def visitDiv(self, node, *args):
+            self.operator("/", node, *args)
+        def visitFloorDiv(self, node, *args):
+            self.operator("//", node, *args)
+        def visitSubscript(self, node, *args):
+            self.visit(node.expr)
+            self.buf.write("[")
+            [self.visit(x) for x in node.subs]
+            self.buf.write("]")
+        def visitUnarySub(self, node, *args):
+            self.buf.write("-")
+            self.visit(node.expr)
+        def visitUnaryAdd(self, node, *args):
+            self.buf.write("-")
+            self.visit(node.expr)
+        def visitSlice(self, node, *args):
+            self.visit(node.expr)
+            self.buf.write("[")
+            if node.lower is not None:
+                self.visit(node.lower)
+            self.buf.write(":")
+            if node.upper is not None:
+                self.visit(node.upper)
+            self.buf.write("]")
+        def visitDict(self, node):
+            self.buf.write("{")
+            c = node.getChildren()
+            for i in range(0, len(c), 2):
+                self.visit(c[i])
+                self.buf.write(": ")
+                self.visit(c[i+1])
+                if i<len(c) -2:
+                    self.buf.write(", ")
+            self.buf.write("}")
+        def visitTuple(self, node):
+            self.buf.write("(")
+            c = node.getChildren()
+            for i in range(0, len(c)):
+                self.visit(c[i])
+                if i<len(c) - 1:
+                    self.buf.write(", ")
+            self.buf.write(")")
+        def visitList(self, node):
+            self.buf.write("[")
+            c = node.getChildren()
+            for i in range(0, len(c)):
+                self.visit(c[i])
+                if i<len(c) - 1:
+                    self.buf.write(", ")
+            self.buf.write("]")
+        def visitListComp(self, node):
+            self.buf.write("[")
+            self.visit(node.expr)
+            self.buf.write(" ")
+            for n in node.quals:
+                self.visit(n)
+            self.buf.write("]")
+        def visitListCompFor(self, node):
+            self.buf.write(" for ")
+            self.visit(node.assign)
+            self.buf.write(" in ")
+            self.visit(node.list)
+            for n in node.ifs:
+                self.visit(n)
+        def visitListCompIf(self, node):
+            self.buf.write(" if ")
+            self.visit(node.test)
+        def visitCompare(self, node):
+            self.visit(node.expr)
+            for tup in node.ops:
+                self.buf.write(tup[0])
+                self.visit(tup[1])
+        def visitCallFunc(self, node, *args):
+            self.visit(node.node)
+            self.buf.write("(")
+            if len(node.args):
+                self.visit(node.args[0])
+                for a in node.args[1:]:
+                    self.buf.write(", ")
+                    self.visit(a)
+            self.buf.write(")")
+
+    class walker(visitor.ASTVisitor):
+        def dispatch(self, node, *args):
+            print "Node:", str(node)
+            #print "dir:", dir(node)
+            return visitor.ASTVisitor.dispatch(self, node, *args)
diff --git a/lib/mako/template.py b/lib/mako/template.py
index 36e4d43d829125200d154fe0fcbce87505cda671..aca51cd2ed297869b896e9781c151b2d4a44b6c2 100644
--- a/lib/mako/template.py
+++ b/lib/mako/template.py
@@ -10,8 +10,7 @@ as well as template runtime operations."""
 from mako.lexer import Lexer
 from mako import codegen
 from mako import runtime, util, exceptions
-import imp, time, weakref, tempfile, shutil,  os, stat, sys, re
-
+import imp, os, re, shutil, stat, sys, tempfile, time, types, weakref
 
     
 class Template(object):
@@ -205,7 +204,9 @@ def _compile_text(template, text, filename):
     source = codegen.compile(node, template.uri, filename, default_filters=template.default_filters, buffer_filters=template.buffer_filters, imports=template.imports, source_encoding=lexer.encoding, generate_unicode=not template.disable_unicode)
     #print source
     cid = identifier
-    module = imp.new_module(cid)
+    if isinstance(cid, unicode):
+        cid = cid.encode()
+    module = types.ModuleType(cid)
     code = compile(source, cid, 'exec')
     exec code in module.__dict__, module.__dict__
     return (source, module)
diff --git a/lib/mako/util.py b/lib/mako/util.py
index 0257464f2d5206f4f634c8a2af7c7722cefddd18..83bb0cfe8e021bd3cb15b78ac40e47b08f550d50 100644
--- a/lib/mako/util.py
+++ b/lib/mako/util.py
@@ -120,3 +120,78 @@ class LRUCache(dict):
                     # if we couldnt find a key, most likely some other thread broke in 
                     # on us. loop around and try again
                     break
+
+def restore__ast(_ast):
+    """Attempt to restore the required classes to the _ast module if it
+    appears to be missing them
+    """
+    if hasattr(_ast, 'AST'):
+        return
+    _ast.PyCF_ONLY_AST = 2 << 9
+    m = compile("""\
+def foo(): pass
+class Bar(object): pass
+if False: pass
+baz = 'mako'
+1 + 2 - 3 * 4 / 5
+6 // 7 % 8 << 9 >> 10
+11 & 12 ^ 13 | 14
+15 and 16 or 17
+-baz + (not +18) - ~17
+baz and 'foo' or 'bar'
+(mako is baz == baz) is not baz != mako
+mako > baz < mako >= baz <= mako
+mako in baz not in mako""", '<unknown>', 'exec', _ast.PyCF_ONLY_AST)
+    _ast.Module = type(m)
+
+    for cls in _ast.Module.__mro__:
+        if cls.__name__ == 'mod':
+            _ast.mod = cls
+        elif cls.__name__ == 'AST':
+            _ast.AST = cls
+
+    _ast.FunctionDef = type(m.body[0])
+    _ast.ClassDef = type(m.body[1])
+    _ast.If = type(m.body[2])
+
+    _ast.Name = type(m.body[3].targets[0])
+    _ast.Store = type(m.body[3].targets[0].ctx)
+    _ast.Str = type(m.body[3].value)
+
+    _ast.Sub = type(m.body[4].value.op)
+    _ast.Add = type(m.body[4].value.left.op)
+    _ast.Div = type(m.body[4].value.right.op)
+    _ast.Mult = type(m.body[4].value.right.left.op)
+
+    _ast.RShift = type(m.body[5].value.op)
+    _ast.LShift = type(m.body[5].value.left.op)
+    _ast.Mod = type(m.body[5].value.left.left.op)
+    _ast.FloorDiv = type(m.body[5].value.left.left.left.op)
+
+    _ast.BitOr = type(m.body[6].value.op)
+    _ast.BitXor = type(m.body[6].value.left.op)
+    _ast.BitAnd = type(m.body[6].value.left.left.op)
+
+    _ast.Or = type(m.body[7].value.op)
+    _ast.And = type(m.body[7].value.values[0].op)
+
+    _ast.Invert = type(m.body[8].value.right.op)
+    _ast.Not = type(m.body[8].value.left.right.op)
+    _ast.UAdd = type(m.body[8].value.left.right.operand.op)
+    _ast.USub = type(m.body[8].value.left.left.op)
+
+    _ast.Or = type(m.body[9].value.op)
+    _ast.And = type(m.body[9].value.values[0].op)
+
+    _ast.IsNot = type(m.body[10].value.ops[0])
+    _ast.NotEq = type(m.body[10].value.ops[1])
+    _ast.Is = type(m.body[10].value.left.ops[0])
+    _ast.Eq = type(m.body[10].value.left.ops[1])
+
+    _ast.Gt = type(m.body[11].value.ops[0])
+    _ast.Lt = type(m.body[11].value.ops[1])
+    _ast.GtE = type(m.body[11].value.ops[2])
+    _ast.LtE = type(m.body[11].value.ops[3])
+
+    _ast.In = type(m.body[12].value.ops[0])
+    _ast.NotIn = type(m.body[12].value.ops[1])
diff --git a/test/ast.py b/test/ast.py
index f8d604d9d028564b386ae17bb02dd309be7394f3..8f348e58360b377f2c9100b89f6f166a1bf58e10 100644
--- a/test/ast.py
+++ b/test/ast.py
@@ -1,7 +1,6 @@
 import unittest
 
-from mako import ast, util, exceptions
-from compiler import parse
+from mako import ast, exceptions, pyparser, util
 
 exception_kwargs = {'source':'', 'lineno':0, 'pos':0, 'filename':''}
 
@@ -183,8 +182,8 @@ import x as bar
         local_dict = dict(x=x, y=y, foo=F(), lala=lala)
         
         code = "str((x+7*y) / foo.bar(5,6)) + lala('ho')"
-        astnode = parse(code)
-        newcode = ast.ExpressionGenerator(astnode).value()
+        astnode = pyparser.parse(code)
+        newcode = pyparser.ExpressionGenerator(astnode).value()
         #print "newcode:" + newcode
         #print "result:" + eval(code, local_dict)
         assert (eval(code, local_dict) == eval(newcode, local_dict))
@@ -194,26 +193,26 @@ import x as bar
         g = [1,2,3,4,5]
         local_dict = dict(a=a,hoho=hoho,g=g)
         code = "a[2] + hoho['somevalue'] + repr(g[3:5]) + repr(g[3:]) + repr(g[:5])"
-        astnode = parse(code)
-        newcode = ast.ExpressionGenerator(astnode).value()
+        astnode = pyparser.parse(code)
+        newcode = pyparser.ExpressionGenerator(astnode).value()
         #print newcode
         #print "result:", eval(code, local_dict)
         assert(eval(code, local_dict) == eval(newcode, local_dict))
         
         local_dict={'f':lambda :9, 'x':7}
         code = "x+f()"
-        astnode = parse(code)
-        newcode = ast.ExpressionGenerator(astnode).value()
+        astnode = pyparser.parse(code)
+        newcode = pyparser.ExpressionGenerator(astnode).value()
         assert(eval(code, local_dict)) == eval(newcode, local_dict)
 
         for code in ["repr({'x':7,'y':18})", "repr([])", "repr({})", "repr([{3:[]}])", "repr({'x':37*2 + len([6,7,8])})", "repr([1, 2, {}, {'x':'7'}])", "repr({'x':-1})", "repr(((1,2,3), (4,5,6)))", "repr(1 and 2 and 3 and 4)", "repr(True and False or 55)", "repr(1 & 2 | 3)", "repr(3//5)", "repr(3^5)", "repr([q.endswith('e') for q in ['one', 'two', 'three']])", "repr([x for x in (5,6,7) if x == 6])", "repr(not False)"]:
             local_dict={}
-            astnode = parse(code)
-            newcode = ast.ExpressionGenerator(astnode).value()
+            astnode = pyparser.parse(code)
+            newcode = pyparser.ExpressionGenerator(astnode).value()
             #print code, newcode
             assert(eval(code, local_dict)) == eval(newcode, local_dict), "%s != %s" % (code, newcode)
 
 if __name__ == '__main__':
     unittest.main()
     
-    
\ No newline at end of file
+