agent-lab-sdk 0.1.35__py3-none-any.whl → 0.1.49__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,19 +1,21 @@
1
1
  from __future__ import annotations
2
2
 
3
- import orjson
4
- from random import random
5
- from langgraph.checkpoint.serde.types import ChannelProtocol
6
3
  import asyncio
4
+ import threading
7
5
  import base64
8
- from contextlib import asynccontextmanager
9
- from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple
10
6
  import logging
7
+ import os
8
+ from contextlib import asynccontextmanager
9
+ from random import random
10
+ from typing import Any, AsyncIterator, Dict, Iterable, Iterator, Optional, Sequence, Tuple
11
+
12
+ import orjson
13
+ from langgraph.checkpoint.serde.types import ChannelProtocol
11
14
 
12
15
  import httpx
13
16
  from langchain_core.runnables import RunnableConfig
14
17
 
15
18
  from langgraph.checkpoint.base import (
16
- WRITES_IDX_MAP,
17
19
  BaseCheckpointSaver,
18
20
  ChannelVersions,
19
21
  Checkpoint,
@@ -23,11 +25,44 @@ from langgraph.checkpoint.base import (
23
25
  get_checkpoint_metadata,
24
26
  )
25
27
  from langgraph.checkpoint.serde.base import SerializerProtocol
28
+ from langgraph.checkpoint.serde.encrypted import EncryptedSerializer
29
+
30
+ from .serde import Serializer
31
+ from agent_lab_sdk.metrics import get_metric
26
32
 
27
33
  __all__ = ["AsyncAGWCheckpointSaver"]
28
34
 
29
35
  logger = logging.getLogger(__name__)
30
36
 
37
+ AGW_METRIC_LABELS = ["method", "endpoint"]
38
+ AGW_HTTP_SUCCESS = get_metric(
39
+ "counter",
40
+ "agw_http_success_total",
41
+ "Number of successful AGW HTTP requests",
42
+ labelnames=AGW_METRIC_LABELS,
43
+ )
44
+ AGW_HTTP_ERROR = get_metric(
45
+ "counter",
46
+ "agw_http_error_total",
47
+ "Number of failed AGW HTTP request attempts",
48
+ labelnames=AGW_METRIC_LABELS,
49
+ )
50
+ AGW_HTTP_FINAL_ERROR = get_metric(
51
+ "counter",
52
+ "agw_http_final_error_total",
53
+ "Number of AGW HTTP requests that failed after retries",
54
+ labelnames=AGW_METRIC_LABELS,
55
+ )
56
+
57
+ TYPED_KEYS = ("type", "blob")
58
+
59
+
60
+ def _b64decode_strict(value: str) -> bytes | None:
61
+ try:
62
+ return base64.b64decode(value, validate=True)
63
+ except Exception:
64
+ return None
65
+
31
66
  # ------------------------------------------------------------------ #
32
67
  # helpers for Py < 3.10
33
68
  # ------------------------------------------------------------------ #
@@ -53,10 +88,85 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
53
88
  extra_headers: Dict[str, str] | None = None,
54
89
  verify: bool = True,
55
90
  ):
91
+ if not serde:
92
+ base_serde: SerializerProtocol = Serializer()
93
+ aes_key = (
94
+ os.getenv("LANGGRAPH_AES_KEY")
95
+ or os.getenv("AGW_AES_KEY")
96
+ or os.getenv("AES_KEY")
97
+ )
98
+ if aes_key:
99
+ base_serde = EncryptedSerializer.from_pycryptodome_aes(
100
+ base_serde, key=aes_key
101
+ )
102
+ serde = base_serde
56
103
  super().__init__(serde=serde)
57
104
  self.base_url = base_url.rstrip("/")
58
105
  self.timeout = timeout
59
- self.loop = asyncio.get_running_loop()
106
+ # Фоновый loop для sync-обёрток
107
+ self._bg_loop: asyncio.AbstractEventLoop | None = None
108
+ self._bg_thread: threading.Thread | None = None
109
+ self._loop_lock = threading.Lock()
110
+
111
+ raw_attempts = os.getenv("AGW_HTTP_MAX_RETRIES")
112
+ if raw_attempts is None:
113
+ self.retry_max_attempts = 3
114
+ else:
115
+ try:
116
+ self.retry_max_attempts = max(int(raw_attempts), 1)
117
+ except ValueError:
118
+ logger.warning(
119
+ "Env %s expected int, got %r; using default %s",
120
+ "AGW_HTTP_MAX_RETRIES",
121
+ raw_attempts,
122
+ 3,
123
+ )
124
+ self.retry_max_attempts = 3
125
+
126
+ raw_backoff_base = os.getenv("AGW_HTTP_RETRY_BACKOFF_BASE")
127
+ if raw_backoff_base is None:
128
+ self.retry_backoff_base = 0.5
129
+ else:
130
+ try:
131
+ self.retry_backoff_base = max(float(raw_backoff_base), 0.0)
132
+ except ValueError:
133
+ logger.warning(
134
+ "Env %s expected float, got %r; using default %.3f",
135
+ "AGW_HTTP_RETRY_BACKOFF_BASE",
136
+ raw_backoff_base,
137
+ 0.5,
138
+ )
139
+ self.retry_backoff_base = 0.5
140
+
141
+ raw_backoff_max = os.getenv("AGW_HTTP_RETRY_BACKOFF_MAX")
142
+ if raw_backoff_max is None:
143
+ self.retry_backoff_max = 5.0
144
+ else:
145
+ try:
146
+ self.retry_backoff_max = max(float(raw_backoff_max), 0.0)
147
+ except ValueError:
148
+ logger.warning(
149
+ "Env %s expected float, got %r; using default %.3f",
150
+ "AGW_HTTP_RETRY_BACKOFF_MAX",
151
+ raw_backoff_max,
152
+ 5.0,
153
+ )
154
+ self.retry_backoff_max = 5.0
155
+
156
+ raw_jitter = os.getenv("AGW_HTTP_RETRY_JITTER")
157
+ if raw_jitter is None:
158
+ self.retry_jitter = 0.25
159
+ else:
160
+ try:
161
+ self.retry_jitter = max(float(raw_jitter), 0.0)
162
+ except ValueError:
163
+ logger.warning(
164
+ "Env %s expected float, got %r; using default %.3f",
165
+ "AGW_HTTP_RETRY_JITTER",
166
+ raw_jitter,
167
+ 0.25,
168
+ )
169
+ self.retry_jitter = 0.25
60
170
 
61
171
  self.headers: Dict[str, str] = {
62
172
  "Accept": "application/json",
@@ -66,70 +176,279 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
66
176
  self.headers.update(extra_headers)
67
177
  if api_key:
68
178
  self.headers["Authorization"] = f"Bearer {api_key}"
69
-
70
- self._client = httpx.AsyncClient(
179
+
180
+ self._verify = verify
181
+ self._client: httpx.AsyncClient | None = None
182
+ self._client_loop: asyncio.AbstractEventLoop | None = None
183
+
184
+ def _create_client(self) -> httpx.AsyncClient:
185
+ return httpx.AsyncClient(
71
186
  base_url=self.base_url,
72
187
  headers=self.headers,
73
188
  timeout=self.timeout,
74
- verify=verify,
75
- trust_env=True
189
+ verify=self._verify,
190
+ trust_env=True,
76
191
  )
77
192
 
193
+ def _ensure_bg_loop(self) -> asyncio.AbstractEventLoop:
194
+ with self._loop_lock:
195
+ if self._bg_loop and self._bg_loop.is_running():
196
+ return self._bg_loop
197
+
198
+ loop = asyncio.new_event_loop()
199
+
200
+ def runner():
201
+ asyncio.set_event_loop(loop)
202
+ loop.run_forever()
203
+
204
+ t = threading.Thread(target=runner, name="agw-checkpoint-loop", daemon=True)
205
+ t.start()
206
+
207
+ self._bg_loop = loop
208
+ self._bg_thread = t
209
+ return loop
210
+
211
+ def _compute_retry_delay(self, attempt: int) -> float:
212
+ if attempt <= 0:
213
+ attempt = 1
214
+ if self.retry_backoff_base <= 0:
215
+ delay = 0.0
216
+ else:
217
+ delay = self.retry_backoff_base * (2 ** (attempt - 1))
218
+ if self.retry_backoff_max > 0:
219
+ delay = min(delay, self.retry_backoff_max)
220
+ if self.retry_jitter > 0:
221
+ delay += self.retry_jitter * random()
222
+ return delay
223
+
78
224
  async def __aenter__(self): # noqa: D401
79
225
  return self
80
226
 
81
227
  async def __aexit__(self, exc_type, exc, tb): # noqa: D401
82
- await self._client.aclose()
228
+ if self._client is not None:
229
+ try:
230
+ await self._client.aclose()
231
+ except Exception as close_exc: # pragma: no cover - best effort
232
+ logger.debug("Failed to close AGW httpx.AsyncClient: %s", close_exc)
233
+ finally:
234
+ self._client = None
235
+ self._client_loop = None
236
+ # останавливаем фоновый loop, если поднимали
237
+ if self._bg_loop is not None:
238
+ try:
239
+ self._bg_loop.call_soon_threadsafe(self._bg_loop.stop)
240
+ finally:
241
+ self._bg_loop = None
242
+ self._bg_thread = None
83
243
 
84
244
  # ----------------------- universal dump/load ---------------------
85
- # def _safe_dump(self, obj: Any) -> Any:
86
- # """self.serde.dump → гарантированная JSON-строка."""
87
- # dumped = self.serde.dumps(obj)
88
- # if isinstance(dumped, (bytes, bytearray)):
89
- # return base64.b64encode(dumped).decode() # str
90
- # return dumped # уже json-совместимо
91
-
92
245
  def _safe_dump(self, obj: Any) -> Any:
93
- """bytes → python-object; fallback base64 для реально бинарных данных."""
94
- dumped = self.serde.dumps(obj)
246
+ """
247
+ JSON-first сериализация:
248
+ 1) Пытаемся через self.serde.dumps(obj).
249
+ - Если вернул bytes: пробуем декодировать в JSON-строку и распарсить.
250
+ - Если не JSON/не UTF-8 — заворачиваем как base64-строку.
251
+ - Если вернул dict/list/scalar — они уже JSON-совместимы.
252
+ 2) Если self.serde.dumps(obj) бросает исключение (например, для Send),
253
+ делаем типизированный фолбэк {"type": str, "blob": base64 | None}.
254
+ """
255
+ try:
256
+ dumped = self.serde.dumps(obj)
257
+ except Exception:
258
+ # typed fallback (как рекомендуют в LangGraph для нетривиальных типов)
259
+ # https://langchain-ai.github.io/langgraph/reference/checkpoints/
260
+ try:
261
+ t, b = self.serde.dumps_typed(obj)
262
+ except Exception:
263
+ # крайний случай: строковое представление
264
+ t, b = type(obj).__name__, str(obj).encode()
265
+ return {"type": t, "blob": base64.b64encode(b).decode() if b is not None else None}
266
+
95
267
  if isinstance(dumped, (bytes, bytearray)):
96
268
  try:
97
- # 1) bytes → str
98
269
  s = dumped.decode()
99
- # 2) str JSON → python (list/dict/scalar)
100
270
  return orjson.loads(s)
101
271
  except (UnicodeDecodeError, orjson.JSONDecodeError):
102
- # не UTF-8 или не JSON → base64
103
272
  return base64.b64encode(dumped).decode()
104
273
  return dumped
105
274
 
106
275
  def _safe_load(self, obj: Any) -> Any:
107
- if isinstance(obj, (dict, list)): # уже распакованный JSON
108
- return self.serde.loads(orjson.dumps(obj))
276
+ if obj is None:
277
+ return None
278
+
279
+ if isinstance(obj, dict):
280
+ if all(k in obj for k in TYPED_KEYS):
281
+ t = obj.get("type")
282
+ blob = obj.get("blob")
283
+ if blob is None:
284
+ try:
285
+ return self.serde.loads_typed((t, None))
286
+ except Exception:
287
+ return obj
288
+ if isinstance(blob, str):
289
+ payload = _b64decode_strict(blob)
290
+ if payload is not None:
291
+ try:
292
+ return self.serde.loads_typed((t, payload))
293
+ except Exception:
294
+ # fall back to generic handling below
295
+ pass
296
+ try:
297
+ return self.serde.loads(orjson.dumps(obj))
298
+ except Exception:
299
+ return obj
300
+
301
+ if isinstance(obj, (list, tuple)):
302
+ if (
303
+ len(obj) == 2
304
+ and isinstance(obj[0], str)
305
+ and (obj[1] is None or isinstance(obj[1], str))
306
+ ):
307
+ blob = obj[1]
308
+ if blob is None:
309
+ try:
310
+ return self.serde.loads_typed((obj[0], None))
311
+ except Exception:
312
+ pass
313
+ elif isinstance(blob, str):
314
+ payload = _b64decode_strict(blob)
315
+ if payload is not None:
316
+ try:
317
+ return self.serde.loads_typed((obj[0], payload))
318
+ except Exception:
319
+ pass
320
+ try:
321
+ return self.serde.loads(orjson.dumps(list(obj)))
322
+ except Exception:
323
+ return obj
324
+
109
325
  if isinstance(obj, str):
110
- # сначала plain JSON-строка
111
326
  try:
112
327
  return self.serde.loads(obj.encode())
113
328
  except Exception:
114
- # возможно base64
115
- try:
116
- return self.serde.loads(base64.b64decode(obj))
117
- except Exception:
118
- return obj
329
+ payload = _b64decode_strict(obj)
330
+ if payload is not None:
331
+ try:
332
+ return self.serde.loads(payload)
333
+ except Exception:
334
+ pass
335
+ return obj
336
+
119
337
  try:
120
338
  return self.serde.loads(obj)
121
339
  except Exception:
122
340
  return obj
123
341
 
124
- # def _safe_load(self, obj: Any) -> Any:
125
- # """Обратная операция к _safe_dump."""
126
- # if isinstance(obj, str):
127
- # try:
128
- # return self.serde.load(base64.b64decode(obj))
129
- # except Exception:
130
- # # не base64 обычная строка
131
- # return self.serde.load(obj)
132
- # return self.serde.load(obj)
342
+ # ----------------------- deep dump/load (leaf-first) -------------
343
+ @staticmethod
344
+ def _is_json_scalar(x: Any) -> bool:
345
+ return x is None or isinstance(x, (str, int, float, bool))
346
+
347
+ @staticmethod
348
+ def _coerce_key(k: Any) -> str:
349
+ return k if isinstance(k, str) else str(k)
350
+
351
+ def _safe_dump_deep(self, obj: Any, _seen: set[int] | None = None) -> Any:
352
+ """
353
+ Идём от листьев к корню:
354
+ - Для контейнеров рекурсируем внутрь и сохраняем форму контейнера.
355
+ - Для листьев вызываем _safe_dump (fallback на serde/typed + base64).
356
+ """
357
+ if _seen is None:
358
+ _seen = set()
359
+
360
+ if self._is_json_scalar(obj):
361
+ return obj
362
+
363
+ if isinstance(obj, dict):
364
+ oid = id(obj)
365
+ if oid in _seen:
366
+ return {"type": "Cycle", "blob": None}
367
+ _seen.add(oid)
368
+ return {
369
+ self._coerce_key(k): self._safe_dump_deep(v, _seen) for k, v in obj.items()
370
+ }
371
+
372
+ if isinstance(obj, (list, tuple, set)):
373
+ oid = id(obj)
374
+ if oid in _seen:
375
+ return ["<cycle>"]
376
+ _seen.add(oid)
377
+ return [self._safe_dump_deep(v, _seen) for v in obj]
378
+
379
+ # лист: доверяем универсальному дамперу
380
+ return self._safe_dump(obj)
381
+
382
+ def _safe_load_deep(self, obj: Any, _seen: set[int] | None = None) -> Any:
383
+ """
384
+ Обратная операция:
385
+ - Контейнеры сначала пробуем целиком скормить serde.loads(...).
386
+ Если вернулся НЕ JSON-контейнер (например, объект сообщения) — возвращаем его.
387
+ Иначе рекурсивно обходим внутрь и листья скармливаем _safe_load.
388
+ - typed {"type","blob"} обрабатываем как раньше.
389
+ """
390
+ if _seen is None:
391
+ _seen = set()
392
+
393
+ # Примитивы: просто через _safe_load (декод base64/bytes и т.п.)
394
+ if self._is_json_scalar(obj):
395
+ return self._safe_load(obj)
396
+
397
+ # dict
398
+ if isinstance(obj, dict):
399
+ # типизированная обёртка — сразу разворачиваем
400
+ if all(k in obj for k in TYPED_KEYS):
401
+ return self._safe_load(obj)
402
+
403
+ # 1) parse-first: пробуем целиком восстановить объект через serde
404
+ try:
405
+ parsed = self.serde.loads(orjson.dumps(obj))
406
+ # если получили не-JSON-контейнер (объект), возвращаем
407
+ if not isinstance(parsed, (dict, list, tuple, str, int, float, bool, type(None))):
408
+ return parsed
409
+ except Exception:
410
+ pass
411
+
412
+ # 2) иначе — рекурсивно
413
+ oid = id(obj)
414
+ if oid in _seen:
415
+ return obj
416
+ _seen.add(oid)
417
+ return {k: self._safe_load_deep(v, _seen) for k, v in obj.items()}
418
+
419
+ # list
420
+ if isinstance(obj, list):
421
+ # parse-first: пытаемся восстановить весь список одной операцией
422
+ try:
423
+ parsed = self.serde.loads(orjson.dumps(obj))
424
+ if not isinstance(parsed, (dict, list, tuple, str, int, float, bool, type(None))):
425
+ return parsed
426
+ except Exception:
427
+ pass
428
+
429
+ oid = id(obj)
430
+ if oid in _seen:
431
+ return obj
432
+ _seen.add(oid)
433
+ return [self._safe_load_deep(v, _seen) for v in obj]
434
+
435
+ # tuple — аналогично list, но вернём list (JSON-совместимо)
436
+ if isinstance(obj, tuple):
437
+ try:
438
+ parsed = self.serde.loads(orjson.dumps(obj))
439
+ if not isinstance(parsed, (dict, list, tuple, str, int, float, bool, type(None))):
440
+ return parsed
441
+ except Exception:
442
+ pass
443
+
444
+ oid = id(obj)
445
+ if oid in _seen:
446
+ return obj
447
+ _seen.add(oid)
448
+ return [self._safe_load_deep(v, _seen) for v in obj]
449
+
450
+ # Всё остальное — лист, через _safe_load
451
+ return self._safe_load(obj)
133
452
 
134
453
  # ----------------------- config <-> api --------------------------
135
454
  def _to_api_config(self, cfg: RunnableConfig | None) -> Dict[str, Any]:
@@ -152,24 +471,44 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
152
471
  "v": cp["v"],
153
472
  "id": cp["id"],
154
473
  "ts": cp["ts"],
155
- "channelValues": {k: self._safe_dump(v) for k, v in cp["channel_values"].items()},
474
+ "channelValues": {
475
+ k: self._safe_dump_deep(v) for k, v in cp["channel_values"].items()
476
+ },
156
477
  "channelVersions": cp["channel_versions"],
157
478
  "versionsSeen": cp["versions_seen"],
158
- "pendingSends": cp.get("pending_sends", []),
479
+ "pendingSends": [] # как в BasePostgresSaver, они внутри checkpoint не нужны
159
480
  }
160
481
 
161
482
  def _decode_cp(self, raw: Dict[str, Any]) -> Checkpoint:
483
+ # Поддерживаем приём pendingSends (если сервер их отдаёт),
484
+ # но сами их не шлём при записи.
485
+ pending_sends: list[Tuple[str, Any]] = []
486
+ for obj in raw.get("pendingSends", []) or []:
487
+ if isinstance(obj, dict) and "channel" in obj:
488
+ channel = obj["channel"]
489
+ value_payload: Any = obj.get("value")
490
+ if value_payload is None and all(k in obj for k in TYPED_KEYS):
491
+ value_payload = {k: obj[k] for k in TYPED_KEYS}
492
+ pending_sends.append((channel, self._safe_load_deep(value_payload)))
493
+ elif isinstance(obj, (list, tuple)) and len(obj) >= 2:
494
+ channel = obj[0]
495
+ value_payload = obj[1]
496
+ pending_sends.append((channel, self._safe_load_deep(value_payload)))
497
+ else:
498
+ pending_sends.append(obj)
162
499
  return Checkpoint(
163
500
  v=raw["v"],
164
501
  id=raw["id"],
165
502
  ts=raw["ts"],
166
- channel_values={k: self._safe_load(v) for k, v in raw["channelValues"].items()},
503
+ channel_values={
504
+ k: self._safe_load_deep(v) for k, v in raw["channelValues"].items()
505
+ },
167
506
  channel_versions=raw["channelVersions"],
168
507
  versions_seen=raw["versionsSeen"],
169
- pending_sends=raw.get("pendingSends", []),
508
+ pending_sends=pending_sends,
170
509
  )
171
510
 
172
- def _decode_config(self, raw: Dict[str, Any]) -> Optional[RunnableConfig]:
511
+ def _decode_config(self, raw: Dict[str, Any] | None) -> Optional[RunnableConfig]:
173
512
  if not raw:
174
513
  return None
175
514
  return RunnableConfig(
@@ -179,7 +518,7 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
179
518
  run_name=raw.get("run_name"),
180
519
  max_concurrency=raw.get("max_concurrency"),
181
520
  recursion_limit=raw.get("recursion_limit"),
182
- configurable=self._decode_configurable(raw.get("configurable"))
521
+ configurable=self._decode_configurable(raw.get("configurable") or {}),
183
522
  )
184
523
 
185
524
  def _decode_configurable(self, raw: Dict[str, Any]) -> dict[str, Any]:
@@ -192,34 +531,157 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
192
531
 
193
532
  # metadata (de)ser
194
533
  def _enc_meta(self, md: CheckpointMetadata) -> CheckpointMetadata:
534
+ if not md:
535
+ return {}
195
536
  out: CheckpointMetadata = {}
196
537
  for k, v in md.items():
197
- out[k] = self._enc_meta(v) if isinstance(v, dict) else self._safe_dump(v) # type: ignore[assignment]
538
+ out[k] = self._enc_meta(v) if isinstance(v, dict) else self._safe_dump_deep(v) # type: ignore[assignment]
198
539
  return out
199
540
 
200
541
  def _dec_meta(self, md: Any) -> Any:
201
542
  if isinstance(md, dict):
202
543
  return {k: self._dec_meta(v) for k, v in md.items()}
203
- return self._safe_load(md)
544
+ return self._safe_load_deep(md)
204
545
 
205
546
  # ------------------------ HTTP wrapper ---------------------------
206
- async def _http(self, method: str, path: str, **kw) -> httpx.Response:
547
+ async def _http(
548
+ self,
549
+ method: str,
550
+ path: str,
551
+ *,
552
+ ok_statuses: Iterable[int] | None = None,
553
+ label_path: str | None = None,
554
+ **kw,
555
+ ) -> httpx.Response:
207
556
  if "json" in kw:
208
557
  payload = kw.pop("json")
209
558
  kw["data"] = orjson.dumps(payload)
210
- logger.info(kw["data"].decode())
211
-
212
- return await self._client.request(method, path, **kw)
559
+ logger.debug("AGW HTTP payload: %s", kw["data"].decode())
560
+
561
+ ok_set = set(ok_statuses) if ok_statuses is not None else set()
562
+ metric_path = label_path or path
563
+
564
+ attempt = 1
565
+ while True:
566
+ # клиент должен принадлежать текущему loop
567
+ current_loop = asyncio.get_running_loop()
568
+ client = self._client
569
+ if client is None or client.is_closed or self._client_loop is not current_loop:
570
+ if client is not None:
571
+ try:
572
+ await client.aclose()
573
+ except Exception as e:
574
+ logger.exception("ошибка при закрытии клиента", e)
575
+ client = self._create_client()
576
+ self._client = client
577
+ self._client_loop = current_loop
578
+ try:
579
+ resp = await client.request(method, path, **kw)
580
+ except httpx.RequestError as exc:
581
+ AGW_HTTP_ERROR.labels(method, metric_path).inc()
582
+ logger.warning(
583
+ "AGW request %s %s failed on attempt %d/%d: %s",
584
+ method,
585
+ path,
586
+ attempt,
587
+ self.retry_max_attempts,
588
+ exc,
589
+ )
590
+ if attempt >= self.retry_max_attempts:
591
+ AGW_HTTP_FINAL_ERROR.labels(method, metric_path).inc()
592
+ if self._client is not None:
593
+ try:
594
+ await self._client.aclose()
595
+ except Exception as close_exc: # pragma: no cover
596
+ logger.debug(
597
+ "Failed to close AGW httpx.AsyncClient: %s",
598
+ close_exc,
599
+ )
600
+ finally:
601
+ self._client = None
602
+ self._client_loop = None
603
+ raise
604
+
605
+ if self._client is not None:
606
+ try:
607
+ await self._client.aclose()
608
+ except Exception as close_exc: # pragma: no cover
609
+ logger.debug(
610
+ "Failed to close AGW httpx.AsyncClient: %s",
611
+ close_exc,
612
+ )
613
+ finally:
614
+ self._client = None
615
+ self._client_loop = None
616
+ delay = self._compute_retry_delay(attempt)
617
+ if delay > 0:
618
+ await asyncio.sleep(delay)
619
+ attempt += 1
620
+ continue
621
+
622
+ status = resp.status_code
623
+ if status < 400 or status in ok_set:
624
+ AGW_HTTP_SUCCESS.labels(method, metric_path).inc()
625
+ return resp
626
+
627
+ AGW_HTTP_ERROR.labels(method, metric_path).inc()
628
+ if status in (404, 406):
629
+ AGW_HTTP_FINAL_ERROR.labels(method, metric_path).inc()
630
+ return resp
631
+
632
+ if attempt >= self.retry_max_attempts:
633
+ AGW_HTTP_FINAL_ERROR.labels(method, metric_path).inc()
634
+ return resp
635
+
636
+ try:
637
+ await resp.aclose()
638
+ except Exception as exc: # pragma: no cover - best effort
639
+ logger.debug("Failed to close AGW httpx.Response before retry: %s", exc)
640
+
641
+ if self._client is not None:
642
+ try:
643
+ await self._client.aclose()
644
+ except Exception as close_exc: # pragma: no cover
645
+ logger.debug(
646
+ "Failed to close AGW httpx.AsyncClient: %s",
647
+ close_exc,
648
+ )
649
+ finally:
650
+ self._client = None
651
+ self._client_loop = None
652
+ delay = self._compute_retry_delay(attempt)
653
+ if delay > 0:
654
+ await asyncio.sleep(delay)
655
+ attempt += 1
213
656
 
214
657
  # -------------------- api -> CheckpointTuple ----------------------
215
658
  def _to_tuple(self, node: Dict[str, Any]) -> CheckpointTuple:
216
659
  pending = None
217
660
  if node.get("pendingWrites"):
218
- pending = [(w["first"], w["second"], self._safe_load(w["third"])) for w in node["pendingWrites"]]
661
+ pending = []
662
+ for w in node["pendingWrites"]:
663
+ if isinstance(w, dict):
664
+ first = w.get("first")
665
+ second = w.get("second")
666
+ third = w.get("third")
667
+ if third is None and isinstance(second, dict) and all(
668
+ k in second for k in TYPED_KEYS
669
+ ):
670
+ third = second
671
+ pending.append((first, second, self._safe_load_deep(third)))
672
+ elif isinstance(w, (list, tuple)):
673
+ if len(w) == 3:
674
+ first, second, third = w
675
+ elif len(w) == 2:
676
+ first, second = w
677
+ third = None
678
+ else:
679
+ continue
680
+ pending.append((first, second, self._safe_load_deep(third)))
219
681
  return CheckpointTuple(
220
- config=self._decode_config(node["config"]),
682
+ config=self._decode_config(node.get("config")),
221
683
  checkpoint=self._decode_cp(node["checkpoint"]),
222
- metadata=self._dec_meta(node["metadata"]),
684
+ metadata=self._dec_meta(node.get("metadata")),
223
685
  parent_config=self._decode_config(node.get("parentConfig")),
224
686
  pending_writes=pending,
225
687
  )
@@ -233,13 +695,15 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
233
695
  tid = api_cfg["threadId"]
234
696
 
235
697
  if cid:
698
+ path_template = "/checkpoint/{threadId}/{checkpointId}"
236
699
  path = f"/checkpoint/{tid}/{cid}"
237
700
  params = {"checkpointNs": api_cfg.get("checkpointNs", "")}
238
701
  else:
702
+ path_template = "/checkpoint/{threadId}"
239
703
  path = f"/checkpoint/{tid}"
240
704
  params = None
241
705
 
242
- resp = await self._http("GET", path, params=params)
706
+ resp = await self._http("GET", path, params=params, label_path=path_template)
243
707
  logger.debug("AGW aget_tuple response: %s", resp.text)
244
708
 
245
709
  if not resp.text:
@@ -263,7 +727,12 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
263
727
  "before": self._to_api_config(before) if before else None,
264
728
  "limit": limit,
265
729
  }
266
- resp = await self._http("POST", "/checkpoint/list", json=payload)
730
+ resp = await self._http(
731
+ "POST",
732
+ "/checkpoint/list",
733
+ json=payload,
734
+ label_path="/checkpoint/list",
735
+ )
267
736
  logger.debug("AGW alist response: %s", resp.text)
268
737
  resp.raise_for_status()
269
738
  for item in resp.json():
@@ -282,7 +751,7 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
282
751
  "metadata": self._enc_meta(get_checkpoint_metadata(cfg, metadata)),
283
752
  "newVersions": new_versions,
284
753
  }
285
- resp = await self._http("POST", "/checkpoint", json=payload)
754
+ resp = await self._http("POST", "/checkpoint", json=payload, label_path="/checkpoint")
286
755
  logger.debug("AGW aput response: %s", resp.text)
287
756
  resp.raise_for_status()
288
757
  return resp.json()["config"]
@@ -294,26 +763,38 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
294
763
  task_id: str,
295
764
  task_path: str = "",
296
765
  ) -> None:
297
- enc = [{"first": ch, "second": self._safe_dump(v)} for ch, v in writes]
766
+ enc = [{"first": ch, "second": self._safe_dump_deep(v)} for ch, v in writes]
298
767
  payload = {
299
768
  "config": self._to_api_config(cfg),
300
769
  "writes": enc,
301
770
  "taskId": task_id,
302
771
  "taskPath": task_path,
303
772
  }
304
- resp = await self._http("POST", "/checkpoint/writes", json=payload)
773
+ resp = await self._http(
774
+ "POST",
775
+ "/checkpoint/writes",
776
+ json=payload,
777
+ label_path="/checkpoint/writes",
778
+ )
305
779
  logger.debug("AGW aput_writes response: %s", resp.text)
306
780
  resp.raise_for_status()
307
781
 
308
782
  async def adelete_thread(self, thread_id: str) -> None:
309
- resp = await self._http("DELETE", f"/checkpoint/{thread_id}")
783
+ resp = await self._http(
784
+ "DELETE",
785
+ f"/checkpoint/{thread_id}",
786
+ label_path="/checkpoint/{threadId}",
787
+ )
310
788
  resp.raise_for_status()
311
789
 
312
790
  # =================================================================
313
791
  # sync-обёртки
314
792
  # =================================================================
315
793
  def _run(self, coro):
316
- return asyncio.run_coroutine_threadsafe(coro, self.loop).result()
794
+ # sync-обёртки всегда выполняем в собственном loop в отдельном потоке
795
+ loop = self._ensure_bg_loop()
796
+ fut = asyncio.run_coroutine_threadsafe(coro, loop)
797
+ return fut.result()
317
798
 
318
799
  def list(
319
800
  self,
@@ -323,12 +804,14 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
323
804
  before: RunnableConfig | None = None,
324
805
  limit: int | None = None,
325
806
  ) -> Iterator[CheckpointTuple]:
326
- aiter_ = self.alist(cfg, filter=filter, before=before, limit=limit)
327
- while True:
328
- try:
329
- yield self._run(anext(aiter_))
330
- except StopAsyncIteration:
331
- break
807
+ async def _collect():
808
+ out = []
809
+ async for item in self.alist(cfg, filter=filter, before=before, limit=limit):
810
+ out.append(item)
811
+ return out
812
+
813
+ for item in self._run(_collect()):
814
+ yield item
332
815
 
333
816
  def get_tuple(self, cfg: RunnableConfig) -> CheckpointTuple | None:
334
817
  return self._run(self.aget_tuple(cfg))
@@ -381,4 +864,10 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
381
864
  try:
382
865
  yield saver
383
866
  finally:
384
- await saver._client.aclose()
867
+ if saver._client is not None:
868
+ try:
869
+ await saver._client.aclose()
870
+ except Exception as close_exc: # pragma: no cover - best effort
871
+ logger.debug("Failed to close AGW httpx.AsyncClient: %s", close_exc)
872
+ finally:
873
+ saver._client = None