march-agent 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- march_agent/__init__.py +52 -0
- march_agent/agent.py +341 -0
- march_agent/agent_state_client.py +149 -0
- march_agent/app.py +416 -0
- march_agent/artifact.py +58 -0
- march_agent/checkpoint_client.py +169 -0
- march_agent/checkpointer.py +16 -0
- march_agent/cli.py +139 -0
- march_agent/conversation.py +103 -0
- march_agent/conversation_client.py +86 -0
- march_agent/conversation_message.py +48 -0
- march_agent/exceptions.py +36 -0
- march_agent/extensions/__init__.py +1 -0
- march_agent/extensions/langgraph.py +526 -0
- march_agent/extensions/pydantic_ai.py +180 -0
- march_agent/gateway_client.py +506 -0
- march_agent/gateway_pb2.py +73 -0
- march_agent/gateway_pb2_grpc.py +101 -0
- march_agent/heartbeat.py +84 -0
- march_agent/memory.py +73 -0
- march_agent/memory_client.py +155 -0
- march_agent/message.py +80 -0
- march_agent/streamer.py +220 -0
- march_agent-0.1.1.dist-info/METADATA +503 -0
- march_agent-0.1.1.dist-info/RECORD +29 -0
- march_agent-0.1.1.dist-info/WHEEL +5 -0
- march_agent-0.1.1.dist-info/entry_points.txt +2 -0
- march_agent-0.1.1.dist-info/licenses/LICENSE +21 -0
- march_agent-0.1.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,526 @@
|
|
|
1
|
+
"""LangGraph extension for march_agent.
|
|
2
|
+
|
|
3
|
+
This module provides LangGraph-compatible components that integrate with
|
|
4
|
+
the march_agent framework.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from march_agent.extensions.langgraph import HTTPCheckpointSaver
|
|
8
|
+
|
|
9
|
+
app = MarchAgentApp(gateway_url="agent-gateway:8080", api_key="key")
|
|
10
|
+
checkpointer = HTTPCheckpointSaver(app=app)
|
|
11
|
+
|
|
12
|
+
graph = StateGraph(...)
|
|
13
|
+
compiled = graph.compile(checkpointer=checkpointer)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import asyncio
|
|
19
|
+
import base64
|
|
20
|
+
import logging
|
|
21
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
22
|
+
from datetime import datetime, timezone
|
|
23
|
+
from typing import (
|
|
24
|
+
TYPE_CHECKING,
|
|
25
|
+
Any,
|
|
26
|
+
AsyncIterator,
|
|
27
|
+
Dict,
|
|
28
|
+
Iterator,
|
|
29
|
+
Optional,
|
|
30
|
+
Sequence,
|
|
31
|
+
Set,
|
|
32
|
+
Tuple,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if TYPE_CHECKING:
|
|
36
|
+
from ..app import MarchAgentApp
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
# Try to import LangGraph types, but make them optional
|
|
41
|
+
try:
|
|
42
|
+
from langgraph.checkpoint.base import (
|
|
43
|
+
BaseCheckpointSaver,
|
|
44
|
+
ChannelVersions,
|
|
45
|
+
Checkpoint,
|
|
46
|
+
CheckpointMetadata,
|
|
47
|
+
CheckpointTuple,
|
|
48
|
+
)
|
|
49
|
+
from langchain_core.runnables import RunnableConfig
|
|
50
|
+
|
|
51
|
+
LANGGRAPH_AVAILABLE = True
|
|
52
|
+
except ImportError:
|
|
53
|
+
LANGGRAPH_AVAILABLE = False
|
|
54
|
+
# Define stub types for when langgraph is not installed
|
|
55
|
+
BaseCheckpointSaver = object
|
|
56
|
+
RunnableConfig = Dict[str, Any]
|
|
57
|
+
Checkpoint = Dict[str, Any]
|
|
58
|
+
CheckpointMetadata = Dict[str, Any]
|
|
59
|
+
CheckpointTuple = Tuple[Any, ...]
|
|
60
|
+
ChannelVersions = Dict[str, Any]
|
|
61
|
+
|
|
62
|
+
from ..checkpoint_client import CheckpointClient
|
|
63
|
+
from ..exceptions import APIException
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _generate_checkpoint_id() -> str:
|
|
67
|
+
"""Generate a unique checkpoint ID based on timestamp."""
|
|
68
|
+
return datetime.now(timezone.utc).isoformat()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class HTTPCheckpointSaver(BaseCheckpointSaver if LANGGRAPH_AVAILABLE else object):
|
|
72
|
+
"""HTTP-based checkpoint saver for LangGraph.
|
|
73
|
+
|
|
74
|
+
This checkpointer stores graph state via HTTP calls to the conversation-store
|
|
75
|
+
checkpoint API, enabling distributed checkpoint storage without direct
|
|
76
|
+
database access.
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
```python
|
|
80
|
+
from march_agent import MarchAgentApp
|
|
81
|
+
from march_agent.extensions.langgraph import HTTPCheckpointSaver
|
|
82
|
+
from langgraph.graph import StateGraph
|
|
83
|
+
|
|
84
|
+
app = MarchAgentApp(gateway_url="agent-gateway:8080", api_key="key")
|
|
85
|
+
checkpointer = HTTPCheckpointSaver(app=app)
|
|
86
|
+
|
|
87
|
+
graph = StateGraph(MyState)
|
|
88
|
+
# ... define graph ...
|
|
89
|
+
compiled = graph.compile(checkpointer=checkpointer)
|
|
90
|
+
|
|
91
|
+
config = {"configurable": {"thread_id": "my-thread"}}
|
|
92
|
+
result = compiled.invoke({"messages": [...]}, config)
|
|
93
|
+
```
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
app: "MarchAgentApp",
|
|
99
|
+
*,
|
|
100
|
+
serde: Optional[Any] = None,
|
|
101
|
+
):
|
|
102
|
+
"""Initialize HTTP checkpoint saver.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
app: MarchAgentApp instance to get the gateway client from.
|
|
106
|
+
serde: Optional serializer/deserializer (for LangGraph compatibility)
|
|
107
|
+
"""
|
|
108
|
+
if LANGGRAPH_AVAILABLE:
|
|
109
|
+
super().__init__(serde=serde)
|
|
110
|
+
|
|
111
|
+
base_url = app.gateway_client.conversation_store_url
|
|
112
|
+
self.client = CheckpointClient(base_url)
|
|
113
|
+
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
|
114
|
+
self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="checkpoint")
|
|
115
|
+
|
|
116
|
+
def _get_loop(self) -> asyncio.AbstractEventLoop:
|
|
117
|
+
"""Get or create an event loop for sync operations."""
|
|
118
|
+
try:
|
|
119
|
+
return asyncio.get_running_loop()
|
|
120
|
+
except RuntimeError:
|
|
121
|
+
if self._loop is None or self._loop.is_closed():
|
|
122
|
+
self._loop = asyncio.new_event_loop()
|
|
123
|
+
return self._loop
|
|
124
|
+
|
|
125
|
+
async def close(self):
|
|
126
|
+
"""Close the HTTP client session and executor."""
|
|
127
|
+
await self.client.close()
|
|
128
|
+
self._executor.shutdown(wait=True)
|
|
129
|
+
|
|
130
|
+
# ==================== Config Helpers ====================
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def _get_thread_id(config: RunnableConfig) -> str:
|
|
134
|
+
"""Extract thread_id from config."""
|
|
135
|
+
configurable = config.get("configurable", {})
|
|
136
|
+
thread_id = configurable.get("thread_id")
|
|
137
|
+
if not thread_id:
|
|
138
|
+
raise ValueError("Config must contain configurable.thread_id")
|
|
139
|
+
return thread_id
|
|
140
|
+
|
|
141
|
+
@staticmethod
|
|
142
|
+
def _get_checkpoint_ns(config: RunnableConfig) -> str:
|
|
143
|
+
"""Extract checkpoint_ns from config (defaults to empty string)."""
|
|
144
|
+
return config.get("configurable", {}).get("checkpoint_ns", "")
|
|
145
|
+
|
|
146
|
+
@staticmethod
|
|
147
|
+
def _get_checkpoint_id(config: RunnableConfig) -> Optional[str]:
|
|
148
|
+
"""Extract checkpoint_id from config."""
|
|
149
|
+
return config.get("configurable", {}).get("checkpoint_id")
|
|
150
|
+
|
|
151
|
+
# ==================== Async Methods (Primary Implementation) ====================
|
|
152
|
+
|
|
153
|
+
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
|
|
154
|
+
"""Fetch a checkpoint tuple asynchronously."""
|
|
155
|
+
thread_id = self._get_thread_id(config)
|
|
156
|
+
checkpoint_ns = self._get_checkpoint_ns(config)
|
|
157
|
+
checkpoint_id = self._get_checkpoint_id(config)
|
|
158
|
+
|
|
159
|
+
try:
|
|
160
|
+
result = await self.client.get_tuple(
|
|
161
|
+
thread_id=thread_id,
|
|
162
|
+
checkpoint_ns=checkpoint_ns,
|
|
163
|
+
checkpoint_id=checkpoint_id,
|
|
164
|
+
)
|
|
165
|
+
except APIException as e:
|
|
166
|
+
logger.error(f"Failed to get checkpoint: {e}")
|
|
167
|
+
return None
|
|
168
|
+
|
|
169
|
+
if not result:
|
|
170
|
+
return None
|
|
171
|
+
|
|
172
|
+
return self._response_to_tuple(result)
|
|
173
|
+
|
|
174
|
+
async def alist(
|
|
175
|
+
self,
|
|
176
|
+
config: Optional[RunnableConfig],
|
|
177
|
+
*,
|
|
178
|
+
filter: Optional[Dict[str, Any]] = None,
|
|
179
|
+
before: Optional[RunnableConfig] = None,
|
|
180
|
+
limit: Optional[int] = None,
|
|
181
|
+
) -> AsyncIterator[CheckpointTuple]:
|
|
182
|
+
"""List checkpoints asynchronously."""
|
|
183
|
+
thread_id = None
|
|
184
|
+
checkpoint_ns = None
|
|
185
|
+
|
|
186
|
+
if config:
|
|
187
|
+
thread_id = config.get("configurable", {}).get("thread_id")
|
|
188
|
+
checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
|
|
189
|
+
|
|
190
|
+
before_id = None
|
|
191
|
+
if before:
|
|
192
|
+
before_id = before.get("configurable", {}).get("checkpoint_id")
|
|
193
|
+
|
|
194
|
+
try:
|
|
195
|
+
results = await self.client.list(
|
|
196
|
+
thread_id=thread_id,
|
|
197
|
+
checkpoint_ns=checkpoint_ns,
|
|
198
|
+
before=before_id,
|
|
199
|
+
limit=limit,
|
|
200
|
+
)
|
|
201
|
+
except APIException as e:
|
|
202
|
+
logger.error(f"Failed to list checkpoints: {e}")
|
|
203
|
+
return
|
|
204
|
+
|
|
205
|
+
for result in results:
|
|
206
|
+
tuple_result = self._response_to_tuple(result)
|
|
207
|
+
if tuple_result:
|
|
208
|
+
yield tuple_result
|
|
209
|
+
|
|
210
|
+
async def aput(
|
|
211
|
+
self,
|
|
212
|
+
config: RunnableConfig,
|
|
213
|
+
checkpoint: Checkpoint,
|
|
214
|
+
metadata: CheckpointMetadata,
|
|
215
|
+
new_versions: ChannelVersions,
|
|
216
|
+
) -> RunnableConfig:
|
|
217
|
+
"""Store a checkpoint asynchronously."""
|
|
218
|
+
thread_id = self._get_thread_id(config)
|
|
219
|
+
checkpoint_ns = self._get_checkpoint_ns(config)
|
|
220
|
+
|
|
221
|
+
checkpoint_id = self._get_checkpoint_id(config)
|
|
222
|
+
if not checkpoint_id:
|
|
223
|
+
checkpoint_id = checkpoint.get("id", _generate_checkpoint_id())
|
|
224
|
+
|
|
225
|
+
api_config = {
|
|
226
|
+
"configurable": {
|
|
227
|
+
"thread_id": thread_id,
|
|
228
|
+
"checkpoint_ns": checkpoint_ns,
|
|
229
|
+
"checkpoint_id": checkpoint_id,
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
checkpoint_data = self._checkpoint_to_api(checkpoint)
|
|
234
|
+
metadata_data = self._metadata_to_api(metadata)
|
|
235
|
+
|
|
236
|
+
try:
|
|
237
|
+
result = await self.client.put(
|
|
238
|
+
config=api_config,
|
|
239
|
+
checkpoint=checkpoint_data,
|
|
240
|
+
metadata=metadata_data,
|
|
241
|
+
new_versions=dict(new_versions) if new_versions else {},
|
|
242
|
+
)
|
|
243
|
+
except APIException as e:
|
|
244
|
+
logger.error(f"Failed to store checkpoint: {e}")
|
|
245
|
+
raise
|
|
246
|
+
|
|
247
|
+
return result.get("config", api_config)
|
|
248
|
+
|
|
249
|
+
async def aput_writes(
|
|
250
|
+
self,
|
|
251
|
+
config: RunnableConfig,
|
|
252
|
+
writes: Sequence[Tuple[str, Any]],
|
|
253
|
+
task_id: str,
|
|
254
|
+
task_path: str = "",
|
|
255
|
+
) -> None:
|
|
256
|
+
"""Store intermediate writes asynchronously (stub)."""
|
|
257
|
+
logger.debug(
|
|
258
|
+
f"aput_writes called (not persisted): task_id={task_id}, "
|
|
259
|
+
f"writes_count={len(writes)}"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
async def adelete_thread(self, thread_id: str) -> None:
|
|
263
|
+
"""Delete all checkpoints for a thread asynchronously."""
|
|
264
|
+
try:
|
|
265
|
+
await self.client.delete_thread(thread_id)
|
|
266
|
+
except APIException as e:
|
|
267
|
+
logger.error(f"Failed to delete thread checkpoints: {e}")
|
|
268
|
+
raise
|
|
269
|
+
|
|
270
|
+
# ==================== Sync Methods (Wrappers) ====================
|
|
271
|
+
|
|
272
|
+
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
|
|
273
|
+
"""Fetch a checkpoint tuple synchronously (thread-safe)."""
|
|
274
|
+
return self._executor.submit(asyncio.run, self.aget_tuple(config)).result()
|
|
275
|
+
|
|
276
|
+
def list(
|
|
277
|
+
self,
|
|
278
|
+
config: Optional[RunnableConfig],
|
|
279
|
+
*,
|
|
280
|
+
filter: Optional[Dict[str, Any]] = None,
|
|
281
|
+
before: Optional[RunnableConfig] = None,
|
|
282
|
+
limit: Optional[int] = None,
|
|
283
|
+
) -> Iterator[CheckpointTuple]:
|
|
284
|
+
"""List checkpoints synchronously (thread-safe)."""
|
|
285
|
+
|
|
286
|
+
async def collect():
|
|
287
|
+
results = []
|
|
288
|
+
async for item in self.alist(config, filter=filter, before=before, limit=limit):
|
|
289
|
+
results.append(item)
|
|
290
|
+
return results
|
|
291
|
+
|
|
292
|
+
results = self._executor.submit(asyncio.run, collect()).result()
|
|
293
|
+
yield from results
|
|
294
|
+
|
|
295
|
+
def put(
|
|
296
|
+
self,
|
|
297
|
+
config: RunnableConfig,
|
|
298
|
+
checkpoint: Checkpoint,
|
|
299
|
+
metadata: CheckpointMetadata,
|
|
300
|
+
new_versions: ChannelVersions,
|
|
301
|
+
) -> RunnableConfig:
|
|
302
|
+
"""Store a checkpoint synchronously (thread-safe)."""
|
|
303
|
+
return self._executor.submit(
|
|
304
|
+
asyncio.run,
|
|
305
|
+
self.aput(config, checkpoint, metadata, new_versions)
|
|
306
|
+
).result()
|
|
307
|
+
|
|
308
|
+
def put_writes(
|
|
309
|
+
self,
|
|
310
|
+
config: RunnableConfig,
|
|
311
|
+
writes: Sequence[Tuple[str, Any]],
|
|
312
|
+
task_id: str,
|
|
313
|
+
task_path: str = "",
|
|
314
|
+
) -> None:
|
|
315
|
+
"""Store intermediate writes synchronously (thread-safe)."""
|
|
316
|
+
self._executor.submit(
|
|
317
|
+
asyncio.run,
|
|
318
|
+
self.aput_writes(config, writes, task_id, task_path)
|
|
319
|
+
).result()
|
|
320
|
+
|
|
321
|
+
def delete_thread(self, thread_id: str) -> None:
|
|
322
|
+
"""Delete all checkpoints for a thread synchronously (thread-safe)."""
|
|
323
|
+
self._executor.submit(asyncio.run, self.adelete_thread(thread_id)).result()
|
|
324
|
+
|
|
325
|
+
# ==================== Data Conversion Helpers ====================
|
|
326
|
+
|
|
327
|
+
def _serialize_value(
|
|
328
|
+
self,
|
|
329
|
+
value: Any,
|
|
330
|
+
_visited: Optional[Set[int]] = None,
|
|
331
|
+
_depth: int = 0
|
|
332
|
+
) -> Any:
|
|
333
|
+
"""Serialize a value for JSON transmission with cycle detection.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
value: Value to serialize
|
|
337
|
+
_visited: Set of visited object IDs (for cycle detection)
|
|
338
|
+
_depth: Current recursion depth (for depth limit)
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
Serialized value safe for JSON
|
|
342
|
+
"""
|
|
343
|
+
# Initialize visited set on first call
|
|
344
|
+
if _visited is None:
|
|
345
|
+
_visited = set()
|
|
346
|
+
|
|
347
|
+
# Depth protection (prevent stack overflow)
|
|
348
|
+
MAX_DEPTH = 100
|
|
349
|
+
if _depth > MAX_DEPTH:
|
|
350
|
+
logger.warning(
|
|
351
|
+
f"Serialization depth limit reached ({MAX_DEPTH}). "
|
|
352
|
+
"Returning placeholder."
|
|
353
|
+
)
|
|
354
|
+
return {"__max_depth_exceeded__": True}
|
|
355
|
+
|
|
356
|
+
# Handle bytes (before cycle check, as bytes are immutable)
|
|
357
|
+
if isinstance(value, bytes):
|
|
358
|
+
return {"__bytes__": base64.b64encode(value).decode("ascii")}
|
|
359
|
+
|
|
360
|
+
# Cycle detection for container types
|
|
361
|
+
if isinstance(value, (dict, list, tuple)):
|
|
362
|
+
obj_id = id(value)
|
|
363
|
+
if obj_id in _visited:
|
|
364
|
+
logger.warning(
|
|
365
|
+
"Circular reference detected during serialization. "
|
|
366
|
+
"Returning placeholder."
|
|
367
|
+
)
|
|
368
|
+
return {"__circular_ref__": True}
|
|
369
|
+
|
|
370
|
+
# Mark as visited
|
|
371
|
+
_visited.add(obj_id)
|
|
372
|
+
|
|
373
|
+
try:
|
|
374
|
+
# Serialize based on type
|
|
375
|
+
if isinstance(value, dict):
|
|
376
|
+
return {
|
|
377
|
+
k: self._serialize_value(v, _visited, _depth + 1)
|
|
378
|
+
for k, v in value.items()
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
if isinstance(value, list):
|
|
382
|
+
return [
|
|
383
|
+
self._serialize_value(item, _visited, _depth + 1)
|
|
384
|
+
for item in value
|
|
385
|
+
]
|
|
386
|
+
|
|
387
|
+
if isinstance(value, tuple):
|
|
388
|
+
return {
|
|
389
|
+
"__tuple__": [
|
|
390
|
+
self._serialize_value(item, _visited, _depth + 1)
|
|
391
|
+
for item in value
|
|
392
|
+
]
|
|
393
|
+
}
|
|
394
|
+
finally:
|
|
395
|
+
# Remove from visited after processing
|
|
396
|
+
# This allows same object in different branches (DAG structure)
|
|
397
|
+
_visited.discard(obj_id)
|
|
398
|
+
|
|
399
|
+
# Handle custom serialization with serde
|
|
400
|
+
if LANGGRAPH_AVAILABLE and hasattr(self, "serde") and self.serde is not None:
|
|
401
|
+
try:
|
|
402
|
+
type_str, serialized = self.serde.dumps_typed(value)
|
|
403
|
+
if isinstance(serialized, bytes):
|
|
404
|
+
serialized = base64.b64encode(serialized).decode("ascii")
|
|
405
|
+
return {"__serde_type__": type_str, "__serde_value__": serialized}
|
|
406
|
+
except Exception as e:
|
|
407
|
+
logger.warning(f"Failed to serialize value with serde: {e}")
|
|
408
|
+
|
|
409
|
+
# Handle objects with serialization methods
|
|
410
|
+
if hasattr(value, "model_dump"):
|
|
411
|
+
return self._serialize_value(value.model_dump(), _visited, _depth + 1)
|
|
412
|
+
if hasattr(value, "dict"):
|
|
413
|
+
return self._serialize_value(value.dict(), _visited, _depth + 1)
|
|
414
|
+
if hasattr(value, "to_dict"):
|
|
415
|
+
return self._serialize_value(value.to_dict(), _visited, _depth + 1)
|
|
416
|
+
|
|
417
|
+
# Return primitives as-is
|
|
418
|
+
return value
|
|
419
|
+
|
|
420
|
+
def _serialize_channel_values(self, channel_values: Dict[str, Any]) -> Dict[str, Any]:
|
|
421
|
+
"""Serialize all channel values for API transmission."""
|
|
422
|
+
return self._serialize_value(channel_values)
|
|
423
|
+
|
|
424
|
+
def _deserialize_value(self, value: Any) -> Any:
|
|
425
|
+
"""Deserialize a value, decoding base64 bytes and reconstructing tuples."""
|
|
426
|
+
if isinstance(value, dict):
|
|
427
|
+
if "__bytes__" in value:
|
|
428
|
+
return base64.b64decode(value["__bytes__"])
|
|
429
|
+
if "__tuple__" in value:
|
|
430
|
+
return tuple(self._deserialize_value(item) for item in value["__tuple__"])
|
|
431
|
+
if "__serde_type__" in value and "__serde_value__" in value:
|
|
432
|
+
if LANGGRAPH_AVAILABLE and hasattr(self, "serde") and self.serde is not None:
|
|
433
|
+
try:
|
|
434
|
+
serialized = value["__serde_value__"]
|
|
435
|
+
if isinstance(serialized, str):
|
|
436
|
+
serialized = base64.b64decode(serialized)
|
|
437
|
+
return self.serde.loads_typed((value["__serde_type__"], serialized))
|
|
438
|
+
except Exception as e:
|
|
439
|
+
logger.warning(f"Failed to deserialize value with serde: {e}")
|
|
440
|
+
return value
|
|
441
|
+
return {k: self._deserialize_value(v) for k, v in value.items()}
|
|
442
|
+
|
|
443
|
+
if isinstance(value, list):
|
|
444
|
+
return [self._deserialize_value(item) for item in value]
|
|
445
|
+
|
|
446
|
+
return value
|
|
447
|
+
|
|
448
|
+
def _deserialize_checkpoint(self, checkpoint_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
449
|
+
"""Deserialize checkpoint data received from API."""
|
|
450
|
+
if not checkpoint_data:
|
|
451
|
+
return checkpoint_data
|
|
452
|
+
|
|
453
|
+
result = dict(checkpoint_data)
|
|
454
|
+
if "channel_values" in result:
|
|
455
|
+
result["channel_values"] = self._deserialize_value(result["channel_values"])
|
|
456
|
+
return result
|
|
457
|
+
|
|
458
|
+
def _checkpoint_to_api(self, checkpoint: Checkpoint) -> Dict[str, Any]:
|
|
459
|
+
"""Convert LangGraph Checkpoint to API format."""
|
|
460
|
+
if isinstance(checkpoint, dict):
|
|
461
|
+
channel_values = checkpoint.get("channel_values", {})
|
|
462
|
+
return {
|
|
463
|
+
"v": checkpoint.get("v", 1),
|
|
464
|
+
"id": checkpoint.get("id", _generate_checkpoint_id()),
|
|
465
|
+
"ts": checkpoint.get("ts", datetime.now(timezone.utc).isoformat()),
|
|
466
|
+
"channel_values": self._serialize_channel_values(channel_values),
|
|
467
|
+
"channel_versions": checkpoint.get("channel_versions", {}),
|
|
468
|
+
"versions_seen": checkpoint.get("versions_seen", {}),
|
|
469
|
+
"pending_sends": checkpoint.get("pending_sends", []),
|
|
470
|
+
}
|
|
471
|
+
channel_values = dict(getattr(checkpoint, "channel_values", {}))
|
|
472
|
+
return {
|
|
473
|
+
"v": getattr(checkpoint, "v", 1),
|
|
474
|
+
"id": getattr(checkpoint, "id", _generate_checkpoint_id()),
|
|
475
|
+
"ts": getattr(checkpoint, "ts", datetime.now(timezone.utc).isoformat()),
|
|
476
|
+
"channel_values": self._serialize_channel_values(channel_values),
|
|
477
|
+
"channel_versions": dict(getattr(checkpoint, "channel_versions", {})),
|
|
478
|
+
"versions_seen": dict(getattr(checkpoint, "versions_seen", {})),
|
|
479
|
+
"pending_sends": list(getattr(checkpoint, "pending_sends", [])),
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
def _serialize_writes(self, writes: Any) -> Any:
|
|
483
|
+
"""Serialize writes field which may contain LangChain objects."""
|
|
484
|
+
if writes is None:
|
|
485
|
+
return None
|
|
486
|
+
return self._serialize_value(writes)
|
|
487
|
+
|
|
488
|
+
def _metadata_to_api(self, metadata: CheckpointMetadata) -> Dict[str, Any]:
|
|
489
|
+
"""Convert LangGraph CheckpointMetadata to API format."""
|
|
490
|
+
if isinstance(metadata, dict):
|
|
491
|
+
return {
|
|
492
|
+
"source": metadata.get("source", "input"),
|
|
493
|
+
"step": metadata.get("step", -1),
|
|
494
|
+
"writes": self._serialize_writes(metadata.get("writes")),
|
|
495
|
+
"parents": metadata.get("parents", {}),
|
|
496
|
+
}
|
|
497
|
+
return {
|
|
498
|
+
"source": getattr(metadata, "source", "input"),
|
|
499
|
+
"step": getattr(metadata, "step", -1),
|
|
500
|
+
"writes": self._serialize_writes(getattr(metadata, "writes", None)),
|
|
501
|
+
"parents": dict(getattr(metadata, "parents", {})),
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
def _response_to_tuple(self, response: Dict[str, Any]) -> Optional[CheckpointTuple]:
|
|
505
|
+
"""Convert API response to LangGraph CheckpointTuple."""
|
|
506
|
+
if not response:
|
|
507
|
+
return None
|
|
508
|
+
|
|
509
|
+
config = response.get("config", {})
|
|
510
|
+
checkpoint_data = response.get("checkpoint", {})
|
|
511
|
+
metadata_data = response.get("metadata", {})
|
|
512
|
+
parent_config = response.get("parent_config")
|
|
513
|
+
pending_writes = response.get("pending_writes")
|
|
514
|
+
|
|
515
|
+
checkpoint_data = self._deserialize_checkpoint(checkpoint_data)
|
|
516
|
+
|
|
517
|
+
if not LANGGRAPH_AVAILABLE:
|
|
518
|
+
return (config, checkpoint_data, metadata_data, parent_config, pending_writes)
|
|
519
|
+
|
|
520
|
+
return CheckpointTuple(
|
|
521
|
+
config=config,
|
|
522
|
+
checkpoint=checkpoint_data,
|
|
523
|
+
metadata=metadata_data,
|
|
524
|
+
parent_config=parent_config,
|
|
525
|
+
pending_writes=pending_writes or [],
|
|
526
|
+
)
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""Pydantic AI extension for march_agent.
|
|
2
|
+
|
|
3
|
+
This module provides integration with Pydantic AI, enabling persistent
|
|
4
|
+
message history storage via the agent-state API.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from march_agent import MarchAgentApp
|
|
8
|
+
from march_agent.extensions.pydantic_ai import PydanticAIMessageStore
|
|
9
|
+
from pydantic_ai import Agent
|
|
10
|
+
|
|
11
|
+
app = MarchAgentApp(gateway_url="agent-gateway:8080", api_key="key")
|
|
12
|
+
store = PydanticAIMessageStore(app=app)
|
|
13
|
+
|
|
14
|
+
my_agent = Agent('openai:gpt-4o', system_prompt="...")
|
|
15
|
+
|
|
16
|
+
@medical_agent.on_message
|
|
17
|
+
async def handle(message, sender):
|
|
18
|
+
# Load message history
|
|
19
|
+
history = await store.load(message.conversation_id)
|
|
20
|
+
|
|
21
|
+
# Run agent with streaming
|
|
22
|
+
async with medical_agent.streamer(message) as s:
|
|
23
|
+
async with my_agent.run_stream(message.content, message_history=history) as result:
|
|
24
|
+
async for chunk in result.stream_text():
|
|
25
|
+
s.stream(chunk)
|
|
26
|
+
|
|
27
|
+
# Save updated history
|
|
28
|
+
await store.save(message.conversation_id, result.all_messages())
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
from __future__ import annotations
|
|
32
|
+
|
|
33
|
+
import logging
|
|
34
|
+
from typing import TYPE_CHECKING, List, Any, Optional
|
|
35
|
+
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from ..app import MarchAgentApp
|
|
38
|
+
|
|
39
|
+
logger = logging.getLogger(__name__)
|
|
40
|
+
|
|
41
|
+
# Try to import Pydantic AI types, but make them optional
|
|
42
|
+
try:
|
|
43
|
+
from pydantic_ai.messages import (
|
|
44
|
+
ModelMessage,
|
|
45
|
+
ModelMessagesTypeAdapter,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
PYDANTIC_AI_AVAILABLE = True
|
|
49
|
+
except ImportError:
|
|
50
|
+
PYDANTIC_AI_AVAILABLE = False
|
|
51
|
+
ModelMessage = Any
|
|
52
|
+
ModelMessagesTypeAdapter = None
|
|
53
|
+
|
|
54
|
+
from ..agent_state_client import AgentStateClient
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class PydanticAIMessageStore:
|
|
58
|
+
"""Persistent message store for Pydantic AI.
|
|
59
|
+
|
|
60
|
+
Stores and retrieves Pydantic AI native message history using the
|
|
61
|
+
agent-state API. Messages are serialized using Pydantic AI's built-in
|
|
62
|
+
ModelMessagesTypeAdapter for full fidelity.
|
|
63
|
+
|
|
64
|
+
Example:
|
|
65
|
+
```python
|
|
66
|
+
from march_agent import MarchAgentApp
|
|
67
|
+
from march_agent.extensions.pydantic_ai import PydanticAIMessageStore
|
|
68
|
+
from pydantic_ai import Agent
|
|
69
|
+
|
|
70
|
+
app = MarchAgentApp(gateway_url="...", api_key="...")
|
|
71
|
+
store = PydanticAIMessageStore(app=app)
|
|
72
|
+
|
|
73
|
+
my_agent = Agent('openai:gpt-4o')
|
|
74
|
+
|
|
75
|
+
@medical_agent.on_message
|
|
76
|
+
async def handle(message, sender):
|
|
77
|
+
history = await store.load(message.conversation_id)
|
|
78
|
+
|
|
79
|
+
async with medical_agent.streamer(message) as s:
|
|
80
|
+
async with my_agent.run_stream(
|
|
81
|
+
message.content,
|
|
82
|
+
message_history=history
|
|
83
|
+
) as result:
|
|
84
|
+
async for chunk in result.stream_text():
|
|
85
|
+
s.stream(chunk)
|
|
86
|
+
|
|
87
|
+
await store.save(message.conversation_id, result.all_messages())
|
|
88
|
+
```
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
NAMESPACE = "pydantic_ai"
|
|
92
|
+
|
|
93
|
+
def __init__(self, app: "MarchAgentApp"):
|
|
94
|
+
"""Initialize Pydantic AI message store.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
app: MarchAgentApp instance to get the gateway client from.
|
|
98
|
+
"""
|
|
99
|
+
if not PYDANTIC_AI_AVAILABLE:
|
|
100
|
+
raise ImportError(
|
|
101
|
+
"pydantic-ai is required for PydanticAIMessageStore. "
|
|
102
|
+
"Install it with: pip install march-agent[pydantic]"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
base_url = app.gateway_client.conversation_store_url
|
|
106
|
+
self.client = AgentStateClient(base_url)
|
|
107
|
+
self._app = app
|
|
108
|
+
|
|
109
|
+
async def load(self, conversation_id: str) -> List[ModelMessage]:
|
|
110
|
+
"""Load Pydantic AI message history for a conversation.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
conversation_id: The conversation ID to load history for.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
List of ModelMessage objects (empty list if no history).
|
|
117
|
+
"""
|
|
118
|
+
result = await self.client.get(conversation_id, self.NAMESPACE)
|
|
119
|
+
|
|
120
|
+
if not result:
|
|
121
|
+
logger.debug(f"No message history found for conversation {conversation_id}")
|
|
122
|
+
return []
|
|
123
|
+
|
|
124
|
+
state = result.get("state", {})
|
|
125
|
+
messages_data = state.get("messages", [])
|
|
126
|
+
|
|
127
|
+
if not messages_data:
|
|
128
|
+
return []
|
|
129
|
+
|
|
130
|
+
# Deserialize using Pydantic AI's TypeAdapter
|
|
131
|
+
try:
|
|
132
|
+
messages = ModelMessagesTypeAdapter.validate_python(messages_data)
|
|
133
|
+
logger.debug(
|
|
134
|
+
f"Loaded {len(messages)} messages for conversation {conversation_id}"
|
|
135
|
+
)
|
|
136
|
+
return messages
|
|
137
|
+
except Exception as e:
|
|
138
|
+
logger.error(f"Failed to deserialize messages: {e}")
|
|
139
|
+
return []
|
|
140
|
+
|
|
141
|
+
async def save(
|
|
142
|
+
self,
|
|
143
|
+
conversation_id: str,
|
|
144
|
+
messages: List[ModelMessage],
|
|
145
|
+
) -> None:
|
|
146
|
+
"""Save Pydantic AI message history for a conversation.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
conversation_id: The conversation ID to save history for.
|
|
150
|
+
messages: List of ModelMessage objects to save.
|
|
151
|
+
"""
|
|
152
|
+
# Serialize using Pydantic AI's TypeAdapter
|
|
153
|
+
try:
|
|
154
|
+
serialized = ModelMessagesTypeAdapter.dump_python(messages, mode="json")
|
|
155
|
+
except Exception as e:
|
|
156
|
+
logger.error(f"Failed to serialize messages: {e}")
|
|
157
|
+
raise
|
|
158
|
+
|
|
159
|
+
await self.client.put(
|
|
160
|
+
conversation_id,
|
|
161
|
+
self.NAMESPACE,
|
|
162
|
+
{"messages": serialized},
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
logger.debug(
|
|
166
|
+
f"Saved {len(messages)} messages for conversation {conversation_id}"
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
async def clear(self, conversation_id: str) -> None:
|
|
170
|
+
"""Clear message history for a conversation.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
conversation_id: The conversation ID to clear history for.
|
|
174
|
+
"""
|
|
175
|
+
await self.client.delete(conversation_id, self.NAMESPACE)
|
|
176
|
+
logger.debug(f"Cleared message history for conversation {conversation_id}")
|
|
177
|
+
|
|
178
|
+
async def close(self) -> None:
|
|
179
|
+
"""Close the HTTP client session."""
|
|
180
|
+
await self.client.close()
|