Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
T
typon-compiler
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
typon
typon-compiler
Commits
57dfb3cf
Commit
57dfb3cf
authored
Jul 29, 2023
by
Tom Niget
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Continue work on pretty errors
parent
84a03ffc
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
286 additions
and
52 deletions
+286
-52
trans/requirements.txt
trans/requirements.txt
+1
-1
trans/tests/a_a_a_errtest.py
trans/tests/a_a_a_errtest.py
+6
-2
trans/transpiler/__init__.py
trans/transpiler/__init__.py
+69
-4
trans/transpiler/phases/typing/__init__.py
trans/transpiler/phases/typing/__init__.py
+1
-1
trans/transpiler/phases/typing/annotations.py
trans/transpiler/phases/typing/annotations.py
+2
-2
trans/transpiler/phases/typing/block.py
trans/transpiler/phases/typing/block.py
+2
-1
trans/transpiler/phases/typing/common.py
trans/transpiler/phases/typing/common.py
+1
-1
trans/transpiler/phases/typing/exceptions.py
trans/transpiler/phases/typing/exceptions.py
+166
-1
trans/transpiler/phases/typing/expr.py
trans/transpiler/phases/typing/expr.py
+14
-25
trans/transpiler/phases/typing/types.py
trans/transpiler/phases/typing/types.py
+14
-6
trans/transpiler/phases/utils.py
trans/transpiler/phases/utils.py
+2
-1
trans/transpiler/utils.py
trans/transpiler/utils.py
+8
-7
No files found.
trans/requirements.txt
View file @
57dfb3cf
...
@@ -4,4 +4,4 @@ dataclasses~=0.6
...
@@ -4,4 +4,4 @@ dataclasses~=0.6
python-dotenv~=1.0.0
python-dotenv~=1.0.0
colorama~=0.4.6
colorama~=0.4.6
numpy~=1.25.1
numpy~=1.25.1
pygments~=2.15.1
colorful~=0.5.5
\ No newline at end of file
\ No newline at end of file
trans/tests/a_a_a_errtest.py
View file @
57dfb3cf
def
f
(
x
):
import
sys
import
math
;
x
=
(
math
.
abcd
)
def
f
(
x
:
int
):
return
x
return
x
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
y
=
f
(
f
)
y
=
(
6
).
x
\ No newline at end of file
\ No newline at end of file
trans/transpiler/__init__.py
View file @
57dfb3cf
# coding: utf-8
# coding: utf-8
import
ast
import
ast
import
builtins
import
builtins
import
importlib
import
inspect
import
inspect
import
os
import
os
...
@@ -17,7 +18,7 @@ from transpiler.phases.if_main import IfMainVisitor
...
@@ -17,7 +18,7 @@ from transpiler.phases.if_main import IfMainVisitor
from
transpiler.phases.typing.block
import
ScoperBlockVisitor
from
transpiler.phases.typing.block
import
ScoperBlockVisitor
from
transpiler.phases.typing.scope
import
Scope
from
transpiler.phases.typing.scope
import
Scope
from
itertools
import
islice
import
sys
import
sys
import
colorful
as
cf
import
colorful
as
cf
...
@@ -48,17 +49,40 @@ def exception_hook(exc_type, exc_value, tb):
...
@@ -48,17 +49,40 @@ def exception_hook(exc_type, exc_value, tb):
filename
=
tb
.
tb_frame
.
f_code
.
co_filename
filename
=
tb
.
tb_frame
.
f_code
.
co_filename
line_no
=
tb
.
tb_lineno
line_no
=
tb
.
tb_lineno
print
(
cf
.
red
(
f"File
\
"
{
filename
}\
"
, line
{
line_no
}
, in
{
name
}
"
),
end
=
""
)
print
(
cf
.
red
(
f"File
\
"
{
filename
}\
"
, line
{
line_no
}
, in
{
cf
.
green
(
name
)
}
"
),
end
=
""
)
if
info
:
=
local_vars
.
get
(
"TB"
,
None
):
if
info
:
=
local_vars
.
get
(
"TB"
,
None
):
print
(
f"
, while
{
cf
.
magenta
(
info
)
}
"
)
print
(
f"
:
{
cf
.
magenta
(
info
)
}\
x1b
[24m
"
)
else
:
else
:
print
()
print
()
tb
=
tb
.
tb_next
tb
=
tb
.
tb_next
if
last_node
is
not
None
:
if
last_node
is
not
None
:
print
()
print
(
f"In file
\
"
{
cf
.
white
(
last_file
)
}\
"
, line
{
last_node
.
lineno
}
"
)
print
(
f"In file
\
"
{
cf
.
white
(
last_file
)
}\
"
, line
{
last_node
.
lineno
}
"
)
print
(
"
\
t
"
+
highlight
(
ast
.
unparse
(
last_node
)))
with
open
(
last_file
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
code
=
f
.
read
()
hg
=
str
(
highlight
(
code
,
True
)).
replace
(
"
\
x1b
[04m"
,
""
).
replace
(
"
\
x1b
[24m"
,
""
).
splitlines
()
if
last_node
.
lineno
==
last_node
.
end_lineno
:
old
=
hg
[
last_node
.
lineno
-
1
]
start
,
end
=
find_indices
(
old
,
[
last_node
.
col_offset
,
last_node
.
end_col_offset
])
hg
[
last_node
.
lineno
-
1
]
=
old
[:
start
]
+
"
\
x1b
[4m"
+
old
[
start
:
end
]
+
"
\
x1b
[24m"
+
old
[
end
:]
else
:
old
=
hg
[
last_node
.
lineno
-
1
]
[
start
]
=
find_indices
(
old
,
[
last_node
.
col_offset
])
hg
[
last_node
.
lineno
-
1
]
=
old
[:
start
]
+
"
\
x1b
[4m"
+
old
[
start
:]
old
=
hg
[
last_node
.
end_lineno
-
1
]
first_nonspace
=
len
(
old
)
-
len
(
old
.
lstrip
())
[
end
]
=
find_indices
(
old
,
[
last_node
.
end_col_offset
])
hg
[
last_node
.
end_lineno
-
1
]
=
old
[:
first_nonspace
]
+
"
\
x1b
[4m"
+
old
[
first_nonspace
:
end
]
+
"
\
x1b
[24m"
+
old
[
end
:]
CONTEXT_SIZE
=
2
start
=
max
(
0
,
last_node
.
lineno
-
CONTEXT_SIZE
-
1
)
offset
=
start
+
1
for
i
,
line
in
enumerate
(
hg
[
start
:
last_node
.
end_lineno
+
CONTEXT_SIZE
]):
erroneous
=
last_node
.
lineno
<=
offset
+
i
<=
last_node
.
end_lineno
indicator
=
cf
.
white
(
"**>"
)
if
erroneous
else
" "
print
(
f"
\
x1b
[24m
{
indicator
}{
cf
.
white
}{
(
offset
+
i
):
>
4
}{
cf
.
red
if
erroneous
else
cf
.
reset
}
|
{
cf
.
reset
}{
line
}\
x1b
[24m"
)
print
()
print
()
print
(
cf
.
red
(
"Error:"
),
exc_value
)
print
(
cf
.
red
(
"Error:"
),
exc_value
)
if
isinstance
(
exc_value
,
CompileError
):
if
isinstance
(
exc_value
,
CompileError
):
...
@@ -66,9 +90,50 @@ def exception_hook(exc_type, exc_value, tb):
...
@@ -66,9 +90,50 @@ def exception_hook(exc_type, exc_value, tb):
print
(
inspect
.
cleandoc
(
exc_value
.
detail
(
last_node
)))
print
(
inspect
.
cleandoc
(
exc_value
.
detail
(
last_node
)))
print
()
print
()
def
find_indices
(
s
,
indices
:
list
[
int
])
->
list
[
int
]:
"""
Matches indices to an ANSI-colored string
"""
results
=
set
()
i
=
0
j
=
0
it
=
iter
(
set
(
indices
))
current
=
next
(
it
)
while
i
<=
len
(
s
):
if
i
!=
len
(
s
)
and
s
[
i
]
==
"
\
x1b
"
:
i
+=
1
while
s
[
i
]
!=
"m"
:
i
+=
1
i
+=
1
continue
if
j
==
current
:
results
.
add
(
i
)
try
:
current
=
next
(
it
)
except
StopIteration
:
break
i
+=
1
j
+=
1
assert
len
(
results
)
==
len
(
indices
),
(
results
,
indices
,
s
)
return
sorted
(
list
(
results
))
assert
find_indices
(
"
\
x1b
[48;5;237mmath.abcd
\
x1b
[37m
\
x1b
[39m
\
x1b
[49m"
,
[
0
,
9
])
==
[
11
,
35
],
find_indices
(
"
\
x1b
[48;5;237mmath.abcd
\
x1b
[37m
\
x1b
[39m
\
x1b
[49m"
,
[
0
,
9
])
assert
find_indices
(
"abcdef"
,
[
2
,
5
])
==
[
2
,
5
]
assert
find_indices
(
"abc
\
x1b
[32mdef"
,
[
2
,
5
])
==
[
2
,
10
],
find_indices
(
"abc
\
x1b
[32mdef"
,
[
2
,
5
])
assert
find_indices
(
"math.abcd
\
x1b
[37m
\
x1b
[39m"
,
[
0
,
9
])
==
[
0
,
19
],
find_indices
(
"math.abcd
\
x1b
[37m
\
x1b
[39m"
,
[
0
,
9
])
sys
.
excepthook
=
exception_hook
sys
.
excepthook
=
exception_hook
try
:
pydevd
=
importlib
.
import_module
(
"_pydevd_bundle.pydevd_breakpoints"
)
except
ImportError
:
pass
else
:
pydevd
.
_fallback_excepthook
=
sys
.
excepthook
pydevd
.
original_excepthook
=
sys
.
excepthook
def
transpile
(
source
,
name
=
"<module>"
,
path
=
None
):
def
transpile
(
source
,
name
=
"<module>"
,
path
=
None
):
TB
=
f"transpiling module
{
cf
.
white
(
name
)
}
"
TB
=
f"transpiling module
{
cf
.
white
(
name
)
}
"
...
...
trans/transpiler/phases/typing/__init__.py
View file @
57dfb3cf
...
@@ -39,7 +39,7 @@ PRELUDE.vars.update({
...
@@ -39,7 +39,7 @@ PRELUDE.vars.update({
typon_std
=
Path
(
__file__
).
parent
.
parent
.
parent
.
parent
/
"stdlib"
typon_std
=
Path
(
__file__
).
parent
.
parent
.
parent
.
parent
/
"stdlib"
def
make_module
(
name
:
str
,
scope
:
Scope
)
->
BaseType
:
def
make_module
(
name
:
str
,
scope
:
Scope
)
->
BaseType
:
ty
=
ModuleType
([],
f"
module$
{
name
}
"
)
ty
=
ModuleType
([],
f"
{
name
}
"
)
for
n
,
v
in
scope
.
vars
.
items
():
for
n
,
v
in
scope
.
vars
.
items
():
ty
.
members
[
n
]
=
v
.
type
ty
.
members
[
n
]
=
v
.
type
return
ty
return
ty
...
...
trans/transpiler/phases/typing/annotations.py
View file @
57dfb3cf
...
@@ -32,7 +32,8 @@ class TypeAnnotationVisitor(NodeVisitorSeq):
...
@@ -32,7 +32,8 @@ class TypeAnnotationVisitor(NodeVisitorSeq):
return
ty
.
type_object
return
ty
.
type_object
return
ty
return
ty
raise
NameError
(
node
)
from
transpiler.phases.typing.exceptions
import
UnknownNameError
raise
UnknownNameError
(
node
)
def
visit_Name
(
self
,
node
:
ast
.
Name
)
->
BaseType
:
def
visit_Name
(
self
,
node
:
ast
.
Name
)
->
BaseType
:
return
self
.
visit_str
(
node
.
id
)
return
self
.
visit_str
(
node
.
id
)
...
@@ -59,4 +60,3 @@ class TypeAnnotationVisitor(NodeVisitorSeq):
...
@@ -59,4 +60,3 @@ class TypeAnnotationVisitor(NodeVisitorSeq):
res
=
left
.
members
[
node
.
attr
]
res
=
left
.
members
[
node
.
attr
]
assert
isinstance
(
res
,
TypeType
)
assert
isinstance
(
res
,
TypeType
)
return
res
.
type_object
return
res
.
type_object
raise
NotImplementedError
(
ast
.
unparse
(
node
))
trans/transpiler/phases/typing/block.py
View file @
57dfb3cf
...
@@ -58,7 +58,8 @@ class ScoperBlockVisitor(ScoperVisitor):
...
@@ -58,7 +58,8 @@ class ScoperBlockVisitor(ScoperVisitor):
for
alias
in
node
.
names
:
for
alias
in
node
.
names
:
thing
=
module
.
val
.
get
(
alias
.
name
)
thing
=
module
.
val
.
get
(
alias
.
name
)
if
not
thing
:
if
not
thing
:
raise
NameError
(
alias
.
name
)
from
transpiler.phases.typing.exceptions
import
UnknownModuleMemberError
raise
UnknownModuleMemberError
(
node
.
module
,
alias
.
name
)
alias
.
item_obj
=
thing
alias
.
item_obj
=
thing
self
.
scope
.
vars
[
alias
.
asname
or
alias
.
name
]
=
VarDecl
(
VarKind
.
LOCAL
,
thing
)
self
.
scope
.
vars
[
alias
.
asname
or
alias
.
name
]
=
VarDecl
(
VarKind
.
LOCAL
,
thing
)
...
...
trans/transpiler/phases/typing/common.py
View file @
57dfb3cf
...
@@ -35,5 +35,5 @@ class ScoperVisitor(NodeVisitorSeq):
...
@@ -35,5 +35,5 @@ class ScoperVisitor(NodeVisitorSeq):
visitor
.
visit
(
b
)
visitor
.
visit
(
b
)
b
.
decls
=
decls
b
.
decls
=
decls
if
not
node
.
inner_scope
.
has_return
:
if
not
node
.
inner_scope
.
has_return
:
rtype
.
unify
(
TY_NONE
)
rtype
.
unify
(
TY_NONE
)
# todo: properly indicate missing return
trans/transpiler/phases/typing/exceptions.py
View file @
57dfb3cf
import
ast
import
ast
import
enum
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
transpiler.utils
import
highlight
from
transpiler.utils
import
highlight
from
transpiler.exceptions
import
CompileError
from
transpiler.exceptions
import
CompileError
from
transpiler.phases.typing.types
import
TypeVariable
,
BaseType
from
transpiler.phases.typing.types
import
TypeVariable
,
BaseType
,
TypeOperator
@
dataclass
@
dataclass
...
@@ -34,6 +35,7 @@ class UnresolvedTypeVariableError(CompileError):
...
@@ -34,6 +35,7 @@ class UnresolvedTypeVariableError(CompileError):
{
highlight
(
'def f(x: int):'
)
}
{
highlight
(
'def f(x: int):'
)
}
"""
"""
@
dataclass
@
dataclass
class
RecursiveTypeUnificationError
(
CompileError
):
class
RecursiveTypeUnificationError
(
CompileError
):
needle
:
BaseType
needle
:
BaseType
...
@@ -52,3 +54,166 @@ class RecursiveTypeUnificationError(CompileError):
...
@@ -52,3 +54,166 @@ class RecursiveTypeUnificationError(CompileError):
In the current case,
{
highlight
(
self
.
haystack
)
}
contains type
{
highlight
(
self
.
needle
)
}
, but an attempt was made to
In the current case,
{
highlight
(
self
.
haystack
)
}
contains type
{
highlight
(
self
.
needle
)
}
, but an attempt was made to
unify them.
unify them.
"""
"""
@
dataclass
class
InvalidCallError
(
CompileError
):
callee
:
BaseType
args
:
list
[
BaseType
]
def
__str__
(
self
)
->
str
:
return
f"Invalid call:
{
highlight
(
self
.
callee
)
}
with arguments
{
highlight
(
self
.
args
)
}
"
def
detail
(
self
,
last_node
:
ast
.
AST
=
None
)
->
str
:
return
f"""
This generally indicates a type error in a function call.
For example:
{
highlight
(
'def f(x: int): pass'
)
}
{
highlight
(
'f("hello")'
)
}
In the current case,
{
highlight
(
self
.
callee
)
}
was called with arguments
{
highlight
(
self
.
args
)
}
, but the function
expects arguments of type
{
highlight
(
self
.
callee
.
args
)
}
.
"""
class
TypeMismatchKind
(
enum
.
Enum
):
NO_COMMON_PARENT
=
enum
.
auto
()
DIFFERENT_TYPE
=
enum
.
auto
()
@
dataclass
class
TypeMismatchError
(
CompileError
):
expected
:
BaseType
got
:
BaseType
reason
:
TypeMismatchKind
def
__str__
(
self
)
->
str
:
return
f"Type mismatch: expected
{
highlight
(
self
.
expected
)
}
, got
{
highlight
(
self
.
got
)
}
"
def
detail
(
self
,
last_node
:
ast
.
AST
=
None
)
->
str
:
return
f"""
This generally indicates a type error.
For example:
{
highlight
(
'def f(x: int): pass'
)
}
{
highlight
(
'f("hello")'
)
}
In the current case, the compiler expected an expression of type
{
highlight
(
self
.
expected
)
}
, but instead got
an expression of type
{
highlight
(
self
.
got
)
}
.
"""
@
dataclass
class
ArgumentCountMismatchError
(
CompileError
):
func
:
TypeOperator
arguments
:
TypeOperator
def
__setattr__
(
self
,
key
,
value
):
print
(
key
,
value
)
super
().
__setattr__
(
key
,
value
)
def
__str__
(
self
)
->
str
:
fcount
=
str
(
len
(
self
.
func
.
args
))
if
self
.
func
.
variadic
:
fcount
=
f"at least
{
fcount
}
"
return
f"Argument count mismatch: expected
{
fcount
}
, got
{
len
(
self
.
arguments
.
args
)
}
"
def
detail
(
self
,
last_node
:
ast
.
AST
=
None
)
->
str
:
return
f"""
This indicates missing or extraneous arguments in a function call or type instantiation.
The called or instantiated signature was
{
highlight
(
self
.
func
)
}
.
Other examples:
{
highlight
(
'def f(x: int): pass'
)
}
{
highlight
(
'f(1, 2)'
)
}
Here, the function
{
highlight
(
'f'
)
}
expects one argument, but was called with two.
{
highlight
(
'x: list[int, str]'
)
}
Here, the type
{
highlight
(
'list'
)
}
expects one argument, but was instantiated with two.
"""
@
dataclass
class
ProtocolMismatchError
(
CompileError
):
value
:
BaseType
protocol
:
BaseType
reason
:
Exception
def
__str__
(
self
)
->
str
:
return
f"Protocol mismatch:
{
highlight
(
self
.
value
)
}
does not implement
{
highlight
(
self
.
protocol
)
}
"
def
detail
(
self
,
last_node
:
ast
.
AST
=
None
)
->
str
:
return
f"""
This generally indicates a type error.
For example:
{
highlight
(
'def f(x: Iterable[int]): pass'
)
}
{
highlight
(
'f("hello")'
)
}
In the current case, the compiler expected an expression whose type implements
{
highlight
(
self
.
protocol
)
}
, but
instead got an expression of type
{
highlight
(
self
.
value
)
}
.
"""
@
dataclass
class
NotCallableError
(
CompileError
):
value
:
BaseType
def
__str__
(
self
)
->
str
:
return
f"Trying to call a non-function type:
{
highlight
(
self
.
value
)
}
"
def
detail
(
self
,
last_node
:
ast
.
AST
=
None
)
->
str
:
return
f"""
This indicates that an attempt was made to call an object that is not a function.
For example:
{
highlight
(
'x = 1'
)
}
{
highlight
(
'x()'
)
}
"""
@
dataclass
class
MissingAttributeError
(
CompileError
):
value
:
BaseType
attribute
:
str
def
__str__
(
self
)
->
str
:
return
f"Missing attribute:
{
highlight
(
self
.
value
)
}
has no attribute
{
highlight
(
self
.
attribute
)
}
"
def
detail
(
self
,
last_node
:
ast
.
AST
=
None
)
->
str
:
return
f"""
This indicates that an attempt was made to access an attribute that does not exist.
For example:
{
highlight
(
'x = 1'
)
}
{
highlight
(
'print(x.y)'
)
}
"""
@
dataclass
class
UnknownNameError
(
CompileError
):
name
:
str
def
__str__
(
self
)
->
str
:
return
f"Unknown name:
{
highlight
(
self
.
name
)
}
"
def
detail
(
self
,
last_node
:
ast
.
AST
=
None
)
->
str
:
return
f"""
This indicates that an attempt was made to access a name that does not exist.
For example:
{
highlight
(
'print(abcd)'
)
}
"""
@
dataclass
class
UnknownModuleMemberError
(
CompileError
):
module
:
str
name
:
str
def
__str__
(
self
)
->
str
:
return
f"Unknown module member: Module
{
highlight
(
self
.
module
)
}
does not contain
{
highlight
(
self
.
name
)
}
"
def
detail
(
self
,
last_node
:
ast
.
AST
=
None
)
->
str
:
return
f"""
This indicates that an attempt was made to import
For example:
{
highlight
(
'from math import abcd'
)
}
"""
\ No newline at end of file
trans/transpiler/phases/typing/expr.py
View file @
57dfb3cf
...
@@ -74,9 +74,10 @@ class ScoperExprVisitor(ScoperVisitor):
...
@@ -74,9 +74,10 @@ class ScoperExprVisitor(ScoperVisitor):
def
visit_Name
(
self
,
node
:
ast
.
Name
)
->
BaseType
:
def
visit_Name
(
self
,
node
:
ast
.
Name
)
->
BaseType
:
obj
=
self
.
scope
.
get
(
node
.
id
)
obj
=
self
.
scope
.
get
(
node
.
id
)
if
not
obj
:
if
not
obj
:
raise
NameError
(
f"Name
{
node
.
id
}
is not defined"
)
from
transpiler.phases.typing.exceptions
import
UnknownNameError
raise
UnknownNameError
(
node
.
id
)
if
isinstance
(
obj
.
type
,
TypeType
)
and
isinstance
(
obj
.
type
.
type_object
,
TypeVariable
):
if
isinstance
(
obj
.
type
,
TypeType
)
and
isinstance
(
obj
.
type
.
type_object
,
TypeVariable
):
raise
NameError
(
f"Use of type variable"
)
raise
NameError
(
f"Use of type variable"
)
# todo: when does this happen exactly?
if
getattr
(
obj
,
"is_python_func"
,
False
):
if
getattr
(
obj
,
"is_python_func"
,
False
):
obj
.
python_func_used
=
True
obj
.
python_func_used
=
True
return
obj
.
type
return
obj
.
type
...
@@ -93,10 +94,7 @@ class ScoperExprVisitor(ScoperVisitor):
...
@@ -93,10 +94,7 @@ class ScoperExprVisitor(ScoperVisitor):
ftype
=
self
.
visit
(
node
.
func
)
ftype
=
self
.
visit
(
node
.
func
)
if
ftype
.
typevars
:
if
ftype
.
typevars
:
ftype
=
ftype
.
gen_sub
(
None
,
{
v
.
name
:
TypeVariable
(
v
.
name
)
for
v
in
ftype
.
typevars
})
ftype
=
ftype
.
gen_sub
(
None
,
{
v
.
name
:
TypeVariable
(
v
.
name
)
for
v
in
ftype
.
typevars
})
try
:
rtype
=
self
.
visit_function_call
(
ftype
,
[
self
.
visit
(
arg
)
for
arg
in
node
.
args
])
rtype
=
self
.
visit_function_call
(
ftype
,
[
self
.
visit
(
arg
)
for
arg
in
node
.
args
])
except
IncompatibleTypesError
as
e
:
raise
IncompatibleTypesError
(
f"`
{
ast
.
unparse
(
node
)
}
`:
{
e
}
"
)
actual
=
rtype
actual
=
rtype
node
.
is_await
=
False
node
.
is_await
=
False
if
isinstance
(
actual
,
Promise
)
and
actual
.
kind
!=
PromiseKind
.
GENERATOR
:
if
isinstance
(
actual
,
Promise
)
and
actual
.
kind
!=
PromiseKind
.
GENERATOR
:
...
@@ -115,13 +113,12 @@ class ScoperExprVisitor(ScoperVisitor):
...
@@ -115,13 +113,12 @@ class ScoperExprVisitor(ScoperVisitor):
init
.
return_type
=
ftype
.
type_object
init
.
return_type
=
ftype
.
type_object
return
self
.
visit_function_call
(
init
,
arguments
)
return
self
.
visit_function_call
(
init
,
arguments
)
if
not
isinstance
(
ftype
,
FunctionType
):
if
not
isinstance
(
ftype
,
FunctionType
):
raise
IncompatibleTypesError
(
f"Cannot call
{
ftype
}
"
)
from
transpiler.phases.typing.exceptions
import
NotCallableError
raise
NotCallableError
(
ftype
)
#is_generic = any(isinstance(arg, TypeVariable) for arg in ftype.to_list())
#is_generic = any(isinstance(arg, TypeVariable) for arg in ftype.to_list())
equivalent
=
FunctionType
(
arguments
,
ftype
.
return_type
)
equivalent
=
FunctionType
(
arguments
,
ftype
.
return_type
)
try
:
equivalent
.
is_intermediary
=
True
ftype
.
unify
(
equivalent
)
ftype
.
unify
(
equivalent
)
except
IncompatibleTypesError
as
e
:
raise
IncompatibleTypesError
(
f"Cannot call
{
ftype
}
with (
{
(
', '
.
join
(
map
(
str
,
arguments
)))
}
):
{
e
}
"
)
return
ftype
.
return_type
return
ftype
.
return_type
def
visit_Lambda
(
self
,
node
:
ast
.
Lambda
)
->
BaseType
:
def
visit_Lambda
(
self
,
node
:
ast
.
Lambda
)
->
BaseType
:
...
@@ -143,17 +140,11 @@ class ScoperExprVisitor(ScoperVisitor):
...
@@ -143,17 +140,11 @@ class ScoperExprVisitor(ScoperVisitor):
def
visit_BinOp
(
self
,
node
:
ast
.
BinOp
)
->
BaseType
:
def
visit_BinOp
(
self
,
node
:
ast
.
BinOp
)
->
BaseType
:
left
,
right
=
map
(
self
.
visit
,
(
node
.
left
,
node
.
right
))
left
,
right
=
map
(
self
.
visit
,
(
node
.
left
,
node
.
right
))
try
:
return
self
.
make_dunder
([
left
,
right
],
DUNDER
[
type
(
node
.
op
)])
return
self
.
make_dunder
([
left
,
right
],
DUNDER
[
type
(
node
.
op
)])
except
IncompatibleTypesError
as
e
:
raise
IncompatibleTypesError
(
f"
{
e
}
in `
{
ast
.
unparse
(
node
)
}
`"
)
def
visit_Attribute
(
self
,
node
:
ast
.
Attribute
)
->
BaseType
:
def
visit_Attribute
(
self
,
node
:
ast
.
Attribute
)
->
BaseType
:
try
:
ltype
=
self
.
visit
(
node
.
value
)
ltype
=
self
.
visit
(
node
.
value
)
return
self
.
visit_getattr
(
ltype
,
node
.
attr
)
return
self
.
visit_getattr
(
ltype
,
node
.
attr
)
except
Exception
as
e
:
raise
IncompatibleTypesError
(
f"
{
e
}
in `
{
ast
.
unparse
(
node
)
}
`"
)
def
visit_getattr
(
self
,
ltype
:
BaseType
,
name
:
str
):
def
visit_getattr
(
self
,
ltype
:
BaseType
,
name
:
str
):
bound
=
True
bound
=
True
...
@@ -175,7 +166,8 @@ class ScoperExprVisitor(ScoperVisitor):
...
@@ -175,7 +166,8 @@ class ScoperExprVisitor(ScoperVisitor):
return
meth
.
remove_self
()
return
meth
.
remove_self
()
else
:
else
:
return
meth
return
meth
raise
IncompatibleTypesError
(
f"Type
{
ltype
}
has no attribute
{
name
}
"
)
from
transpiler.phases.typing.exceptions
import
MissingAttributeError
raise
MissingAttributeError
(
ltype
,
name
)
def
visit_List
(
self
,
node
:
ast
.
List
)
->
BaseType
:
def
visit_List
(
self
,
node
:
ast
.
List
)
->
BaseType
:
if
not
node
.
elts
:
if
not
node
.
elts
:
...
@@ -216,10 +208,7 @@ class ScoperExprVisitor(ScoperVisitor):
...
@@ -216,10 +208,7 @@ class ScoperExprVisitor(ScoperVisitor):
val
=
self
.
visit
(
node
.
operand
)
val
=
self
.
visit
(
node
.
operand
)
if
isinstance
(
node
.
op
,
ast
.
Not
):
if
isinstance
(
node
.
op
,
ast
.
Not
):
return
TY_BOOL
return
TY_BOOL
try
:
return
self
.
make_dunder
([
val
],
DUNDER
[
type
(
node
.
op
)])
return
self
.
make_dunder
([
val
],
DUNDER
[
type
(
node
.
op
)])
except
IncompatibleTypesError
as
e
:
raise
IncompatibleTypesError
(
f"
{
e
}
in `
{
ast
.
unparse
(
node
)
}
`"
)
def
visit_IfExp
(
self
,
node
:
ast
.
IfExp
)
->
BaseType
:
def
visit_IfExp
(
self
,
node
:
ast
.
IfExp
)
->
BaseType
:
self
.
visit
(
node
.
test
)
self
.
visit
(
node
.
test
)
...
...
trans/transpiler/phases/typing/types.py
View file @
57dfb3cf
...
@@ -5,7 +5,6 @@ from dataclasses import dataclass, field
...
@@ -5,7 +5,6 @@ from dataclasses import dataclass, field
from
enum
import
Enum
from
enum
import
Enum
from
itertools
import
zip_longest
from
itertools
import
zip_longest
from
typing
import
Dict
,
Optional
,
List
,
ClassVar
,
Callable
from
typing
import
Dict
,
Optional
,
List
,
ClassVar
,
Callable
from
transpiler.utils
import
highlight
from
transpiler.utils
import
highlight
...
@@ -131,6 +130,7 @@ class TypeOperator(BaseType, ABC):
...
@@ -131,6 +130,7 @@ class TypeOperator(BaseType, ABC):
is_protocol_gen
:
ClassVar
[
bool
]
=
False
is_protocol_gen
:
ClassVar
[
bool
]
=
False
match_cache
:
set
[
"TypeOperator"
]
=
field
(
default_factory
=
set
,
init
=
False
)
match_cache
:
set
[
"TypeOperator"
]
=
field
(
default_factory
=
set
,
init
=
False
)
is_reference
:
bool
=
False
is_reference
:
bool
=
False
is_intermediary
:
bool
=
False
@
staticmethod
@
staticmethod
def
make_type
(
name
:
str
):
def
make_type
(
name
:
str
):
...
@@ -167,9 +167,11 @@ class TypeOperator(BaseType, ABC):
...
@@ -167,9 +167,11 @@ class TypeOperator(BaseType, ABC):
corresp
.
remove_self
().
unify
(
ty
.
remove_self
())
corresp
.
remove_self
().
unify
(
ty
.
remove_self
())
except
Exception
as
e
:
except
Exception
as
e
:
self
.
match_cache
.
remove
(
hash
(
protocol
))
self
.
match_cache
.
remove
(
hash
(
protocol
))
raise
IncompatibleTypesError
(
f"Type
{
self
}
doesn't implement protocol
{
protocol
}
:
{
e
}
"
)
from
transpiler.phases.typing.exceptions
import
ProtocolMismatchError
raise
ProtocolMismatchError
(
self
,
protocol
,
e
)
def
unify_internal
(
self
,
other
:
BaseType
):
def
unify_internal
(
self
,
other
:
BaseType
):
from
transpiler.phases.typing.exceptions
import
TypeMismatchError
,
TypeMismatchKind
if
not
isinstance
(
other
,
TypeOperator
):
if
not
isinstance
(
other
,
TypeOperator
):
raise
IncompatibleTypesError
()
raise
IncompatibleTypesError
()
if
other
.
is_protocol
and
not
self
.
is_protocol
:
if
other
.
is_protocol
and
not
self
.
is_protocol
:
...
@@ -194,10 +196,10 @@ class TypeOperator(BaseType, ABC):
...
@@ -194,10 +196,10 @@ class TypeOperator(BaseType, ABC):
pass
pass
else
:
else
:
return
return
raise
IncompatibleTypesError
(
f"Cannot unify
{
self
}
and
{
other
}
with different type and no common parents"
)
raise
TypeMismatchError
(
self
,
other
,
TypeMismatchKind
.
DIFFERENT_TYPE
)
if
len
(
self
.
args
)
==
0
:
if
len
(
self
.
args
)
==
0
:
if
self
.
name
!=
other
.
name
:
if
self
.
name
!=
other
.
name
:
raise
IncompatibleTypesError
(
f"Cannot unify
{
self
}
and
{
other
}
"
)
raise
TypeMismatchError
(
self
,
other
,
TypeMismatchKind
.
DIFFERENT_TYPE
)
for
i
,
(
a
,
b
)
in
enumerate
(
zip_longest
(
self
.
args
,
other
.
args
)):
for
i
,
(
a
,
b
)
in
enumerate
(
zip_longest
(
self
.
args
,
other
.
args
)):
if
a
is
None
and
self
.
variadic
or
b
is
None
and
other
.
variadic
:
if
a
is
None
and
self
.
variadic
or
b
is
None
and
other
.
variadic
:
continue
continue
...
@@ -210,13 +212,14 @@ class TypeOperator(BaseType, ABC):
...
@@ -210,13 +212,14 @@ class TypeOperator(BaseType, ABC):
other
.
args
.
append
(
a
)
other
.
args
.
append
(
a
)
continue
continue
else
:
else
:
raise
IncompatibleTypesError
(
f"Cannot unify
{
self
}
and
{
other
}
, not enough arguments"
)
from
transpiler.phases.typing.exceptions
import
ArgumentCountMismatchError
raise
ArgumentCountMismatchError
(
*
sorted
((
self
,
other
),
key
=
lambda
x
:
x
.
is_intermediary
))
if
isinstance
(
a
,
BaseType
)
and
isinstance
(
b
,
BaseType
):
if
isinstance
(
a
,
BaseType
)
and
isinstance
(
b
,
BaseType
):
a
.
unify
(
b
)
a
.
unify
(
b
)
else
:
else
:
if
a
!=
b
:
if
a
!=
b
:
raise
IncompatibleTypesError
(
f"Cannot unify
{
a
}
and
{
b
}
"
)
raise
TypeMismatchError
(
a
,
b
,
TypeMismatchKind
.
DIFFERENT_TYPE
)
def
contains_internal
(
self
,
other
:
"BaseType"
)
->
bool
:
def
contains_internal
(
self
,
other
:
"BaseType"
)
->
bool
:
return
any
(
arg
.
contains
(
other
)
for
arg
in
self
.
args
)
return
any
(
arg
.
contains
(
other
)
for
arg
in
self
.
args
)
...
@@ -259,6 +262,11 @@ class FunctionType(TypeOperator):
...
@@ -259,6 +262,11 @@ class FunctionType(TypeOperator):
is_python_func
:
bool
=
False
is_python_func
:
bool
=
False
python_func_used
:
bool
=
False
python_func_used
:
bool
=
False
def
__iter__
(
self
):
x
=
5
pass
return
iter
([
str
(
self
)])
def
__init__
(
self
,
args
:
List
[
BaseType
],
ret
:
BaseType
):
def
__init__
(
self
,
args
:
List
[
BaseType
],
ret
:
BaseType
):
super
().
__init__
([
ret
,
*
args
])
super
().
__init__
([
ret
,
*
args
])
...
...
trans/transpiler/phases/utils.py
View file @
57dfb3cf
...
@@ -2,11 +2,12 @@ import ast
...
@@ -2,11 +2,12 @@ import ast
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
transpiler.utils
import
UnsupportedNodeError
from
transpiler.utils
import
UnsupportedNodeError
,
highlight
class
NodeVisitorSeq
:
class
NodeVisitorSeq
:
def
visit
(
self
,
node
):
def
visit
(
self
,
node
):
TB
=
f"running type analysis on
{
highlight
(
node
)
}
"
"""Visit a node."""
"""Visit a node."""
if
type
(
node
)
==
list
:
if
type
(
node
)
==
list
:
for
n
in
node
:
for
n
in
node
:
...
...
trans/transpiler/utils.py
View file @
57dfb3cf
...
@@ -29,21 +29,22 @@ def highlight(code, full=False):
...
@@ -29,21 +29,22 @@ def highlight(code, full=False):
"""
"""
from
transpiler.phases.typing
import
BaseType
from
transpiler.phases.typing
import
BaseType
if
isinstance
(
code
,
ast
.
AST
):
if
isinstance
(
code
,
ast
.
AST
):
return
cf
.
italic_
darkGrey
(
f"[
{
type
(
code
).
__name__
}
] "
)
+
highlight
(
ast
.
unparse
(
code
))
return
cf
.
italic_
grey60
(
f"[
{
type
(
code
).
__name__
}
] "
)
+
highlight
(
ast
.
unparse
(
code
))
elif
isinstance
(
code
,
BaseType
):
elif
isinstance
(
code
,
BaseType
):
return
cf
.
italic_grey
5
0
(
f"[
{
type
(
code
).
__name__
}
] "
)
+
highlight
(
str
(
code
))
return
cf
.
italic_grey
6
0
(
f"[
{
type
(
code
).
__name__
}
] "
)
+
highlight
(
str
(
code
))
from
pygments
import
highlight
as
pyg_highlight
from
pygments
import
highlight
as
pyg_highlight
from
pygments.lexers
import
PythonLexer
from
pygments.lexers
import
get_lexer_by_name
from
pygments.formatters
import
TerminalFormatter
from
pygments.formatters
import
TerminalFormatter
items
=
pyg_highlight
(
code
,
PythonLexer
(),
TerminalFormatter
()).
replace
(
"
\
x1b
[39;49;00m"
,
"
\
x1b
[39m"
).
splitlines
()
lexer
=
get_lexer_by_name
(
"python"
,
stripnl
=
False
)
items
=
pyg_highlight
(
code
,
lexer
,
TerminalFormatter
()).
replace
(
"
\
x1b
[39;49;00m"
,
"
\
x1b
[39;24m"
)
if
full
:
if
full
:
return
"
\
n
"
.
join
(
items
)
return
items
items
=
items
.
splitlines
()
res
=
items
[
0
]
res
=
items
[
0
]
if
len
(
items
)
>
1
:
if
len
(
items
)
>
1
:
res
+=
cf
.
white
(
" [...]"
)
res
+=
cf
.
white
(
" [...]"
)
#return Back.LIGHTBLACK_EX + Fore.RESET + res + Back.RESET
return
cf
.
on_gray25
(
res
)
return
cf
.
on_gray30
(
res
)
def
compare_ast
(
node1
:
Union
[
ast
.
expr
,
list
[
ast
.
expr
]],
node2
:
Union
[
ast
.
expr
,
list
[
ast
.
expr
]])
->
bool
:
def
compare_ast
(
node1
:
Union
[
ast
.
expr
,
list
[
ast
.
expr
]],
node2
:
Union
[
ast
.
expr
,
list
[
ast
.
expr
]])
->
bool
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment