Commit ea4cd6e0 authored by Marius Wachtler's avatar Marius Wachtler

BoxIteratorGeneric: call PyIter_Next() lazily

We were fetching to eagerly which caused issues if the iterator got used for other stuff too
parent 14ea4c01
......@@ -23,22 +23,15 @@ class BoxIteratorGeneric : public BoxIteratorImpl {
private:
Box* iterator;
Box* value;
bool need_to_fetch_value;
public:
BoxIteratorGeneric(Box* container) : iterator(nullptr), value(nullptr) {
BoxIteratorGeneric(Box* container) : iterator(nullptr), value(nullptr), need_to_fetch_value(false) {
if (container) {
// TODO: this should probably call getPystonIter
iterator = getiter(container);
if (iterator) {
// try catch block to manually decref the iterator because if the constructor throwes the destructor
// won't get called
// but we should probably just change the code to not call next inside the constructor...
try {
next();
} catch (ExcInfo e) {
Py_CLEAR(iterator);
throw e;
}
need_to_fetch_value = true;
} else
*this = *end();
}
......@@ -50,21 +43,13 @@ public:
}
void next() override {
STAT_TIMER(t0, "us_timer_iteratorgeneric_next", 0);
assert(!value);
Box* next = PyIter_Next(iterator);
if (next) {
value = next;
} else {
if (PyErr_Occurred())
throwCAPIException();
Py_CLEAR(iterator);
*this = *end();
}
assert(!need_to_fetch_value);
need_to_fetch_value = true;
}
Box* getValue() override {
if (need_to_fetch_value)
fetchNextValue();
Box* r = value;
assert(r);
value = NULL;
......@@ -73,6 +58,9 @@ public:
bool isSame(const BoxIteratorImpl* _rhs) override {
const BoxIteratorGeneric* rhs = (const BoxIteratorGeneric*)_rhs;
assert(!rhs->need_to_fetch_value); // we can't fetch the value here because rhs is const
if (need_to_fetch_value)
fetchNextValue();
return iterator == rhs->iterator && value == rhs->value;
}
......@@ -87,6 +75,23 @@ public:
static BoxIteratorGeneric _end(nullptr);
return &_end;
}
private:
void fetchNextValue() {
STAT_TIMER(t0, "us_timer_iteratorgeneric_next", 0);
assert(!value);
assert(need_to_fetch_value);
Box* next = PyIter_Next(iterator);
need_to_fetch_value = false;
if (next) {
value = next;
} else {
if (PyErr_Occurred())
throwCAPIException();
Py_CLEAR(iterator);
*this = *end();
}
}
};
......
......@@ -89,7 +89,7 @@ except:
try:
test_helper.run_test(['sh', '-c', '. %s/bin/activate && python %s/numpy/tools/test-installed-numpy.py' % (ENV_DIR, ENV_DIR)],
ENV_NAME, [dict(ran=5781, errors=2, failures=2)])
ENV_NAME, [dict(ran=5781, errors=2, failures=1)])
finally:
if USE_CUSTOM_PATCHES:
print_progress_header("Unpatching NumPy...")
......
......@@ -4,3 +4,8 @@ print list(enumerate(range(100), sys.maxint-50))
# cycle collection:
print enumerate(range(100)).next()
it = iter(range(5))
e = enumerate(it)
print e.next(), e.next()
print list(it)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment