Commit 02310e9a authored by Tom Niget's avatar Tom Niget

Add support for future()

parent 1afe2cb5
from typon import future
def fibo(n: int) -> int:
if n < 2:
return n
a = future(lambda: fibo(n - 1))
b = future(lambda: fibo(n - 2))
return a + b
if __name__ == "__main__":
print(fibo(30)) # should display 832040
\ No newline at end of file
......@@ -34,7 +34,7 @@ class BlockVisitor(NodeVisitor):
def visit_Call(self, node: ast.Call):
func = node.func
if compare_ast(func, ast.parse('fork', mode="eval").body):
if compare_ast(func, ast.parse("fork", mode="eval").body):
yield CoroutineMode.JOIN
yield from ()
......
......@@ -74,11 +74,13 @@ class ExpressionVisitor(NodeVisitor):
def visit_Name(self, node: ast.Name) -> Iterable[str]:
res = self.fix_name(node.id)
if (decl := self.scope.get(res)):
if decl := self.scope.get(res):
if decl.kind == VarKind.SELF:
res = "(*this)"
elif decl.future and CoroutineMode.ASYNC in self.generator:
res = f"{res}.get()"
if decl.future == "future":
res = "co_await " + res
yield res
def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
......@@ -98,23 +100,24 @@ class ExpressionVisitor(NodeVisitor):
if getattr(node, "kwargs", None):
raise NotImplementedError(node, "kwargs")
func = node.func
if compare_ast(func, ast.parse('fork', mode="eval").body):
for name in ("fork", "future"):
if compare_ast(func, ast.parse(name, mode="eval").body):
assert len(node.args) == 1
arg = node.args[0]
assert isinstance(arg, ast.Lambda)
node.is_future = True
node.is_future = name
vis = self.reset()
vis.generator = CoroutineMode.SYNC
# todo: bad code
if CoroutineMode.ASYNC in self.generator:
yield "co_await typon::fork("
yield f"co_await typon::{name}("
yield from vis.visit(arg.body)
yield ")"
return
elif CoroutineMode.FAKE in self.generator:
yield from self.visit(arg.body)
return
elif compare_ast(func, ast.parse('sync', mode="eval").body):
if compare_ast(func, ast.parse('sync', mode="eval").body):
if CoroutineMode.ASYNC in self.generator:
yield "co_await typon::Sync()"
elif CoroutineMode.FAKE in self.generator:
......
......@@ -8,6 +8,11 @@ def fork(f: Callable[[], T]) -> T:
return f()
def future(f: Callable[[], T]) -> T:
# stub
return f()
def sync():
# stub
pass
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment