Commit e84a9530 authored by Tom Niget's avatar Tom Niget

User generics attempt #5: monomorphization works

parent 101d8867
......@@ -5,9 +5,14 @@ from dataclasses import dataclass
T = TypeVar("T")
@dataclass
class Thing():
x: int
class Thing(Generic[T]):
x: T
if __name__ == "__main__":
a = Thing(1)
\ No newline at end of file
a = Thing[int](1)
b = Thing[str]("abc")
print(a)
print(b)
......@@ -3,8 +3,9 @@ import ast
from dataclasses import dataclass, field
from typing import Iterable, Optional
from transpiler.phases.typing.common import is_builtin
from transpiler.phases.typing.scope import Scope
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TypeVariable, Promise
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TypeVariable, Promise, TypeType
from transpiler.utils import compare_ast
from transpiler.phases.emit_cpp import NodeVisitor, CoroutineMode, flatmap, FunctionEmissionKind
from transpiler.phases.emit_cpp.expr import ExpressionVisitor
......@@ -166,6 +167,9 @@ class BlockVisitor(NodeVisitor):
def visit_Assign(self, node: ast.Assign) -> Iterable[str]:
if len(node.targets) != 1:
raise NotImplementedError(node)
if isinstance(node.targets[0].type, TypeType) and isinstance(node.targets[0].type.type_object, TypeVariable):
yield from ()
return
#if node.value.type
yield from self.visit_lvalue(node.targets[0], node.is_declare)
yield " = "
......
......@@ -9,6 +9,11 @@ from transpiler.phases.emit_cpp import NodeVisitor, FunctionEmissionKind
class ClassVisitor(NodeVisitor):
def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]:
if gen_instances := getattr(node, "gen_instances", None):
for args, inst in gen_instances.items():
yield from self.visit_ClassDef(inst)
return
yield f"struct {node.name}_s;"
yield f"extern {node.name}_s {node.name};"
yield f"struct {node.name}_s {{"
......
......@@ -3,7 +3,8 @@ import ast
from dataclasses import dataclass, field
from typing import List, Iterable
from transpiler.phases.typing.types import UserType, FunctionType, Promise, TypeType
from transpiler.phases.typing.types import UserType, FunctionType, Promise, TypeType, GenericUserType, \
MonomorphizedUserType
from transpiler.phases.utils import make_lnd
from transpiler.utils import compare_ast, linenodata
from transpiler.phases.emit_cpp.consts import SYMBOLS, PRECEDENCE_LEVELS, DUNDER_SYMBOLS
......@@ -277,6 +278,9 @@ class ExpressionVisitor(NodeVisitor):
yield "{}"
def visit_Subscript(self, node: ast.Subscript) -> Iterable[str]:
if isinstance(node.type.type_object, MonomorphizedUserType):
yield node.type.type_object.name
return
yield from self.prec("[]").visit(node.value)
yield "["
yield from self.reset().visit(node.slice)
......
......@@ -128,6 +128,10 @@ class ModuleVisitorExt(NodeVisitor):
yield f'm.def("{node.name}", PROGRAMNS::{node.name});'
def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]:
if gen_instances := getattr(node, "gen_instances", None):
for args, inst in gen_instances.items():
yield from self.visit_ClassDef(inst)
return
yield f"py::class_<PROGRAMNS::{node.name}_s::py_type>(m, \"{node.name}\")"
if init := node.type.fields.get("__init__", None):
init = init.type.resolve().remove_self()
......
......@@ -25,7 +25,7 @@ PRELUDE.vars.update({
"complex": VarDecl(VarKind.LOCAL, TypeType(TY_COMPLEX)),
"None": VarDecl(VarKind.LOCAL, TypeType(TY_NONE)),
"Callable": VarDecl(VarKind.LOCAL, TypeType(FunctionType)),
"TypeVar": VarDecl(VarKind.LOCAL, TypeType(TypeVariable)),
#"TypeVar": VarDecl(VarKind.LOCAL, TypeType(TypeVariable)),
"CppType": VarDecl(VarKind.LOCAL, TypeType(CppType)),
"list": VarDecl(VarKind.LOCAL, TypeType(PyList)),
"dict": VarDecl(VarKind.LOCAL, TypeType(PyDict)),
......
import ast
import copy
import dataclasses
import importlib
from dataclasses import dataclass
......@@ -12,7 +13,7 @@ from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \
Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType, BuiltinFeature, TY_INT, MemberDef, \
RuntimeValue
RuntimeValue, GenericUserType, MonomorphizedUserType
from transpiler.phases.utils import PlainBlock, AnnotationName
......@@ -150,44 +151,23 @@ class ScoperBlockVisitor(ScoperVisitor):
ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype)
def visit_ClassDef(self, node: ast.ClassDef):
class NewUserType(UserType):
def __init__(self):
super().__init__(node.name)
#ctype = UserType(node.name)
ctype = NewUserType
cttype = TypeType(ctype)
typevars = []
bases_after = []
for base in node.bases:
if isinstance(base, ast.Subscript):
if isinstance(base.slice, ast.Name):
sliceval = [base.slice.id]
elif isinstance(base.slice, ast.Tuple):
sliceval = [n.id for n in base.slice.elts]
if is_builtin(self.expr().visit(base.slice), "Generic"):
typevars = sliceval
else:
bases_after.append(base)
if not typevars:
cttype.type_object = cttype.type_object()
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, cttype)
def process_class_ast(self, ctype: BaseType, node: ast.ClassDef, bases_after: list[ast.expr]):
scope = self.scope.child(ScopeKind.CLASS)
scope.obj_type = cttype.type_object
scope.obj_type = ctype
scope.class_ = scope
node.inner_scope = scope
node.type = cttype.type_object
visitor = ScoperClassVisitor(scope, cur_class=cttype)
node.type = ctype
visitor = ScoperClassVisitor(scope, cur_class=TypeType(ctype))
visitor.visit_block(node.body)
for base in bases_after:
base = self.expr().visit(base)
if is_builtin(base, "Enum"):
cttype.type_object.parents.append(TY_INT)
for k, m in cttype.type_object.fields.items():
m.type = cttype.type_object
ctype.parents.append(TY_INT)
for k, m in ctype.fields.items():
m.type = ctype
m.val = ast.literal_eval(m.val)
assert type(m.val) == int
cttype.type_object.fields["value"] = MemberDef(TY_INT)
ctype.fields["value"] = MemberDef(TY_INT)
lnd = linenodata(node)
init_method = ast.FunctionDef(
name="__init__",
......@@ -214,7 +194,7 @@ class ScoperBlockVisitor(ScoperVisitor):
_, rtype = visitor.visit_FunctionDef(init_method)
visitor.visit_function_definition(init_method, rtype)
node.body.append(init_method)
cttype.type_object.is_enum = True
ctype.is_enum = True
else:
raise NotImplementedError(base)
for deco in node.decorator_list:
......@@ -226,7 +206,7 @@ class ScoperBlockVisitor(ScoperVisitor):
init_method = ast.FunctionDef(
name="__init__",
args=ast.arguments(
args=[ast.arg(arg="self"), * [ast.arg(arg=n) for n in cttype.type_object.get_members()]],
args=[ast.arg(arg="self"), * [ast.arg(arg=n) for n in ctype.get_members()]],
defaults=[],
kw_defaults=[],
kwarg=None,
......@@ -238,7 +218,7 @@ class ScoperBlockVisitor(ScoperVisitor):
targets=[ast.Attribute(value=ast.Name(id="self"), attr=n)],
value=ast.Name(id=n),
**lnd
) for n in cttype.type_object.get_members()
) for n in ctype.get_members()
],
decorator_list=[],
returns=None,
......@@ -250,6 +230,59 @@ class ScoperBlockVisitor(ScoperVisitor):
node.body.append(init_method)
else:
raise NotImplementedError(deco)
return ctype
def visit_ClassDef(self, node: ast.ClassDef):
copied = copy.deepcopy(node)
class NewUserType(UserType):
def __init__(self):
super().__init__(node.name)
#ctype = UserType(node.name)
typevars = []
bases_after = []
for base in node.bases:
if isinstance(base, ast.Subscript):
if isinstance(base.slice, ast.Name):
sliceval = [base.slice.id]
elif isinstance(base.slice, ast.Tuple):
sliceval = [n.id for n in base.slice.elts]
if is_builtin(self.expr().visit(base.value), "Generic"):
typevars = sliceval
else:
bases_after.append(base)
if typevars:
# generic
#ctype = GenericUserType(node.name, typevars, node)
var_scope = self.scope.child(ScopeKind.GLOBAL)
var_visitor = ScoperBlockVisitor(var_scope, self.root_decls)
node.gen_instances = {}
class OurGenericType(GenericUserType):
# def __init__(self, *args):
# super().__init__(node.name)
# for tv, arg in zip(typevars, args):
# var_scope.declare_local(tv, arg)
# var_visitor.process_class_ast(self, node, bases_after)
def __new__(cls, *args, **kwargs):
res = MonomorphizedUserType(node.name + "$$" + "__".join(map(str, args)) + "$$")
for tv, arg in zip(typevars, args):
var_scope.declare_local(tv, arg)
new_node = copy.deepcopy(copied)
new_node.name = res.name
var_visitor.process_class_ast(res, new_node, bases_after)
node.gen_instances[tuple(args)] = new_node
return res
ctype = OurGenericType
else:
# not generic
ctype = self.process_class_ast(UserType(node.name), node, bases_after)
cttype = TypeType(OurGenericType)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, cttype)
def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
......
......@@ -30,7 +30,7 @@ class ScoperVisitor(NodeVisitorSeq):
return res
def annotate_arg(self, arg: ast.arg) -> BaseType:
if arg.annotation is None:
if arg.annotation is None or isinstance(arg.annotation, AnnotationName):
res = TypeVariable()
arg.annotation = AnnotationName(res)
return res
......
......@@ -5,11 +5,11 @@ from dataclasses import dataclass, field
from typing import Optional, List, Dict
from logging import debug
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.common import PRELUDE
from transpiler.phases.typing.common import PRELUDE, is_builtin
from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.scope import Scope, VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, TypeOperator, FunctionType, TY_VARARG, TypeType, TypeVariable, \
MemberDef
MemberDef, BuiltinFeature
from transpiler.phases.utils import NodeVisitorSeq
......@@ -128,6 +128,8 @@ class StdlibVisitor(NodeVisitorSeq):
def visit_Call(self, node: ast.Call) -> BaseType:
ty_op = self.visit(node.func)
if is_builtin(ty_op, "TypeVar"):
return TypeType(TypeVariable(*[ast.literal_eval(arg) for arg in node.args]))
if isinstance(ty_op, TypeType):
return TypeType(ty_op.type_object(*[ast.literal_eval(arg) for arg in node.args]))
raise NotImplementedError(ast.unparse(node))
......@@ -142,4 +144,6 @@ class StdlibVisitor(NodeVisitorSeq):
raise UnknownNameError(node)
def visit_Name(self, node: ast.Name) -> BaseType:
if node.id == "TypeVar":
return BuiltinFeature("TypeVar")
return self.visit_str(node.id)
\ No newline at end of file
import ast
import dataclasses
import typing
from abc import ABC, abstractmethod
......@@ -321,6 +322,9 @@ class TypeOperator(BaseType, ABC):
def __str__(self):
return self.name + (f"<{', '.join(map(str, self.args))}>" if self.args else "")
def __repr__(self):
return self.__str__()
def __hash__(self):
return hash((self.name, tuple(self.args)))
......@@ -569,3 +573,8 @@ class UnionType(TypeOperator):
return (set(self.args) - {TY_NONE}).pop()
return False
class GenericUserType(UserType):
pass
class MonomorphizedUserType(UserType):
pass
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment