basion-agent 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. basion_agent/__init__.py +62 -0
  2. basion_agent/agent.py +360 -0
  3. basion_agent/agent_state_client.py +149 -0
  4. basion_agent/app.py +502 -0
  5. basion_agent/artifact.py +58 -0
  6. basion_agent/attachment_client.py +153 -0
  7. basion_agent/checkpoint_client.py +169 -0
  8. basion_agent/checkpointer.py +16 -0
  9. basion_agent/cli.py +139 -0
  10. basion_agent/conversation.py +103 -0
  11. basion_agent/conversation_client.py +86 -0
  12. basion_agent/conversation_message.py +48 -0
  13. basion_agent/exceptions.py +36 -0
  14. basion_agent/extensions/__init__.py +1 -0
  15. basion_agent/extensions/langgraph.py +526 -0
  16. basion_agent/extensions/pydantic_ai.py +180 -0
  17. basion_agent/gateway_client.py +531 -0
  18. basion_agent/gateway_pb2.py +73 -0
  19. basion_agent/gateway_pb2_grpc.py +101 -0
  20. basion_agent/heartbeat.py +84 -0
  21. basion_agent/loki_handler.py +355 -0
  22. basion_agent/memory.py +73 -0
  23. basion_agent/memory_client.py +155 -0
  24. basion_agent/message.py +333 -0
  25. basion_agent/py.typed +0 -0
  26. basion_agent/streamer.py +184 -0
  27. basion_agent/structural/__init__.py +6 -0
  28. basion_agent/structural/artifact.py +94 -0
  29. basion_agent/structural/base.py +71 -0
  30. basion_agent/structural/stepper.py +125 -0
  31. basion_agent/structural/surface.py +90 -0
  32. basion_agent/structural/text_block.py +96 -0
  33. basion_agent/tools/__init__.py +19 -0
  34. basion_agent/tools/container.py +46 -0
  35. basion_agent/tools/knowledge_graph.py +306 -0
  36. basion_agent-0.4.0.dist-info/METADATA +880 -0
  37. basion_agent-0.4.0.dist-info/RECORD +41 -0
  38. basion_agent-0.4.0.dist-info/WHEEL +5 -0
  39. basion_agent-0.4.0.dist-info/entry_points.txt +2 -0
  40. basion_agent-0.4.0.dist-info/licenses/LICENSE +21 -0
  41. basion_agent-0.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,48 @@
1
+ """Typed message from conversation-store."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Any, Optional, List
5
+
6
+
7
+ @dataclass
8
+ class ConversationMessage:
9
+ """A message from conversation-store.
10
+
11
+ Represents a stored message in a conversation with full metadata.
12
+ """
13
+
14
+ id: str
15
+ conversation_id: str
16
+ role: str # "user", "assistant", "system"
17
+ content: str
18
+ sequence_number: int
19
+ from_: Optional[str] = None # Sender (agent name or "user")
20
+ to_: Optional[str] = None # Recipient
21
+ metadata: Optional[Dict[str, Any]] = None
22
+ schema: Optional[Dict[str, Any]] = None
23
+ response_schema: Optional[Dict[str, Any]] = None
24
+ created_at: Optional[str] = None
25
+ updated_at: Optional[str] = None
26
+
27
+ @classmethod
28
+ def from_dict(cls, data: Dict[str, Any]) -> "ConversationMessage":
29
+ """Create from conversation-store API response."""
30
+ return cls(
31
+ id=str(data.get("id", "")),
32
+ conversation_id=str(data.get("conversation_id", "")),
33
+ role=data.get("role", ""),
34
+ content=data.get("content", ""),
35
+ sequence_number=data.get("sequence_number", 0),
36
+ from_=data.get("from_"),
37
+ to_=data.get("to_"),
38
+ metadata=data.get("metadata"),
39
+ schema=data.get("schema"),
40
+ response_schema=data.get("response_schema"),
41
+ created_at=data.get("created_at"),
42
+ updated_at=data.get("updated_at"),
43
+ )
44
+
45
+ @classmethod
46
+ def from_list(cls, data: List[Dict[str, Any]]) -> List["ConversationMessage"]:
47
+ """Create list from conversation-store API response."""
48
+ return [cls.from_dict(item) for item in data]
@@ -0,0 +1,36 @@
1
+ """Custom exceptions for Basion AI Agent framework."""
2
+
3
+
4
+ class BasionAgentError(Exception):
5
+ """Base exception for all Basion Agent errors."""
6
+ pass
7
+
8
+
9
+ class RegistrationError(BasionAgentError):
10
+ """Raised when agent registration fails."""
11
+ pass
12
+
13
+
14
+ class KafkaError(BasionAgentError):
15
+ """Raised when Kafka operations fail."""
16
+ pass
17
+
18
+
19
+ class HeartbeatError(BasionAgentError):
20
+ """Raised when heartbeat operations fail."""
21
+ pass
22
+
23
+
24
+ class MessageHandlerError(BasionAgentError):
25
+ """Raised when message handler execution fails."""
26
+ pass
27
+
28
+
29
+ class ConfigurationError(BasionAgentError):
30
+ """Raised when configuration is invalid."""
31
+ pass
32
+
33
+
34
+ class APIException(BasionAgentError):
35
+ """Raised when conversation-store API requests fail."""
36
+ pass
@@ -0,0 +1 @@
1
+ # Extensions package for basion_agent
@@ -0,0 +1,526 @@
1
+ """LangGraph extension for basion_agent.
2
+
3
+ This module provides LangGraph-compatible components that integrate with
4
+ the basion_agent framework.
5
+
6
+ Usage:
7
+ from basion_agent.extensions.langgraph import HTTPCheckpointSaver
8
+
9
+ app = BasionAgentApp(gateway_url="agent-gateway:8080", api_key="key")
10
+ checkpointer = HTTPCheckpointSaver(app=app)
11
+
12
+ graph = StateGraph(...)
13
+ compiled = graph.compile(checkpointer=checkpointer)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import base64
20
+ import logging
21
+ from concurrent.futures import ThreadPoolExecutor
22
+ from datetime import datetime, timezone
23
+ from typing import (
24
+ TYPE_CHECKING,
25
+ Any,
26
+ AsyncIterator,
27
+ Dict,
28
+ Iterator,
29
+ Optional,
30
+ Sequence,
31
+ Set,
32
+ Tuple,
33
+ )
34
+
35
+ if TYPE_CHECKING:
36
+ from ..app import BasionAgentApp
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ # Try to import LangGraph types, but make them optional
41
+ try:
42
+ from langgraph.checkpoint.base import (
43
+ BaseCheckpointSaver,
44
+ ChannelVersions,
45
+ Checkpoint,
46
+ CheckpointMetadata,
47
+ CheckpointTuple,
48
+ )
49
+ from langchain_core.runnables import RunnableConfig
50
+
51
+ LANGGRAPH_AVAILABLE = True
52
+ except ImportError:
53
+ LANGGRAPH_AVAILABLE = False
54
+ # Define stub types for when langgraph is not installed
55
+ BaseCheckpointSaver = object
56
+ RunnableConfig = Dict[str, Any]
57
+ Checkpoint = Dict[str, Any]
58
+ CheckpointMetadata = Dict[str, Any]
59
+ CheckpointTuple = Tuple[Any, ...]
60
+ ChannelVersions = Dict[str, Any]
61
+
62
+ from ..checkpoint_client import CheckpointClient
63
+ from ..exceptions import APIException
64
+
65
+
66
+ def _generate_checkpoint_id() -> str:
67
+ """Generate a unique checkpoint ID based on timestamp."""
68
+ return datetime.now(timezone.utc).isoformat()
69
+
70
+
71
+ class HTTPCheckpointSaver(BaseCheckpointSaver if LANGGRAPH_AVAILABLE else object):
72
+ """HTTP-based checkpoint saver for LangGraph.
73
+
74
+ This checkpointer stores graph state via HTTP calls to the conversation-store
75
+ checkpoint API, enabling distributed checkpoint storage without direct
76
+ database access.
77
+
78
+ Example:
79
+ ```python
80
+ from basion_agent import BasionAgentApp
81
+ from basion_agent.extensions.langgraph import HTTPCheckpointSaver
82
+ from langgraph.graph import StateGraph
83
+
84
+ app = BasionAgentApp(gateway_url="agent-gateway:8080", api_key="key")
85
+ checkpointer = HTTPCheckpointSaver(app=app)
86
+
87
+ graph = StateGraph(MyState)
88
+ # ... define graph ...
89
+ compiled = graph.compile(checkpointer=checkpointer)
90
+
91
+ config = {"configurable": {"thread_id": "my-thread"}}
92
+ result = compiled.invoke({"messages": [...]}, config)
93
+ ```
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ app: "BasionAgentApp",
99
+ *,
100
+ serde: Optional[Any] = None,
101
+ ):
102
+ """Initialize HTTP checkpoint saver.
103
+
104
+ Args:
105
+ app: BasionAgentApp instance to get the gateway client from.
106
+ serde: Optional serializer/deserializer (for LangGraph compatibility)
107
+ """
108
+ if LANGGRAPH_AVAILABLE:
109
+ super().__init__(serde=serde)
110
+
111
+ base_url = app.gateway_client.conversation_store_url
112
+ self.client = CheckpointClient(base_url)
113
+ self._loop: Optional[asyncio.AbstractEventLoop] = None
114
+ self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="checkpoint")
115
+
116
+ def _get_loop(self) -> asyncio.AbstractEventLoop:
117
+ """Get or create an event loop for sync operations."""
118
+ try:
119
+ return asyncio.get_running_loop()
120
+ except RuntimeError:
121
+ if self._loop is None or self._loop.is_closed():
122
+ self._loop = asyncio.new_event_loop()
123
+ return self._loop
124
+
125
+ async def close(self):
126
+ """Close the HTTP client session and executor."""
127
+ await self.client.close()
128
+ self._executor.shutdown(wait=True)
129
+
130
+ # ==================== Config Helpers ====================
131
+
132
+ @staticmethod
133
+ def _get_thread_id(config: RunnableConfig) -> str:
134
+ """Extract thread_id from config."""
135
+ configurable = config.get("configurable", {})
136
+ thread_id = configurable.get("thread_id")
137
+ if not thread_id:
138
+ raise ValueError("Config must contain configurable.thread_id")
139
+ return thread_id
140
+
141
+ @staticmethod
142
+ def _get_checkpoint_ns(config: RunnableConfig) -> str:
143
+ """Extract checkpoint_ns from config (defaults to empty string)."""
144
+ return config.get("configurable", {}).get("checkpoint_ns", "")
145
+
146
+ @staticmethod
147
+ def _get_checkpoint_id(config: RunnableConfig) -> Optional[str]:
148
+ """Extract checkpoint_id from config."""
149
+ return config.get("configurable", {}).get("checkpoint_id")
150
+
151
+ # ==================== Async Methods (Primary Implementation) ====================
152
+
153
+ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
154
+ """Fetch a checkpoint tuple asynchronously."""
155
+ thread_id = self._get_thread_id(config)
156
+ checkpoint_ns = self._get_checkpoint_ns(config)
157
+ checkpoint_id = self._get_checkpoint_id(config)
158
+
159
+ try:
160
+ result = await self.client.get_tuple(
161
+ thread_id=thread_id,
162
+ checkpoint_ns=checkpoint_ns,
163
+ checkpoint_id=checkpoint_id,
164
+ )
165
+ except APIException as e:
166
+ logger.error(f"Failed to get checkpoint: {e}")
167
+ return None
168
+
169
+ if not result:
170
+ return None
171
+
172
+ return self._response_to_tuple(result)
173
+
174
+ async def alist(
175
+ self,
176
+ config: Optional[RunnableConfig],
177
+ *,
178
+ filter: Optional[Dict[str, Any]] = None,
179
+ before: Optional[RunnableConfig] = None,
180
+ limit: Optional[int] = None,
181
+ ) -> AsyncIterator[CheckpointTuple]:
182
+ """List checkpoints asynchronously."""
183
+ thread_id = None
184
+ checkpoint_ns = None
185
+
186
+ if config:
187
+ thread_id = config.get("configurable", {}).get("thread_id")
188
+ checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
189
+
190
+ before_id = None
191
+ if before:
192
+ before_id = before.get("configurable", {}).get("checkpoint_id")
193
+
194
+ try:
195
+ results = await self.client.list(
196
+ thread_id=thread_id,
197
+ checkpoint_ns=checkpoint_ns,
198
+ before=before_id,
199
+ limit=limit,
200
+ )
201
+ except APIException as e:
202
+ logger.error(f"Failed to list checkpoints: {e}")
203
+ return
204
+
205
+ for result in results:
206
+ tuple_result = self._response_to_tuple(result)
207
+ if tuple_result:
208
+ yield tuple_result
209
+
210
+ async def aput(
211
+ self,
212
+ config: RunnableConfig,
213
+ checkpoint: Checkpoint,
214
+ metadata: CheckpointMetadata,
215
+ new_versions: ChannelVersions,
216
+ ) -> RunnableConfig:
217
+ """Store a checkpoint asynchronously."""
218
+ thread_id = self._get_thread_id(config)
219
+ checkpoint_ns = self._get_checkpoint_ns(config)
220
+
221
+ checkpoint_id = self._get_checkpoint_id(config)
222
+ if not checkpoint_id:
223
+ checkpoint_id = checkpoint.get("id", _generate_checkpoint_id())
224
+
225
+ api_config = {
226
+ "configurable": {
227
+ "thread_id": thread_id,
228
+ "checkpoint_ns": checkpoint_ns,
229
+ "checkpoint_id": checkpoint_id,
230
+ }
231
+ }
232
+
233
+ checkpoint_data = self._checkpoint_to_api(checkpoint)
234
+ metadata_data = self._metadata_to_api(metadata)
235
+
236
+ try:
237
+ result = await self.client.put(
238
+ config=api_config,
239
+ checkpoint=checkpoint_data,
240
+ metadata=metadata_data,
241
+ new_versions=dict(new_versions) if new_versions else {},
242
+ )
243
+ except APIException as e:
244
+ logger.error(f"Failed to store checkpoint: {e}")
245
+ raise
246
+
247
+ return result.get("config", api_config)
248
+
249
+ async def aput_writes(
250
+ self,
251
+ config: RunnableConfig,
252
+ writes: Sequence[Tuple[str, Any]],
253
+ task_id: str,
254
+ task_path: str = "",
255
+ ) -> None:
256
+ """Store intermediate writes asynchronously (stub)."""
257
+ logger.debug(
258
+ f"aput_writes called (not persisted): task_id={task_id}, "
259
+ f"writes_count={len(writes)}"
260
+ )
261
+
262
+ async def adelete_thread(self, thread_id: str) -> None:
263
+ """Delete all checkpoints for a thread asynchronously."""
264
+ try:
265
+ await self.client.delete_thread(thread_id)
266
+ except APIException as e:
267
+ logger.error(f"Failed to delete thread checkpoints: {e}")
268
+ raise
269
+
270
+ # ==================== Sync Methods (Wrappers) ====================
271
+
272
+ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
273
+ """Fetch a checkpoint tuple synchronously (thread-safe)."""
274
+ return self._executor.submit(asyncio.run, self.aget_tuple(config)).result()
275
+
276
+ def list(
277
+ self,
278
+ config: Optional[RunnableConfig],
279
+ *,
280
+ filter: Optional[Dict[str, Any]] = None,
281
+ before: Optional[RunnableConfig] = None,
282
+ limit: Optional[int] = None,
283
+ ) -> Iterator[CheckpointTuple]:
284
+ """List checkpoints synchronously (thread-safe)."""
285
+
286
+ async def collect():
287
+ results = []
288
+ async for item in self.alist(config, filter=filter, before=before, limit=limit):
289
+ results.append(item)
290
+ return results
291
+
292
+ results = self._executor.submit(asyncio.run, collect()).result()
293
+ yield from results
294
+
295
+ def put(
296
+ self,
297
+ config: RunnableConfig,
298
+ checkpoint: Checkpoint,
299
+ metadata: CheckpointMetadata,
300
+ new_versions: ChannelVersions,
301
+ ) -> RunnableConfig:
302
+ """Store a checkpoint synchronously (thread-safe)."""
303
+ return self._executor.submit(
304
+ asyncio.run,
305
+ self.aput(config, checkpoint, metadata, new_versions)
306
+ ).result()
307
+
308
+ def put_writes(
309
+ self,
310
+ config: RunnableConfig,
311
+ writes: Sequence[Tuple[str, Any]],
312
+ task_id: str,
313
+ task_path: str = "",
314
+ ) -> None:
315
+ """Store intermediate writes synchronously (thread-safe)."""
316
+ self._executor.submit(
317
+ asyncio.run,
318
+ self.aput_writes(config, writes, task_id, task_path)
319
+ ).result()
320
+
321
+ def delete_thread(self, thread_id: str) -> None:
322
+ """Delete all checkpoints for a thread synchronously (thread-safe)."""
323
+ self._executor.submit(asyncio.run, self.adelete_thread(thread_id)).result()
324
+
325
+ # ==================== Data Conversion Helpers ====================
326
+
327
+ def _serialize_value(
328
+ self,
329
+ value: Any,
330
+ _visited: Optional[Set[int]] = None,
331
+ _depth: int = 0
332
+ ) -> Any:
333
+ """Serialize a value for JSON transmission with cycle detection.
334
+
335
+ Args:
336
+ value: Value to serialize
337
+ _visited: Set of visited object IDs (for cycle detection)
338
+ _depth: Current recursion depth (for depth limit)
339
+
340
+ Returns:
341
+ Serialized value safe for JSON
342
+ """
343
+ # Initialize visited set on first call
344
+ if _visited is None:
345
+ _visited = set()
346
+
347
+ # Depth protection (prevent stack overflow)
348
+ MAX_DEPTH = 100
349
+ if _depth > MAX_DEPTH:
350
+ logger.warning(
351
+ f"Serialization depth limit reached ({MAX_DEPTH}). "
352
+ "Returning placeholder."
353
+ )
354
+ return {"__max_depth_exceeded__": True}
355
+
356
+ # Handle bytes (before cycle check, as bytes are immutable)
357
+ if isinstance(value, bytes):
358
+ return {"__bytes__": base64.b64encode(value).decode("ascii")}
359
+
360
+ # Cycle detection for container types
361
+ if isinstance(value, (dict, list, tuple)):
362
+ obj_id = id(value)
363
+ if obj_id in _visited:
364
+ logger.warning(
365
+ "Circular reference detected during serialization. "
366
+ "Returning placeholder."
367
+ )
368
+ return {"__circular_ref__": True}
369
+
370
+ # Mark as visited
371
+ _visited.add(obj_id)
372
+
373
+ try:
374
+ # Serialize based on type
375
+ if isinstance(value, dict):
376
+ return {
377
+ k: self._serialize_value(v, _visited, _depth + 1)
378
+ for k, v in value.items()
379
+ }
380
+
381
+ if isinstance(value, list):
382
+ return [
383
+ self._serialize_value(item, _visited, _depth + 1)
384
+ for item in value
385
+ ]
386
+
387
+ if isinstance(value, tuple):
388
+ return {
389
+ "__tuple__": [
390
+ self._serialize_value(item, _visited, _depth + 1)
391
+ for item in value
392
+ ]
393
+ }
394
+ finally:
395
+ # Remove from visited after processing
396
+ # This allows same object in different branches (DAG structure)
397
+ _visited.discard(obj_id)
398
+
399
+ # Handle custom serialization with serde
400
+ if LANGGRAPH_AVAILABLE and hasattr(self, "serde") and self.serde is not None:
401
+ try:
402
+ type_str, serialized = self.serde.dumps_typed(value)
403
+ if isinstance(serialized, bytes):
404
+ serialized = base64.b64encode(serialized).decode("ascii")
405
+ return {"__serde_type__": type_str, "__serde_value__": serialized}
406
+ except Exception as e:
407
+ logger.warning(f"Failed to serialize value with serde: {e}")
408
+
409
+ # Handle objects with serialization methods
410
+ if hasattr(value, "model_dump"):
411
+ return self._serialize_value(value.model_dump(), _visited, _depth + 1)
412
+ if hasattr(value, "dict"):
413
+ return self._serialize_value(value.dict(), _visited, _depth + 1)
414
+ if hasattr(value, "to_dict"):
415
+ return self._serialize_value(value.to_dict(), _visited, _depth + 1)
416
+
417
+ # Return primitives as-is
418
+ return value
419
+
420
+ def _serialize_channel_values(self, channel_values: Dict[str, Any]) -> Dict[str, Any]:
421
+ """Serialize all channel values for API transmission."""
422
+ return self._serialize_value(channel_values)
423
+
424
+ def _deserialize_value(self, value: Any) -> Any:
425
+ """Deserialize a value, decoding base64 bytes and reconstructing tuples."""
426
+ if isinstance(value, dict):
427
+ if "__bytes__" in value:
428
+ return base64.b64decode(value["__bytes__"])
429
+ if "__tuple__" in value:
430
+ return tuple(self._deserialize_value(item) for item in value["__tuple__"])
431
+ if "__serde_type__" in value and "__serde_value__" in value:
432
+ if LANGGRAPH_AVAILABLE and hasattr(self, "serde") and self.serde is not None:
433
+ try:
434
+ serialized = value["__serde_value__"]
435
+ if isinstance(serialized, str):
436
+ serialized = base64.b64decode(serialized)
437
+ return self.serde.loads_typed((value["__serde_type__"], serialized))
438
+ except Exception as e:
439
+ logger.warning(f"Failed to deserialize value with serde: {e}")
440
+ return value
441
+ return {k: self._deserialize_value(v) for k, v in value.items()}
442
+
443
+ if isinstance(value, list):
444
+ return [self._deserialize_value(item) for item in value]
445
+
446
+ return value
447
+
448
+ def _deserialize_checkpoint(self, checkpoint_data: Dict[str, Any]) -> Dict[str, Any]:
449
+ """Deserialize checkpoint data received from API."""
450
+ if not checkpoint_data:
451
+ return checkpoint_data
452
+
453
+ result = dict(checkpoint_data)
454
+ if "channel_values" in result:
455
+ result["channel_values"] = self._deserialize_value(result["channel_values"])
456
+ return result
457
+
458
+ def _checkpoint_to_api(self, checkpoint: Checkpoint) -> Dict[str, Any]:
459
+ """Convert LangGraph Checkpoint to API format."""
460
+ if isinstance(checkpoint, dict):
461
+ channel_values = checkpoint.get("channel_values", {})
462
+ return {
463
+ "v": checkpoint.get("v", 1),
464
+ "id": checkpoint.get("id", _generate_checkpoint_id()),
465
+ "ts": checkpoint.get("ts", datetime.now(timezone.utc).isoformat()),
466
+ "channel_values": self._serialize_channel_values(channel_values),
467
+ "channel_versions": checkpoint.get("channel_versions", {}),
468
+ "versions_seen": checkpoint.get("versions_seen", {}),
469
+ "pending_sends": checkpoint.get("pending_sends", []),
470
+ }
471
+ channel_values = dict(getattr(checkpoint, "channel_values", {}))
472
+ return {
473
+ "v": getattr(checkpoint, "v", 1),
474
+ "id": getattr(checkpoint, "id", _generate_checkpoint_id()),
475
+ "ts": getattr(checkpoint, "ts", datetime.now(timezone.utc).isoformat()),
476
+ "channel_values": self._serialize_channel_values(channel_values),
477
+ "channel_versions": dict(getattr(checkpoint, "channel_versions", {})),
478
+ "versions_seen": dict(getattr(checkpoint, "versions_seen", {})),
479
+ "pending_sends": list(getattr(checkpoint, "pending_sends", [])),
480
+ }
481
+
482
+ def _serialize_writes(self, writes: Any) -> Any:
483
+ """Serialize writes field which may contain LangChain objects."""
484
+ if writes is None:
485
+ return None
486
+ return self._serialize_value(writes)
487
+
488
+ def _metadata_to_api(self, metadata: CheckpointMetadata) -> Dict[str, Any]:
489
+ """Convert LangGraph CheckpointMetadata to API format."""
490
+ if isinstance(metadata, dict):
491
+ return {
492
+ "source": metadata.get("source", "input"),
493
+ "step": metadata.get("step", -1),
494
+ "writes": self._serialize_writes(metadata.get("writes")),
495
+ "parents": metadata.get("parents", {}),
496
+ }
497
+ return {
498
+ "source": getattr(metadata, "source", "input"),
499
+ "step": getattr(metadata, "step", -1),
500
+ "writes": self._serialize_writes(getattr(metadata, "writes", None)),
501
+ "parents": dict(getattr(metadata, "parents", {})),
502
+ }
503
+
504
+ def _response_to_tuple(self, response: Dict[str, Any]) -> Optional[CheckpointTuple]:
505
+ """Convert API response to LangGraph CheckpointTuple."""
506
+ if not response:
507
+ return None
508
+
509
+ config = response.get("config", {})
510
+ checkpoint_data = response.get("checkpoint", {})
511
+ metadata_data = response.get("metadata", {})
512
+ parent_config = response.get("parent_config")
513
+ pending_writes = response.get("pending_writes")
514
+
515
+ checkpoint_data = self._deserialize_checkpoint(checkpoint_data)
516
+
517
+ if not LANGGRAPH_AVAILABLE:
518
+ return (config, checkpoint_data, metadata_data, parent_config, pending_writes)
519
+
520
+ return CheckpointTuple(
521
+ config=config,
522
+ checkpoint=checkpoint_data,
523
+ metadata=metadata_data,
524
+ parent_config=parent_config,
525
+ pending_writes=pending_writes or [],
526
+ )