Commit ae6126f2 authored by Tom Niget's avatar Tom Niget

Fix desugaring and internal handling of bin/unops

parent 2cb61dd6
......@@ -11,6 +11,7 @@ import traceback
import colorama
from transpiler.phases.desugar_compare import DesugarCompare
from transpiler.phases.desugar_op import DesugarOp
colorama.init()
......@@ -177,6 +178,7 @@ def transpile(source, name="<module>", path=None):
IfMainVisitor().visit(res)
res = DesugarWith().visit(res)
res = DesugarCompare().visit(res)
res = DesugarOp().visit(res)
ScoperBlockVisitor().visit(res)
# print(res.scope)
......
......@@ -28,6 +28,28 @@ SYMBOLS = {
}
"""Mapping of Python AST nodes to C++ symbols."""
DUNDER_SYMBOLS = {
"__eq__": "==",
"__ne__": "!=",
"__lt__": "<",
"__gt__": ">",
"__ge__": ">=",
"__le__": "<=",
"__add__": "+",
"__sub__": "-",
"__mul__": "*",
"__div__": "/",
"__mod__": "%",
"__lshift__": "<<",
"__rshift__": ">>",
"__xor__": "^",
"__or__": "|",
"__and__": "&",
"__invert__": "~",
"__neg__": "-",
"__pos__": "+",
}
PRECEDENCE = [
("()", "[]", ".",),
("unary", "co_await"),
......
# coding: utf-8
import ast
from transpiler.phases.typing.expr import DUNDER
from transpiler.phases.utils import make_lnd
from transpiler.utils import linenodata
DUNDER = {
ast.Eq: "eq",
ast.NotEq: "ne",
ast.Lt: "lt",
ast.Gt: "gt",
ast.GtE: "ge",
ast.LtE: "le",
ast.In: "contains",
ast.NotIn: "contains",
}
class DesugarCompare(ast.NodeTransformer):
def visit_Compare(self, node: ast.Compare):
res = ast.BoolOp(ast.And(), [], **linenodata(node))
for left, op, right in zip([node.left] + node.comparators, node.ops, node.comparators):
operands = list(map(self.visit, [node.left, *node.comparators]))
for left, op, right in zip(operands, node.ops, operands[1:]):
lnd = make_lnd(left, right)
if type(op) in (ast.In, ast.NotIn):
left, right = right, left
......@@ -25,3 +35,21 @@ class DesugarCompare(ast.NodeTransformer):
if len(res.values) == 1:
return res.values[0]
return res
# def visit_Compare(self, node: ast.Compare):
# res = ast.BoolOp(ast.And(), [], **linenodata(node))
# operands = list(map(self.visit, [node.left, *node.comparators]))
# for left, op, right in zip(operands, node.ops, operands[1:]):
# lnd = make_lnd(left, right)
# call = ast.Compare(
# left,
# [op],
# [right],
# **lnd
# )
# if type(op) == ast.NotIn:
# call = ast.UnaryOp(ast.Not(), call, **lnd)
# res.values.append(call)
# if len(res.values) == 1:
# return res.values[0]
# return res
# coding: utf-8
import ast
from transpiler.utils import linenodata
DUNDER = {
ast.Mult: "mul",
ast.Add: "add",
ast.Sub: "sub",
ast.Div: "truediv",
ast.FloorDiv: "floordiv",
ast.Mod: "mod",
ast.LShift: "lshift",
ast.RShift: "rshift",
ast.BitXor: "xor",
ast.BitOr: "or",
ast.BitAnd: "and",
ast.USub: "neg",
ast.UAdd: "pos",
ast.Invert: "invert",
}
class DesugarOp(ast.NodeTransformer):
def visit_BinOp(self, node: ast.BinOp):
lnd = linenodata(node)
return ast.Call(
func=ast.Attribute(
value=self.visit(node.left),
attr=f"__{DUNDER[type(node.op)]}__",
ctx=ast.Load(),
**lnd
),
args=[self.visit(node.right)],
keywords={},
**lnd
)
def visit_UnaryOp(self, node: ast.UnaryOp):
lnd = linenodata(node)
if type(node.op) == ast.Not:
return ast.UnaryOp(
operand=self.visit(node.operand),
op=node.op,
**lnd
)
return ast.Call(
func=ast.Attribute(
value=self.visit(node.operand),
attr=f"__{DUNDER[type(node.op)]}__",
ctx=ast.Load(),
**lnd
),
args=[],
keywords={},
**lnd
)
# def visit_AugAssign(self, node: ast.AugAssign):
# return
......@@ -31,4 +31,7 @@ def process(items: list[ast.withitem], body: list[ast.stmt]) -> PlainBlock:
class DesugarWith(ast.NodeTransformer):
def visit_With(self, node: ast.With):
return process(node.items, node.body)
return process(
list(map(self.visit, node.items)),
list(map(self.visit, node.body))
)
......@@ -6,7 +6,7 @@ from typing import List, Iterable
from transpiler.phases.typing.types import UserType, FunctionType
from transpiler.phases.utils import make_lnd
from transpiler.utils import compare_ast, linenodata
from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS
from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS, DUNDER_SYMBOLS
from transpiler.phases.emit_cpp import CoroutineMode, join, NodeVisitor
from transpiler.phases.typing.scope import Scope, VarKind
......@@ -115,10 +115,10 @@ class ExpressionVisitor(NodeVisitor):
ast.Or: "||"
}[type(node.op)]
with self.prec_ctx(cpp_op):
yield from self.visit_binary_operation(node.op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1]))
yield from self.visit_binary_operation(cpp_op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1]))
for left, right in zip(node.values[1:], node.values[2:]):
yield f" {cpp_op} "
yield from self.visit_binary_operation(node.op, left, right, make_lnd(left, right))
yield from self.visit_binary_operation(cpp_op, left, right, make_lnd(left, right))
def visit_Call(self, node: ast.Call) -> Iterable[str]:
# TODO
......@@ -129,6 +129,13 @@ class ExpressionVisitor(NodeVisitor):
if getattr(node, "kwargs", None):
raise NotImplementedError(node, "kwargs")
func = node.func
if isinstance(func, ast.Attribute):
if sym := DUNDER_SYMBOLS.get(func.attr, None):
if len(node.args) == 1:
yield from self.visit_binary_operation(sym, func.value, node.args[0], linenodata(node))
else:
yield from self.visit_unary_operation(sym, func.value)
return
for name in ("fork", "future"):
if compare_ast(func, ast.parse(name, mode="eval").body):
assert len(node.args) == 1
......@@ -180,13 +187,17 @@ class ExpressionVisitor(NodeVisitor):
def visit_BinOp(self, node: ast.BinOp) -> Iterable[str]:
yield from self.visit_binary_operation(node.op, node.left, node.right, linenodata(node))
def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
yield from self.visit_binary_operation(node.ops[0], node.left, node.comparators[0], linenodata(node))
def visit_binary_operation(self, op, left: ast.AST, right: ast.AST, lnd: dict) -> Iterable[str]:
if type(op) == ast.In:
call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd)
call.is_await = False
yield from self.visit_Call(call)
print(call.func.type)
return
# if type(op) == ast.In:
# call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd)
# call.is_await = False
# yield from self.visit_Call(call)
# print(call.func.type)
# return
if type(op) != str:
op = SYMBOLS[type(op)]
# TODO: handle precedence locally since only binops really need it
# we could just store the history of traversed nodes and check if the last one was a binop
......@@ -206,9 +217,9 @@ class ExpressionVisitor(NodeVisitor):
yield "dotp"
else:
yield "dot"
yield "("
yield "(("
yield from self.visit(node.value)
yield ", "
yield "), "
yield self.fix_name(node.attr)
yield ")"
else:
......@@ -261,8 +272,11 @@ class ExpressionVisitor(NodeVisitor):
yield "]"
def visit_UnaryOp(self, node: ast.UnaryOp) -> Iterable[str]:
yield from self.visit(node.op)
yield from self.prec("unary").visit(node.operand)
yield from self.visit_unary_operation(node.op, node.operand)
def visit_unary_operation(self, op, operand) -> Iterable[str]:
yield op
yield from self.prec("unary").visit(operand)
def visit_IfExp(self, node: ast.IfExp) -> Iterable[str]:
with self.prec_ctx("?:"):
......
......@@ -3,10 +3,11 @@ import dataclasses
import importlib
from dataclasses import dataclass
from transpiler.exceptions import CompileError
from transpiler.utils import highlight, linenodata
from transpiler.phases.typing import make_mod_decl
from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER
from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \
......@@ -249,11 +250,16 @@ class ScoperBlockVisitor(ScoperVisitor):
self.scope.global_scope.vars[name] = VarDecl(VarKind.LOCAL, None)
def visit_AugAssign(self, node: ast.AugAssign):
equivalent = ast.Assign(
targets=[node.target],
value=ast.BinOp(left=node.target, op=node.op, right=node.value, **linenodata(node)),
**linenodata(node))
self.visit(equivalent)
target, value = map(self.get_type, (node.target, node.value))
try:
self.expr().make_dunder([target, value], "i" + DUNDER[type(node.op)])
except CompileError as e:
self.visit_assign_target(node.target, self.expr().make_dunder([target, value], DUNDER[type(node.op)]))
# equivalent = ast.Assign(
# targets=[node.target],
# value=ast.BinOp(left=node.target, op=node.op, right=node.value, **linenodata(node)),
# **linenodata(node))
# self.visit(equivalent)
def visit(self, node: ast.AST):
if isinstance(node, ast.AST):
......
......@@ -149,9 +149,16 @@ class ScoperExprVisitor(ScoperVisitor):
node.body.decls = decls
return ftype
def visit_BinOp(self, node: ast.BinOp) -> BaseType:
left, right = map(self.visit, (node.left, node.right))
return self.make_dunder([left, right], DUNDER[type(node.op)])
# def visit_BinOp(self, node: ast.BinOp) -> BaseType:
# left, right = map(self.visit, (node.left, node.right))
# return self.make_dunder([left, right], DUNDER[type(node.op)])
# def visit_Compare(self, node: ast.Compare) -> BaseType:
# left, right = map(self.visit, (node.left, node.comparators[0]))
# op = node.ops[0]
# if type(op) == ast.In:
# left, right = right, left
# return self.make_dunder([left, right], DUNDER[type(op)])
def visit_Attribute(self, node: ast.Attribute) -> BaseType:
ltype = self.visit(node.value)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment