Commit 50a42e46 authored by Tom Niget's avatar Tom Niget

Add custom exception hook, add support for Python module calling

parent f3dade02
...@@ -9,6 +9,36 @@ from transpiler.phases.if_main import IfMainVisitor ...@@ -9,6 +9,36 @@ from transpiler.phases.if_main import IfMainVisitor
from transpiler.phases.typing.block import ScoperBlockVisitor from transpiler.phases.typing.block import ScoperBlockVisitor
from transpiler.phases.typing.scope import Scope from transpiler.phases.typing.scope import Scope
import sys
from colorama import Fore
import colorama
colorama.init()
def exception_hook(exc_type, exc_value, tb):
while tb:
local_vars = tb.tb_frame.f_locals
if local_vars.get("TB_SKIP", None) and tb.tb_next:
tb = tb.tb_next
continue
filename = tb.tb_frame.f_code.co_filename
name = tb.tb_frame.f_code.co_name
line_no = tb.tb_lineno
print(f"{Fore.RED}File \"{filename}\", line {line_no}, in {name}", end="")
if info := local_vars.get("TB", None):
print(f", while {Fore.MAGENTA}{info}")
else:
print()
tb = tb.tb_next
# Exception type and value
print(f"{exc_type.__name__}, Message: {exc_value}")
sys.excepthook = exception_hook
def transpile(source): def transpile(source):
res = ast.parse(source, type_comments=True) res = ast.parse(source, type_comments=True)
......
...@@ -9,11 +9,15 @@ from transpiler.phases.emit_cpp.consts import MAPPINGS ...@@ -9,11 +9,15 @@ from transpiler.phases.emit_cpp.consts import MAPPINGS
from transpiler.phases.typing import TypeVariable from transpiler.phases.typing import TypeVariable
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TY_NONE, Promise, PromiseKind, TY_STR, UserType, \ from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TY_NONE, Promise, PromiseKind, TY_STR, UserType, \
TypeType, TypeOperator, TY_FLOAT TypeType, TypeOperator, TY_FLOAT
from transpiler.utils import UnsupportedNodeError from transpiler.utils import UnsupportedNodeError, highlight
class UniversalVisitor: class UniversalVisitor:
def visit(self, node): def visit(self, node):
"""Visit a node.""" """Visit a node."""
TB = f"emitting C++ code for {highlight(node)}"
#TB_SKIP = True
if type(node) == list: if type(node) == list:
for n in node: for n in node:
yield from self.visit(n) yield from self.visit(n)
...@@ -83,6 +87,7 @@ class NodeVisitor(UniversalVisitor): ...@@ -83,6 +87,7 @@ class NodeVisitor(UniversalVisitor):
yield from self.visit(node.return_type) yield from self.visit(node.return_type)
yield ">" yield ">"
elif isinstance(node, TypeVariable): elif isinstance(node, TypeVariable):
#yield f"TYPEVAR_{node.name}";return
raise NotImplementedError(f"Not unified type variable {node}") raise NotImplementedError(f"Not unified type variable {node}")
elif isinstance(node, TypeOperator): elif isinstance(node, TypeOperator):
yield "Py" + node.name.title() yield "Py" + node.name.title()
......
...@@ -9,6 +9,8 @@ from transpiler.phases.emit_cpp.module import ModuleVisitor, ModuleVisitor2 ...@@ -9,6 +9,8 @@ from transpiler.phases.emit_cpp.module import ModuleVisitor, ModuleVisitor2
# noinspection PyPep8Naming # noinspection PyPep8Naming
class FileVisitor(BlockVisitor): class FileVisitor(BlockVisitor):
def visit_Module(self, node: ast.Module) -> Iterable[str]: def visit_Module(self, node: ast.Module) -> Iterable[str]:
TB = "emitting C++ code for Python module"
stmt: ast.AST stmt: ast.AST
yield "#include <python/builtins.hpp>" yield "#include <python/builtins.hpp>"
yield "#include <python/sys.hpp>" yield "#include <python/sys.hpp>"
......
...@@ -3,12 +3,14 @@ import ast ...@@ -3,12 +3,14 @@ import ast
from typing import Iterable from typing import Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from transpiler.phases.typing import FunctionType
from transpiler.phases.typing.scope import Scope from transpiler.phases.typing.scope import Scope
from transpiler.phases.emit_cpp import CoroutineMode, FunctionEmissionKind, NodeVisitor from transpiler.phases.emit_cpp import CoroutineMode, FunctionEmissionKind, NodeVisitor
from transpiler.phases.emit_cpp.block import BlockVisitor from transpiler.phases.emit_cpp.block import BlockVisitor
from transpiler.phases.emit_cpp.class_ import ClassVisitor from transpiler.phases.emit_cpp.class_ import ClassVisitor
from transpiler.phases.emit_cpp.function import FunctionVisitor from transpiler.phases.emit_cpp.function import FunctionVisitor
from transpiler.utils import compare_ast from transpiler.utils import compare_ast, highlight
# noinspection PyPep8Naming # noinspection PyPep8Naming
...@@ -16,19 +18,64 @@ from transpiler.utils import compare_ast ...@@ -16,19 +18,64 @@ from transpiler.utils import compare_ast
class ModuleVisitor(BlockVisitor): class ModuleVisitor(BlockVisitor):
includes: list[str] = field(default_factory=list) includes: list[str] = field(default_factory=list)
def visit_Import(self, node: ast.Import) -> Iterable[str]: def visit_Import(self, node: ast.Import) -> Iterable[str]:
TB = f"emitting C++ code for {highlight(node)}"
for alias in node.names: for alias in node.names:
if alias.name in {"typon", "typing", "__future__"}: concrete = alias.asname or alias.name
if alias.module_obj.is_python:
yield f"namespace py_{concrete} {{"
yield f"struct {concrete}_t {{"
for name, obj in alias.module_obj.members.items():
if obj.python_func_used:
yield from self.emit_python_func(alias.name, name, name, obj)
yield "} all;"
yield f"auto& get_all() {{ return all; }}"
yield "}"
yield f'auto& {concrete} = py_{concrete}::get_all();'
elif alias.name in {"typon", "typing", "__future__"}:
yield "" yield ""
else: else:
yield from self.import_module(alias.name) yield from self.import_module(alias.name)
yield f'auto& {alias.asname or alias.name} = py_{alias.name}::get_all();' yield f'auto& {concrete} = py_{alias.name}::get_all();'
def import_module(self, name: str) -> Iterable[str]: def import_module(self, name: str) -> Iterable[str]:
self.includes.append(f'#include <python/{name}.hpp>') self.includes.append(f'#include <python/{name}.hpp>')
yield "" yield ""
def emit_python_func(self, mod: str, name: str, alias: str, fty: FunctionType) -> Iterable[str]:
TB = f"emitting C++ code for Python function {highlight(f'{mod}.{name}')}"
yield f"auto {alias}("
for i, argty in enumerate(fty.parameters):
if i != 0:
yield ", "
yield "lvalue_or_rvalue<"
yield from self.visit(argty)
yield f"> arg{i}"
yield ") {"
yield "py::scoped_interpreter guard{};"
yield f"return py::module_::import(\"{mod}\").attr(\"{name}\")("
for i, argty in enumerate(fty.parameters):
if i != 0:
yield ", "
yield f"*arg{i}"
yield ").cast<"
yield from self.visit(fty.return_type)
yield ">();"
yield "}"
def visit_ImportFrom(self, node: ast.ImportFrom) -> Iterable[str]: def visit_ImportFrom(self, node: ast.ImportFrom) -> Iterable[str]:
if node.module in {"typon", "typing", "__future__"}: if node.module_obj.is_python:
for alias in node.names:
fty = alias.item_obj
assert isinstance(fty, FunctionType)
yield from self.emit_python_func(node.module, alias.name, alias.asname or alias.name, fty)
yield "// python"
elif node.module in {"typon", "typing", "__future__"}:
yield "" yield ""
else: else:
yield from self.import_module(node.module) yield from self.import_module(node.module)
......
...@@ -50,7 +50,7 @@ def discover_module(path: Path, scope): ...@@ -50,7 +50,7 @@ def discover_module(path: Path, scope):
if child.is_dir(): if child.is_dir():
mod_scope = PRELUDE.child(ScopeKind.GLOBAL) mod_scope = PRELUDE.child(ScopeKind.GLOBAL)
discover_module(child, mod_scope) discover_module(child, mod_scope)
scope.vars[child.name] = make_mod_decl(child, mod_scope) scope.vars[child.name] = make_mod_decl(child.name, mod_scope)
elif child.name == "__init__.py": elif child.name == "__init__.py":
StdlibVisitor(scope).visit(ast.parse(child.read_text())) StdlibVisitor(scope).visit(ast.parse(child.read_text()))
print(f"Visited {child}") print(f"Visited {child}")
...@@ -59,12 +59,12 @@ def discover_module(path: Path, scope): ...@@ -59,12 +59,12 @@ def discover_module(path: Path, scope):
StdlibVisitor(mod_scope).visit(ast.parse(child.read_text())) StdlibVisitor(mod_scope).visit(ast.parse(child.read_text()))
if child.stem[-1] == "_": if child.stem[-1] == "_":
child = child.with_name(child.stem[:-1]) child = child.with_name(child.stem[:-1])
scope.vars[child.stem] = make_mod_decl(child, mod_scope) scope.vars[child.stem] = make_mod_decl(child.name, mod_scope)
print(f"Visited {child}") print(f"Visited {child}")
def make_mod_decl(child, mod_scope): def make_mod_decl(child, mod_scope):
return VarDecl(VarKind.MODULE, make_module(child.name, mod_scope), {k: v.type for k, v in mod_scope.vars.items()}) return VarDecl(VarKind.MODULE, make_module(child, mod_scope), {k: v.type for k, v in mod_scope.vars.items()})
discover_module(typon_std, PRELUDE) discover_module(typon_std, PRELUDE)
......
import ast import ast
import dataclasses import dataclasses
import importlib
from dataclasses import dataclass from dataclasses import dataclass
from transpiler.phases.typing import make_mod_decl
from transpiler.phases.typing.common import ScoperVisitor from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.expr import ScoperExprVisitor from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.class_ import ScoperClassVisitor from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, \ from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, \
Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType
from transpiler.phases.utils import PlainBlock from transpiler.phases.utils import PlainBlock
...@@ -21,26 +23,43 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -21,26 +23,43 @@ class ScoperBlockVisitor(ScoperVisitor):
def visit_Pass(self, node: ast.Pass): def visit_Pass(self, node: ast.Pass):
pass pass
def get_module(self, name: str) -> VarDecl:
mod = self.scope.get(name, VarKind.MODULE)
if mod is None:
# try lookup with importlib
py_mod = importlib.import_module(name)
mod_scope = Scope()
# copy all functions to mod_scope
for fname, obj in py_mod.__dict__.items():
if callable(obj):
fty = FunctionType([], TypeVariable())
fty.is_python_func = True
mod_scope.vars[fname] = VarDecl(VarKind.LOCAL, fty)
mod = make_mod_decl(name, mod_scope)
mod.type.is_python = True
self.scope.vars[name] = mod
if mod is None:
raise NameError(name)
assert isinstance(mod, VarDecl), mod
assert isinstance(mod.type, ModuleType), mod.type
return mod
def visit_Import(self, node: ast.Import): def visit_Import(self, node: ast.Import):
for alias in node.names: for alias in node.names:
mod = self.scope.get(alias.name, VarKind.MODULE) mod = self.get_module(alias.name)
if mod is None: alias.module_obj = mod.type
raise NameError(alias.name)
assert isinstance(mod, VarDecl), mod
self.scope.vars[alias.asname or alias.name] = dataclasses.replace(mod, kind=VarKind.LOCAL) self.scope.vars[alias.asname or alias.name] = dataclasses.replace(mod, kind=VarKind.LOCAL)
def visit_ImportFrom(self, node: ast.ImportFrom): def visit_ImportFrom(self, node: ast.ImportFrom):
if node.module in {"typing", "__future__"}: if node.module in {"typing", "__future__"}:
return return
module = self.scope.get(node.module, VarKind.MODULE) module = self.get_module(node.module)
if not module: node.module_obj = module.type
raise NameError(node.module)
if not isinstance(module.type, ModuleType):
raise IncompatibleTypesError(f"{node.module} is not a module")
for alias in node.names: for alias in node.names:
thing = module.val.get(alias.name) thing = module.val.get(alias.name)
if not thing: if not thing:
raise NameError(alias.name) raise NameError(alias.name)
alias.item_obj = thing
self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, thing) self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, thing)
def visit_Module(self, node: ast.Module): def visit_Module(self, node: ast.Module):
...@@ -66,8 +85,8 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -66,8 +85,8 @@ class ScoperBlockVisitor(ScoperVisitor):
raise IncompatibleTypesError(f"`{ast.unparse(node)}: {e}") from e raise IncompatibleTypesError(f"`{ast.unparse(node)}: {e}") from e
def visit_AnnAssign(self, node: ast.AnnAssign): def visit_AnnAssign(self, node: ast.AnnAssign):
if node.value is not None: # if node.value is not None:
raise NotImplementedError(node) # raise NotImplementedError(node)
if node.simple != 1: if node.simple != 1:
raise NotImplementedError(node) raise NotImplementedError(node)
if not isinstance(node.target, ast.Name): if not isinstance(node.target, ast.Name):
...@@ -77,6 +96,11 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -77,6 +96,11 @@ class ScoperBlockVisitor(ScoperVisitor):
node.is_declare = self.visit_assign_target(node.target, ty) node.is_declare = self.visit_assign_target(node.target, ty)
except IncompatibleTypesError as e: except IncompatibleTypesError as e:
raise IncompatibleTypesError(f"`{ast.unparse(node)}: {e}") raise IncompatibleTypesError(f"`{ast.unparse(node)}: {e}")
ty_val = self.get_type(node.value)
try:
ty.unify(ty_val)
except IncompatibleTypesError as e:
raise IncompatibleTypesError(f"`{ast.unparse(node)}: {e}")
def visit_assign_target(self, target, decl_val: BaseType) -> bool: def visit_assign_target(self, target, decl_val: BaseType) -> bool:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
......
...@@ -77,6 +77,8 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -77,6 +77,8 @@ class ScoperExprVisitor(ScoperVisitor):
raise NameError(f"Name {node.id} is not defined") raise NameError(f"Name {node.id} is not defined")
if isinstance(obj.type, TypeType) and isinstance(obj.type.type_object, TypeVariable): if isinstance(obj.type, TypeType) and isinstance(obj.type.type_object, TypeVariable):
raise NameError(f"Use of type variable") raise NameError(f"Use of type variable")
if getattr(obj, "is_python_func", False):
obj.python_func_used = True
return obj.type return obj.type
def visit_Compare(self, node: ast.Compare) -> BaseType: def visit_Compare(self, node: ast.Compare) -> BaseType:
...@@ -165,6 +167,8 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -165,6 +167,8 @@ class ScoperExprVisitor(ScoperVisitor):
raise NotImplementedError("I don't know how to handle this type") raise NotImplementedError("I don't know how to handle this type")
ltype = ltype(*(TypeVariable() for _ in args)) ltype = ltype(*(TypeVariable() for _ in args))
if attr := ltype.members.get(name): if attr := ltype.members.get(name):
if getattr(attr, "is_python_func", False):
attr.python_func_used = True
return attr return attr
if meth := ltype.methods.get(name): if meth := ltype.methods.get(name):
if bound: if bound:
......
...@@ -201,7 +201,11 @@ class TypeOperator(BaseType, ABC): ...@@ -201,7 +201,11 @@ class TypeOperator(BaseType, ABC):
if self.optional_at is not None and i >= self.optional_at: if self.optional_at is not None and i >= self.optional_at:
continue continue
else: else:
raise IncompatibleTypesError(f"Cannot unify {self} and {other}, not enough arguments") if getattr(other, "is_python_func", False):
other.args.append(a)
continue
else:
raise IncompatibleTypesError(f"Cannot unify {self} and {other}, not enough arguments")
if isinstance(a, BaseType) and isinstance(b, BaseType): if isinstance(a, BaseType) and isinstance(b, BaseType):
a.unify(b) a.unify(b)
...@@ -243,10 +247,13 @@ class TypeOperator(BaseType, ABC): ...@@ -243,10 +247,13 @@ class TypeOperator(BaseType, ABC):
@dataclass @dataclass
class ModuleType(TypeOperator): class ModuleType(TypeOperator):
pass is_python: bool = False
class FunctionType(TypeOperator): class FunctionType(TypeOperator):
is_python_func: bool = False
python_func_used: bool = False
def __init__(self, args: List[BaseType], ret: BaseType): def __init__(self, args: List[BaseType], ret: BaseType):
super().__init__([ret, *args]) super().__init__([ret, *args])
......
...@@ -4,6 +4,28 @@ from dataclasses import dataclass ...@@ -4,6 +4,28 @@ from dataclasses import dataclass
from itertools import zip_longest from itertools import zip_longest
from typing import Union from typing import Union
from colorama import Fore
def highlight(code):
"""
Syntax highlights code as Python using colorama
"""
from transpiler.phases.typing import BaseType
if isinstance(code, ast.AST):
return f"{Fore.WHITE}[{type(code).__name__}] " + highlight(ast.unparse(code))
elif isinstance(code, BaseType):
return f"{Fore.WHITE}[{type(code).__name__}] " + highlight(str(code))
from pygments import highlight as pyg_highlight
from pygments.lexers import PythonLexer
from pygments.formatters import TerminalFormatter
items = pyg_highlight(code, PythonLexer(), TerminalFormatter()).splitlines()
res = items[0]
if len(items) > 1:
res += Fore.WHITE + " [...]"
return Fore.RESET + res
def compare_ast(node1: Union[ast.expr, list[ast.expr]], node2: Union[ast.expr, list[ast.expr]]) -> bool: def compare_ast(node1: Union[ast.expr, list[ast.expr]], node2: Union[ast.expr, list[ast.expr]]) -> bool:
if type(node1) is not type(node2): if type(node1) is not type(node2):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment