Skip to content

Add filter class to dask and do the tests for it #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 61 additions & 3 deletions streamz/dask.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import absolute_import, division, print_function
from functools import wraps

from .core import _truthy
from operator import getitem

from tornado import gen
Expand All @@ -10,6 +12,36 @@
from .core import Stream
from . import core, sources

from collections import Sequence


NULL_COMPUTE = "~~NULL_COMPUTE~~"


def return_null(func):
@wraps(func)
def inner(x, *args, **kwargs):
tv = func(x, *args, **kwargs)
if tv:
return x
else:
return NULL_COMPUTE

return inner


def filter_null_wrapper(func):
@wraps(func)
def inner(*args, **kwargs):
if any(a is NULL_COMPUTE for a in args) or any(
v is NULL_COMPUTE for v in kwargs.values()
):
return NULL_COMPUTE
else:
return func(*args, **kwargs)

return inner


class DaskStream(Stream):
""" A Parallel stream using Dask
Expand Down Expand Up @@ -46,7 +78,7 @@ def __init__(self, *args, **kwargs):
@DaskStream.register_api()
class map(DaskStream):
def __init__(self, upstream, func, *args, **kwargs):
self.func = func
self.func = filter_null_wrapper(func)
self.kwargs = kwargs
self.args = args

Expand Down Expand Up @@ -117,12 +149,20 @@ class gather(core.Stream):
buffer
scatter
"""

@gen.coroutine
def update(self, x, who=None):
client = default_client()
result = yield client.gather(x, asynchronous=True)
result2 = yield self._emit(result)
raise gen.Return(result2)
if (
not (
isinstance(result, Sequence)
and any(r is NULL_COMPUTE for r in result)
)
and result is not NULL_COMPUTE
):
result2 = yield self._emit(result)
raise gen.Return(result2)


@DaskStream.register_api()
Expand All @@ -140,6 +180,24 @@ def update(self, x, who=None):
return self._emit(result)


@DaskStream.register_api()
class filter(DaskStream):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you also need the modifications to the gather and other nodes as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made changes to gather already. I also compared other nodes. There is a slightly difference between the dask.starmap and parallel.starmap. Do I need to change that one?

def __init__(self, upstream, predicate, *args, **kwargs):
if predicate is None:
predicate = _truthy
self.predicate = return_null(predicate)
stream_name = kwargs.pop("stream_name", None)
self.kwargs = kwargs
self.args = args

DaskStream.__init__(self, upstream, stream_name=stream_name)

def update(self, x, who=None):
client = default_client()
result = client.submit(self.predicate, x, *self.args, **self.kwargs)
return self._emit(result)


@DaskStream.register_api()
class buffer(DaskStream, core.buffer):
pass
Expand Down
61 changes: 61 additions & 0 deletions streamz/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,67 @@ def test_buffer(c, s, a, b):
assert source.loop == c.loop


@pytest.mark.slow
def test_filter():
source = Stream(asynchronous=True)
futures = scatter(source).filter(lambda x: x % 2 == 0)
futures_L = futures.sink_to_list()
L = futures.gather().sink_to_list()

for i in range(5):
yield source.emit(i)

assert L == [0, 2, 4]
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.slow
def test_filter_buffer():
source = Stream(asynchronous=True)
futures = scatter(source).filter(lambda x: x % 2 == 0)
futures_L = futures.sink_to_list()
L = futures.buffer(10).gather().sink_to_list()

for i in range(5):
yield source.emit(i)
while len(L) < 3:
yield gen.sleep(.01)

assert L == [0, 2, 4]
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.slow
def test_filter_map():
source = Stream(asynchronous=True)
futures = (
scatter(source).filter(lambda x: x % 2 == 0).map(inc)
)
futures_L = futures.sink_to_list()
L = futures.gather().sink_to_list()

for i in range(5):
yield source.emit(i)

assert L == [1, 3, 5]
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.slow
def test_filter_starmap():
source = Stream(asynchronous=True)
futures1 = scatter(source).filter(lambda x: x[1] % 2 == 0)
futures = futures1.starmap(add)
futures_L = futures.sink_to_list()
L = futures.gather().sink_to_list()

for i in range(5):
yield source.emit((i, i))

assert L == [0, 4, 8]
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.slow
def test_buffer_sync(loop): # noqa: F811
with cluster() as (s, [a, b]):
Expand Down