Skip to content

Commit 1f11ecf

Browse files
committed
Merge branch 'release/4.0.0-beta1'
2 parents 04a8492 + 919ad20 commit 1f11ecf

File tree

262 files changed

+1246
-339
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

262 files changed

+1246
-339
lines changed

README.md

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,10 @@
2323
- [x] Async support
2424
- [x] Streaming support
2525
- [x] Subgraph support
26-
- [ ] Checkpoints (_save and replay feature_)
27-
- [ ] Threads (_checkpointing of multiple different runs_)
28-
- [ ] Update state (_interact with the state directly and update it_)
29-
- [ ] Breakpoints (_pause and resume feature_)
30-
- [ ] Graph migration
26+
- [x] Checkpoints (_save and replay feature_)
27+
- [x] Threads (_checkpointing of multiple different runs_)
28+
- [x] Update state (_interact with the state directly and update it_)
29+
- [x] Breakpoints (_pause and resume feature_)
3130
- [ ] Graph visualization
3231
- [ ] [PlantUML]
3332
- [ ] [Mermaid]
@@ -183,12 +182,83 @@ In the [LangChainDemo](LangChainDemo) project, you can find the porting of [Agen
183182

184183
let app = try workflow.compile()
185184

186-
let result = try await app.invoke(inputs: [ "input": input, "chat_history": [] ])
185+
let result = try await app.invoke(inputs: .args([ "input": input, "chat_history": [] ]) )
187186

188187
print( result )
189188

190189
```
191190

191+
## User interruptions
192+
193+
LangGraph support pause and resume of execution. You must provide a `CheckpointSaver` through `CompileConfig` and a `ThreadId`(aka Session) through `RunnableConfig` to enable it. Below an example
194+
195+
```swift
196+
197+
// Create a memory-based checkpoint saver
198+
let saver = MemoryCheckpointSaver()
199+
200+
// Build the workflow with an initial state
201+
let workflow = try StateGraph { BinaryOpState($0) }
202+
203+
// Add node "agent_1" that returns "add1": 37
204+
.addNode("agent_1") { state in
205+
print( "agent_1", state )
206+
return ["add1": 37]
207+
}
208+
// Add node "agent_2" that returns "add2": 10
209+
.addNode("agent_2") { state in
210+
print( "agent_2", state )
211+
return ["add2": 10]
212+
}
213+
// Add node "sum" that sums add1 and add2
214+
.addNode("sum") { state in
215+
print( "sum", state )
216+
guard let add1 = state.add1, let add2 = state.add2 else {
217+
throw CompiledGraphError.executionError("agent state is not valid! expect 'add1', 'add2'")
218+
}
219+
return ["result": add1 + add2 ]
220+
}
221+
// Define the edges between nodes
222+
.addEdge(sourceId: "agent_1", targetId: "agent_2")
223+
.addEdge(sourceId: "agent_2", targetId: "sum")
224+
.addEdge( sourceId: START, targetId: "agent_1")
225+
.addEdge(sourceId: "sum", targetId: END )
226+
227+
// Compile the workflow, instructing it to interrupt before executing "sum"
228+
let app = try workflow.compile( config: CompileConfig(checkpointSaver: saver, interruptionsBefore: ["sum"]) )
229+
230+
// Start a new run in a different thread
231+
let runnableConfig = RunnableConfig( threadId: "T1" )
232+
233+
let initValue:( lastState:BinaryOpState?, nodes:[String]) = ( nil, [] )
234+
235+
let result = try await app.stream( .args([:]), config: runnableConfig ).reduce( initValue, { partialResult, output in
236+
print( output )
237+
return ( output.state, partialResult.1 + [output.node ] )
238+
})
239+
240+
// This run is also interrupted before "sum"
241+
#expect( dictionaryOfAnyEqual( ["add1": 37, "add2": 10 ], result.lastState!.data) )
242+
243+
// Resume the third run with updated state: change add2 from 10 to 13
244+
let lastCheckpoint2 = try #require( saver.last( config: runnableConfig ) )
245+
var runnableConfig2 = runnableConfig.with { $0.checkpointId = lastCheckpoint2.id }
246+
runnableConfig2 = try await app.updateState(config: runnableConfig2, values: ["add2": 13] )
247+
248+
// Resume and complete execution with updated value
249+
let initValue2:( lastState:BinaryOpState?, nodes:[String]) = ( nil, [] )
250+
let result2 = try await app.stream( .resume, config: runnableConfig2 ).reduce( initValue2, { partialResult, output in
251+
print( output )
252+
return ( output.state, partialResult.1 + [output.node ] )
253+
})
254+
255+
// Verify that "result" now reflects the updated input
256+
#expect( dictionaryOfAnyEqual( ["add1": 37, "add2": 13, "result": 50 ], result2.lastState!.data) )
257+
258+
```
259+
260+
261+
192262
# References
193263

194264
* [AI Agent on iOS with LangGraph for Swift](https://dev.to/bsorrentino/ai-agent-on-ios-with-langgraph-for-swift-1740)

Sources/LangGraph/Checkpoints.swift

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
import Foundation
2+
3+
4+
public enum CheckpointError: Error, LocalizedError {
5+
6+
case missingThreadIdentifier(String)
7+
8+
public var errorDescription: String? {
9+
switch self {
10+
case .missingThreadIdentifier(let message):
11+
return message
12+
}
13+
}
14+
15+
}
16+
17+
/// Represents a checkpoint of an agent state.
18+
///
19+
/// The checkpoint is an immutable object that holds an agent state
20+
/// and a string that represents the next state.
21+
/// It is designed to be serializable and restorable.
22+
public struct Checkpoint : Equatable {
23+
24+
public static func == (lhs: Checkpoint, rhs: Checkpoint) -> Bool {
25+
lhs.id == rhs.id
26+
}
27+
28+
let id: UUID
29+
var state: [String: Any]
30+
var nodeId: String
31+
var nextNodeId: String
32+
33+
public init( state: [String: Any], nodeId: String, nextNodeId: String) {
34+
self.id = UUID()
35+
self.state = state
36+
self.nodeId = nodeId
37+
self.nextNodeId = nextNodeId
38+
}
39+
40+
func updateState(values: PartialAgentState, channels: Channels) throws -> Self {
41+
42+
var editable = self
43+
editable.state = try LangGraph.updateState(currentState: self.state , partialState: values , channels: channels)
44+
return editable
45+
}
46+
}
47+
48+
extension Checkpoint: Codable {
49+
private enum CodingKeys: String, CodingKey {
50+
case id
51+
case state
52+
case nodeId
53+
case nextNodeId
54+
}
55+
56+
public init(from decoder: any Decoder) throws {
57+
let container = try decoder.container(keyedBy: CodingKeys.self)
58+
59+
id = try container.decode(UUID.self, forKey: .id)
60+
state = try container.decode([String: LangGraph.AnyDecodable].self, forKey: .state).mapValues { $0.value }
61+
nodeId = try container.decode(String.self, forKey: .nodeId)
62+
nextNodeId = try container.decode(String.self, forKey: .nextNodeId)
63+
}
64+
65+
public func encode(to encoder: any Encoder) throws {
66+
67+
var container = encoder.container(keyedBy: CodingKeys.self)
68+
try container.encode(id, forKey: .id)
69+
try container.encode( toEncodableStateData(data: state), forKey: .state)
70+
try container.encode(nodeId, forKey: .nodeId)
71+
try container.encode(nextNodeId, forKey: .nextNodeId)
72+
73+
}
74+
}
75+
76+
public struct Tag {
77+
let threadId: String
78+
let checkpoints: AnyCollection<Checkpoint>
79+
}
80+
81+
/// A protocol that defines an interface for saving and retrieving `Checkpoint` instances.
82+
///
83+
/// Conforming types manage checkpoint data associated with a specific `RunnableConfig`,
84+
/// allowing retrieval, update, listing, and release of checkpoints. This protocol enables
85+
/// persistence strategies to be customized for different runtime environments or threading models.
86+
public protocol CheckpointSaver {
87+
/// Returns all checkpoints associated with the provided `RunnableConfig`.
88+
///
89+
/// - Parameter config: The configuration that identifies the context or thread.
90+
/// - Returns: A collection of `Checkpoint` instances.
91+
func list(config: RunnableConfig) -> AnyCollection<Checkpoint>;
92+
93+
/// Retrieves a specific checkpoint based on the provided `RunnableConfig`.
94+
///
95+
/// If `checkpointId` is set in the configuration, the corresponding checkpoint is returned.
96+
/// Otherwise, returns the latest checkpoint for the context.
97+
///
98+
/// - Parameter config: The configuration that may include a specific checkpoint identifier.
99+
/// - Returns: The requested `Checkpoint` instance, or `nil` if not found.
100+
func get(config: RunnableConfig) -> Checkpoint?;
101+
102+
/// Persists a new checkpoint or updates an existing one based on the configuration.
103+
///
104+
/// If the `checkpointId` is set in the configuration, the checkpoint with that ID is updated.
105+
/// Otherwise, the new checkpoint is added to the thread's checkpoint stack.
106+
///
107+
/// - Parameters:
108+
/// - config: The configuration identifying the context or thread.
109+
/// - checkpoint: The checkpoint to be saved or updated.
110+
/// - Returns: A modified `RunnableConfig` reflecting the new checkpoint ID.
111+
/// - Throws: An error if the checkpoint cannot be persisted.
112+
func put(config: RunnableConfig, checkpoint: Checkpoint) throws -> RunnableConfig;
113+
114+
/// Releases all checkpoints associated with the provided `RunnableConfig`.
115+
///
116+
/// This method is responsible for cleanup of checkpoints tied to a specific context or thread.
117+
///
118+
/// - Parameter config: The configuration identifying the context or thread.
119+
/// - Returns: A `Tag` representing the final state of the thread's checkpoints.
120+
/// - Throws: An error if the release operation fails.
121+
func release(config: RunnableConfig) throws -> Tag;
122+
}
123+
124+
extension CheckpointSaver {
125+
@inline(__always) func THREAD_ID_DEFAULT() -> String { "$default" };
126+
127+
@inline(__always) func last( config: RunnableConfig ) -> Checkpoint? {
128+
list( config: config ).first
129+
}
130+
}
131+
132+
struct Stack<T> {
133+
var elements: [T] = []
134+
135+
var isEmpty: Bool {
136+
return elements.isEmpty
137+
}
138+
139+
var count: Int {
140+
return elements.count
141+
}
142+
143+
mutating func push(_ value: T) {
144+
elements.append(value)
145+
}
146+
147+
mutating func pop() -> T? {
148+
return elements.popLast()
149+
}
150+
151+
func peek() -> T? {
152+
return elements.last
153+
}
154+
}
155+
156+
157+
extension Stack: Sequence {
158+
public func makeIterator() -> IndexingIterator<Array<T>> {
159+
return elements.reversed().makeIterator()
160+
}
161+
}
162+
163+
extension Stack {
164+
165+
subscript(id: UUID) -> T? where T == Checkpoint {
166+
get {
167+
return elements.first { $0.id == id }
168+
}
169+
set(newValue) {
170+
guard let newValue else {
171+
fatalError( "Cannot set checkpoint with id \(id) to nil")
172+
}
173+
174+
guard let index = elements.firstIndex(where: { $0.id == id }) else {
175+
fatalError( "Cannot find checkpoint with id \(id)" )
176+
}
177+
178+
elements[index] = newValue
179+
}
180+
}
181+
}
182+
183+
public class MemoryCheckpointSaver: CheckpointSaver {
184+
var checkpointsByThread: [String: Stack<Checkpoint>] = [:];
185+
186+
private func checkpoints(config: RunnableConfig ) -> Stack<Checkpoint> {
187+
let threadId = config.threadId ?? THREAD_ID_DEFAULT()
188+
189+
guard let result = self.checkpointsByThread[threadId] else {
190+
let result = Stack<Checkpoint>();
191+
self.checkpointsByThread[threadId] = result;
192+
return result;
193+
}
194+
return result
195+
196+
}
197+
198+
private func updateCheckpoint( config: RunnableConfig, checkpoints: Stack<Checkpoint> ) {
199+
let threadId = config.threadId ?? THREAD_ID_DEFAULT()
200+
201+
self.checkpointsByThread[threadId] = checkpoints
202+
}
203+
204+
public func get(config: RunnableConfig) -> Checkpoint? {
205+
let checkpoints = checkpoints(config: config);
206+
207+
guard let checkpointId = config.checkpointId else {
208+
return checkpoints.peek()
209+
}
210+
211+
return checkpoints.first(where: { $0.id == checkpointId })
212+
}
213+
214+
public func put(config: RunnableConfig, checkpoint: Checkpoint) throws -> RunnableConfig {
215+
var checkpoints = checkpoints(config: config);
216+
217+
if let checkpointId = config.checkpointId, checkpointId == checkpoint.id {
218+
219+
checkpoints[checkpointId] = checkpoint
220+
}
221+
222+
checkpoints.push(checkpoint)
223+
224+
updateCheckpoint( config: config, checkpoints: checkpoints )
225+
226+
return config.with {
227+
$0.checkpointId = checkpoint.id
228+
}
229+
230+
}
231+
232+
@discardableResult
233+
public func release(config: RunnableConfig) throws -> Tag {
234+
let threadId = config.threadId ?? THREAD_ID_DEFAULT()
235+
236+
guard let removedCheckpoints = self.checkpointsByThread.removeValue(forKey: threadId) else {
237+
throw CheckpointError.missingThreadIdentifier("No checkpoint found for thread \(threadId)")
238+
}
239+
240+
return Tag( threadId: threadId, checkpoints: AnyCollection(removedCheckpoints.elements) )
241+
}
242+
243+
public func list(config: RunnableConfig) -> AnyCollection<Checkpoint> {
244+
let checkpoints = checkpoints(config: config);
245+
246+
return AnyCollection(checkpoints.elements.reversed())
247+
}
248+
}
249+

0 commit comments

Comments
 (0)