Commit c9cb938b authored by Tom Niget's avatar Tom Niget

Fork Sync works

parent 19d31953
......@@ -399,7 +399,7 @@ concept HasSync = requires(T t) { typename T::has_sync; };
auto call_sync(auto f) {
if constexpr (HasSync<decltype(f)>) {
return [f](auto... args) {
return f.sync(std::forward<decltype(args)>(args)...);
return f.typon$$sync(std::forward<decltype(args)>(args)...);
};
} else {
return f;
......
from typon import fork, sync
def fibo(n):
if n < 2:
......@@ -8,46 +7,47 @@ def fibo(n):
sync()
return a.get() + b.get()
"""
def fibo(n: int) -> int:
if n < 2:
return n
with sync(): # {
a = fork(lambda: fibo(n - 1))
b = fork(lambda: fibo(n - 2))
# }
return a + b
"""
"""
Task<int> fibo(int n) {
if (n < 2) {
return n;
}
Forked<int> a;
Forked<int> b;
{
a = fork(fibo(n - 1));
// cvcvc
b = fork(fibo(n - 2));
co_await sync();
}
co_return a.get() + b.get();
"""
"""
Task<int> fibo(int n) {
int a, b;
co_return []() -> Join<int> {
if (n < 2) {
return n;
}
co_await fork(fibo(n - 1), a);
co_await fork(fibo(n - 2), b);
co_await Sync();
co_return a + b;
}();
}
"""
# """
# def fibo(n: int) -> int:
# if n < 2:
# return n
# with sync(): # {
# a = fork(lambda: fibo(n - 1))
# b = fork(lambda: fibo(n - 2))
# # }
# return a + b
# """
# """
# Task<int> fibo(int n) {
# if (n < 2) {
# return n;
# }
# Forked<int> a;
# Forked<int> b;
# {
# a = fork(fibo(n - 1));
# // cvcvc
# b = fork(fibo(n - 2));
# co_await sync();
# }
# co_return a.get() + b.get();
# """
#
# """
# Task<int> fibo(int n) {
# int a, b;
# co_return []() -> Join<int> {
# if (n < 2) {
# return n;
# }
# co_await fork(fibo(n - 1), a);
# co_await fork(fibo(n - 2), b);
# co_await Sync();
# co_return a + b;
# }();
# }
# """
......
......@@ -5,7 +5,7 @@ from typing import Iterable
from transpiler.phases.emit_cpp.visitors import NodeVisitor, CoroutineMode, join
from transpiler.phases.typing.scope import Scope
from transpiler.phases.typing.types import ClassTypeType, TupleInstanceType, TY_FUTURE, ResolvedConcreteType
from transpiler.phases.typing.types import ClassTypeType, TupleInstanceType, TY_FUTURE, ResolvedConcreteType, TY_FORKED
from transpiler.phases.utils import make_lnd
from transpiler.utils import linenodata
......@@ -138,31 +138,46 @@ class ExpressionVisitor(NodeVisitor):
yield from self.visit(arg.body)
return
if isinstance(node.func, ast.Name) and node.func.id == "sync":
if self.generator != CoroutineMode.SYNC:
yield "co_await typon::Sync()"
else:
yield "(void)0"
return
is_get = isinstance(node.func, ast.Attribute) and node.func.attr == "get"
# async : co_await f(args)
# sync : call_sync(f, args)
if self.generator != CoroutineMode.SYNC:
nty = node.type.resolve()
if not (isinstance(nty, ResolvedConcreteType) and nty.inherits(TY_FUTURE)):
yield "co_await"
if isinstance(nty, ResolvedConcreteType) and (
nty.inherits(TY_FUTURE) or (
is_get and nty.inherits(TY_FORKED)
)
):
pass
else:
yield "call_sync"
if isinstance(node.func, ast.Attribute) and node.func.attr == "get" and node.func.value.type.inherits(TY_FUTURE):
yield "("
if self.generator == CoroutineMode.SYNC:
yield from self.visit(node.func.value)
yield "co_await"
else:
yield "("
if is_get and node.func.value.type.inherits(TY_FUTURE, TY_FORKED):
yield from self.visit(node.func.value)
yield ").get()"
yield ")"
return
yield "call_sync"
yield "("
if is_get and node.func.value.type.inherits(TY_FUTURE, TY_FORKED):
yield "("
yield from self.visit(node.func.value)
yield ").get"
else:
yield from self.visit(node.func)
yield ")("
yield from join(", ", map(self.visit, node.args))
yield ")"
......
......@@ -59,13 +59,13 @@ def emit_function(name: str, func: CallableInstanceType, base="function", gen_p=
try:
rty_code = " ".join(NodeVisitor().visit_BaseType(func.return_type))
except:
yield from emit_body("sync", CoroutineMode.SYNC, None)
yield from emit_body("typon$$sync", CoroutineMode.SYNC, None)
has_sync = True
yield "using has_sync = std::true_type;"
def task_type():
yield from NodeVisitor().visit_BaseType(func.return_type.generic_parent)
yield "<"
yield"decltype(sync("
yield"decltype(typon$$sync("
yield from join(",", (arg.arg for arg in func.block_data.node.args.args))
yield "))"
yield ">"
......
......@@ -197,8 +197,8 @@ class ResolvedConcreteType(ConcreteType):
return [self] + merge(*[p.get_mro() for p in self.parents], self.parents)
def inherits(self, parent: BaseType):
return self == parent or any(p.inherits(parent) for p in self.parents)
def inherits(self, *parent: BaseType):
return self in parent or any(p.inherits(*parent) for p in self.parents)
def try_assign_internal(self, other: BaseType) -> bool:
if self == other:
......@@ -264,8 +264,8 @@ class GenericInstanceType(ResolvedConcreteType):
def __init__(self):
super().__init__()
def inherits(self, parent: BaseType):
return self.generic_parent == parent or super().inherits(parent)
def inherits(self, *parent: BaseType):
return self.generic_parent in parent or super().inherits(*parent)
def __eq__(self, other):
if isinstance(other, GenericInstanceType):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment