Commit 6e8e1738 authored by Tom Niget's avatar Tom Niget

Handle recursive functions

parent 39b718cc
......@@ -386,9 +386,7 @@ using InterpGuard = py::scoped_interpreter;
#endif
template <typename T>
concept HasSync = requires(T t) {
{ t.sync() } -> std::same_as<T>;
};
concept HasSync = requires(T t) { typename T::has_sync; };
/*auto call_sync(auto f, auto... args) {
if constexpr (HasSync<decltype(f)>) {
......
......@@ -43,7 +43,9 @@ def emit_function(name: str, func: CallableInstanceType, base="function", gen_p=
yield ")>::type"
yield var
yield ";"
yield from BlockVisitor(func.block_data.scope, generator=mode).visit(func.block_data.node.body)
vis = BlockVisitor(func.block_data.scope, generator=mode)
for stmt in func.block_data.node.body:
yield from vis.visit(stmt)
if not getattr(func.block_data.scope, "has_return", False):
if mode == CoroutineMode.SYNC:
yield "return"
......@@ -53,10 +55,13 @@ def emit_function(name: str, func: CallableInstanceType, base="function", gen_p=
yield "}"
rty = func.return_type.generic_args[0]
has_sync = False
try:
rty_code = " ".join(NodeVisitor().visit_BaseType(func.return_type))
except:
yield from emit_body("sync", CoroutineMode.SYNC, None)
has_sync = True
yield "using has_sync = std::true_type;"
def task_type():
yield from NodeVisitor().visit_BaseType(func.return_type.generic_parent)
yield "<"
......@@ -70,6 +75,8 @@ def emit_function(name: str, func: CallableInstanceType, base="function", gen_p=
yield from emit_body("operator()", CoroutineMode.TASK, rty_code)
yield f"}} static constexpr {name} {{}};"
if has_sync:
yield f"static_assert(HasSync<decltype({name})>);"
yield f"static_assert(sizeof {name} == 1);"
......
......@@ -43,6 +43,7 @@ def emit_module(mod: ModuleType) -> Iterable[str]:
x = 5
match ty:
case CallableInstanceType():
ty.generic_parent.instance_cache = []
ScoperExprVisitor(ty.block_data.scope).visit_function_call(ty, [TypeVariable() for _ in ty.parameters])
yield from emit_function(name, ty, gen_p=gen_p)
case GenericInstanceType() if isinstance(ty.generic_parent, UserGenericType):
......
......@@ -168,16 +168,20 @@ class ScoperExprVisitor(ScoperVisitor):
ftype.block_data.scope.declare_local(pname, b)
if not ftype.is_native:
from transpiler.phases.typing.block import ScoperBlockVisitor
scope = ftype.block_data.scope
vis = ScoperBlockVisitor(scope)
for stmt in ftype.block_data.node.body:
vis.visit(stmt)
if not getattr(scope.function, "has_return", False):
stmt = ast.Return()
ftype.block_data.node.body.append(stmt)
vis.visit(stmt)
#ftype.generic_parent.cache_instance(ftype)
existing = ftype.generic_parent.find_cached_instance(ftype.generic_args)
if not existing:
ftype.generic_parent.cache_instance(ftype.generic_args, ftype)
from transpiler.phases.typing.block import ScoperBlockVisitor
scope = ftype.block_data.scope
vis = ScoperBlockVisitor(scope)
for stmt in ftype.block_data.node.body:
vis.visit(stmt)
if not getattr(scope.function, "has_return", False):
stmt = ast.Return()
ftype.block_data.node.body.append(stmt)
vis.visit(stmt)
else:
return existing.return_type.resolve()
return ftype.return_type.resolve()
# if isinstance(ftype, TypeType):# and isinstance(ftype.type_object, UserType):
# init: FunctionType = self.visit_getattr(ftype, "__init__").remove_self()
......
......@@ -320,7 +320,7 @@ class GenericConstraint:
@dataclass(eq=False)
class GenericType(BaseType):
parameters: list[GenericParameter] = field(default_factory=list, init=False)
instance_cache: dict[object, GenericInstanceType] = field(default_factory=dict, init=False)
instance_cache: list[(object, GenericInstanceType)] = field(default_factory=list, init=False)
def constraints(self, args: list[ConcreteType]) -> list[GenericConstraint]:
return []
......@@ -348,11 +348,17 @@ class GenericType(BaseType):
def deref(self):
return self.instantiate_default().deref()
def cache_instance(self, instance):
if not hasattr(self, "instance_cache"):
self.instance_cache = {}
self.instance_cache[tuple(instance.generic_args)] = instance
def find_cached_instance(self, args):
for inst_args, inst in self.instance_cache:
if all(inst_arg.try_assign(arg) for inst_arg, arg in zip(inst_args, args)):
return inst
return None
def cache_instance(self, args, instance):
if not hasattr(self, "instance_cache"):
self.instance_cache = []
if not self.find_cached_instance(args):
self.instance_cache.append((tuple(args), instance))
@dataclass(eq=False, init=False)
class BuiltinGenericType(UniqueTypeMixin, GenericType):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment