march-agent 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,526 @@
1
+ """LangGraph extension for march_agent.
2
+
3
+ This module provides LangGraph-compatible components that integrate with
4
+ the march_agent framework.
5
+
6
+ Usage:
7
+ from march_agent.extensions.langgraph import HTTPCheckpointSaver
8
+
9
+ app = MarchAgentApp(gateway_url="agent-gateway:8080", api_key="key")
10
+ checkpointer = HTTPCheckpointSaver(app=app)
11
+
12
+ graph = StateGraph(...)
13
+ compiled = graph.compile(checkpointer=checkpointer)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import base64
20
+ import logging
21
+ from concurrent.futures import ThreadPoolExecutor
22
+ from datetime import datetime, timezone
23
+ from typing import (
24
+ TYPE_CHECKING,
25
+ Any,
26
+ AsyncIterator,
27
+ Dict,
28
+ Iterator,
29
+ Optional,
30
+ Sequence,
31
+ Set,
32
+ Tuple,
33
+ )
34
+
35
+ if TYPE_CHECKING:
36
+ from ..app import MarchAgentApp
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ # Try to import LangGraph types, but make them optional
41
+ try:
42
+ from langgraph.checkpoint.base import (
43
+ BaseCheckpointSaver,
44
+ ChannelVersions,
45
+ Checkpoint,
46
+ CheckpointMetadata,
47
+ CheckpointTuple,
48
+ )
49
+ from langchain_core.runnables import RunnableConfig
50
+
51
+ LANGGRAPH_AVAILABLE = True
52
+ except ImportError:
53
+ LANGGRAPH_AVAILABLE = False
54
+ # Define stub types for when langgraph is not installed
55
+ BaseCheckpointSaver = object
56
+ RunnableConfig = Dict[str, Any]
57
+ Checkpoint = Dict[str, Any]
58
+ CheckpointMetadata = Dict[str, Any]
59
+ CheckpointTuple = Tuple[Any, ...]
60
+ ChannelVersions = Dict[str, Any]
61
+
62
+ from ..checkpoint_client import CheckpointClient
63
+ from ..exceptions import APIException
64
+
65
+
66
+ def _generate_checkpoint_id() -> str:
67
+ """Generate a unique checkpoint ID based on timestamp."""
68
+ return datetime.now(timezone.utc).isoformat()
69
+
70
+
71
+ class HTTPCheckpointSaver(BaseCheckpointSaver if LANGGRAPH_AVAILABLE else object):
72
+ """HTTP-based checkpoint saver for LangGraph.
73
+
74
+ This checkpointer stores graph state via HTTP calls to the conversation-store
75
+ checkpoint API, enabling distributed checkpoint storage without direct
76
+ database access.
77
+
78
+ Example:
79
+ ```python
80
+ from march_agent import MarchAgentApp
81
+ from march_agent.extensions.langgraph import HTTPCheckpointSaver
82
+ from langgraph.graph import StateGraph
83
+
84
+ app = MarchAgentApp(gateway_url="agent-gateway:8080", api_key="key")
85
+ checkpointer = HTTPCheckpointSaver(app=app)
86
+
87
+ graph = StateGraph(MyState)
88
+ # ... define graph ...
89
+ compiled = graph.compile(checkpointer=checkpointer)
90
+
91
+ config = {"configurable": {"thread_id": "my-thread"}}
92
+ result = compiled.invoke({"messages": [...]}, config)
93
+ ```
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ app: "MarchAgentApp",
99
+ *,
100
+ serde: Optional[Any] = None,
101
+ ):
102
+ """Initialize HTTP checkpoint saver.
103
+
104
+ Args:
105
+ app: MarchAgentApp instance to get the gateway client from.
106
+ serde: Optional serializer/deserializer (for LangGraph compatibility)
107
+ """
108
+ if LANGGRAPH_AVAILABLE:
109
+ super().__init__(serde=serde)
110
+
111
+ base_url = app.gateway_client.conversation_store_url
112
+ self.client = CheckpointClient(base_url)
113
+ self._loop: Optional[asyncio.AbstractEventLoop] = None
114
+ self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="checkpoint")
115
+
116
+ def _get_loop(self) -> asyncio.AbstractEventLoop:
117
+ """Get or create an event loop for sync operations."""
118
+ try:
119
+ return asyncio.get_running_loop()
120
+ except RuntimeError:
121
+ if self._loop is None or self._loop.is_closed():
122
+ self._loop = asyncio.new_event_loop()
123
+ return self._loop
124
+
125
+ async def close(self):
126
+ """Close the HTTP client session and executor."""
127
+ await self.client.close()
128
+ self._executor.shutdown(wait=True)
129
+
130
+ # ==================== Config Helpers ====================
131
+
132
+ @staticmethod
133
+ def _get_thread_id(config: RunnableConfig) -> str:
134
+ """Extract thread_id from config."""
135
+ configurable = config.get("configurable", {})
136
+ thread_id = configurable.get("thread_id")
137
+ if not thread_id:
138
+ raise ValueError("Config must contain configurable.thread_id")
139
+ return thread_id
140
+
141
+ @staticmethod
142
+ def _get_checkpoint_ns(config: RunnableConfig) -> str:
143
+ """Extract checkpoint_ns from config (defaults to empty string)."""
144
+ return config.get("configurable", {}).get("checkpoint_ns", "")
145
+
146
+ @staticmethod
147
+ def _get_checkpoint_id(config: RunnableConfig) -> Optional[str]:
148
+ """Extract checkpoint_id from config."""
149
+ return config.get("configurable", {}).get("checkpoint_id")
150
+
151
+ # ==================== Async Methods (Primary Implementation) ====================
152
+
153
+ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
154
+ """Fetch a checkpoint tuple asynchronously."""
155
+ thread_id = self._get_thread_id(config)
156
+ checkpoint_ns = self._get_checkpoint_ns(config)
157
+ checkpoint_id = self._get_checkpoint_id(config)
158
+
159
+ try:
160
+ result = await self.client.get_tuple(
161
+ thread_id=thread_id,
162
+ checkpoint_ns=checkpoint_ns,
163
+ checkpoint_id=checkpoint_id,
164
+ )
165
+ except APIException as e:
166
+ logger.error(f"Failed to get checkpoint: {e}")
167
+ return None
168
+
169
+ if not result:
170
+ return None
171
+
172
+ return self._response_to_tuple(result)
173
+
174
+ async def alist(
175
+ self,
176
+ config: Optional[RunnableConfig],
177
+ *,
178
+ filter: Optional[Dict[str, Any]] = None,
179
+ before: Optional[RunnableConfig] = None,
180
+ limit: Optional[int] = None,
181
+ ) -> AsyncIterator[CheckpointTuple]:
182
+ """List checkpoints asynchronously."""
183
+ thread_id = None
184
+ checkpoint_ns = None
185
+
186
+ if config:
187
+ thread_id = config.get("configurable", {}).get("thread_id")
188
+ checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
189
+
190
+ before_id = None
191
+ if before:
192
+ before_id = before.get("configurable", {}).get("checkpoint_id")
193
+
194
+ try:
195
+ results = await self.client.list(
196
+ thread_id=thread_id,
197
+ checkpoint_ns=checkpoint_ns,
198
+ before=before_id,
199
+ limit=limit,
200
+ )
201
+ except APIException as e:
202
+ logger.error(f"Failed to list checkpoints: {e}")
203
+ return
204
+
205
+ for result in results:
206
+ tuple_result = self._response_to_tuple(result)
207
+ if tuple_result:
208
+ yield tuple_result
209
+
210
+ async def aput(
211
+ self,
212
+ config: RunnableConfig,
213
+ checkpoint: Checkpoint,
214
+ metadata: CheckpointMetadata,
215
+ new_versions: ChannelVersions,
216
+ ) -> RunnableConfig:
217
+ """Store a checkpoint asynchronously."""
218
+ thread_id = self._get_thread_id(config)
219
+ checkpoint_ns = self._get_checkpoint_ns(config)
220
+
221
+ checkpoint_id = self._get_checkpoint_id(config)
222
+ if not checkpoint_id:
223
+ checkpoint_id = checkpoint.get("id", _generate_checkpoint_id())
224
+
225
+ api_config = {
226
+ "configurable": {
227
+ "thread_id": thread_id,
228
+ "checkpoint_ns": checkpoint_ns,
229
+ "checkpoint_id": checkpoint_id,
230
+ }
231
+ }
232
+
233
+ checkpoint_data = self._checkpoint_to_api(checkpoint)
234
+ metadata_data = self._metadata_to_api(metadata)
235
+
236
+ try:
237
+ result = await self.client.put(
238
+ config=api_config,
239
+ checkpoint=checkpoint_data,
240
+ metadata=metadata_data,
241
+ new_versions=dict(new_versions) if new_versions else {},
242
+ )
243
+ except APIException as e:
244
+ logger.error(f"Failed to store checkpoint: {e}")
245
+ raise
246
+
247
+ return result.get("config", api_config)
248
+
249
+ async def aput_writes(
250
+ self,
251
+ config: RunnableConfig,
252
+ writes: Sequence[Tuple[str, Any]],
253
+ task_id: str,
254
+ task_path: str = "",
255
+ ) -> None:
256
+ """Store intermediate writes asynchronously (stub)."""
257
+ logger.debug(
258
+ f"aput_writes called (not persisted): task_id={task_id}, "
259
+ f"writes_count={len(writes)}"
260
+ )
261
+
262
+ async def adelete_thread(self, thread_id: str) -> None:
263
+ """Delete all checkpoints for a thread asynchronously."""
264
+ try:
265
+ await self.client.delete_thread(thread_id)
266
+ except APIException as e:
267
+ logger.error(f"Failed to delete thread checkpoints: {e}")
268
+ raise
269
+
270
+ # ==================== Sync Methods (Wrappers) ====================
271
+
272
+ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
273
+ """Fetch a checkpoint tuple synchronously (thread-safe)."""
274
+ return self._executor.submit(asyncio.run, self.aget_tuple(config)).result()
275
+
276
+ def list(
277
+ self,
278
+ config: Optional[RunnableConfig],
279
+ *,
280
+ filter: Optional[Dict[str, Any]] = None,
281
+ before: Optional[RunnableConfig] = None,
282
+ limit: Optional[int] = None,
283
+ ) -> Iterator[CheckpointTuple]:
284
+ """List checkpoints synchronously (thread-safe)."""
285
+
286
+ async def collect():
287
+ results = []
288
+ async for item in self.alist(config, filter=filter, before=before, limit=limit):
289
+ results.append(item)
290
+ return results
291
+
292
+ results = self._executor.submit(asyncio.run, collect()).result()
293
+ yield from results
294
+
295
+ def put(
296
+ self,
297
+ config: RunnableConfig,
298
+ checkpoint: Checkpoint,
299
+ metadata: CheckpointMetadata,
300
+ new_versions: ChannelVersions,
301
+ ) -> RunnableConfig:
302
+ """Store a checkpoint synchronously (thread-safe)."""
303
+ return self._executor.submit(
304
+ asyncio.run,
305
+ self.aput(config, checkpoint, metadata, new_versions)
306
+ ).result()
307
+
308
+ def put_writes(
309
+ self,
310
+ config: RunnableConfig,
311
+ writes: Sequence[Tuple[str, Any]],
312
+ task_id: str,
313
+ task_path: str = "",
314
+ ) -> None:
315
+ """Store intermediate writes synchronously (thread-safe)."""
316
+ self._executor.submit(
317
+ asyncio.run,
318
+ self.aput_writes(config, writes, task_id, task_path)
319
+ ).result()
320
+
321
+ def delete_thread(self, thread_id: str) -> None:
322
+ """Delete all checkpoints for a thread synchronously (thread-safe)."""
323
+ self._executor.submit(asyncio.run, self.adelete_thread(thread_id)).result()
324
+
325
+ # ==================== Data Conversion Helpers ====================
326
+
327
+ def _serialize_value(
328
+ self,
329
+ value: Any,
330
+ _visited: Optional[Set[int]] = None,
331
+ _depth: int = 0
332
+ ) -> Any:
333
+ """Serialize a value for JSON transmission with cycle detection.
334
+
335
+ Args:
336
+ value: Value to serialize
337
+ _visited: Set of visited object IDs (for cycle detection)
338
+ _depth: Current recursion depth (for depth limit)
339
+
340
+ Returns:
341
+ Serialized value safe for JSON
342
+ """
343
+ # Initialize visited set on first call
344
+ if _visited is None:
345
+ _visited = set()
346
+
347
+ # Depth protection (prevent stack overflow)
348
+ MAX_DEPTH = 100
349
+ if _depth > MAX_DEPTH:
350
+ logger.warning(
351
+ f"Serialization depth limit reached ({MAX_DEPTH}). "
352
+ "Returning placeholder."
353
+ )
354
+ return {"__max_depth_exceeded__": True}
355
+
356
+ # Handle bytes (before cycle check, as bytes are immutable)
357
+ if isinstance(value, bytes):
358
+ return {"__bytes__": base64.b64encode(value).decode("ascii")}
359
+
360
+ # Cycle detection for container types
361
+ if isinstance(value, (dict, list, tuple)):
362
+ obj_id = id(value)
363
+ if obj_id in _visited:
364
+ logger.warning(
365
+ "Circular reference detected during serialization. "
366
+ "Returning placeholder."
367
+ )
368
+ return {"__circular_ref__": True}
369
+
370
+ # Mark as visited
371
+ _visited.add(obj_id)
372
+
373
+ try:
374
+ # Serialize based on type
375
+ if isinstance(value, dict):
376
+ return {
377
+ k: self._serialize_value(v, _visited, _depth + 1)
378
+ for k, v in value.items()
379
+ }
380
+
381
+ if isinstance(value, list):
382
+ return [
383
+ self._serialize_value(item, _visited, _depth + 1)
384
+ for item in value
385
+ ]
386
+
387
+ if isinstance(value, tuple):
388
+ return {
389
+ "__tuple__": [
390
+ self._serialize_value(item, _visited, _depth + 1)
391
+ for item in value
392
+ ]
393
+ }
394
+ finally:
395
+ # Remove from visited after processing
396
+ # This allows same object in different branches (DAG structure)
397
+ _visited.discard(obj_id)
398
+
399
+ # Handle custom serialization with serde
400
+ if LANGGRAPH_AVAILABLE and hasattr(self, "serde") and self.serde is not None:
401
+ try:
402
+ type_str, serialized = self.serde.dumps_typed(value)
403
+ if isinstance(serialized, bytes):
404
+ serialized = base64.b64encode(serialized).decode("ascii")
405
+ return {"__serde_type__": type_str, "__serde_value__": serialized}
406
+ except Exception as e:
407
+ logger.warning(f"Failed to serialize value with serde: {e}")
408
+
409
+ # Handle objects with serialization methods
410
+ if hasattr(value, "model_dump"):
411
+ return self._serialize_value(value.model_dump(), _visited, _depth + 1)
412
+ if hasattr(value, "dict"):
413
+ return self._serialize_value(value.dict(), _visited, _depth + 1)
414
+ if hasattr(value, "to_dict"):
415
+ return self._serialize_value(value.to_dict(), _visited, _depth + 1)
416
+
417
+ # Return primitives as-is
418
+ return value
419
+
420
+ def _serialize_channel_values(self, channel_values: Dict[str, Any]) -> Dict[str, Any]:
421
+ """Serialize all channel values for API transmission."""
422
+ return self._serialize_value(channel_values)
423
+
424
+ def _deserialize_value(self, value: Any) -> Any:
425
+ """Deserialize a value, decoding base64 bytes and reconstructing tuples."""
426
+ if isinstance(value, dict):
427
+ if "__bytes__" in value:
428
+ return base64.b64decode(value["__bytes__"])
429
+ if "__tuple__" in value:
430
+ return tuple(self._deserialize_value(item) for item in value["__tuple__"])
431
+ if "__serde_type__" in value and "__serde_value__" in value:
432
+ if LANGGRAPH_AVAILABLE and hasattr(self, "serde") and self.serde is not None:
433
+ try:
434
+ serialized = value["__serde_value__"]
435
+ if isinstance(serialized, str):
436
+ serialized = base64.b64decode(serialized)
437
+ return self.serde.loads_typed((value["__serde_type__"], serialized))
438
+ except Exception as e:
439
+ logger.warning(f"Failed to deserialize value with serde: {e}")
440
+ return value
441
+ return {k: self._deserialize_value(v) for k, v in value.items()}
442
+
443
+ if isinstance(value, list):
444
+ return [self._deserialize_value(item) for item in value]
445
+
446
+ return value
447
+
448
+ def _deserialize_checkpoint(self, checkpoint_data: Dict[str, Any]) -> Dict[str, Any]:
449
+ """Deserialize checkpoint data received from API."""
450
+ if not checkpoint_data:
451
+ return checkpoint_data
452
+
453
+ result = dict(checkpoint_data)
454
+ if "channel_values" in result:
455
+ result["channel_values"] = self._deserialize_value(result["channel_values"])
456
+ return result
457
+
458
+ def _checkpoint_to_api(self, checkpoint: Checkpoint) -> Dict[str, Any]:
459
+ """Convert LangGraph Checkpoint to API format."""
460
+ if isinstance(checkpoint, dict):
461
+ channel_values = checkpoint.get("channel_values", {})
462
+ return {
463
+ "v": checkpoint.get("v", 1),
464
+ "id": checkpoint.get("id", _generate_checkpoint_id()),
465
+ "ts": checkpoint.get("ts", datetime.now(timezone.utc).isoformat()),
466
+ "channel_values": self._serialize_channel_values(channel_values),
467
+ "channel_versions": checkpoint.get("channel_versions", {}),
468
+ "versions_seen": checkpoint.get("versions_seen", {}),
469
+ "pending_sends": checkpoint.get("pending_sends", []),
470
+ }
471
+ channel_values = dict(getattr(checkpoint, "channel_values", {}))
472
+ return {
473
+ "v": getattr(checkpoint, "v", 1),
474
+ "id": getattr(checkpoint, "id", _generate_checkpoint_id()),
475
+ "ts": getattr(checkpoint, "ts", datetime.now(timezone.utc).isoformat()),
476
+ "channel_values": self._serialize_channel_values(channel_values),
477
+ "channel_versions": dict(getattr(checkpoint, "channel_versions", {})),
478
+ "versions_seen": dict(getattr(checkpoint, "versions_seen", {})),
479
+ "pending_sends": list(getattr(checkpoint, "pending_sends", [])),
480
+ }
481
+
482
+ def _serialize_writes(self, writes: Any) -> Any:
483
+ """Serialize writes field which may contain LangChain objects."""
484
+ if writes is None:
485
+ return None
486
+ return self._serialize_value(writes)
487
+
488
+ def _metadata_to_api(self, metadata: CheckpointMetadata) -> Dict[str, Any]:
489
+ """Convert LangGraph CheckpointMetadata to API format."""
490
+ if isinstance(metadata, dict):
491
+ return {
492
+ "source": metadata.get("source", "input"),
493
+ "step": metadata.get("step", -1),
494
+ "writes": self._serialize_writes(metadata.get("writes")),
495
+ "parents": metadata.get("parents", {}),
496
+ }
497
+ return {
498
+ "source": getattr(metadata, "source", "input"),
499
+ "step": getattr(metadata, "step", -1),
500
+ "writes": self._serialize_writes(getattr(metadata, "writes", None)),
501
+ "parents": dict(getattr(metadata, "parents", {})),
502
+ }
503
+
504
+ def _response_to_tuple(self, response: Dict[str, Any]) -> Optional[CheckpointTuple]:
505
+ """Convert API response to LangGraph CheckpointTuple."""
506
+ if not response:
507
+ return None
508
+
509
+ config = response.get("config", {})
510
+ checkpoint_data = response.get("checkpoint", {})
511
+ metadata_data = response.get("metadata", {})
512
+ parent_config = response.get("parent_config")
513
+ pending_writes = response.get("pending_writes")
514
+
515
+ checkpoint_data = self._deserialize_checkpoint(checkpoint_data)
516
+
517
+ if not LANGGRAPH_AVAILABLE:
518
+ return (config, checkpoint_data, metadata_data, parent_config, pending_writes)
519
+
520
+ return CheckpointTuple(
521
+ config=config,
522
+ checkpoint=checkpoint_data,
523
+ metadata=metadata_data,
524
+ parent_config=parent_config,
525
+ pending_writes=pending_writes or [],
526
+ )
@@ -0,0 +1,180 @@
1
+ """Pydantic AI extension for march_agent.
2
+
3
+ This module provides integration with Pydantic AI, enabling persistent
4
+ message history storage via the agent-state API.
5
+
6
+ Usage:
7
+ from march_agent import MarchAgentApp
8
+ from march_agent.extensions.pydantic_ai import PydanticAIMessageStore
9
+ from pydantic_ai import Agent
10
+
11
+ app = MarchAgentApp(gateway_url="agent-gateway:8080", api_key="key")
12
+ store = PydanticAIMessageStore(app=app)
13
+
14
+ my_agent = Agent('openai:gpt-4o', system_prompt="...")
15
+
16
+ @medical_agent.on_message
17
+ async def handle(message, sender):
18
+ # Load message history
19
+ history = await store.load(message.conversation_id)
20
+
21
+ # Run agent with streaming
22
+ async with medical_agent.streamer(message) as s:
23
+ async with my_agent.run_stream(message.content, message_history=history) as result:
24
+ async for chunk in result.stream_text():
25
+ s.stream(chunk)
26
+
27
+ # Save updated history
28
+ await store.save(message.conversation_id, result.all_messages())
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import logging
34
+ from typing import TYPE_CHECKING, List, Any, Optional
35
+
36
+ if TYPE_CHECKING:
37
+ from ..app import MarchAgentApp
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+ # Try to import Pydantic AI types, but make them optional
42
+ try:
43
+ from pydantic_ai.messages import (
44
+ ModelMessage,
45
+ ModelMessagesTypeAdapter,
46
+ )
47
+
48
+ PYDANTIC_AI_AVAILABLE = True
49
+ except ImportError:
50
+ PYDANTIC_AI_AVAILABLE = False
51
+ ModelMessage = Any
52
+ ModelMessagesTypeAdapter = None
53
+
54
+ from ..agent_state_client import AgentStateClient
55
+
56
+
57
+ class PydanticAIMessageStore:
58
+ """Persistent message store for Pydantic AI.
59
+
60
+ Stores and retrieves Pydantic AI native message history using the
61
+ agent-state API. Messages are serialized using Pydantic AI's built-in
62
+ ModelMessagesTypeAdapter for full fidelity.
63
+
64
+ Example:
65
+ ```python
66
+ from march_agent import MarchAgentApp
67
+ from march_agent.extensions.pydantic_ai import PydanticAIMessageStore
68
+ from pydantic_ai import Agent
69
+
70
+ app = MarchAgentApp(gateway_url="...", api_key="...")
71
+ store = PydanticAIMessageStore(app=app)
72
+
73
+ my_agent = Agent('openai:gpt-4o')
74
+
75
+ @medical_agent.on_message
76
+ async def handle(message, sender):
77
+ history = await store.load(message.conversation_id)
78
+
79
+ async with medical_agent.streamer(message) as s:
80
+ async with my_agent.run_stream(
81
+ message.content,
82
+ message_history=history
83
+ ) as result:
84
+ async for chunk in result.stream_text():
85
+ s.stream(chunk)
86
+
87
+ await store.save(message.conversation_id, result.all_messages())
88
+ ```
89
+ """
90
+
91
+ NAMESPACE = "pydantic_ai"
92
+
93
+ def __init__(self, app: "MarchAgentApp"):
94
+ """Initialize Pydantic AI message store.
95
+
96
+ Args:
97
+ app: MarchAgentApp instance to get the gateway client from.
98
+ """
99
+ if not PYDANTIC_AI_AVAILABLE:
100
+ raise ImportError(
101
+ "pydantic-ai is required for PydanticAIMessageStore. "
102
+ "Install it with: pip install march-agent[pydantic]"
103
+ )
104
+
105
+ base_url = app.gateway_client.conversation_store_url
106
+ self.client = AgentStateClient(base_url)
107
+ self._app = app
108
+
109
+ async def load(self, conversation_id: str) -> List[ModelMessage]:
110
+ """Load Pydantic AI message history for a conversation.
111
+
112
+ Args:
113
+ conversation_id: The conversation ID to load history for.
114
+
115
+ Returns:
116
+ List of ModelMessage objects (empty list if no history).
117
+ """
118
+ result = await self.client.get(conversation_id, self.NAMESPACE)
119
+
120
+ if not result:
121
+ logger.debug(f"No message history found for conversation {conversation_id}")
122
+ return []
123
+
124
+ state = result.get("state", {})
125
+ messages_data = state.get("messages", [])
126
+
127
+ if not messages_data:
128
+ return []
129
+
130
+ # Deserialize using Pydantic AI's TypeAdapter
131
+ try:
132
+ messages = ModelMessagesTypeAdapter.validate_python(messages_data)
133
+ logger.debug(
134
+ f"Loaded {len(messages)} messages for conversation {conversation_id}"
135
+ )
136
+ return messages
137
+ except Exception as e:
138
+ logger.error(f"Failed to deserialize messages: {e}")
139
+ return []
140
+
141
+ async def save(
142
+ self,
143
+ conversation_id: str,
144
+ messages: List[ModelMessage],
145
+ ) -> None:
146
+ """Save Pydantic AI message history for a conversation.
147
+
148
+ Args:
149
+ conversation_id: The conversation ID to save history for.
150
+ messages: List of ModelMessage objects to save.
151
+ """
152
+ # Serialize using Pydantic AI's TypeAdapter
153
+ try:
154
+ serialized = ModelMessagesTypeAdapter.dump_python(messages, mode="json")
155
+ except Exception as e:
156
+ logger.error(f"Failed to serialize messages: {e}")
157
+ raise
158
+
159
+ await self.client.put(
160
+ conversation_id,
161
+ self.NAMESPACE,
162
+ {"messages": serialized},
163
+ )
164
+
165
+ logger.debug(
166
+ f"Saved {len(messages)} messages for conversation {conversation_id}"
167
+ )
168
+
169
+ async def clear(self, conversation_id: str) -> None:
170
+ """Clear message history for a conversation.
171
+
172
+ Args:
173
+ conversation_id: The conversation ID to clear history for.
174
+ """
175
+ await self.client.delete(conversation_id, self.NAMESPACE)
176
+ logger.debug(f"Cleared message history for conversation {conversation_id}")
177
+
178
+ async def close(self) -> None:
179
+ """Close the HTTP client session."""
180
+ await self.client.close()