Commit a002ad59 authored by Mark Florisson's avatar Mark Florisson Committed by Dag Sverre Seljebotn

Fix openmp pxd & support num_threads clause

parent 8dcd1394
......@@ -5864,6 +5864,10 @@ class ParallelStatNode(StatNode, ParallelNode):
privatization_insertion_point a code insertion point used to make temps
private (esp. the "nsteps" temp)
args tuple the arguments passed to the parallel construct
kwargs DictNode the keyword arguments passed to the parallel
construct (replaced by its compile time value)
"""
child_attrs = ['body']
......@@ -5888,6 +5892,17 @@ class ParallelStatNode(StatNode, ParallelNode):
def analyse_declarations(self, env):
self.body.analyse_declarations(env)
if self.kwargs:
self.kwargs = self.kwargs.compile_time_value(env)
else:
self.kwargs = {}
for kw, val in self.kwargs.iteritems():
if kw not in self.valid_keyword_arguments:
error(self.pos, "Invalid keyword argument: %s" % kw)
else:
setattr(self, kw, val)
def analyse_expressions(self, env):
self.body.analyse_expressions(env)
self.analyse_sharing_attributes(env)
......@@ -6044,6 +6059,17 @@ class ParallelStatNode(StatNode, ParallelNode):
code.putln("%s = %s;" % (entry.cname,
entry.type.cast_code(invalid_value)))
def put_num_threads(self, code):
"""
Write self.num_threads if set as the num_threads OpenMP directive
"""
if self.num_threads is not None:
if isinstance(self.num_threads, (int, long)):
code.put(" num_threads(%d)" % (self.num_threads,))
else:
error(self.pos, "Invalid value for num_threads argument, "
"expected an int")
def declare_closure_privates(self, code):
"""
Set self.privates to a dict mapping C variable names that are to be
......@@ -6081,6 +6107,16 @@ class ParallelWithBlockNode(ParallelStatNode):
nogil_check = None
valid_keyword_arguments = ['num_threads']
num_threads = None
def analyse_declarations(self, env):
super(ParallelWithBlockNode, self).analyse_declarations(env)
if self.args:
error(self.pos, "cython.parallel.parallel() does not take "
"positional arguments")
def generate_execution_code(self, code):
self.declare_closure_privates(code)
......@@ -6092,8 +6128,9 @@ class ParallelWithBlockNode(ParallelStatNode):
'private(%s)' % ', '.join([e.cname for e in self.privates]))
self.privatization_insertion_point = code.insertion_point()
self.put_num_threads(code)
code.putln("")
code.putln("#endif /* _OPENMP */")
code.begin_block()
......@@ -6110,11 +6147,6 @@ class ParallelRangeNode(ParallelStatNode):
target NameNode the target iteration variable
else_clause Node or None the else clause of this loop
args tuple the arguments passed to prange()
kwargs DictNode the keyword arguments passed to prange()
(replaced by its compile time value)
is_nogil bool indicates whether this is a nogil prange() node
"""
child_attrs = ['body', 'target', 'else_clause', 'args']
......@@ -6124,7 +6156,12 @@ class ParallelRangeNode(ParallelStatNode):
start = stop = step = None
is_prange = True
is_nogil = False
nogil = False
schedule = None
num_threads = None
valid_keyword_arguments = ['schedule', 'nogil', 'num_threads']
def analyse_declarations(self, env):
super(ParallelRangeNode, self).analyse_declarations(env)
......@@ -6143,14 +6180,6 @@ class ParallelRangeNode(ParallelStatNode):
else:
self.start, self.stop, self.step = self.args
if self.kwargs:
self.kwargs = self.kwargs.compile_time_value(env)
else:
self.kwargs = {}
self.is_nogil = self.kwargs.pop('nogil', False)
self.schedule = self.kwargs.pop('schedule', None)
if hasattr(self.schedule, 'decode'):
self.schedule = self.schedule.decode('ascii')
......@@ -6159,9 +6188,6 @@ class ParallelRangeNode(ParallelStatNode):
error(self.pos, "Invalid schedule argument to prange: %s" %
(self.schedule,))
for kw in self.kwargs:
error(self.pos, "Invalid keyword argument to prange: %s" % kw)
def analyse_expressions(self, env):
if self.target is None:
error(self.pos, "prange() can only be used as part of a for loop")
......@@ -6349,6 +6375,8 @@ class ParallelRangeNode(ParallelStatNode):
c = self.parent.privatization_insertion_point
c.put(" private(%(nsteps)s)" % fmt_dict)
self.put_num_threads(code)
self.privatization_insertion_point = code.insertion_point()
code.putln("")
......
......@@ -2094,8 +2094,8 @@ class GilCheck(VisitorTransform):
return node
def visit_ParallelRangeNode(self, node):
if node.is_nogil:
node.is_nogil = False
if node.nogil:
node.nogil = False
node = Nodes.GILStatNode(node.pos, state='nogil', body=node)
return self.visit_GILStatNode(node)
......
......@@ -8,42 +8,42 @@ cdef extern from "omp.h":
omp_sched_guided = 3,
omp_sched_auto = 4
extern void omp_set_num_threads(int)
extern int omp_get_num_threads()
extern int omp_get_max_threads()
extern int omp_get_thread_num()
extern int omp_get_num_procs()
extern int omp_in_parallel()
extern void omp_set_dynamic(int)
extern int omp_get_dynamic()
extern void omp_set_nested(int)
extern int omp_get_nested()
extern void omp_init_lock(omp_lock_t *)
extern void omp_destroy_lock(omp_lock_t *)
extern void omp_set_lock(omp_lock_t *)
extern void omp_unset_lock(omp_lock_t *)
extern int omp_test_lock(omp_lock_t *)
extern void omp_init_nest_lock(omp_nest_lock_t *)
extern void omp_destroy_nest_lock(omp_nest_lock_t *)
extern void omp_set_nest_lock(omp_nest_lock_t *)
extern void omp_unset_nest_lock(omp_nest_lock_t *)
extern int omp_test_nest_lock(omp_nest_lock_t *)
extern double omp_get_wtime()
extern double omp_get_wtick()
void omp_set_schedule(omp_sched_t, int)
void omp_get_schedule(omp_sched_t *, int *)
int omp_get_thread_limit()
void omp_set_max_active_levels(int)
int omp_get_max_active_levels()
int omp_get_level()
int omp_get_ancestor_thread_num(int)
int omp_get_team_size(int)
int omp_get_active_level()
extern void omp_set_num_threads(int) nogil
extern int omp_get_num_threads() nogil
extern int omp_get_max_threads() nogil
extern int omp_get_thread_num() nogil
extern int omp_get_num_procs() nogil
extern int omp_in_parallel() nogil
extern void omp_set_dynamic(int) nogil
extern int omp_get_dynamic() nogil
extern void omp_set_nested(int) nogil
extern int omp_get_nested() nogil
extern void omp_init_lock(omp_lock_t *) nogil
extern void omp_destroy_lock(omp_lock_t *) nogil
extern void omp_set_lock(omp_lock_t *) nogil
extern void omp_unset_lock(omp_lock_t *) nogil
extern int omp_test_lock(omp_lock_t *) nogil
extern void omp_init_nest_lock(omp_nest_lock_t *) nogil
extern void omp_destroy_nest_lock(omp_nest_lock_t *) nogil
extern void omp_set_nest_lock(omp_nest_lock_t *) nogil
extern void omp_unset_nest_lock(omp_nest_lock_t *) nogil
extern int omp_test_nest_lock(omp_nest_lock_t *) nogil
extern double omp_get_wtime() nogil
extern double omp_get_wtick() nogil
void omp_set_schedule(omp_sched_t, int) nogil
void omp_get_schedule(omp_sched_t *, int *) nogil
int omp_get_thread_limit() nogil
void omp_set_max_active_levels(int) nogil
int omp_get_max_active_levels() nogil
int omp_get_level() nogil
int omp_get_ancestor_thread_num(int) nogil
int omp_get_team_size(int) nogil
int omp_get_active_level() nogil
......@@ -59,6 +59,12 @@ for i in prange(10, nogil=True):
y += i
y *= i
with nogil, cython.parallel.parallel("invalid"):
pass
with nogil, cython.parallel.parallel(invalid=True):
pass
_ERRORS = u"""
e_cython_parallel.pyx:3:8: cython.parallel.parallel is not a module
e_cython_parallel.pyx:4:0: No such directive: cython.parallel.something
......@@ -76,4 +82,6 @@ e_cython_parallel.pyx:39:12: The parallel directive must be called
e_cython_parallel.pyx:45:10: Expression value depends on previous loop iteration, cannot execute in parallel
e_cython_parallel.pyx:55:9: Expression depends on an uninitialized thread-private variable
e_cython_parallel.pyx:60:6: Reduction operator '*' is inconsistent with previous reduction operator '+'
e_cython_parallel.pyx:62:36: cython.parallel.parallel() does not take positional arguments
e_cython_parallel.pyx:65:36: Invalid keyword argument: invalid
"""
......@@ -24,4 +24,22 @@ def test_parallel():
free(buf)
def test_num_threads():
"""
>>> test_num_threads()
1
"""
cdef int dyn = openmp.omp_get_dynamic()
cdef int num_threads
cdef int *p = &num_threads
openmp.omp_set_dynamic(0)
with nogil, cython.parallel.parallel(num_threads=1):
p[0] = openmp.omp_get_num_threads()
openmp.omp_set_dynamic(dyn)
return num_threads
include "sequential_parallel.pyx"
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment