Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
T
typon-compiler
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
typon
typon-compiler
Commits
13764aba
Commit
13764aba
authored
Aug 13, 2023
by
Tom Niget
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add proper implementation of Compare nodes
parent
6374abcf
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
96 additions
and
31 deletions
+96
-31
trans/transpiler/phases/desugar_compare/__init__.py
trans/transpiler/phases/desugar_compare/__init__.py
+27
-0
trans/transpiler/phases/emit_cpp/expr.py
trans/transpiler/phases/emit_cpp/expr.py
+28
-15
trans/transpiler/phases/typing/exceptions.py
trans/transpiler/phases/typing/exceptions.py
+2
-2
trans/transpiler/phases/typing/expr.py
trans/transpiler/phases/typing/expr.py
+5
-6
trans/transpiler/phases/typing/types.py
trans/transpiler/phases/typing/types.py
+25
-7
trans/transpiler/phases/utils.py
trans/transpiler/phases/utils.py
+9
-1
No files found.
trans/transpiler/phases/desugar_compare/__init__.py
0 → 100644
View file @
13764aba
# coding: utf-8
import
ast
from
transpiler.phases.typing.expr
import
DUNDER
from
transpiler.phases.utils
import
make_lnd
from
transpiler.utils
import
linenodata
class
DesugarCompare
(
ast
.
NodeTransformer
):
def
visit_Compare
(
self
,
node
:
ast
.
Compare
):
res
=
ast
.
BoolOp
(
ast
.
And
(),
[],
**
linenodata
(
node
))
for
left
,
op
,
right
in
zip
([
node
.
left
]
+
node
.
comparators
,
node
.
ops
,
node
.
comparators
):
lnd
=
make_lnd
(
left
,
right
)
if
type
(
op
)
in
(
ast
.
In
,
ast
.
NotIn
):
left
,
right
=
right
,
left
call
=
ast
.
Call
(
ast
.
Attribute
(
left
,
f"__
{
DUNDER
[
type
(
op
)]
}
__"
,
**
lnd
),
[
right
],
[],
**
lnd
)
if
type
(
op
)
==
ast
.
NotIn
:
call
=
ast
.
UnaryOp
(
ast
.
Not
(),
call
,
**
lnd
)
res
.
values
.
append
(
call
)
if
len
(
res
.
values
)
==
1
:
return
res
.
values
[
0
]
return
res
trans/transpiler/phases/emit_cpp/expr.py
View file @
13764aba
...
...
@@ -4,6 +4,7 @@ from dataclasses import dataclass, field
from
typing
import
List
,
Iterable
from
transpiler.phases.typing.types
import
UserType
,
FunctionType
from
transpiler.phases.utils
import
make_lnd
from
transpiler.utils
import
compare_ast
,
linenodata
from
transpiler.consts
import
SYMBOLS
,
PRECEDENCE_LEVELS
from
transpiler.phases.emit_cpp
import
CoroutineMode
,
join
,
NodeVisitor
...
...
@@ -91,22 +92,33 @@ class ExpressionVisitor(NodeVisitor):
# res = "co_await " + res
yield
res
def
visit_Compare
(
self
,
node
:
ast
.
Compare
)
->
Iterable
[
str
]:
def
make_lnd
(
op1
,
op2
):
return
{
"lineno"
:
op1
.
lineno
,
"col_offset"
:
op1
.
col_offset
,
"end_lineno"
:
op2
.
end_lineno
,
"end_col_offset"
:
op2
.
end_col_offset
}
# def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
# def make_lnd(op1, op2):
# return {
# "lineno": op1.lineno,
# "col_offset": op1.col_offset,
# "end_lineno": op2.end_lineno,
# "end_col_offset": op2.end_col_offset
# }
#
# operands = [node.left, *node.comparators]
# with self.prec_ctx("&&"):
# yield from self.visit_binary_operation(node.ops[0], operands[0], operands[1], make_lnd(operands[0], operands[1]))
# for (left, right), op in zip(zip(operands[1:], operands[2:]), node.ops[1:]):
# # TODO: cleaner code
# yield " && "
# yield from self.visit_binary_operation(op, left, right, make_lnd(left, right))
operands
=
[
node
.
left
,
*
node
.
comparators
]
with
self
.
prec_ctx
(
"&&"
):
yield
from
self
.
visit_binary_operation
(
node
.
ops
[
0
],
operands
[
0
],
operands
[
1
],
make_lnd
(
operands
[
0
],
operands
[
1
]))
for
(
left
,
right
),
op
in
zip
(
zip
(
operands
[
1
:],
operands
[
2
:]),
node
.
ops
[
1
:]):
# TODO: cleaner code
yield
" && "
yield
from
self
.
visit_binary_operation
(
op
,
left
,
right
,
make_lnd
(
left
,
right
))
def
visit_BoolOp
(
self
,
node
:
ast
.
BoolOp
)
->
Iterable
[
str
]:
cpp_op
=
{
ast
.
And
:
"&&"
,
ast
.
Or
:
"||"
}[
type
(
node
.
op
)]
with
self
.
prec_ctx
(
cpp_op
):
yield
from
self
.
visit_binary_operation
(
node
.
op
,
node
.
values
[
0
],
node
.
values
[
1
],
make_lnd
(
node
.
values
[
0
],
node
.
values
[
1
]))
for
left
,
right
in
zip
(
node
.
values
[
1
:],
node
.
values
[
2
:]):
yield
f"
{
cpp_op
}
"
yield
from
self
.
visit_binary_operation
(
node
.
op
,
left
,
right
,
make_lnd
(
left
,
right
))
def
visit_Call
(
self
,
node
:
ast
.
Call
)
->
Iterable
[
str
]:
# TODO
...
...
@@ -173,6 +185,7 @@ class ExpressionVisitor(NodeVisitor):
call
=
ast
.
Call
(
ast
.
Attribute
(
right
,
"__contains__"
,
**
lnd
),
[
left
],
[],
**
lnd
)
call
.
is_await
=
False
yield
from
self
.
visit_Call
(
call
)
print
(
call
.
func
.
type
)
return
op
=
SYMBOLS
[
type
(
op
)]
# TODO: handle precedence locally since only binops really need it
...
...
trans/transpiler/phases/typing/exceptions.py
View file @
13764aba
...
...
@@ -114,10 +114,10 @@ class ArgumentCountMismatchError(CompileError):
class
ProtocolMismatchError
(
CompileError
):
value
:
BaseType
protocol
:
BaseType
reason
:
Exception
reason
:
Exception
|
str
def
__str__
(
self
)
->
str
:
return
f"Protocol mismatch:
{
highlight
(
self
.
value
)
}
does not implement
{
highlight
(
self
.
protocol
)
}
"
return
f"Protocol mismatch:
{
str
(
self
.
value
)
}
does not implement
{
str
(
self
.
protocol
)
}
"
def
detail
(
self
,
last_node
:
ast
.
AST
=
None
)
->
str
:
return
f"""
...
...
trans/transpiler/phases/typing/expr.py
View file @
13764aba
...
...
@@ -8,6 +8,7 @@ from transpiler.phases.typing.common import ScoperVisitor
from
transpiler.phases.typing.types
import
BaseType
,
TupleType
,
TY_STR
,
TY_BOOL
,
TY_INT
,
\
TY_COMPLEX
,
TY_NONE
,
FunctionType
,
PyList
,
TypeVariable
,
PySet
,
TypeType
,
PyDict
,
Promise
,
PromiseKind
,
UserType
,
\
TY_SLICE
from
transpiler.utils
import
linenodata
DUNDER
=
{
ast
.
Eq
:
"eq"
,
...
...
@@ -30,6 +31,7 @@ DUNDER = {
ast
.
USub
:
"neg"
,
ast
.
UAdd
:
"pos"
,
ast
.
Invert
:
"invert"
,
ast
.
In
:
"contains"
,
}
class
ScoperExprVisitor
(
ScoperVisitor
):
...
...
@@ -94,13 +96,10 @@ class ScoperExprVisitor(ScoperVisitor):
obj
.
python_func_used
=
True
return
obj
.
type
def
visit_Compare
(
self
,
node
:
ast
.
Compare
)
->
BaseType
:
# todo:
self
.
visit
(
node
.
left
)
for
op
,
right
in
zip
(
node
.
ops
,
node
.
comparators
):
self
.
visit
(
right
)
def
visit_BoolOp
(
self
,
node
:
ast
.
BoolOp
)
->
BaseType
:
for
value
in
node
.
values
:
self
.
visit
(
value
)
return
TY_BOOL
#raise NotImplementedError(node)
def
visit_Call
(
self
,
node
:
ast
.
Call
)
->
BaseType
:
ftype
=
self
.
visit
(
node
.
func
)
...
...
trans/transpiler/phases/typing/types.py
View file @
13764aba
...
...
@@ -147,6 +147,7 @@ class TypeOperator(BaseType, ABC):
cls
.
gen_parents
=
[]
def
__post_init__
(
self
):
assert
all
(
x
is
not
None
for
x
in
self
.
args
)
if
self
.
name
is
None
:
self
.
name
=
self
.
__class__
.
__name__
for
name
,
factory
in
self
.
gen_methods
.
items
():
...
...
@@ -157,19 +158,30 @@ class TypeOperator(BaseType, ABC):
self
.
parents
.
append
(
gp
)
self
.
methods
=
{
**
gp
.
methods
,
**
self
.
methods
}
self
.
is_protocol
=
self
.
is_protocol
or
self
.
is_protocol_gen
self
.
_add_default_eq
()
def
_add_default_eq
(
self
):
if
"__eq__"
not
in
self
.
methods
:
if
"DEFAULT_EQ"
in
globals
():
self
.
methods
[
"__eq__"
]
=
DEFAULT_EQ
def
matches_protocol
(
self
,
protocol
:
"TypeOperator"
):
if
hash
(
protocol
)
in
self
.
match_cache
:
return
from
transpiler.phases.typing.exceptions
import
ProtocolMismatchError
,
TypeMismatchError
try
:
dupl
=
protocol
.
gen_sub
(
self
,
{
v
.
name
:
(
TypeVariable
(
v
.
name
)
if
isinstance
(
v
.
resolve
(),
TypeVariable
)
else
v
)
for
v
in
protocol
.
args
})
self
.
match_cache
.
add
(
hash
(
protocol
))
for
name
,
ty
in
dupl
.
methods
.
items
():
if
name
==
"__eq__"
:
continue
if
name
not
in
self
.
methods
:
raise
ProtocolMismatchError
(
self
,
protocol
,
f"missing method
{
name
}
"
)
corresp
=
self
.
methods
[
name
]
corresp
.
remove_self
().
unify
(
ty
.
remove_self
())
except
Exception
as
e
:
self
.
match_cache
.
remove
(
hash
(
protocol
))
from
transpiler.phases.typing.exceptions
import
ProtocolMismatchError
except
TypeMismatchError
as
e
:
if
hash
(
protocol
)
in
self
.
match_cache
:
self
.
match_cache
.
remove
(
hash
(
protocol
))
raise
ProtocolMismatchError
(
self
,
protocol
,
e
)
def
unify_internal
(
self
,
other
:
BaseType
):
...
...
@@ -331,18 +343,24 @@ class TypeType(TypeOperator):
self
.
args
[
0
]
=
value
TY_SELF
=
TypeOperator
.
make_type
(
"Self"
)
def
self_gen_sub
(
this
,
typevars
,
_
):
assert
this
is
not
None
return
this
TY_SELF
.
gen_sub
=
self_gen_sub
TY_BOOL
=
TypeOperator
.
make_type
(
"bool"
)
DEFAULT_EQ
=
FunctionType
([
TY_SELF
,
TY_SELF
],
TY_BOOL
)
TY_BOOL
.
_add_default_eq
()
TY_TYPE
=
TypeOperator
.
make_type
(
"type"
)
TY_INT
=
TypeOperator
.
make_type
(
"int"
)
TY_FLOAT
=
TypeOperator
.
make_type
(
"float"
)
TY_STR
=
TypeOperator
.
make_type
(
"str"
)
TY_BYTES
=
TypeOperator
.
make_type
(
"bytes"
)
TY_BOOL
=
TypeOperator
.
make_type
(
"bool"
)
TY_COMPLEX
=
TypeOperator
.
make_type
(
"complex"
)
TY_NONE
=
TypeOperator
.
make_type
(
"NoneType"
)
#TY_MODULE = TypeOperator([], "module")
TY_VARARG
=
TypeOperator
.
make_type
(
"vararg"
)
TY_SELF
=
TypeOperator
.
make_type
(
"Self"
)
TY_SELF
.
gen_sub
=
lambda
this
,
typevars
,
_
:
this
TY_SLICE
=
TypeOperator
.
make_type
(
"slice"
)
...
...
@@ -460,4 +478,4 @@ class UserType(TypeOperator):
class
UnionType
(
TypeOperator
):
def
__init__
(
self
,
*
args
:
List
[
BaseType
]):
super
().
__init__
(
args
,
"Union"
)
self
.
parents
.
extend
(
args
)
\ No newline at end of file
self
.
parents
.
extend
(
args
)
trans/transpiler/phases/utils.py
View file @
13764aba
...
...
@@ -43,4 +43,12 @@ class AnnotationName:
def
id
(
self
):
return
str
(
self
.
inner
)
AnnotationName
.
__name__
=
"Name"
\ No newline at end of file
AnnotationName
.
__name__
=
"Name"
def
make_lnd
(
op1
,
op2
):
return
{
"lineno"
:
op1
.
lineno
,
"col_offset"
:
op1
.
col_offset
,
"end_lineno"
:
op2
.
end_lineno
,
"end_col_offset"
:
op2
.
end_col_offset
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment