Commit ea4cd6e0 authored by Marius Wachtler's avatar Marius Wachtler

BoxIteratorGeneric: call PyIter_Next() lazily

We were fetching to eagerly which caused issues if the iterator got used for other stuff too
parent 14ea4c01
...@@ -23,22 +23,15 @@ class BoxIteratorGeneric : public BoxIteratorImpl { ...@@ -23,22 +23,15 @@ class BoxIteratorGeneric : public BoxIteratorImpl {
private: private:
Box* iterator; Box* iterator;
Box* value; Box* value;
bool need_to_fetch_value;
public: public:
BoxIteratorGeneric(Box* container) : iterator(nullptr), value(nullptr) { BoxIteratorGeneric(Box* container) : iterator(nullptr), value(nullptr), need_to_fetch_value(false) {
if (container) { if (container) {
// TODO: this should probably call getPystonIter // TODO: this should probably call getPystonIter
iterator = getiter(container); iterator = getiter(container);
if (iterator) { if (iterator) {
// try catch block to manually decref the iterator because if the constructor throwes the destructor need_to_fetch_value = true;
// won't get called
// but we should probably just change the code to not call next inside the constructor...
try {
next();
} catch (ExcInfo e) {
Py_CLEAR(iterator);
throw e;
}
} else } else
*this = *end(); *this = *end();
} }
...@@ -50,21 +43,13 @@ public: ...@@ -50,21 +43,13 @@ public:
} }
void next() override { void next() override {
STAT_TIMER(t0, "us_timer_iteratorgeneric_next", 0); assert(!need_to_fetch_value);
assert(!value); need_to_fetch_value = true;
Box* next = PyIter_Next(iterator);
if (next) {
value = next;
} else {
if (PyErr_Occurred())
throwCAPIException();
Py_CLEAR(iterator);
*this = *end();
}
} }
Box* getValue() override { Box* getValue() override {
if (need_to_fetch_value)
fetchNextValue();
Box* r = value; Box* r = value;
assert(r); assert(r);
value = NULL; value = NULL;
...@@ -73,6 +58,9 @@ public: ...@@ -73,6 +58,9 @@ public:
bool isSame(const BoxIteratorImpl* _rhs) override { bool isSame(const BoxIteratorImpl* _rhs) override {
const BoxIteratorGeneric* rhs = (const BoxIteratorGeneric*)_rhs; const BoxIteratorGeneric* rhs = (const BoxIteratorGeneric*)_rhs;
assert(!rhs->need_to_fetch_value); // we can't fetch the value here because rhs is const
if (need_to_fetch_value)
fetchNextValue();
return iterator == rhs->iterator && value == rhs->value; return iterator == rhs->iterator && value == rhs->value;
} }
...@@ -87,6 +75,23 @@ public: ...@@ -87,6 +75,23 @@ public:
static BoxIteratorGeneric _end(nullptr); static BoxIteratorGeneric _end(nullptr);
return &_end; return &_end;
} }
private:
void fetchNextValue() {
STAT_TIMER(t0, "us_timer_iteratorgeneric_next", 0);
assert(!value);
assert(need_to_fetch_value);
Box* next = PyIter_Next(iterator);
need_to_fetch_value = false;
if (next) {
value = next;
} else {
if (PyErr_Occurred())
throwCAPIException();
Py_CLEAR(iterator);
*this = *end();
}
}
}; };
......
...@@ -89,7 +89,7 @@ except: ...@@ -89,7 +89,7 @@ except:
try: try:
test_helper.run_test(['sh', '-c', '. %s/bin/activate && python %s/numpy/tools/test-installed-numpy.py' % (ENV_DIR, ENV_DIR)], test_helper.run_test(['sh', '-c', '. %s/bin/activate && python %s/numpy/tools/test-installed-numpy.py' % (ENV_DIR, ENV_DIR)],
ENV_NAME, [dict(ran=5781, errors=2, failures=2)]) ENV_NAME, [dict(ran=5781, errors=2, failures=1)])
finally: finally:
if USE_CUSTOM_PATCHES: if USE_CUSTOM_PATCHES:
print_progress_header("Unpatching NumPy...") print_progress_header("Unpatching NumPy...")
......
...@@ -4,3 +4,8 @@ print list(enumerate(range(100), sys.maxint-50)) ...@@ -4,3 +4,8 @@ print list(enumerate(range(100), sys.maxint-50))
# cycle collection: # cycle collection:
print enumerate(range(100)).next() print enumerate(range(100)).next()
it = iter(range(5))
e = enumerate(it)
print e.next(), e.next()
print list(it)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment