Commit e84a9530 authored by Tom Niget's avatar Tom Niget

User generics attempt #5: monomorphization works

parent 101d8867
...@@ -5,9 +5,14 @@ from dataclasses import dataclass ...@@ -5,9 +5,14 @@ from dataclasses import dataclass
T = TypeVar("T") T = TypeVar("T")
@dataclass @dataclass
class Thing(): class Thing(Generic[T]):
x: int x: T
if __name__ == "__main__": if __name__ == "__main__":
a = Thing(1) a = Thing[int](1)
\ No newline at end of file b = Thing[str]("abc")
print(a)
print(b)
...@@ -3,8 +3,9 @@ import ast ...@@ -3,8 +3,9 @@ import ast
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Iterable, Optional from typing import Iterable, Optional
from transpiler.phases.typing.common import is_builtin
from transpiler.phases.typing.scope import Scope from transpiler.phases.typing.scope import Scope
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TypeVariable, Promise from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TypeVariable, Promise, TypeType
from transpiler.utils import compare_ast from transpiler.utils import compare_ast
from transpiler.phases.emit_cpp import NodeVisitor, CoroutineMode, flatmap, FunctionEmissionKind from transpiler.phases.emit_cpp import NodeVisitor, CoroutineMode, flatmap, FunctionEmissionKind
from transpiler.phases.emit_cpp.expr import ExpressionVisitor from transpiler.phases.emit_cpp.expr import ExpressionVisitor
...@@ -166,6 +167,9 @@ class BlockVisitor(NodeVisitor): ...@@ -166,6 +167,9 @@ class BlockVisitor(NodeVisitor):
def visit_Assign(self, node: ast.Assign) -> Iterable[str]: def visit_Assign(self, node: ast.Assign) -> Iterable[str]:
if len(node.targets) != 1: if len(node.targets) != 1:
raise NotImplementedError(node) raise NotImplementedError(node)
if isinstance(node.targets[0].type, TypeType) and isinstance(node.targets[0].type.type_object, TypeVariable):
yield from ()
return
#if node.value.type #if node.value.type
yield from self.visit_lvalue(node.targets[0], node.is_declare) yield from self.visit_lvalue(node.targets[0], node.is_declare)
yield " = " yield " = "
......
...@@ -9,6 +9,11 @@ from transpiler.phases.emit_cpp import NodeVisitor, FunctionEmissionKind ...@@ -9,6 +9,11 @@ from transpiler.phases.emit_cpp import NodeVisitor, FunctionEmissionKind
class ClassVisitor(NodeVisitor): class ClassVisitor(NodeVisitor):
def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]: def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]:
if gen_instances := getattr(node, "gen_instances", None):
for args, inst in gen_instances.items():
yield from self.visit_ClassDef(inst)
return
yield f"struct {node.name}_s;" yield f"struct {node.name}_s;"
yield f"extern {node.name}_s {node.name};" yield f"extern {node.name}_s {node.name};"
yield f"struct {node.name}_s {{" yield f"struct {node.name}_s {{"
......
...@@ -3,7 +3,8 @@ import ast ...@@ -3,7 +3,8 @@ import ast
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Iterable from typing import List, Iterable
from transpiler.phases.typing.types import UserType, FunctionType, Promise, TypeType from transpiler.phases.typing.types import UserType, FunctionType, Promise, TypeType, GenericUserType, \
MonomorphizedUserType
from transpiler.phases.utils import make_lnd from transpiler.phases.utils import make_lnd
from transpiler.utils import compare_ast, linenodata from transpiler.utils import compare_ast, linenodata
from transpiler.phases.emit_cpp.consts import SYMBOLS, PRECEDENCE_LEVELS, DUNDER_SYMBOLS from transpiler.phases.emit_cpp.consts import SYMBOLS, PRECEDENCE_LEVELS, DUNDER_SYMBOLS
...@@ -277,6 +278,9 @@ class ExpressionVisitor(NodeVisitor): ...@@ -277,6 +278,9 @@ class ExpressionVisitor(NodeVisitor):
yield "{}" yield "{}"
def visit_Subscript(self, node: ast.Subscript) -> Iterable[str]: def visit_Subscript(self, node: ast.Subscript) -> Iterable[str]:
if isinstance(node.type.type_object, MonomorphizedUserType):
yield node.type.type_object.name
return
yield from self.prec("[]").visit(node.value) yield from self.prec("[]").visit(node.value)
yield "[" yield "["
yield from self.reset().visit(node.slice) yield from self.reset().visit(node.slice)
......
...@@ -128,6 +128,10 @@ class ModuleVisitorExt(NodeVisitor): ...@@ -128,6 +128,10 @@ class ModuleVisitorExt(NodeVisitor):
yield f'm.def("{node.name}", PROGRAMNS::{node.name});' yield f'm.def("{node.name}", PROGRAMNS::{node.name});'
def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]: def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]:
if gen_instances := getattr(node, "gen_instances", None):
for args, inst in gen_instances.items():
yield from self.visit_ClassDef(inst)
return
yield f"py::class_<PROGRAMNS::{node.name}_s::py_type>(m, \"{node.name}\")" yield f"py::class_<PROGRAMNS::{node.name}_s::py_type>(m, \"{node.name}\")"
if init := node.type.fields.get("__init__", None): if init := node.type.fields.get("__init__", None):
init = init.type.resolve().remove_self() init = init.type.resolve().remove_self()
......
...@@ -25,7 +25,7 @@ PRELUDE.vars.update({ ...@@ -25,7 +25,7 @@ PRELUDE.vars.update({
"complex": VarDecl(VarKind.LOCAL, TypeType(TY_COMPLEX)), "complex": VarDecl(VarKind.LOCAL, TypeType(TY_COMPLEX)),
"None": VarDecl(VarKind.LOCAL, TypeType(TY_NONE)), "None": VarDecl(VarKind.LOCAL, TypeType(TY_NONE)),
"Callable": VarDecl(VarKind.LOCAL, TypeType(FunctionType)), "Callable": VarDecl(VarKind.LOCAL, TypeType(FunctionType)),
"TypeVar": VarDecl(VarKind.LOCAL, TypeType(TypeVariable)), #"TypeVar": VarDecl(VarKind.LOCAL, TypeType(TypeVariable)),
"CppType": VarDecl(VarKind.LOCAL, TypeType(CppType)), "CppType": VarDecl(VarKind.LOCAL, TypeType(CppType)),
"list": VarDecl(VarKind.LOCAL, TypeType(PyList)), "list": VarDecl(VarKind.LOCAL, TypeType(PyList)),
"dict": VarDecl(VarKind.LOCAL, TypeType(PyDict)), "dict": VarDecl(VarKind.LOCAL, TypeType(PyDict)),
......
import ast import ast
import copy
import dataclasses import dataclasses
import importlib import importlib
from dataclasses import dataclass from dataclasses import dataclass
...@@ -12,7 +13,7 @@ from transpiler.phases.typing.class_ import ScoperClassVisitor ...@@ -12,7 +13,7 @@ from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \ from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \
Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType, BuiltinFeature, TY_INT, MemberDef, \ Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType, BuiltinFeature, TY_INT, MemberDef, \
RuntimeValue RuntimeValue, GenericUserType, MonomorphizedUserType
from transpiler.phases.utils import PlainBlock, AnnotationName from transpiler.phases.utils import PlainBlock, AnnotationName
...@@ -150,44 +151,23 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -150,44 +151,23 @@ class ScoperBlockVisitor(ScoperVisitor):
ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK) ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype) self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype)
def visit_ClassDef(self, node: ast.ClassDef): def process_class_ast(self, ctype: BaseType, node: ast.ClassDef, bases_after: list[ast.expr]):
class NewUserType(UserType):
def __init__(self):
super().__init__(node.name)
#ctype = UserType(node.name)
ctype = NewUserType
cttype = TypeType(ctype)
typevars = []
bases_after = []
for base in node.bases:
if isinstance(base, ast.Subscript):
if isinstance(base.slice, ast.Name):
sliceval = [base.slice.id]
elif isinstance(base.slice, ast.Tuple):
sliceval = [n.id for n in base.slice.elts]
if is_builtin(self.expr().visit(base.slice), "Generic"):
typevars = sliceval
else:
bases_after.append(base)
if not typevars:
cttype.type_object = cttype.type_object()
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, cttype)
scope = self.scope.child(ScopeKind.CLASS) scope = self.scope.child(ScopeKind.CLASS)
scope.obj_type = cttype.type_object scope.obj_type = ctype
scope.class_ = scope scope.class_ = scope
node.inner_scope = scope node.inner_scope = scope
node.type = cttype.type_object node.type = ctype
visitor = ScoperClassVisitor(scope, cur_class=cttype) visitor = ScoperClassVisitor(scope, cur_class=TypeType(ctype))
visitor.visit_block(node.body) visitor.visit_block(node.body)
for base in bases_after: for base in bases_after:
base = self.expr().visit(base) base = self.expr().visit(base)
if is_builtin(base, "Enum"): if is_builtin(base, "Enum"):
cttype.type_object.parents.append(TY_INT) ctype.parents.append(TY_INT)
for k, m in cttype.type_object.fields.items(): for k, m in ctype.fields.items():
m.type = cttype.type_object m.type = ctype
m.val = ast.literal_eval(m.val) m.val = ast.literal_eval(m.val)
assert type(m.val) == int assert type(m.val) == int
cttype.type_object.fields["value"] = MemberDef(TY_INT) ctype.fields["value"] = MemberDef(TY_INT)
lnd = linenodata(node) lnd = linenodata(node)
init_method = ast.FunctionDef( init_method = ast.FunctionDef(
name="__init__", name="__init__",
...@@ -214,7 +194,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -214,7 +194,7 @@ class ScoperBlockVisitor(ScoperVisitor):
_, rtype = visitor.visit_FunctionDef(init_method) _, rtype = visitor.visit_FunctionDef(init_method)
visitor.visit_function_definition(init_method, rtype) visitor.visit_function_definition(init_method, rtype)
node.body.append(init_method) node.body.append(init_method)
cttype.type_object.is_enum = True ctype.is_enum = True
else: else:
raise NotImplementedError(base) raise NotImplementedError(base)
for deco in node.decorator_list: for deco in node.decorator_list:
...@@ -226,7 +206,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -226,7 +206,7 @@ class ScoperBlockVisitor(ScoperVisitor):
init_method = ast.FunctionDef( init_method = ast.FunctionDef(
name="__init__", name="__init__",
args=ast.arguments( args=ast.arguments(
args=[ast.arg(arg="self"), * [ast.arg(arg=n) for n in cttype.type_object.get_members()]], args=[ast.arg(arg="self"), * [ast.arg(arg=n) for n in ctype.get_members()]],
defaults=[], defaults=[],
kw_defaults=[], kw_defaults=[],
kwarg=None, kwarg=None,
...@@ -238,7 +218,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -238,7 +218,7 @@ class ScoperBlockVisitor(ScoperVisitor):
targets=[ast.Attribute(value=ast.Name(id="self"), attr=n)], targets=[ast.Attribute(value=ast.Name(id="self"), attr=n)],
value=ast.Name(id=n), value=ast.Name(id=n),
**lnd **lnd
) for n in cttype.type_object.get_members() ) for n in ctype.get_members()
], ],
decorator_list=[], decorator_list=[],
returns=None, returns=None,
...@@ -250,6 +230,59 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -250,6 +230,59 @@ class ScoperBlockVisitor(ScoperVisitor):
node.body.append(init_method) node.body.append(init_method)
else: else:
raise NotImplementedError(deco) raise NotImplementedError(deco)
return ctype
def visit_ClassDef(self, node: ast.ClassDef):
copied = copy.deepcopy(node)
class NewUserType(UserType):
def __init__(self):
super().__init__(node.name)
#ctype = UserType(node.name)
typevars = []
bases_after = []
for base in node.bases:
if isinstance(base, ast.Subscript):
if isinstance(base.slice, ast.Name):
sliceval = [base.slice.id]
elif isinstance(base.slice, ast.Tuple):
sliceval = [n.id for n in base.slice.elts]
if is_builtin(self.expr().visit(base.value), "Generic"):
typevars = sliceval
else:
bases_after.append(base)
if typevars:
# generic
#ctype = GenericUserType(node.name, typevars, node)
var_scope = self.scope.child(ScopeKind.GLOBAL)
var_visitor = ScoperBlockVisitor(var_scope, self.root_decls)
node.gen_instances = {}
class OurGenericType(GenericUserType):
# def __init__(self, *args):
# super().__init__(node.name)
# for tv, arg in zip(typevars, args):
# var_scope.declare_local(tv, arg)
# var_visitor.process_class_ast(self, node, bases_after)
def __new__(cls, *args, **kwargs):
res = MonomorphizedUserType(node.name + "$$" + "__".join(map(str, args)) + "$$")
for tv, arg in zip(typevars, args):
var_scope.declare_local(tv, arg)
new_node = copy.deepcopy(copied)
new_node.name = res.name
var_visitor.process_class_ast(res, new_node, bases_after)
node.gen_instances[tuple(args)] = new_node
return res
ctype = OurGenericType
else:
# not generic
ctype = self.process_class_ast(UserType(node.name), node, bases_after)
cttype = TypeType(OurGenericType)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, cttype)
def visit_If(self, node: ast.If): def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER) scope = self.scope.child(ScopeKind.FUNCTION_INNER)
......
...@@ -30,7 +30,7 @@ class ScoperVisitor(NodeVisitorSeq): ...@@ -30,7 +30,7 @@ class ScoperVisitor(NodeVisitorSeq):
return res return res
def annotate_arg(self, arg: ast.arg) -> BaseType: def annotate_arg(self, arg: ast.arg) -> BaseType:
if arg.annotation is None: if arg.annotation is None or isinstance(arg.annotation, AnnotationName):
res = TypeVariable() res = TypeVariable()
arg.annotation = AnnotationName(res) arg.annotation = AnnotationName(res)
return res return res
......
...@@ -5,11 +5,11 @@ from dataclasses import dataclass, field ...@@ -5,11 +5,11 @@ from dataclasses import dataclass, field
from typing import Optional, List, Dict from typing import Optional, List, Dict
from logging import debug from logging import debug
from transpiler.phases.typing.annotations import TypeAnnotationVisitor from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.common import PRELUDE from transpiler.phases.typing.common import PRELUDE, is_builtin
from transpiler.phases.typing.expr import ScoperExprVisitor from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.scope import Scope, VarDecl, VarKind, ScopeKind from transpiler.phases.typing.scope import Scope, VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, TypeOperator, FunctionType, TY_VARARG, TypeType, TypeVariable, \ from transpiler.phases.typing.types import BaseType, TypeOperator, FunctionType, TY_VARARG, TypeType, TypeVariable, \
MemberDef MemberDef, BuiltinFeature
from transpiler.phases.utils import NodeVisitorSeq from transpiler.phases.utils import NodeVisitorSeq
...@@ -128,6 +128,8 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -128,6 +128,8 @@ class StdlibVisitor(NodeVisitorSeq):
def visit_Call(self, node: ast.Call) -> BaseType: def visit_Call(self, node: ast.Call) -> BaseType:
ty_op = self.visit(node.func) ty_op = self.visit(node.func)
if is_builtin(ty_op, "TypeVar"):
return TypeType(TypeVariable(*[ast.literal_eval(arg) for arg in node.args]))
if isinstance(ty_op, TypeType): if isinstance(ty_op, TypeType):
return TypeType(ty_op.type_object(*[ast.literal_eval(arg) for arg in node.args])) return TypeType(ty_op.type_object(*[ast.literal_eval(arg) for arg in node.args]))
raise NotImplementedError(ast.unparse(node)) raise NotImplementedError(ast.unparse(node))
...@@ -142,4 +144,6 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -142,4 +144,6 @@ class StdlibVisitor(NodeVisitorSeq):
raise UnknownNameError(node) raise UnknownNameError(node)
def visit_Name(self, node: ast.Name) -> BaseType: def visit_Name(self, node: ast.Name) -> BaseType:
if node.id == "TypeVar":
return BuiltinFeature("TypeVar")
return self.visit_str(node.id) return self.visit_str(node.id)
\ No newline at end of file
import ast
import dataclasses import dataclasses
import typing import typing
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -321,6 +322,9 @@ class TypeOperator(BaseType, ABC): ...@@ -321,6 +322,9 @@ class TypeOperator(BaseType, ABC):
def __str__(self): def __str__(self):
return self.name + (f"<{', '.join(map(str, self.args))}>" if self.args else "") return self.name + (f"<{', '.join(map(str, self.args))}>" if self.args else "")
def __repr__(self):
return self.__str__()
def __hash__(self): def __hash__(self):
return hash((self.name, tuple(self.args))) return hash((self.name, tuple(self.args)))
...@@ -569,3 +573,8 @@ class UnionType(TypeOperator): ...@@ -569,3 +573,8 @@ class UnionType(TypeOperator):
return (set(self.args) - {TY_NONE}).pop() return (set(self.args) - {TY_NONE}).pop()
return False return False
class GenericUserType(UserType):
pass
class MonomorphizedUserType(UserType):
pass
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment