Skip to content

Add filter class to dask and do the tests for it #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions streamz/clients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from collections import Sequence, MutableMapping
from concurrent.futures import ThreadPoolExecutor, Future
from functools import wraps

from distributed import default_client as dask_default_client
from tornado import gen

from .core import identity


FILL_COLOR_LOOKUP = {"dask": "cornflowerblue", "threads": "coral"}


def result_maybe(future_maybe):
if isinstance(future_maybe, Future):
return future_maybe.result()
else:
if isinstance(future_maybe, Sequence) and not isinstance(
future_maybe, str
):
aa = []
for a in future_maybe:
aa.append(result_maybe(a))
if isinstance(future_maybe, tuple):
aa = tuple(aa)
return aa
elif isinstance(future_maybe, MutableMapping):
for k, v in future_maybe.items():
future_maybe[k] = result_maybe(v)
return future_maybe


def delayed_execution(func):
@wraps(func)
def inner(*args, **kwargs):
args = tuple([result_maybe(v) for v in args])
kwargs = {k: result_maybe(v) for k, v in kwargs.items()}
return func(*args, **kwargs)

return inner


def executor_to_client(executor):
executor._submit = executor.submit

@wraps(executor.submit)
def inner(fn, *args, **kwargs):
wfn = delayed_execution(fn)
return executor._submit(wfn, *args, **kwargs)

executor.submit = inner

@gen.coroutine
def scatter(x, asynchronous=True):
f = executor.submit(identity, x)
return f

executor.scatter = getattr(executor, "scatter", scatter)

@gen.coroutine
def gather(x, asynchronous=True):
# If we have a sequence of futures await each one
if isinstance(x, Sequence):
final_result = []
for sub_x in x:
yx = yield sub_x
final_result.append(yx)
result = type(x)(final_result)
else:
result = yield x
return result

executor.gather = getattr(executor, "gather", gather)
return executor


thread_ex_list = []


def thread_default_client():
if thread_ex_list:
ex = thread_ex_list[0]
if ex._shutdown:
thread_ex_list.pop()
ex = executor_to_client(ThreadPoolExecutor())
thread_ex_list.append(ex)
else:
ex = executor_to_client(ThreadPoolExecutor())
thread_ex_list.append(ex)
return ex


DEFAULT_BACKENDS = {
"dask": dask_default_client,
"thread": thread_default_client,
}
76 changes: 73 additions & 3 deletions streamz/dask.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from __future__ import absolute_import, division, print_function
from functools import wraps

from .core import _truthy
from .core import get_io_loop
from .clients import DEFAULT_BACKENDS
from operator import getitem

from tornado import gen
Expand All @@ -10,6 +14,23 @@
from .core import Stream
from . import core, sources

from collections import Sequence


NULL_COMPUTE = "~~NULL_COMPUTE~~"


def return_null(func):
@wraps(func)
def inner(x, *args, **kwargs):
tv = func(x, *args, **kwargs)
if tv:
return x
else:
return NULL_COMPUTE

return inner


class DaskStream(Stream):
""" A Parallel stream using Dask
Expand Down Expand Up @@ -117,12 +138,43 @@ class gather(core.Stream):
buffer
scatter
"""

def __init__(self, *args, backend="dask", **kwargs):
super().__init__(*args, **kwargs)
upstream_backends = set(
[getattr(u, "default_client", None) for u in self.upstreams]
)
if None in upstream_backends:
upstream_backends.remove(None)
if len(upstream_backends) > 1:
raise RuntimeError("Mixing backends is not supported")
elif upstream_backends:
self.default_client = upstream_backends.pop()
else:
self.default_client = DEFAULT_BACKENDS.get(backend, backend)
if "loop" not in kwargs and getattr(
self.default_client(), "loop", None
):
loop = self.default_client().loop
self._set_loop(loop)
if kwargs.get("ensure_io_loop", False) and not self.loop:
self._set_asynchronous(False)
if self.loop is None and self.asynchronous is not None:
self._set_loop(get_io_loop(self.asynchronous))

@gen.coroutine
def update(self, x, who=None):
client = default_client()
client = self.default_client()
result = yield client.gather(x, asynchronous=True)
result2 = yield self._emit(result)
raise gen.Return(result2)
if (
not (
isinstance(result, Sequence)
and any(r == NULL_COMPUTE for r in result)
)
and result != NULL_COMPUTE
):
result2 = yield self._emit(result)
raise gen.Return(result2)


@DaskStream.register_api()
Expand All @@ -140,6 +192,24 @@ def update(self, x, who=None):
return self._emit(result)


@DaskStream.register_api()
class filter(DaskStream):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you also need the modifications to the gather and other nodes as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made changes to gather already. I also compared other nodes. There is a slightly difference between the dask.starmap and parallel.starmap. Do I need to change that one?

def __init__(self, upstream, predicate, *args, **kwargs):
if predicate is None:
predicate = _truthy
self.predicate = return_null(predicate)
stream_name = kwargs.pop("stream_name", None)
self.kwargs = kwargs
self.args = args

DaskStream.__init__(self, upstream, stream_name=stream_name)

def update(self, x, who=None):
client = self.default_client()
result = client.submit(self.predicate, x, *self.args, **self.kwargs)
return self._emit(result)


@DaskStream.register_api()
class buffer(DaskStream, core.buffer):
pass
Expand Down
61 changes: 61 additions & 0 deletions streamz/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,67 @@ def test_buffer(c, s, a, b):
assert source.loop == c.loop


@pytest.mark.slow
def test_filter():
source = Stream(asynchronous=True)
futures = scatter(source).filter(lambda x: x % 2 == 0)
futures_L = futures.sink_to_list()
L = futures.gather().sink_to_list()

for i in range(5):
yield source.emit(i)

assert L == [0, 2, 4]
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.slow
def test_filter_buffer(backend):
source = Stream(asynchronous=True)
futures = scatter(source, backend=backend).filter(lambda x: x % 2 == 0)
futures_L = futures.sink_to_list()
L = futures.buffer(10).gather().sink_to_list()

for i in range(5):
yield source.emit(i)
while len(L) < 3:
yield gen.sleep(.01)

assert L == [0, 2, 4]
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.slow
def test_filter_map(backend):
source = Stream(asynchronous=True)
futures = (
scatter(source, backend=backend).filter(lambda x: x % 2 == 0).map(inc)
)
futures_L = futures.sink_to_list()
L = futures.gather().sink_to_list()

for i in range(5):
yield source.emit(i)

assert L == [1, 3, 5]
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.slow
def test_filter_starmap(backend):
source = Stream(asynchronous=True)
futures1 = scatter(source, backend=backend).filter(lambda x: x[1] % 2 == 0)
futures = futures1.starmap(add)
futures_L = futures.sink_to_list()
L = futures.gather().sink_to_list()

for i in range(5):
yield source.emit((i, i))

assert L == [0, 4, 8]
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.slow
def test_buffer_sync(loop): # noqa: F811
with cluster() as (s, [a, b]):
Expand Down