ag-ui-langgraph 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ag_ui_langgraph-0.0.1/PKG-INFO +21 -0
- ag_ui_langgraph-0.0.1/README.md +0 -0
- ag_ui_langgraph-0.0.1/ag_ui_langgraph/__init__.py +35 -0
- ag_ui_langgraph-0.0.1/ag_ui_langgraph/agent.py +682 -0
- ag_ui_langgraph-0.0.1/ag_ui_langgraph/endpoint.py +27 -0
- ag_ui_langgraph-0.0.1/ag_ui_langgraph/types.py +91 -0
- ag_ui_langgraph-0.0.1/ag_ui_langgraph/utils.py +179 -0
- ag_ui_langgraph-0.0.1/pyproject.toml +27 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: ag-ui-langgraph
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary:
|
|
5
|
+
Author: Ran Shem Tov
|
|
6
|
+
Author-email: ran@copilotkit.ai
|
|
7
|
+
Requires-Python: >=3.10,<3.14
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
13
|
+
Provides-Extra: fastapi
|
|
14
|
+
Requires-Dist: ag-ui-protocol (==0.1.7)
|
|
15
|
+
Requires-Dist: fastapi (>=0.115.12,<0.116.0) ; extra == "fastapi"
|
|
16
|
+
Requires-Dist: langchain (>=0.3.0)
|
|
17
|
+
Requires-Dist: langchain-core (>=0.3.0)
|
|
18
|
+
Requires-Dist: langgraph (>=0.3.25,<=0.5.0)
|
|
19
|
+
Description-Content-Type: text/markdown
|
|
20
|
+
|
|
21
|
+
|
|
File without changes
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from .agent import LangGraphAgent
|
|
2
|
+
from .types import (
|
|
3
|
+
LangGraphEventTypes,
|
|
4
|
+
CustomEventNames,
|
|
5
|
+
State,
|
|
6
|
+
SchemaKeys,
|
|
7
|
+
MessageInProgress,
|
|
8
|
+
RunMetadata,
|
|
9
|
+
MessagesInProgressRecord,
|
|
10
|
+
ToolCall,
|
|
11
|
+
BaseLangGraphPlatformMessage,
|
|
12
|
+
LangGraphPlatformResultMessage,
|
|
13
|
+
LangGraphPlatformActionExecutionMessage,
|
|
14
|
+
LangGraphPlatformMessage,
|
|
15
|
+
PredictStateTool
|
|
16
|
+
)
|
|
17
|
+
from .endpoint import add_langgraph_fastapi_endpoint
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"LangGraphAgent",
|
|
21
|
+
"LangGraphEventTypes",
|
|
22
|
+
"CustomEventNames",
|
|
23
|
+
"State",
|
|
24
|
+
"SchemaKeys",
|
|
25
|
+
"MessageInProgress",
|
|
26
|
+
"RunMetadata",
|
|
27
|
+
"MessagesInProgressRecord",
|
|
28
|
+
"ToolCall",
|
|
29
|
+
"BaseLangGraphPlatformMessage",
|
|
30
|
+
"LangGraphPlatformResultMessage",
|
|
31
|
+
"LangGraphPlatformActionExecutionMessage",
|
|
32
|
+
"LangGraphPlatformMessage",
|
|
33
|
+
"PredictStateTool",
|
|
34
|
+
"add_langgraph_fastapi_endpoint"
|
|
35
|
+
]
|
|
@@ -0,0 +1,682 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
import json
|
|
3
|
+
from typing import Optional, List, Any, Union, AsyncGenerator, Generator
|
|
4
|
+
|
|
5
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
6
|
+
from langchain.schema import BaseMessage, SystemMessage
|
|
7
|
+
from langchain_core.runnables import RunnableConfig, ensure_config
|
|
8
|
+
from langchain_core.messages import HumanMessage
|
|
9
|
+
from langgraph.types import Command
|
|
10
|
+
|
|
11
|
+
from .types import (
|
|
12
|
+
State,
|
|
13
|
+
LangGraphPlatformMessage,
|
|
14
|
+
MessagesInProgressRecord,
|
|
15
|
+
SchemaKeys,
|
|
16
|
+
MessageInProgress,
|
|
17
|
+
RunMetadata,
|
|
18
|
+
LangGraphEventTypes,
|
|
19
|
+
CustomEventNames,
|
|
20
|
+
LangGraphReasoning
|
|
21
|
+
)
|
|
22
|
+
from .utils import (
|
|
23
|
+
agui_messages_to_langchain,
|
|
24
|
+
DEFAULT_SCHEMA_KEYS,
|
|
25
|
+
filter_object_by_schema_keys,
|
|
26
|
+
get_stream_payload_input,
|
|
27
|
+
langchain_messages_to_agui,
|
|
28
|
+
resolve_reasoning_content,
|
|
29
|
+
resolve_message_content,
|
|
30
|
+
camel_to_snake
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
from ag_ui.core import (
|
|
34
|
+
EventType,
|
|
35
|
+
CustomEvent,
|
|
36
|
+
MessagesSnapshotEvent,
|
|
37
|
+
RawEvent,
|
|
38
|
+
RunAgentInput,
|
|
39
|
+
RunErrorEvent,
|
|
40
|
+
RunFinishedEvent,
|
|
41
|
+
RunStartedEvent,
|
|
42
|
+
StateDeltaEvent,
|
|
43
|
+
StateSnapshotEvent,
|
|
44
|
+
StepFinishedEvent,
|
|
45
|
+
StepStartedEvent,
|
|
46
|
+
TextMessageContentEvent,
|
|
47
|
+
TextMessageEndEvent,
|
|
48
|
+
TextMessageStartEvent,
|
|
49
|
+
ToolCallArgsEvent,
|
|
50
|
+
ToolCallEndEvent,
|
|
51
|
+
ToolCallStartEvent,
|
|
52
|
+
ThinkingTextMessageStartEvent,
|
|
53
|
+
ThinkingTextMessageContentEvent,
|
|
54
|
+
ThinkingTextMessageEndEvent,
|
|
55
|
+
ThinkingStartEvent,
|
|
56
|
+
ThinkingEndEvent,
|
|
57
|
+
)
|
|
58
|
+
from ag_ui.encoder import EventEncoder
|
|
59
|
+
|
|
60
|
+
ProcessedEvents = Union[
|
|
61
|
+
TextMessageStartEvent,
|
|
62
|
+
TextMessageContentEvent,
|
|
63
|
+
TextMessageEndEvent,
|
|
64
|
+
ToolCallStartEvent,
|
|
65
|
+
ToolCallArgsEvent,
|
|
66
|
+
ToolCallEndEvent,
|
|
67
|
+
StateSnapshotEvent,
|
|
68
|
+
StateDeltaEvent,
|
|
69
|
+
MessagesSnapshotEvent,
|
|
70
|
+
RawEvent,
|
|
71
|
+
CustomEvent,
|
|
72
|
+
RunStartedEvent,
|
|
73
|
+
RunFinishedEvent,
|
|
74
|
+
RunErrorEvent,
|
|
75
|
+
StepStartedEvent,
|
|
76
|
+
StepFinishedEvent,
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
class LangGraphAgent:
|
|
80
|
+
def __init__(self, *, name: str, graph: CompiledStateGraph, description: Optional[str] = None, config: Union[Optional[RunnableConfig], dict] = None):
|
|
81
|
+
self.name = name
|
|
82
|
+
self.description = description
|
|
83
|
+
self.graph = graph
|
|
84
|
+
self.config = config or {}
|
|
85
|
+
self.messages_in_process: MessagesInProgressRecord = {}
|
|
86
|
+
self.active_run: Optional[RunMetadata] = None
|
|
87
|
+
self.constant_schema_keys = ['messages', 'tools']
|
|
88
|
+
|
|
89
|
+
def _dispatch_event(self, event: ProcessedEvents) -> str:
|
|
90
|
+
return event # Fallback if no encoder
|
|
91
|
+
|
|
92
|
+
async def run(self, input: RunAgentInput) -> AsyncGenerator[str, None]:
|
|
93
|
+
forwarded_props = {}
|
|
94
|
+
if hasattr(input, "forwarded_props") and input.forwarded_props:
|
|
95
|
+
forwarded_props = {
|
|
96
|
+
camel_to_snake(k): v for k, v in input.forwarded_props.items()
|
|
97
|
+
}
|
|
98
|
+
async for event_str in self._handle_stream_events(input.copy(update={"forwarded_props": forwarded_props})):
|
|
99
|
+
yield event_str
|
|
100
|
+
|
|
101
|
+
async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[str, None]:
|
|
102
|
+
thread_id = input.thread_id or str(uuid.uuid4())
|
|
103
|
+
self.active_run = {
|
|
104
|
+
"id": input.run_id,
|
|
105
|
+
"thread_id": thread_id,
|
|
106
|
+
"thinking_process": None,
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
messages = input.messages or []
|
|
110
|
+
forwarded_props = input.forwarded_props
|
|
111
|
+
node_name_input = forwarded_props.get('node_name', None) if forwarded_props else None
|
|
112
|
+
|
|
113
|
+
self.active_run["manually_emitted_state"] = None
|
|
114
|
+
self.active_run["node_name"] = node_name_input
|
|
115
|
+
if self.active_run["node_name"] == "__end__":
|
|
116
|
+
self.active_run["node_name"] = None
|
|
117
|
+
|
|
118
|
+
config = ensure_config(self.config.copy() if self.config else {})
|
|
119
|
+
config["configurable"] = {**(config.get('configurable', {})), "thread_id": thread_id}
|
|
120
|
+
|
|
121
|
+
agent_state = await self.graph.aget_state(config)
|
|
122
|
+
self.active_run["mode"] = "continue" if thread_id and self.active_run.get("node_name") != "__end__" and self.active_run.get("node_name") else "start"
|
|
123
|
+
prepared_stream_response = await self.prepare_stream(input=input, agent_state=agent_state, config=config)
|
|
124
|
+
|
|
125
|
+
yield self._dispatch_event(
|
|
126
|
+
RunStartedEvent(type=EventType.RUN_STARTED, thread_id=thread_id, run_id=self.active_run["id"])
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
langchain_messages = agui_messages_to_langchain(messages)
|
|
130
|
+
non_system_messages = [msg for msg in langchain_messages if not isinstance(msg, SystemMessage)]
|
|
131
|
+
|
|
132
|
+
if len(agent_state.values.get("messages", [])) > len(non_system_messages):
|
|
133
|
+
# Find the last user message by working backwards from the last message
|
|
134
|
+
last_user_message = None
|
|
135
|
+
for i in range(len(langchain_messages) - 1, -1, -1):
|
|
136
|
+
if isinstance(langchain_messages[i], HumanMessage):
|
|
137
|
+
last_user_message = langchain_messages[i]
|
|
138
|
+
break
|
|
139
|
+
|
|
140
|
+
if last_user_message:
|
|
141
|
+
prepared_stream_response = await self.prepare_regenerate_stream(
|
|
142
|
+
input=input,
|
|
143
|
+
message_checkpoint=last_user_message,
|
|
144
|
+
config=config
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
state = prepared_stream_response["state"]
|
|
148
|
+
stream = prepared_stream_response["stream"]
|
|
149
|
+
config = prepared_stream_response["config"]
|
|
150
|
+
events_to_dispatch = prepared_stream_response.get('events_to_dispatch', None)
|
|
151
|
+
|
|
152
|
+
if events_to_dispatch is not None and len(events_to_dispatch) > 0:
|
|
153
|
+
for event in events_to_dispatch:
|
|
154
|
+
yield self._dispatch_event(event)
|
|
155
|
+
return
|
|
156
|
+
|
|
157
|
+
should_exit = False
|
|
158
|
+
current_graph_state = state
|
|
159
|
+
async for event in stream:
|
|
160
|
+
if event["event"] == "error":
|
|
161
|
+
yield self._dispatch_event(
|
|
162
|
+
RunErrorEvent(type=EventType.RUN_ERROR, message=event["data"]["message"], raw_event=event)
|
|
163
|
+
)
|
|
164
|
+
break
|
|
165
|
+
|
|
166
|
+
current_node_name = event.get("metadata", {}).get("langgraph_node")
|
|
167
|
+
event_type = event.get("event")
|
|
168
|
+
self.active_run["id"] = event.get("run_id")
|
|
169
|
+
exiting_node = False
|
|
170
|
+
|
|
171
|
+
if event_type == "on_chain_end" and isinstance(
|
|
172
|
+
event.get("data", {}).get("output"), dict
|
|
173
|
+
):
|
|
174
|
+
current_graph_state.update(event["data"]["output"])
|
|
175
|
+
exiting_node = self.active_run["node_name"] == current_node_name
|
|
176
|
+
|
|
177
|
+
should_exit = should_exit or (
|
|
178
|
+
event_type == "on_custom_event" and
|
|
179
|
+
event["name"] == "exit"
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
if current_node_name and current_node_name != self.active_run.get("node_name"):
|
|
183
|
+
if self.active_run["node_name"] and self.active_run["node_name"] != node_name_input:
|
|
184
|
+
yield self._dispatch_event(
|
|
185
|
+
StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
|
|
186
|
+
)
|
|
187
|
+
self.active_run["node_name"] = None
|
|
188
|
+
|
|
189
|
+
yield self._dispatch_event(
|
|
190
|
+
StepStartedEvent(type=EventType.STEP_STARTED, step_name=current_node_name)
|
|
191
|
+
)
|
|
192
|
+
self.active_run["node_name"] = current_node_name
|
|
193
|
+
|
|
194
|
+
updated_state = self.active_run.get("manually_emitted_state") or current_graph_state
|
|
195
|
+
has_state_diff = updated_state != state
|
|
196
|
+
if exiting_node or (has_state_diff and not self.get_message_in_progress(self.active_run["id"])):
|
|
197
|
+
state = updated_state
|
|
198
|
+
self.active_run["prev_node_name"] = self.active_run["node_name"]
|
|
199
|
+
current_graph_state.update(updated_state)
|
|
200
|
+
yield self._dispatch_event(
|
|
201
|
+
StateSnapshotEvent(
|
|
202
|
+
type=EventType.STATE_SNAPSHOT,
|
|
203
|
+
snapshot=self.get_state_snapshot(state),
|
|
204
|
+
raw_event=event,
|
|
205
|
+
)
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
yield self._dispatch_event(
|
|
209
|
+
RawEvent(type=EventType.RAW, event=event)
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
async for single_event in self._handle_single_event(event, state):
|
|
213
|
+
yield single_event
|
|
214
|
+
|
|
215
|
+
state = await self.graph.aget_state(config)
|
|
216
|
+
|
|
217
|
+
tasks = state.tasks if len(state.tasks) > 0 else None
|
|
218
|
+
interrupts = tasks[0].interrupts if tasks else []
|
|
219
|
+
|
|
220
|
+
writes = state.metadata.get("writes", {}) or {}
|
|
221
|
+
node_name = self.active_run["node_name"] if interrupts else next(iter(writes), None)
|
|
222
|
+
next_nodes = state.next or ()
|
|
223
|
+
is_end_node = len(next_nodes) == 0 and not interrupts
|
|
224
|
+
|
|
225
|
+
node_name = "__end__" if is_end_node else node_name
|
|
226
|
+
|
|
227
|
+
for interrupt in interrupts:
|
|
228
|
+
yield self._dispatch_event(
|
|
229
|
+
CustomEvent(
|
|
230
|
+
type=EventType.CUSTOM,
|
|
231
|
+
name=LangGraphEventTypes.OnInterrupt.value,
|
|
232
|
+
value=json.dumps(interrupt.value) if not isinstance(interrupt.value, str) else interrupt.value,
|
|
233
|
+
raw_event=interrupt,
|
|
234
|
+
)
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
if self.active_run.get("node_name") != node_name:
|
|
238
|
+
yield self._dispatch_event(
|
|
239
|
+
StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
|
|
240
|
+
)
|
|
241
|
+
self.active_run["node_name"] = node_name
|
|
242
|
+
yield self._dispatch_event(
|
|
243
|
+
StepStartedEvent(type=EventType.STEP_STARTED, step_name=self.active_run["node_name"])
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
state_values = state.values if state.values else state
|
|
247
|
+
yield self._dispatch_event(
|
|
248
|
+
StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=self.get_state_snapshot(state_values))
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
yield self._dispatch_event(
|
|
252
|
+
MessagesSnapshotEvent(
|
|
253
|
+
type=EventType.MESSAGES_SNAPSHOT,
|
|
254
|
+
messages=langchain_messages_to_agui(state_values.get("messages", [])),
|
|
255
|
+
)
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
yield self._dispatch_event(
|
|
259
|
+
StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
|
|
260
|
+
)
|
|
261
|
+
self.active_run["node_name"] = None
|
|
262
|
+
|
|
263
|
+
yield self._dispatch_event(
|
|
264
|
+
RunFinishedEvent(type=EventType.RUN_FINISHED, thread_id=thread_id, run_id=self.active_run["id"])
|
|
265
|
+
)
|
|
266
|
+
self.active_run = None
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
async def prepare_stream(self, input: RunAgentInput, agent_state: State, config: RunnableConfig):
|
|
270
|
+
state_input = input.state or {}
|
|
271
|
+
messages = input.messages or []
|
|
272
|
+
tools = input.tools or []
|
|
273
|
+
forwarded_props = input.forwarded_props or {}
|
|
274
|
+
thread_id = input.thread_id
|
|
275
|
+
|
|
276
|
+
state_input["messages"] = agent_state.values.get("messages", [])
|
|
277
|
+
self.active_run["current_graph_state"] = agent_state.values
|
|
278
|
+
langchain_messages = agui_messages_to_langchain(messages)
|
|
279
|
+
state = self.langgraph_default_merge_state(state_input, langchain_messages, tools)
|
|
280
|
+
self.active_run["current_graph_state"].update(state)
|
|
281
|
+
config["configurable"]["thread_id"] = thread_id
|
|
282
|
+
interrupts = agent_state.tasks[0].interrupts if agent_state.tasks and len(agent_state.tasks) > 0 else []
|
|
283
|
+
has_active_interrupts = len(interrupts) > 0
|
|
284
|
+
resume_input = forwarded_props.get('command', {}).get('resume', None)
|
|
285
|
+
|
|
286
|
+
events_to_dispatch = []
|
|
287
|
+
if has_active_interrupts and not resume_input:
|
|
288
|
+
events_to_dispatch.append(
|
|
289
|
+
RunStartedEvent(type=EventType.RUN_STARTED, thread_id=thread_id, run_id=self.active_run["id"])
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
for interrupt in interrupts:
|
|
293
|
+
events_to_dispatch.append(
|
|
294
|
+
CustomEvent(
|
|
295
|
+
type=EventType.CUSTOM,
|
|
296
|
+
name=LangGraphEventTypes.OnInterrupt.value,
|
|
297
|
+
value=json.dumps(interrupt.value) if not isinstance(interrupt.value, str) else interrupt.value,
|
|
298
|
+
raw_event=interrupt,
|
|
299
|
+
)
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
events_to_dispatch.append(
|
|
303
|
+
RunFinishedEvent(type=EventType.RUN_FINISHED, thread_id=thread_id, run_id=self.active_run["id"])
|
|
304
|
+
)
|
|
305
|
+
return {
|
|
306
|
+
"stream": None,
|
|
307
|
+
"state": None,
|
|
308
|
+
"config": None,
|
|
309
|
+
"events_to_dispatch": events_to_dispatch,
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
if self.active_run["mode"] == "continue":
|
|
313
|
+
await self.graph.aupdate_state(config, state, as_node=self.active_run.get("node_name"))
|
|
314
|
+
|
|
315
|
+
self.active_run["schema_keys"] = self.get_schema_keys(config)
|
|
316
|
+
|
|
317
|
+
if resume_input:
|
|
318
|
+
stream_input = Command(resume=resume_input)
|
|
319
|
+
else:
|
|
320
|
+
payload_input = get_stream_payload_input(
|
|
321
|
+
mode=self.active_run["mode"],
|
|
322
|
+
state=state,
|
|
323
|
+
schema_keys=self.active_run["schema_keys"],
|
|
324
|
+
)
|
|
325
|
+
stream_input = {**forwarded_props, **payload_input} if payload_input else None
|
|
326
|
+
|
|
327
|
+
return {
|
|
328
|
+
"stream": self.graph.astream_events(stream_input, config, version="v2"),
|
|
329
|
+
"state": state,
|
|
330
|
+
"config": config
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
async def prepare_regenerate_stream( # pylint: disable=too-many-arguments
|
|
334
|
+
self,
|
|
335
|
+
input: RunAgentInput,
|
|
336
|
+
message_checkpoint: HumanMessage,
|
|
337
|
+
config: RunnableConfig
|
|
338
|
+
):
|
|
339
|
+
tools = input.tools or []
|
|
340
|
+
thread_id = input.thread_id
|
|
341
|
+
|
|
342
|
+
time_travel_checkpoint = await self.get_checkpoint_before_message(message_checkpoint.id, thread_id)
|
|
343
|
+
if time_travel_checkpoint is None:
|
|
344
|
+
return None
|
|
345
|
+
|
|
346
|
+
fork = await self.graph.aupdate_state(
|
|
347
|
+
time_travel_checkpoint.config,
|
|
348
|
+
time_travel_checkpoint.values,
|
|
349
|
+
as_node=time_travel_checkpoint.next[0] if time_travel_checkpoint.next else "__start__"
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
stream_input = self.langgraph_default_merge_state(time_travel_checkpoint.values, [message_checkpoint], tools)
|
|
353
|
+
stream = self.graph.astream_events(stream_input, fork, version="v2")
|
|
354
|
+
|
|
355
|
+
return {
|
|
356
|
+
"stream": stream,
|
|
357
|
+
"state": time_travel_checkpoint.values,
|
|
358
|
+
"config": config
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
def get_message_in_progress(self, run_id: str) -> Optional[MessageInProgress]:
|
|
362
|
+
return self.messages_in_process.get(run_id)
|
|
363
|
+
|
|
364
|
+
def set_message_in_progress(self, run_id: str, data: MessageInProgress):
|
|
365
|
+
current_message_in_progress = self.messages_in_process.get(run_id, {})
|
|
366
|
+
self.messages_in_process[run_id] = {
|
|
367
|
+
**current_message_in_progress,
|
|
368
|
+
**data,
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
def get_schema_keys(self, config) -> SchemaKeys:
|
|
372
|
+
try:
|
|
373
|
+
input_schema = self.graph.get_input_jsonschema(config)
|
|
374
|
+
output_schema = self.graph.get_output_jsonschema(config)
|
|
375
|
+
config_schema = self.graph.config_schema().schema()
|
|
376
|
+
|
|
377
|
+
input_schema_keys = list(input_schema["properties"].keys()) if "properties" in input_schema else []
|
|
378
|
+
output_schema_keys = list(output_schema["properties"].keys()) if "properties" in output_schema else []
|
|
379
|
+
config_schema_keys = list(config_schema["properties"].keys()) if "properties" in config_schema else []
|
|
380
|
+
|
|
381
|
+
return {
|
|
382
|
+
"input": [*input_schema_keys, *self.constant_schema_keys],
|
|
383
|
+
"output": [*output_schema_keys, *self.constant_schema_keys],
|
|
384
|
+
"config": config_schema_keys,
|
|
385
|
+
}
|
|
386
|
+
except Exception:
|
|
387
|
+
return {
|
|
388
|
+
"input": self.constant_schema_keys,
|
|
389
|
+
"output": self.constant_schema_keys,
|
|
390
|
+
"config": [],
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage], tools: Any) -> State:
|
|
394
|
+
if messages and isinstance(messages[0], SystemMessage):
|
|
395
|
+
messages = messages[1:]
|
|
396
|
+
|
|
397
|
+
existing_messages: List[LangGraphPlatformMessage] = state.get("messages", [])
|
|
398
|
+
existing_message_ids = {msg.id for msg in existing_messages}
|
|
399
|
+
|
|
400
|
+
new_messages = [msg for msg in messages if msg.id not in existing_message_ids]
|
|
401
|
+
|
|
402
|
+
tools_as_dicts = []
|
|
403
|
+
if tools:
|
|
404
|
+
for tool in tools:
|
|
405
|
+
if hasattr(tool, "model_dump"):
|
|
406
|
+
tools_as_dicts.append(tool.model_dump())
|
|
407
|
+
elif hasattr(tool, "dict"):
|
|
408
|
+
tools_as_dicts.append(tool.dict())
|
|
409
|
+
else:
|
|
410
|
+
tools_as_dicts.append(tool)
|
|
411
|
+
|
|
412
|
+
return {
|
|
413
|
+
**state,
|
|
414
|
+
"messages": new_messages,
|
|
415
|
+
"tools": [*state.get("tools", []), *tools_as_dicts],
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
def get_state_snapshot(self, state: State) -> State:
|
|
419
|
+
schema_keys = self.active_run["schema_keys"]
|
|
420
|
+
if schema_keys and schema_keys.get("output"):
|
|
421
|
+
state = filter_object_by_schema_keys(state, [*DEFAULT_SCHEMA_KEYS, *schema_keys["output"]])
|
|
422
|
+
return state
|
|
423
|
+
|
|
424
|
+
async def _handle_single_event(self, event: Any, state: State) -> AsyncGenerator[str, None]:
|
|
425
|
+
event_type = event.get("event")
|
|
426
|
+
if event_type == LangGraphEventTypes.OnChatModelStream:
|
|
427
|
+
should_emit_messages = event["metadata"].get("emit-messages", True)
|
|
428
|
+
should_emit_tool_calls = event["metadata"].get("emit-tool-calls", True)
|
|
429
|
+
|
|
430
|
+
if event["data"]["chunk"].response_metadata.get('finish_reason', None):
|
|
431
|
+
return
|
|
432
|
+
|
|
433
|
+
current_stream = self.get_message_in_progress(self.active_run["id"])
|
|
434
|
+
has_current_stream = bool(current_stream and current_stream.get("id"))
|
|
435
|
+
tool_call_data = event["data"]["chunk"].tool_call_chunks[0] if event["data"]["chunk"].tool_call_chunks else None
|
|
436
|
+
predict_state_metadata = event["metadata"].get("predict_state", [])
|
|
437
|
+
tool_call_used_to_predict_state = False
|
|
438
|
+
if tool_call_data and tool_call_data.get("name") and predict_state_metadata:
|
|
439
|
+
tool_call_used_to_predict_state = any(
|
|
440
|
+
predict_tool.get("tool") == tool_call_data["name"]
|
|
441
|
+
for predict_tool in predict_state_metadata
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
is_tool_call_start_event = not has_current_stream and tool_call_data and tool_call_data.get("name")
|
|
445
|
+
is_tool_call_args_event = has_current_stream and current_stream.get("tool_call_id") and tool_call_data and tool_call_data.get("args")
|
|
446
|
+
is_tool_call_end_event = has_current_stream and current_stream.get("tool_call_id") and not tool_call_data
|
|
447
|
+
|
|
448
|
+
reasoning_data = resolve_reasoning_content(event["data"]["chunk"]) if event["data"]["chunk"] else None
|
|
449
|
+
message_content = resolve_message_content(event["data"]["chunk"].content) if event["data"]["chunk"] and event["data"]["chunk"].content else None
|
|
450
|
+
is_message_content_event = tool_call_data is None and message_content
|
|
451
|
+
is_message_end_event = has_current_stream and not current_stream.get("tool_call_id") and not is_message_content_event
|
|
452
|
+
|
|
453
|
+
if reasoning_data:
|
|
454
|
+
self.handle_thinking_event(reasoning_data)
|
|
455
|
+
return
|
|
456
|
+
|
|
457
|
+
if reasoning_data is None and self.active_run.get('thinking_process', None) is not None:
|
|
458
|
+
yield self._dispatch_event(
|
|
459
|
+
ThinkingTextMessageEndEvent(
|
|
460
|
+
type=EventType.THINKING_TEXT_MESSAGE_END,
|
|
461
|
+
)
|
|
462
|
+
)
|
|
463
|
+
yield self._dispatch_event(
|
|
464
|
+
ThinkingEndEvent(
|
|
465
|
+
type=EventType.THINKING_END,
|
|
466
|
+
)
|
|
467
|
+
)
|
|
468
|
+
self.active_run["thinking_process"] = None
|
|
469
|
+
|
|
470
|
+
if tool_call_used_to_predict_state:
|
|
471
|
+
yield self._dispatch_event(
|
|
472
|
+
CustomEvent(
|
|
473
|
+
type=EventType.CUSTOM,
|
|
474
|
+
name="PredictState",
|
|
475
|
+
value=predict_state_metadata,
|
|
476
|
+
raw_event=event
|
|
477
|
+
)
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
if is_tool_call_end_event:
|
|
481
|
+
yield self._dispatch_event(
|
|
482
|
+
ToolCallEndEvent(type=EventType.TOOL_CALL_END, tool_call_id=current_stream["tool_call_id"], raw_event=event)
|
|
483
|
+
)
|
|
484
|
+
self.messages_in_process[self.active_run["id"]] = None
|
|
485
|
+
return
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
if is_message_end_event:
|
|
489
|
+
yield self._dispatch_event(
|
|
490
|
+
TextMessageEndEvent(type=EventType.TEXT_MESSAGE_END, message_id=current_stream["id"], raw_event=event)
|
|
491
|
+
)
|
|
492
|
+
self.messages_in_process[self.active_run["id"]] = None
|
|
493
|
+
return
|
|
494
|
+
|
|
495
|
+
if is_tool_call_start_event and should_emit_tool_calls:
|
|
496
|
+
yield self._dispatch_event(
|
|
497
|
+
ToolCallStartEvent(
|
|
498
|
+
type=EventType.TOOL_CALL_START,
|
|
499
|
+
tool_call_id=tool_call_data["id"],
|
|
500
|
+
tool_call_name=tool_call_data["name"],
|
|
501
|
+
parent_message_id=event["data"]["chunk"].id,
|
|
502
|
+
raw_event=event,
|
|
503
|
+
)
|
|
504
|
+
)
|
|
505
|
+
self.set_message_in_progress(
|
|
506
|
+
self.active_run["id"],
|
|
507
|
+
MessageInProgress(id=event["data"]["chunk"].id, tool_call_id=tool_call_data["id"], tool_call_name=tool_call_data["name"])
|
|
508
|
+
)
|
|
509
|
+
return
|
|
510
|
+
|
|
511
|
+
if is_tool_call_args_event and should_emit_tool_calls:
|
|
512
|
+
yield self._dispatch_event(
|
|
513
|
+
ToolCallArgsEvent(
|
|
514
|
+
type=EventType.TOOL_CALL_ARGS,
|
|
515
|
+
tool_call_id=current_stream["tool_call_id"],
|
|
516
|
+
delta=tool_call_data["args"],
|
|
517
|
+
raw_event=event
|
|
518
|
+
)
|
|
519
|
+
)
|
|
520
|
+
return
|
|
521
|
+
|
|
522
|
+
if is_message_content_event and should_emit_messages:
|
|
523
|
+
if bool(current_stream and current_stream.get("id")) == False:
|
|
524
|
+
yield self._dispatch_event(
|
|
525
|
+
TextMessageStartEvent(
|
|
526
|
+
type=EventType.TEXT_MESSAGE_START,
|
|
527
|
+
role="assistant",
|
|
528
|
+
message_id=event["data"]["chunk"].id,
|
|
529
|
+
raw_event=event,
|
|
530
|
+
)
|
|
531
|
+
)
|
|
532
|
+
self.set_message_in_progress(
|
|
533
|
+
self.active_run["id"],
|
|
534
|
+
MessageInProgress(
|
|
535
|
+
id=event["data"]["chunk"].id,
|
|
536
|
+
tool_call_id=None,
|
|
537
|
+
tool_call_name=None
|
|
538
|
+
)
|
|
539
|
+
)
|
|
540
|
+
current_stream = self.get_message_in_progress(self.active_run["id"])
|
|
541
|
+
|
|
542
|
+
yield self._dispatch_event(
|
|
543
|
+
TextMessageContentEvent(
|
|
544
|
+
type=EventType.TEXT_MESSAGE_CONTENT,
|
|
545
|
+
message_id=current_stream["id"],
|
|
546
|
+
delta=event["data"]["chunk"].content,
|
|
547
|
+
raw_event=event,
|
|
548
|
+
)
|
|
549
|
+
)
|
|
550
|
+
return
|
|
551
|
+
|
|
552
|
+
elif event_type == LangGraphEventTypes.OnChatModelEnd:
|
|
553
|
+
if self.get_message_in_progress(self.active_run["id"]) and self.get_message_in_progress(self.active_run["id"]).get("tool_call_id"):
|
|
554
|
+
resolved = self._dispatch_event(
|
|
555
|
+
ToolCallEndEvent(type=EventType.TOOL_CALL_END, tool_call_id=self.get_message_in_progress(self.active_run["id"])["tool_call_id"], raw_event=event)
|
|
556
|
+
)
|
|
557
|
+
if resolved:
|
|
558
|
+
self.messages_in_process[self.active_run["id"]] = None
|
|
559
|
+
yield resolved
|
|
560
|
+
elif self.get_message_in_progress(self.active_run["id"]) and self.get_message_in_progress(self.active_run["id"]).get("id"):
|
|
561
|
+
resolved = self._dispatch_event(
|
|
562
|
+
TextMessageEndEvent(type=EventType.TEXT_MESSAGE_END, message_id=self.get_message_in_progress(self.active_run["id"])["id"], raw_event=event)
|
|
563
|
+
)
|
|
564
|
+
if resolved:
|
|
565
|
+
self.messages_in_process[self.active_run["id"]] = None
|
|
566
|
+
yield resolved
|
|
567
|
+
|
|
568
|
+
elif event_type == LangGraphEventTypes.OnCustomEvent:
|
|
569
|
+
if event["name"] == CustomEventNames.ManuallyEmitMessage:
|
|
570
|
+
yield self._dispatch_event(
|
|
571
|
+
TextMessageStartEvent(type=EventType.TEXT_MESSAGE_START, role="assistant", message_id=event["data"]["message_id"], raw_event=event)
|
|
572
|
+
)
|
|
573
|
+
yield self._dispatch_event(
|
|
574
|
+
TextMessageContentEvent(
|
|
575
|
+
type=EventType.TEXT_MESSAGE_CONTENT,
|
|
576
|
+
message_id=event["data"]["message_id"],
|
|
577
|
+
delta=event["data"]["message"],
|
|
578
|
+
raw_event=event,
|
|
579
|
+
)
|
|
580
|
+
)
|
|
581
|
+
yield self._dispatch_event(
|
|
582
|
+
TextMessageEndEvent(type=EventType.TEXT_MESSAGE_END, message_id=event["data"]["message_id"], raw_event=event)
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
elif event["name"] == CustomEventNames.ManuallyEmitToolCall:
|
|
586
|
+
yield self._dispatch_event(
|
|
587
|
+
ToolCallStartEvent(
|
|
588
|
+
type=EventType.TOOL_CALL_START,
|
|
589
|
+
tool_call_id=event["data"]["id"],
|
|
590
|
+
tool_call_name=event["data"]["name"],
|
|
591
|
+
parent_message_id=event["data"]["id"],
|
|
592
|
+
raw_event=event,
|
|
593
|
+
)
|
|
594
|
+
)
|
|
595
|
+
yield self._dispatch_event(
|
|
596
|
+
ToolCallArgsEvent(type=EventType.TOOL_CALL_ARGS, tool_call_id=event["data"]["id"], delta=event["data"]["args"], raw_event=event)
|
|
597
|
+
)
|
|
598
|
+
yield self._dispatch_event(
|
|
599
|
+
ToolCallEndEvent(type=EventType.TOOL_CALL_END, tool_call_id=event["data"]["id"], raw_event=event)
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
elif event["name"] == CustomEventNames.ManuallyEmitState:
|
|
603
|
+
self.active_run["manually_emitted_state"] = event["data"]
|
|
604
|
+
yield self._dispatch_event(
|
|
605
|
+
StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=self.get_state_snapshot(state), raw_event=event)
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
yield self._dispatch_event(
|
|
609
|
+
CustomEvent(type=EventType.CUSTOM, name=event["name"], value=event["data"], raw_event=event)
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
def handle_thinking_event(self, reasoning_data: LangGraphReasoning) -> Generator[str, Any, str | None]:
|
|
613
|
+
if not reasoning_data or "type" not in reasoning_data or "text" not in reasoning_data:
|
|
614
|
+
return ""
|
|
615
|
+
|
|
616
|
+
thinking_step_index = reasoning_data.get("index")
|
|
617
|
+
|
|
618
|
+
if (self.active_run.get("thinking_process") and
|
|
619
|
+
self.active_run["thinking_process"].get("index") and
|
|
620
|
+
self.active_run["thinking_process"]["index"] != thinking_step_index):
|
|
621
|
+
|
|
622
|
+
if self.active_run["thinking_process"].get("type"):
|
|
623
|
+
yield self._dispatch_event(
|
|
624
|
+
ThinkingTextMessageEndEvent(
|
|
625
|
+
type=EventType.THINKING_TEXT_MESSAGE_END,
|
|
626
|
+
)
|
|
627
|
+
)
|
|
628
|
+
yield self._dispatch_event(
|
|
629
|
+
ThinkingEndEvent(
|
|
630
|
+
type=EventType.THINKING_END,
|
|
631
|
+
)
|
|
632
|
+
)
|
|
633
|
+
self.active_run["thinking_process"] = None
|
|
634
|
+
|
|
635
|
+
if not self.active_run.get("thinking_process"):
|
|
636
|
+
yield self._dispatch_event(
|
|
637
|
+
ThinkingStartEvent(
|
|
638
|
+
type=EventType.THINKING_START,
|
|
639
|
+
)
|
|
640
|
+
)
|
|
641
|
+
self.active_run["thinking_process"] = {
|
|
642
|
+
"index": thinking_step_index
|
|
643
|
+
}
|
|
644
|
+
|
|
645
|
+
if self.active_run["thinking_process"].get("type") != reasoning_data["type"]:
|
|
646
|
+
yield self._dispatch_event(
|
|
647
|
+
ThinkingTextMessageStartEvent(
|
|
648
|
+
type=EventType.THINKING_TEXT_MESSAGE_START,
|
|
649
|
+
)
|
|
650
|
+
)
|
|
651
|
+
self.active_run["thinking_process"]["type"] = reasoning_data["type"]
|
|
652
|
+
|
|
653
|
+
if self.active_run["thinking_process"].get("type"):
|
|
654
|
+
yield self._dispatch_event(
|
|
655
|
+
ThinkingTextMessageContentEvent(
|
|
656
|
+
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
|
|
657
|
+
delta=reasoning_data["text"]
|
|
658
|
+
)
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
async def get_checkpoint_before_message(self, message_id: str, thread_id: str):
|
|
662
|
+
if not thread_id:
|
|
663
|
+
raise ValueError("Missing thread_id in config")
|
|
664
|
+
|
|
665
|
+
history_list = []
|
|
666
|
+
async for snapshot in self.graph.aget_state_history({"configurable": {"thread_id": thread_id}}):
|
|
667
|
+
history_list.append(snapshot)
|
|
668
|
+
|
|
669
|
+
history_list.reverse()
|
|
670
|
+
for idx, snapshot in enumerate(history_list):
|
|
671
|
+
messages = snapshot.values.get("messages", [])
|
|
672
|
+
if any(getattr(m, "id", None) == message_id for m in messages):
|
|
673
|
+
if idx == 0:
|
|
674
|
+
# No snapshot before this
|
|
675
|
+
# Return synthetic "empty before" version
|
|
676
|
+
empty_snapshot = snapshot
|
|
677
|
+
empty_snapshot.values["messages"] = []
|
|
678
|
+
return empty_snapshot
|
|
679
|
+
return history_list[idx - 1] # return one snapshot *before* the one that includes the message
|
|
680
|
+
|
|
681
|
+
raise ValueError("Message ID not found in history")
|
|
682
|
+
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from fastapi import FastAPI, HTTPException, Request
|
|
2
|
+
from fastapi.responses import StreamingResponse
|
|
3
|
+
|
|
4
|
+
from ag_ui.core.types import RunAgentInput
|
|
5
|
+
from ag_ui.encoder import EventEncoder
|
|
6
|
+
|
|
7
|
+
from .agent import LangGraphAgent
|
|
8
|
+
|
|
9
|
+
def add_langgraph_fastapi_endpoint(app: FastAPI, agent: LangGraphAgent, path: str = "/"):
|
|
10
|
+
"""Adds an endpoint to the FastAPI app."""
|
|
11
|
+
|
|
12
|
+
@app.post(path)
|
|
13
|
+
async def langgraph_agent_endpoint(input_data: RunAgentInput, request: Request):
|
|
14
|
+
# Get the accept header from the request
|
|
15
|
+
accept_header = request.headers.get("accept")
|
|
16
|
+
|
|
17
|
+
# Create an event encoder to properly format SSE events
|
|
18
|
+
encoder = EventEncoder(accept=accept_header)
|
|
19
|
+
|
|
20
|
+
async def event_generator():
|
|
21
|
+
async for event in agent.run(input_data):
|
|
22
|
+
yield encoder.encode(event)
|
|
23
|
+
|
|
24
|
+
return StreamingResponse(
|
|
25
|
+
event_generator(),
|
|
26
|
+
media_type=encoder.get_content_type()
|
|
27
|
+
)
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from typing import TypedDict, Optional, List, Any, Dict, Union, Literal
|
|
2
|
+
from typing_extensions import NotRequired
|
|
3
|
+
from enum import Enum
|
|
4
|
+
|
|
5
|
+
class LangGraphEventTypes(str, Enum):
|
|
6
|
+
OnChainStart = "on_chain_start"
|
|
7
|
+
OnChainStream = "on_chain_stream"
|
|
8
|
+
OnChainEnd = "on_chain_end"
|
|
9
|
+
OnChatModelStart = "on_chat_model_start"
|
|
10
|
+
OnChatModelStream = "on_chat_model_stream"
|
|
11
|
+
OnChatModelEnd = "on_chat_model_end"
|
|
12
|
+
OnToolStart = "on_tool_start"
|
|
13
|
+
OnToolEnd = "on_tool_end"
|
|
14
|
+
OnCustomEvent = "on_custom_event"
|
|
15
|
+
OnInterrupt = "on_interrupt"
|
|
16
|
+
|
|
17
|
+
class CustomEventNames(str, Enum):
|
|
18
|
+
ManuallyEmitMessage = "manually_emit_message"
|
|
19
|
+
ManuallyEmitToolCall = "manually_emit_tool_call"
|
|
20
|
+
ManuallyEmitState = "manually_emit_state"
|
|
21
|
+
Exit = "exit"
|
|
22
|
+
|
|
23
|
+
State = Dict[str, Any]
|
|
24
|
+
|
|
25
|
+
SchemaKeys = TypedDict("SchemaKeys", {
|
|
26
|
+
"input": NotRequired[Optional[List[str]]],
|
|
27
|
+
"output": NotRequired[Optional[List[str]]],
|
|
28
|
+
"config": NotRequired[Optional[List[str]]]
|
|
29
|
+
})
|
|
30
|
+
|
|
31
|
+
ThinkingProcess = TypedDict("ThinkingProcess", {
|
|
32
|
+
"index": int,
|
|
33
|
+
"type": NotRequired[Optional[Literal['text']]],
|
|
34
|
+
})
|
|
35
|
+
|
|
36
|
+
MessageInProgress = TypedDict("MessageInProgress", {
|
|
37
|
+
"id": str,
|
|
38
|
+
"tool_call_id": NotRequired[Optional[str]],
|
|
39
|
+
"tool_call_name": NotRequired[Optional[str]]
|
|
40
|
+
})
|
|
41
|
+
|
|
42
|
+
RunMetadata = TypedDict("RunMetadata", {
|
|
43
|
+
"id": str,
|
|
44
|
+
"schema_keys": NotRequired[Optional[SchemaKeys]],
|
|
45
|
+
"node_name": NotRequired[Optional[str]],
|
|
46
|
+
"prev_node_name": NotRequired[Optional[str]],
|
|
47
|
+
"exiting_node": NotRequired[bool],
|
|
48
|
+
"manually_emitted_state": NotRequired[Optional[State]],
|
|
49
|
+
"thread_id": NotRequired[Optional[ThinkingProcess]],
|
|
50
|
+
"thinking_process": NotRequired[Optional[str]]
|
|
51
|
+
})
|
|
52
|
+
|
|
53
|
+
MessagesInProgressRecord = Dict[str, Optional[MessageInProgress]]
|
|
54
|
+
|
|
55
|
+
ToolCall = TypedDict("ToolCall", {
|
|
56
|
+
"id": str,
|
|
57
|
+
"name": str,
|
|
58
|
+
"args": Dict[str, Any]
|
|
59
|
+
})
|
|
60
|
+
|
|
61
|
+
class BaseLangGraphPlatformMessage(TypedDict):
|
|
62
|
+
content: str
|
|
63
|
+
role: str
|
|
64
|
+
additional_kwargs: NotRequired[Dict[str, Any]]
|
|
65
|
+
type: str
|
|
66
|
+
id: str
|
|
67
|
+
|
|
68
|
+
class LangGraphPlatformResultMessage(BaseLangGraphPlatformMessage):
|
|
69
|
+
tool_call_id: str
|
|
70
|
+
name: str
|
|
71
|
+
|
|
72
|
+
class LangGraphPlatformActionExecutionMessage(BaseLangGraphPlatformMessage):
|
|
73
|
+
tool_calls: List[ToolCall]
|
|
74
|
+
|
|
75
|
+
LangGraphPlatformMessage = Union[
|
|
76
|
+
LangGraphPlatformActionExecutionMessage,
|
|
77
|
+
LangGraphPlatformResultMessage,
|
|
78
|
+
BaseLangGraphPlatformMessage,
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
PredictStateTool = TypedDict("PredictStateTool", {
|
|
82
|
+
"tool": str,
|
|
83
|
+
"state_key": str,
|
|
84
|
+
"tool_argument": str
|
|
85
|
+
})
|
|
86
|
+
|
|
87
|
+
LangGraphReasoning = TypedDict("LangGraphReasoning", {
|
|
88
|
+
"type": str,
|
|
89
|
+
"text": str,
|
|
90
|
+
"index": int
|
|
91
|
+
})
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
from typing import List, Any, Dict, Union
|
|
4
|
+
|
|
5
|
+
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
|
|
6
|
+
from ag_ui.core import (
|
|
7
|
+
Message as AGUIMessage,
|
|
8
|
+
UserMessage as AGUIUserMessage,
|
|
9
|
+
AssistantMessage as AGUIAssistantMessage,
|
|
10
|
+
SystemMessage as AGUISystemMessage,
|
|
11
|
+
ToolMessage as AGUIToolMessage,
|
|
12
|
+
ToolCall as AGUIToolCall,
|
|
13
|
+
FunctionCall as AGUIFunctionCall,
|
|
14
|
+
)
|
|
15
|
+
from .types import State, SchemaKeys, LangGraphReasoning
|
|
16
|
+
|
|
17
|
+
DEFAULT_SCHEMA_KEYS = ["tools"]
|
|
18
|
+
|
|
19
|
+
def filter_object_by_schema_keys(obj: Dict[str, Any], schema_keys: List[str]) -> Dict[str, Any]:
|
|
20
|
+
if not obj:
|
|
21
|
+
return {}
|
|
22
|
+
return {k: v for k, v in obj.items() if k in schema_keys}
|
|
23
|
+
|
|
24
|
+
def get_stream_payload_input(
|
|
25
|
+
*,
|
|
26
|
+
mode: str,
|
|
27
|
+
state: State,
|
|
28
|
+
schema_keys: SchemaKeys,
|
|
29
|
+
) -> Union[State, None]:
|
|
30
|
+
input_payload = state if mode == "start" else None
|
|
31
|
+
if input_payload and schema_keys and schema_keys.get("input"):
|
|
32
|
+
input_payload = filter_object_by_schema_keys(input_payload, [*DEFAULT_SCHEMA_KEYS, *schema_keys["input"]])
|
|
33
|
+
return input_payload
|
|
34
|
+
|
|
35
|
+
def stringify_if_needed(item: Any) -> str:
|
|
36
|
+
if item is None:
|
|
37
|
+
return ''
|
|
38
|
+
if isinstance(item, str):
|
|
39
|
+
return item
|
|
40
|
+
return json.dumps(item)
|
|
41
|
+
|
|
42
|
+
def langchain_messages_to_agui(messages: List[BaseMessage]) -> List[AGUIMessage]:
|
|
43
|
+
agui_messages: List[AGUIMessage] = []
|
|
44
|
+
for message in messages:
|
|
45
|
+
if isinstance(message, HumanMessage):
|
|
46
|
+
agui_messages.append(AGUIUserMessage(
|
|
47
|
+
id=str(message.id),
|
|
48
|
+
role="user",
|
|
49
|
+
content=stringify_if_needed(resolve_message_content(message.content)),
|
|
50
|
+
name=message.name,
|
|
51
|
+
))
|
|
52
|
+
elif isinstance(message, AIMessage):
|
|
53
|
+
tool_calls = None
|
|
54
|
+
if message.tool_calls:
|
|
55
|
+
tool_calls = [
|
|
56
|
+
AGUIToolCall(
|
|
57
|
+
id=str(tc["id"]),
|
|
58
|
+
type="function",
|
|
59
|
+
function=AGUIFunctionCall(
|
|
60
|
+
name=tc["name"],
|
|
61
|
+
arguments=json.dumps(tc.get("args", {})),
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
for tc in message.tool_calls
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
agui_messages.append(AGUIAssistantMessage(
|
|
68
|
+
id=str(message.id),
|
|
69
|
+
role="assistant",
|
|
70
|
+
content=stringify_if_needed(resolve_message_content(message.content)),
|
|
71
|
+
tool_calls=tool_calls,
|
|
72
|
+
name=message.name,
|
|
73
|
+
))
|
|
74
|
+
elif isinstance(message, SystemMessage):
|
|
75
|
+
agui_messages.append(AGUISystemMessage(
|
|
76
|
+
id=str(message.id),
|
|
77
|
+
role="system",
|
|
78
|
+
content=stringify_if_needed(resolve_message_content(message.content)),
|
|
79
|
+
name=message.name,
|
|
80
|
+
))
|
|
81
|
+
elif isinstance(message, ToolMessage):
|
|
82
|
+
agui_messages.append(AGUIToolMessage(
|
|
83
|
+
id=str(message.id),
|
|
84
|
+
role="tool",
|
|
85
|
+
content=stringify_if_needed(resolve_message_content(message.content)),
|
|
86
|
+
tool_call_id=message.tool_call_id,
|
|
87
|
+
))
|
|
88
|
+
else:
|
|
89
|
+
raise TypeError(f"Unsupported message type: {type(message)}")
|
|
90
|
+
return agui_messages
|
|
91
|
+
|
|
92
|
+
def agui_messages_to_langchain(messages: List[AGUIMessage]) -> List[BaseMessage]:
|
|
93
|
+
langchain_messages = []
|
|
94
|
+
for message in messages:
|
|
95
|
+
role = message.role
|
|
96
|
+
if role == "user":
|
|
97
|
+
langchain_messages.append(HumanMessage(
|
|
98
|
+
id=message.id,
|
|
99
|
+
content=message.content,
|
|
100
|
+
name=message.name,
|
|
101
|
+
))
|
|
102
|
+
elif role == "assistant":
|
|
103
|
+
tool_calls = []
|
|
104
|
+
if hasattr(message, "tool_calls") and message.tool_calls:
|
|
105
|
+
for tc in message.tool_calls:
|
|
106
|
+
tool_calls.append({
|
|
107
|
+
"id": tc.id,
|
|
108
|
+
"name": tc.function.name,
|
|
109
|
+
"args": json.loads(tc.function.arguments) if hasattr(tc, "function") and tc.function.arguments else {},
|
|
110
|
+
"type": "tool_call",
|
|
111
|
+
})
|
|
112
|
+
langchain_messages.append(AIMessage(
|
|
113
|
+
id=message.id,
|
|
114
|
+
content=message.content or "",
|
|
115
|
+
tool_calls=tool_calls,
|
|
116
|
+
name=message.name,
|
|
117
|
+
))
|
|
118
|
+
elif role == "system":
|
|
119
|
+
langchain_messages.append(SystemMessage(
|
|
120
|
+
id=message.id,
|
|
121
|
+
content=message.content,
|
|
122
|
+
name=message.name,
|
|
123
|
+
))
|
|
124
|
+
elif role == "tool":
|
|
125
|
+
langchain_messages.append(ToolMessage(
|
|
126
|
+
id=message.id,
|
|
127
|
+
content=message.content,
|
|
128
|
+
tool_call_id=message.tool_call_id,
|
|
129
|
+
))
|
|
130
|
+
else:
|
|
131
|
+
raise ValueError(f"Unsupported message role: {role}")
|
|
132
|
+
return langchain_messages
|
|
133
|
+
|
|
134
|
+
def resolve_reasoning_content(chunk: Any) -> LangGraphReasoning | None:
|
|
135
|
+
content = chunk.content
|
|
136
|
+
if not content:
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
# Anthropic reasoning response
|
|
140
|
+
if isinstance(content, list) and content and content[0]:
|
|
141
|
+
if not content[0].get("thinking"):
|
|
142
|
+
return None
|
|
143
|
+
return LangGraphReasoning(
|
|
144
|
+
text=content[0]["thinking"],
|
|
145
|
+
type="text",
|
|
146
|
+
index=content[0].get("index", 0)
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# OpenAI reasoning response
|
|
150
|
+
if hasattr(chunk, "additional_kwargs"):
|
|
151
|
+
reasoning = chunk.additional_kwargs.get("reasoning", {})
|
|
152
|
+
summary = reasoning.get("summary", [])
|
|
153
|
+
if summary:
|
|
154
|
+
data = summary[0]
|
|
155
|
+
if not data or not data.get("text"):
|
|
156
|
+
return None
|
|
157
|
+
return LangGraphReasoning(
|
|
158
|
+
type="text",
|
|
159
|
+
text=data["text"],
|
|
160
|
+
index=data.get("index", 0)
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
def resolve_message_content(content: Any) -> str | None:
|
|
166
|
+
if not content:
|
|
167
|
+
return None
|
|
168
|
+
|
|
169
|
+
if isinstance(content, str):
|
|
170
|
+
return content
|
|
171
|
+
|
|
172
|
+
if isinstance(content, list) and content:
|
|
173
|
+
content_text = next((c.get("text") for c in content if isinstance(c, dict) and c.get("type") == "text"), None)
|
|
174
|
+
return content_text
|
|
175
|
+
|
|
176
|
+
return None
|
|
177
|
+
|
|
178
|
+
def camel_to_snake(name):
|
|
179
|
+
return re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower()
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
[tool.poetry]
|
|
2
|
+
name = "ag-ui-langgraph"
|
|
3
|
+
version = "0.0.1"
|
|
4
|
+
description = ""
|
|
5
|
+
authors = ["Ran Shem Tov <ran@copilotkit.ai>"]
|
|
6
|
+
readme = "README.md"
|
|
7
|
+
exclude = [
|
|
8
|
+
"ag_ui_langgraph/examples/**",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
[tool.poetry.dependencies]
|
|
12
|
+
python = "<3.14,>=3.10"
|
|
13
|
+
ag-ui-protocol = "==0.1.7"
|
|
14
|
+
fastapi = { version = "^0.115.12", optional = true }
|
|
15
|
+
langchain = ">=0.3.0"
|
|
16
|
+
langchain-core = ">=0.3.0"
|
|
17
|
+
langgraph = ">=0.3.25,<=0.5.0"
|
|
18
|
+
|
|
19
|
+
[tool.poetry.extras]
|
|
20
|
+
fastapi = ["fastapi"]
|
|
21
|
+
|
|
22
|
+
[build-system]
|
|
23
|
+
requires = ["poetry-core"]
|
|
24
|
+
build-backend = "poetry.core.masonry.api"
|
|
25
|
+
|
|
26
|
+
[tool.poetry.scripts]
|
|
27
|
+
dev = "ag_ui_langgraph.dojo:main"
|