agent-lab-sdk 0.1.36__tar.gz → 0.1.38__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agent-lab-sdk might be problematic. Click here for more details.
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/PKG-INFO +1 -1
- agent_lab_sdk-0.1.38/agent_lab_sdk/langgraph/checkpoint/agw_saver.py +706 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk.egg-info/PKG-INFO +1 -1
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/pyproject.toml +1 -1
- agent_lab_sdk-0.1.36/agent_lab_sdk/langgraph/checkpoint/agw_saver.py +0 -444
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/LICENSE +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/README.md +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/__init__.py +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/langgraph/checkpoint/__init__.py +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/langgraph/checkpoint/serde.py +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/llm/__init__.py +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/llm/agw_token_manager.py +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/llm/gigachat_token_manager.py +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/llm/llm.py +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/llm/throttled.py +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/metrics/__init__.py +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/metrics/metrics.py +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/schema/__init__.py +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/schema/input_types.py +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/schema/log_message.py +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/storage/__init__.py +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/storage/storage.py +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk.egg-info/SOURCES.txt +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk.egg-info/dependency_links.txt +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk.egg-info/requires.txt +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk.egg-info/top_level.txt +0 -0
- {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/setup.cfg +0 -0
|
@@ -0,0 +1,706 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import base64
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
from contextlib import asynccontextmanager
|
|
8
|
+
from random import random
|
|
9
|
+
from typing import Any, AsyncIterator, Dict, Iterable, Iterator, Optional, Sequence, Tuple
|
|
10
|
+
|
|
11
|
+
import orjson
|
|
12
|
+
from langgraph.checkpoint.serde.types import ChannelProtocol
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
15
|
+
from langchain_core.runnables import RunnableConfig
|
|
16
|
+
|
|
17
|
+
from langgraph.checkpoint.base import (
|
|
18
|
+
BaseCheckpointSaver,
|
|
19
|
+
ChannelVersions,
|
|
20
|
+
Checkpoint,
|
|
21
|
+
CheckpointMetadata,
|
|
22
|
+
CheckpointTuple,
|
|
23
|
+
get_checkpoint_id,
|
|
24
|
+
get_checkpoint_metadata,
|
|
25
|
+
)
|
|
26
|
+
from langgraph.checkpoint.serde.base import SerializerProtocol
|
|
27
|
+
from langgraph.checkpoint.serde.encrypted import EncryptedSerializer
|
|
28
|
+
|
|
29
|
+
from .serde import Serializer
|
|
30
|
+
from agent_lab_sdk.metrics import get_metric
|
|
31
|
+
|
|
32
|
+
__all__ = ["AsyncAGWCheckpointSaver"]
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
AGW_METRIC_LABELS = ["method", "endpoint"]
|
|
37
|
+
AGW_HTTP_SUCCESS = get_metric(
|
|
38
|
+
"counter",
|
|
39
|
+
"agw_http_success_total",
|
|
40
|
+
"Number of successful AGW HTTP requests",
|
|
41
|
+
labelnames=AGW_METRIC_LABELS,
|
|
42
|
+
)
|
|
43
|
+
AGW_HTTP_ERROR = get_metric(
|
|
44
|
+
"counter",
|
|
45
|
+
"agw_http_error_total",
|
|
46
|
+
"Number of failed AGW HTTP request attempts",
|
|
47
|
+
labelnames=AGW_METRIC_LABELS,
|
|
48
|
+
)
|
|
49
|
+
AGW_HTTP_FINAL_ERROR = get_metric(
|
|
50
|
+
"counter",
|
|
51
|
+
"agw_http_final_error_total",
|
|
52
|
+
"Number of AGW HTTP requests that failed after retries",
|
|
53
|
+
labelnames=AGW_METRIC_LABELS,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
TYPED_KEYS = ("type", "blob")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _b64decode_strict(value: str) -> bytes | None:
|
|
60
|
+
try:
|
|
61
|
+
return base64.b64decode(value, validate=True)
|
|
62
|
+
except Exception:
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
# ------------------------------------------------------------------ #
|
|
66
|
+
# helpers for Py < 3.10
|
|
67
|
+
# ------------------------------------------------------------------ #
|
|
68
|
+
try:
|
|
69
|
+
anext # type: ignore[name-defined]
|
|
70
|
+
except NameError: # pragma: no cover
|
|
71
|
+
|
|
72
|
+
async def anext(it):
|
|
73
|
+
return await it.__anext__()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
77
|
+
"""Persist checkpoints in Agent-Gateway с помощью `httpx` async client."""
|
|
78
|
+
|
|
79
|
+
# ---------------------------- init / ctx -------------------------
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
base_url: str = "http://localhost",
|
|
83
|
+
*,
|
|
84
|
+
serde: SerializerProtocol | None = None,
|
|
85
|
+
timeout: int | float = 10,
|
|
86
|
+
api_key: str | None = None,
|
|
87
|
+
extra_headers: Dict[str, str] | None = None,
|
|
88
|
+
verify: bool = True,
|
|
89
|
+
):
|
|
90
|
+
if not serde:
|
|
91
|
+
base_serde: SerializerProtocol = Serializer()
|
|
92
|
+
aes_key = (
|
|
93
|
+
os.getenv("LANGGRAPH_AES_KEY")
|
|
94
|
+
or os.getenv("AGW_AES_KEY")
|
|
95
|
+
or os.getenv("AES_KEY")
|
|
96
|
+
)
|
|
97
|
+
if aes_key:
|
|
98
|
+
base_serde = EncryptedSerializer.from_pycryptodome_aes(
|
|
99
|
+
base_serde, key=aes_key
|
|
100
|
+
)
|
|
101
|
+
serde = base_serde
|
|
102
|
+
super().__init__(serde=serde)
|
|
103
|
+
self.base_url = base_url.rstrip("/")
|
|
104
|
+
self.timeout = timeout
|
|
105
|
+
self.loop = asyncio.get_running_loop()
|
|
106
|
+
|
|
107
|
+
raw_attempts = os.getenv("AGW_HTTP_MAX_RETRIES")
|
|
108
|
+
if raw_attempts is None:
|
|
109
|
+
self.retry_max_attempts = 3
|
|
110
|
+
else:
|
|
111
|
+
try:
|
|
112
|
+
self.retry_max_attempts = max(int(raw_attempts), 1)
|
|
113
|
+
except ValueError:
|
|
114
|
+
logger.warning(
|
|
115
|
+
"Env %s expected int, got %r; using default %s",
|
|
116
|
+
"AGW_HTTP_MAX_RETRIES",
|
|
117
|
+
raw_attempts,
|
|
118
|
+
3,
|
|
119
|
+
)
|
|
120
|
+
self.retry_max_attempts = 3
|
|
121
|
+
|
|
122
|
+
raw_backoff_base = os.getenv("AGW_HTTP_RETRY_BACKOFF_BASE")
|
|
123
|
+
if raw_backoff_base is None:
|
|
124
|
+
self.retry_backoff_base = 0.5
|
|
125
|
+
else:
|
|
126
|
+
try:
|
|
127
|
+
self.retry_backoff_base = max(float(raw_backoff_base), 0.0)
|
|
128
|
+
except ValueError:
|
|
129
|
+
logger.warning(
|
|
130
|
+
"Env %s expected float, got %r; using default %.3f",
|
|
131
|
+
"AGW_HTTP_RETRY_BACKOFF_BASE",
|
|
132
|
+
raw_backoff_base,
|
|
133
|
+
0.5,
|
|
134
|
+
)
|
|
135
|
+
self.retry_backoff_base = 0.5
|
|
136
|
+
|
|
137
|
+
raw_backoff_max = os.getenv("AGW_HTTP_RETRY_BACKOFF_MAX")
|
|
138
|
+
if raw_backoff_max is None:
|
|
139
|
+
self.retry_backoff_max = 5.0
|
|
140
|
+
else:
|
|
141
|
+
try:
|
|
142
|
+
self.retry_backoff_max = max(float(raw_backoff_max), 0.0)
|
|
143
|
+
except ValueError:
|
|
144
|
+
logger.warning(
|
|
145
|
+
"Env %s expected float, got %r; using default %.3f",
|
|
146
|
+
"AGW_HTTP_RETRY_BACKOFF_MAX",
|
|
147
|
+
raw_backoff_max,
|
|
148
|
+
5.0,
|
|
149
|
+
)
|
|
150
|
+
self.retry_backoff_max = 5.0
|
|
151
|
+
|
|
152
|
+
raw_jitter = os.getenv("AGW_HTTP_RETRY_JITTER")
|
|
153
|
+
if raw_jitter is None:
|
|
154
|
+
self.retry_jitter = 0.25
|
|
155
|
+
else:
|
|
156
|
+
try:
|
|
157
|
+
self.retry_jitter = max(float(raw_jitter), 0.0)
|
|
158
|
+
except ValueError:
|
|
159
|
+
logger.warning(
|
|
160
|
+
"Env %s expected float, got %r; using default %.3f",
|
|
161
|
+
"AGW_HTTP_RETRY_JITTER",
|
|
162
|
+
raw_jitter,
|
|
163
|
+
0.25,
|
|
164
|
+
)
|
|
165
|
+
self.retry_jitter = 0.25
|
|
166
|
+
|
|
167
|
+
self.headers: Dict[str, str] = {
|
|
168
|
+
"Accept": "application/json",
|
|
169
|
+
"Content-Type": "application/json",
|
|
170
|
+
}
|
|
171
|
+
if extra_headers:
|
|
172
|
+
self.headers.update(extra_headers)
|
|
173
|
+
if api_key:
|
|
174
|
+
self.headers["Authorization"] = f"Bearer {api_key}"
|
|
175
|
+
|
|
176
|
+
self._verify = verify
|
|
177
|
+
self._client: httpx.AsyncClient | None = None
|
|
178
|
+
|
|
179
|
+
def _create_client(self) -> httpx.AsyncClient:
|
|
180
|
+
return httpx.AsyncClient(
|
|
181
|
+
base_url=self.base_url,
|
|
182
|
+
headers=self.headers,
|
|
183
|
+
timeout=self.timeout,
|
|
184
|
+
verify=self._verify,
|
|
185
|
+
trust_env=True,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def _ensure_client(self) -> httpx.AsyncClient:
|
|
189
|
+
client = self._client
|
|
190
|
+
if client is None or client.is_closed:
|
|
191
|
+
if client is not None and client.is_closed:
|
|
192
|
+
logger.debug("Recreating closed httpx.AsyncClient for AGW")
|
|
193
|
+
client = self._create_client()
|
|
194
|
+
self._client = client
|
|
195
|
+
return client
|
|
196
|
+
|
|
197
|
+
def _compute_retry_delay(self, attempt: int) -> float:
|
|
198
|
+
if attempt <= 0:
|
|
199
|
+
attempt = 1
|
|
200
|
+
if self.retry_backoff_base <= 0:
|
|
201
|
+
delay = 0.0
|
|
202
|
+
else:
|
|
203
|
+
delay = self.retry_backoff_base * (2 ** (attempt - 1))
|
|
204
|
+
if self.retry_backoff_max > 0:
|
|
205
|
+
delay = min(delay, self.retry_backoff_max)
|
|
206
|
+
if self.retry_jitter > 0:
|
|
207
|
+
delay += self.retry_jitter * random()
|
|
208
|
+
return delay
|
|
209
|
+
|
|
210
|
+
async def __aenter__(self): # noqa: D401
|
|
211
|
+
return self
|
|
212
|
+
|
|
213
|
+
async def __aexit__(self, exc_type, exc, tb): # noqa: D401
|
|
214
|
+
if self._client is not None:
|
|
215
|
+
try:
|
|
216
|
+
await self._client.aclose()
|
|
217
|
+
except Exception as close_exc: # pragma: no cover - best effort
|
|
218
|
+
logger.debug("Failed to close AGW httpx.AsyncClient: %s", close_exc)
|
|
219
|
+
finally:
|
|
220
|
+
self._client = None
|
|
221
|
+
|
|
222
|
+
# ----------------------- universal dump/load ---------------------
|
|
223
|
+
# def _safe_dump(self, obj: Any) -> Any:
|
|
224
|
+
# """self.serde.dump → гарантированная JSON-строка."""
|
|
225
|
+
# dumped = self.serde.dumps(obj)
|
|
226
|
+
# if isinstance(dumped, (bytes, bytearray)):
|
|
227
|
+
# return base64.b64encode(dumped).decode() # str
|
|
228
|
+
# return dumped # уже json-совместимо
|
|
229
|
+
|
|
230
|
+
def _safe_dump(self, obj: Any) -> Any:
|
|
231
|
+
"""bytes → python-object; fallback base64 для реально бинарных данных."""
|
|
232
|
+
dumped = self.serde.dumps(obj)
|
|
233
|
+
if isinstance(dumped, (bytes, bytearray)):
|
|
234
|
+
try:
|
|
235
|
+
# 1) bytes → str
|
|
236
|
+
s = dumped.decode()
|
|
237
|
+
# 2) str JSON → python (list/dict/scalar)
|
|
238
|
+
return orjson.loads(s)
|
|
239
|
+
except (UnicodeDecodeError, orjson.JSONDecodeError):
|
|
240
|
+
# не UTF-8 или не JSON → base64
|
|
241
|
+
return base64.b64encode(dumped).decode()
|
|
242
|
+
return dumped
|
|
243
|
+
|
|
244
|
+
def _safe_load(self, obj: Any) -> Any:
|
|
245
|
+
if obj is None:
|
|
246
|
+
return None
|
|
247
|
+
|
|
248
|
+
if isinstance(obj, dict):
|
|
249
|
+
if all(k in obj for k in TYPED_KEYS):
|
|
250
|
+
t = obj.get("type")
|
|
251
|
+
blob = obj.get("blob")
|
|
252
|
+
if blob is None:
|
|
253
|
+
try:
|
|
254
|
+
return self.serde.loads_typed((t, None))
|
|
255
|
+
except Exception:
|
|
256
|
+
return obj
|
|
257
|
+
if isinstance(blob, str):
|
|
258
|
+
payload = _b64decode_strict(blob)
|
|
259
|
+
if payload is not None:
|
|
260
|
+
try:
|
|
261
|
+
return self.serde.loads_typed((t, payload))
|
|
262
|
+
except Exception:
|
|
263
|
+
# fall back to generic handling below
|
|
264
|
+
pass
|
|
265
|
+
try:
|
|
266
|
+
return self.serde.loads(orjson.dumps(obj))
|
|
267
|
+
except Exception:
|
|
268
|
+
return obj
|
|
269
|
+
|
|
270
|
+
if isinstance(obj, (list, tuple)):
|
|
271
|
+
if (
|
|
272
|
+
len(obj) == 2
|
|
273
|
+
and isinstance(obj[0], str)
|
|
274
|
+
and (obj[1] is None or isinstance(obj[1], str))
|
|
275
|
+
):
|
|
276
|
+
blob = obj[1]
|
|
277
|
+
if blob is None:
|
|
278
|
+
try:
|
|
279
|
+
return self.serde.loads_typed((obj[0], None))
|
|
280
|
+
except Exception:
|
|
281
|
+
pass
|
|
282
|
+
elif isinstance(blob, str):
|
|
283
|
+
payload = _b64decode_strict(blob)
|
|
284
|
+
if payload is not None:
|
|
285
|
+
try:
|
|
286
|
+
return self.serde.loads_typed((obj[0], payload))
|
|
287
|
+
except Exception:
|
|
288
|
+
pass
|
|
289
|
+
try:
|
|
290
|
+
return self.serde.loads(orjson.dumps(list(obj)))
|
|
291
|
+
except Exception:
|
|
292
|
+
return obj
|
|
293
|
+
|
|
294
|
+
if isinstance(obj, str):
|
|
295
|
+
try:
|
|
296
|
+
return self.serde.loads(obj.encode())
|
|
297
|
+
except Exception:
|
|
298
|
+
payload = _b64decode_strict(obj)
|
|
299
|
+
if payload is not None:
|
|
300
|
+
try:
|
|
301
|
+
return self.serde.loads(payload)
|
|
302
|
+
except Exception:
|
|
303
|
+
pass
|
|
304
|
+
return obj
|
|
305
|
+
|
|
306
|
+
try:
|
|
307
|
+
return self.serde.loads(obj)
|
|
308
|
+
except Exception:
|
|
309
|
+
return obj
|
|
310
|
+
|
|
311
|
+
# def _safe_load(self, obj: Any) -> Any:
|
|
312
|
+
# """Обратная операция к _safe_dump."""
|
|
313
|
+
# if isinstance(obj, str):
|
|
314
|
+
# try:
|
|
315
|
+
# return self.serde.load(base64.b64decode(obj))
|
|
316
|
+
# except Exception:
|
|
317
|
+
# # не base64 — обычная строка
|
|
318
|
+
# return self.serde.load(obj)
|
|
319
|
+
# return self.serde.load(obj)
|
|
320
|
+
|
|
321
|
+
# ----------------------- config <-> api --------------------------
|
|
322
|
+
def _to_api_config(self, cfg: RunnableConfig | None) -> Dict[str, Any]:
|
|
323
|
+
if not cfg:
|
|
324
|
+
return {}
|
|
325
|
+
c = cfg.get("configurable", {})
|
|
326
|
+
res: Dict[str, Any] = {
|
|
327
|
+
"threadId": c.get("thread_id", ""),
|
|
328
|
+
"checkpointNs": c.get("checkpoint_ns", ""),
|
|
329
|
+
}
|
|
330
|
+
if cid := c.get("checkpoint_id"):
|
|
331
|
+
res["checkpointId"] = cid
|
|
332
|
+
if ts := c.get("thread_ts"):
|
|
333
|
+
res["threadTs"] = ts
|
|
334
|
+
return res
|
|
335
|
+
|
|
336
|
+
# --------------------- checkpoint (de)ser ------------------------
|
|
337
|
+
def _encode_cp(self, cp: Checkpoint) -> Dict[str, Any]:
|
|
338
|
+
pending: list[Any] = []
|
|
339
|
+
for item in cp.get("pending_sends", []) or []:
|
|
340
|
+
try:
|
|
341
|
+
channel, value = item
|
|
342
|
+
except Exception:
|
|
343
|
+
pending.append(item)
|
|
344
|
+
continue
|
|
345
|
+
pending.append([channel, self._safe_dump(value)])
|
|
346
|
+
return {
|
|
347
|
+
"v": cp["v"],
|
|
348
|
+
"id": cp["id"],
|
|
349
|
+
"ts": cp["ts"],
|
|
350
|
+
"channelValues": {k: self._safe_dump(v) for k, v in cp["channel_values"].items()},
|
|
351
|
+
"channelVersions": cp["channel_versions"],
|
|
352
|
+
"versionsSeen": cp["versions_seen"],
|
|
353
|
+
"pendingSends": pending,
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
def _decode_cp(self, raw: Dict[str, Any]) -> Checkpoint:
|
|
357
|
+
pending_sends: list[Tuple[str, Any]] = []
|
|
358
|
+
for obj in raw.get("pendingSends", []) or []:
|
|
359
|
+
if isinstance(obj, dict) and "channel" in obj:
|
|
360
|
+
channel = obj["channel"]
|
|
361
|
+
value_payload: Any = obj.get("value")
|
|
362
|
+
if value_payload is None and all(k in obj for k in TYPED_KEYS):
|
|
363
|
+
value_payload = {k: obj[k] for k in TYPED_KEYS}
|
|
364
|
+
pending_sends.append((channel, self._safe_load(value_payload)))
|
|
365
|
+
elif isinstance(obj, (list, tuple)) and len(obj) >= 2:
|
|
366
|
+
channel = obj[0]
|
|
367
|
+
value_payload = obj[1]
|
|
368
|
+
pending_sends.append((channel, self._safe_load(value_payload)))
|
|
369
|
+
else:
|
|
370
|
+
pending_sends.append(obj) # сохраняем как есть, если формат неизвестен
|
|
371
|
+
return Checkpoint(
|
|
372
|
+
v=raw["v"],
|
|
373
|
+
id=raw["id"],
|
|
374
|
+
ts=raw["ts"],
|
|
375
|
+
channel_values={k: self._safe_load(v) for k, v in raw["channelValues"].items()},
|
|
376
|
+
channel_versions=raw["channelVersions"],
|
|
377
|
+
versions_seen=raw["versionsSeen"],
|
|
378
|
+
pending_sends=pending_sends,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
def _decode_config(self, raw: Dict[str, Any] | None) -> Optional[RunnableConfig]:
|
|
382
|
+
if not raw:
|
|
383
|
+
return None
|
|
384
|
+
return RunnableConfig(
|
|
385
|
+
tags=raw.get("tags"),
|
|
386
|
+
metadata=raw.get("metadata"),
|
|
387
|
+
callbacks=raw.get("callbacks"),
|
|
388
|
+
run_name=raw.get("run_name"),
|
|
389
|
+
max_concurrency=raw.get("max_concurrency"),
|
|
390
|
+
recursion_limit=raw.get("recursion_limit"),
|
|
391
|
+
configurable=self._decode_configurable(raw.get("configurable") or {}),
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
def _decode_configurable(self, raw: Dict[str, Any]) -> dict[str, Any]:
|
|
395
|
+
return {
|
|
396
|
+
"thread_id": raw.get("threadId"),
|
|
397
|
+
"thread_ts": raw.get("threadTs"),
|
|
398
|
+
"checkpoint_ns": raw.get("checkpointNs"),
|
|
399
|
+
"checkpoint_id": raw.get("checkpointId")
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
# metadata (de)ser
|
|
403
|
+
def _enc_meta(self, md: CheckpointMetadata) -> CheckpointMetadata:
|
|
404
|
+
if not md:
|
|
405
|
+
return {}
|
|
406
|
+
out: CheckpointMetadata = {}
|
|
407
|
+
for k, v in md.items():
|
|
408
|
+
out[k] = self._enc_meta(v) if isinstance(v, dict) else self._safe_dump(v) # type: ignore[assignment]
|
|
409
|
+
return out
|
|
410
|
+
|
|
411
|
+
def _dec_meta(self, md: Any) -> Any:
|
|
412
|
+
if isinstance(md, dict):
|
|
413
|
+
return {k: self._dec_meta(v) for k, v in md.items()}
|
|
414
|
+
return self._safe_load(md)
|
|
415
|
+
|
|
416
|
+
# ------------------------ HTTP wrapper ---------------------------
|
|
417
|
+
async def _http(
|
|
418
|
+
self,
|
|
419
|
+
method: str,
|
|
420
|
+
path: str,
|
|
421
|
+
*,
|
|
422
|
+
ok_statuses: Iterable[int] | None = None,
|
|
423
|
+
**kw,
|
|
424
|
+
) -> httpx.Response:
|
|
425
|
+
if "json" in kw:
|
|
426
|
+
payload = kw.pop("json")
|
|
427
|
+
kw["data"] = orjson.dumps(payload)
|
|
428
|
+
logger.debug("AGW HTTP payload: %s", kw["data"].decode())
|
|
429
|
+
|
|
430
|
+
ok_set = set(ok_statuses) if ok_statuses is not None else set()
|
|
431
|
+
|
|
432
|
+
attempt = 1
|
|
433
|
+
while True:
|
|
434
|
+
client = self._ensure_client()
|
|
435
|
+
try:
|
|
436
|
+
resp = await client.request(method, path, **kw)
|
|
437
|
+
except httpx.RequestError as exc:
|
|
438
|
+
AGW_HTTP_ERROR.labels(method, path).inc()
|
|
439
|
+
logger.warning(
|
|
440
|
+
"AGW request %s %s failed on attempt %d/%d: %s",
|
|
441
|
+
method,
|
|
442
|
+
path,
|
|
443
|
+
attempt,
|
|
444
|
+
self.retry_max_attempts,
|
|
445
|
+
exc,
|
|
446
|
+
)
|
|
447
|
+
if attempt >= self.retry_max_attempts:
|
|
448
|
+
AGW_HTTP_FINAL_ERROR.labels(method, path).inc()
|
|
449
|
+
if self._client is not None:
|
|
450
|
+
try:
|
|
451
|
+
await self._client.aclose()
|
|
452
|
+
except Exception as close_exc: # pragma: no cover
|
|
453
|
+
logger.debug(
|
|
454
|
+
"Failed to close AGW httpx.AsyncClient: %s",
|
|
455
|
+
close_exc,
|
|
456
|
+
)
|
|
457
|
+
finally:
|
|
458
|
+
self._client = None
|
|
459
|
+
raise
|
|
460
|
+
|
|
461
|
+
if self._client is not None:
|
|
462
|
+
try:
|
|
463
|
+
await self._client.aclose()
|
|
464
|
+
except Exception as close_exc: # pragma: no cover
|
|
465
|
+
logger.debug(
|
|
466
|
+
"Failed to close AGW httpx.AsyncClient: %s",
|
|
467
|
+
close_exc,
|
|
468
|
+
)
|
|
469
|
+
finally:
|
|
470
|
+
self._client = None
|
|
471
|
+
delay = self._compute_retry_delay(attempt)
|
|
472
|
+
if delay > 0:
|
|
473
|
+
await asyncio.sleep(delay)
|
|
474
|
+
attempt += 1
|
|
475
|
+
continue
|
|
476
|
+
|
|
477
|
+
status = resp.status_code
|
|
478
|
+
if status < 400 or status in ok_set:
|
|
479
|
+
AGW_HTTP_SUCCESS.labels(method, path).inc()
|
|
480
|
+
return resp
|
|
481
|
+
|
|
482
|
+
AGW_HTTP_ERROR.labels(method, path).inc()
|
|
483
|
+
if status in (404, 406):
|
|
484
|
+
AGW_HTTP_FINAL_ERROR.labels(method, path).inc()
|
|
485
|
+
return resp
|
|
486
|
+
|
|
487
|
+
if attempt >= self.retry_max_attempts:
|
|
488
|
+
AGW_HTTP_FINAL_ERROR.labels(method, path).inc()
|
|
489
|
+
return resp
|
|
490
|
+
|
|
491
|
+
try:
|
|
492
|
+
await resp.aclose()
|
|
493
|
+
except Exception as exc: # pragma: no cover - best effort
|
|
494
|
+
logger.debug("Failed to close AGW httpx.Response before retry: %s", exc)
|
|
495
|
+
|
|
496
|
+
if self._client is not None:
|
|
497
|
+
try:
|
|
498
|
+
await self._client.aclose()
|
|
499
|
+
except Exception as close_exc: # pragma: no cover
|
|
500
|
+
logger.debug(
|
|
501
|
+
"Failed to close AGW httpx.AsyncClient: %s",
|
|
502
|
+
close_exc,
|
|
503
|
+
)
|
|
504
|
+
finally:
|
|
505
|
+
self._client = None
|
|
506
|
+
delay = self._compute_retry_delay(attempt)
|
|
507
|
+
if delay > 0:
|
|
508
|
+
await asyncio.sleep(delay)
|
|
509
|
+
attempt += 1
|
|
510
|
+
|
|
511
|
+
# -------------------- api -> CheckpointTuple ----------------------
|
|
512
|
+
def _to_tuple(self, node: Dict[str, Any]) -> CheckpointTuple:
|
|
513
|
+
pending = None
|
|
514
|
+
if node.get("pendingWrites"):
|
|
515
|
+
pending = []
|
|
516
|
+
for w in node["pendingWrites"]:
|
|
517
|
+
if isinstance(w, dict):
|
|
518
|
+
first = w.get("first")
|
|
519
|
+
second = w.get("second")
|
|
520
|
+
third = w.get("third")
|
|
521
|
+
if third is None and isinstance(second, dict) and all(
|
|
522
|
+
k in second for k in TYPED_KEYS
|
|
523
|
+
):
|
|
524
|
+
third = second
|
|
525
|
+
pending.append((first, second, self._safe_load(third)))
|
|
526
|
+
elif isinstance(w, (list, tuple)):
|
|
527
|
+
if len(w) == 3:
|
|
528
|
+
first, second, third = w
|
|
529
|
+
elif len(w) == 2:
|
|
530
|
+
first, second = w
|
|
531
|
+
third = None
|
|
532
|
+
else:
|
|
533
|
+
continue
|
|
534
|
+
pending.append((first, second, self._safe_load(third)))
|
|
535
|
+
return CheckpointTuple(
|
|
536
|
+
config=self._decode_config(node.get("config")),
|
|
537
|
+
checkpoint=self._decode_cp(node["checkpoint"]),
|
|
538
|
+
metadata=self._dec_meta(node.get("metadata")),
|
|
539
|
+
parent_config=self._decode_config(node.get("parentConfig")),
|
|
540
|
+
pending_writes=pending,
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# =================================================================
|
|
544
|
+
# async-методы BaseCheckpointSaver
|
|
545
|
+
# =================================================================
|
|
546
|
+
async def aget_tuple(self, cfg: RunnableConfig) -> CheckpointTuple | None:
|
|
547
|
+
cid = get_checkpoint_id(cfg)
|
|
548
|
+
api_cfg = self._to_api_config(cfg)
|
|
549
|
+
tid = api_cfg["threadId"]
|
|
550
|
+
|
|
551
|
+
if cid:
|
|
552
|
+
path = f"/checkpoint/{tid}/{cid}"
|
|
553
|
+
params = {"checkpointNs": api_cfg.get("checkpointNs", "")}
|
|
554
|
+
else:
|
|
555
|
+
path = f"/checkpoint/{tid}"
|
|
556
|
+
params = None
|
|
557
|
+
|
|
558
|
+
resp = await self._http("GET", path, params=params)
|
|
559
|
+
logger.debug("AGW aget_tuple response: %s", resp.text)
|
|
560
|
+
|
|
561
|
+
if not resp.text:
|
|
562
|
+
return None
|
|
563
|
+
if resp.status_code in (404, 406):
|
|
564
|
+
return None
|
|
565
|
+
resp.raise_for_status()
|
|
566
|
+
return self._to_tuple(resp.json())
|
|
567
|
+
|
|
568
|
+
async def alist(
|
|
569
|
+
self,
|
|
570
|
+
cfg: RunnableConfig | None,
|
|
571
|
+
*,
|
|
572
|
+
filter: Dict[str, Any] | None = None,
|
|
573
|
+
before: RunnableConfig | None = None,
|
|
574
|
+
limit: int | None = None,
|
|
575
|
+
) -> AsyncIterator[CheckpointTuple]:
|
|
576
|
+
payload = {
|
|
577
|
+
"config": self._to_api_config(cfg) if cfg else None,
|
|
578
|
+
"filter": filter,
|
|
579
|
+
"before": self._to_api_config(before) if before else None,
|
|
580
|
+
"limit": limit,
|
|
581
|
+
}
|
|
582
|
+
resp = await self._http("POST", "/checkpoint/list", json=payload)
|
|
583
|
+
logger.debug("AGW alist response: %s", resp.text)
|
|
584
|
+
resp.raise_for_status()
|
|
585
|
+
for item in resp.json():
|
|
586
|
+
yield self._to_tuple(item)
|
|
587
|
+
|
|
588
|
+
async def aput(
|
|
589
|
+
self,
|
|
590
|
+
cfg: RunnableConfig,
|
|
591
|
+
cp: Checkpoint,
|
|
592
|
+
metadata: CheckpointMetadata,
|
|
593
|
+
new_versions: ChannelVersions,
|
|
594
|
+
) -> RunnableConfig:
|
|
595
|
+
payload = {
|
|
596
|
+
"config": self._to_api_config(cfg),
|
|
597
|
+
"checkpoint": self._encode_cp(cp),
|
|
598
|
+
"metadata": self._enc_meta(get_checkpoint_metadata(cfg, metadata)),
|
|
599
|
+
"newVersions": new_versions,
|
|
600
|
+
}
|
|
601
|
+
resp = await self._http("POST", "/checkpoint", json=payload)
|
|
602
|
+
logger.debug("AGW aput response: %s", resp.text)
|
|
603
|
+
resp.raise_for_status()
|
|
604
|
+
return resp.json()["config"]
|
|
605
|
+
|
|
606
|
+
async def aput_writes(
|
|
607
|
+
self,
|
|
608
|
+
cfg: RunnableConfig,
|
|
609
|
+
writes: Sequence[Tuple[str, Any]],
|
|
610
|
+
task_id: str,
|
|
611
|
+
task_path: str = "",
|
|
612
|
+
) -> None:
|
|
613
|
+
enc = [{"first": ch, "second": self._safe_dump(v)} for ch, v in writes]
|
|
614
|
+
payload = {
|
|
615
|
+
"config": self._to_api_config(cfg),
|
|
616
|
+
"writes": enc,
|
|
617
|
+
"taskId": task_id,
|
|
618
|
+
"taskPath": task_path,
|
|
619
|
+
}
|
|
620
|
+
resp = await self._http("POST", "/checkpoint/writes", json=payload)
|
|
621
|
+
logger.debug("AGW aput_writes response: %s", resp.text)
|
|
622
|
+
resp.raise_for_status()
|
|
623
|
+
|
|
624
|
+
async def adelete_thread(self, thread_id: str) -> None:
|
|
625
|
+
resp = await self._http("DELETE", f"/checkpoint/{thread_id}")
|
|
626
|
+
resp.raise_for_status()
|
|
627
|
+
|
|
628
|
+
# =================================================================
|
|
629
|
+
# sync-обёртки
|
|
630
|
+
# =================================================================
|
|
631
|
+
def _run(self, coro):
|
|
632
|
+
return asyncio.run_coroutine_threadsafe(coro, self.loop).result()
|
|
633
|
+
|
|
634
|
+
def list(
|
|
635
|
+
self,
|
|
636
|
+
cfg: RunnableConfig | None,
|
|
637
|
+
*,
|
|
638
|
+
filter: Dict[str, Any] | None = None,
|
|
639
|
+
before: RunnableConfig | None = None,
|
|
640
|
+
limit: int | None = None,
|
|
641
|
+
) -> Iterator[CheckpointTuple]:
|
|
642
|
+
aiter_ = self.alist(cfg, filter=filter, before=before, limit=limit)
|
|
643
|
+
while True:
|
|
644
|
+
try:
|
|
645
|
+
yield self._run(anext(aiter_))
|
|
646
|
+
except StopAsyncIteration:
|
|
647
|
+
break
|
|
648
|
+
|
|
649
|
+
def get_tuple(self, cfg: RunnableConfig) -> CheckpointTuple | None:
|
|
650
|
+
return self._run(self.aget_tuple(cfg))
|
|
651
|
+
|
|
652
|
+
def put(
|
|
653
|
+
self,
|
|
654
|
+
cfg: RunnableConfig,
|
|
655
|
+
cp: Checkpoint,
|
|
656
|
+
metadata: CheckpointMetadata,
|
|
657
|
+
new_versions: ChannelVersions,
|
|
658
|
+
) -> RunnableConfig:
|
|
659
|
+
return self._run(self.aput(cfg, cp, metadata, new_versions))
|
|
660
|
+
|
|
661
|
+
def put_writes(
|
|
662
|
+
self,
|
|
663
|
+
cfg: RunnableConfig,
|
|
664
|
+
writes: Sequence[Tuple[str, Any]],
|
|
665
|
+
task_id: str,
|
|
666
|
+
task_path: str = "",
|
|
667
|
+
) -> None:
|
|
668
|
+
self._run(self.aput_writes(cfg, writes, task_id, task_path))
|
|
669
|
+
|
|
670
|
+
def delete_thread(self, thread_id: str) -> None:
|
|
671
|
+
self._run(self.adelete_thread(thread_id))
|
|
672
|
+
|
|
673
|
+
def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
|
|
674
|
+
if current is None:
|
|
675
|
+
current_v = 0
|
|
676
|
+
elif isinstance(current, int):
|
|
677
|
+
current_v = current
|
|
678
|
+
else:
|
|
679
|
+
current_v = int(current.split(".")[0])
|
|
680
|
+
next_v = current_v + 1
|
|
681
|
+
next_h = random()
|
|
682
|
+
return f"{next_v:032}.{next_h:016}"
|
|
683
|
+
|
|
684
|
+
# ------------------------------------------------------------------ #
|
|
685
|
+
# Convenience factory #
|
|
686
|
+
# ------------------------------------------------------------------ #
|
|
687
|
+
@classmethod
|
|
688
|
+
@asynccontextmanager
|
|
689
|
+
async def from_base_url(
|
|
690
|
+
cls,
|
|
691
|
+
base_url: str,
|
|
692
|
+
*,
|
|
693
|
+
api_key: str | None = None,
|
|
694
|
+
**kwargs: Any,
|
|
695
|
+
) -> AsyncIterator["AsyncAGWCheckpointSaver"]:
|
|
696
|
+
saver = cls(base_url, api_key=api_key, **kwargs)
|
|
697
|
+
try:
|
|
698
|
+
yield saver
|
|
699
|
+
finally:
|
|
700
|
+
if saver._client is not None:
|
|
701
|
+
try:
|
|
702
|
+
await saver._client.aclose()
|
|
703
|
+
except Exception as close_exc: # pragma: no cover - best effort
|
|
704
|
+
logger.debug("Failed to close AGW httpx.AsyncClient: %s", close_exc)
|
|
705
|
+
finally:
|
|
706
|
+
saver._client = None
|
|
@@ -1,444 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import asyncio
|
|
4
|
-
import base64
|
|
5
|
-
import logging
|
|
6
|
-
import os
|
|
7
|
-
from contextlib import asynccontextmanager
|
|
8
|
-
from random import random
|
|
9
|
-
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple
|
|
10
|
-
|
|
11
|
-
import httpx
|
|
12
|
-
import orjson
|
|
13
|
-
from langchain_core.runnables import RunnableConfig
|
|
14
|
-
|
|
15
|
-
from langgraph.checkpoint.base import (
|
|
16
|
-
BaseCheckpointSaver,
|
|
17
|
-
ChannelVersions,
|
|
18
|
-
Checkpoint,
|
|
19
|
-
CheckpointMetadata,
|
|
20
|
-
CheckpointTuple,
|
|
21
|
-
get_checkpoint_id,
|
|
22
|
-
get_checkpoint_metadata,
|
|
23
|
-
)
|
|
24
|
-
from langgraph.checkpoint.serde.base import SerializerProtocol
|
|
25
|
-
from langgraph.checkpoint.serde.encrypted import EncryptedSerializer
|
|
26
|
-
from langgraph.checkpoint.serde.types import ChannelProtocol
|
|
27
|
-
from .serde import Serializer
|
|
28
|
-
|
|
29
|
-
__all__ = ["AsyncAGWCheckpointSaver"]
|
|
30
|
-
|
|
31
|
-
logger = logging.getLogger(__name__)
|
|
32
|
-
|
|
33
|
-
TYPED_KEYS = ("type", "blob")
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def _to_b64(b: bytes | None) -> str | None:
|
|
37
|
-
return base64.b64encode(b).decode() if b is not None else None
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def _b64decode_strict(s: str) -> bytes | None:
|
|
41
|
-
"""Возвращает bytes только если строка действительно корректная base64."""
|
|
42
|
-
try:
|
|
43
|
-
return base64.b64decode(s, validate=True)
|
|
44
|
-
except Exception:
|
|
45
|
-
return None
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
49
|
-
"""Persist checkpoints in Agent-Gateway с помощью `httpx` async client."""
|
|
50
|
-
|
|
51
|
-
# ---------------------------- init / ctx -------------------------
|
|
52
|
-
def __init__(
|
|
53
|
-
self,
|
|
54
|
-
base_url: str = "http://localhost",
|
|
55
|
-
*,
|
|
56
|
-
serde: SerializerProtocol | None = None,
|
|
57
|
-
timeout: int | float = 10,
|
|
58
|
-
api_key: str | None = None,
|
|
59
|
-
extra_headers: Dict[str, str] | None = None,
|
|
60
|
-
verify: bool = True,
|
|
61
|
-
):
|
|
62
|
-
if not serde:
|
|
63
|
-
base_serde: SerializerProtocol = Serializer()
|
|
64
|
-
# опционально оборачиваем в AES по ENV
|
|
65
|
-
_aes_key = (
|
|
66
|
-
os.getenv("LANGGRAPH_AES_KEY")
|
|
67
|
-
or os.getenv("AGW_AES_KEY")
|
|
68
|
-
or os.getenv("AES_KEY")
|
|
69
|
-
)
|
|
70
|
-
if _aes_key:
|
|
71
|
-
base_serde = EncryptedSerializer.from_pycryptodome_aes(
|
|
72
|
-
base_serde, key=_aes_key
|
|
73
|
-
)
|
|
74
|
-
serde = base_serde
|
|
75
|
-
super().__init__(serde=serde)
|
|
76
|
-
self.base_url = base_url.rstrip("/")
|
|
77
|
-
self.timeout = timeout
|
|
78
|
-
self.loop = asyncio.get_running_loop()
|
|
79
|
-
|
|
80
|
-
self.headers: Dict[str, str] = {
|
|
81
|
-
"Accept": "application/json",
|
|
82
|
-
"Content-Type": "application/json",
|
|
83
|
-
}
|
|
84
|
-
if extra_headers:
|
|
85
|
-
self.headers.update(extra_headers)
|
|
86
|
-
if api_key:
|
|
87
|
-
self.headers["Authorization"] = f"Bearer {api_key}"
|
|
88
|
-
|
|
89
|
-
self._client = httpx.AsyncClient(
|
|
90
|
-
base_url=self.base_url,
|
|
91
|
-
headers=self.headers,
|
|
92
|
-
timeout=self.timeout,
|
|
93
|
-
verify=verify,
|
|
94
|
-
trust_env=True,
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
async def __aenter__(self): # noqa: D401
|
|
98
|
-
return self
|
|
99
|
-
|
|
100
|
-
async def __aexit__(self, exc_type, exc, tb): # noqa: D401
|
|
101
|
-
await self._client.aclose()
|
|
102
|
-
|
|
103
|
-
# ----------------------- typed (de)serialize ---------------------
|
|
104
|
-
def _encode_typed(self, value: Any) -> dict[str, Any]:
|
|
105
|
-
"""value -> {"type": str, "blob": base64str | null}"""
|
|
106
|
-
t, b = self.serde.dumps_typed(value)
|
|
107
|
-
return {"type": t, "blob": _to_b64(b)}
|
|
108
|
-
|
|
109
|
-
def _decode_typed(self, obj: Any) -> Any:
|
|
110
|
-
"""{type, blob} | [type, blob] | legacy -> python."""
|
|
111
|
-
# Новый формат: dict с ключами type/blob — только если blob валидная base64 или None
|
|
112
|
-
if isinstance(obj, dict) and all(k in obj for k in TYPED_KEYS):
|
|
113
|
-
t = obj.get("type")
|
|
114
|
-
b64 = obj.get("blob")
|
|
115
|
-
if b64 is None:
|
|
116
|
-
return self.serde.loads_typed((t, None))
|
|
117
|
-
if isinstance(b64, str):
|
|
118
|
-
b = _b64decode_strict(b64)
|
|
119
|
-
if b is not None:
|
|
120
|
-
return self.serde.loads_typed((t, b))
|
|
121
|
-
# если невалидно — падаем ниже на общую обработку
|
|
122
|
-
|
|
123
|
-
# Допускаем tuple/list вида [type, base64] — только при валидной base64
|
|
124
|
-
if isinstance(obj, (list, tuple)) and len(obj) == 2 and isinstance(obj[0], str):
|
|
125
|
-
t, b64 = obj
|
|
126
|
-
if b64 is None and t == "empty":
|
|
127
|
-
return self.serde.loads_typed((t, None))
|
|
128
|
-
if isinstance(b64, str):
|
|
129
|
-
b = _b64decode_strict(b64)
|
|
130
|
-
if b is not None:
|
|
131
|
-
return self.serde.loads_typed((t, b))
|
|
132
|
-
# иначе это не typed-пара
|
|
133
|
-
|
|
134
|
-
# Если это строка — пробуем как base64 строго, затем как JSON-строку
|
|
135
|
-
if isinstance(obj, str):
|
|
136
|
-
b = _b64decode_strict(obj)
|
|
137
|
-
if b is not None:
|
|
138
|
-
try:
|
|
139
|
-
return self.serde.loads(b)
|
|
140
|
-
except Exception:
|
|
141
|
-
pass
|
|
142
|
-
try:
|
|
143
|
-
return self.serde.loads(obj.encode())
|
|
144
|
-
except Exception:
|
|
145
|
-
return obj
|
|
146
|
-
|
|
147
|
-
# dict/list -> считаем это уже JSON и грузим через serde
|
|
148
|
-
if isinstance(obj, (dict, list)):
|
|
149
|
-
try:
|
|
150
|
-
return self.serde.loads(orjson.dumps(obj))
|
|
151
|
-
except Exception:
|
|
152
|
-
return obj
|
|
153
|
-
|
|
154
|
-
# как есть пробуем через serde
|
|
155
|
-
try:
|
|
156
|
-
return self.serde.loads(obj)
|
|
157
|
-
except Exception:
|
|
158
|
-
return obj
|
|
159
|
-
|
|
160
|
-
# ----------------------- config <-> api --------------------------
|
|
161
|
-
def _to_api_config(self, cfg: RunnableConfig | None) -> Dict[str, Any]:
|
|
162
|
-
if not cfg:
|
|
163
|
-
return {}
|
|
164
|
-
c = cfg.get("configurable", {})
|
|
165
|
-
res: Dict[str, Any] = {
|
|
166
|
-
"threadId": c.get("thread_id", ""),
|
|
167
|
-
"checkpointNs": c.get("checkpoint_ns", ""),
|
|
168
|
-
}
|
|
169
|
-
if cid := c.get("checkpoint_id"):
|
|
170
|
-
res["checkpointId"] = cid
|
|
171
|
-
if ts := c.get("thread_ts"):
|
|
172
|
-
res["threadTs"] = ts
|
|
173
|
-
return res
|
|
174
|
-
|
|
175
|
-
# --------------------- checkpoint (de)ser ------------------------
|
|
176
|
-
def _encode_cp(self, cp: Checkpoint) -> Dict[str, Any]:
|
|
177
|
-
channel_values = {
|
|
178
|
-
k: self._encode_typed(v) for k, v in cp.get("channel_values", {}).items()
|
|
179
|
-
}
|
|
180
|
-
pending = []
|
|
181
|
-
for item in cp.get("pending_sends", []) or []:
|
|
182
|
-
try:
|
|
183
|
-
channel, value = item
|
|
184
|
-
pending.append({"channel": channel, **self._encode_typed(value)})
|
|
185
|
-
except Exception:
|
|
186
|
-
continue
|
|
187
|
-
return {
|
|
188
|
-
"v": cp["v"],
|
|
189
|
-
"id": cp["id"],
|
|
190
|
-
"ts": cp["ts"],
|
|
191
|
-
"channelValues": channel_values,
|
|
192
|
-
"channelVersions": cp["channel_versions"],
|
|
193
|
-
"versionsSeen": cp["versions_seen"],
|
|
194
|
-
"pendingSends": pending,
|
|
195
|
-
}
|
|
196
|
-
|
|
197
|
-
def _decode_cp(self, raw: Dict[str, Any]) -> Checkpoint:
|
|
198
|
-
cv_raw = raw.get("channelValues") or {}
|
|
199
|
-
channel_values = {k: self._decode_typed(v) for k, v in cv_raw.items()}
|
|
200
|
-
ps_raw = raw.get("pendingSends") or []
|
|
201
|
-
pending_sends = []
|
|
202
|
-
for obj in ps_raw:
|
|
203
|
-
# ожидаем {channel, type, blob}
|
|
204
|
-
if isinstance(obj, dict) and "channel" in obj:
|
|
205
|
-
ch = obj["channel"]
|
|
206
|
-
typed = {k: obj[k] for k in obj.keys() if k in TYPED_KEYS}
|
|
207
|
-
val = self._decode_typed(typed)
|
|
208
|
-
pending_sends.append((ch, val))
|
|
209
|
-
elif isinstance(obj, (list, tuple)) and len(obj) == 2:
|
|
210
|
-
ch, val = obj
|
|
211
|
-
pending_sends.append((ch, self._decode_typed(val)))
|
|
212
|
-
|
|
213
|
-
return Checkpoint(
|
|
214
|
-
v=raw["v"],
|
|
215
|
-
id=raw["id"],
|
|
216
|
-
ts=raw["ts"],
|
|
217
|
-
channel_values=channel_values,
|
|
218
|
-
channel_versions=raw["channelVersions"],
|
|
219
|
-
versions_seen=raw["versionsSeen"],
|
|
220
|
-
pending_sends=pending_sends,
|
|
221
|
-
)
|
|
222
|
-
|
|
223
|
-
def _decode_config(self, raw: Dict[str, Any] | None) -> Optional[RunnableConfig]:
|
|
224
|
-
if not raw:
|
|
225
|
-
return None
|
|
226
|
-
return RunnableConfig(
|
|
227
|
-
tags=raw.get("tags"),
|
|
228
|
-
metadata=raw.get("metadata"),
|
|
229
|
-
callbacks=raw.get("callbacks"),
|
|
230
|
-
run_name=raw.get("run_name"),
|
|
231
|
-
max_concurrency=raw.get("max_concurrency"),
|
|
232
|
-
recursion_limit=raw.get("recursion_limit"),
|
|
233
|
-
configurable=self._decode_configurable(raw.get("configurable") or {}),
|
|
234
|
-
)
|
|
235
|
-
|
|
236
|
-
def _decode_configurable(self, raw: Dict[str, Any]) -> dict[str, Any]:
|
|
237
|
-
return {
|
|
238
|
-
"thread_id": raw.get("threadId"),
|
|
239
|
-
"thread_ts": raw.get("threadTs"),
|
|
240
|
-
"checkpoint_ns": raw.get("checkpointNs"),
|
|
241
|
-
"checkpoint_id": raw.get("checkpointId"),
|
|
242
|
-
}
|
|
243
|
-
|
|
244
|
-
# metadata (de)ser — передаём как есть (JSON-совместимый словарь)
|
|
245
|
-
def _enc_meta(self, md: CheckpointMetadata) -> CheckpointMetadata:
|
|
246
|
-
return md or {}
|
|
247
|
-
|
|
248
|
-
def _dec_meta(self, md: Any) -> Any:
|
|
249
|
-
return md
|
|
250
|
-
|
|
251
|
-
# ------------------------ HTTP wrapper ---------------------------
|
|
252
|
-
async def _http(self, method: str, path: str, **kw) -> httpx.Response:
|
|
253
|
-
if "json" in kw:
|
|
254
|
-
payload = kw.pop("json")
|
|
255
|
-
kw["data"] = orjson.dumps(payload)
|
|
256
|
-
logger.debug("AGW HTTP payload: %s", kw["data"].decode())
|
|
257
|
-
return await self._client.request(method, path, **kw)
|
|
258
|
-
|
|
259
|
-
# -------------------- api -> CheckpointTuple ----------------------
|
|
260
|
-
def _to_tuple(self, node: Dict[str, Any]) -> CheckpointTuple:
|
|
261
|
-
pending_writes = None
|
|
262
|
-
raw_pw = node.get("pendingWrites")
|
|
263
|
-
if raw_pw:
|
|
264
|
-
decoded: list[tuple[str, str, Any]] = []
|
|
265
|
-
for w in raw_pw:
|
|
266
|
-
if isinstance(w, dict) and "first" in w and "second" in w:
|
|
267
|
-
# ожидаем формат, который возвращает бек: first=task_id, second=channel, third=typed
|
|
268
|
-
task_id = w["first"]
|
|
269
|
-
channel = w["second"]
|
|
270
|
-
tv = w.get("third")
|
|
271
|
-
value = self._decode_typed(tv)
|
|
272
|
-
decoded.append((task_id, channel, value))
|
|
273
|
-
elif isinstance(w, (list, tuple)):
|
|
274
|
-
try:
|
|
275
|
-
first, channel, tv = w
|
|
276
|
-
decoded.append((first, channel, self._decode_typed(tv)))
|
|
277
|
-
except Exception: # pragma: no cover
|
|
278
|
-
continue
|
|
279
|
-
pending_writes = decoded
|
|
280
|
-
|
|
281
|
-
return CheckpointTuple(
|
|
282
|
-
config=self._decode_config(node.get("config")),
|
|
283
|
-
checkpoint=self._decode_cp(node["checkpoint"]),
|
|
284
|
-
metadata=self._dec_meta(node.get("metadata")),
|
|
285
|
-
parent_config=self._decode_config(node.get("parentConfig")),
|
|
286
|
-
pending_writes=pending_writes,
|
|
287
|
-
)
|
|
288
|
-
|
|
289
|
-
# =================================================================
|
|
290
|
-
# async-методы BaseCheckpointSaver
|
|
291
|
-
# =================================================================
|
|
292
|
-
async def aget_tuple(self, cfg: RunnableConfig) -> CheckpointTuple | None:
|
|
293
|
-
cid = get_checkpoint_id(cfg)
|
|
294
|
-
api_cfg = self._to_api_config(cfg)
|
|
295
|
-
tid = api_cfg.get("threadId")
|
|
296
|
-
|
|
297
|
-
if cid:
|
|
298
|
-
path = f"/checkpoint/{tid}/{cid}"
|
|
299
|
-
params = {"checkpointNs": api_cfg.get("checkpointNs", "")}
|
|
300
|
-
else:
|
|
301
|
-
path = f"/checkpoint/{tid}"
|
|
302
|
-
params = None
|
|
303
|
-
|
|
304
|
-
resp = await self._http("GET", path, params=params)
|
|
305
|
-
logger.debug("AGW aget_tuple response: %s", resp.text)
|
|
306
|
-
|
|
307
|
-
if not resp.text or resp.status_code in (404, 406):
|
|
308
|
-
return None
|
|
309
|
-
resp.raise_for_status()
|
|
310
|
-
return self._to_tuple(resp.json())
|
|
311
|
-
|
|
312
|
-
async def alist(
|
|
313
|
-
self,
|
|
314
|
-
cfg: RunnableConfig | None,
|
|
315
|
-
*,
|
|
316
|
-
filter: Dict[str, Any] | None = None,
|
|
317
|
-
before: RunnableConfig | None = None,
|
|
318
|
-
limit: int | None = None,
|
|
319
|
-
) -> AsyncIterator[CheckpointTuple]:
|
|
320
|
-
payload = {
|
|
321
|
-
"config": self._to_api_config(cfg) if cfg else None,
|
|
322
|
-
"filter": filter,
|
|
323
|
-
"before": self._to_api_config(before) if before else None,
|
|
324
|
-
"limit": limit,
|
|
325
|
-
}
|
|
326
|
-
resp = await self._http("POST", "/checkpoint/list", json=payload)
|
|
327
|
-
logger.debug("AGW alist response: %s", resp.text)
|
|
328
|
-
resp.raise_for_status()
|
|
329
|
-
for item in resp.json():
|
|
330
|
-
yield self._to_tuple(item)
|
|
331
|
-
|
|
332
|
-
async def aput(
|
|
333
|
-
self,
|
|
334
|
-
cfg: RunnableConfig,
|
|
335
|
-
cp: Checkpoint,
|
|
336
|
-
metadata: CheckpointMetadata,
|
|
337
|
-
new_versions: ChannelVersions,
|
|
338
|
-
) -> RunnableConfig:
|
|
339
|
-
payload = {
|
|
340
|
-
"config": self._to_api_config(cfg),
|
|
341
|
-
"checkpoint": self._encode_cp(cp),
|
|
342
|
-
"metadata": get_checkpoint_metadata(cfg, metadata),
|
|
343
|
-
"newVersions": new_versions,
|
|
344
|
-
}
|
|
345
|
-
resp = await self._http("POST", "/checkpoint", json=payload)
|
|
346
|
-
logger.debug("AGW aput response: %s", resp.text)
|
|
347
|
-
resp.raise_for_status()
|
|
348
|
-
return resp.json()["config"]
|
|
349
|
-
|
|
350
|
-
async def aput_writes(
|
|
351
|
-
self,
|
|
352
|
-
cfg: RunnableConfig,
|
|
353
|
-
writes: Sequence[Tuple[str, Any]],
|
|
354
|
-
task_id: str,
|
|
355
|
-
task_path: str = "",
|
|
356
|
-
) -> None:
|
|
357
|
-
enc = [{"first": ch, "second": self._encode_typed(v)} for ch, v in writes]
|
|
358
|
-
payload = {
|
|
359
|
-
"config": self._to_api_config(cfg),
|
|
360
|
-
"writes": enc,
|
|
361
|
-
"taskId": task_id,
|
|
362
|
-
"taskPath": task_path,
|
|
363
|
-
}
|
|
364
|
-
resp = await self._http("POST", "/checkpoint/writes", json=payload)
|
|
365
|
-
logger.debug("AGW aput_writes response: %s", resp.text)
|
|
366
|
-
resp.raise_for_status()
|
|
367
|
-
|
|
368
|
-
async def adelete_thread(self, thread_id: str) -> None:
|
|
369
|
-
resp = await self._http("DELETE", f"/checkpoint/{thread_id}")
|
|
370
|
-
resp.raise_for_status()
|
|
371
|
-
|
|
372
|
-
# =================================================================
|
|
373
|
-
# sync-обёртки
|
|
374
|
-
# =================================================================
|
|
375
|
-
def _run(self, coro):
|
|
376
|
-
return asyncio.run_coroutine_threadsafe(coro, self.loop).result()
|
|
377
|
-
|
|
378
|
-
def list(
|
|
379
|
-
self,
|
|
380
|
-
cfg: RunnableConfig | None,
|
|
381
|
-
*,
|
|
382
|
-
filter: Dict[str, Any] | None = None,
|
|
383
|
-
before: RunnableConfig | None = None,
|
|
384
|
-
limit: int | None = None,
|
|
385
|
-
) -> Iterator[CheckpointTuple]:
|
|
386
|
-
aiter_ = self.alist(cfg, filter=filter, before=before, limit=limit)
|
|
387
|
-
while True:
|
|
388
|
-
try:
|
|
389
|
-
yield self._run(anext(aiter_))
|
|
390
|
-
except StopAsyncIteration:
|
|
391
|
-
break
|
|
392
|
-
|
|
393
|
-
def get_tuple(self, cfg: RunnableConfig) -> CheckpointTuple | None:
|
|
394
|
-
return self._run(self.aget_tuple(cfg))
|
|
395
|
-
|
|
396
|
-
def put(
|
|
397
|
-
self,
|
|
398
|
-
cfg: RunnableConfig,
|
|
399
|
-
cp: Checkpoint,
|
|
400
|
-
metadata: CheckpointMetadata,
|
|
401
|
-
new_versions: ChannelVersions,
|
|
402
|
-
) -> RunnableConfig:
|
|
403
|
-
return self._run(self.aput(cfg, cp, metadata, new_versions))
|
|
404
|
-
|
|
405
|
-
def put_writes(
|
|
406
|
-
self,
|
|
407
|
-
cfg: RunnableConfig,
|
|
408
|
-
writes: Sequence[Tuple[str, Any]],
|
|
409
|
-
task_id: str,
|
|
410
|
-
task_path: str = "",
|
|
411
|
-
) -> None:
|
|
412
|
-
self._run(self.aput_writes(cfg, writes, task_id, task_path))
|
|
413
|
-
|
|
414
|
-
def delete_thread(self, thread_id: str) -> None:
|
|
415
|
-
self._run(self.adelete_thread(thread_id))
|
|
416
|
-
|
|
417
|
-
def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
|
|
418
|
-
if current is None:
|
|
419
|
-
current_v = 0
|
|
420
|
-
elif isinstance(current, int):
|
|
421
|
-
current_v = current
|
|
422
|
-
else:
|
|
423
|
-
current_v = int(current.split(".")[0])
|
|
424
|
-
next_v = current_v + 1
|
|
425
|
-
next_h = random()
|
|
426
|
-
return f"{next_v:032}.{next_h:016}"
|
|
427
|
-
|
|
428
|
-
# ------------------------------------------------------------------ #
|
|
429
|
-
# Convenience factory #
|
|
430
|
-
# ------------------------------------------------------------------ #
|
|
431
|
-
@classmethod
|
|
432
|
-
@asynccontextmanager
|
|
433
|
-
async def from_base_url(
|
|
434
|
-
cls,
|
|
435
|
-
base_url: str,
|
|
436
|
-
*,
|
|
437
|
-
api_key: str | None = None,
|
|
438
|
-
**kwargs: Any,
|
|
439
|
-
) -> AsyncIterator["AsyncAGWCheckpointSaver"]:
|
|
440
|
-
saver = cls(base_url, api_key=api_key, **kwargs)
|
|
441
|
-
try:
|
|
442
|
-
yield saver
|
|
443
|
-
finally:
|
|
444
|
-
await saver._client.aclose()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/langgraph/checkpoint/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|