Commit 9a209460 authored by Tom Niget's avatar Tom Niget

Make inference run in multiple passes

parent 06f91a66
...@@ -57,8 +57,8 @@ path = Path(args.input[0]) ...@@ -57,8 +57,8 @@ path = Path(args.input[0])
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
code = f.read() code = f.read()
from .transpiler import transpile from transpiler import transpile
from .transpiler.format import format_code from transpiler.format import format_code
raw_cpp = transpile(code, path.name, path) raw_cpp = transpile(code, path.name, path)
formatted = format_code(raw_cpp) formatted = format_code(raw_cpp)
......
import sys import sys
import math import math
x = [6] x = [6]
x = 5
u = (math.abcd) # abcd
a = 5 if True else 3
def c(x: int):
return x
for v in 6:
g = 6
h = 7
i = 8
pass
if __name__ == "__main__": if __name__ == "__main__":
pass pass
\ No newline at end of file
...@@ -29,7 +29,7 @@ def read_file(path): ...@@ -29,7 +29,7 @@ def read_file(path):
fd.close() fd.close()
return content return content
def handle_connection(connfd: socket, filepath): def handle_connection(connfd, filepath):
buf = connfd.recv(1024).decode("utf-8") buf = connfd.recv(1024).decode("utf-8")
length = buf.find("\r\n\r\n") length = buf.find("\r\n\r\n")
content = read_file(filepath) content = read_file(filepath)
...@@ -37,7 +37,7 @@ def handle_connection(connfd: socket, filepath): ...@@ -37,7 +37,7 @@ def handle_connection(connfd: socket, filepath):
connfd.send(response.encode("utf-8")) connfd.send(response.encode("utf-8"))
connfd.close() connfd.close()
def server_loop(sockfd: socket, filepath): def server_loop(sockfd, filepath):
while True: while True:
connfd, _ = sockfd.accept() connfd, _ = sockfd.accept()
......
...@@ -25,8 +25,8 @@ class BlockVisitor(NodeVisitor): ...@@ -25,8 +25,8 @@ class BlockVisitor(NodeVisitor):
def visit_Pass(self, node: ast.Pass) -> Iterable[str]: def visit_Pass(self, node: ast.Pass) -> Iterable[str]:
yield ";" yield ";"
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]: # def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield from self.visit_free_func(node) # yield from self.visit_free_func(node)
def visit_free_func(self, node: ast.FunctionDef, emission: FunctionEmissionKind) -> Iterable[str]: def visit_free_func(self, node: ast.FunctionDef, emission: FunctionEmissionKind) -> Iterable[str]:
if getattr(node, "is_main", False): if getattr(node, "is_main", False):
......
...@@ -12,5 +12,6 @@ class IfMainVisitor(ast.NodeVisitor): ...@@ -12,5 +12,6 @@ class IfMainVisitor(ast.NodeVisitor):
new_node = ast.parse("def main(): pass").body[0] new_node = ast.parse("def main(): pass").body[0]
new_node.body = stmt.body new_node.body = stmt.body
new_node.is_main = True new_node.is_main = True
node.main_if = new_node
node.body[i] = new_node node.body[i] = new_node
return return
\ No newline at end of file
...@@ -28,12 +28,29 @@ class ScoperVisitor(NodeVisitorSeq): ...@@ -28,12 +28,29 @@ class ScoperVisitor(NodeVisitorSeq):
self.fdecls = [] self.fdecls = []
for b in block: for b in block:
self.visit(b) self.visit(b)
for node, rtype in self.fdecls: if self.fdecls:
for b in node.body: old_list = self.fdecls
decls = {} exc = None
visitor = ScoperBlockVisitor(node.inner_scope, decls) while True:
visitor.visit(b) new_list = []
b.decls = decls for node, rtype in old_list:
if not node.inner_scope.has_return: from transpiler.exceptions import CompileError
rtype.unify(TY_NONE) # todo: properly indicate missing return try:
for b in node.body:
decls = {}
visitor = ScoperBlockVisitor(node.inner_scope, decls)
visitor.visit(b)
b.decls = decls
if not node.inner_scope.has_return:
rtype.unify(TY_NONE) # todo: properly indicate missing return
except CompileError as e:
new_list.append((node, rtype))
if not exc:
exc = e
if len(new_list) == len(old_list):
raise exc
if not new_list:
break
old_list = new_list
exc = None
...@@ -34,7 +34,7 @@ DUNDER = { ...@@ -34,7 +34,7 @@ DUNDER = {
class ScoperExprVisitor(ScoperVisitor): class ScoperExprVisitor(ScoperVisitor):
def visit(self, node) -> BaseType: def visit(self, node) -> BaseType:
if existing := getattr(node, "type", None): if existing := getattr(node, "type", None):
return existing return existing.resolve()
res = super().visit(node) res = super().visit(node)
if not res: if not res:
raise NotImplementedError(f"`{ast.unparse(node)}` {type(node)}") raise NotImplementedError(f"`{ast.unparse(node)}` {type(node)}")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment