Commit a63f48d7 authored by Kevin Modzelewski's avatar Kevin Modzelewski

I guess we need to be checking __new__ and __init__ for descriptors

parent 6c340adf
......@@ -601,7 +601,7 @@ Box* Box::getattr(const std::string& attr, GetattrRewriteArgs* rewrite_args) {
}
// TODO should centralize all of these:
static const std::string _call_str("__call__"), _new_str("__new__"), _init_str("__init__");
static const std::string _call_str("__call__"), _new_str("__new__"), _init_str("__init__"), _get_str("__get__");
void Box::setattr(const std::string& attr, Box* val, SetattrRewriteArgs* rewrite_args) {
assert(cls->instancesHaveAttrs());
......@@ -951,10 +951,14 @@ return gotten;
// Does a simple call of the descriptor's __get__ if it exists;
// this function is useful for custom getattribute implementations that already know whether the descriptor
// came from the class or not.
Box* processDescriptorOrNull(Box* obj, Box* inst, Box* owner) {
Box* descr_r = callattrInternal(obj, &_get_str, LookupScope::CLASS_ONLY, NULL, ArgPassSpec(2), inst, owner, NULL,
NULL, NULL);
return descr_r;
}
Box* processDescriptor(Box* obj, Box* inst, Box* owner) {
static const std::string get_str("__get__");
Box* descr_r
= callattrInternal(obj, &get_str, LookupScope::CLASS_ONLY, NULL, ArgPassSpec(2), inst, owner, NULL, NULL, NULL);
Box* descr_r = processDescriptorOrNull(obj, inst, owner);
if (descr_r)
return descr_r;
return obj;
......@@ -1088,14 +1092,14 @@ Box* getattrInternalGeneral(Box* obj, const std::string& attr, GetattrRewriteArg
= r_descr.getAttr(BOX_CLS_OFFSET, RewriterVarUsage::NoKill, Location::any());
GetattrRewriteArgs grewrite_args(rewrite_args->rewriter, std::move(r_descr_cls), Location::any(),
false);
_get_ = typeLookup(descr->cls, "__get__", &grewrite_args);
_get_ = typeLookup(descr->cls, _get_str, &grewrite_args);
if (!grewrite_args.out_success) {
rewrite_args = NULL;
} else if (_get_) {
r_get = std::move(grewrite_args.out_rtn);
}
} else {
_get_ = typeLookup(descr->cls, "__get__", NULL);
_get_ = typeLookup(descr->cls, _get_str, NULL);
}
// As an optimization, don't check for __set__ if we're in cls_only mode, since it won't matter.
......@@ -1223,14 +1227,14 @@ Box* getattrInternalGeneral(Box* obj, const std::string& attr, GetattrRewriteArg
= r_val.getAttr(BOX_CLS_OFFSET, RewriterVarUsage::NoKill, Location::any());
GetattrRewriteArgs grewrite_args(rewrite_args->rewriter, std::move(r_val_cls), Location::any(),
false);
local_get = typeLookup(val->cls, "__get__", &grewrite_args);
local_get = typeLookup(val->cls, _get_str, &grewrite_args);
if (!grewrite_args.out_success) {
rewrite_args = NULL;
} else if (local_get) {
r_get = std::move(grewrite_args.out_rtn);
}
} else {
local_get = typeLookup(val->cls, "__get__", NULL);
local_get = typeLookup(val->cls, _get_str, NULL);
}
// Call __get__(val, None, obj)
......@@ -3419,8 +3423,18 @@ Box* typeCallInternal(BoxedFunction* f, CallRewriteArgs* rewrite_args, ArgPassSp
r_new.addGuard((intptr_t)new_attr);
}
}
// Special-case functions to allow them to still rewrite:
if (new_attr->cls != function_cls) {
Box* descr_r = processDescriptorOrNull(new_attr, None, cls);
if (descr_r) {
new_attr = descr_r;
rewrite_args = NULL;
}
}
} else {
new_attr = typeLookup(cls, _new_str, NULL);
new_attr = processDescriptor(new_attr, None, cls);
}
assert(new_attr && "This should always resolve");
......@@ -3492,8 +3506,12 @@ Box* typeCallInternal(BoxedFunction* f, CallRewriteArgs* rewrite_args, ArgPassSp
getNameOfClass(cls)->c_str());
if (init_attr && init_attr != typeLookup(object_cls, _init_str, NULL)) {
// TODO apply the same descriptor special-casing as in callattr?
Box* initrtn;
if (rewrite_args) {
// Attempt to rewrite the basic case:
if (rewrite_args && init_attr->cls == function_cls) {
// Note: this code path includes the descriptor logic
CallRewriteArgs srewrite_args(rewrite_args->rewriter, std::move(r_init), rewrite_args->destination, false);
if (npassed_args >= 1)
srewrite_args.arg1 = r_made.addUse();
......@@ -3517,9 +3535,19 @@ Box* typeCallInternal(BoxedFunction* f, CallRewriteArgs* rewrite_args, ArgPassSp
.setDoneUsing();
}
} else {
// initrtn = callattrInternal(cls, &_init_str, INST_ONLY, NULL, argspec, made, arg2, arg3, args,
// keyword_names);
initrtn = runtimeCallInternal(init_attr, NULL, argspec, made, arg2, arg3, args, keyword_names);
init_attr = processDescriptor(init_attr, made, cls);
ArgPassSpec init_argspec = argspec;
init_argspec.num_args--;
int passed = init_argspec.totalPassed();
// If we weren't passed the args array, it's not safe to index into it
if (passed <= 2)
initrtn = runtimeCallInternal(init_attr, NULL, init_argspec, arg2, arg3, NULL, NULL, keyword_names);
else
initrtn
= runtimeCallInternal(init_attr, NULL, init_argspec, arg2, arg3, args[0], &args[1], keyword_names);
}
assertInitNone(initrtn);
} else {
......
......@@ -2,26 +2,74 @@
# Descriptors get processed when fetched as part of a dunder lookup
class D(object):
def __init__(self, n):
self.n = n
def __get__(self, obj, cls):
print "__get__()", obj is None, self.n
def desc(*args):
print "desc()", len(args)
def f1():
class D(object):
def __init__(self, n):
self.n = n
def __get__(self, obj, cls):
print "__get__()", obj is None, self.n
def desc(*args):
print "desc()", len(args)
return self.n
return desc
def __call__(self):
print "D.call"
return self.n
return desc
def __call__(self):
print "D.call"
return self.n
class C(object):
__hash__ = D(1)
__add__ = D(2)
__init__ = D(None)
print C.__init__()
c = C()
print C.__hash__()
print c.__hash__()
print hash(c)
print c + c
f1()
def f2():
print "\nf2"
class D(object):
def __call__(self, subcl):
print "call", subcl
return object.__new__(subcl)
def get(self, inst, owner):
print "__get__", inst, owner
def new(self):
print "new"
return object.__new__(owner)
return new
class C(object):
__new__ = D()
print type(C())
D.__get__ = get
print type(C())
f2()
def f3():
print "\nf3"
class D(object):
def __call__(self):
print "call"
return None
def get(self, inst, owner):
print "__get__", type(inst), owner
def init():
print "init"
return None
return init
class C(object):
__hash__ = D(1)
__add__ = D(2)
class C(object):
__init__ = D()
c = C()
print C.__hash__()
print c.__hash__()
print hash(c)
print c + c
print type(C())
D.__get__ = get
print type(C())
f3()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment