agent-lab-sdk 0.1.35__py3-none-any.whl → 0.1.49__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_lab_sdk/langgraph/checkpoint/agw_saver.py +559 -70
- agent_lab_sdk/langgraph/checkpoint/serde.py +172 -0
- agent_lab_sdk/llm/llm.py +4 -1
- agent_lab_sdk/schema/__init__.py +5 -3
- agent_lab_sdk/schema/input_types.py +103 -50
- agent_lab_sdk/storage/__init__.py +2 -1
- agent_lab_sdk/storage/storage_v2.py +132 -0
- {agent_lab_sdk-0.1.35.dist-info → agent_lab_sdk-0.1.49.dist-info}/METADATA +123 -20
- {agent_lab_sdk-0.1.35.dist-info → agent_lab_sdk-0.1.49.dist-info}/RECORD +12 -10
- {agent_lab_sdk-0.1.35.dist-info → agent_lab_sdk-0.1.49.dist-info}/WHEEL +0 -0
- {agent_lab_sdk-0.1.35.dist-info → agent_lab_sdk-0.1.49.dist-info}/licenses/LICENSE +0 -0
- {agent_lab_sdk-0.1.35.dist-info → agent_lab_sdk-0.1.49.dist-info}/top_level.txt +0 -0
|
@@ -1,19 +1,21 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import orjson
|
|
4
|
-
from random import random
|
|
5
|
-
from langgraph.checkpoint.serde.types import ChannelProtocol
|
|
6
3
|
import asyncio
|
|
4
|
+
import threading
|
|
7
5
|
import base64
|
|
8
|
-
from contextlib import asynccontextmanager
|
|
9
|
-
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple
|
|
10
6
|
import logging
|
|
7
|
+
import os
|
|
8
|
+
from contextlib import asynccontextmanager
|
|
9
|
+
from random import random
|
|
10
|
+
from typing import Any, AsyncIterator, Dict, Iterable, Iterator, Optional, Sequence, Tuple
|
|
11
|
+
|
|
12
|
+
import orjson
|
|
13
|
+
from langgraph.checkpoint.serde.types import ChannelProtocol
|
|
11
14
|
|
|
12
15
|
import httpx
|
|
13
16
|
from langchain_core.runnables import RunnableConfig
|
|
14
17
|
|
|
15
18
|
from langgraph.checkpoint.base import (
|
|
16
|
-
WRITES_IDX_MAP,
|
|
17
19
|
BaseCheckpointSaver,
|
|
18
20
|
ChannelVersions,
|
|
19
21
|
Checkpoint,
|
|
@@ -23,11 +25,44 @@ from langgraph.checkpoint.base import (
|
|
|
23
25
|
get_checkpoint_metadata,
|
|
24
26
|
)
|
|
25
27
|
from langgraph.checkpoint.serde.base import SerializerProtocol
|
|
28
|
+
from langgraph.checkpoint.serde.encrypted import EncryptedSerializer
|
|
29
|
+
|
|
30
|
+
from .serde import Serializer
|
|
31
|
+
from agent_lab_sdk.metrics import get_metric
|
|
26
32
|
|
|
27
33
|
__all__ = ["AsyncAGWCheckpointSaver"]
|
|
28
34
|
|
|
29
35
|
logger = logging.getLogger(__name__)
|
|
30
36
|
|
|
37
|
+
AGW_METRIC_LABELS = ["method", "endpoint"]
|
|
38
|
+
AGW_HTTP_SUCCESS = get_metric(
|
|
39
|
+
"counter",
|
|
40
|
+
"agw_http_success_total",
|
|
41
|
+
"Number of successful AGW HTTP requests",
|
|
42
|
+
labelnames=AGW_METRIC_LABELS,
|
|
43
|
+
)
|
|
44
|
+
AGW_HTTP_ERROR = get_metric(
|
|
45
|
+
"counter",
|
|
46
|
+
"agw_http_error_total",
|
|
47
|
+
"Number of failed AGW HTTP request attempts",
|
|
48
|
+
labelnames=AGW_METRIC_LABELS,
|
|
49
|
+
)
|
|
50
|
+
AGW_HTTP_FINAL_ERROR = get_metric(
|
|
51
|
+
"counter",
|
|
52
|
+
"agw_http_final_error_total",
|
|
53
|
+
"Number of AGW HTTP requests that failed after retries",
|
|
54
|
+
labelnames=AGW_METRIC_LABELS,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
TYPED_KEYS = ("type", "blob")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _b64decode_strict(value: str) -> bytes | None:
|
|
61
|
+
try:
|
|
62
|
+
return base64.b64decode(value, validate=True)
|
|
63
|
+
except Exception:
|
|
64
|
+
return None
|
|
65
|
+
|
|
31
66
|
# ------------------------------------------------------------------ #
|
|
32
67
|
# helpers for Py < 3.10
|
|
33
68
|
# ------------------------------------------------------------------ #
|
|
@@ -53,10 +88,85 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
53
88
|
extra_headers: Dict[str, str] | None = None,
|
|
54
89
|
verify: bool = True,
|
|
55
90
|
):
|
|
91
|
+
if not serde:
|
|
92
|
+
base_serde: SerializerProtocol = Serializer()
|
|
93
|
+
aes_key = (
|
|
94
|
+
os.getenv("LANGGRAPH_AES_KEY")
|
|
95
|
+
or os.getenv("AGW_AES_KEY")
|
|
96
|
+
or os.getenv("AES_KEY")
|
|
97
|
+
)
|
|
98
|
+
if aes_key:
|
|
99
|
+
base_serde = EncryptedSerializer.from_pycryptodome_aes(
|
|
100
|
+
base_serde, key=aes_key
|
|
101
|
+
)
|
|
102
|
+
serde = base_serde
|
|
56
103
|
super().__init__(serde=serde)
|
|
57
104
|
self.base_url = base_url.rstrip("/")
|
|
58
105
|
self.timeout = timeout
|
|
59
|
-
|
|
106
|
+
# Фоновый loop для sync-обёрток
|
|
107
|
+
self._bg_loop: asyncio.AbstractEventLoop | None = None
|
|
108
|
+
self._bg_thread: threading.Thread | None = None
|
|
109
|
+
self._loop_lock = threading.Lock()
|
|
110
|
+
|
|
111
|
+
raw_attempts = os.getenv("AGW_HTTP_MAX_RETRIES")
|
|
112
|
+
if raw_attempts is None:
|
|
113
|
+
self.retry_max_attempts = 3
|
|
114
|
+
else:
|
|
115
|
+
try:
|
|
116
|
+
self.retry_max_attempts = max(int(raw_attempts), 1)
|
|
117
|
+
except ValueError:
|
|
118
|
+
logger.warning(
|
|
119
|
+
"Env %s expected int, got %r; using default %s",
|
|
120
|
+
"AGW_HTTP_MAX_RETRIES",
|
|
121
|
+
raw_attempts,
|
|
122
|
+
3,
|
|
123
|
+
)
|
|
124
|
+
self.retry_max_attempts = 3
|
|
125
|
+
|
|
126
|
+
raw_backoff_base = os.getenv("AGW_HTTP_RETRY_BACKOFF_BASE")
|
|
127
|
+
if raw_backoff_base is None:
|
|
128
|
+
self.retry_backoff_base = 0.5
|
|
129
|
+
else:
|
|
130
|
+
try:
|
|
131
|
+
self.retry_backoff_base = max(float(raw_backoff_base), 0.0)
|
|
132
|
+
except ValueError:
|
|
133
|
+
logger.warning(
|
|
134
|
+
"Env %s expected float, got %r; using default %.3f",
|
|
135
|
+
"AGW_HTTP_RETRY_BACKOFF_BASE",
|
|
136
|
+
raw_backoff_base,
|
|
137
|
+
0.5,
|
|
138
|
+
)
|
|
139
|
+
self.retry_backoff_base = 0.5
|
|
140
|
+
|
|
141
|
+
raw_backoff_max = os.getenv("AGW_HTTP_RETRY_BACKOFF_MAX")
|
|
142
|
+
if raw_backoff_max is None:
|
|
143
|
+
self.retry_backoff_max = 5.0
|
|
144
|
+
else:
|
|
145
|
+
try:
|
|
146
|
+
self.retry_backoff_max = max(float(raw_backoff_max), 0.0)
|
|
147
|
+
except ValueError:
|
|
148
|
+
logger.warning(
|
|
149
|
+
"Env %s expected float, got %r; using default %.3f",
|
|
150
|
+
"AGW_HTTP_RETRY_BACKOFF_MAX",
|
|
151
|
+
raw_backoff_max,
|
|
152
|
+
5.0,
|
|
153
|
+
)
|
|
154
|
+
self.retry_backoff_max = 5.0
|
|
155
|
+
|
|
156
|
+
raw_jitter = os.getenv("AGW_HTTP_RETRY_JITTER")
|
|
157
|
+
if raw_jitter is None:
|
|
158
|
+
self.retry_jitter = 0.25
|
|
159
|
+
else:
|
|
160
|
+
try:
|
|
161
|
+
self.retry_jitter = max(float(raw_jitter), 0.0)
|
|
162
|
+
except ValueError:
|
|
163
|
+
logger.warning(
|
|
164
|
+
"Env %s expected float, got %r; using default %.3f",
|
|
165
|
+
"AGW_HTTP_RETRY_JITTER",
|
|
166
|
+
raw_jitter,
|
|
167
|
+
0.25,
|
|
168
|
+
)
|
|
169
|
+
self.retry_jitter = 0.25
|
|
60
170
|
|
|
61
171
|
self.headers: Dict[str, str] = {
|
|
62
172
|
"Accept": "application/json",
|
|
@@ -66,70 +176,279 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
66
176
|
self.headers.update(extra_headers)
|
|
67
177
|
if api_key:
|
|
68
178
|
self.headers["Authorization"] = f"Bearer {api_key}"
|
|
69
|
-
|
|
70
|
-
self.
|
|
179
|
+
|
|
180
|
+
self._verify = verify
|
|
181
|
+
self._client: httpx.AsyncClient | None = None
|
|
182
|
+
self._client_loop: asyncio.AbstractEventLoop | None = None
|
|
183
|
+
|
|
184
|
+
def _create_client(self) -> httpx.AsyncClient:
|
|
185
|
+
return httpx.AsyncClient(
|
|
71
186
|
base_url=self.base_url,
|
|
72
187
|
headers=self.headers,
|
|
73
188
|
timeout=self.timeout,
|
|
74
|
-
verify=
|
|
75
|
-
trust_env=True
|
|
189
|
+
verify=self._verify,
|
|
190
|
+
trust_env=True,
|
|
76
191
|
)
|
|
77
192
|
|
|
193
|
+
def _ensure_bg_loop(self) -> asyncio.AbstractEventLoop:
|
|
194
|
+
with self._loop_lock:
|
|
195
|
+
if self._bg_loop and self._bg_loop.is_running():
|
|
196
|
+
return self._bg_loop
|
|
197
|
+
|
|
198
|
+
loop = asyncio.new_event_loop()
|
|
199
|
+
|
|
200
|
+
def runner():
|
|
201
|
+
asyncio.set_event_loop(loop)
|
|
202
|
+
loop.run_forever()
|
|
203
|
+
|
|
204
|
+
t = threading.Thread(target=runner, name="agw-checkpoint-loop", daemon=True)
|
|
205
|
+
t.start()
|
|
206
|
+
|
|
207
|
+
self._bg_loop = loop
|
|
208
|
+
self._bg_thread = t
|
|
209
|
+
return loop
|
|
210
|
+
|
|
211
|
+
def _compute_retry_delay(self, attempt: int) -> float:
|
|
212
|
+
if attempt <= 0:
|
|
213
|
+
attempt = 1
|
|
214
|
+
if self.retry_backoff_base <= 0:
|
|
215
|
+
delay = 0.0
|
|
216
|
+
else:
|
|
217
|
+
delay = self.retry_backoff_base * (2 ** (attempt - 1))
|
|
218
|
+
if self.retry_backoff_max > 0:
|
|
219
|
+
delay = min(delay, self.retry_backoff_max)
|
|
220
|
+
if self.retry_jitter > 0:
|
|
221
|
+
delay += self.retry_jitter * random()
|
|
222
|
+
return delay
|
|
223
|
+
|
|
78
224
|
async def __aenter__(self): # noqa: D401
|
|
79
225
|
return self
|
|
80
226
|
|
|
81
227
|
async def __aexit__(self, exc_type, exc, tb): # noqa: D401
|
|
82
|
-
|
|
228
|
+
if self._client is not None:
|
|
229
|
+
try:
|
|
230
|
+
await self._client.aclose()
|
|
231
|
+
except Exception as close_exc: # pragma: no cover - best effort
|
|
232
|
+
logger.debug("Failed to close AGW httpx.AsyncClient: %s", close_exc)
|
|
233
|
+
finally:
|
|
234
|
+
self._client = None
|
|
235
|
+
self._client_loop = None
|
|
236
|
+
# останавливаем фоновый loop, если поднимали
|
|
237
|
+
if self._bg_loop is not None:
|
|
238
|
+
try:
|
|
239
|
+
self._bg_loop.call_soon_threadsafe(self._bg_loop.stop)
|
|
240
|
+
finally:
|
|
241
|
+
self._bg_loop = None
|
|
242
|
+
self._bg_thread = None
|
|
83
243
|
|
|
84
244
|
# ----------------------- universal dump/load ---------------------
|
|
85
|
-
# def _safe_dump(self, obj: Any) -> Any:
|
|
86
|
-
# """self.serde.dump → гарантированная JSON-строка."""
|
|
87
|
-
# dumped = self.serde.dumps(obj)
|
|
88
|
-
# if isinstance(dumped, (bytes, bytearray)):
|
|
89
|
-
# return base64.b64encode(dumped).decode() # str
|
|
90
|
-
# return dumped # уже json-совместимо
|
|
91
|
-
|
|
92
245
|
def _safe_dump(self, obj: Any) -> Any:
|
|
93
|
-
"""
|
|
94
|
-
|
|
246
|
+
"""
|
|
247
|
+
JSON-first сериализация:
|
|
248
|
+
1) Пытаемся через self.serde.dumps(obj).
|
|
249
|
+
- Если вернул bytes: пробуем декодировать в JSON-строку и распарсить.
|
|
250
|
+
- Если не JSON/не UTF-8 — заворачиваем как base64-строку.
|
|
251
|
+
- Если вернул dict/list/scalar — они уже JSON-совместимы.
|
|
252
|
+
2) Если self.serde.dumps(obj) бросает исключение (например, для Send),
|
|
253
|
+
делаем типизированный фолбэк {"type": str, "blob": base64 | None}.
|
|
254
|
+
"""
|
|
255
|
+
try:
|
|
256
|
+
dumped = self.serde.dumps(obj)
|
|
257
|
+
except Exception:
|
|
258
|
+
# typed fallback (как рекомендуют в LangGraph для нетривиальных типов)
|
|
259
|
+
# https://langchain-ai.github.io/langgraph/reference/checkpoints/
|
|
260
|
+
try:
|
|
261
|
+
t, b = self.serde.dumps_typed(obj)
|
|
262
|
+
except Exception:
|
|
263
|
+
# крайний случай: строковое представление
|
|
264
|
+
t, b = type(obj).__name__, str(obj).encode()
|
|
265
|
+
return {"type": t, "blob": base64.b64encode(b).decode() if b is not None else None}
|
|
266
|
+
|
|
95
267
|
if isinstance(dumped, (bytes, bytearray)):
|
|
96
268
|
try:
|
|
97
|
-
# 1) bytes → str
|
|
98
269
|
s = dumped.decode()
|
|
99
|
-
# 2) str JSON → python (list/dict/scalar)
|
|
100
270
|
return orjson.loads(s)
|
|
101
271
|
except (UnicodeDecodeError, orjson.JSONDecodeError):
|
|
102
|
-
# не UTF-8 или не JSON → base64
|
|
103
272
|
return base64.b64encode(dumped).decode()
|
|
104
273
|
return dumped
|
|
105
274
|
|
|
106
275
|
def _safe_load(self, obj: Any) -> Any:
|
|
107
|
-
if
|
|
108
|
-
return
|
|
276
|
+
if obj is None:
|
|
277
|
+
return None
|
|
278
|
+
|
|
279
|
+
if isinstance(obj, dict):
|
|
280
|
+
if all(k in obj for k in TYPED_KEYS):
|
|
281
|
+
t = obj.get("type")
|
|
282
|
+
blob = obj.get("blob")
|
|
283
|
+
if blob is None:
|
|
284
|
+
try:
|
|
285
|
+
return self.serde.loads_typed((t, None))
|
|
286
|
+
except Exception:
|
|
287
|
+
return obj
|
|
288
|
+
if isinstance(blob, str):
|
|
289
|
+
payload = _b64decode_strict(blob)
|
|
290
|
+
if payload is not None:
|
|
291
|
+
try:
|
|
292
|
+
return self.serde.loads_typed((t, payload))
|
|
293
|
+
except Exception:
|
|
294
|
+
# fall back to generic handling below
|
|
295
|
+
pass
|
|
296
|
+
try:
|
|
297
|
+
return self.serde.loads(orjson.dumps(obj))
|
|
298
|
+
except Exception:
|
|
299
|
+
return obj
|
|
300
|
+
|
|
301
|
+
if isinstance(obj, (list, tuple)):
|
|
302
|
+
if (
|
|
303
|
+
len(obj) == 2
|
|
304
|
+
and isinstance(obj[0], str)
|
|
305
|
+
and (obj[1] is None or isinstance(obj[1], str))
|
|
306
|
+
):
|
|
307
|
+
blob = obj[1]
|
|
308
|
+
if blob is None:
|
|
309
|
+
try:
|
|
310
|
+
return self.serde.loads_typed((obj[0], None))
|
|
311
|
+
except Exception:
|
|
312
|
+
pass
|
|
313
|
+
elif isinstance(blob, str):
|
|
314
|
+
payload = _b64decode_strict(blob)
|
|
315
|
+
if payload is not None:
|
|
316
|
+
try:
|
|
317
|
+
return self.serde.loads_typed((obj[0], payload))
|
|
318
|
+
except Exception:
|
|
319
|
+
pass
|
|
320
|
+
try:
|
|
321
|
+
return self.serde.loads(orjson.dumps(list(obj)))
|
|
322
|
+
except Exception:
|
|
323
|
+
return obj
|
|
324
|
+
|
|
109
325
|
if isinstance(obj, str):
|
|
110
|
-
# сначала plain JSON-строка
|
|
111
326
|
try:
|
|
112
327
|
return self.serde.loads(obj.encode())
|
|
113
328
|
except Exception:
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
329
|
+
payload = _b64decode_strict(obj)
|
|
330
|
+
if payload is not None:
|
|
331
|
+
try:
|
|
332
|
+
return self.serde.loads(payload)
|
|
333
|
+
except Exception:
|
|
334
|
+
pass
|
|
335
|
+
return obj
|
|
336
|
+
|
|
119
337
|
try:
|
|
120
338
|
return self.serde.loads(obj)
|
|
121
339
|
except Exception:
|
|
122
340
|
return obj
|
|
123
341
|
|
|
124
|
-
#
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
342
|
+
# ----------------------- deep dump/load (leaf-first) -------------
|
|
343
|
+
@staticmethod
|
|
344
|
+
def _is_json_scalar(x: Any) -> bool:
|
|
345
|
+
return x is None or isinstance(x, (str, int, float, bool))
|
|
346
|
+
|
|
347
|
+
@staticmethod
|
|
348
|
+
def _coerce_key(k: Any) -> str:
|
|
349
|
+
return k if isinstance(k, str) else str(k)
|
|
350
|
+
|
|
351
|
+
def _safe_dump_deep(self, obj: Any, _seen: set[int] | None = None) -> Any:
|
|
352
|
+
"""
|
|
353
|
+
Идём от листьев к корню:
|
|
354
|
+
- Для контейнеров рекурсируем внутрь и сохраняем форму контейнера.
|
|
355
|
+
- Для листьев вызываем _safe_dump (fallback на serde/typed + base64).
|
|
356
|
+
"""
|
|
357
|
+
if _seen is None:
|
|
358
|
+
_seen = set()
|
|
359
|
+
|
|
360
|
+
if self._is_json_scalar(obj):
|
|
361
|
+
return obj
|
|
362
|
+
|
|
363
|
+
if isinstance(obj, dict):
|
|
364
|
+
oid = id(obj)
|
|
365
|
+
if oid in _seen:
|
|
366
|
+
return {"type": "Cycle", "blob": None}
|
|
367
|
+
_seen.add(oid)
|
|
368
|
+
return {
|
|
369
|
+
self._coerce_key(k): self._safe_dump_deep(v, _seen) for k, v in obj.items()
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
if isinstance(obj, (list, tuple, set)):
|
|
373
|
+
oid = id(obj)
|
|
374
|
+
if oid in _seen:
|
|
375
|
+
return ["<cycle>"]
|
|
376
|
+
_seen.add(oid)
|
|
377
|
+
return [self._safe_dump_deep(v, _seen) for v in obj]
|
|
378
|
+
|
|
379
|
+
# лист: доверяем универсальному дамперу
|
|
380
|
+
return self._safe_dump(obj)
|
|
381
|
+
|
|
382
|
+
def _safe_load_deep(self, obj: Any, _seen: set[int] | None = None) -> Any:
|
|
383
|
+
"""
|
|
384
|
+
Обратная операция:
|
|
385
|
+
- Контейнеры сначала пробуем целиком скормить serde.loads(...).
|
|
386
|
+
Если вернулся НЕ JSON-контейнер (например, объект сообщения) — возвращаем его.
|
|
387
|
+
Иначе рекурсивно обходим внутрь и листья скармливаем _safe_load.
|
|
388
|
+
- typed {"type","blob"} обрабатываем как раньше.
|
|
389
|
+
"""
|
|
390
|
+
if _seen is None:
|
|
391
|
+
_seen = set()
|
|
392
|
+
|
|
393
|
+
# Примитивы: просто через _safe_load (декод base64/bytes и т.п.)
|
|
394
|
+
if self._is_json_scalar(obj):
|
|
395
|
+
return self._safe_load(obj)
|
|
396
|
+
|
|
397
|
+
# dict
|
|
398
|
+
if isinstance(obj, dict):
|
|
399
|
+
# типизированная обёртка — сразу разворачиваем
|
|
400
|
+
if all(k in obj for k in TYPED_KEYS):
|
|
401
|
+
return self._safe_load(obj)
|
|
402
|
+
|
|
403
|
+
# 1) parse-first: пробуем целиком восстановить объект через serde
|
|
404
|
+
try:
|
|
405
|
+
parsed = self.serde.loads(orjson.dumps(obj))
|
|
406
|
+
# если получили не-JSON-контейнер (объект), возвращаем
|
|
407
|
+
if not isinstance(parsed, (dict, list, tuple, str, int, float, bool, type(None))):
|
|
408
|
+
return parsed
|
|
409
|
+
except Exception:
|
|
410
|
+
pass
|
|
411
|
+
|
|
412
|
+
# 2) иначе — рекурсивно
|
|
413
|
+
oid = id(obj)
|
|
414
|
+
if oid in _seen:
|
|
415
|
+
return obj
|
|
416
|
+
_seen.add(oid)
|
|
417
|
+
return {k: self._safe_load_deep(v, _seen) for k, v in obj.items()}
|
|
418
|
+
|
|
419
|
+
# list
|
|
420
|
+
if isinstance(obj, list):
|
|
421
|
+
# parse-first: пытаемся восстановить весь список одной операцией
|
|
422
|
+
try:
|
|
423
|
+
parsed = self.serde.loads(orjson.dumps(obj))
|
|
424
|
+
if not isinstance(parsed, (dict, list, tuple, str, int, float, bool, type(None))):
|
|
425
|
+
return parsed
|
|
426
|
+
except Exception:
|
|
427
|
+
pass
|
|
428
|
+
|
|
429
|
+
oid = id(obj)
|
|
430
|
+
if oid in _seen:
|
|
431
|
+
return obj
|
|
432
|
+
_seen.add(oid)
|
|
433
|
+
return [self._safe_load_deep(v, _seen) for v in obj]
|
|
434
|
+
|
|
435
|
+
# tuple — аналогично list, но вернём list (JSON-совместимо)
|
|
436
|
+
if isinstance(obj, tuple):
|
|
437
|
+
try:
|
|
438
|
+
parsed = self.serde.loads(orjson.dumps(obj))
|
|
439
|
+
if not isinstance(parsed, (dict, list, tuple, str, int, float, bool, type(None))):
|
|
440
|
+
return parsed
|
|
441
|
+
except Exception:
|
|
442
|
+
pass
|
|
443
|
+
|
|
444
|
+
oid = id(obj)
|
|
445
|
+
if oid in _seen:
|
|
446
|
+
return obj
|
|
447
|
+
_seen.add(oid)
|
|
448
|
+
return [self._safe_load_deep(v, _seen) for v in obj]
|
|
449
|
+
|
|
450
|
+
# Всё остальное — лист, через _safe_load
|
|
451
|
+
return self._safe_load(obj)
|
|
133
452
|
|
|
134
453
|
# ----------------------- config <-> api --------------------------
|
|
135
454
|
def _to_api_config(self, cfg: RunnableConfig | None) -> Dict[str, Any]:
|
|
@@ -152,24 +471,44 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
152
471
|
"v": cp["v"],
|
|
153
472
|
"id": cp["id"],
|
|
154
473
|
"ts": cp["ts"],
|
|
155
|
-
"channelValues": {
|
|
474
|
+
"channelValues": {
|
|
475
|
+
k: self._safe_dump_deep(v) for k, v in cp["channel_values"].items()
|
|
476
|
+
},
|
|
156
477
|
"channelVersions": cp["channel_versions"],
|
|
157
478
|
"versionsSeen": cp["versions_seen"],
|
|
158
|
-
"pendingSends":
|
|
479
|
+
"pendingSends": [] # как в BasePostgresSaver, они внутри checkpoint не нужны
|
|
159
480
|
}
|
|
160
481
|
|
|
161
482
|
def _decode_cp(self, raw: Dict[str, Any]) -> Checkpoint:
|
|
483
|
+
# Поддерживаем приём pendingSends (если сервер их отдаёт),
|
|
484
|
+
# но сами их не шлём при записи.
|
|
485
|
+
pending_sends: list[Tuple[str, Any]] = []
|
|
486
|
+
for obj in raw.get("pendingSends", []) or []:
|
|
487
|
+
if isinstance(obj, dict) and "channel" in obj:
|
|
488
|
+
channel = obj["channel"]
|
|
489
|
+
value_payload: Any = obj.get("value")
|
|
490
|
+
if value_payload is None and all(k in obj for k in TYPED_KEYS):
|
|
491
|
+
value_payload = {k: obj[k] for k in TYPED_KEYS}
|
|
492
|
+
pending_sends.append((channel, self._safe_load_deep(value_payload)))
|
|
493
|
+
elif isinstance(obj, (list, tuple)) and len(obj) >= 2:
|
|
494
|
+
channel = obj[0]
|
|
495
|
+
value_payload = obj[1]
|
|
496
|
+
pending_sends.append((channel, self._safe_load_deep(value_payload)))
|
|
497
|
+
else:
|
|
498
|
+
pending_sends.append(obj)
|
|
162
499
|
return Checkpoint(
|
|
163
500
|
v=raw["v"],
|
|
164
501
|
id=raw["id"],
|
|
165
502
|
ts=raw["ts"],
|
|
166
|
-
channel_values={
|
|
503
|
+
channel_values={
|
|
504
|
+
k: self._safe_load_deep(v) for k, v in raw["channelValues"].items()
|
|
505
|
+
},
|
|
167
506
|
channel_versions=raw["channelVersions"],
|
|
168
507
|
versions_seen=raw["versionsSeen"],
|
|
169
|
-
pending_sends=
|
|
508
|
+
pending_sends=pending_sends,
|
|
170
509
|
)
|
|
171
510
|
|
|
172
|
-
def _decode_config(self, raw: Dict[str, Any]) -> Optional[RunnableConfig]:
|
|
511
|
+
def _decode_config(self, raw: Dict[str, Any] | None) -> Optional[RunnableConfig]:
|
|
173
512
|
if not raw:
|
|
174
513
|
return None
|
|
175
514
|
return RunnableConfig(
|
|
@@ -179,7 +518,7 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
179
518
|
run_name=raw.get("run_name"),
|
|
180
519
|
max_concurrency=raw.get("max_concurrency"),
|
|
181
520
|
recursion_limit=raw.get("recursion_limit"),
|
|
182
|
-
configurable=self._decode_configurable(raw.get("configurable"))
|
|
521
|
+
configurable=self._decode_configurable(raw.get("configurable") or {}),
|
|
183
522
|
)
|
|
184
523
|
|
|
185
524
|
def _decode_configurable(self, raw: Dict[str, Any]) -> dict[str, Any]:
|
|
@@ -192,34 +531,157 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
192
531
|
|
|
193
532
|
# metadata (de)ser
|
|
194
533
|
def _enc_meta(self, md: CheckpointMetadata) -> CheckpointMetadata:
|
|
534
|
+
if not md:
|
|
535
|
+
return {}
|
|
195
536
|
out: CheckpointMetadata = {}
|
|
196
537
|
for k, v in md.items():
|
|
197
|
-
out[k] = self._enc_meta(v) if isinstance(v, dict) else self.
|
|
538
|
+
out[k] = self._enc_meta(v) if isinstance(v, dict) else self._safe_dump_deep(v) # type: ignore[assignment]
|
|
198
539
|
return out
|
|
199
540
|
|
|
200
541
|
def _dec_meta(self, md: Any) -> Any:
|
|
201
542
|
if isinstance(md, dict):
|
|
202
543
|
return {k: self._dec_meta(v) for k, v in md.items()}
|
|
203
|
-
return self.
|
|
544
|
+
return self._safe_load_deep(md)
|
|
204
545
|
|
|
205
546
|
# ------------------------ HTTP wrapper ---------------------------
|
|
206
|
-
async def _http(
|
|
547
|
+
async def _http(
|
|
548
|
+
self,
|
|
549
|
+
method: str,
|
|
550
|
+
path: str,
|
|
551
|
+
*,
|
|
552
|
+
ok_statuses: Iterable[int] | None = None,
|
|
553
|
+
label_path: str | None = None,
|
|
554
|
+
**kw,
|
|
555
|
+
) -> httpx.Response:
|
|
207
556
|
if "json" in kw:
|
|
208
557
|
payload = kw.pop("json")
|
|
209
558
|
kw["data"] = orjson.dumps(payload)
|
|
210
|
-
logger.
|
|
211
|
-
|
|
212
|
-
|
|
559
|
+
logger.debug("AGW HTTP payload: %s", kw["data"].decode())
|
|
560
|
+
|
|
561
|
+
ok_set = set(ok_statuses) if ok_statuses is not None else set()
|
|
562
|
+
metric_path = label_path or path
|
|
563
|
+
|
|
564
|
+
attempt = 1
|
|
565
|
+
while True:
|
|
566
|
+
# клиент должен принадлежать текущему loop
|
|
567
|
+
current_loop = asyncio.get_running_loop()
|
|
568
|
+
client = self._client
|
|
569
|
+
if client is None or client.is_closed or self._client_loop is not current_loop:
|
|
570
|
+
if client is not None:
|
|
571
|
+
try:
|
|
572
|
+
await client.aclose()
|
|
573
|
+
except Exception as e:
|
|
574
|
+
logger.exception("ошибка при закрытии клиента", e)
|
|
575
|
+
client = self._create_client()
|
|
576
|
+
self._client = client
|
|
577
|
+
self._client_loop = current_loop
|
|
578
|
+
try:
|
|
579
|
+
resp = await client.request(method, path, **kw)
|
|
580
|
+
except httpx.RequestError as exc:
|
|
581
|
+
AGW_HTTP_ERROR.labels(method, metric_path).inc()
|
|
582
|
+
logger.warning(
|
|
583
|
+
"AGW request %s %s failed on attempt %d/%d: %s",
|
|
584
|
+
method,
|
|
585
|
+
path,
|
|
586
|
+
attempt,
|
|
587
|
+
self.retry_max_attempts,
|
|
588
|
+
exc,
|
|
589
|
+
)
|
|
590
|
+
if attempt >= self.retry_max_attempts:
|
|
591
|
+
AGW_HTTP_FINAL_ERROR.labels(method, metric_path).inc()
|
|
592
|
+
if self._client is not None:
|
|
593
|
+
try:
|
|
594
|
+
await self._client.aclose()
|
|
595
|
+
except Exception as close_exc: # pragma: no cover
|
|
596
|
+
logger.debug(
|
|
597
|
+
"Failed to close AGW httpx.AsyncClient: %s",
|
|
598
|
+
close_exc,
|
|
599
|
+
)
|
|
600
|
+
finally:
|
|
601
|
+
self._client = None
|
|
602
|
+
self._client_loop = None
|
|
603
|
+
raise
|
|
604
|
+
|
|
605
|
+
if self._client is not None:
|
|
606
|
+
try:
|
|
607
|
+
await self._client.aclose()
|
|
608
|
+
except Exception as close_exc: # pragma: no cover
|
|
609
|
+
logger.debug(
|
|
610
|
+
"Failed to close AGW httpx.AsyncClient: %s",
|
|
611
|
+
close_exc,
|
|
612
|
+
)
|
|
613
|
+
finally:
|
|
614
|
+
self._client = None
|
|
615
|
+
self._client_loop = None
|
|
616
|
+
delay = self._compute_retry_delay(attempt)
|
|
617
|
+
if delay > 0:
|
|
618
|
+
await asyncio.sleep(delay)
|
|
619
|
+
attempt += 1
|
|
620
|
+
continue
|
|
621
|
+
|
|
622
|
+
status = resp.status_code
|
|
623
|
+
if status < 400 or status in ok_set:
|
|
624
|
+
AGW_HTTP_SUCCESS.labels(method, metric_path).inc()
|
|
625
|
+
return resp
|
|
626
|
+
|
|
627
|
+
AGW_HTTP_ERROR.labels(method, metric_path).inc()
|
|
628
|
+
if status in (404, 406):
|
|
629
|
+
AGW_HTTP_FINAL_ERROR.labels(method, metric_path).inc()
|
|
630
|
+
return resp
|
|
631
|
+
|
|
632
|
+
if attempt >= self.retry_max_attempts:
|
|
633
|
+
AGW_HTTP_FINAL_ERROR.labels(method, metric_path).inc()
|
|
634
|
+
return resp
|
|
635
|
+
|
|
636
|
+
try:
|
|
637
|
+
await resp.aclose()
|
|
638
|
+
except Exception as exc: # pragma: no cover - best effort
|
|
639
|
+
logger.debug("Failed to close AGW httpx.Response before retry: %s", exc)
|
|
640
|
+
|
|
641
|
+
if self._client is not None:
|
|
642
|
+
try:
|
|
643
|
+
await self._client.aclose()
|
|
644
|
+
except Exception as close_exc: # pragma: no cover
|
|
645
|
+
logger.debug(
|
|
646
|
+
"Failed to close AGW httpx.AsyncClient: %s",
|
|
647
|
+
close_exc,
|
|
648
|
+
)
|
|
649
|
+
finally:
|
|
650
|
+
self._client = None
|
|
651
|
+
self._client_loop = None
|
|
652
|
+
delay = self._compute_retry_delay(attempt)
|
|
653
|
+
if delay > 0:
|
|
654
|
+
await asyncio.sleep(delay)
|
|
655
|
+
attempt += 1
|
|
213
656
|
|
|
214
657
|
# -------------------- api -> CheckpointTuple ----------------------
|
|
215
658
|
def _to_tuple(self, node: Dict[str, Any]) -> CheckpointTuple:
|
|
216
659
|
pending = None
|
|
217
660
|
if node.get("pendingWrites"):
|
|
218
|
-
pending = [
|
|
661
|
+
pending = []
|
|
662
|
+
for w in node["pendingWrites"]:
|
|
663
|
+
if isinstance(w, dict):
|
|
664
|
+
first = w.get("first")
|
|
665
|
+
second = w.get("second")
|
|
666
|
+
third = w.get("third")
|
|
667
|
+
if third is None and isinstance(second, dict) and all(
|
|
668
|
+
k in second for k in TYPED_KEYS
|
|
669
|
+
):
|
|
670
|
+
third = second
|
|
671
|
+
pending.append((first, second, self._safe_load_deep(third)))
|
|
672
|
+
elif isinstance(w, (list, tuple)):
|
|
673
|
+
if len(w) == 3:
|
|
674
|
+
first, second, third = w
|
|
675
|
+
elif len(w) == 2:
|
|
676
|
+
first, second = w
|
|
677
|
+
third = None
|
|
678
|
+
else:
|
|
679
|
+
continue
|
|
680
|
+
pending.append((first, second, self._safe_load_deep(third)))
|
|
219
681
|
return CheckpointTuple(
|
|
220
|
-
config=self._decode_config(node
|
|
682
|
+
config=self._decode_config(node.get("config")),
|
|
221
683
|
checkpoint=self._decode_cp(node["checkpoint"]),
|
|
222
|
-
metadata=self._dec_meta(node
|
|
684
|
+
metadata=self._dec_meta(node.get("metadata")),
|
|
223
685
|
parent_config=self._decode_config(node.get("parentConfig")),
|
|
224
686
|
pending_writes=pending,
|
|
225
687
|
)
|
|
@@ -233,13 +695,15 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
233
695
|
tid = api_cfg["threadId"]
|
|
234
696
|
|
|
235
697
|
if cid:
|
|
698
|
+
path_template = "/checkpoint/{threadId}/{checkpointId}"
|
|
236
699
|
path = f"/checkpoint/{tid}/{cid}"
|
|
237
700
|
params = {"checkpointNs": api_cfg.get("checkpointNs", "")}
|
|
238
701
|
else:
|
|
702
|
+
path_template = "/checkpoint/{threadId}"
|
|
239
703
|
path = f"/checkpoint/{tid}"
|
|
240
704
|
params = None
|
|
241
705
|
|
|
242
|
-
resp = await self._http("GET", path, params=params)
|
|
706
|
+
resp = await self._http("GET", path, params=params, label_path=path_template)
|
|
243
707
|
logger.debug("AGW aget_tuple response: %s", resp.text)
|
|
244
708
|
|
|
245
709
|
if not resp.text:
|
|
@@ -263,7 +727,12 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
263
727
|
"before": self._to_api_config(before) if before else None,
|
|
264
728
|
"limit": limit,
|
|
265
729
|
}
|
|
266
|
-
resp = await self._http(
|
|
730
|
+
resp = await self._http(
|
|
731
|
+
"POST",
|
|
732
|
+
"/checkpoint/list",
|
|
733
|
+
json=payload,
|
|
734
|
+
label_path="/checkpoint/list",
|
|
735
|
+
)
|
|
267
736
|
logger.debug("AGW alist response: %s", resp.text)
|
|
268
737
|
resp.raise_for_status()
|
|
269
738
|
for item in resp.json():
|
|
@@ -282,7 +751,7 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
282
751
|
"metadata": self._enc_meta(get_checkpoint_metadata(cfg, metadata)),
|
|
283
752
|
"newVersions": new_versions,
|
|
284
753
|
}
|
|
285
|
-
resp = await self._http("POST", "/checkpoint", json=payload)
|
|
754
|
+
resp = await self._http("POST", "/checkpoint", json=payload, label_path="/checkpoint")
|
|
286
755
|
logger.debug("AGW aput response: %s", resp.text)
|
|
287
756
|
resp.raise_for_status()
|
|
288
757
|
return resp.json()["config"]
|
|
@@ -294,26 +763,38 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
294
763
|
task_id: str,
|
|
295
764
|
task_path: str = "",
|
|
296
765
|
) -> None:
|
|
297
|
-
enc = [{"first": ch, "second": self.
|
|
766
|
+
enc = [{"first": ch, "second": self._safe_dump_deep(v)} for ch, v in writes]
|
|
298
767
|
payload = {
|
|
299
768
|
"config": self._to_api_config(cfg),
|
|
300
769
|
"writes": enc,
|
|
301
770
|
"taskId": task_id,
|
|
302
771
|
"taskPath": task_path,
|
|
303
772
|
}
|
|
304
|
-
resp = await self._http(
|
|
773
|
+
resp = await self._http(
|
|
774
|
+
"POST",
|
|
775
|
+
"/checkpoint/writes",
|
|
776
|
+
json=payload,
|
|
777
|
+
label_path="/checkpoint/writes",
|
|
778
|
+
)
|
|
305
779
|
logger.debug("AGW aput_writes response: %s", resp.text)
|
|
306
780
|
resp.raise_for_status()
|
|
307
781
|
|
|
308
782
|
async def adelete_thread(self, thread_id: str) -> None:
|
|
309
|
-
resp = await self._http(
|
|
783
|
+
resp = await self._http(
|
|
784
|
+
"DELETE",
|
|
785
|
+
f"/checkpoint/{thread_id}",
|
|
786
|
+
label_path="/checkpoint/{threadId}",
|
|
787
|
+
)
|
|
310
788
|
resp.raise_for_status()
|
|
311
789
|
|
|
312
790
|
# =================================================================
|
|
313
791
|
# sync-обёртки
|
|
314
792
|
# =================================================================
|
|
315
793
|
def _run(self, coro):
|
|
316
|
-
|
|
794
|
+
# sync-обёртки всегда выполняем в собственном loop в отдельном потоке
|
|
795
|
+
loop = self._ensure_bg_loop()
|
|
796
|
+
fut = asyncio.run_coroutine_threadsafe(coro, loop)
|
|
797
|
+
return fut.result()
|
|
317
798
|
|
|
318
799
|
def list(
|
|
319
800
|
self,
|
|
@@ -323,12 +804,14 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
323
804
|
before: RunnableConfig | None = None,
|
|
324
805
|
limit: int | None = None,
|
|
325
806
|
) -> Iterator[CheckpointTuple]:
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
807
|
+
async def _collect():
|
|
808
|
+
out = []
|
|
809
|
+
async for item in self.alist(cfg, filter=filter, before=before, limit=limit):
|
|
810
|
+
out.append(item)
|
|
811
|
+
return out
|
|
812
|
+
|
|
813
|
+
for item in self._run(_collect()):
|
|
814
|
+
yield item
|
|
332
815
|
|
|
333
816
|
def get_tuple(self, cfg: RunnableConfig) -> CheckpointTuple | None:
|
|
334
817
|
return self._run(self.aget_tuple(cfg))
|
|
@@ -381,4 +864,10 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
381
864
|
try:
|
|
382
865
|
yield saver
|
|
383
866
|
finally:
|
|
384
|
-
|
|
867
|
+
if saver._client is not None:
|
|
868
|
+
try:
|
|
869
|
+
await saver._client.aclose()
|
|
870
|
+
except Exception as close_exc: # pragma: no cover - best effort
|
|
871
|
+
logger.debug("Failed to close AGW httpx.AsyncClient: %s", close_exc)
|
|
872
|
+
finally:
|
|
873
|
+
saver._client = None
|