ag-ui-langgraph 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ Metadata-Version: 2.3
2
+ Name: ag-ui-langgraph
3
+ Version: 0.0.1
4
+ Summary:
5
+ Author: Ran Shem Tov
6
+ Author-email: ran@copilotkit.ai
7
+ Requires-Python: >=3.10,<3.14
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.10
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Classifier: Programming Language :: Python :: 3.13
13
+ Provides-Extra: fastapi
14
+ Requires-Dist: ag-ui-protocol (==0.1.7)
15
+ Requires-Dist: fastapi (>=0.115.12,<0.116.0) ; extra == "fastapi"
16
+ Requires-Dist: langchain (>=0.3.0)
17
+ Requires-Dist: langchain-core (>=0.3.0)
18
+ Requires-Dist: langgraph (>=0.3.25,<=0.5.0)
19
+ Description-Content-Type: text/markdown
20
+
21
+
File without changes
@@ -0,0 +1,35 @@
1
+ from .agent import LangGraphAgent
2
+ from .types import (
3
+ LangGraphEventTypes,
4
+ CustomEventNames,
5
+ State,
6
+ SchemaKeys,
7
+ MessageInProgress,
8
+ RunMetadata,
9
+ MessagesInProgressRecord,
10
+ ToolCall,
11
+ BaseLangGraphPlatformMessage,
12
+ LangGraphPlatformResultMessage,
13
+ LangGraphPlatformActionExecutionMessage,
14
+ LangGraphPlatformMessage,
15
+ PredictStateTool
16
+ )
17
+ from .endpoint import add_langgraph_fastapi_endpoint
18
+
19
+ __all__ = [
20
+ "LangGraphAgent",
21
+ "LangGraphEventTypes",
22
+ "CustomEventNames",
23
+ "State",
24
+ "SchemaKeys",
25
+ "MessageInProgress",
26
+ "RunMetadata",
27
+ "MessagesInProgressRecord",
28
+ "ToolCall",
29
+ "BaseLangGraphPlatformMessage",
30
+ "LangGraphPlatformResultMessage",
31
+ "LangGraphPlatformActionExecutionMessage",
32
+ "LangGraphPlatformMessage",
33
+ "PredictStateTool",
34
+ "add_langgraph_fastapi_endpoint"
35
+ ]
@@ -0,0 +1,682 @@
1
+ import uuid
2
+ import json
3
+ from typing import Optional, List, Any, Union, AsyncGenerator, Generator
4
+
5
+ from langgraph.graph.state import CompiledStateGraph
6
+ from langchain.schema import BaseMessage, SystemMessage
7
+ from langchain_core.runnables import RunnableConfig, ensure_config
8
+ from langchain_core.messages import HumanMessage
9
+ from langgraph.types import Command
10
+
11
+ from .types import (
12
+ State,
13
+ LangGraphPlatformMessage,
14
+ MessagesInProgressRecord,
15
+ SchemaKeys,
16
+ MessageInProgress,
17
+ RunMetadata,
18
+ LangGraphEventTypes,
19
+ CustomEventNames,
20
+ LangGraphReasoning
21
+ )
22
+ from .utils import (
23
+ agui_messages_to_langchain,
24
+ DEFAULT_SCHEMA_KEYS,
25
+ filter_object_by_schema_keys,
26
+ get_stream_payload_input,
27
+ langchain_messages_to_agui,
28
+ resolve_reasoning_content,
29
+ resolve_message_content,
30
+ camel_to_snake
31
+ )
32
+
33
+ from ag_ui.core import (
34
+ EventType,
35
+ CustomEvent,
36
+ MessagesSnapshotEvent,
37
+ RawEvent,
38
+ RunAgentInput,
39
+ RunErrorEvent,
40
+ RunFinishedEvent,
41
+ RunStartedEvent,
42
+ StateDeltaEvent,
43
+ StateSnapshotEvent,
44
+ StepFinishedEvent,
45
+ StepStartedEvent,
46
+ TextMessageContentEvent,
47
+ TextMessageEndEvent,
48
+ TextMessageStartEvent,
49
+ ToolCallArgsEvent,
50
+ ToolCallEndEvent,
51
+ ToolCallStartEvent,
52
+ ThinkingTextMessageStartEvent,
53
+ ThinkingTextMessageContentEvent,
54
+ ThinkingTextMessageEndEvent,
55
+ ThinkingStartEvent,
56
+ ThinkingEndEvent,
57
+ )
58
+ from ag_ui.encoder import EventEncoder
59
+
60
+ ProcessedEvents = Union[
61
+ TextMessageStartEvent,
62
+ TextMessageContentEvent,
63
+ TextMessageEndEvent,
64
+ ToolCallStartEvent,
65
+ ToolCallArgsEvent,
66
+ ToolCallEndEvent,
67
+ StateSnapshotEvent,
68
+ StateDeltaEvent,
69
+ MessagesSnapshotEvent,
70
+ RawEvent,
71
+ CustomEvent,
72
+ RunStartedEvent,
73
+ RunFinishedEvent,
74
+ RunErrorEvent,
75
+ StepStartedEvent,
76
+ StepFinishedEvent,
77
+ ]
78
+
79
+ class LangGraphAgent:
80
+ def __init__(self, *, name: str, graph: CompiledStateGraph, description: Optional[str] = None, config: Union[Optional[RunnableConfig], dict] = None):
81
+ self.name = name
82
+ self.description = description
83
+ self.graph = graph
84
+ self.config = config or {}
85
+ self.messages_in_process: MessagesInProgressRecord = {}
86
+ self.active_run: Optional[RunMetadata] = None
87
+ self.constant_schema_keys = ['messages', 'tools']
88
+
89
+ def _dispatch_event(self, event: ProcessedEvents) -> str:
90
+ return event # Fallback if no encoder
91
+
92
+ async def run(self, input: RunAgentInput) -> AsyncGenerator[str, None]:
93
+ forwarded_props = {}
94
+ if hasattr(input, "forwarded_props") and input.forwarded_props:
95
+ forwarded_props = {
96
+ camel_to_snake(k): v for k, v in input.forwarded_props.items()
97
+ }
98
+ async for event_str in self._handle_stream_events(input.copy(update={"forwarded_props": forwarded_props})):
99
+ yield event_str
100
+
101
+ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[str, None]:
102
+ thread_id = input.thread_id or str(uuid.uuid4())
103
+ self.active_run = {
104
+ "id": input.run_id,
105
+ "thread_id": thread_id,
106
+ "thinking_process": None,
107
+ }
108
+
109
+ messages = input.messages or []
110
+ forwarded_props = input.forwarded_props
111
+ node_name_input = forwarded_props.get('node_name', None) if forwarded_props else None
112
+
113
+ self.active_run["manually_emitted_state"] = None
114
+ self.active_run["node_name"] = node_name_input
115
+ if self.active_run["node_name"] == "__end__":
116
+ self.active_run["node_name"] = None
117
+
118
+ config = ensure_config(self.config.copy() if self.config else {})
119
+ config["configurable"] = {**(config.get('configurable', {})), "thread_id": thread_id}
120
+
121
+ agent_state = await self.graph.aget_state(config)
122
+ self.active_run["mode"] = "continue" if thread_id and self.active_run.get("node_name") != "__end__" and self.active_run.get("node_name") else "start"
123
+ prepared_stream_response = await self.prepare_stream(input=input, agent_state=agent_state, config=config)
124
+
125
+ yield self._dispatch_event(
126
+ RunStartedEvent(type=EventType.RUN_STARTED, thread_id=thread_id, run_id=self.active_run["id"])
127
+ )
128
+
129
+ langchain_messages = agui_messages_to_langchain(messages)
130
+ non_system_messages = [msg for msg in langchain_messages if not isinstance(msg, SystemMessage)]
131
+
132
+ if len(agent_state.values.get("messages", [])) > len(non_system_messages):
133
+ # Find the last user message by working backwards from the last message
134
+ last_user_message = None
135
+ for i in range(len(langchain_messages) - 1, -1, -1):
136
+ if isinstance(langchain_messages[i], HumanMessage):
137
+ last_user_message = langchain_messages[i]
138
+ break
139
+
140
+ if last_user_message:
141
+ prepared_stream_response = await self.prepare_regenerate_stream(
142
+ input=input,
143
+ message_checkpoint=last_user_message,
144
+ config=config
145
+ )
146
+
147
+ state = prepared_stream_response["state"]
148
+ stream = prepared_stream_response["stream"]
149
+ config = prepared_stream_response["config"]
150
+ events_to_dispatch = prepared_stream_response.get('events_to_dispatch', None)
151
+
152
+ if events_to_dispatch is not None and len(events_to_dispatch) > 0:
153
+ for event in events_to_dispatch:
154
+ yield self._dispatch_event(event)
155
+ return
156
+
157
+ should_exit = False
158
+ current_graph_state = state
159
+ async for event in stream:
160
+ if event["event"] == "error":
161
+ yield self._dispatch_event(
162
+ RunErrorEvent(type=EventType.RUN_ERROR, message=event["data"]["message"], raw_event=event)
163
+ )
164
+ break
165
+
166
+ current_node_name = event.get("metadata", {}).get("langgraph_node")
167
+ event_type = event.get("event")
168
+ self.active_run["id"] = event.get("run_id")
169
+ exiting_node = False
170
+
171
+ if event_type == "on_chain_end" and isinstance(
172
+ event.get("data", {}).get("output"), dict
173
+ ):
174
+ current_graph_state.update(event["data"]["output"])
175
+ exiting_node = self.active_run["node_name"] == current_node_name
176
+
177
+ should_exit = should_exit or (
178
+ event_type == "on_custom_event" and
179
+ event["name"] == "exit"
180
+ )
181
+
182
+ if current_node_name and current_node_name != self.active_run.get("node_name"):
183
+ if self.active_run["node_name"] and self.active_run["node_name"] != node_name_input:
184
+ yield self._dispatch_event(
185
+ StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
186
+ )
187
+ self.active_run["node_name"] = None
188
+
189
+ yield self._dispatch_event(
190
+ StepStartedEvent(type=EventType.STEP_STARTED, step_name=current_node_name)
191
+ )
192
+ self.active_run["node_name"] = current_node_name
193
+
194
+ updated_state = self.active_run.get("manually_emitted_state") or current_graph_state
195
+ has_state_diff = updated_state != state
196
+ if exiting_node or (has_state_diff and not self.get_message_in_progress(self.active_run["id"])):
197
+ state = updated_state
198
+ self.active_run["prev_node_name"] = self.active_run["node_name"]
199
+ current_graph_state.update(updated_state)
200
+ yield self._dispatch_event(
201
+ StateSnapshotEvent(
202
+ type=EventType.STATE_SNAPSHOT,
203
+ snapshot=self.get_state_snapshot(state),
204
+ raw_event=event,
205
+ )
206
+ )
207
+
208
+ yield self._dispatch_event(
209
+ RawEvent(type=EventType.RAW, event=event)
210
+ )
211
+
212
+ async for single_event in self._handle_single_event(event, state):
213
+ yield single_event
214
+
215
+ state = await self.graph.aget_state(config)
216
+
217
+ tasks = state.tasks if len(state.tasks) > 0 else None
218
+ interrupts = tasks[0].interrupts if tasks else []
219
+
220
+ writes = state.metadata.get("writes", {}) or {}
221
+ node_name = self.active_run["node_name"] if interrupts else next(iter(writes), None)
222
+ next_nodes = state.next or ()
223
+ is_end_node = len(next_nodes) == 0 and not interrupts
224
+
225
+ node_name = "__end__" if is_end_node else node_name
226
+
227
+ for interrupt in interrupts:
228
+ yield self._dispatch_event(
229
+ CustomEvent(
230
+ type=EventType.CUSTOM,
231
+ name=LangGraphEventTypes.OnInterrupt.value,
232
+ value=json.dumps(interrupt.value) if not isinstance(interrupt.value, str) else interrupt.value,
233
+ raw_event=interrupt,
234
+ )
235
+ )
236
+
237
+ if self.active_run.get("node_name") != node_name:
238
+ yield self._dispatch_event(
239
+ StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
240
+ )
241
+ self.active_run["node_name"] = node_name
242
+ yield self._dispatch_event(
243
+ StepStartedEvent(type=EventType.STEP_STARTED, step_name=self.active_run["node_name"])
244
+ )
245
+
246
+ state_values = state.values if state.values else state
247
+ yield self._dispatch_event(
248
+ StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=self.get_state_snapshot(state_values))
249
+ )
250
+
251
+ yield self._dispatch_event(
252
+ MessagesSnapshotEvent(
253
+ type=EventType.MESSAGES_SNAPSHOT,
254
+ messages=langchain_messages_to_agui(state_values.get("messages", [])),
255
+ )
256
+ )
257
+
258
+ yield self._dispatch_event(
259
+ StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
260
+ )
261
+ self.active_run["node_name"] = None
262
+
263
+ yield self._dispatch_event(
264
+ RunFinishedEvent(type=EventType.RUN_FINISHED, thread_id=thread_id, run_id=self.active_run["id"])
265
+ )
266
+ self.active_run = None
267
+
268
+
269
+ async def prepare_stream(self, input: RunAgentInput, agent_state: State, config: RunnableConfig):
270
+ state_input = input.state or {}
271
+ messages = input.messages or []
272
+ tools = input.tools or []
273
+ forwarded_props = input.forwarded_props or {}
274
+ thread_id = input.thread_id
275
+
276
+ state_input["messages"] = agent_state.values.get("messages", [])
277
+ self.active_run["current_graph_state"] = agent_state.values
278
+ langchain_messages = agui_messages_to_langchain(messages)
279
+ state = self.langgraph_default_merge_state(state_input, langchain_messages, tools)
280
+ self.active_run["current_graph_state"].update(state)
281
+ config["configurable"]["thread_id"] = thread_id
282
+ interrupts = agent_state.tasks[0].interrupts if agent_state.tasks and len(agent_state.tasks) > 0 else []
283
+ has_active_interrupts = len(interrupts) > 0
284
+ resume_input = forwarded_props.get('command', {}).get('resume', None)
285
+
286
+ events_to_dispatch = []
287
+ if has_active_interrupts and not resume_input:
288
+ events_to_dispatch.append(
289
+ RunStartedEvent(type=EventType.RUN_STARTED, thread_id=thread_id, run_id=self.active_run["id"])
290
+ )
291
+
292
+ for interrupt in interrupts:
293
+ events_to_dispatch.append(
294
+ CustomEvent(
295
+ type=EventType.CUSTOM,
296
+ name=LangGraphEventTypes.OnInterrupt.value,
297
+ value=json.dumps(interrupt.value) if not isinstance(interrupt.value, str) else interrupt.value,
298
+ raw_event=interrupt,
299
+ )
300
+ )
301
+
302
+ events_to_dispatch.append(
303
+ RunFinishedEvent(type=EventType.RUN_FINISHED, thread_id=thread_id, run_id=self.active_run["id"])
304
+ )
305
+ return {
306
+ "stream": None,
307
+ "state": None,
308
+ "config": None,
309
+ "events_to_dispatch": events_to_dispatch,
310
+ }
311
+
312
+ if self.active_run["mode"] == "continue":
313
+ await self.graph.aupdate_state(config, state, as_node=self.active_run.get("node_name"))
314
+
315
+ self.active_run["schema_keys"] = self.get_schema_keys(config)
316
+
317
+ if resume_input:
318
+ stream_input = Command(resume=resume_input)
319
+ else:
320
+ payload_input = get_stream_payload_input(
321
+ mode=self.active_run["mode"],
322
+ state=state,
323
+ schema_keys=self.active_run["schema_keys"],
324
+ )
325
+ stream_input = {**forwarded_props, **payload_input} if payload_input else None
326
+
327
+ return {
328
+ "stream": self.graph.astream_events(stream_input, config, version="v2"),
329
+ "state": state,
330
+ "config": config
331
+ }
332
+
333
+ async def prepare_regenerate_stream( # pylint: disable=too-many-arguments
334
+ self,
335
+ input: RunAgentInput,
336
+ message_checkpoint: HumanMessage,
337
+ config: RunnableConfig
338
+ ):
339
+ tools = input.tools or []
340
+ thread_id = input.thread_id
341
+
342
+ time_travel_checkpoint = await self.get_checkpoint_before_message(message_checkpoint.id, thread_id)
343
+ if time_travel_checkpoint is None:
344
+ return None
345
+
346
+ fork = await self.graph.aupdate_state(
347
+ time_travel_checkpoint.config,
348
+ time_travel_checkpoint.values,
349
+ as_node=time_travel_checkpoint.next[0] if time_travel_checkpoint.next else "__start__"
350
+ )
351
+
352
+ stream_input = self.langgraph_default_merge_state(time_travel_checkpoint.values, [message_checkpoint], tools)
353
+ stream = self.graph.astream_events(stream_input, fork, version="v2")
354
+
355
+ return {
356
+ "stream": stream,
357
+ "state": time_travel_checkpoint.values,
358
+ "config": config
359
+ }
360
+
361
+ def get_message_in_progress(self, run_id: str) -> Optional[MessageInProgress]:
362
+ return self.messages_in_process.get(run_id)
363
+
364
+ def set_message_in_progress(self, run_id: str, data: MessageInProgress):
365
+ current_message_in_progress = self.messages_in_process.get(run_id, {})
366
+ self.messages_in_process[run_id] = {
367
+ **current_message_in_progress,
368
+ **data,
369
+ }
370
+
371
+ def get_schema_keys(self, config) -> SchemaKeys:
372
+ try:
373
+ input_schema = self.graph.get_input_jsonschema(config)
374
+ output_schema = self.graph.get_output_jsonschema(config)
375
+ config_schema = self.graph.config_schema().schema()
376
+
377
+ input_schema_keys = list(input_schema["properties"].keys()) if "properties" in input_schema else []
378
+ output_schema_keys = list(output_schema["properties"].keys()) if "properties" in output_schema else []
379
+ config_schema_keys = list(config_schema["properties"].keys()) if "properties" in config_schema else []
380
+
381
+ return {
382
+ "input": [*input_schema_keys, *self.constant_schema_keys],
383
+ "output": [*output_schema_keys, *self.constant_schema_keys],
384
+ "config": config_schema_keys,
385
+ }
386
+ except Exception:
387
+ return {
388
+ "input": self.constant_schema_keys,
389
+ "output": self.constant_schema_keys,
390
+ "config": [],
391
+ }
392
+
393
+ def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage], tools: Any) -> State:
394
+ if messages and isinstance(messages[0], SystemMessage):
395
+ messages = messages[1:]
396
+
397
+ existing_messages: List[LangGraphPlatformMessage] = state.get("messages", [])
398
+ existing_message_ids = {msg.id for msg in existing_messages}
399
+
400
+ new_messages = [msg for msg in messages if msg.id not in existing_message_ids]
401
+
402
+ tools_as_dicts = []
403
+ if tools:
404
+ for tool in tools:
405
+ if hasattr(tool, "model_dump"):
406
+ tools_as_dicts.append(tool.model_dump())
407
+ elif hasattr(tool, "dict"):
408
+ tools_as_dicts.append(tool.dict())
409
+ else:
410
+ tools_as_dicts.append(tool)
411
+
412
+ return {
413
+ **state,
414
+ "messages": new_messages,
415
+ "tools": [*state.get("tools", []), *tools_as_dicts],
416
+ }
417
+
418
+ def get_state_snapshot(self, state: State) -> State:
419
+ schema_keys = self.active_run["schema_keys"]
420
+ if schema_keys and schema_keys.get("output"):
421
+ state = filter_object_by_schema_keys(state, [*DEFAULT_SCHEMA_KEYS, *schema_keys["output"]])
422
+ return state
423
+
424
+ async def _handle_single_event(self, event: Any, state: State) -> AsyncGenerator[str, None]:
425
+ event_type = event.get("event")
426
+ if event_type == LangGraphEventTypes.OnChatModelStream:
427
+ should_emit_messages = event["metadata"].get("emit-messages", True)
428
+ should_emit_tool_calls = event["metadata"].get("emit-tool-calls", True)
429
+
430
+ if event["data"]["chunk"].response_metadata.get('finish_reason', None):
431
+ return
432
+
433
+ current_stream = self.get_message_in_progress(self.active_run["id"])
434
+ has_current_stream = bool(current_stream and current_stream.get("id"))
435
+ tool_call_data = event["data"]["chunk"].tool_call_chunks[0] if event["data"]["chunk"].tool_call_chunks else None
436
+ predict_state_metadata = event["metadata"].get("predict_state", [])
437
+ tool_call_used_to_predict_state = False
438
+ if tool_call_data and tool_call_data.get("name") and predict_state_metadata:
439
+ tool_call_used_to_predict_state = any(
440
+ predict_tool.get("tool") == tool_call_data["name"]
441
+ for predict_tool in predict_state_metadata
442
+ )
443
+
444
+ is_tool_call_start_event = not has_current_stream and tool_call_data and tool_call_data.get("name")
445
+ is_tool_call_args_event = has_current_stream and current_stream.get("tool_call_id") and tool_call_data and tool_call_data.get("args")
446
+ is_tool_call_end_event = has_current_stream and current_stream.get("tool_call_id") and not tool_call_data
447
+
448
+ reasoning_data = resolve_reasoning_content(event["data"]["chunk"]) if event["data"]["chunk"] else None
449
+ message_content = resolve_message_content(event["data"]["chunk"].content) if event["data"]["chunk"] and event["data"]["chunk"].content else None
450
+ is_message_content_event = tool_call_data is None and message_content
451
+ is_message_end_event = has_current_stream and not current_stream.get("tool_call_id") and not is_message_content_event
452
+
453
+ if reasoning_data:
454
+ self.handle_thinking_event(reasoning_data)
455
+ return
456
+
457
+ if reasoning_data is None and self.active_run.get('thinking_process', None) is not None:
458
+ yield self._dispatch_event(
459
+ ThinkingTextMessageEndEvent(
460
+ type=EventType.THINKING_TEXT_MESSAGE_END,
461
+ )
462
+ )
463
+ yield self._dispatch_event(
464
+ ThinkingEndEvent(
465
+ type=EventType.THINKING_END,
466
+ )
467
+ )
468
+ self.active_run["thinking_process"] = None
469
+
470
+ if tool_call_used_to_predict_state:
471
+ yield self._dispatch_event(
472
+ CustomEvent(
473
+ type=EventType.CUSTOM,
474
+ name="PredictState",
475
+ value=predict_state_metadata,
476
+ raw_event=event
477
+ )
478
+ )
479
+
480
+ if is_tool_call_end_event:
481
+ yield self._dispatch_event(
482
+ ToolCallEndEvent(type=EventType.TOOL_CALL_END, tool_call_id=current_stream["tool_call_id"], raw_event=event)
483
+ )
484
+ self.messages_in_process[self.active_run["id"]] = None
485
+ return
486
+
487
+
488
+ if is_message_end_event:
489
+ yield self._dispatch_event(
490
+ TextMessageEndEvent(type=EventType.TEXT_MESSAGE_END, message_id=current_stream["id"], raw_event=event)
491
+ )
492
+ self.messages_in_process[self.active_run["id"]] = None
493
+ return
494
+
495
+ if is_tool_call_start_event and should_emit_tool_calls:
496
+ yield self._dispatch_event(
497
+ ToolCallStartEvent(
498
+ type=EventType.TOOL_CALL_START,
499
+ tool_call_id=tool_call_data["id"],
500
+ tool_call_name=tool_call_data["name"],
501
+ parent_message_id=event["data"]["chunk"].id,
502
+ raw_event=event,
503
+ )
504
+ )
505
+ self.set_message_in_progress(
506
+ self.active_run["id"],
507
+ MessageInProgress(id=event["data"]["chunk"].id, tool_call_id=tool_call_data["id"], tool_call_name=tool_call_data["name"])
508
+ )
509
+ return
510
+
511
+ if is_tool_call_args_event and should_emit_tool_calls:
512
+ yield self._dispatch_event(
513
+ ToolCallArgsEvent(
514
+ type=EventType.TOOL_CALL_ARGS,
515
+ tool_call_id=current_stream["tool_call_id"],
516
+ delta=tool_call_data["args"],
517
+ raw_event=event
518
+ )
519
+ )
520
+ return
521
+
522
+ if is_message_content_event and should_emit_messages:
523
+ if bool(current_stream and current_stream.get("id")) == False:
524
+ yield self._dispatch_event(
525
+ TextMessageStartEvent(
526
+ type=EventType.TEXT_MESSAGE_START,
527
+ role="assistant",
528
+ message_id=event["data"]["chunk"].id,
529
+ raw_event=event,
530
+ )
531
+ )
532
+ self.set_message_in_progress(
533
+ self.active_run["id"],
534
+ MessageInProgress(
535
+ id=event["data"]["chunk"].id,
536
+ tool_call_id=None,
537
+ tool_call_name=None
538
+ )
539
+ )
540
+ current_stream = self.get_message_in_progress(self.active_run["id"])
541
+
542
+ yield self._dispatch_event(
543
+ TextMessageContentEvent(
544
+ type=EventType.TEXT_MESSAGE_CONTENT,
545
+ message_id=current_stream["id"],
546
+ delta=event["data"]["chunk"].content,
547
+ raw_event=event,
548
+ )
549
+ )
550
+ return
551
+
552
+ elif event_type == LangGraphEventTypes.OnChatModelEnd:
553
+ if self.get_message_in_progress(self.active_run["id"]) and self.get_message_in_progress(self.active_run["id"]).get("tool_call_id"):
554
+ resolved = self._dispatch_event(
555
+ ToolCallEndEvent(type=EventType.TOOL_CALL_END, tool_call_id=self.get_message_in_progress(self.active_run["id"])["tool_call_id"], raw_event=event)
556
+ )
557
+ if resolved:
558
+ self.messages_in_process[self.active_run["id"]] = None
559
+ yield resolved
560
+ elif self.get_message_in_progress(self.active_run["id"]) and self.get_message_in_progress(self.active_run["id"]).get("id"):
561
+ resolved = self._dispatch_event(
562
+ TextMessageEndEvent(type=EventType.TEXT_MESSAGE_END, message_id=self.get_message_in_progress(self.active_run["id"])["id"], raw_event=event)
563
+ )
564
+ if resolved:
565
+ self.messages_in_process[self.active_run["id"]] = None
566
+ yield resolved
567
+
568
+ elif event_type == LangGraphEventTypes.OnCustomEvent:
569
+ if event["name"] == CustomEventNames.ManuallyEmitMessage:
570
+ yield self._dispatch_event(
571
+ TextMessageStartEvent(type=EventType.TEXT_MESSAGE_START, role="assistant", message_id=event["data"]["message_id"], raw_event=event)
572
+ )
573
+ yield self._dispatch_event(
574
+ TextMessageContentEvent(
575
+ type=EventType.TEXT_MESSAGE_CONTENT,
576
+ message_id=event["data"]["message_id"],
577
+ delta=event["data"]["message"],
578
+ raw_event=event,
579
+ )
580
+ )
581
+ yield self._dispatch_event(
582
+ TextMessageEndEvent(type=EventType.TEXT_MESSAGE_END, message_id=event["data"]["message_id"], raw_event=event)
583
+ )
584
+
585
+ elif event["name"] == CustomEventNames.ManuallyEmitToolCall:
586
+ yield self._dispatch_event(
587
+ ToolCallStartEvent(
588
+ type=EventType.TOOL_CALL_START,
589
+ tool_call_id=event["data"]["id"],
590
+ tool_call_name=event["data"]["name"],
591
+ parent_message_id=event["data"]["id"],
592
+ raw_event=event,
593
+ )
594
+ )
595
+ yield self._dispatch_event(
596
+ ToolCallArgsEvent(type=EventType.TOOL_CALL_ARGS, tool_call_id=event["data"]["id"], delta=event["data"]["args"], raw_event=event)
597
+ )
598
+ yield self._dispatch_event(
599
+ ToolCallEndEvent(type=EventType.TOOL_CALL_END, tool_call_id=event["data"]["id"], raw_event=event)
600
+ )
601
+
602
+ elif event["name"] == CustomEventNames.ManuallyEmitState:
603
+ self.active_run["manually_emitted_state"] = event["data"]
604
+ yield self._dispatch_event(
605
+ StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=self.get_state_snapshot(state), raw_event=event)
606
+ )
607
+
608
+ yield self._dispatch_event(
609
+ CustomEvent(type=EventType.CUSTOM, name=event["name"], value=event["data"], raw_event=event)
610
+ )
611
+
612
+ def handle_thinking_event(self, reasoning_data: LangGraphReasoning) -> Generator[str, Any, str | None]:
613
+ if not reasoning_data or "type" not in reasoning_data or "text" not in reasoning_data:
614
+ return ""
615
+
616
+ thinking_step_index = reasoning_data.get("index")
617
+
618
+ if (self.active_run.get("thinking_process") and
619
+ self.active_run["thinking_process"].get("index") and
620
+ self.active_run["thinking_process"]["index"] != thinking_step_index):
621
+
622
+ if self.active_run["thinking_process"].get("type"):
623
+ yield self._dispatch_event(
624
+ ThinkingTextMessageEndEvent(
625
+ type=EventType.THINKING_TEXT_MESSAGE_END,
626
+ )
627
+ )
628
+ yield self._dispatch_event(
629
+ ThinkingEndEvent(
630
+ type=EventType.THINKING_END,
631
+ )
632
+ )
633
+ self.active_run["thinking_process"] = None
634
+
635
+ if not self.active_run.get("thinking_process"):
636
+ yield self._dispatch_event(
637
+ ThinkingStartEvent(
638
+ type=EventType.THINKING_START,
639
+ )
640
+ )
641
+ self.active_run["thinking_process"] = {
642
+ "index": thinking_step_index
643
+ }
644
+
645
+ if self.active_run["thinking_process"].get("type") != reasoning_data["type"]:
646
+ yield self._dispatch_event(
647
+ ThinkingTextMessageStartEvent(
648
+ type=EventType.THINKING_TEXT_MESSAGE_START,
649
+ )
650
+ )
651
+ self.active_run["thinking_process"]["type"] = reasoning_data["type"]
652
+
653
+ if self.active_run["thinking_process"].get("type"):
654
+ yield self._dispatch_event(
655
+ ThinkingTextMessageContentEvent(
656
+ type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
657
+ delta=reasoning_data["text"]
658
+ )
659
+ )
660
+
661
+ async def get_checkpoint_before_message(self, message_id: str, thread_id: str):
662
+ if not thread_id:
663
+ raise ValueError("Missing thread_id in config")
664
+
665
+ history_list = []
666
+ async for snapshot in self.graph.aget_state_history({"configurable": {"thread_id": thread_id}}):
667
+ history_list.append(snapshot)
668
+
669
+ history_list.reverse()
670
+ for idx, snapshot in enumerate(history_list):
671
+ messages = snapshot.values.get("messages", [])
672
+ if any(getattr(m, "id", None) == message_id for m in messages):
673
+ if idx == 0:
674
+ # No snapshot before this
675
+ # Return synthetic "empty before" version
676
+ empty_snapshot = snapshot
677
+ empty_snapshot.values["messages"] = []
678
+ return empty_snapshot
679
+ return history_list[idx - 1] # return one snapshot *before* the one that includes the message
680
+
681
+ raise ValueError("Message ID not found in history")
682
+
@@ -0,0 +1,27 @@
1
+ from fastapi import FastAPI, HTTPException, Request
2
+ from fastapi.responses import StreamingResponse
3
+
4
+ from ag_ui.core.types import RunAgentInput
5
+ from ag_ui.encoder import EventEncoder
6
+
7
+ from .agent import LangGraphAgent
8
+
9
+ def add_langgraph_fastapi_endpoint(app: FastAPI, agent: LangGraphAgent, path: str = "/"):
10
+ """Adds an endpoint to the FastAPI app."""
11
+
12
+ @app.post(path)
13
+ async def langgraph_agent_endpoint(input_data: RunAgentInput, request: Request):
14
+ # Get the accept header from the request
15
+ accept_header = request.headers.get("accept")
16
+
17
+ # Create an event encoder to properly format SSE events
18
+ encoder = EventEncoder(accept=accept_header)
19
+
20
+ async def event_generator():
21
+ async for event in agent.run(input_data):
22
+ yield encoder.encode(event)
23
+
24
+ return StreamingResponse(
25
+ event_generator(),
26
+ media_type=encoder.get_content_type()
27
+ )
@@ -0,0 +1,91 @@
1
+ from typing import TypedDict, Optional, List, Any, Dict, Union, Literal
2
+ from typing_extensions import NotRequired
3
+ from enum import Enum
4
+
5
+ class LangGraphEventTypes(str, Enum):
6
+ OnChainStart = "on_chain_start"
7
+ OnChainStream = "on_chain_stream"
8
+ OnChainEnd = "on_chain_end"
9
+ OnChatModelStart = "on_chat_model_start"
10
+ OnChatModelStream = "on_chat_model_stream"
11
+ OnChatModelEnd = "on_chat_model_end"
12
+ OnToolStart = "on_tool_start"
13
+ OnToolEnd = "on_tool_end"
14
+ OnCustomEvent = "on_custom_event"
15
+ OnInterrupt = "on_interrupt"
16
+
17
+ class CustomEventNames(str, Enum):
18
+ ManuallyEmitMessage = "manually_emit_message"
19
+ ManuallyEmitToolCall = "manually_emit_tool_call"
20
+ ManuallyEmitState = "manually_emit_state"
21
+ Exit = "exit"
22
+
23
+ State = Dict[str, Any]
24
+
25
+ SchemaKeys = TypedDict("SchemaKeys", {
26
+ "input": NotRequired[Optional[List[str]]],
27
+ "output": NotRequired[Optional[List[str]]],
28
+ "config": NotRequired[Optional[List[str]]]
29
+ })
30
+
31
+ ThinkingProcess = TypedDict("ThinkingProcess", {
32
+ "index": int,
33
+ "type": NotRequired[Optional[Literal['text']]],
34
+ })
35
+
36
+ MessageInProgress = TypedDict("MessageInProgress", {
37
+ "id": str,
38
+ "tool_call_id": NotRequired[Optional[str]],
39
+ "tool_call_name": NotRequired[Optional[str]]
40
+ })
41
+
42
+ RunMetadata = TypedDict("RunMetadata", {
43
+ "id": str,
44
+ "schema_keys": NotRequired[Optional[SchemaKeys]],
45
+ "node_name": NotRequired[Optional[str]],
46
+ "prev_node_name": NotRequired[Optional[str]],
47
+ "exiting_node": NotRequired[bool],
48
+ "manually_emitted_state": NotRequired[Optional[State]],
49
+ "thread_id": NotRequired[Optional[ThinkingProcess]],
50
+ "thinking_process": NotRequired[Optional[str]]
51
+ })
52
+
53
+ MessagesInProgressRecord = Dict[str, Optional[MessageInProgress]]
54
+
55
+ ToolCall = TypedDict("ToolCall", {
56
+ "id": str,
57
+ "name": str,
58
+ "args": Dict[str, Any]
59
+ })
60
+
61
+ class BaseLangGraphPlatformMessage(TypedDict):
62
+ content: str
63
+ role: str
64
+ additional_kwargs: NotRequired[Dict[str, Any]]
65
+ type: str
66
+ id: str
67
+
68
+ class LangGraphPlatformResultMessage(BaseLangGraphPlatformMessage):
69
+ tool_call_id: str
70
+ name: str
71
+
72
+ class LangGraphPlatformActionExecutionMessage(BaseLangGraphPlatformMessage):
73
+ tool_calls: List[ToolCall]
74
+
75
+ LangGraphPlatformMessage = Union[
76
+ LangGraphPlatformActionExecutionMessage,
77
+ LangGraphPlatformResultMessage,
78
+ BaseLangGraphPlatformMessage,
79
+ ]
80
+
81
+ PredictStateTool = TypedDict("PredictStateTool", {
82
+ "tool": str,
83
+ "state_key": str,
84
+ "tool_argument": str
85
+ })
86
+
87
+ LangGraphReasoning = TypedDict("LangGraphReasoning", {
88
+ "type": str,
89
+ "text": str,
90
+ "index": int
91
+ })
@@ -0,0 +1,179 @@
1
+ import json
2
+ import re
3
+ from typing import List, Any, Dict, Union
4
+
5
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
6
+ from ag_ui.core import (
7
+ Message as AGUIMessage,
8
+ UserMessage as AGUIUserMessage,
9
+ AssistantMessage as AGUIAssistantMessage,
10
+ SystemMessage as AGUISystemMessage,
11
+ ToolMessage as AGUIToolMessage,
12
+ ToolCall as AGUIToolCall,
13
+ FunctionCall as AGUIFunctionCall,
14
+ )
15
+ from .types import State, SchemaKeys, LangGraphReasoning
16
+
17
+ DEFAULT_SCHEMA_KEYS = ["tools"]
18
+
19
+ def filter_object_by_schema_keys(obj: Dict[str, Any], schema_keys: List[str]) -> Dict[str, Any]:
20
+ if not obj:
21
+ return {}
22
+ return {k: v for k, v in obj.items() if k in schema_keys}
23
+
24
+ def get_stream_payload_input(
25
+ *,
26
+ mode: str,
27
+ state: State,
28
+ schema_keys: SchemaKeys,
29
+ ) -> Union[State, None]:
30
+ input_payload = state if mode == "start" else None
31
+ if input_payload and schema_keys and schema_keys.get("input"):
32
+ input_payload = filter_object_by_schema_keys(input_payload, [*DEFAULT_SCHEMA_KEYS, *schema_keys["input"]])
33
+ return input_payload
34
+
35
+ def stringify_if_needed(item: Any) -> str:
36
+ if item is None:
37
+ return ''
38
+ if isinstance(item, str):
39
+ return item
40
+ return json.dumps(item)
41
+
42
+ def langchain_messages_to_agui(messages: List[BaseMessage]) -> List[AGUIMessage]:
43
+ agui_messages: List[AGUIMessage] = []
44
+ for message in messages:
45
+ if isinstance(message, HumanMessage):
46
+ agui_messages.append(AGUIUserMessage(
47
+ id=str(message.id),
48
+ role="user",
49
+ content=stringify_if_needed(resolve_message_content(message.content)),
50
+ name=message.name,
51
+ ))
52
+ elif isinstance(message, AIMessage):
53
+ tool_calls = None
54
+ if message.tool_calls:
55
+ tool_calls = [
56
+ AGUIToolCall(
57
+ id=str(tc["id"]),
58
+ type="function",
59
+ function=AGUIFunctionCall(
60
+ name=tc["name"],
61
+ arguments=json.dumps(tc.get("args", {})),
62
+ ),
63
+ )
64
+ for tc in message.tool_calls
65
+ ]
66
+
67
+ agui_messages.append(AGUIAssistantMessage(
68
+ id=str(message.id),
69
+ role="assistant",
70
+ content=stringify_if_needed(resolve_message_content(message.content)),
71
+ tool_calls=tool_calls,
72
+ name=message.name,
73
+ ))
74
+ elif isinstance(message, SystemMessage):
75
+ agui_messages.append(AGUISystemMessage(
76
+ id=str(message.id),
77
+ role="system",
78
+ content=stringify_if_needed(resolve_message_content(message.content)),
79
+ name=message.name,
80
+ ))
81
+ elif isinstance(message, ToolMessage):
82
+ agui_messages.append(AGUIToolMessage(
83
+ id=str(message.id),
84
+ role="tool",
85
+ content=stringify_if_needed(resolve_message_content(message.content)),
86
+ tool_call_id=message.tool_call_id,
87
+ ))
88
+ else:
89
+ raise TypeError(f"Unsupported message type: {type(message)}")
90
+ return agui_messages
91
+
92
+ def agui_messages_to_langchain(messages: List[AGUIMessage]) -> List[BaseMessage]:
93
+ langchain_messages = []
94
+ for message in messages:
95
+ role = message.role
96
+ if role == "user":
97
+ langchain_messages.append(HumanMessage(
98
+ id=message.id,
99
+ content=message.content,
100
+ name=message.name,
101
+ ))
102
+ elif role == "assistant":
103
+ tool_calls = []
104
+ if hasattr(message, "tool_calls") and message.tool_calls:
105
+ for tc in message.tool_calls:
106
+ tool_calls.append({
107
+ "id": tc.id,
108
+ "name": tc.function.name,
109
+ "args": json.loads(tc.function.arguments) if hasattr(tc, "function") and tc.function.arguments else {},
110
+ "type": "tool_call",
111
+ })
112
+ langchain_messages.append(AIMessage(
113
+ id=message.id,
114
+ content=message.content or "",
115
+ tool_calls=tool_calls,
116
+ name=message.name,
117
+ ))
118
+ elif role == "system":
119
+ langchain_messages.append(SystemMessage(
120
+ id=message.id,
121
+ content=message.content,
122
+ name=message.name,
123
+ ))
124
+ elif role == "tool":
125
+ langchain_messages.append(ToolMessage(
126
+ id=message.id,
127
+ content=message.content,
128
+ tool_call_id=message.tool_call_id,
129
+ ))
130
+ else:
131
+ raise ValueError(f"Unsupported message role: {role}")
132
+ return langchain_messages
133
+
134
+ def resolve_reasoning_content(chunk: Any) -> LangGraphReasoning | None:
135
+ content = chunk.content
136
+ if not content:
137
+ return None
138
+
139
+ # Anthropic reasoning response
140
+ if isinstance(content, list) and content and content[0]:
141
+ if not content[0].get("thinking"):
142
+ return None
143
+ return LangGraphReasoning(
144
+ text=content[0]["thinking"],
145
+ type="text",
146
+ index=content[0].get("index", 0)
147
+ )
148
+
149
+ # OpenAI reasoning response
150
+ if hasattr(chunk, "additional_kwargs"):
151
+ reasoning = chunk.additional_kwargs.get("reasoning", {})
152
+ summary = reasoning.get("summary", [])
153
+ if summary:
154
+ data = summary[0]
155
+ if not data or not data.get("text"):
156
+ return None
157
+ return LangGraphReasoning(
158
+ type="text",
159
+ text=data["text"],
160
+ index=data.get("index", 0)
161
+ )
162
+
163
+ return None
164
+
165
+ def resolve_message_content(content: Any) -> str | None:
166
+ if not content:
167
+ return None
168
+
169
+ if isinstance(content, str):
170
+ return content
171
+
172
+ if isinstance(content, list) and content:
173
+ content_text = next((c.get("text") for c in content if isinstance(c, dict) and c.get("type") == "text"), None)
174
+ return content_text
175
+
176
+ return None
177
+
178
+ def camel_to_snake(name):
179
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower()
@@ -0,0 +1,27 @@
1
+ [tool.poetry]
2
+ name = "ag-ui-langgraph"
3
+ version = "0.0.1"
4
+ description = ""
5
+ authors = ["Ran Shem Tov <ran@copilotkit.ai>"]
6
+ readme = "README.md"
7
+ exclude = [
8
+ "ag_ui_langgraph/examples/**",
9
+ ]
10
+
11
+ [tool.poetry.dependencies]
12
+ python = "<3.14,>=3.10"
13
+ ag-ui-protocol = "==0.1.7"
14
+ fastapi = { version = "^0.115.12", optional = true }
15
+ langchain = ">=0.3.0"
16
+ langchain-core = ">=0.3.0"
17
+ langgraph = ">=0.3.25,<=0.5.0"
18
+
19
+ [tool.poetry.extras]
20
+ fastapi = ["fastapi"]
21
+
22
+ [build-system]
23
+ requires = ["poetry-core"]
24
+ build-backend = "poetry.core.masonry.api"
25
+
26
+ [tool.poetry.scripts]
27
+ dev = "ag_ui_langgraph.dojo:main"