Commit 06f91a66 authored by Tom Niget's avatar Tom Niget

Finish removing old error types

parent e0455237
import sys import sys
import math import math
x = [6]
x = 5 x = 5
#x: str = "str"
y = "ab" u = (math.abcd) # abcd
a = 5 if True else 3
def c(x: int):
return x
for v in 6:
g = 6
h = 7
i = 8
pass
if __name__ == "__main__": if __name__ == "__main__":
pass pass
\ No newline at end of file
...@@ -5,7 +5,7 @@ import importlib ...@@ -5,7 +5,7 @@ import importlib
import inspect import inspect
import os import os
os.environ["TERM"] = "xterm-256" #os.environ["TERM"] = "xterm-256"
import colorama import colorama
colorama.init() colorama.init()
...@@ -80,6 +80,10 @@ def exception_hook(exc_type, exc_value, tb): ...@@ -80,6 +80,10 @@ def exception_hook(exc_type, exc_value, tb):
old = hg[last_node.lineno - 1] old = hg[last_node.lineno - 1]
[start] = find_indices(old, [last_node.col_offset]) [start] = find_indices(old, [last_node.col_offset])
hg[last_node.lineno - 1] = old[:start] + "\x1b[4m" + old[start:] hg[last_node.lineno - 1] = old[:start] + "\x1b[4m" + old[start:]
for lineid in range(last_node.lineno, last_node.end_lineno - 1):
old = hg[lineid]
first_nonspace = len(old) - len(old.lstrip())
hg[lineid] = old[:first_nonspace] + "\x1b[4m" + old[first_nonspace:] + "\x1b[24m"
old = hg[last_node.end_lineno - 1] old = hg[last_node.end_lineno - 1]
first_nonspace = len(old) - len(old.lstrip()) first_nonspace = len(old) - len(old.lstrip())
[end] = find_indices(old, [last_node.end_col_offset]) [end] = find_indices(old, [last_node.end_col_offset])
...@@ -96,8 +100,10 @@ def exception_hook(exc_type, exc_value, tb): ...@@ -96,8 +100,10 @@ def exception_hook(exc_type, exc_value, tb):
print() print()
print(cf.red("Error:"), exc_value) print(cf.red("Error:"), exc_value)
if isinstance(exc_value, CompileError): if isinstance(exc_value, CompileError):
detail = inspect.cleandoc(exc_value.detail(last_node))
if detail:
print() print()
print(inspect.cleandoc(exc_value.detail(last_node))) print(detail)
print() print()
def find_indices(s, indices: list[int]) -> list[int]: def find_indices(s, indices: list[int]) -> list[int]:
......
...@@ -9,7 +9,7 @@ from transpiler.phases.typing.common import ScoperVisitor ...@@ -9,7 +9,7 @@ from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.expr import ScoperExprVisitor from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.class_ import ScoperClassVisitor from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, \ from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \
Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType
from transpiler.phases.utils import PlainBlock from transpiler.phases.utils import PlainBlock
...@@ -40,7 +40,8 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -40,7 +40,8 @@ class ScoperBlockVisitor(ScoperVisitor):
mod.type.is_python = True mod.type.is_python = True
self.scope.vars[name] = mod self.scope.vars[name] = mod
if mod is None: if mod is None:
raise NameError(name) from transpiler.phases.typing.exceptions import UnknownNameError
raise UnknownNameError(name)
assert isinstance(mod, VarDecl), mod assert isinstance(mod, VarDecl), mod
assert isinstance(mod.type, ModuleType), mod.type assert isinstance(mod.type, ModuleType), mod.type
return mod return mod
...@@ -81,10 +82,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -81,10 +82,7 @@ class ScoperBlockVisitor(ScoperVisitor):
raise NotImplementedError(node) raise NotImplementedError(node)
target = node.targets[0] target = node.targets[0]
ty = self.get_type(node.value) ty = self.get_type(node.value)
try:
node.is_declare = self.visit_assign_target(target, ty) node.is_declare = self.visit_assign_target(target, ty)
except IncompatibleTypesError as e:
raise IncompatibleTypesError(f"`{ast.unparse(node)}: {e}") from e
def visit_AnnAssign(self, node: ast.AnnAssign): def visit_AnnAssign(self, node: ast.AnnAssign):
# if node.value is not None: # if node.value is not None:
...@@ -116,8 +114,12 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -116,8 +114,12 @@ class ScoperBlockVisitor(ScoperVisitor):
self.root_decls[target.id] = VarDecl(VarKind.OUTER_DECL, decl_val) self.root_decls[target.id] = VarDecl(VarKind.OUTER_DECL, decl_val)
return True return True
elif isinstance(target, ast.Tuple): elif isinstance(target, ast.Tuple):
if not (isinstance(decl_val, TupleType) and len(target.elts) == len(decl_val.args)): if not isinstance(decl_val, TupleType):
raise IncompatibleTypesError(f"Cannot unpack {decl_val} into {target}") from transpiler.phases.typing.exceptions import InvalidUnpackError
raise InvalidUnpackError(decl_val)
if len(target.elts) != len(decl_val.args):
from transpiler.phases.typing.exceptions import InvalidUnpackCountError
raise InvalidUnpackCountError(decl_val, len(target.elts))
decls = [self.visit_assign_target(t, ty) for t, ty in zip(target.elts, decl_val.args)] # eager evaluated decls = [self.visit_assign_target(t, ty) for t, ty in zip(target.elts, decl_val.args)] # eager evaluated
return any(decls) return any(decls)
elif isinstance(target, ast.Attribute): elif isinstance(target, ast.Attribute):
...@@ -210,11 +212,13 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -210,11 +212,13 @@ class ScoperBlockVisitor(ScoperVisitor):
try: try:
iter_type = seq_type.methods["__iter__"].return_type iter_type = seq_type.methods["__iter__"].return_type
except: except:
raise IncompatibleTypesError(f"{seq_type} is not iterable in `{ast.unparse(node.iter)}`") from transpiler.phases.typing.exceptions import NotIterableError
raise NotIterableError(seq_type)
try: try:
next_type = iter_type.methods["__next__"].return_type next_type = iter_type.methods["__next__"].return_type
except: except:
raise IncompatibleTypesError(f"iter({iter_type}) is not an iterator") from transpiler.phases.typing.exceptions import NotIteratorError
raise NotIteratorError(iter_type)
var_var.unify(next_type) var_var.unify(next_type)
body_scope = scope.child(ScopeKind.FUNCTION_INNER) body_scope = scope.child(ScopeKind.FUNCTION_INNER)
body_visitor = ScoperBlockVisitor(body_scope, self.root_decls) body_visitor = ScoperBlockVisitor(body_scope, self.root_decls)
...@@ -228,7 +232,8 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -228,7 +232,8 @@ class ScoperBlockVisitor(ScoperVisitor):
def visit_Return(self, node: ast.Return): def visit_Return(self, node: ast.Return):
fct = self.scope.function fct = self.scope.function
if fct is None: if fct is None:
raise IncompatibleTypesError("Return outside function") from transpiler.phases.typing.exceptions import ReturnOutsideFunctionError
raise ReturnOutsideFunctionError()
ftype = fct.obj_type ftype = fct.obj_type
assert isinstance(ftype, FunctionType) assert isinstance(ftype, FunctionType)
vtype = self.expr().visit(node.value) if node.value else TY_NONE vtype = self.expr().visit(node.value) if node.value else TY_NONE
......
...@@ -180,6 +180,7 @@ class UnknownNameError(CompileError): ...@@ -180,6 +180,7 @@ class UnknownNameError(CompileError):
For example: For example:
{highlight('print(abcd)')} {highlight('print(abcd)')}
{highlight('import foobar')}
""" """
...@@ -198,3 +199,81 @@ class UnknownModuleMemberError(CompileError): ...@@ -198,3 +199,81 @@ class UnknownModuleMemberError(CompileError):
For example: For example:
{highlight('from math import abcd')} {highlight('from math import abcd')}
""" """
@dataclass
class InvalidUnpackCountError(CompileError):
value: BaseType
count: int
def __str__(self) -> str:
return f"Invalid unpack: {highlight(self.value)} cannot be unpacked into {self.count} variables"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to unpack a value that cannot be unpacked into the given number of
variables.
For example:
{highlight('a, b, c = 1, 2')}
"""
@dataclass
class InvalidUnpackError(CompileError):
value: BaseType
def __str__(self) -> str:
return f"Invalid unpack: {highlight(self.value)} cannot be unpacked"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to unpack a value that cannot be unpacked.
For example:
{highlight('a, b, c = 1')}
Moreover, currently typon only supports unpacking tuples.
"""
@dataclass
class NotIterableError(CompileError):
value: BaseType
def __str__(self) -> str:
return f"Not iterable: {highlight(self.value)} is not iterable"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to iterate over a value that is not iterable.
For example:
{highlight('for x in 1: ...')}
Iterable types must implement the Python {highlight('Iterable')} protocol, which requires the presence of a
{highlight('__iter__')} method.
"""
@dataclass
class NotIteratorError(CompileError):
value: BaseType
def __str__(self) -> str:
return f"Not iterator: {highlight(self.value)} is not an iterator"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to iterate over a value that is not an iterator.
For example:
{highlight('x = next(5)')}
Iterator types must implement the Python {highlight('Iterator')} protocol, which requires the presence of a
{highlight('__next__')} method.
"""
@dataclass
class ReturnOutsideFunctionError(CompileError):
def __str__(self) -> str:
return f"{highlight('return')} cannot be used outside of a function"
def detail(self, last_node: ast.AST = None) -> str:
return ""
\ No newline at end of file
...@@ -5,7 +5,7 @@ from typing import List ...@@ -5,7 +5,7 @@ from typing import List
from transpiler.phases.typing import ScopeKind, VarDecl, VarKind from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.common import ScoperVisitor from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import IncompatibleTypesError, BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \ from transpiler.phases.typing.types import BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType
DUNDER = { DUNDER = {
......
...@@ -136,7 +136,8 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -136,7 +136,8 @@ class StdlibVisitor(NodeVisitorSeq):
def visit_str(self, node: str) -> BaseType: def visit_str(self, node: str) -> BaseType:
if existing := self.scope.get(node): if existing := self.scope.get(node):
return existing.type return existing.type
raise NameError(node) from transpiler.phases.typing.exceptions import UnknownNameError
raise UnknownNameError(node)
def visit_Name(self, node: ast.Name) -> BaseType: def visit_Name(self, node: ast.Name) -> BaseType:
return self.visit_str(node.id) return self.visit_str(node.id)
\ No newline at end of file
...@@ -8,10 +8,6 @@ from typing import Dict, Optional, List, ClassVar, Callable ...@@ -8,10 +8,6 @@ from typing import Dict, Optional, List, ClassVar, Callable
from transpiler.utils import highlight from transpiler.utils import highlight
class IncompatibleTypesError(Exception):
pass
@dataclass(eq=False) @dataclass(eq=False)
class BaseType(ABC): class BaseType(ABC):
members: Dict[str, "BaseType"] = field(default_factory=dict, init=False) members: Dict[str, "BaseType"] = field(default_factory=dict, init=False)
...@@ -68,7 +64,8 @@ class MagicType(BaseType, typing.Generic[T]): ...@@ -68,7 +64,8 @@ class MagicType(BaseType, typing.Generic[T]):
def unify_internal(self, other: "BaseType"): def unify_internal(self, other: "BaseType"):
if type(self) != type(other) or self.val != other.val: if type(self) != type(other) or self.val != other.val:
raise IncompatibleTypesError() from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
def contains_internal(self, other: "BaseType") -> bool: def contains_internal(self, other: "BaseType") -> bool:
return False return False
...@@ -173,7 +170,7 @@ class TypeOperator(BaseType, ABC): ...@@ -173,7 +170,7 @@ class TypeOperator(BaseType, ABC):
def unify_internal(self, other: BaseType): def unify_internal(self, other: BaseType):
from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind
if not isinstance(other, TypeOperator): if not isinstance(other, TypeOperator):
raise IncompatibleTypesError() raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
if other.is_protocol and not self.is_protocol: if other.is_protocol and not self.is_protocol:
return other.unify_internal(self) return other.unify_internal(self)
if self.is_protocol and not other.is_protocol: if self.is_protocol and not other.is_protocol:
...@@ -185,14 +182,14 @@ class TypeOperator(BaseType, ABC): ...@@ -185,14 +182,14 @@ class TypeOperator(BaseType, ABC):
for parent in other.get_parents(): for parent in other.get_parents():
try: try:
self.unify(parent) self.unify(parent)
except IncompatibleTypesError: except TypeMismatchError:
pass pass
else: else:
return return
for parent in self.get_parents(): for parent in self.get_parents():
try: try:
parent.unify(other) parent.unify(other)
except IncompatibleTypesError: except TypeMismatchError:
pass pass
else: else:
return return
...@@ -447,4 +444,5 @@ class UserType(TypeOperator): ...@@ -447,4 +444,5 @@ class UserType(TypeOperator):
def unify_internal(self, other: "BaseType"): def unify_internal(self, other: "BaseType"):
if type(self) != type(other): if type(self) != type(other):
raise IncompatibleTypesError() from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
...@@ -19,7 +19,6 @@ class NodeVisitorSeq: ...@@ -19,7 +19,6 @@ class NodeVisitorSeq:
return visitor(node) return visitor(node)
except Exception as e: except Exception as e:
raise raise
#raise IncompatibleTypesError(f"{e} in `{ast.unparse(node)}`")
else: else:
self.missing_impl(node) self.missing_impl(node)
......
...@@ -4,6 +4,8 @@ from dataclasses import dataclass ...@@ -4,6 +4,8 @@ from dataclasses import dataclass
from itertools import zip_longest from itertools import zip_longest
from typing import Union from typing import Union
import colorful as cf import colorful as cf
from pygments.token import *
# #
# from colorama import Fore, Back # from colorama import Fore, Back
# from colorama.ansi import AnsiCodes # from colorama.ansi import AnsiCodes
...@@ -23,6 +25,38 @@ import colorful as cf ...@@ -23,6 +25,38 @@ import colorful as cf
# #
# Style = AnsiStyle() # Style = AnsiStyle()
COLOR_SCHEME = {
Token: ('', ''),
Whitespace: ('gray', 'brightblack'),
Comment: ('brightblack', 'brightblack'),
Comment.Preproc: ('cyan', 'brightcyan'),
Keyword: ('brightblue', 'brightblue'),
Keyword.Type: ('cyan', 'brightcyan'),
Operator.Word: ('magenta', 'brightmagenta'),
Name.Builtin: ('cyan', 'brightcyan'),
Name.Function: ('green', 'brightgreen'),
Name.Namespace: ('brightcyan', 'brightcyan'),
Name.Class: ('green', 'brightgreen'),
Name.Exception: ('cyan', 'brightcyan'),
Name.Decorator: ('brightblack', 'gray'),
Name.Variable: ('red', 'brightred'),
Name.Constant: ('red', 'brightred'),
Name.Attribute: ('cyan', 'brightcyan'),
Name.Tag: ('brightblue', 'brightblue'),
String: ('yellow', 'yellow'),
Number: ('brightmagenta', 'brightblue'),
Generic.Deleted: ('brightred', 'brightred'),
Generic.Inserted: ('green', 'brightgreen'),
Generic.Heading: ('**', '**'),
Generic.Subheading: ('*magenta*', '*brightmagenta*'),
Generic.Prompt: ('**', '**'),
Generic.Error: ('brightred', 'brightred'),
Error: ('brightred', 'brightred'),
}
def highlight(code, full=False): def highlight(code, full=False):
""" """
Syntax highlights code as Python using colorama Syntax highlights code as Python using colorama
...@@ -42,14 +76,14 @@ def highlight(code, full=False): ...@@ -42,14 +76,14 @@ def highlight(code, full=False):
from pygments.formatters import TerminalFormatter from pygments.formatters import TerminalFormatter
lexer = get_lexer_by_name("python", stripnl=False) lexer = get_lexer_by_name("python", stripnl=False)
items = pyg_highlight(code, lexer, TerminalFormatter()).replace("\x1b[39;49;00m", "\x1b[39;24m") items = pyg_highlight(code, lexer, TerminalFormatter(colorscheme=COLOR_SCHEME)).replace("\x1b[39;49;00m", "\x1b[39m")
if full: if full:
return items return items
items = items.splitlines() items = items.splitlines()
res = items[0] res = items[0]
if len(items) > 1: if len(items) > 1:
res += cf.white(" [...]") res += cf.white(" [...]")
return f"\x1b[39;49m{cf.on_gray25(res)}" return f"\x1b[39;49m{cf.on_gray23(res)}"
def compare_ast(node1: Union[ast.expr, list[ast.expr]], node2: Union[ast.expr, list[ast.expr]]) -> bool: def compare_ast(node1: Union[ast.expr, list[ast.expr]], node2: Union[ast.expr, list[ast.expr]]) -> bool:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment