Skip to content

Feat: allow conditional cycles in graph #571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 16 additions & 22 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Directed Acyclic Graph (DAG) Multi-Agent Pattern Implementation.
"""Directed Graph Multi-Agent Pattern Implementation.

This module provides a deterministic DAG-based agent orchestration system where
This module provides a deterministic graph-based agent orchestration system where
agents or MultiAgentBase instances (like Swarm or Graph) are nodes in a graph,
executed according to edge dependencies, with output from one node passed as input
to connected nodes.

Key Features:
- Agents and MultiAgentBase instances (Swarm, Graph, etc.) as graph nodes
- Deterministic execution order based on DAG structure
- Deterministic execution order based on graph structure
- Cycles are permitted only if at least one edge is conditional.
- Output propagation along edges
- Topological sort for execution ordering
- Clear dependency management
Expand Down Expand Up @@ -253,18 +254,21 @@ def has_cycle_from(node_id: str) -> bool:
colors[node_id] = GRAY
# Check all outgoing edges for cycles
for edge in self.edges:
if edge.from_node.node_id == node_id and has_cycle_from(edge.to_node.node_id):
if not edge.condition and edge.from_node.node_id == node_id and has_cycle_from(edge.to_node.node_id):
return True
colors[node_id] = BLACK
return False

# Check for cycles from each unvisited node
if any(colors[node_id] == WHITE and has_cycle_from(node_id) for node_id in self.nodes):
raise ValueError("Graph contains cycles - must be a directed acyclic graph")
raise ValueError(
"Graph contains unconditional cycles — it must either be a Directed Acyclic Graph (DAG) "
"or contain at least one conditional edge within each cycle."
)


class Graph(MultiAgentBase):
"""Directed Acyclic Graph multi-agent orchestration."""
"""Directed graph multi-agent orchestration."""

def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_points: set[GraphNode]) -> None:
"""Initialize Graph."""
Expand Down Expand Up @@ -332,45 +336,35 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:

async def _execute_graph(self) -> None:
"""Unified execution flow with conditional routing."""
ready_nodes = list(self.entry_points)
ready_nodes = set(self.entry_points)

while ready_nodes:
current_batch = ready_nodes.copy()
ready_nodes.clear()

# Execute current batch of ready nodes concurrently
tasks = [
asyncio.create_task(self._execute_node(node))
for node in current_batch
if node not in self.state.completed_nodes
]
tasks = [asyncio.create_task(self._execute_node(node)) for node in ready_nodes]

for task in tasks:
await task

# Find newly ready nodes after batch execution
ready_nodes.extend(self._find_newly_ready_nodes())
ready_nodes = self._find_newly_ready_nodes(ready_nodes)

def _find_newly_ready_nodes(self) -> list["GraphNode"]:
def _find_newly_ready_nodes(self, executed_nodes: set["GraphNode"]) -> set["GraphNode"]:
"""Find nodes that became ready after the last execution."""
newly_ready = []
for _node_id, node in self.nodes.items():
if (
node not in self.state.completed_nodes
node.dependencies & executed_nodes
and node not in self.state.failed_nodes
and self._is_node_ready_with_conditions(node)
):
newly_ready.append(node)
return newly_ready
return set(newly_ready)

def _is_node_ready_with_conditions(self, node: GraphNode) -> bool:
"""Check if a node is ready considering conditional edges."""
# Get incoming edges to this node
incoming_edges = [edge for edge in self.edges if edge.to_node == node]

if not incoming_edges:
return node in self.entry_points

# Check if at least one incoming edge condition is satisfied
for edge in incoming_edges:
if edge.from_node in self.state.completed_nodes:
Expand Down
20 changes: 19 additions & 1 deletion tests/strands/multiagent/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import pytest
Expand Down Expand Up @@ -378,9 +379,26 @@ def test_graph_builder_validation():
builder.add_edge("c", "a") # Creates cycle
builder.set_entry_point("a")

with pytest.raises(ValueError, match="Graph contains cycles"):
with pytest.raises(
ValueError,
match=re.escape(
"Graph contains unconditional cycles — it must either be a Directed Acyclic Graph (DAG) "
"or contain at least one conditional edge within each cycle."
),
):
builder.build()

# Test cycle detection with back edge condition
builder = GraphBuilder()
builder.add_node(agent1, "a")
builder.add_node(agent2, "b")
builder.add_node(create_mock_agent("agent3"), "c")
builder.add_edge("a", "b")
builder.add_edge("b", "c")
builder.add_edge("c", "a", condition=lambda _: True) # Creates cycle
builder.set_entry_point("a")
builder.build()

# Test auto-detection of entry points
builder = GraphBuilder()
builder.add_node(agent1, "entry")
Expand Down