Commit 85b27174 authored by Tom Niget's avatar Tom Niget

Add preliminary support for list comprehensions

parent 00f0474f
...@@ -294,4 +294,28 @@ template <class T> auto begin(std::shared_ptr<T> &obj) { return dotp(obj, begin) ...@@ -294,4 +294,28 @@ template <class T> auto begin(std::shared_ptr<T> &obj) { return dotp(obj, begin)
template <class T> auto end(std::shared_ptr<T> &obj) { return dotp(obj, end)(); } template <class T> auto end(std::shared_ptr<T> &obj) { return dotp(obj, end)(); }
} }
template <typename T>
struct AlwaysTrue { // (1)
constexpr bool operator()(const T&) const {
return true;
}
};
template <typename Seq>
struct ValueTypeEx {
using type = decltype(*std::begin(std::declval<Seq&>()));
};
// (2)
template <typename Map, typename Seq, typename Filt = AlwaysTrue<typename ValueTypeEx<Seq>::type>>
auto mapFilter(Map map, Seq seq, Filt filt = Filt()) {
//typedef typename Seq::value_type value_type;
using value_type = typename ValueTypeEx<Seq>::type;
using return_type = decltype(map(std::declval<value_type>()));
std::vector<return_type> result{};
for (auto i : seq | std::views::filter(filt)
| std::views::transform(map)) result.push_back(i);
return typon::PyList(std::move(result));
}
#endif // TYPON_BUILTINS_HPP #endif // TYPON_BUILTINS_HPP
...@@ -15,6 +15,7 @@ namespace typon { ...@@ -15,6 +15,7 @@ namespace typon {
template <typename T> class PyList { template <typename T> class PyList {
public: public:
using value_type = T;
PyList(std::shared_ptr<std::vector<T>> &&v) : _v(std::move(v)) {} PyList(std::shared_ptr<std::vector<T>> &&v) : _v(std::move(v)) {}
PyList(std::vector<T> &&v) PyList(std::vector<T> &&v)
: _v(std::move(std::make_shared<std::vector<T>>(std::move(v)))) {} : _v(std::move(std::make_shared<std::vector<T>>(std::move(v)))) {}
......
...@@ -17,6 +17,7 @@ class int: ...@@ -17,6 +17,7 @@ class int:
def __init__(self, x: str) -> None: ... def __init__(self, x: str) -> None: ...
def __lt__(self, other: Self) -> bool: ... def __lt__(self, other: Self) -> bool: ...
def __gt__(self, other: Self) -> bool: ... def __gt__(self, other: Self) -> bool: ...
def __mod__(self, other: Self) -> Self: ...
assert int.__add__ assert int.__add__
...@@ -71,6 +72,7 @@ class list(Generic[U]): ...@@ -71,6 +72,7 @@ class list(Generic[U]):
def __len__(self) -> int: ... def __len__(self) -> int: ...
def append(self, value: U) -> None: ... def append(self, value: U) -> None: ...
def __contains__(self, item: U) -> bool: ... def __contains__(self, item: U) -> bool: ...
def __init__(self, it: Iterator[U]) -> None: ...
assert [1, 2].__iter__() assert [1, 2].__iter__()
assert list[int].__iter__ assert list[int].__iter__
......
import sys import sys
import math import math
def gàé():
return 1,2,3
if __name__ == "__main__": if __name__ == "__main__":
if True: a = [n for n in range(10)]
a, b, c = gàé() # abc b = [x for x in a if x % 2 == 0]
\ No newline at end of file c = [y * y for y in b]
print(a, b, c)
\ No newline at end of file
from dataclasses import dataclass
from typing import Any, Callable
from enum import Enum
from itertools import groupby
import operator
import string
@dataclass
class BinOperator:
symbol: str
priority: int
perform: Callable[[float, float], float]
OPERATORS = [
BinOperator("+", 0, operator.add),
BinOperator("-", 0, operator.sub),
BinOperator("*", 1, operator.mul),
BinOperator("/", 1, operator.truediv)
]
ops_by_priority = [list(it) for _, it in groupby(OPERATORS, lambda op: op.priority)]
MAX_PRIORITY = len(ops_by_priority)
ops_syms = [op.symbol for op in OPERATORS]
class TokenType(Enum):
NUMBER = 1
PARENTHESIS = 2
OPERATION = 3
@dataclass
class Token:
type: TokenType
val: Any
def tokenize(inp: str):
tokens = []
index = 0
def skip_spaces():
nonlocal index
while inp[index].isspace():
index += 1
def has():
return index < len(inp)
def peek():
return inp[index]
def read():
nonlocal index
index += 1
return inp[index - 1]
def read_number():
res = ""
while True:
res += read()
if not has() or peek() not in "0123456789.":
break
return Token(TokenType.NUMBER, float(res) if "." in res else int(res))
while has():
skip_spaces()
next = peek()
if next in ops_syms:
tok = Token(TokenType.OPERATION, read())
elif next in "()":
tok = Token(TokenType.PARENTHESIS, read())
elif next in "0123456789.":
tok = read_number()
else:
raise Exception(f"invalid character '{next}'", index)
tokens.append(tok)
return tokens
def parse(tokens):
index = 0
def has():
return index < len(tokens)
def current():
if not has():
raise Exception("expected token, got EOL")
return tokens[index]
def match(type: TokenType, val: Any = None):
return has() and tokens[index].type == type and (val is None or tokens[index].val == val)
def accept(type: TokenType, val: Any = None):
nonlocal index
if match(type, val):
index += 1
return True
return False
def expect(type: TokenType, val: Any = None):
nonlocal index
if match(type, val):
index += 1
return tokens[index - 1]
if not has():
raise Exception(f"expected {type}, got EOL")
else:
raise Exception(f"expected {type}, got {current().type}")
def parse_bin(priority=0):
if priority >= MAX_PRIORITY:
return parse_term()
left = parse_bin(priority + 1)
ops = ops_by_priority[priority]
while has() and current().type == TokenType.OPERATION:
for op in ops:
if accept(TokenType.OPERATION, op.symbol):
right = parse_bin(priority + 1)
left = op.perform(left, right)
break
else:
break
return left
def parse_term():
token = current()
if token.type == TokenType.NUMBER:
return expect(TokenType.NUMBER).val
elif accept(TokenType.PARENTHESIS, "("):
val = parse_expr()
expect(TokenType.PARENTHESIS, ")")
return val
else:
raise Exception(f"expected term, got {token.type}")
def parse_expr():
return parse_bin()
return parse_expr()
if __name__ == "__main__":
while True:
inp = input("> ")
try:
tok = tokenize(inp)
res = parse(tok)
print(res)
except Exception as e:
print(e)
print()
...@@ -110,6 +110,9 @@ class ExpressionVisitor(NodeVisitor): ...@@ -110,6 +110,9 @@ class ExpressionVisitor(NodeVisitor):
# yield from self.visit_binary_operation(op, left, right, make_lnd(left, right)) # yield from self.visit_binary_operation(op, left, right, make_lnd(left, right))
def visit_BoolOp(self, node: ast.BoolOp) -> Iterable[str]: def visit_BoolOp(self, node: ast.BoolOp) -> Iterable[str]:
if len(node.values) == 1:
yield from self.visit(node.values[0])
return
cpp_op = { cpp_op = {
ast.And: "&&", ast.And: "&&",
ast.Or: "||" ast.Or: "||"
...@@ -297,3 +300,34 @@ class ExpressionVisitor(NodeVisitor): ...@@ -297,3 +300,34 @@ class ExpressionVisitor(NodeVisitor):
# raise NotImplementedError(node) # raise NotImplementedError(node)
yield "co_yield" yield "co_yield"
yield from self.prec("co_yield").visit(node.value) yield from self.prec("co_yield").visit(node.value)
def visit_ListComp(self, node: ast.ListComp) -> Iterable[str]:
if len(node.generators) != 1:
raise NotImplementedError("Multiple generators not handled yet")
gen: ast.comprehension = node.generators[0]
yield "mapFilter([]("
yield from self.visit(node.input_item_type)
yield from self.visit(gen.target)
yield ") { return "
yield from self.visit(node.elt)
yield "; }, "
yield from self.visit(gen.iter)
if gen.ifs:
yield ", "
yield "[]("
yield from self.visit(node.input_item_type)
yield from self.visit(gen.target)
yield ") { return "
yield from self.visit(gen.ifs_node)
yield "; }"
yield ")"
# iter_type = get_iter(self.visit(gen.iter))
# next_type = get_next(iter_type)
# virt_scope = self.scope.child(ScopeKind.FUNCTION_INNER)
# from transpiler import ScoperBlockVisitor
# visitor = ScoperBlockVisitor(virt_scope)
# visitor.visit_assign_target(gen.target, next_type)
# res_item_type = visitor.expr().visit(node.elt)
# for if_ in gen.ifs:
# visitor.expr().visit(if_)
# return PyList(res_item_type)
\ No newline at end of file
...@@ -63,3 +63,18 @@ class ScoperVisitor(NodeVisitorSeq): ...@@ -63,3 +63,18 @@ class ScoperVisitor(NodeVisitorSeq):
if not node.inner_scope.has_return: if not node.inner_scope.has_return:
rtype.unify(TY_NONE) # todo: properly indicate missing return rtype.unify(TY_NONE) # todo: properly indicate missing return
def get_iter(seq_type):
try:
iter_type = seq_type.methods["__iter__"].return_type
except:
from transpiler.phases.typing.exceptions import NotIterableError
raise NotIterableError(seq_type)
return iter_type
def get_next(iter_type):
try:
next_type = iter_type.methods["__next__"].return_type
except:
from transpiler.phases.typing.exceptions import NotIteratorError
raise NotIteratorError(iter_type)
return next_type
\ No newline at end of file
...@@ -4,7 +4,7 @@ import inspect ...@@ -4,7 +4,7 @@ import inspect
from typing import List from typing import List
from transpiler.phases.typing import ScopeKind, VarDecl, VarKind from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.common import ScoperVisitor from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next
from transpiler.phases.typing.types import BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \ from transpiler.phases.typing.types import BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType, \ TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType, \
TY_SLICE TY_SLICE
...@@ -249,3 +249,19 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -249,3 +249,19 @@ class ScoperExprVisitor(ScoperVisitor):
self.visit_getattr(TypeType(args[0]), f"__{name}__"), self.visit_getattr(TypeType(args[0]), f"__{name}__"),
args args
) )
def visit_ListComp(self, node: ast.ListComp) -> BaseType:
if len(node.generators) != 1:
raise NotImplementedError("Multiple generators not handled yet")
gen: ast.comprehension = node.generators[0]
iter_type = get_iter(self.visit(gen.iter))
node.input_item_type = get_next(iter_type)
virt_scope = self.scope.child(ScopeKind.FUNCTION_INNER)
from transpiler import ScoperBlockVisitor
visitor = ScoperBlockVisitor(virt_scope)
visitor.visit_assign_target(gen.target, node.input_item_type)
node.item_type = visitor.expr().visit(node.elt)
for if_ in gen.ifs:
visitor.expr().visit(if_)
gen.ifs_node = ast.BoolOp(ast.And(), gen.ifs, **linenodata(node))
return PyList(node.item_type)
\ No newline at end of file
...@@ -218,7 +218,7 @@ class TypeOperator(BaseType, ABC): ...@@ -218,7 +218,7 @@ class TypeOperator(BaseType, ABC):
if other.is_protocol and not self.is_protocol: if other.is_protocol and not self.is_protocol:
return other.unify_internal(self) return other.unify_internal(self)
if self.is_protocol and not other.is_protocol: if self.is_protocol and not other.is_protocol:
return other.matches_protocol(self) return other.matches_protocol(self) # TODO: doesn't print the correct type in the error message
if len(self.args) < len(other.args): if len(self.args) < len(other.args):
return other.unify_internal(self) return other.unify_internal(self)
assert self.is_protocol == other.is_protocol assert self.is_protocol == other.is_protocol
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment