Commit 29793d0f authored by Tom Niget's avatar Tom Niget

Generic functions works

parent ed9984de
def f[T](x: T):
return x
if __name__ == "__main__":
#a = 5
print(f("abc"))
print(f(6))
import ast
from typing import Iterable
from transpiler.phases.emit_cpp.function import emit_function
from transpiler.phases.typing.types import ConcreteType
......@@ -15,6 +16,7 @@ def emit_class(name: str, node: ConcreteType) -> Iterable[str]:
yield f"struct Obj : referencemodel::instance<{name}__oo<>, Obj> {{"
# inner = ClassInnerVisitor4(node.inner_scope)
# for stmt in node.body:
# yield from inner.visit(stmt)
......@@ -27,6 +29,11 @@ def emit_class(name: str, node: ConcreteType) -> Iterable[str]:
yield "};"
for name, mdef in node.fields.items():
if isinstance(mdef.val, ast.FunctionDef):
yield from emit_function(name, mdef.type.deref(), "method")
yield "template <typename... T>"
yield "auto operator() (T&&... args) const {"
yield "return referencemodel::rc(Obj{std::forward<T>(args)...});"
......
......@@ -6,16 +6,26 @@ from transpiler.phases.emit_cpp.expr import ExpressionVisitor
from transpiler.phases.typing.common import IsDeclare
from transpiler.phases.typing.scope import Scope
from transpiler.phases.emit_cpp.visitors import NodeVisitor, flatmap, CoroutineMode, join
from transpiler.phases.typing.types import CallableInstanceType, BaseType
from transpiler.phases.typing.types import CallableInstanceType, BaseType, TypeVariable
def emit_function(name: str, func: CallableInstanceType) -> Iterable[str]:
yield f"struct : referencemodel::function {{"
def emit_function(name: str, func: CallableInstanceType, base="function", gen_p=None) -> Iterable[str]:
yield f"struct : referencemodel::{base} {{"
def emit_body(name: str, mode: CoroutineMode, rty):
#real_params = [p for p in func.generic_parent.parameters if not p.name.startswith("AutoVar$")]
real_params = func.generic_parent.parameters
if real_params:
yield "template<"
yield from join(",", (f"typename {p.name} = void" for p in real_params))
yield ">"
yield "auto"
yield name
yield "("
def emit_arg(arg, ty):
if isinstance(ty, TypeVariable) and ty.emit_as_is:
yield ty.var_name
else:
raise NotImplementedError("can this happen?")
yield "auto"
yield arg.arg
......
......@@ -4,7 +4,7 @@ from typing import Iterable
from transpiler.phases.emit_cpp.class_ import emit_class
from transpiler.phases.emit_cpp.function import emit_function
from transpiler.phases.typing.modules import ModuleType
from transpiler.phases.typing.types import CallableInstanceType, ClassTypeType, TypeVariable
from transpiler.phases.typing.types import CallableInstanceType, ClassTypeType, TypeVariable, BaseType
def emit_module(mod: ModuleType) -> Iterable[str]:
......@@ -29,13 +29,15 @@ def emit_module(mod: ModuleType) -> Iterable[str]:
for name, field in mod.fields.items():
if not field.in_class_def:
continue
ty = field.type.deref()
gen_p = [TypeVariable(p.name, emit_as_is=True) for p in field.type.parameters]
ty = field.type.instantiate(gen_p)
from transpiler.phases.typing.expr import ScoperExprVisitor
x = 5
match ty:
case CallableInstanceType():
ScoperExprVisitor(ty.block_data.scope).visit_function_call(ty, [TypeVariable() for _ in ty.parameters])
yield from emit_function(name, ty)
parameters_ = [TypeVariable() for _ in ty.parameters]
ScoperExprVisitor(ty.block_data.scope).visit_function_call(ty, parameters_)
yield from emit_function(name, ty, gen_p=gen_p)
case ClassTypeType(inner_type):
yield from emit_class(name, inner_type)
case _:
......
......@@ -71,7 +71,10 @@ class NodeVisitor(UniversalVisitor):
yield "typon::TyNone"
case types.TY_STR:
yield 'decltype(""_ps)'
case types.TypeVariable(name):
case types.TypeVariable(name, emit_as_is=em):
if em:
yield name
else:
yield f"$VAR__{name}"
#raise UnresolvedTypeVariableError(node)
......
import ast
import copy
import dataclasses
from abc import ABCMeta
from dataclasses import dataclass, field
......@@ -145,7 +146,7 @@ class StdlibVisitor(NodeVisitorSeq):
def visit_nongeneric(scope: Scope, output: ResolvedConcreteType):
cl_scope = scope.child(ScopeKind.CLASS)
cl_scope.declare_local("Self", output.type_type())
output.block_data = BlockData(node, scope)
output.block_data = BlockData(copy.deepcopy(node), scope)
visitor = StdlibVisitor(self.python_path, cl_scope, output, self.is_native)
bases = [self.anno().visit(base) for base in node.bases]
match bases:
......@@ -169,8 +170,10 @@ class StdlibVisitor(NodeVisitorSeq):
scope.function = scope
scope.obj_type = output
arg_visitor = TypeAnnotationVisitor(scope)
output.block_data = BlockData(node, scope)
output.block_data = BlockData(copy.deepcopy(node), scope)
output.parameters = [arg_visitor.visit(arg.annotation) for arg in node.args.args]
for arg, ty in zip(node.args.args, output.parameters):
scope.declare_local(arg.arg, ty)
output.return_type = arg_visitor.visit(node.returns)
output.optional_at = len(node.args.args) - len(node.args.defaults)
output.is_variadic = args.vararg is not None
......@@ -199,7 +202,7 @@ class StdlibVisitor(NodeVisitorSeq):
if i == 0 and self.cur_class is not None:
arg_name = "Self"
else:
arg_name = f"AutoVar${hash(arg.arg)}"
arg_name = f"AutoVar${abs(hash(arg.arg))}"
node.type_params.append(ast.TypeVar(arg_name, None)) # todo: bounds
arg.annotation = ast.Name(arg_name, ast.Load())
else:
......@@ -210,7 +213,7 @@ class StdlibVisitor(NodeVisitorSeq):
# annotation is type variable so we keep it
pass
else:
arg_name = f"AutoBoundedVar${hash(arg.arg)}"
arg_name = f"AutoBoundedVar${abs(hash(arg.arg))}"
node.type_params.append(ast.TypeVar(arg_name, arg.annotation))
arg.annotation = ast.Name(arg_name, ast.Load())
......
......@@ -118,6 +118,7 @@ class ConcreteType(BaseType):
class TypeVariable(ConcreteType):
var_name: str = field(default_factory=lambda: next_var_id())
resolved: Optional[ConcreteType] = None
emit_as_is: bool = False
def resolve(self) -> ConcreteType:
if self.resolved is None:
......@@ -557,11 +558,13 @@ class CallableInstanceType(GenericInstanceType, MethodType):
def remove_self(self, self_type):
assert self.parameters[0].try_assign(self_type)
return dataclasses.replace(
res = dataclasses.replace(
self,
parameters=self.parameters[1:],
optional_at=self.optional_at - 1,
optional_at=self.optional_at - 1
)
res.block_data = self.block_data
return res
def __str__(self):
return f"({", ".join(map(str, self.parameters + (["*args"] if self.is_variadic else [])))}) -> {self.return_type}"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment