Skip to content

Commit 0db6898

Browse files
committed
feat: add dynamic tests and docs
1 parent 669d850 commit 0db6898

File tree

6 files changed

+264
-22
lines changed

6 files changed

+264
-22
lines changed

drive_events/core.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
import inspect
22
import asyncio
33
from typing import Callable, Optional, Union, Any, Tuple
4-
from .types import BaseEvent, EventFunction, EventGroup, EventInput
4+
from .types import (
5+
BaseEvent,
6+
EventFunction,
7+
EventGroup,
8+
EventInput,
9+
_SpecialEventReturn,
10+
ReturnBehavior,
11+
)
512
from .broker import BaseBroker
613
from .utils import (
714
logger,
815
string_to_md5_hash,
916
)
1017

1118

12-
def goto_events(group_markers: list[BaseEvent], any_return: Any):
13-
pass
14-
15-
1619
class EventEngineCls:
1720
def __init__(self, name="default", broker: Optional[BaseBroker] = None):
1821
self.name = name
@@ -23,6 +26,9 @@ def __init__(self, name="default", broker: Optional[BaseBroker] = None):
2326
def reset(self):
2427
self.__event_maps = {}
2528

29+
def get_event_from_id(self, event_id: str) -> Optional[BaseEvent]:
30+
return self.__event_maps.get(event_id)
31+
2632
def make_event(self, func: Union[EventFunction, BaseEvent]) -> BaseEvent:
2733
if isinstance(func, BaseEvent):
2834
self.__event_maps[func.id] = func
@@ -80,23 +86,38 @@ async def invoke_event(
8086
this_run_ctx = {}
8187
queue: list[Tuple[BaseEvent, EventInput]] = [(event, event_input)]
8288

83-
async def run_event(current_event, current_event_input):
89+
async def run_event(current_event: BaseEvent, current_event_input: Any):
8490
result = await current_event.solo_run(current_event_input, global_ctx)
8591
this_run_ctx[current_event.id] = result
86-
for cand_event in self.__event_maps.values():
87-
cand_event_parents = cand_event.parent_groups
88-
for group_hash, group in cand_event_parents.items():
89-
if current_event.id in group.events and all(
90-
[event_id in this_run_ctx for event_id in group.events]
91-
):
92-
this_group_returns = {
93-
event_id: this_run_ctx[event_id]
94-
for event_id in group.events
95-
}
96-
build_input = EventInput(
97-
group_name=group.name, results=this_group_returns
92+
if isinstance(result, _SpecialEventReturn):
93+
if result.behavior == ReturnBehavior.GOTO:
94+
group_markers, any_return = result.returns
95+
for group_marker in group_markers:
96+
this_group_returns = {current_event.id: any_return}
97+
build_input_goto = EventInput(
98+
group_name="$goto",
99+
results=this_group_returns,
100+
behavior=ReturnBehavior.GOTO,
98101
)
99-
queue.append((cand_event, build_input))
102+
queue.append((group_marker, build_input_goto))
103+
elif result.behavior == ReturnBehavior.ABORT:
104+
return
105+
else:
106+
# dispath to events who listen
107+
for cand_event in self.__event_maps.values():
108+
cand_event_parents = cand_event.parent_groups
109+
for group_hash, group in cand_event_parents.items():
110+
if current_event.id in group.events and all(
111+
[event_id in this_run_ctx for event_id in group.events]
112+
):
113+
this_group_returns = {
114+
event_id: this_run_ctx[event_id]
115+
for event_id in group.events
116+
}
117+
build_input = EventInput(
118+
group_name=group.name, results=this_group_returns
119+
)
120+
queue.append((cand_event, build_input))
100121

101122
while len(queue):
102123
this_batch_events = queue[:max_async_events] if max_async_events else queue

drive_events/dynamic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import Any
2+
from .types import (
3+
BaseEvent,
4+
_SpecialEventReturn,
5+
ReturnBehavior,
6+
)
7+
8+
9+
def goto_events(group_markers: list[BaseEvent], any_return: Any) -> _SpecialEventReturn:
10+
return _SpecialEventReturn(
11+
behavior=ReturnBehavior.GOTO, returns=(group_markers, any_return)
12+
)
13+
14+
15+
def abort_this():
16+
return _SpecialEventReturn(behavior=ReturnBehavior.ABORT, returns=None)

drive_events/types.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from enum import Enum
33
from dataclasses import dataclass, field
44
from datetime import datetime
5-
from typing import Callable, Any, Awaitable, Optional, TypeVar, Generic
5+
from typing import Any, Awaitable, Optional, Union, Callable
66

77
from .utils import (
88
string_to_md5_hash,
@@ -39,8 +39,22 @@ class EventInput(EventGroupInput):
3939
pass
4040

4141

42+
@dataclass
43+
class _SpecialEventReturn:
44+
behavior: ReturnBehavior
45+
returns: Any
46+
47+
def __post_init__(self):
48+
if not isinstance(self.behavior, ReturnBehavior):
49+
raise TypeError(
50+
f"behavior must be a ReturnBehavior, not {type(self.behavior)}"
51+
)
52+
53+
4254
# (group_event_results, global ctx set by user) -> result
43-
EventFunction = Callable[[Optional[EventInput], Optional[Any]], Awaitable[Any]]
55+
EventFunction = Callable[
56+
[Optional[EventInput], Optional[Any]], Awaitable[Union[Any, _SpecialEventReturn]]
57+
]
4458

4559

4660
@dataclass

readme.md

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ from drive_events import EventInput, default_drive
5252
async def hello(event: EventInput, global_ctx):
5353
print("hello")
5454

55-
5655
@default_drive.listen_groups([hello])
5756
async def world(event: EventInput, global_ctx):
5857
print("world")
@@ -77,5 +76,129 @@ await default_drive.invoke_event(EVENT, EVENT_INPUT, GLOBAL_CTX)
7776

7877
Check out [examples](./examples) for more user cases!
7978

79+
### Multi-Recv
80+
81+
`drive_events` allow an event to be triggered only when a group of events are produced:
82+
83+
<details>
84+
<summary> code snippet</summary>
85+
86+
```python
87+
@default_drive.make_event
88+
async def start(event: EventInput, global_ctx):
89+
print("start")
90+
91+
@default_drive.listen_groups([start])
92+
async def hello(event: EventInput, global_ctx):
93+
return 1
94+
95+
96+
@default_drive.listen_groups([start])
97+
async def world(event: EventInput, global_ctx):
98+
return 2
99+
100+
101+
@default_drive.listen_groups([hello, world])
102+
async def adding(event: EventInput, global_ctx):
103+
results = event.results
104+
print("adding", hello, world)
105+
return results[hello.id] + results[world.id]
106+
107+
108+
results = asyncio.run(default_drive.invoke_event(start))
109+
assert results[adding.id] == 3
110+
```
111+
</details>
112+
113+
114+
### Parallel
115+
116+
`drive_events` is perfect for workflows that have many network IO that can be awaited in parallel. If two events are listened to the same group of events, then they will be triggered at the same time:
117+
118+
<details>
119+
<summary> code snippet</summary>
120+
121+
```python
122+
@default_drive.make_event
123+
async def start(event: EventInput, global_ctx):
124+
print("start")
125+
126+
@default_drive.listen_groups([start])
127+
async def hello(event: EventInput, global_ctx):
128+
print(datetime.now(), "hello")
129+
await asyncio.sleep(0.2)
130+
print(datetime.now(), "hello done")
131+
132+
133+
@default_drive.listen_groups([start])
134+
async def world(event: EventInput, global_ctx):
135+
print(datetime.now(), "world")
136+
await asyncio.sleep(0.2)
137+
print(datetime.now(), "world done")
138+
139+
140+
asyncio.run(default_drive.invoke_event(start))
141+
```
142+
143+
</details>
144+
145+
146+
147+
### Dynamic
148+
149+
`drive_events` is dynamic. You can use `goto` and `abort` to change the workflow at runtime:
150+
151+
<details>
152+
<summary> code snippet for abort</summary>
153+
154+
```python
155+
from drive_events.dynamic import goto_events, abort_this
156+
157+
@default_drive.make_event
158+
async def a(event: EventInput, global_ctx):
159+
return abort_this()
160+
161+
@default_drive.listen_groups([a])
162+
async def b(event: EventInput, global_ctx):
163+
assert False, "should not be called"
164+
165+
asyncio.run(default_drive.invoke_event(a))
166+
```
167+
168+
</details>
169+
170+
<details>
171+
<summary> code snippet for goto</summary>
172+
173+
```python
174+
call_a_count = 0
175+
176+
@default_drive.make_event
177+
async def a(event: EventInput, global_ctx):
178+
global call_a_count
179+
if call_a_count == 0:
180+
assert event is None
181+
elif call_a_count == 1:
182+
assert event.behavior == ReturnBehavior.GOTO
183+
assert event.results == {b.id: 2}
184+
return abort_this()
185+
call_a_count += 1
186+
return 1
187+
188+
@default_drive.listen_groups([a])
189+
async def b(event: EventInput, global_ctx):
190+
return goto_events([a], 2)
191+
192+
@default_drive.listen_groups([b])
193+
async def c(event: EventInput, global_ctx):
194+
assert False, "should not be called"
195+
196+
asyncio.run(default_drive.invoke_event(a))
197+
```
198+
199+
</details>
200+
201+
202+
80203

81204

tests/test_define.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ async def a(event: EventInput, global_ctx):
5454
assert isinstance(a, BaseEvent)
5555

5656

57+
@pytest.mark.asyncio
58+
async def test_correct_get_id():
59+
@default_drive.make_event
60+
async def a(event: EventInput, global_ctx):
61+
return 1
62+
63+
assert default_drive.get_event_from_id(a.id) == a
64+
65+
5766
@pytest.mark.asyncio
5867
async def test_order():
5968
@default_drive.make_event

tests/test_dynamic_run.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import pytest
2+
from drive_events import default_drive, EventInput
3+
from drive_events.types import ReturnBehavior, _SpecialEventReturn
4+
from drive_events.dynamic import goto_events, abort_this
5+
6+
7+
class DeliberateExcepion(Exception):
8+
pass
9+
10+
11+
def test_special_event_init():
12+
with pytest.raises(TypeError):
13+
_SpecialEventReturn("fool", 1)
14+
15+
16+
@pytest.mark.asyncio
17+
async def test_abort():
18+
@default_drive.make_event
19+
async def a(event: EventInput, global_ctx):
20+
assert global_ctx == {"test_ctx": 1}
21+
return abort_this()
22+
23+
@default_drive.listen_groups([a])
24+
async def b(event: EventInput, global_ctx):
25+
assert False, "should not be called"
26+
27+
result = await default_drive.invoke_event(a, None, {"test_ctx": 1})
28+
print(result)
29+
30+
31+
@pytest.mark.asyncio
32+
async def test_goto():
33+
call_a_count = 0
34+
35+
@default_drive.make_event
36+
async def a(event: EventInput, global_ctx):
37+
nonlocal call_a_count
38+
if call_a_count == 0:
39+
assert event is None
40+
elif call_a_count == 1:
41+
assert event.behavior == ReturnBehavior.GOTO
42+
assert event.group_name == "$goto"
43+
assert event.results == {b.id: 2}
44+
return abort_this()
45+
else:
46+
raise ValueError("should not be called more than twice")
47+
call_a_count += 1
48+
return 1
49+
50+
@default_drive.listen_groups([a])
51+
async def b(event: EventInput, global_ctx):
52+
return goto_events([a], 2)
53+
54+
@default_drive.listen_groups([b])
55+
async def c(event: EventInput, global_ctx):
56+
assert False, "should not be called"
57+
58+
result = await default_drive.invoke_event(a, None, {"test_ctx": 1})
59+
print(result)

0 commit comments

Comments
 (0)