diff --git a/gevent_pipeline/__init__.py b/gevent_pipeline/__init__.py index b710e98..491fca0 100644 --- a/gevent_pipeline/__init__.py +++ b/gevent_pipeline/__init__.py @@ -1,2 +1,2 @@ -from .closablequeue import ClosableQueue # noqa +from .closablequeue import ClosableQueue, ClosablePriorityQueue # noqa from .pipeline import Pipeline, worker, forward_input # noqa diff --git a/gevent_pipeline/closablequeue.py b/gevent_pipeline/closablequeue.py index baa6e6f..95d73fc 100644 --- a/gevent_pipeline/closablequeue.py +++ b/gevent_pipeline/closablequeue.py @@ -86,3 +86,19 @@ def get(self, *args, **kwargs): return super().get(block=False) except Exception: return StopIteration + +class ClosablePriorityQueue(queue.PriorityQueue, ClosableQueue): + """ + Mixes gevent's PriorityQueue with the ClosableQueue + + This can be useful for ordering output of a pipeline stage. + + Example: + >>> from gevent_pipeline import Pipeline + >>> cpq = ClosablePriorityQueue() + >>> random_array = [random.randint(1,50) for _ in range(10)] + >>> output = list(Pipeline().from_iter(random_array, q_out=cpq)) + >>> sorted(random_array) == output + True + """ + pass diff --git a/gevent_pipeline/pipeline.py b/gevent_pipeline/pipeline.py index 3946639..c33ad33 100644 --- a/gevent_pipeline/pipeline.py +++ b/gevent_pipeline/pipeline.py @@ -295,7 +295,7 @@ def g(q_in, q_out, q_done): raise RuntimeError("Unexpected data on fold output channel") return result - def join(self): + def joinall(self): """ Wait for the greenlets to finish Wrapper around gevent.joinall @@ -306,4 +306,4 @@ def join(self): # TODO pass argumetns to .joinall and remove greenlets in done from _greenlets self._greenlets = [] - return done + return self diff --git a/tests/test_closablequeue.py b/tests/test_closablequeue.py index ae448a2..6b047ab 100644 --- a/tests/test_closablequeue.py +++ b/tests/test_closablequeue.py @@ -1,4 +1,4 @@ -from gevent_pipeline import ClosableQueue +from gevent_pipeline import ClosableQueue, ClosablePriorityQueue import gevent from gevent import queue @@ -224,3 +224,15 @@ def putter(): # No items lost n_ok_get = n_left_on_queue + n_got assert n_ok_get == n_ok_put + +def test_cpq_order_matches(): + # Order of this list matches comparator order for strings. + # Will be randomized and then reordered by the Queue + ordered = list('abcdefghijklmnopqrstuvwxyz') + randomized = ordered.copy() + random.shuffle(randomized) + cpq = ClosablePriorityQueue(fuzz=0.01) + for letter in randomized: + cpq.put(letter) + for orig_el, queued in zip(ordered, cpq): + assert orig_el == queued diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index ec9eae1..1527bfc 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,4 +1,4 @@ -from gevent_pipeline import Pipeline, ClosableQueue, worker, forward_input +from gevent_pipeline import Pipeline, ClosableQueue, ClosablePriorityQueue, worker, forward_input import gevent from gevent import queue @@ -125,7 +125,7 @@ def doubler(x): l = sorted(p) assert l == [i*i for i in range(10)] - p.join() + p.joinall() def test_pipeline_sloppy_map(): @@ -143,3 +143,14 @@ def f(x): s_odd = sum(range(1, 100, 2)) s_even = sum(2*i for i in range(0, 100, 2)) assert sum(p) == s_odd + s_even + +def test_cpq_out_join_matches_order(): + cpq = ClosablePriorityQueue() + original = [random.randint(0,1000) for _ in range(100)] + p = Pipeline()\ + .from_iter(enumerate(original),n_workers=5)\ + .map(lambda x: (x[0], x[1]*2), n_workers=10, q_out=cpq)\ + .joinall()\ + .map(lambda x: x[1], n_workers=1) + result = list(p) + assert result == [x*2 for x in original]