Commit 9dc8cec5 authored by Tom Niget's avatar Tom Niget

Fix first python interop

parent 7044ede9
...@@ -116,7 +116,53 @@ template <> struct std::hash<decltype(0_pi)> { ...@@ -116,7 +116,53 @@ template <> struct std::hash<decltype(0_pi)> {
namespace PYBIND11_NAMESPACE { namespace PYBIND11_NAMESPACE {
namespace detail { namespace detail {
template <> struct type_caster<decltype(0_pi)> : type_caster<int> {}; template <> struct type_caster<decltype(0_pi)> {
public:
/**
* This macro establishes the name 'inty' in
* function signatures and declares a local variable
* 'value' of type inty
*/
PYBIND11_TYPE_CASTER(decltype(0_pi), const_name("TyInt"));
/**
* Conversion part 1 (Python->C++): convert a PyObject into a inty
* instance or return false upon failure. The second argument
* indicates whether implicit conversions should be applied.
*/
bool load(handle src, bool) {
/* Extract PyObject from handle */
PyObject *source = src.ptr();
/* Try converting into a Python integer value */
PyObject *tmp = PyNumber_Long(source);
if (!tmp)
return false;
/* Now try to convert into a C++ int */
dot(value, value) = PyLong_AsLong(tmp);
Py_DECREF(tmp);
/* Ensure return code was OK (to avoid out-of-range errors etc) */
return !(dot(value, value) == -1 && !PyErr_Occurred());
}
/**
* Conversion part 2 (C++ -> Python): convert an inty instance into
* a Python object. The second and third arguments are used to
* indicate the return value policy and parent object (for
* ``return_value_policy::reference_internal``) and are generally
* ignored by implicit casters.
*/
static handle cast(auto src, return_value_policy /* policy */, handle /* parent */) {
return PyLong_FromLong(dot(src, value));
}
};
} // namespace detail } // namespace detail
} // namespace PYBIND11_NAMESPACE } // namespace PYBIND11_NAMESPACE
......
# test numpy interop # test numpy interop
from numpy import square #from numpy import square
import math import math
if __name__ == "__main__": if __name__ == "__main__":
x = [1, 2, 3, 4] x = [1, 2, 3, 4]
y: list[int] = square(x) # y: list[int] = square(x)
print(x, y) # print(x, y)
f: int = math.factorial(5) f: int = math.factorial(5)
print("5! =", f) print("5! =", f)
\ No newline at end of file
...@@ -3,7 +3,8 @@ from typing import Iterable ...@@ -3,7 +3,8 @@ from typing import Iterable
from transpiler.phases.emit_cpp.class_ import emit_class from transpiler.phases.emit_cpp.class_ import emit_class
from transpiler.phases.emit_cpp.function import emit_function from transpiler.phases.emit_cpp.function import emit_function
from transpiler.phases.typing.modules import ModuleType from transpiler.phases.emit_cpp.visitors import NodeVisitor
from transpiler.phases.typing.modules import ModuleType, TyponModuleType, PythonModuleType
from transpiler.phases.typing.types import CallableInstanceType, ClassTypeType, TypeVariable, BaseType, GenericType, \ from transpiler.phases.typing.types import CallableInstanceType, ClassTypeType, TypeVariable, BaseType, GenericType, \
GenericInstanceType, UserGenericType GenericInstanceType, UserGenericType
...@@ -12,15 +13,71 @@ def emit_module(mod: ModuleType) -> Iterable[str]: ...@@ -12,15 +13,71 @@ def emit_module(mod: ModuleType) -> Iterable[str]:
__TB_NODE__ = mod.block_data.node __TB_NODE__ = mod.block_data.node
yield "#include <python/builtins.hpp>" yield "#include <python/builtins.hpp>"
yield "#include <python/sys.hpp>" yield "#include <python/sys.hpp>"
emitted = set()
def emit(mod_obj: ModuleType):
if mod_obj in emitted:
return
emitted.add(mod_obj)
name = mod_obj.name()
match mod_obj:
case TyponModuleType():
yield f"#include <python/{name}.hpp>"
case PythonModuleType():
yield f"namespace py_{name} {{"
yield "template <typename _Unused = void>"
yield f"struct {name}__oo : referencemodel::moduletype<{name}__oo<>> {{"
for fname, obj in mod_obj.fields.items():
obj = obj.type.resolve()
if type(obj) is TypeVariable:
continue # unused python function
assert isinstance(obj, CallableInstanceType)
yield "struct : referencemodel::function {"
yield "auto operator()("
for i, argty in enumerate(obj.parameters):
if i != 0:
yield ", "
yield "lvalue_or_rvalue<"
yield from NodeVisitor().visit_BaseType(argty)
yield f"> arg{i}"
yield ") const {"
yield "InterpGuard guard{};"
yield "try {"
yield f"return py::module_::import(\"{name}\").attr(\"{fname}\")("
for i, argty in enumerate(obj.parameters):
if i != 0:
yield ", "
yield f"*arg{i}"
yield ").cast<"
yield from NodeVisitor().visit_BaseType(obj.return_type)
yield ">();"
yield "} catch (py::error_already_set& e) {"
yield 'std::cerr << "Python exception: " << e.what() << std::endl;'
yield "throw;"
yield "}"
yield "}"
yield f"}} static constexpr {fname} {{}};"
yield "};"
yield f"{name}__oo<> all;"
yield "}"
incl_vars = [] incl_vars = []
for node in mod.block_data.node.body: for node in mod.block_data.node.body:
match node: match node:
case ast.Import(names): case ast.Import(names):
for alias in names: for alias in names:
yield f"#include <python/{alias.name}.hpp>" yield from emit(alias.module_obj)
incl_vars.append(f"auto& {alias.asname or alias.name} = py_{alias.name}::all;") incl_vars.append(f"auto& {alias.asname or alias.name} = py_{alias.name}::all;")
case ast.ImportFrom(module, names, _): case ast.ImportFrom(module, names, _):
yield f"#include <python/{module}.hpp>" yield from emit(node.module_obj)
for alias in names: for alias in names:
incl_vars.append(f"auto& {alias.asname or alias.name} = py_{module}::all.{alias.name};") incl_vars.append(f"auto& {alias.asname or alias.name} = py_{module}::all.{alias.name};")
yield "namespace PROGRAMNS {" yield "namespace PROGRAMNS {"
......
...@@ -110,8 +110,8 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -110,8 +110,8 @@ class ScoperExprVisitor(ScoperVisitor):
ty = obj.type.resolve() ty = obj.type.resolve()
# if isinstance(ty, TypeType) and isinstance(ty.type_object, TypeVariable): # if isinstance(ty, TypeType) and isinstance(ty.type_object, TypeVariable):
# raise NameError(f"Use of type variable") # todo: when does this happen exactly? # raise NameError(f"Use of type variable") # todo: when does this happen exactly?
if getattr(ty, "is_python_func", False): # if getattr(ty, "is_python_func", False):
ty.python_func_used = True # ty.python_func_used = True
return ty return ty
def visit_BoolOp(self, node: ast.BoolOp) -> BaseType: def visit_BoolOp(self, node: ast.BoolOp) -> BaseType:
...@@ -149,6 +149,13 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -149,6 +149,13 @@ class ScoperExprVisitor(ScoperVisitor):
# assert isinstance(ftype, CallableInstanceType) TODO # assert isinstance(ftype, CallableInstanceType) TODO
if isinstance(ftype, TypeVariable) and ftype.python_func_placeholder:
ret = TypeVariable()
new_ftype = CallableInstanceType(arguments, ret)
new_ftype.is_native = True
ftype.unify(new_ftype)
return ret
if not isinstance(ftype, CallableInstanceType): if not isinstance(ftype, CallableInstanceType):
return TypeVariable() return TypeVariable()
...@@ -272,8 +279,8 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -272,8 +279,8 @@ class ScoperExprVisitor(ScoperVisitor):
# return meth # return meth
if field := ltype.fields.get(name): if field := ltype.fields.get(name):
ty = field.type.resolve() ty = field.type.resolve()
if getattr(ty, "is_python_func", False): # if getattr(ty, "is_python_func", False):
ty.python_func_used = True # ty.python_func_used = True
if isinstance(ty, MethodType): if isinstance(ty, MethodType):
if bound and field.in_class_def and type(field.val) != RuntimeValue: if bound and field.in_class_def and type(field.val) != RuntimeValue:
return ty.remove_self(ltype) return ty.remove_self(ltype)
......
import ast import ast
import importlib
from pathlib import Path from pathlib import Path
from logging import debug from logging import debug
from transpiler.phases.typing import PRELUDE from transpiler.phases.typing import PRELUDE
from transpiler.phases.typing.scope import Scope, VarKind, VarDecl, ScopeKind from transpiler.phases.typing.scope import Scope, VarKind, VarDecl, ScopeKind
from transpiler.phases.typing.types import MemberDef, ResolvedConcreteType, UniqueTypeMixin, BlockData from transpiler.phases.typing.types import MemberDef, ResolvedConcreteType, UniqueTypeMixin, BlockData, TypeVariable
class ModuleType(UniqueTypeMixin, ResolvedConcreteType): class ModuleType(UniqueTypeMixin, ResolvedConcreteType):
pass pass
def make_module(name: str, scope: Scope) -> ModuleType: class TyponModuleType(ModuleType):
class CreatedType(ModuleType): pass
class PythonModuleType(ModuleType):
pass
def make_module(name: str, scope: Scope) -> TyponModuleType:
class CreatedType(TyponModuleType):
def name(self): def name(self):
return name return name
ty = CreatedType() ty = CreatedType()
...@@ -36,7 +43,49 @@ def parse_module(mod_name: str, python_path: list[Path], scope=None, preprocess= ...@@ -36,7 +43,49 @@ def parse_module(mod_name: str, python_path: list[Path], scope=None, preprocess=
break break
else: else:
raise FileNotFoundError(f"Could not find {mod_name}") """
py_mod = importlib.import_module(name)
mod_scope = Scope()
# copy all functions to mod_scope
for fname, obj in py_mod.__dict__.items():
if callable(obj):
# fty = FunctionType([], TypeVariable())
# fty.is_python_func = True
fty = TypeVariable()
fty.is_python_func = True
mod_scope.vars[fname] = VarDecl(VarKind.LOCAL, fty)
mod = make_mod_decl(name, mod_scope)
mod.type.is_python = True
self.scope.vars[name] = mod
"""
try:
py_mod = importlib.import_module(mod_name)
except ModuleNotFoundError:
raise FileNotFoundError(f"Could not find {mod_name}")
else:
if mod := visited_modules.get(py_mod):
return mod.type
try:
class OurModule(PythonModuleType):
def name(self):
return mod_name
mod = OurModule()
# copy all functions to mod_scope
for fname, obj in py_mod.__dict__.items():
if callable(obj):
# fty = FunctionType([], TypeVariable())
# fty.is_python_func = True
fty = TypeVariable(python_func_placeholder = True)
#fty.is_python_func = True
mod.fields[fname] = MemberDef(fty)
visited_modules[py_mod] = VarDecl(VarKind.LOCAL, mod)
return mod
except:
raise NotImplementedError(f"Could not process python module {mod_name}")
if path.is_dir(): if path.is_dir():
path = path / "__init__.py" path = path / "__init__.py"
......
...@@ -120,6 +120,7 @@ class TypeVariable(ConcreteType): ...@@ -120,6 +120,7 @@ class TypeVariable(ConcreteType):
resolved: Optional[ConcreteType] = None resolved: Optional[ConcreteType] = None
emit_as_is: bool = False emit_as_is: bool = False
decltype_str: Optional[str] = None decltype_str: Optional[str] = None
python_func_placeholder: bool = False
def resolve(self) -> ConcreteType: def resolve(self) -> ConcreteType:
if self.resolved is None: if self.resolved is None:
...@@ -577,6 +578,10 @@ class CallableInstanceType(GenericInstanceType, MethodType): ...@@ -577,6 +578,10 @@ class CallableInstanceType(GenericInstanceType, MethodType):
def __post_init__(self): def __post_init__(self):
if self.optional_at is None and self.parameters is not None: if self.optional_at is None and self.parameters is not None:
self.optional_at = len(self.parameters) self.optional_at = len(self.parameters)
if not hasattr(self, "generic_args") and self.parameters is not None:
self.generic_args = [*self.parameters, self.return_type]
if not hasattr(self, "generic_parent"):
self.generic_parent = None
def remove_self(self, self_type): def remove_self(self, self_type):
assert self.parameters[0].try_assign(self_type) assert self.parameters[0].try_assign(self_type)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment