Commit 72920c17 authored by Tom Niget's avatar Tom Niget

Update remote repository

parent a45dfe8e
//
// Created by Tom on 24/03/2023.
//
#ifndef TYPON_BASEDEF_HPP
#define TYPON_BASEDEF_HPP
template<typename Self>
class TyBuiltin {
template <typename... Args>
auto sync_wrapper(Args&&... args)
{
return static_cast<Self*>(this)->sync(std::forward<Args>(args)...);
}
public:
template <typename... Args>
auto operator()(Args&&... args) -> decltype(sync_wrapper(std::forward<Args>(args)...))
{
return sync_wrapper(std::forward<Args>(args)...);
}
};
#endif // TYPON_BASEDEF_HPP
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <string> #include <string>
#include <typon/typon.hpp> #include <typon/typon.hpp>
#include <python/basedef.hpp>
#ifdef __cpp_lib_unreachable #ifdef __cpp_lib_unreachable
#include <utility> #include <utility>
......
...@@ -7,25 +7,16 @@ ...@@ -7,25 +7,16 @@
#include <ranges> #include <ranges>
#include <typon/typon.hpp> #include <python/basedef.hpp>
// todo: proper range support // todo: proper range support
struct { struct range_s : TyBuiltin<range_s>
template <typename T> auto sync(T stop) { return std::views::iota(0, stop); } {
template <typename T> auto sync(T start, T stop) {
return std::views::iota(start, stop);
}
template <typename T> template <typename T>
auto operator()(T stop) -> typon::Task<decltype(sync(stop))> { auto sync(T stop) { return std::views::iota(0, stop); }
co_return sync(stop);
}
template <typename T> template <typename T>
auto operator()(T start, T stop) -> typon::Task<decltype(sync(start, stop))> { auto sync(T start, T stop) { return std::views::iota(start, stop); }
co_return sync(start, stop);
}
} range; } range;
#endif // TYPON_RANGE_HPP #endif // TYPON_RANGE_HPP
ALT_RUNNER="clang++-16 -O3 -Wno-return-type -Wno-unused-result -I../rt/include -std=c++20 -o {name_bin} {name_cpp_posix} -pthread -luring && {name_bin}"
\ No newline at end of file
__pycache__ __pycache__
\ No newline at end of file .env
\ No newline at end of file
Bclang-format==15.0.7 Bclang-format==15.0.7
......
from typing import Self, TypeVar, Generic
class int:
def __add__(self, other: Self) -> Self: ...
def __sub__(self, other: Self) -> Self: ...
def __mul__(self, other: Self) -> Self: ...
def __and__(self, other: Self) -> Self: ...
U = TypeVar("U")
class list(Generic[U]):
def __add__(self, other: Self) -> Self: ...
def __mul__(self, other: int) -> Self: ...
def first(self) -> U: ...
def print(*args) -> None: ...
stdout: CppType["auto&"]
from typing import Callable, TypeVar from typing import Callable, TypeVar, Generic
T = TypeVar("T") T = TypeVar("T")
class Fork(Generic[T]):
def get(self) -> T: ...
def fork(f: Callable[[], T]) -> T:
def fork(f: Callable[[], T]) -> Fork[T]:
# stub # stub
return f() class Res:
get = f
return Res
def future(f: Callable[[], T]) -> T: def future(f: Callable[[], T]) -> T:
...@@ -13,10 +18,10 @@ def future(f: Callable[[], T]) -> T: ...@@ -13,10 +18,10 @@ def future(f: Callable[[], T]) -> T:
return f() return f()
def sync(): def sync() -> None:
# stub # stub
pass pass
def is_cpp(): def is_cpp() -> bool:
return False return False
# coding: utf-8 # coding: utf-8
from os import system from os import system, environ
from pathlib import Path from pathlib import Path
from transpiler import transpile from transpiler import transpile
from transpiler.format import format_code from transpiler.format import format_code
# print(format_code("int x = 2 + ((3 * 5));;;")) # load .env file
from dotenv import load_dotenv
load_dotenv()
def run_tests(): def run_tests():
for path in Path('tests').glob('*.py'): for path in Path('tests').glob('*.py'):
...@@ -15,22 +18,21 @@ def run_tests(): ...@@ -15,22 +18,21 @@ def run_tests():
continue continue
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
res = format_code(transpile(f.read())) res = format_code(transpile(f.read()))
#print(res)
name_cpp = path.with_suffix('.cpp') name_cpp = path.with_suffix('.cpp')
with open(name_cpp, "w", encoding="utf-8") as fcpp: with open(name_cpp, "w", encoding="utf-8") as fcpp:
fcpp.write(res) fcpp.write(res)
name_bin = path.with_suffix('').as_posix() name_bin = path.with_suffix('').as_posix()
cmd = f"bash -c 'g++ -I../rt/include -std=c++20 -o {name_bin} {name_cpp.as_posix()}'"
commands = [ commands = [
cmd, f"bash -c 'PYTHONPATH=stdlib python3 ./{path.as_posix()}'",
f"bash -c 'PYTHONPATH=. python3 ./{path.as_posix()}'",
f"bash -c './{name_bin}'",
f"scp ./{name_bin} tom@192.168.139.128:/tmp/{name_bin}", # TODO: temporary test suite. Will fix.
f"bash -c 'ssh tom@192.168.139.128 \"chmod +x /tmp/{name_bin} && /tmp/{name_bin}\"'"
] ]
if alt := environ.get("ALT_RUNNER"):
commands.append(alt.format(name_bin=name_bin, name_cpp_posix=name_cpp.as_posix()))
for cmd in commands: for cmd in commands:
if system(cmd) != 0: if system(cmd) != 0:
print(f"Error running command: {cmd}") print(f"Error running command: {cmd}")
break break
#exit()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -4,27 +4,28 @@ from typon import is_cpp ...@@ -4,27 +4,28 @@ from typon import is_cpp
import sys as sis import sys as sis
from sys import stdout as truc from sys import stdout as truc
foo = 123
test = (2 + 3) * 4 test = (2 + 3) * 4
glob = 5 glob = 5
def g(): # def g():
a = 8 # a = 8
if True: # if True:
b = 9 # b = 9
if True: # if True:
c = 10 # c = 10
if True: # if True:
d = a + b + c # d = a + b + c
if True: # if True:
e = d + 1 # e = d + 1
print(e) # print(e)
def f(x): def f(x: int):
return x + 1 return x + 1
def fct(param): def fct(param):
loc = 456 loc = f(456)
global glob global glob
loc = 789 loc = 789
glob = 123 glob = 123
......
...@@ -17,7 +17,7 @@ def parallel_fibo(n: int) -> int: ...@@ -17,7 +17,7 @@ def parallel_fibo(n: int) -> int:
x = fork(lambda: fibo(n - 1)) x = fork(lambda: fibo(n - 1))
y = fork(lambda: fibo(n - 2)) y = fork(lambda: fibo(n - 2))
sync() sync()
return x + y return x.get() + y.get()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -6,8 +6,8 @@ def fibo(n: int) -> int: ...@@ -6,8 +6,8 @@ def fibo(n: int) -> int:
a = fork(lambda: fibo(n - 1)) a = fork(lambda: fibo(n - 1))
b = fork(lambda: fibo(n - 2)) b = fork(lambda: fibo(n - 2))
sync() sync()
return a + b return a.get() + b.get()
if __name__ == "__main__": if __name__ == "__main__":
print(fibo(30)) # should display 832040 print(fibo(20)) # should display 832040
\ No newline at end of file \ No newline at end of file
def fibo(n): def fibo(n: int):
if n < 2: if n < 2:
return n return n
a = fibo(n - 1) a = fibo(n - 1)
......
...@@ -2,10 +2,31 @@ ...@@ -2,10 +2,31 @@
import ast import ast
from transpiler.consts import MAPPINGS from transpiler.consts import MAPPINGS
from transpiler.scope import Scope #from transpiler.phases import initial_pytype
from transpiler.visitors.file import FileVisitor from transpiler.phases.emit_cpp.file import FileVisitor
from transpiler.phases.if_main import IfMainVisitor
from transpiler.phases.typing.block import ScoperBlockVisitor
from transpiler.phases.typing.scope import Scope
def transpile(source): def transpile(source):
tree = ast.parse(source) res = ast.parse(source)
return "\n".join(filter(None, map(str, FileVisitor(Scope()).visit(tree)))) #res = initial_pytype.run(source, res)
IfMainVisitor().visit(res)
ScoperBlockVisitor().visit(res)
#print(res.scope)
# display each scope
def disp_scope(scope, indent=0):
print(" " * indent, scope.kind)
for child in scope.children:
disp_scope(child, indent + 1)
for var in scope.vars.items():
print(" " * (indent + 1), var)
#disp_scope(res.scope)
code = "\n".join(filter(None, map(str, FileVisitor(Scope()).visit(res))))
return code
# coding: utf-8 # coding: utf-8
import ast import ast
from dataclasses import dataclass
from enum import Flag from enum import Flag
from itertools import zip_longest, chain from itertools import chain
from typing import Iterable, Union from typing import Iterable
from transpiler import MAPPINGS from transpiler.phases.emit_cpp.consts import MAPPINGS
from transpiler.phases.typing import TypeVariable
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, ForkResult
from transpiler.utils import UnsupportedNodeError
class UniversalVisitor:
class NodeVisitor:
def visit(self, node): def visit(self, node):
"""Visit a node.""" """Visit a node."""
if type(node) == list: if type(node) == list:
...@@ -25,6 +26,23 @@ class NodeVisitor: ...@@ -25,6 +26,23 @@ class NodeVisitor:
def missing_impl(self, node): def missing_impl(self, node):
raise UnsupportedNodeError(node) raise UnsupportedNodeError(node)
class TypeVisitor(UniversalVisitor):
def visit_TypeVariable(self, node: TypeVariable) -> Iterable[str]:
yield str(node)
#raise ValueError("Unresolved type variable")
def visit_BaseType(self, node: BaseType) -> Iterable[str]:
node = node.resolve()
if node is TY_INT:
yield "int"
elif node is TY_BOOL:
yield "bool"
elif isinstance(node, TypeVariable):
raise NotImplementedError(f"Not unified type variable {node}")
else:
raise NotImplementedError(node)
class NodeVisitor(UniversalVisitor):
def process_args(self, node: ast.arguments) -> (str, str, str): def process_args(self, node: ast.arguments) -> (str, str, str):
for field in ("posonlyargs", "vararg", "kwonlyargs", "kw_defaults", "kwarg", "defaults"): for field in ("posonlyargs", "vararg", "kwonlyargs", "kw_defaults", "kwarg", "defaults"):
if getattr(node, field, None): if getattr(node, field, None):
...@@ -44,14 +62,20 @@ class NodeVisitor: ...@@ -44,14 +62,20 @@ class NodeVisitor:
return f"py_{name[2:-2]}" return f"py_{name[2:-2]}"
return MAPPINGS.get(name, name) return MAPPINGS.get(name, name)
def visit_BaseType(self, node: BaseType) -> Iterable[str]:
@dataclass node = node.resolve()
class UnsupportedNodeError(Exception): if node is TY_INT:
node: ast.AST yield "int"
elif node is TY_BOOL:
def __str__(self) -> str: yield "bool"
return f"Unsupported node: {self.node.__class__.__mro__} {ast.dump(self.node)}" elif isinstance(node, ForkResult):
yield "Forked<"
yield from self.visit(node.return_type)
yield ">"
elif isinstance(node, TypeVariable):
raise NotImplementedError(f"Not unified type variable {node}")
else:
raise NotImplementedError(node)
class CoroutineMode(Flag): class CoroutineMode(Flag):
SYNC = 1 SYNC = 1
...@@ -61,7 +85,6 @@ class CoroutineMode(Flag): ...@@ -61,7 +85,6 @@ class CoroutineMode(Flag):
TASK = 16 | ASYNC TASK = 16 | ASYNC
JOIN = 32 | ASYNC JOIN = 32 | ASYNC
def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]: def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
items = iter(items) items = iter(items)
try: try:
...@@ -72,24 +95,5 @@ def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]: ...@@ -72,24 +95,5 @@ def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
except StopIteration: except StopIteration:
return return
def compare_ast(node1: Union[ast.expr, list[ast.expr]], node2: Union[ast.expr, list[ast.expr]]) -> bool:
if type(node1) is not type(node2):
return False
if isinstance(node1, ast.AST):
for k, v in vars(node1).items():
if k in {"lineno", "end_lineno", "col_offset", "end_col_offset", "ctx"}:
continue
if not compare_ast(v, getattr(node2, k)):
return False
return True
elif isinstance(node1, list) and isinstance(node2, list):
return all(compare_ast(n1, n2) for n1, n2 in zip_longest(node1, node2))
else:
return node1 == node2
def flatmap(f, items): def flatmap(f, items):
return chain.from_iterable(map(f, items)) return chain.from_iterable(map(f, items))
\ No newline at end of file
...@@ -3,10 +3,14 @@ import ast ...@@ -3,10 +3,14 @@ import ast
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Optional from typing import Iterable, Optional
from transpiler.scope import VarDecl, VarKind, Scope from transpiler.phases.typing.scope import Scope
from transpiler.visitors import CoroutineMode, NodeVisitor, flatmap, compare_ast from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TypeVariable
from transpiler.visitors.expr import ExpressionVisitor from transpiler.utils import compare_ast
from transpiler.visitors.search import SearchVisitor from transpiler.phases.emit_cpp import NodeVisitor, CoroutineMode, flatmap
from transpiler.phases.emit_cpp.expr import ExpressionVisitor
from transpiler.phases.emit_cpp.search import SearchVisitor
#from transpiler.scope import VarDecl, VarKind, Scope
# noinspection PyPep8Naming # noinspection PyPep8Naming
...@@ -19,6 +23,25 @@ class BlockVisitor(NodeVisitor): ...@@ -19,6 +23,25 @@ class BlockVisitor(NodeVisitor):
return ExpressionVisitor(self.scope, self.generator) return ExpressionVisitor(self.scope, self.generator)
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]: def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
if getattr(node, "is_main", False):
# Special case handling for Python's interesting way of defining an entry point.
# I mean, it's not *that* bad, it's just an attempt at retrofitting an "entry point" logic in a scripting
# language that, by essence, uses "the start of the file" as the implicit entry point, since files are
# read and executed line-by-line, contrary to usual structured languages that mark a distinction between
# declarations (functions, classes, modules, ...) and code.
# Also, for nitpickers, the C++ standard explicitly allows for omitting a `return` statement in the `main`.
# 0 is returned by default.
yield "typon::Root root()"
def block():
yield from node.body
yield ast.Return()
from transpiler.phases.emit_cpp.function import FunctionVisitor
yield from FunctionVisitor(self.scope, CoroutineMode.TASK).emit_block(node.scope, block())
yield "int main() { root().call(); }"
return
yield "struct {" yield "struct {"
yield from self.visit_func(node, CoroutineMode.FAKE) yield from self.visit_func(node, CoroutineMode.FAKE)
...@@ -56,7 +79,6 @@ class BlockVisitor(NodeVisitor): ...@@ -56,7 +79,6 @@ class BlockVisitor(NodeVisitor):
yield f"}} {node.name};" yield f"}} {node.name};"
def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]: def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]:
from transpiler.visitors.function import FunctionVisitor
templ, args, names = self.process_args(node.args) templ, args, names = self.process_args(node.args)
if templ: if templ:
yield "template" yield "template"
...@@ -98,7 +120,8 @@ class BlockVisitor(NodeVisitor): ...@@ -98,7 +120,8 @@ class BlockVisitor(NodeVisitor):
yield "Join" yield "Join"
yield f"<decltype(sync({', '.join(names)}))>" yield f"<decltype(sync({', '.join(names)}))>"
yield "{" yield "{"
inner_scope = self.scope.function(vars={node.name: VarDecl(VarKind.SELF, None)}) inner_scope = node.scope
for child in node.body: for child in node.body:
# Python uses module- and function- level scoping. Blocks, like conditionals and loops, do not form scopes # Python uses module- and function- level scoping. Blocks, like conditionals and loops, do not form scopes
# on their own. Variables are still accessible in the remainder of the parent function or in the global # on their own. Variables are still accessible in the remainder of the parent function or in the global
...@@ -138,23 +161,32 @@ class BlockVisitor(NodeVisitor): ...@@ -138,23 +161,32 @@ class BlockVisitor(NodeVisitor):
# auto y = 2; # auto y = 2;
# } # }
# ``` # ```
child_visitor = FunctionVisitor(inner_scope.child(), generator) from transpiler.phases.emit_cpp.function import FunctionVisitor
child_visitor = FunctionVisitor(inner_scope, generator)
# We need to do this in two-passes. This unfortunately breaks our nice generator state-machine architecture.
# Fair enough. if True:
[*child_code] = child_visitor.visit(child) for name, decl in getattr(child, "decls", {}).items():
#yield f"decltype({' '.join(self.expr().visit(decl.type))}) {name};"
# Hoist inner variables to the root scope. yield from self.visit(decl.type)
for var, decl in child_visitor.scope.vars.items(): yield f" {name};"
if decl.kind == VarKind.LOCAL: # Nested declarations become `decltype` declarations. yield from child_visitor.visit(child)
if getattr(decl.val[1], "in_await", False): else:
# TODO(zdimension): really? # We need to do this in two-passes. This unfortunately breaks our nice generator state-machine architecture.
yield f"decltype({decl.val[0][9:]}.operator co_await().await_resume()) {var};" # Fair enough.
else: # TODO(zdimension): break this in two visitors
yield f"decltype({decl.val[0]}) {var};" [*child_code] = child_visitor.visit(child)
elif decl.kind in (VarKind.GLOBAL, VarKind.NONLOCAL): # `global` and `nonlocal` just get hoisted as-is.
inner_scope.vars[var] = decl # Hoist inner variables to the root scope.
yield from child_code # Yeet back the child node code. for var, decl in child_visitor.scope.vars.items():
if decl.kind == VarKind.LOCAL: # Nested declarations become `decltype` declarations.
if getattr(decl.val[1], "in_await", False):
# TODO(zdimension): really?
yield f"decltype({decl.val[0][9:]}.operator co_await().await_resume()) {var};"
else:
yield f"decltype({decl.val[0]}) {var};"
elif decl.kind in (VarKind.GLOBAL, VarKind.NONLOCAL): # `global` and `nonlocal` just get hoisted as-is.
inner_scope.vars[var] = decl
yield from child_code # Yeet back the child node code.
if CoroutineMode.FAKE in generator: if CoroutineMode.FAKE in generator:
yield "TYPON_UNREACHABLE();" # So the compiler doesn't complain about missing return statements. yield "TYPON_UNREACHABLE();" # So the compiler doesn't complain about missing return statements.
elif CoroutineMode.ASYNC in generator and CoroutineMode.GENERATOR not in generator: elif CoroutineMode.ASYNC in generator and CoroutineMode.GENERATOR not in generator:
...@@ -162,15 +194,17 @@ class BlockVisitor(NodeVisitor): ...@@ -162,15 +194,17 @@ class BlockVisitor(NodeVisitor):
yield "co_return;" yield "co_return;"
yield "}" yield "}"
def visit_lvalue(self, lvalue: ast.expr, val: Optional[ast.AST] = None) -> Iterable[str]: def visit_lvalue(self, lvalue: ast.expr, declare: bool = False) -> Iterable[str]:
if isinstance(lvalue, ast.Tuple): if isinstance(lvalue, ast.Tuple):
yield f"std::tie({', '.join(flatmap(self.expr().visit, lvalue.elts))})" yield f"std::tie({', '.join(flatmap(self.expr().visit, lvalue.elts))})"
elif isinstance(lvalue, ast.Name): elif isinstance(lvalue, ast.Name):
name = self.fix_name(lvalue.id) name = self.fix_name(lvalue.id)
# if name not in self._scope.vars: # if name not in self._scope.vars:
if not self.scope.exists_local(name): # if not self.scope.exists_local(name):
yield self.scope.declare(name, (" ".join(self.expr().visit(val)), val) if val else None, # yield self.scope.declare(name, (" ".join(self.expr().visit(val)), val) if val else None,
getattr(val, "is_future", False)) # getattr(val, "is_future", False))
if declare:
yield from self.visit(lvalue.type)
yield name yield name
elif isinstance(lvalue, ast.Subscript): elif isinstance(lvalue, ast.Subscript):
yield from self.expr().visit(lvalue) yield from self.expr().visit(lvalue)
...@@ -180,7 +214,7 @@ class BlockVisitor(NodeVisitor): ...@@ -180,7 +214,7 @@ class BlockVisitor(NodeVisitor):
def visit_Assign(self, node: ast.Assign) -> Iterable[str]: def visit_Assign(self, node: ast.Assign) -> Iterable[str]:
if len(node.targets) != 1: if len(node.targets) != 1:
raise NotImplementedError(node) raise NotImplementedError(node)
yield from self.visit_lvalue(node.targets[0], node.value) yield from self.visit_lvalue(node.targets[0], node.is_declare)
yield " = " yield " = "
yield from self.expr().visit(node.value) yield from self.expr().visit(node.value)
yield ";" yield ";"
...@@ -188,7 +222,7 @@ class BlockVisitor(NodeVisitor): ...@@ -188,7 +222,7 @@ class BlockVisitor(NodeVisitor):
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]: def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
if node.value is None: if node.value is None:
raise NotImplementedError(node, "empty value") raise NotImplementedError(node, "empty value")
yield from self.visit_lvalue(node.target, node.value) yield from self.visit_lvalue(node.target)
yield " = " yield " = "
yield from self.expr().visit(node.value) yield from self.expr().visit(node.value)
yield ";" yield ";"
# coding: utf-8
import ast
SYMBOLS = {
ast.Eq: "==",
ast.NotEq: '!=',
ast.Pass: '/* pass */',
ast.Mult: '*',
ast.Add: '+',
ast.Sub: '-',
ast.Div: '/',
ast.FloorDiv: '/', # TODO
ast.Mod: '%',
ast.Lt: '<',
ast.Gt: '>',
ast.GtE: '>=',
ast.LtE: '<=',
ast.LShift: '<<',
ast.RShift: '>>',
ast.BitXor: '^',
ast.BitOr: '|',
ast.BitAnd: '&',
ast.Not: '!',
ast.IsNot: '!=',
ast.USub: '-',
ast.And: '&&',
ast.Or: '||'
}
"""Mapping of Python AST nodes to C++ symbols."""
PRECEDENCE = [
("()", "[]", ".",),
("unary", "co_await"),
("*", "/", "%",),
("+", "-"),
("<<", ">>"),
("<", "<=", ">", ">="),
("==", "!="),
("&",),
("^",),
("|",),
("&&",),
("||",),
("?:", "co_yield"),
(",",)
]
"""Precedence of C++ operators."""
PRECEDENCE_LEVELS = {op: i for i, ops in enumerate(PRECEDENCE) for op in ops}
"""Mapping of C++ operators to their precedence level."""
MAPPINGS = {
"True": "true",
"False": "false",
"None": "nullptr"
}
"""Mapping of Python builtin constants to C++ equivalents."""
...@@ -3,9 +3,10 @@ import ast ...@@ -3,9 +3,10 @@ import ast
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Iterable from typing import List, Iterable
from transpiler.utils import compare_ast
from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS
from transpiler.scope import VarKind, Scope from transpiler.phases.emit_cpp import CoroutineMode, join, NodeVisitor
from transpiler.visitors import CoroutineMode, NodeVisitor, join, compare_ast from transpiler.phases.typing.scope import Scope, VarKind
class PrecedenceContext: class PrecedenceContext:
...@@ -74,7 +75,7 @@ class ExpressionVisitor(NodeVisitor): ...@@ -74,7 +75,7 @@ class ExpressionVisitor(NodeVisitor):
def visit_Name(self, node: ast.Name) -> Iterable[str]: def visit_Name(self, node: ast.Name) -> Iterable[str]:
res = self.fix_name(node.id) res = self.fix_name(node.id)
if decl := self.scope.get(res): if False and (decl := self.scope.get(res)):
if decl.kind == VarKind.SELF: if decl.kind == VarKind.SELF:
res = "(*this)" res = "(*this)"
elif decl.future and CoroutineMode.ASYNC in self.generator: elif decl.future and CoroutineMode.ASYNC in self.generator:
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
import ast import ast
from typing import Iterable from typing import Iterable
from transpiler.visitors.block import BlockVisitor from transpiler.phases.emit_cpp.block import BlockVisitor
from transpiler.visitors.module import ModuleVisitor from transpiler.phases.emit_cpp.module import ModuleVisitor
# noinspection PyPep8Naming # noinspection PyPep8Naming
......
...@@ -4,9 +4,9 @@ from dataclasses import dataclass ...@@ -4,9 +4,9 @@ from dataclasses import dataclass
from typing import Iterable from typing import Iterable
from transpiler.consts import SYMBOLS from transpiler.consts import SYMBOLS
from transpiler.scope import VarDecl, VarKind from transpiler.phases.emit_cpp import CoroutineMode
from transpiler.visitors import CoroutineMode from transpiler.phases.emit_cpp.block import BlockVisitor
from transpiler.visitors.block import BlockVisitor from transpiler.phases.typing.scope import Scope
# noinspection PyPep8Naming # noinspection PyPep8Naming
...@@ -17,7 +17,7 @@ class FunctionVisitor(BlockVisitor): ...@@ -17,7 +17,7 @@ class FunctionVisitor(BlockVisitor):
yield ";" yield ";"
def visit_AugAssign(self, node: ast.AugAssign) -> Iterable[str]: def visit_AugAssign(self, node: ast.AugAssign) -> Iterable[str]:
yield from self.visit_lvalue(node.target) yield from self.visit_lvalue(node.target, False)
yield SYMBOLS[type(node.op)] + "=" yield SYMBOLS[type(node.op)] + "="
yield from self.expr().visit(node.value) yield from self.expr().visit(node.value)
yield ";" yield ";"
...@@ -28,7 +28,7 @@ class FunctionVisitor(BlockVisitor): ...@@ -28,7 +28,7 @@ class FunctionVisitor(BlockVisitor):
yield f"for (auto {node.target.id} : " yield f"for (auto {node.target.id} : "
yield from self.expr().visit(node.iter) yield from self.expr().visit(node.iter)
yield ")" yield ")"
yield from self.emit_block(node.body) yield from self.emit_block(node.scope, node.body)
if node.orelse: if node.orelse:
raise NotImplementedError(node, "orelse") raise NotImplementedError(node, "orelse")
...@@ -36,13 +36,13 @@ class FunctionVisitor(BlockVisitor): ...@@ -36,13 +36,13 @@ class FunctionVisitor(BlockVisitor):
yield "if (" yield "if ("
yield from self.expr().visit(node.test) yield from self.expr().visit(node.test)
yield ")" yield ")"
yield from self.emit_block(node.body) yield from self.emit_block(node.scope, node.body)
if node.orelse: if node.orelse:
yield "else " yield "else "
if isinstance(node.orelse, ast.If): if isinstance(node.orelse, ast.If):
yield from self.visit(node.orelse) yield from self.visit(node.orelse)
else: else:
yield from self.emit_block(node.orelse) yield from self.emit_block(node.orelse.scope, node.orelse)
def visit_Return(self, node: ast.Return) -> Iterable[str]: def visit_Return(self, node: ast.Return) -> Iterable[str]:
if CoroutineMode.ASYNC in self.generator: if CoroutineMode.ASYNC in self.generator:
...@@ -57,29 +57,24 @@ class FunctionVisitor(BlockVisitor): ...@@ -57,29 +57,24 @@ class FunctionVisitor(BlockVisitor):
yield "while (" yield "while ("
yield from self.expr().visit(node.test) yield from self.expr().visit(node.test)
yield ")" yield ")"
yield from self.emit_block(node.body) yield from self.emit_block(node.scope, node.body)
if node.orelse: if node.orelse:
raise NotImplementedError(node, "orelse") raise NotImplementedError(node, "orelse")
def visit_Global(self, node: ast.Global) -> Iterable[str]: def visit_Global(self, node: ast.Global) -> Iterable[str]:
for name in map(self.fix_name, node.names):
self.scope.vars[name] = VarDecl(VarKind.GLOBAL, None)
yield "" yield ""
def visit_Nonlocal(self, node: ast.Nonlocal) -> Iterable[str]: def visit_Nonlocal(self, node: ast.Nonlocal) -> Iterable[str]:
for name in map(self.fix_name, node.names):
self.scope.vars[name] = VarDecl(VarKind.NONLOCAL, None)
yield "" yield ""
def block(self) -> "FunctionVisitor": def block2(self) -> "FunctionVisitor":
# See the comments in visit_FunctionDef. # See the comments in visit_FunctionDef.
# A Python code block does not introduce a new scope, so we create a new `Scope` object that shares the same # A Python code block does not introduce a new scope, so we create a new `Scope` object that shares the same
# variables as the parent scope. # variables as the parent scope.
return FunctionVisitor(self.scope.child_share(), self.generator) return FunctionVisitor(self.scope.child_share(), self.generator)
def emit_block(self, items: Iterable[ast.stmt]) -> Iterable[str]: def emit_block(self, scope: Scope, items: Iterable[ast.stmt]) -> Iterable[str]:
yield "{" yield "{"
block = self.block()
for child in items: for child in items:
yield from block.visit(child) yield from FunctionVisitor(scope, self.generator).visit(child)
yield "}" yield "}"
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
import ast import ast
from typing import Iterable from typing import Iterable
from transpiler.visitors import CoroutineMode, compare_ast from transpiler.phases.emit_cpp import CoroutineMode
from transpiler.visitors.block import BlockVisitor from transpiler.phases.emit_cpp.block import BlockVisitor
from transpiler.visitors.function import FunctionVisitor from transpiler.phases.emit_cpp.function import FunctionVisitor
from transpiler.utils import compare_ast
# noinspection PyPep8Naming # noinspection PyPep8Naming
...@@ -27,24 +28,3 @@ class ModuleVisitor(BlockVisitor): ...@@ -27,24 +28,3 @@ class ModuleVisitor(BlockVisitor):
yield from self.import_module(node.module) yield from self.import_module(node.module)
for alias in node.names: for alias in node.names:
yield f"auto& {alias.asname or alias.name} = py_{node.module}::all.{alias.name};" yield f"auto& {alias.asname or alias.name} = py_{node.module}::all.{alias.name};"
def visit_If(self, node: ast.If) -> Iterable[str]:
if not node.orelse and compare_ast(node.test, ast.parse('__name__ == "__main__"', mode="eval").body):
# Special case handling for Python's interesting way of defining an entry point.
# I mean, it's not *that* bad, it's just an attempt at retrofitting an "entry point" logic in a scripting
# language that, by essence, uses "the start of the file" as the implicit entry point, since files are
# read and executed line-by-line, contrary to usual structured languages that mark a distinction between
# declarations (functions, classes, modules, ...) and code.
# Also, for nitpickers, the C++ standard explicitly allows for omitting a `return` statement in the `main`.
# 0 is returned by default.
yield "typon::Root root()"
def block():
yield from node.body
yield ast.Return()
yield from FunctionVisitor(self.scope.function(), CoroutineMode.TASK).emit_block(block())
yield "int main() { root().call(); }"
return
raise NotImplementedError(node, "global scope if")
# coding: utf-8 # coding: utf-8
import ast import ast
from transpiler.visitors import NodeVisitor from transpiler.phases.emit_cpp import NodeVisitor
class SearchVisitor(NodeVisitor): class SearchVisitor(NodeVisitor):
......
import ast
from transpiler.utils import compare_ast
NAME_MAIN = ast.parse('__name__ == "__main__"', mode="eval").body
class IfMainVisitor(ast.NodeVisitor):
def visit_Module(self, node: ast.Module):
for i, stmt in enumerate(node.body):
if isinstance(stmt, ast.If):
if not stmt.orelse and compare_ast(stmt.test, NAME_MAIN):
new_node = ast.FunctionDef(
name="main",
args=ast.arguments(args=[]),
body=stmt.body,
decorator_list=[],
returns=None
)
new_node.is_main = True
node.body[i] = new_node
return
\ No newline at end of file
import ast
import pytype.config
from pytype import io
from pytype.pytd import pytd_utils
from pytype.tools.traces import traces
def run(source: str, module: ast.Module) -> ast.Module:
opt = pytype.config.Options.create(None, no_return_any=True, precise_return=True)
source_code = infer_types(source, opt)
visitor = AnnotateAstVisitor(source_code, ast)
visitor.visit(module)
return module
def infer_types(source: str, options: pytype.config.Options) -> "source.Code":
with io.wrap_pytype_exceptions(PytypeError, filename=options.input):
return traces.trace(source, options)
class AnnotateAstVisitor(traces.MatchAstVisitor):
def visit_Name(self, node):
self._maybe_annotate(node)
def visit_Attribute(self, node):
self._maybe_annotate(node)
def visit_FunctionDef(self, node):
self._maybe_annotate(node)
def _maybe_annotate(self, node):
"""Annotates a node."""
try:
ops = self.match(node)
except NotImplementedError:
return
# For lack of a better option, take the first one.
unused_loc, entry = next(iter(ops), (None, None))
self._maybe_set_type(node, entry)
def _maybe_set_type(self, node, trace):
"""Sets type information on the node, if there is any to set."""
if not trace:
return
node.resolved_type = trace.types[-1]
node.resolved_annotation = _annotation_str_from_type_def(trace.types[-1])
class PytypeError(Exception):
"""Wrap exceptions raised by Pytype."""
def _annotation_str_from_type_def(type_def):
return pytd_utils.Print(type_def)
import ast
from pathlib import Path
from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind
from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor
from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, FunctionType, \
TypeVariable, TY_MODULE, CppType, PyList, TypeType, ForkResult
PRELUDE.vars.update({
# "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT),
# "str": VarDecl(VarKind.LOCAL, TY_TYPE, TY_STR),
# "bool": VarDecl(VarKind.LOCAL, TY_TYPE, TY_BOOL),
# "complex": VarDecl(VarKind.LOCAL, TY_TYPE, TY_COMPLEX),
# "None": VarDecl(VarKind.LOCAL, TY_NONE, None),
# "Callable": VarDecl(VarKind.LOCAL, TY_TYPE, FunctionType),
# "TypeVar": VarDecl(VarKind.LOCAL, TY_TYPE, TypeVariable),
# "CppType": VarDecl(VarKind.LOCAL, TY_TYPE, CppType),
# "list": VarDecl(VarKind.LOCAL, TY_TYPE, PyList),
"int": VarDecl(VarKind.LOCAL, TypeType(TY_INT)),
"str": VarDecl(VarKind.LOCAL, TypeType(TY_STR)),
"bool": VarDecl(VarKind.LOCAL, TypeType(TY_BOOL)),
"complex": VarDecl(VarKind.LOCAL, TypeType(TY_COMPLEX)),
"None": VarDecl(VarKind.LOCAL, TypeType(TY_NONE)),
"Callable": VarDecl(VarKind.LOCAL, FunctionType),
"TypeVar": VarDecl(VarKind.LOCAL, TypeVariable),
"CppType": VarDecl(VarKind.LOCAL, CppType),
"list": VarDecl(VarKind.LOCAL, PyList),
"Fork": VarDecl(VarKind.LOCAL, ForkResult),
})
typon_std = Path(__file__).parent.parent.parent.parent / "stdlib"
def discover_module(path: Path, scope):
for child in path.iterdir():
if child.is_dir():
mod_scope = PRELUDE.child(ScopeKind.GLOBAL)
discover_module(child, mod_scope)
scope.vars[child.name] = VarDecl(VarKind.LOCAL, TY_MODULE, {k: v.type for k, v in mod_scope.vars.items()})
elif child.name == "__init__.py":
StdlibVisitor(scope).visit(ast.parse(child.read_text()))
print(f"Visited {child}")
elif child.suffix == ".py":
mod_scope = PRELUDE.child(ScopeKind.GLOBAL)
StdlibVisitor(mod_scope).visit(ast.parse(child.read_text()))
scope.vars[child.stem] = VarDecl(VarKind.LOCAL, TY_MODULE, {k: v.type for k, v in mod_scope.vars.items()})
discover_module(typon_std, PRELUDE)
print("Stdlib visited!")
import ast
from dataclasses import dataclass
from typing import Optional, List
from transpiler.phases.typing.scope import Scope
from transpiler.phases.typing.types import BaseType, TY_NONE, TypeType, TY_SELF
from transpiler.phases.utils import NodeVisitorSeq
@dataclass
class TypeAnnotationVisitor(NodeVisitorSeq):
scope: Scope
cur_class: Optional[TypeType] = None
def visit_str(self, node: str) -> BaseType:
if node in ("Self", "self") and self.cur_class:
return TY_SELF
if existing := self.scope.get(node):
ty = existing.type
if isinstance(ty, TypeType):
return ty.type_object
return ty
raise NameError(node)
def visit_Name(self, node: ast.Name) -> BaseType:
return self.visit_str(node.id)
def visit_Constant(self, node: ast.Constant) -> BaseType:
if node.value is None:
return TY_NONE
if type(node.value) == str:
return node.value
raise NotImplementedError
def visit_Subscript(self, node: ast.Subscript) -> BaseType:
ty_op = self.visit(node.value)
args = list(node.slice.elts) if type(node.slice) == ast.Tuple else [node.slice]
args = [self.visit(arg) for arg in args]
return ty_op(*args)
# return TypeOperator([self.visit(node.value)], self.visit(node.slice.value))
def visit_List(self, node: ast.List) -> List[BaseType]:
return [self.visit(elt) for elt in node.elts]
import ast
from dataclasses import dataclass
from typing import Optional
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, TY_MODULE
@dataclass
class ScoperBlockVisitor(ScoperVisitor):
stdlib: bool = False
def expr(self) -> ScoperExprVisitor:
return ScoperExprVisitor(self.scope, self.root_decls)
def visit_Import(self, node: ast.Import):
for alias in node.names:
self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, None)
def visit_ImportFrom(self, node: ast.ImportFrom):
module = self.scope.get(node.module)
if not module:
raise NameError(node.module)
if module.type is not TY_MODULE:
raise IncompatibleTypesError(f"{node.module} is not a module")
for alias in node.names:
thing = module.val.get(alias.name)
if not thing:
raise NameError(alias.name)
self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, thing)
def visit_Module(self, node: ast.Module):
for stmt in node.body:
self.visit(stmt)
def get_type(self, node: ast.expr) -> BaseType:
if type := getattr(node, "type", None):
return type
self.expr().visit(node)
return node.type
# ntype = TypeVariable()
# node.type = ntype
# return ntype
def visit_Assign(self, node: ast.Assign):
if len(node.targets) != 1:
raise NotImplementedError(node)
target = node.targets[0]
ty = self.get_type(node.value)
target.type = ty
node.is_declare = self.visit_assign_target(target, ty)
def visit_assign_target(self, target, decl_val: BaseType) -> bool:
if isinstance(target, ast.Name):
if vdecl := self.scope.get(target.id):
vdecl.type.unify(decl_val)
return False
else:
self.scope.vars[target.id] = VarDecl(VarKind.LOCAL, decl_val)
if self.scope.kind == ScopeKind.FUNCTION_INNER:
self.root_decls[target.id] = VarDecl(VarKind.OUTER_DECL, decl_val)
return True
else:
raise NotImplementedError(target)
def anno(self) -> "TypeAnnotationVisitor":
return TypeAnnotationVisitor(self.scope)
def visit_annotation(self, expr: Optional[ast.expr]) -> BaseType:
return self.anno().visit(expr) if expr else TypeVariable()
def visit_FunctionDef(self, node: ast.FunctionDef):
argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args]
rtype = self.visit_annotation(node.returns)
ftype = FunctionType(argtypes, rtype)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype)
scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype
scope.function = scope
node.scope = scope
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
for b in node.body:
decls = {}
visitor = ScoperBlockVisitor(scope, decls)
visitor.visit(b)
b.decls = decls
def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
scope.function = self.scope.function
node.scope = scope
visitor = ScoperBlockVisitor(scope, self.root_decls)
for b in node.body:
visitor.visit(b)
def visit_Expr(self, node: ast.Expr):
self.expr().visit(node.value)
def visit_Return(self, node: ast.Return):
fct = self.scope.function
if fct is None:
raise IncompatibleTypesError("Return outside function")
ftype = fct.obj_type
assert isinstance(ftype, FunctionType)
vtype = self.expr().visit(node.value) if node.value else None
vtype.unify(ftype.return_type)
def visit_Global(self, node: ast.Global):
for name in node.names:
self.scope.function.vars[name] = VarDecl(VarKind.GLOBAL, None)
if name not in self.scope.global_scope.vars:
self.scope.global_scope.vars[name] = VarDecl(VarKind.LOCAL, None)
def visit_AugAssign(self, node: ast.AugAssign):
equivalent = ast.Assign(targets=[node.target], value=ast.BinOp(left=node.target, op=node.op, right=node.value))
self.visit(equivalent)
def visit(self, node: ast.AST):
if isinstance(node, ast.AST):
super().visit(node)
node.scope = self.scope
from dataclasses import dataclass, field
from typing import Dict
from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl
from transpiler.phases.utils import NodeVisitorSeq
PRELUDE = Scope.make_global()
@dataclass
class ScoperVisitor(NodeVisitorSeq):
scope: Scope = field(default_factory=lambda: PRELUDE.child(ScopeKind.GLOBAL))
root_decls: Dict[str, VarDecl] = field(default_factory=dict)
\ No newline at end of file
import ast
from typing import List
from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import IncompatibleTypesError, BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict
DUNDER = {
ast.Eq: "eq",
ast.NotEq: "ne",
ast.Mult: "mul",
ast.Add: "add",
ast.Sub: "sub",
ast.Div: "truediv",
ast.FloorDiv: "floordiv",
ast.Mod: "mod",
ast.Lt: "lt",
ast.Gt: "gt",
ast.GtE: "ge",
ast.LtE: "le",
ast.LShift: "lshift",
ast.RShift: "rshift",
ast.BitXor: "xor",
ast.BitOr: "or",
ast.BitAnd: "and",
ast.USub: "neg",
ast.UAdd: "pos",
ast.Invert: "invert",
}
class ScoperExprVisitor(ScoperVisitor):
def visit(self, node) -> BaseType:
if existing := getattr(node, "type", None):
return existing
res = super().visit(node)
if not res:
raise NotImplementedError(f"`{ast.unparse(node)}` {type(node)}")
res = res.resolve()
node.type = res
return res
def visit_Tuple(self, node: ast.Tuple) -> BaseType:
return TupleType([self.visit(e) for e in node.elts])
def visit_Constant(self, node: ast.Constant) -> BaseType:
if isinstance(node.value, str):
return TY_STR
elif isinstance(node.value, bool):
return TY_BOOL
elif isinstance(node.value, int):
return TY_INT
elif isinstance(node.value, complex):
return TY_COMPLEX
elif node.value is None:
return TY_NONE
else:
raise NotImplementedError(node, type(node))
def visit_Name(self, node: ast.Name) -> BaseType:
obj = self.scope.get(node.id)
if not obj:
raise NameError(f"Name {node.id} is not defined")
return obj.type
def visit_Compare(self, node: ast.Compare) -> BaseType:
# todo:
self.visit(node.left)
for op, right in zip(node.ops, node.comparators):
self.visit(right)
return TY_BOOL
#raise NotImplementedError(node)
def visit_Call(self, node: ast.Call) -> BaseType:
ftype = self.visit(node.func)
return self.visit_function_call(ftype, [self.visit(arg) for arg in node.args])
def visit_function_call(self, ftype: BaseType, arguments: List[BaseType]):
if not isinstance(ftype, FunctionType):
raise IncompatibleTypesError(f"Cannot call {ftype}")
#is_generic = any(isinstance(arg, TypeVariable) for arg in ftype.to_list())
equivalent = FunctionType(arguments, ftype.return_type)
try:
ftype.unify(equivalent)
except:
raise IncompatibleTypesError(f"Cannot call {ftype} with {equivalent}")
return ftype.return_type
def visit_Lambda(self, node: ast.Lambda) -> BaseType:
argtypes = [TypeVariable() for _ in node.args.args]
rtype = TypeVariable()
ftype = FunctionType(argtypes, rtype)
scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype
scope.function = scope
node.scope = scope
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
decls = {}
visitor = ScoperExprVisitor(scope, decls)
rtype.unify(visitor.visit(node.body))
node.body.decls = decls
return ftype
def visit_BinOp(self, node: ast.BinOp) -> BaseType:
left, right = map(self.visit, (node.left, node.right))
try:
return self.visit_function_call(
self.visit_getattr(TypeType(left), f"__{DUNDER[type(node.op)]}__"),
[left, right]
)
except IncompatibleTypesError as e:
raise IncompatibleTypesError(f"{e} in `{ast.unparse(node)}`")
def visit_Attribute(self, node: ast.Attribute) -> BaseType:
ltype = self.visit(node.value)
try:
return self.visit_getattr(ltype, node.attr)
except IncompatibleTypesError as e:
raise IncompatibleTypesError(f"{e} in `{ast.unparse(node)}`")
def visit_getattr(self, ltype: BaseType, name: str):
bound = True
if isinstance(ltype, TypeType):
ltype = ltype.type_object
bound = False
if attr := ltype.members.get(name):
return attr
if meth := ltype.methods.get(name):
if bound:
return FunctionType(meth.parameters[1:], meth.return_type)
else:
return meth
raise IncompatibleTypesError(f"Type {ltype} has no attribute {name}")
def visit_List(self, node: ast.List) -> BaseType:
if not node.elts:
return PyList(TypeVariable())
elems = [self.visit(e) for e in node.elts]
if len(set(elems)) != 1:
raise NotImplementedError("List with different types not handled yet")
return PyList(elems[0])
def visit_Set(self, node: ast.Set) -> BaseType:
if not node.elts:
return PySet(TypeVariable())
elems = [self.visit(e) for e in node.elts]
if len(set(elems)) != 1:
raise NotImplementedError("Set with different types not handled yet")
return PySet(elems[0])
def visit_Dict(self, node: ast.Dict) -> BaseType:
if not node.keys:
return PyDict(TypeVariable())
keys = [self.visit(e) for e in node.keys]
values = [self.visit(e) for e in node.values]
if len(set(keys)) != 1 or len(set(values)) != 1:
raise NotImplementedError("Dict with different types not handled yet")
return PyDict(keys[0], values[0])
def visit_Subscript(self, node: ast.Subscript) -> BaseType:
raise NotImplementedError(node)
def visit_UnaryOp(self, node: ast.UnaryOp) -> BaseType:
raise NotImplementedError(node)
def visit_IfExp(self, node: ast.IfExp) -> BaseType:
self.visit(node.test)
then = self.visit(node.body)
else_ = self.visit(node.orelse)
if then != else_:
raise NotImplementedError("IfExp with different types not handled yet")
return then
def visit_Yield(self, node: ast.Yield) -> BaseType:
raise NotImplementedError(node)
from dataclasses import field, dataclass
from enum import Enum
from typing import Optional, Dict, List, Any
from transpiler.phases.typing.types import BaseType
class VarKind(Enum):
"""Kind of variable."""
LOCAL = 1
"""`xxx = ...`"""
GLOBAL = 2
"""`global xxx"""
NONLOCAL = 3
"""`nonlocal xxx`"""
SELF = 4
OUTER_DECL = 5
class VarType:
pass
class RuntimeValue:
pass
@dataclass
class VarDecl:
kind: VarKind
type: BaseType
val: Any = RuntimeValue()
class ScopeKind(Enum):
GLOBAL = 1
"""Global (module) scope"""
FUNCTION = 2
"""Function scope"""
FUNCTION_INNER = 3
"""Block (if, for, ...) scope inside a function"""
CLASS = 4
"""Class scope"""
@dataclass
class Scope:
parent: Optional["Scope"] = None
kind: ScopeKind = ScopeKind.GLOBAL
function: Optional["Scope"] = None
global_scope: Optional["Scope"] = None
vars: Dict[str, VarDecl] = field(default_factory=dict)
children: List["Scope"] = field(default_factory=list)
obj_type: Optional[BaseType] = None
@staticmethod
def make_global():
res = Scope()
res.global_scope = res
return res
def child(self, kind: ScopeKind):
res = Scope(self, kind, self.function, self.global_scope)
self.children.append(res)
return res
def declare_local(self, name: str):
"""Declares a local variable"""
def get(self, name: str) -> Optional[VarDecl]:
"""
Gets the variable declaration of a variable in the current scope or any parent scope.
"""
if (res := self.vars.get(name)) and res.kind == VarKind.LOCAL:
return res
if self.parent is not None:
return self.parent.get(name)
return None
import ast
import dataclasses
from dataclasses import dataclass, field
from typing import Optional, List, Dict
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.common import PRELUDE
from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.scope import Scope, VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, TypeOperator, FunctionType, TY_VARARG, TypeType, TypeVariable
from transpiler.phases.utils import NodeVisitorSeq
@dataclass
class StdlibVisitor(NodeVisitorSeq):
scope: Scope = field(default_factory=lambda: PRELUDE)
cur_class: Optional[BaseType] = None
typevars: Dict[str, BaseType] = field(default_factory=dict)
def visit_Module(self, node: ast.Module):
for stmt in node.body:
self.visit(stmt)
def visit_Assign(self, node: ast.Assign):
self.scope.vars[node.targets[0].id] = VarDecl(VarKind.LOCAL, self.visit(node.value))
def visit_AnnAssign(self, node: ast.AnnAssign):
self.scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, self.anno().visit(node.annotation))
def visit_ImportFrom(self, node: ast.ImportFrom):
pass
def visit_Import(self, node: ast.Import):
pass
def visit_ClassDef(self, node: ast.ClassDef):
typevars = []
for b in node.bases:
if isinstance(b, ast.Subscript) and isinstance(b.value, ast.Name) and b.value.id == "Generic":
if isinstance(b.slice, ast.Index):
typevars = [b.slice.value.id]
elif isinstance(b.slice, ast.Tuple):
typevars = [n.id for n in b.slice.value.elts]
if existing := self.scope.get(node.name):
ty = existing.type
else:
ty = TypeOperator([], node.name)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
cl_scope = self.scope.child(ScopeKind.CLASS)
visitor = StdlibVisitor(cl_scope, ty)
for var in typevars:
visitor.typevars[var] = TypeType(TypeVariable(var))
for stmt in node.body:
visitor.visit(stmt)
def visit_FunctionDef(self, node: ast.FunctionDef):
arg_visitor = TypeAnnotationVisitor(self.scope.child(ScopeKind.FUNCTION), self.cur_class)
arg_types = [arg_visitor.visit(arg.annotation or arg.arg) for arg in node.args.args]
ret_type = arg_visitor.visit(node.returns)
ty = FunctionType(arg_types, ret_type)
if node.args.vararg:
ty.variadic = True
#arg_types.append(TY_VARARG)
if self.cur_class:
if isinstance(self.cur_class, TypeType):
# ty_inst = FunctionType(arg_types[1:], ret_type)
# self.cur_class.args[0].add_inst_member(node.name, ty_inst)
self.cur_class.type_object.methods[node.name] = ty.gen_sub(self.cur_class.type_object, self.typevars)
else:
self.cur_class.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
def visit_Assert(self, node: ast.Assert):
print("Type of", ast.unparse(node.test), ":=", ScoperExprVisitor().visit(node.test))
def visit_Call(self, node: ast.Call) -> BaseType:
ty_op = self.visit(node.func)
if isinstance(ty_op, type):
return ty_op(*[ast.literal_eval(arg) for arg in node.args])
raise NotImplementedError
def anno(self) -> "TypeAnnotationVisitor":
return TypeAnnotationVisitor(self.scope, self.cur_class)
def visit_str(self, node: str) -> BaseType:
if existing := self.scope.get(node):
return existing.type
raise NameError(node)
def visit_Name(self, node: ast.Name) -> BaseType:
return self.visit_str(node.id)
\ No newline at end of file
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict, Optional, List, ClassVar, Callable
class IncompatibleTypesError(Exception):
pass
@dataclass
class BaseType(ABC):
members: Dict[str, "BaseType"] = field(default_factory=dict, init=False)
methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False)
def resolve(self) -> "BaseType":
return self
@abstractmethod
def unify_internal(self, other: "BaseType"):
pass
def unify(self, other: "BaseType"):
a, b = self.resolve(), other.resolve()
if isinstance(b, TypeVariable):
a, b = b, a
a.unify_internal(b)
def contains(self, other: "BaseType") -> bool:
needle, haystack = other.resolve(), self.resolve()
return (needle is haystack) or haystack.contains_internal(needle)
@abstractmethod
def contains_internal(self, other: "BaseType") -> bool:
pass
def gen_sub(self, this: "BaseType", typevars) -> "Self":
return self
def __repr__(self):
return str(self)
def to_list(self) -> List["BaseType"]:
return [self]
cur_var = 0
@dataclass
class TypeVariable(BaseType):
name: str = field(default_factory=lambda: chr(ord('a') + cur_var))
resolved: Optional[BaseType] = None
def __str__(self):
if self.resolved is None:
return self.name
return str(self.resolved)
def resolve(self) -> BaseType:
if self.resolved is None:
return self
return self.resolved.resolve()
def unify_internal(self, other: BaseType):
if self is not other:
if other.contains(self):
raise ValueError(f"Recursive type: {self} and {other}")
self.resolved = other
def contains_internal(self, other: BaseType) -> bool:
return self is other
def gen_sub(self, this: "BaseType", typevars) -> "Self":
if match := typevars.get(self.name):
return match
return self
GenMethodFactory = Callable[["BaseType"], "FunctionType"]
@dataclass
class TypeOperator(BaseType, ABC):
args: List[BaseType]
name: str = None
variadic: bool = False
gen_methods: ClassVar[Dict[str, GenMethodFactory]] = {}
def __post_init__(self):
if self.name is None:
self.name = self.__class__.__name__
for name, factory in self.gen_methods.items():
self.methods[name] = factory(self)
def unify_internal(self, other: BaseType):
if not isinstance(other, TypeOperator):
raise IncompatibleTypesError()
if len(self.args) != len(other.args) and not (self.variadic or other.variadic):
raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different number of arguments")
for a, b in zip(self.args, other.args):
a.unify(b)
def contains_internal(self, other: "BaseType") -> bool:
return any(arg.contains(other) for arg in self.args)
def __str__(self):
return self.name + (f"<{', '.join(map(str, self.args))}>" if self.args else "")
def __hash__(self):
return hash((self.name, tuple(self.args)))
def gen_sub(self, this: BaseType, typevars) -> "Self":
res = object.__new__(self.__class__)
if isinstance(this, TypeOperator):
vardict = dict(zip(typevars.keys(), this.args))
else:
vardict = {}
res.args = [arg.resolve().gen_sub(this, vardict) for arg in self.args]
res.name = self.name
res.variadic = self.variadic
return res
def to_list(self) -> List["BaseType"]:
return [self, *self.args]
class FunctionType(TypeOperator):
def __init__(self, args: List[BaseType], ret: BaseType):
super().__init__([ret, *args])
@property
def parameters(self):
return self.args[1:]
@property
def return_type(self):
return self.args[0]
def __str__(self):
ret, *args = map(str, self.args)
if self.variadic:
args.append(f"*args")
if args:
args = f"({', '.join(args)})"
else:
args = "()"
return f"{args} -> {ret}"
class CppType(TypeOperator):
def __init__(self, name: str):
super().__init__([name], name)
def __str__(self):
return self.name
class Union(TypeOperator):
def __init__(self, left: BaseType, right: BaseType):
super().__init__([left, right], "Union")
class TypeType(TypeOperator):
def __init__(self, arg: BaseType):
super().__init__([arg], "Type")
@property
def type_object(self) -> BaseType:
return self.args[0]
TY_TYPE = TypeOperator([], "type")
TY_INT = TypeOperator([], "int")
TY_STR = TypeOperator([], "str")
TY_BOOL = TypeOperator([], "bool")
TY_COMPLEX = TypeOperator([], "complex")
TY_NONE = TypeOperator([], "NoneType")
TY_MODULE = TypeOperator([], "module")
TY_VARARG = TypeOperator([], "vararg")
TY_SELF = TypeOperator([], "Self")
TY_SELF.gen_sub = lambda this, typevars: this
class PyList(TypeOperator):
def __init__(self, arg: BaseType):
super().__init__([arg], "list")
@property
def element_type(self):
return self.args[0]
class PySet(TypeOperator):
def __init__(self, arg: BaseType):
super().__init__([arg], "set")
@property
def element_type(self):
return self.args[0]
class PyDict(TypeOperator):
def __init__(self, key: BaseType, value: BaseType):
super().__init__([key, value], "dict")
@property
def key_type(self):
return self.args[0]
@property
def value_type(self):
return self.args[1]
class TupleType(TypeOperator):
def __init__(self, args: List[BaseType]):
super().__init__(args, "tuple")
class ForkResult(TypeOperator):
def __init__(self, args: BaseType):
super().__init__([args], "ForkResult")
@property
def return_type(self):
return self.args[0]
from transpiler.utils import UnsupportedNodeError
class NodeVisitorSeq:
def visit(self, node):
"""Visit a node."""
if type(node) == list:
for n in node:
self.visit(n)
else:
for parent in node.__class__.__mro__:
if visitor := getattr(self, 'visit_' + parent.__name__, None):
return visitor(node)
else:
self.missing_impl(node)
def missing_impl(self, node):
raise UnsupportedNodeError(node)
# coding: utf-8
import ast
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, Dict, Tuple
class VarKind(Enum):
"""Kind of variable."""
LOCAL = 1
GLOBAL = 2
NONLOCAL = 3
SELF = 4
@dataclass
class VarDecl:
kind: VarKind
val: Optional[Tuple[str, ast.AST]]
future: bool = False
@dataclass
class Scope:
parent: Optional["Scope"] = None
is_function: bool = False
vars: Dict[str, VarDecl] = field(default_factory=dict)
def is_global(self) -> bool:
"""
Determines whether this scope is the global scope. The global scope is the only scope to have no parent.
"""
return self.parent is None
def exists(self, name: str) -> bool:
"""
Determines whether a variable exists in the current scope or any parent scope.
"""
return name in self.vars or (self.parent is not None and self.parent.exists(name))
def get(self, name: str) -> Optional[VarDecl]:
"""
Gets the variable declaration of a variable in the current scope or any parent scope.
"""
if res := self.vars.get(name):
return res
if self.parent is not None:
return self.parent.get(name)
return None
def exists_local(self, name: str) -> bool:
"""
Determines whether a variable exists in the current function or global scope.
The check does not cross function boundaries; i.e. global variables are not taken into account from inside
functions.
"""
return name in self.vars or (
not self.is_function and self.parent is not None and self.parent.exists_local(name))
def child(self) -> "Scope":
"""
Creates a child scope with a new variable dictionary.
This is used for first-level elements of a function.
"""
return Scope(self, False, {})
def child_share(self) -> "Scope":
"""
Creates a child scope sharing the variable dictionary with the parent scope.
This is used for Python blocks, which share the variable scope with their parent block.
"""
return Scope(self, False, self.vars)
def function(self, **kwargs) -> "Scope":
"""
Creates a function scope.
"""
return Scope(self, True, **kwargs)
def is_root(self) -> Optional[Dict[str, VarDecl]]:
"""
Determines whether this scope is a root scope.
A root scope is either the global scope, or the first inner scope of a function.
Variable declarations in the generated code only ever appear in root scopes.
:return: `None` if this scope is not a root scope, otherwise the variable dictionary of the root scope.
"""
if self.parent is None:
return self.vars
if self.parent.is_function:
return self.parent.vars
return None
def declare(self, name: str, val: Optional[Tuple[str, ast.AST]] = None, future: bool = False) -> Optional[str]:
if self.exists_local(name):
# If the variable already exists in the current function or global scope, we don't need to declare it again.
# This is simply an assignment.
return None
vdict, prefix = self.vars, ""
if (root_vars := self.is_root()) is not None:
vdict, prefix = root_vars, "auto " # Root scope declarations can use `auto`.
vdict[name] = VarDecl(VarKind.LOCAL, val, future)
return prefix
# coding: utf-8
import ast
from dataclasses import dataclass
from itertools import zip_longest
from typing import Union
def compare_ast(node1: Union[ast.expr, list[ast.expr]], node2: Union[ast.expr, list[ast.expr]]) -> bool:
if type(node1) is not type(node2):
return False
if isinstance(node1, ast.AST):
for k, v in vars(node1).items():
if k in {"lineno", "end_lineno", "col_offset", "end_col_offset", "ctx", "type"}:
continue
if not compare_ast(v, getattr(node2, k)):
return False
return True
elif isinstance(node1, list) and isinstance(node2, list):
return all(compare_ast(n1, n2) for n1, n2 in zip_longest(node1, node2))
else:
return node1 == node2
@dataclass
class UnsupportedNodeError(Exception):
node: ast.AST
def __str__(self) -> str:
return f"Unsupported node: {self.node.__class__.__mro__} {ast.dump(self.node)}"
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment