Skip to content

Commit db38481

Browse files
committed
feat: invoking events with streaming
1 parent 7ec3d18 commit db38481

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

drive_flow/core.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,26 @@ async def run_event(current_event: BaseEvent, current_event_input: Any):
119119
)
120120
queue.append((cand_event, build_input))
121121

122-
while len(queue):
123-
this_batch_events = queue[:max_async_events] if max_async_events else queue
124-
queue = queue[max_async_events:] if max_async_events else []
125-
logger.debug(
126-
f"Running a turn with {len(this_batch_events)} event tasks, left {len(queue)} event tasks in queue"
127-
)
128-
await asyncio.gather(
129-
*[run_event(*run_event_input) for run_event_input in this_batch_events]
130-
)
122+
tasks = set()
123+
try:
124+
while len(queue) or len(tasks):
125+
this_batch_events = (
126+
queue[:max_async_events] if max_async_events else queue
127+
)
128+
queue = queue[max_async_events:] if max_async_events else []
129+
new_tasks = {
130+
asyncio.create_task(run_event(*run_event_input))
131+
for run_event_input in this_batch_events
132+
}
133+
tasks.update(new_tasks)
134+
done, tasks = await asyncio.wait(
135+
tasks, return_when=asyncio.FIRST_COMPLETED
136+
)
137+
for task in done:
138+
await task # Handle any exceptions
139+
except asyncio.CancelledError:
140+
for task in tasks:
141+
task.cancel()
142+
await asyncio.gather(*tasks, return_exceptions=True)
143+
raise
131144
return this_run_ctx

tests/test_run.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import pytest
23
from drive_flow import default_drive, EventInput
34
from drive_flow.types import ReturnBehavior
@@ -70,6 +71,7 @@ async def a(event: EventInput, global_ctx):
7071

7172
@default_drive.listen_group([start])
7273
async def b(event: EventInput, global_ctx):
74+
await asyncio.sleep(0.2)
7375
return 2
7476

7577
@default_drive.listen_group([a, b])

0 commit comments

Comments
 (0)