Commit d8a51700 authored by Tom Niget's avatar Tom Niget

Update stuff

parent 22a67b3a
......@@ -221,7 +221,7 @@ class ExpressionVisitor(NodeVisitor):
use_dot = None
if type(node.value.type) == TypeType:
use_dot = "dots"
elif isinstance(node.type, FunctionType) and not isinstance(node.value.type, Promise):
elif isinstance(node.type, FunctionType) and node.type.is_method and not isinstance(node.value.type, Promise):
if node.value.type.resolve().is_reference:
use_dot = "dotp"
else:
......
......@@ -285,3 +285,21 @@ class OutsideLoopError(CompileError):
def detail(self, last_node: ast.AST = None) -> str:
return ""
@dataclass
class MissingReturnError(CompileError):
node: ast.FunctionDef
def __str__(self) -> str:
return f"Missing return: not all code paths in {highlight(self.node)} return"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that a function is missing a {highlight('return')} statement in one or more of its code paths.
For example:
{highlight('def f(x: int):')}
{highlight(' if x > 0:')}
{highlight(' return 1')}
{highlight(' # if x <= 0, the function returns nothing')}
"""
\ No newline at end of file
......@@ -67,7 +67,7 @@ class ScoperExprVisitor(ScoperVisitor):
assert ftype.kind == PromiseKind.TASK
ftype.kind = PromiseKind.GENERATOR
ftype.return_type.unify(ytype)
self.scope.function.has_return = True
self.scope.function.has_yield = True
return TY_NONE
......
......@@ -54,7 +54,7 @@ class Scope:
vars: Dict[str, VarDecl] = field(default_factory=dict)
children: List["Scope"] = field(default_factory=list)
obj_type: Optional[BaseType] = None
has_return: bool = False
diverges: bool = False
class_: Optional["Scope"] = None
is_loop: Optional[ast.For | ast.While] = None
......
......@@ -105,6 +105,7 @@ class StdlibVisitor(NodeVisitorSeq):
ty.variadic = True
ty.optional_at = 1 + len(node.args.args) - len(node.args.defaults)
if self.cur_class:
ty.is_method = True
assert isinstance(self.cur_class, TypeType)
if isinstance(self.cur_class.type_object, ABCMeta):
self.cur_class.type_object.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars)
......
......@@ -289,12 +289,12 @@ class TypeOperator(BaseType, ABC):
vardict = dict(zip(typevars.keys(), this.args))
else:
vardict = typevars
for k in dataclasses.fields(self):
setattr(res, k.name, getattr(self, k.name))
for k, v in self.__dict__.items():
setattr(res, k, v)
res.args = [arg.resolve().gen_sub(this, vardict, cache) for arg in self.args]
res.methods = {k: v.gen_sub(this, vardict, cache) for k, v in self.methods.items()}
res.parents = [p.gen_sub(this, vardict, cache) for p in self.parents]
res.is_protocol = self.is_protocol
#res.is_protocol = self.is_protocol
return res
def to_list(self) -> List["BaseType"]:
......@@ -308,6 +308,7 @@ class ModuleType(TypeOperator):
class FunctionType(TypeOperator):
is_python_func: bool = False
python_func_used: bool = False
is_method: bool = False
def __iter__(self):
x = 5
......@@ -331,16 +332,19 @@ class FunctionType(TypeOperator):
def __str__(self):
ret, *args = map(str, self.args)
if self.optional_at is not None:
args = args[:self.optional_at] + [f"{x}=..." for x in args[self.optional_at:]]
if self.variadic:
args.append(f"*args")
args.append("*args")
if args:
args = f"({', '.join(args)})"
args = f"{', '.join(args)}"
else:
args = "()"
return f"{args} -> {ret}"
args = ""
return f"({args}) -> {ret}"
def remove_self(self):
res = FunctionType(self.parameters[1:], self.return_type)
res.is_method = self.is_method
res.variadic = self.variadic
res.optional_at = self.optional_at - 1 if self.optional_at is not None else None
return res
......@@ -460,8 +464,12 @@ class Promise(TypeOperator, ABC):
@kind.setter
def kind(self, value: PromiseKind):
if value == PromiseKind.GENERATOR:
self.methods["__iter__"] = FunctionType([], self)
self.methods["__next__"] = FunctionType([], self.return_type)
f_iter = FunctionType([], self)
f_iter.is_method = True
self.methods["__iter__"] = f_iter
f_next = FunctionType([], self.return_type)
f_next.is_method = True
self.methods["__next__"] = f_next
self.args[1].val = value
def __str__(self):
......@@ -506,4 +514,10 @@ class UserType(TypeOperator):
class UnionType(TypeOperator):
def __init__(self, *args: List[BaseType]):
super().__init__(args, "Union")
self.parents.extend(args)
self.parents.extend(set(args))
def is_optional(self):
if len(self.args) == 2 and TY_NONE in self.args:
return (set(self.args) - {TY_NONE}).pop()
return False
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment