march-agent 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- march_agent/__init__.py +52 -0
- march_agent/agent.py +341 -0
- march_agent/agent_state_client.py +149 -0
- march_agent/app.py +416 -0
- march_agent/artifact.py +58 -0
- march_agent/checkpoint_client.py +169 -0
- march_agent/checkpointer.py +16 -0
- march_agent/cli.py +139 -0
- march_agent/conversation.py +103 -0
- march_agent/conversation_client.py +86 -0
- march_agent/conversation_message.py +48 -0
- march_agent/exceptions.py +36 -0
- march_agent/extensions/__init__.py +1 -0
- march_agent/extensions/langgraph.py +526 -0
- march_agent/extensions/pydantic_ai.py +180 -0
- march_agent/gateway_client.py +506 -0
- march_agent/gateway_pb2.py +73 -0
- march_agent/gateway_pb2_grpc.py +101 -0
- march_agent/heartbeat.py +84 -0
- march_agent/memory.py +73 -0
- march_agent/memory_client.py +155 -0
- march_agent/message.py +80 -0
- march_agent/streamer.py +220 -0
- march_agent-0.1.1.dist-info/METADATA +503 -0
- march_agent-0.1.1.dist-info/RECORD +29 -0
- march_agent-0.1.1.dist-info/WHEEL +5 -0
- march_agent-0.1.1.dist-info/entry_points.txt +2 -0
- march_agent-0.1.1.dist-info/licenses/LICENSE +21 -0
- march_agent-0.1.1.dist-info/top_level.txt +1 -0
march_agent/app.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
1
|
+
"""Main application class for March AI Agent framework (async)."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import signal
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
9
|
+
from typing import Optional, Dict, Any, List, Set
|
|
10
|
+
|
|
11
|
+
from .agent import Agent
|
|
12
|
+
from .gateway_client import GatewayClient
|
|
13
|
+
from .conversation_client import ConversationClient
|
|
14
|
+
from .memory_client import MemoryClient
|
|
15
|
+
from .exceptions import RegistrationError, ConfigurationError
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class MarchAgentApp:
|
|
21
|
+
"""
|
|
22
|
+
Main application class for March AI Agent framework.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
from march_agent import MarchAgentApp
|
|
26
|
+
|
|
27
|
+
app = MarchAgentApp(
|
|
28
|
+
gateway_url="agent-gateway:8080",
|
|
29
|
+
api_key="agent-key-1"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# Register agents
|
|
33
|
+
medical_agent = app.register_me(
|
|
34
|
+
name="medical-qa-agent",
|
|
35
|
+
about="Medical Q&A",
|
|
36
|
+
document="Answers medical questions"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
@medical_agent.on_message
|
|
40
|
+
def handle_message(message, sender):
|
|
41
|
+
# Access conversation history
|
|
42
|
+
history = message.conversation.get_history(limit=5)
|
|
43
|
+
|
|
44
|
+
# Stream response
|
|
45
|
+
with medical_agent.streamer(message) as s:
|
|
46
|
+
s.stream("Processing...")
|
|
47
|
+
s.stream("Done!")
|
|
48
|
+
|
|
49
|
+
app.run()
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
gateway_url: str,
|
|
55
|
+
api_key: str,
|
|
56
|
+
heartbeat_interval: int = 60,
|
|
57
|
+
max_concurrent_tasks: int = 100,
|
|
58
|
+
error_message_template: str = (
|
|
59
|
+
"I encountered an error while processing your message. "
|
|
60
|
+
"Please try again or contact support if the issue persists."
|
|
61
|
+
),
|
|
62
|
+
secure: bool = False,
|
|
63
|
+
):
|
|
64
|
+
"""
|
|
65
|
+
Initialize March Agent App.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
gateway_url: Gateway endpoint (e.g., "agent-gateway:8080")
|
|
69
|
+
api_key: API key for gateway authentication
|
|
70
|
+
heartbeat_interval: Heartbeat interval in seconds
|
|
71
|
+
max_concurrent_tasks: Maximum number of concurrent message handlers (default: 100)
|
|
72
|
+
error_message_template: Template for error messages sent to users
|
|
73
|
+
secure: If True, use TLS for gRPC and HTTPS for HTTP requests
|
|
74
|
+
"""
|
|
75
|
+
self.gateway_url = gateway_url
|
|
76
|
+
self.api_key = api_key
|
|
77
|
+
self.heartbeat_interval = heartbeat_interval
|
|
78
|
+
self.max_concurrent_tasks = max_concurrent_tasks
|
|
79
|
+
self.error_message_template = error_message_template
|
|
80
|
+
self.secure = secure
|
|
81
|
+
|
|
82
|
+
# Create gateway client
|
|
83
|
+
self.gateway_client = GatewayClient(gateway_url, api_key, secure=secure)
|
|
84
|
+
|
|
85
|
+
# Create conversation client using gateway proxy
|
|
86
|
+
self.conversation_client = ConversationClient(
|
|
87
|
+
self.gateway_client.conversation_store_url
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Create memory client using gateway proxy
|
|
91
|
+
self.memory_client = MemoryClient(
|
|
92
|
+
self.gateway_client.ai_memory_url
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
self._agents: List[Agent] = []
|
|
96
|
+
self._running = False
|
|
97
|
+
self._connected = False
|
|
98
|
+
self._active_tasks: Set[asyncio.Task] = set()
|
|
99
|
+
self._executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="consume")
|
|
100
|
+
self._task_start_times: Dict[asyncio.Task, float] = {}
|
|
101
|
+
self._hung_task_threshold = 300.0 # 5 minutes
|
|
102
|
+
|
|
103
|
+
logger.info(f"MarchAgentApp initialized (gateway: {gateway_url}, max_concurrent: {max_concurrent_tasks})")
|
|
104
|
+
|
|
105
|
+
def register_me(
|
|
106
|
+
self,
|
|
107
|
+
name: str,
|
|
108
|
+
about: str,
|
|
109
|
+
document: str,
|
|
110
|
+
representation_name: Optional[str] = None,
|
|
111
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
112
|
+
related_pages: Optional[List[Dict[str, str]]] = None,
|
|
113
|
+
) -> Agent:
|
|
114
|
+
"""
|
|
115
|
+
Register an agent with the backend.
|
|
116
|
+
|
|
117
|
+
Example:
|
|
118
|
+
medical_agent = app.register_me(
|
|
119
|
+
name="medical-qa",
|
|
120
|
+
about="Medical question answering",
|
|
121
|
+
document="Answers medical questions using AI",
|
|
122
|
+
representation_name="Medical Q&A Bot",
|
|
123
|
+
related_pages=[
|
|
124
|
+
{"name": "Dashboard", "endpoint": "/dashboard"},
|
|
125
|
+
{"name": "Reports", "endpoint": "/reports"},
|
|
126
|
+
]
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
@medical_agent.on_message
|
|
130
|
+
def handle_message(message, sender):
|
|
131
|
+
with medical_agent.streamer(message) as s:
|
|
132
|
+
s.stream("Processing...")
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
name: Unique agent name (used for routing)
|
|
136
|
+
about: Short description
|
|
137
|
+
document: Detailed documentation
|
|
138
|
+
representation_name: Display name (optional)
|
|
139
|
+
metadata: Additional metadata (optional)
|
|
140
|
+
related_pages: List of related pages with 'name' and 'endpoint' keys (optional)
|
|
141
|
+
Example: [{"name": "Dashboard", "endpoint": "/dashboard"}]
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Agent instance
|
|
145
|
+
"""
|
|
146
|
+
# Register agent with backend via gateway proxy
|
|
147
|
+
logger.info(f"Registering agent '{name}'...")
|
|
148
|
+
|
|
149
|
+
payload = {
|
|
150
|
+
"name": name,
|
|
151
|
+
"about": about,
|
|
152
|
+
"document": document,
|
|
153
|
+
"representationName": representation_name or name,
|
|
154
|
+
}
|
|
155
|
+
if metadata:
|
|
156
|
+
payload["metadata"] = metadata
|
|
157
|
+
if related_pages:
|
|
158
|
+
payload["relatedPages"] = related_pages
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
response = self.gateway_client.http_post(
|
|
162
|
+
"ai-inventory",
|
|
163
|
+
"/api/v1/agents/register",
|
|
164
|
+
json=payload,
|
|
165
|
+
timeout=30
|
|
166
|
+
)
|
|
167
|
+
if response.status_code == 201:
|
|
168
|
+
agent_data = response.json()
|
|
169
|
+
logger.info(f"Agent '{name}' registered successfully")
|
|
170
|
+
else:
|
|
171
|
+
raise RegistrationError(f"Registration failed: {response.text}")
|
|
172
|
+
except Exception as e:
|
|
173
|
+
raise RegistrationError(f"Failed to register agent '{name}': {e}")
|
|
174
|
+
|
|
175
|
+
# Create agent instance
|
|
176
|
+
agent = Agent(
|
|
177
|
+
name=name,
|
|
178
|
+
gateway_client=self.gateway_client,
|
|
179
|
+
agent_data=agent_data,
|
|
180
|
+
heartbeat_interval=self.heartbeat_interval,
|
|
181
|
+
conversation_client=self.conversation_client,
|
|
182
|
+
memory_client=self.memory_client,
|
|
183
|
+
error_message_template=self.error_message_template,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
self._agents.append(agent)
|
|
187
|
+
logger.info(f"Agent '{name}' ready")
|
|
188
|
+
|
|
189
|
+
return agent
|
|
190
|
+
|
|
191
|
+
def run(self):
|
|
192
|
+
"""Start all registered agents and block until shutdown."""
|
|
193
|
+
if not self._agents:
|
|
194
|
+
raise ConfigurationError(
|
|
195
|
+
"No agents registered. Use app.register_me() to register agents."
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
logger.info(f"Starting {len(self._agents)} agent(s)...")
|
|
199
|
+
self._running = True
|
|
200
|
+
|
|
201
|
+
# Connect to gateway with all agent names
|
|
202
|
+
agent_names = [agent.name for agent in self._agents]
|
|
203
|
+
try:
|
|
204
|
+
self.gateway_client.connect(agent_names)
|
|
205
|
+
self._connected = True
|
|
206
|
+
except Exception as e:
|
|
207
|
+
raise ConfigurationError(f"Failed to connect to gateway: {e}")
|
|
208
|
+
|
|
209
|
+
# Setup signal handlers
|
|
210
|
+
signal.signal(signal.SIGINT, self._shutdown_handler)
|
|
211
|
+
signal.signal(signal.SIGTERM, self._shutdown_handler)
|
|
212
|
+
|
|
213
|
+
# Initialize all agents with the connected gateway
|
|
214
|
+
for agent in self._agents:
|
|
215
|
+
agent._initialize_with_gateway()
|
|
216
|
+
|
|
217
|
+
# Start all agents in background threads (they just stay alive)
|
|
218
|
+
threads = []
|
|
219
|
+
for agent in self._agents:
|
|
220
|
+
thread = threading.Thread(target=agent.start_consuming, daemon=True)
|
|
221
|
+
thread.start()
|
|
222
|
+
threads.append(thread)
|
|
223
|
+
|
|
224
|
+
logger.info("All agents started. Press Ctrl+C to stop.")
|
|
225
|
+
|
|
226
|
+
# Run the single consume loop in the main thread
|
|
227
|
+
try:
|
|
228
|
+
self._consume_loop()
|
|
229
|
+
except KeyboardInterrupt:
|
|
230
|
+
logger.info("Interrupted by user")
|
|
231
|
+
finally:
|
|
232
|
+
self._shutdown()
|
|
233
|
+
|
|
234
|
+
def _consume_loop(self):
|
|
235
|
+
"""Main consume loop that dispatches messages to agents concurrently."""
|
|
236
|
+
# Run the async consume loop
|
|
237
|
+
asyncio.run(self._consume_loop_async())
|
|
238
|
+
|
|
239
|
+
async def _consume_loop_async(self):
|
|
240
|
+
"""Async consume loop with concurrent message processing."""
|
|
241
|
+
# Build topic -> agent mapping
|
|
242
|
+
topic_to_agent = {
|
|
243
|
+
f"{agent.name}.inbox": agent for agent in self._agents
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
logger.info(
|
|
247
|
+
f"Starting consume loop for topics: {list(topic_to_agent.keys())} "
|
|
248
|
+
f"(max_concurrent: {self.max_concurrent_tasks})"
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
loop = asyncio.get_event_loop()
|
|
252
|
+
|
|
253
|
+
try:
|
|
254
|
+
while self._running:
|
|
255
|
+
try:
|
|
256
|
+
# Run sync consume_one in executor to not block the event loop
|
|
257
|
+
msg = await loop.run_in_executor(
|
|
258
|
+
self._executor,
|
|
259
|
+
lambda: self.gateway_client.consume_one(timeout=0.5)
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
if msg:
|
|
263
|
+
agent = topic_to_agent.get(msg.topic)
|
|
264
|
+
if agent:
|
|
265
|
+
# Wait if we've hit max concurrency
|
|
266
|
+
while len(self._active_tasks) >= self.max_concurrent_tasks:
|
|
267
|
+
if not self._active_tasks:
|
|
268
|
+
break
|
|
269
|
+
|
|
270
|
+
# Wait with timeout to detect hung tasks
|
|
271
|
+
done, self._active_tasks = await asyncio.wait(
|
|
272
|
+
self._active_tasks,
|
|
273
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
274
|
+
timeout=30.0 # 30 second timeout
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Check for errors in completed tasks
|
|
278
|
+
for task in done:
|
|
279
|
+
self._task_start_times.pop(task, None)
|
|
280
|
+
if task.exception():
|
|
281
|
+
logger.error(
|
|
282
|
+
f"Task failed with exception: {task.exception()}"
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# If timeout with no completions, check for hung tasks
|
|
286
|
+
if not done:
|
|
287
|
+
current_time = time.time()
|
|
288
|
+
hung_tasks = [
|
|
289
|
+
task for task in self._active_tasks
|
|
290
|
+
if current_time - self._task_start_times.get(task, current_time) > self._hung_task_threshold
|
|
291
|
+
]
|
|
292
|
+
|
|
293
|
+
if hung_tasks:
|
|
294
|
+
logger.error(
|
|
295
|
+
f"Detected {len(hung_tasks)} hung tasks (>{self._hung_task_threshold}s). "
|
|
296
|
+
f"Total active: {len(self._active_tasks)}"
|
|
297
|
+
)
|
|
298
|
+
# Optional: cancel hung tasks
|
|
299
|
+
# for task in hung_tasks:
|
|
300
|
+
# task.cancel()
|
|
301
|
+
# self._task_start_times.pop(task, None)
|
|
302
|
+
else:
|
|
303
|
+
logger.warning(
|
|
304
|
+
f"No tasks completed in 30s. Active tasks: {len(self._active_tasks)} "
|
|
305
|
+
f"(may indicate slow handlers or blocking code)"
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Create task for concurrent processing
|
|
309
|
+
task = asyncio.create_task(
|
|
310
|
+
self._handle_message_safe(agent, msg)
|
|
311
|
+
)
|
|
312
|
+
self._active_tasks.add(task)
|
|
313
|
+
self._task_start_times[task] = time.time()
|
|
314
|
+
# Auto-remove from set when done
|
|
315
|
+
task.add_done_callback(self._active_tasks.discard)
|
|
316
|
+
|
|
317
|
+
logger.debug(
|
|
318
|
+
f"Dispatched message to {agent.name}, "
|
|
319
|
+
f"active tasks: {len(self._active_tasks)}"
|
|
320
|
+
)
|
|
321
|
+
else:
|
|
322
|
+
logger.warning(f"No agent for topic: {msg.topic}")
|
|
323
|
+
|
|
324
|
+
except Exception as e:
|
|
325
|
+
if self._running:
|
|
326
|
+
logger.error(f"Error in consume loop: {e}")
|
|
327
|
+
# Try to reconnect if connection lost
|
|
328
|
+
await asyncio.sleep(1.0)
|
|
329
|
+
try:
|
|
330
|
+
self._reconnect()
|
|
331
|
+
except Exception as re:
|
|
332
|
+
logger.error(f"Reconnect failed: {re}")
|
|
333
|
+
await asyncio.sleep(5.0)
|
|
334
|
+
|
|
335
|
+
finally:
|
|
336
|
+
# Wait for all active tasks to complete on shutdown
|
|
337
|
+
if self._active_tasks:
|
|
338
|
+
logger.info(
|
|
339
|
+
f"Waiting for {len(self._active_tasks)} active tasks to complete..."
|
|
340
|
+
)
|
|
341
|
+
await asyncio.gather(*self._active_tasks, return_exceptions=True)
|
|
342
|
+
logger.info("All tasks completed")
|
|
343
|
+
|
|
344
|
+
async def _handle_message_safe(self, agent: Agent, msg) -> None:
|
|
345
|
+
"""Wrapper that handles errors from message handlers gracefully."""
|
|
346
|
+
try:
|
|
347
|
+
await agent._handle_message_async(msg)
|
|
348
|
+
except Exception as e:
|
|
349
|
+
logger.error(
|
|
350
|
+
f"Error handling message for {agent.name}: {e}",
|
|
351
|
+
exc_info=True
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
def _reconnect(self):
|
|
355
|
+
"""Attempt to reconnect to the gateway."""
|
|
356
|
+
logger.info("Attempting to reconnect to gateway...")
|
|
357
|
+
agent_names = [agent.name for agent in self._agents]
|
|
358
|
+
self.gateway_client.close()
|
|
359
|
+
self.gateway_client.connect(agent_names)
|
|
360
|
+
logger.info("Reconnected to gateway successfully")
|
|
361
|
+
|
|
362
|
+
def _shutdown_handler(self, signum, frame):
|
|
363
|
+
"""Handle shutdown signals."""
|
|
364
|
+
logger.info(f"Received signal {signum}, initiating shutdown...")
|
|
365
|
+
self._running = False
|
|
366
|
+
|
|
367
|
+
def _shutdown(self):
|
|
368
|
+
"""Shutdown all agents gracefully."""
|
|
369
|
+
logger.info("Shutting down all agents...")
|
|
370
|
+
self._running = False
|
|
371
|
+
|
|
372
|
+
# Shutdown the executor
|
|
373
|
+
try:
|
|
374
|
+
self._executor.shutdown(wait=False)
|
|
375
|
+
except Exception as e:
|
|
376
|
+
logger.error(f"Error shutting down executor: {e}")
|
|
377
|
+
|
|
378
|
+
# Close async sessions - need to create a new loop since asyncio.run closed its loop
|
|
379
|
+
try:
|
|
380
|
+
loop = asyncio.new_event_loop()
|
|
381
|
+
asyncio.set_event_loop(loop)
|
|
382
|
+
try:
|
|
383
|
+
loop.run_until_complete(self._close_async_sessions())
|
|
384
|
+
finally:
|
|
385
|
+
loop.close()
|
|
386
|
+
except Exception as e:
|
|
387
|
+
logger.error(f"Error closing async sessions: {e}")
|
|
388
|
+
|
|
389
|
+
for agent in self._agents:
|
|
390
|
+
try:
|
|
391
|
+
agent.shutdown()
|
|
392
|
+
except Exception as e:
|
|
393
|
+
logger.error(f"Error shutting down agent {agent.name}: {e}")
|
|
394
|
+
|
|
395
|
+
# Close gateway connection
|
|
396
|
+
if self._connected:
|
|
397
|
+
self.gateway_client.close()
|
|
398
|
+
|
|
399
|
+
logger.info("All agents shut down successfully")
|
|
400
|
+
|
|
401
|
+
async def _close_async_sessions(self):
|
|
402
|
+
"""Close all async HTTP sessions."""
|
|
403
|
+
try:
|
|
404
|
+
await self.gateway_client.close_async()
|
|
405
|
+
except Exception as e:
|
|
406
|
+
logger.error(f"Error closing gateway async session: {e}")
|
|
407
|
+
|
|
408
|
+
try:
|
|
409
|
+
await self.conversation_client.close()
|
|
410
|
+
except Exception as e:
|
|
411
|
+
logger.error(f"Error closing conversation client: {e}")
|
|
412
|
+
|
|
413
|
+
try:
|
|
414
|
+
await self.memory_client.close()
|
|
415
|
+
except Exception as e:
|
|
416
|
+
logger.error(f"Error closing memory client: {e}")
|
march_agent/artifact.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Artifact class for attaching files/URLs to messages."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Optional, Dict, Any
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ArtifactType(str, Enum):
|
|
9
|
+
"""Artifact type enum for categorizing attachments."""
|
|
10
|
+
|
|
11
|
+
DOCUMENT = "document" # PDF, DOC, etc.
|
|
12
|
+
IMAGE = "image" # PNG, JPG, GIF, etc.
|
|
13
|
+
IFRAME = "iframe" # Embeddable content (maps, charts)
|
|
14
|
+
VIDEO = "video" # Video files or embeds
|
|
15
|
+
AUDIO = "audio" # Audio files
|
|
16
|
+
CODE = "code" # Code snippets with syntax highlighting
|
|
17
|
+
LINK = "link" # External links with preview
|
|
18
|
+
FILE = "file" # Generic file download
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class Artifact:
|
|
23
|
+
"""Represents an artifact (URL attachment) for a message.
|
|
24
|
+
|
|
25
|
+
Artifacts are URLs to files, images, iframes, or other resources
|
|
26
|
+
that agents can attach to their responses.
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
artifact = Artifact(
|
|
30
|
+
url="https://example.com/report.pdf",
|
|
31
|
+
type=ArtifactType.DOCUMENT,
|
|
32
|
+
title="Monthly Report",
|
|
33
|
+
metadata={"size": 1024000, "mimeType": "application/pdf"}
|
|
34
|
+
)
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
url: str
|
|
38
|
+
type: ArtifactType
|
|
39
|
+
title: Optional[str] = None
|
|
40
|
+
description: Optional[str] = None
|
|
41
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
42
|
+
position: Optional[int] = None
|
|
43
|
+
|
|
44
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
45
|
+
"""Convert to dict for JSON serialization."""
|
|
46
|
+
d: Dict[str, Any] = {
|
|
47
|
+
"url": self.url,
|
|
48
|
+
"type": self.type.value,
|
|
49
|
+
}
|
|
50
|
+
if self.title:
|
|
51
|
+
d["title"] = self.title
|
|
52
|
+
if self.description:
|
|
53
|
+
d["description"] = self.description
|
|
54
|
+
if self.metadata:
|
|
55
|
+
d["metadata"] = self.metadata
|
|
56
|
+
if self.position is not None:
|
|
57
|
+
d["position"] = self.position
|
|
58
|
+
return d
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""HTTP client for checkpoint storage API."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import List, Dict, Any, Optional
|
|
5
|
+
import aiohttp
|
|
6
|
+
|
|
7
|
+
from .exceptions import APIException
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CheckpointClient:
|
|
13
|
+
"""Async HTTP client for checkpoint-store API.
|
|
14
|
+
|
|
15
|
+
This client communicates with the conversation-store's checkpoint endpoints
|
|
16
|
+
to store and retrieve LangGraph-compatible checkpoints.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, base_url: str):
|
|
20
|
+
"""Initialize checkpoint client.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
base_url: Base URL for the checkpoint API (e.g., http://gateway/s/conversation-store)
|
|
24
|
+
"""
|
|
25
|
+
self.base_url = base_url.rstrip("/")
|
|
26
|
+
self._session: Optional[aiohttp.ClientSession] = None
|
|
27
|
+
|
|
28
|
+
async def _get_session(self) -> aiohttp.ClientSession:
|
|
29
|
+
"""Get or create aiohttp session."""
|
|
30
|
+
if self._session is None or self._session.closed:
|
|
31
|
+
timeout = aiohttp.ClientTimeout(total=30.0)
|
|
32
|
+
self._session = aiohttp.ClientSession(timeout=timeout)
|
|
33
|
+
return self._session
|
|
34
|
+
|
|
35
|
+
async def close(self):
|
|
36
|
+
"""Close the aiohttp session."""
|
|
37
|
+
if self._session and not self._session.closed:
|
|
38
|
+
await self._session.close()
|
|
39
|
+
|
|
40
|
+
async def put(
|
|
41
|
+
self,
|
|
42
|
+
config: Dict[str, Any],
|
|
43
|
+
checkpoint: Dict[str, Any],
|
|
44
|
+
metadata: Dict[str, Any],
|
|
45
|
+
new_versions: Optional[Dict[str, Any]] = None,
|
|
46
|
+
) -> Dict[str, Any]:
|
|
47
|
+
"""Store a checkpoint.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
config: RunnableConfig with configurable containing thread_id, checkpoint_ns, checkpoint_id
|
|
51
|
+
checkpoint: Checkpoint data (channel_values, channel_versions, etc.)
|
|
52
|
+
metadata: Checkpoint metadata (source, step, parents, writes)
|
|
53
|
+
new_versions: New channel versions (optional)
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Config of the stored checkpoint
|
|
57
|
+
"""
|
|
58
|
+
url = f"{self.base_url}/checkpoints/"
|
|
59
|
+
payload = {
|
|
60
|
+
"config": config,
|
|
61
|
+
"checkpoint": checkpoint,
|
|
62
|
+
"metadata": metadata,
|
|
63
|
+
"new_versions": new_versions or {},
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
session = await self._get_session()
|
|
67
|
+
try:
|
|
68
|
+
async with session.put(url, json=payload) as response:
|
|
69
|
+
if response.status >= 400:
|
|
70
|
+
error_text = await response.text()
|
|
71
|
+
raise APIException(f"Failed to store checkpoint: {response.status} - {error_text}")
|
|
72
|
+
return await response.json()
|
|
73
|
+
except aiohttp.ClientError as e:
|
|
74
|
+
raise APIException(f"Failed to store checkpoint: {e}")
|
|
75
|
+
|
|
76
|
+
async def get_tuple(
|
|
77
|
+
self,
|
|
78
|
+
thread_id: str,
|
|
79
|
+
checkpoint_ns: str = "",
|
|
80
|
+
checkpoint_id: Optional[str] = None,
|
|
81
|
+
) -> Optional[Dict[str, Any]]:
|
|
82
|
+
"""Get a checkpoint tuple.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
thread_id: Thread identifier
|
|
86
|
+
checkpoint_ns: Checkpoint namespace (default "")
|
|
87
|
+
checkpoint_id: Specific checkpoint ID (latest if not provided)
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
CheckpointTuple dict or None if not found
|
|
91
|
+
"""
|
|
92
|
+
url = f"{self.base_url}/checkpoints/{thread_id}"
|
|
93
|
+
params = {"checkpoint_ns": checkpoint_ns}
|
|
94
|
+
if checkpoint_id:
|
|
95
|
+
params["checkpoint_id"] = checkpoint_id
|
|
96
|
+
|
|
97
|
+
session = await self._get_session()
|
|
98
|
+
try:
|
|
99
|
+
async with session.get(url, params=params) as response:
|
|
100
|
+
if response.status == 404:
|
|
101
|
+
return None
|
|
102
|
+
if response.status >= 400:
|
|
103
|
+
error_text = await response.text()
|
|
104
|
+
raise APIException(f"Failed to get checkpoint: {response.status} - {error_text}")
|
|
105
|
+
result = await response.json()
|
|
106
|
+
# API returns null for not found
|
|
107
|
+
return result if result else None
|
|
108
|
+
except aiohttp.ClientError as e:
|
|
109
|
+
raise APIException(f"Failed to get checkpoint: {e}")
|
|
110
|
+
|
|
111
|
+
async def list(
|
|
112
|
+
self,
|
|
113
|
+
thread_id: Optional[str] = None,
|
|
114
|
+
checkpoint_ns: Optional[str] = None,
|
|
115
|
+
before: Optional[str] = None,
|
|
116
|
+
limit: Optional[int] = None,
|
|
117
|
+
) -> List[Dict[str, Any]]:
|
|
118
|
+
"""List checkpoints.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
thread_id: Filter by thread ID
|
|
122
|
+
checkpoint_ns: Filter by namespace
|
|
123
|
+
before: Return checkpoints before this checkpoint_id
|
|
124
|
+
limit: Maximum number of checkpoints to return
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
List of CheckpointTuple dicts
|
|
128
|
+
"""
|
|
129
|
+
url = f"{self.base_url}/checkpoints/"
|
|
130
|
+
params = {}
|
|
131
|
+
if thread_id:
|
|
132
|
+
params["thread_id"] = thread_id
|
|
133
|
+
if checkpoint_ns is not None:
|
|
134
|
+
params["checkpoint_ns"] = checkpoint_ns
|
|
135
|
+
if before:
|
|
136
|
+
params["before"] = before
|
|
137
|
+
if limit:
|
|
138
|
+
params["limit"] = limit
|
|
139
|
+
|
|
140
|
+
session = await self._get_session()
|
|
141
|
+
try:
|
|
142
|
+
async with session.get(url, params=params) as response:
|
|
143
|
+
if response.status >= 400:
|
|
144
|
+
error_text = await response.text()
|
|
145
|
+
raise APIException(f"Failed to list checkpoints: {response.status} - {error_text}")
|
|
146
|
+
return await response.json()
|
|
147
|
+
except aiohttp.ClientError as e:
|
|
148
|
+
raise APIException(f"Failed to list checkpoints: {e}")
|
|
149
|
+
|
|
150
|
+
async def delete_thread(self, thread_id: str) -> Dict[str, Any]:
|
|
151
|
+
"""Delete all checkpoints for a thread.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
thread_id: Thread identifier
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Dict with thread_id and deleted count
|
|
158
|
+
"""
|
|
159
|
+
url = f"{self.base_url}/checkpoints/{thread_id}"
|
|
160
|
+
|
|
161
|
+
session = await self._get_session()
|
|
162
|
+
try:
|
|
163
|
+
async with session.delete(url) as response:
|
|
164
|
+
if response.status >= 400:
|
|
165
|
+
error_text = await response.text()
|
|
166
|
+
raise APIException(f"Failed to delete checkpoints: {response.status} - {error_text}")
|
|
167
|
+
return await response.json()
|
|
168
|
+
except aiohttp.ClientError as e:
|
|
169
|
+
raise APIException(f"Failed to delete checkpoints: {e}")
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Checkpoint client module.
|
|
2
|
+
|
|
3
|
+
This module provides the low-level HTTP client for checkpoint operations.
|
|
4
|
+
For the LangGraph-compatible checkpointer, see march_agent.extensions.langgraph.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from march_agent.extensions.langgraph import HTTPCheckpointSaver
|
|
8
|
+
|
|
9
|
+
app = MarchAgentApp(gateway_url="agent-gateway:8080", api_key="key")
|
|
10
|
+
checkpointer = HTTPCheckpointSaver(app=app)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
# Re-export CheckpointClient for backwards compatibility
|
|
14
|
+
from .checkpoint_client import CheckpointClient
|
|
15
|
+
|
|
16
|
+
__all__ = ["CheckpointClient"]
|