agent-lab-sdk 0.1.36__tar.gz → 0.1.38__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agent-lab-sdk might be problematic. Click here for more details.

Files changed (27) hide show
  1. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/PKG-INFO +1 -1
  2. agent_lab_sdk-0.1.38/agent_lab_sdk/langgraph/checkpoint/agw_saver.py +706 -0
  3. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk.egg-info/PKG-INFO +1 -1
  4. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/pyproject.toml +1 -1
  5. agent_lab_sdk-0.1.36/agent_lab_sdk/langgraph/checkpoint/agw_saver.py +0 -444
  6. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/LICENSE +0 -0
  7. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/README.md +0 -0
  8. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/__init__.py +0 -0
  9. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/langgraph/checkpoint/__init__.py +0 -0
  10. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/langgraph/checkpoint/serde.py +0 -0
  11. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/llm/__init__.py +0 -0
  12. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/llm/agw_token_manager.py +0 -0
  13. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/llm/gigachat_token_manager.py +0 -0
  14. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/llm/llm.py +0 -0
  15. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/llm/throttled.py +0 -0
  16. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/metrics/__init__.py +0 -0
  17. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/metrics/metrics.py +0 -0
  18. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/schema/__init__.py +0 -0
  19. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/schema/input_types.py +0 -0
  20. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/schema/log_message.py +0 -0
  21. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/storage/__init__.py +0 -0
  22. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk/storage/storage.py +0 -0
  23. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk.egg-info/SOURCES.txt +0 -0
  24. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk.egg-info/dependency_links.txt +0 -0
  25. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk.egg-info/requires.txt +0 -0
  26. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/agent_lab_sdk.egg-info/top_level.txt +0 -0
  27. {agent_lab_sdk-0.1.36 → agent_lab_sdk-0.1.38}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: agent-lab-sdk
3
- Version: 0.1.36
3
+ Version: 0.1.38
4
4
  Summary: SDK для работы с Agent Lab
5
5
  Author-email: Andrew Ohurtsov <andermirik@yandex.com>
6
6
  License: Proprietary and Confidential — All Rights Reserved
@@ -0,0 +1,706 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import base64
5
+ import logging
6
+ import os
7
+ from contextlib import asynccontextmanager
8
+ from random import random
9
+ from typing import Any, AsyncIterator, Dict, Iterable, Iterator, Optional, Sequence, Tuple
10
+
11
+ import orjson
12
+ from langgraph.checkpoint.serde.types import ChannelProtocol
13
+
14
+ import httpx
15
+ from langchain_core.runnables import RunnableConfig
16
+
17
+ from langgraph.checkpoint.base import (
18
+ BaseCheckpointSaver,
19
+ ChannelVersions,
20
+ Checkpoint,
21
+ CheckpointMetadata,
22
+ CheckpointTuple,
23
+ get_checkpoint_id,
24
+ get_checkpoint_metadata,
25
+ )
26
+ from langgraph.checkpoint.serde.base import SerializerProtocol
27
+ from langgraph.checkpoint.serde.encrypted import EncryptedSerializer
28
+
29
+ from .serde import Serializer
30
+ from agent_lab_sdk.metrics import get_metric
31
+
32
+ __all__ = ["AsyncAGWCheckpointSaver"]
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ AGW_METRIC_LABELS = ["method", "endpoint"]
37
+ AGW_HTTP_SUCCESS = get_metric(
38
+ "counter",
39
+ "agw_http_success_total",
40
+ "Number of successful AGW HTTP requests",
41
+ labelnames=AGW_METRIC_LABELS,
42
+ )
43
+ AGW_HTTP_ERROR = get_metric(
44
+ "counter",
45
+ "agw_http_error_total",
46
+ "Number of failed AGW HTTP request attempts",
47
+ labelnames=AGW_METRIC_LABELS,
48
+ )
49
+ AGW_HTTP_FINAL_ERROR = get_metric(
50
+ "counter",
51
+ "agw_http_final_error_total",
52
+ "Number of AGW HTTP requests that failed after retries",
53
+ labelnames=AGW_METRIC_LABELS,
54
+ )
55
+
56
+ TYPED_KEYS = ("type", "blob")
57
+
58
+
59
+ def _b64decode_strict(value: str) -> bytes | None:
60
+ try:
61
+ return base64.b64decode(value, validate=True)
62
+ except Exception:
63
+ return None
64
+
65
+ # ------------------------------------------------------------------ #
66
+ # helpers for Py < 3.10
67
+ # ------------------------------------------------------------------ #
68
+ try:
69
+ anext # type: ignore[name-defined]
70
+ except NameError: # pragma: no cover
71
+
72
+ async def anext(it):
73
+ return await it.__anext__()
74
+
75
+
76
+ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
77
+ """Persist checkpoints in Agent-Gateway с помощью `httpx` async client."""
78
+
79
+ # ---------------------------- init / ctx -------------------------
80
+ def __init__(
81
+ self,
82
+ base_url: str = "http://localhost",
83
+ *,
84
+ serde: SerializerProtocol | None = None,
85
+ timeout: int | float = 10,
86
+ api_key: str | None = None,
87
+ extra_headers: Dict[str, str] | None = None,
88
+ verify: bool = True,
89
+ ):
90
+ if not serde:
91
+ base_serde: SerializerProtocol = Serializer()
92
+ aes_key = (
93
+ os.getenv("LANGGRAPH_AES_KEY")
94
+ or os.getenv("AGW_AES_KEY")
95
+ or os.getenv("AES_KEY")
96
+ )
97
+ if aes_key:
98
+ base_serde = EncryptedSerializer.from_pycryptodome_aes(
99
+ base_serde, key=aes_key
100
+ )
101
+ serde = base_serde
102
+ super().__init__(serde=serde)
103
+ self.base_url = base_url.rstrip("/")
104
+ self.timeout = timeout
105
+ self.loop = asyncio.get_running_loop()
106
+
107
+ raw_attempts = os.getenv("AGW_HTTP_MAX_RETRIES")
108
+ if raw_attempts is None:
109
+ self.retry_max_attempts = 3
110
+ else:
111
+ try:
112
+ self.retry_max_attempts = max(int(raw_attempts), 1)
113
+ except ValueError:
114
+ logger.warning(
115
+ "Env %s expected int, got %r; using default %s",
116
+ "AGW_HTTP_MAX_RETRIES",
117
+ raw_attempts,
118
+ 3,
119
+ )
120
+ self.retry_max_attempts = 3
121
+
122
+ raw_backoff_base = os.getenv("AGW_HTTP_RETRY_BACKOFF_BASE")
123
+ if raw_backoff_base is None:
124
+ self.retry_backoff_base = 0.5
125
+ else:
126
+ try:
127
+ self.retry_backoff_base = max(float(raw_backoff_base), 0.0)
128
+ except ValueError:
129
+ logger.warning(
130
+ "Env %s expected float, got %r; using default %.3f",
131
+ "AGW_HTTP_RETRY_BACKOFF_BASE",
132
+ raw_backoff_base,
133
+ 0.5,
134
+ )
135
+ self.retry_backoff_base = 0.5
136
+
137
+ raw_backoff_max = os.getenv("AGW_HTTP_RETRY_BACKOFF_MAX")
138
+ if raw_backoff_max is None:
139
+ self.retry_backoff_max = 5.0
140
+ else:
141
+ try:
142
+ self.retry_backoff_max = max(float(raw_backoff_max), 0.0)
143
+ except ValueError:
144
+ logger.warning(
145
+ "Env %s expected float, got %r; using default %.3f",
146
+ "AGW_HTTP_RETRY_BACKOFF_MAX",
147
+ raw_backoff_max,
148
+ 5.0,
149
+ )
150
+ self.retry_backoff_max = 5.0
151
+
152
+ raw_jitter = os.getenv("AGW_HTTP_RETRY_JITTER")
153
+ if raw_jitter is None:
154
+ self.retry_jitter = 0.25
155
+ else:
156
+ try:
157
+ self.retry_jitter = max(float(raw_jitter), 0.0)
158
+ except ValueError:
159
+ logger.warning(
160
+ "Env %s expected float, got %r; using default %.3f",
161
+ "AGW_HTTP_RETRY_JITTER",
162
+ raw_jitter,
163
+ 0.25,
164
+ )
165
+ self.retry_jitter = 0.25
166
+
167
+ self.headers: Dict[str, str] = {
168
+ "Accept": "application/json",
169
+ "Content-Type": "application/json",
170
+ }
171
+ if extra_headers:
172
+ self.headers.update(extra_headers)
173
+ if api_key:
174
+ self.headers["Authorization"] = f"Bearer {api_key}"
175
+
176
+ self._verify = verify
177
+ self._client: httpx.AsyncClient | None = None
178
+
179
+ def _create_client(self) -> httpx.AsyncClient:
180
+ return httpx.AsyncClient(
181
+ base_url=self.base_url,
182
+ headers=self.headers,
183
+ timeout=self.timeout,
184
+ verify=self._verify,
185
+ trust_env=True,
186
+ )
187
+
188
+ def _ensure_client(self) -> httpx.AsyncClient:
189
+ client = self._client
190
+ if client is None or client.is_closed:
191
+ if client is not None and client.is_closed:
192
+ logger.debug("Recreating closed httpx.AsyncClient for AGW")
193
+ client = self._create_client()
194
+ self._client = client
195
+ return client
196
+
197
+ def _compute_retry_delay(self, attempt: int) -> float:
198
+ if attempt <= 0:
199
+ attempt = 1
200
+ if self.retry_backoff_base <= 0:
201
+ delay = 0.0
202
+ else:
203
+ delay = self.retry_backoff_base * (2 ** (attempt - 1))
204
+ if self.retry_backoff_max > 0:
205
+ delay = min(delay, self.retry_backoff_max)
206
+ if self.retry_jitter > 0:
207
+ delay += self.retry_jitter * random()
208
+ return delay
209
+
210
+ async def __aenter__(self): # noqa: D401
211
+ return self
212
+
213
+ async def __aexit__(self, exc_type, exc, tb): # noqa: D401
214
+ if self._client is not None:
215
+ try:
216
+ await self._client.aclose()
217
+ except Exception as close_exc: # pragma: no cover - best effort
218
+ logger.debug("Failed to close AGW httpx.AsyncClient: %s", close_exc)
219
+ finally:
220
+ self._client = None
221
+
222
+ # ----------------------- universal dump/load ---------------------
223
+ # def _safe_dump(self, obj: Any) -> Any:
224
+ # """self.serde.dump → гарантированная JSON-строка."""
225
+ # dumped = self.serde.dumps(obj)
226
+ # if isinstance(dumped, (bytes, bytearray)):
227
+ # return base64.b64encode(dumped).decode() # str
228
+ # return dumped # уже json-совместимо
229
+
230
+ def _safe_dump(self, obj: Any) -> Any:
231
+ """bytes → python-object; fallback base64 для реально бинарных данных."""
232
+ dumped = self.serde.dumps(obj)
233
+ if isinstance(dumped, (bytes, bytearray)):
234
+ try:
235
+ # 1) bytes → str
236
+ s = dumped.decode()
237
+ # 2) str JSON → python (list/dict/scalar)
238
+ return orjson.loads(s)
239
+ except (UnicodeDecodeError, orjson.JSONDecodeError):
240
+ # не UTF-8 или не JSON → base64
241
+ return base64.b64encode(dumped).decode()
242
+ return dumped
243
+
244
+ def _safe_load(self, obj: Any) -> Any:
245
+ if obj is None:
246
+ return None
247
+
248
+ if isinstance(obj, dict):
249
+ if all(k in obj for k in TYPED_KEYS):
250
+ t = obj.get("type")
251
+ blob = obj.get("blob")
252
+ if blob is None:
253
+ try:
254
+ return self.serde.loads_typed((t, None))
255
+ except Exception:
256
+ return obj
257
+ if isinstance(blob, str):
258
+ payload = _b64decode_strict(blob)
259
+ if payload is not None:
260
+ try:
261
+ return self.serde.loads_typed((t, payload))
262
+ except Exception:
263
+ # fall back to generic handling below
264
+ pass
265
+ try:
266
+ return self.serde.loads(orjson.dumps(obj))
267
+ except Exception:
268
+ return obj
269
+
270
+ if isinstance(obj, (list, tuple)):
271
+ if (
272
+ len(obj) == 2
273
+ and isinstance(obj[0], str)
274
+ and (obj[1] is None or isinstance(obj[1], str))
275
+ ):
276
+ blob = obj[1]
277
+ if blob is None:
278
+ try:
279
+ return self.serde.loads_typed((obj[0], None))
280
+ except Exception:
281
+ pass
282
+ elif isinstance(blob, str):
283
+ payload = _b64decode_strict(blob)
284
+ if payload is not None:
285
+ try:
286
+ return self.serde.loads_typed((obj[0], payload))
287
+ except Exception:
288
+ pass
289
+ try:
290
+ return self.serde.loads(orjson.dumps(list(obj)))
291
+ except Exception:
292
+ return obj
293
+
294
+ if isinstance(obj, str):
295
+ try:
296
+ return self.serde.loads(obj.encode())
297
+ except Exception:
298
+ payload = _b64decode_strict(obj)
299
+ if payload is not None:
300
+ try:
301
+ return self.serde.loads(payload)
302
+ except Exception:
303
+ pass
304
+ return obj
305
+
306
+ try:
307
+ return self.serde.loads(obj)
308
+ except Exception:
309
+ return obj
310
+
311
+ # def _safe_load(self, obj: Any) -> Any:
312
+ # """Обратная операция к _safe_dump."""
313
+ # if isinstance(obj, str):
314
+ # try:
315
+ # return self.serde.load(base64.b64decode(obj))
316
+ # except Exception:
317
+ # # не base64 — обычная строка
318
+ # return self.serde.load(obj)
319
+ # return self.serde.load(obj)
320
+
321
+ # ----------------------- config <-> api --------------------------
322
+ def _to_api_config(self, cfg: RunnableConfig | None) -> Dict[str, Any]:
323
+ if not cfg:
324
+ return {}
325
+ c = cfg.get("configurable", {})
326
+ res: Dict[str, Any] = {
327
+ "threadId": c.get("thread_id", ""),
328
+ "checkpointNs": c.get("checkpoint_ns", ""),
329
+ }
330
+ if cid := c.get("checkpoint_id"):
331
+ res["checkpointId"] = cid
332
+ if ts := c.get("thread_ts"):
333
+ res["threadTs"] = ts
334
+ return res
335
+
336
+ # --------------------- checkpoint (de)ser ------------------------
337
+ def _encode_cp(self, cp: Checkpoint) -> Dict[str, Any]:
338
+ pending: list[Any] = []
339
+ for item in cp.get("pending_sends", []) or []:
340
+ try:
341
+ channel, value = item
342
+ except Exception:
343
+ pending.append(item)
344
+ continue
345
+ pending.append([channel, self._safe_dump(value)])
346
+ return {
347
+ "v": cp["v"],
348
+ "id": cp["id"],
349
+ "ts": cp["ts"],
350
+ "channelValues": {k: self._safe_dump(v) for k, v in cp["channel_values"].items()},
351
+ "channelVersions": cp["channel_versions"],
352
+ "versionsSeen": cp["versions_seen"],
353
+ "pendingSends": pending,
354
+ }
355
+
356
+ def _decode_cp(self, raw: Dict[str, Any]) -> Checkpoint:
357
+ pending_sends: list[Tuple[str, Any]] = []
358
+ for obj in raw.get("pendingSends", []) or []:
359
+ if isinstance(obj, dict) and "channel" in obj:
360
+ channel = obj["channel"]
361
+ value_payload: Any = obj.get("value")
362
+ if value_payload is None and all(k in obj for k in TYPED_KEYS):
363
+ value_payload = {k: obj[k] for k in TYPED_KEYS}
364
+ pending_sends.append((channel, self._safe_load(value_payload)))
365
+ elif isinstance(obj, (list, tuple)) and len(obj) >= 2:
366
+ channel = obj[0]
367
+ value_payload = obj[1]
368
+ pending_sends.append((channel, self._safe_load(value_payload)))
369
+ else:
370
+ pending_sends.append(obj) # сохраняем как есть, если формат неизвестен
371
+ return Checkpoint(
372
+ v=raw["v"],
373
+ id=raw["id"],
374
+ ts=raw["ts"],
375
+ channel_values={k: self._safe_load(v) for k, v in raw["channelValues"].items()},
376
+ channel_versions=raw["channelVersions"],
377
+ versions_seen=raw["versionsSeen"],
378
+ pending_sends=pending_sends,
379
+ )
380
+
381
+ def _decode_config(self, raw: Dict[str, Any] | None) -> Optional[RunnableConfig]:
382
+ if not raw:
383
+ return None
384
+ return RunnableConfig(
385
+ tags=raw.get("tags"),
386
+ metadata=raw.get("metadata"),
387
+ callbacks=raw.get("callbacks"),
388
+ run_name=raw.get("run_name"),
389
+ max_concurrency=raw.get("max_concurrency"),
390
+ recursion_limit=raw.get("recursion_limit"),
391
+ configurable=self._decode_configurable(raw.get("configurable") or {}),
392
+ )
393
+
394
+ def _decode_configurable(self, raw: Dict[str, Any]) -> dict[str, Any]:
395
+ return {
396
+ "thread_id": raw.get("threadId"),
397
+ "thread_ts": raw.get("threadTs"),
398
+ "checkpoint_ns": raw.get("checkpointNs"),
399
+ "checkpoint_id": raw.get("checkpointId")
400
+ }
401
+
402
+ # metadata (de)ser
403
+ def _enc_meta(self, md: CheckpointMetadata) -> CheckpointMetadata:
404
+ if not md:
405
+ return {}
406
+ out: CheckpointMetadata = {}
407
+ for k, v in md.items():
408
+ out[k] = self._enc_meta(v) if isinstance(v, dict) else self._safe_dump(v) # type: ignore[assignment]
409
+ return out
410
+
411
+ def _dec_meta(self, md: Any) -> Any:
412
+ if isinstance(md, dict):
413
+ return {k: self._dec_meta(v) for k, v in md.items()}
414
+ return self._safe_load(md)
415
+
416
+ # ------------------------ HTTP wrapper ---------------------------
417
+ async def _http(
418
+ self,
419
+ method: str,
420
+ path: str,
421
+ *,
422
+ ok_statuses: Iterable[int] | None = None,
423
+ **kw,
424
+ ) -> httpx.Response:
425
+ if "json" in kw:
426
+ payload = kw.pop("json")
427
+ kw["data"] = orjson.dumps(payload)
428
+ logger.debug("AGW HTTP payload: %s", kw["data"].decode())
429
+
430
+ ok_set = set(ok_statuses) if ok_statuses is not None else set()
431
+
432
+ attempt = 1
433
+ while True:
434
+ client = self._ensure_client()
435
+ try:
436
+ resp = await client.request(method, path, **kw)
437
+ except httpx.RequestError as exc:
438
+ AGW_HTTP_ERROR.labels(method, path).inc()
439
+ logger.warning(
440
+ "AGW request %s %s failed on attempt %d/%d: %s",
441
+ method,
442
+ path,
443
+ attempt,
444
+ self.retry_max_attempts,
445
+ exc,
446
+ )
447
+ if attempt >= self.retry_max_attempts:
448
+ AGW_HTTP_FINAL_ERROR.labels(method, path).inc()
449
+ if self._client is not None:
450
+ try:
451
+ await self._client.aclose()
452
+ except Exception as close_exc: # pragma: no cover
453
+ logger.debug(
454
+ "Failed to close AGW httpx.AsyncClient: %s",
455
+ close_exc,
456
+ )
457
+ finally:
458
+ self._client = None
459
+ raise
460
+
461
+ if self._client is not None:
462
+ try:
463
+ await self._client.aclose()
464
+ except Exception as close_exc: # pragma: no cover
465
+ logger.debug(
466
+ "Failed to close AGW httpx.AsyncClient: %s",
467
+ close_exc,
468
+ )
469
+ finally:
470
+ self._client = None
471
+ delay = self._compute_retry_delay(attempt)
472
+ if delay > 0:
473
+ await asyncio.sleep(delay)
474
+ attempt += 1
475
+ continue
476
+
477
+ status = resp.status_code
478
+ if status < 400 or status in ok_set:
479
+ AGW_HTTP_SUCCESS.labels(method, path).inc()
480
+ return resp
481
+
482
+ AGW_HTTP_ERROR.labels(method, path).inc()
483
+ if status in (404, 406):
484
+ AGW_HTTP_FINAL_ERROR.labels(method, path).inc()
485
+ return resp
486
+
487
+ if attempt >= self.retry_max_attempts:
488
+ AGW_HTTP_FINAL_ERROR.labels(method, path).inc()
489
+ return resp
490
+
491
+ try:
492
+ await resp.aclose()
493
+ except Exception as exc: # pragma: no cover - best effort
494
+ logger.debug("Failed to close AGW httpx.Response before retry: %s", exc)
495
+
496
+ if self._client is not None:
497
+ try:
498
+ await self._client.aclose()
499
+ except Exception as close_exc: # pragma: no cover
500
+ logger.debug(
501
+ "Failed to close AGW httpx.AsyncClient: %s",
502
+ close_exc,
503
+ )
504
+ finally:
505
+ self._client = None
506
+ delay = self._compute_retry_delay(attempt)
507
+ if delay > 0:
508
+ await asyncio.sleep(delay)
509
+ attempt += 1
510
+
511
+ # -------------------- api -> CheckpointTuple ----------------------
512
+ def _to_tuple(self, node: Dict[str, Any]) -> CheckpointTuple:
513
+ pending = None
514
+ if node.get("pendingWrites"):
515
+ pending = []
516
+ for w in node["pendingWrites"]:
517
+ if isinstance(w, dict):
518
+ first = w.get("first")
519
+ second = w.get("second")
520
+ third = w.get("third")
521
+ if third is None and isinstance(second, dict) and all(
522
+ k in second for k in TYPED_KEYS
523
+ ):
524
+ third = second
525
+ pending.append((first, second, self._safe_load(third)))
526
+ elif isinstance(w, (list, tuple)):
527
+ if len(w) == 3:
528
+ first, second, third = w
529
+ elif len(w) == 2:
530
+ first, second = w
531
+ third = None
532
+ else:
533
+ continue
534
+ pending.append((first, second, self._safe_load(third)))
535
+ return CheckpointTuple(
536
+ config=self._decode_config(node.get("config")),
537
+ checkpoint=self._decode_cp(node["checkpoint"]),
538
+ metadata=self._dec_meta(node.get("metadata")),
539
+ parent_config=self._decode_config(node.get("parentConfig")),
540
+ pending_writes=pending,
541
+ )
542
+
543
+ # =================================================================
544
+ # async-методы BaseCheckpointSaver
545
+ # =================================================================
546
+ async def aget_tuple(self, cfg: RunnableConfig) -> CheckpointTuple | None:
547
+ cid = get_checkpoint_id(cfg)
548
+ api_cfg = self._to_api_config(cfg)
549
+ tid = api_cfg["threadId"]
550
+
551
+ if cid:
552
+ path = f"/checkpoint/{tid}/{cid}"
553
+ params = {"checkpointNs": api_cfg.get("checkpointNs", "")}
554
+ else:
555
+ path = f"/checkpoint/{tid}"
556
+ params = None
557
+
558
+ resp = await self._http("GET", path, params=params)
559
+ logger.debug("AGW aget_tuple response: %s", resp.text)
560
+
561
+ if not resp.text:
562
+ return None
563
+ if resp.status_code in (404, 406):
564
+ return None
565
+ resp.raise_for_status()
566
+ return self._to_tuple(resp.json())
567
+
568
+ async def alist(
569
+ self,
570
+ cfg: RunnableConfig | None,
571
+ *,
572
+ filter: Dict[str, Any] | None = None,
573
+ before: RunnableConfig | None = None,
574
+ limit: int | None = None,
575
+ ) -> AsyncIterator[CheckpointTuple]:
576
+ payload = {
577
+ "config": self._to_api_config(cfg) if cfg else None,
578
+ "filter": filter,
579
+ "before": self._to_api_config(before) if before else None,
580
+ "limit": limit,
581
+ }
582
+ resp = await self._http("POST", "/checkpoint/list", json=payload)
583
+ logger.debug("AGW alist response: %s", resp.text)
584
+ resp.raise_for_status()
585
+ for item in resp.json():
586
+ yield self._to_tuple(item)
587
+
588
+ async def aput(
589
+ self,
590
+ cfg: RunnableConfig,
591
+ cp: Checkpoint,
592
+ metadata: CheckpointMetadata,
593
+ new_versions: ChannelVersions,
594
+ ) -> RunnableConfig:
595
+ payload = {
596
+ "config": self._to_api_config(cfg),
597
+ "checkpoint": self._encode_cp(cp),
598
+ "metadata": self._enc_meta(get_checkpoint_metadata(cfg, metadata)),
599
+ "newVersions": new_versions,
600
+ }
601
+ resp = await self._http("POST", "/checkpoint", json=payload)
602
+ logger.debug("AGW aput response: %s", resp.text)
603
+ resp.raise_for_status()
604
+ return resp.json()["config"]
605
+
606
+ async def aput_writes(
607
+ self,
608
+ cfg: RunnableConfig,
609
+ writes: Sequence[Tuple[str, Any]],
610
+ task_id: str,
611
+ task_path: str = "",
612
+ ) -> None:
613
+ enc = [{"first": ch, "second": self._safe_dump(v)} for ch, v in writes]
614
+ payload = {
615
+ "config": self._to_api_config(cfg),
616
+ "writes": enc,
617
+ "taskId": task_id,
618
+ "taskPath": task_path,
619
+ }
620
+ resp = await self._http("POST", "/checkpoint/writes", json=payload)
621
+ logger.debug("AGW aput_writes response: %s", resp.text)
622
+ resp.raise_for_status()
623
+
624
+ async def adelete_thread(self, thread_id: str) -> None:
625
+ resp = await self._http("DELETE", f"/checkpoint/{thread_id}")
626
+ resp.raise_for_status()
627
+
628
+ # =================================================================
629
+ # sync-обёртки
630
+ # =================================================================
631
+ def _run(self, coro):
632
+ return asyncio.run_coroutine_threadsafe(coro, self.loop).result()
633
+
634
+ def list(
635
+ self,
636
+ cfg: RunnableConfig | None,
637
+ *,
638
+ filter: Dict[str, Any] | None = None,
639
+ before: RunnableConfig | None = None,
640
+ limit: int | None = None,
641
+ ) -> Iterator[CheckpointTuple]:
642
+ aiter_ = self.alist(cfg, filter=filter, before=before, limit=limit)
643
+ while True:
644
+ try:
645
+ yield self._run(anext(aiter_))
646
+ except StopAsyncIteration:
647
+ break
648
+
649
+ def get_tuple(self, cfg: RunnableConfig) -> CheckpointTuple | None:
650
+ return self._run(self.aget_tuple(cfg))
651
+
652
+ def put(
653
+ self,
654
+ cfg: RunnableConfig,
655
+ cp: Checkpoint,
656
+ metadata: CheckpointMetadata,
657
+ new_versions: ChannelVersions,
658
+ ) -> RunnableConfig:
659
+ return self._run(self.aput(cfg, cp, metadata, new_versions))
660
+
661
+ def put_writes(
662
+ self,
663
+ cfg: RunnableConfig,
664
+ writes: Sequence[Tuple[str, Any]],
665
+ task_id: str,
666
+ task_path: str = "",
667
+ ) -> None:
668
+ self._run(self.aput_writes(cfg, writes, task_id, task_path))
669
+
670
+ def delete_thread(self, thread_id: str) -> None:
671
+ self._run(self.adelete_thread(thread_id))
672
+
673
+ def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
674
+ if current is None:
675
+ current_v = 0
676
+ elif isinstance(current, int):
677
+ current_v = current
678
+ else:
679
+ current_v = int(current.split(".")[0])
680
+ next_v = current_v + 1
681
+ next_h = random()
682
+ return f"{next_v:032}.{next_h:016}"
683
+
684
+ # ------------------------------------------------------------------ #
685
+ # Convenience factory #
686
+ # ------------------------------------------------------------------ #
687
+ @classmethod
688
+ @asynccontextmanager
689
+ async def from_base_url(
690
+ cls,
691
+ base_url: str,
692
+ *,
693
+ api_key: str | None = None,
694
+ **kwargs: Any,
695
+ ) -> AsyncIterator["AsyncAGWCheckpointSaver"]:
696
+ saver = cls(base_url, api_key=api_key, **kwargs)
697
+ try:
698
+ yield saver
699
+ finally:
700
+ if saver._client is not None:
701
+ try:
702
+ await saver._client.aclose()
703
+ except Exception as close_exc: # pragma: no cover - best effort
704
+ logger.debug("Failed to close AGW httpx.AsyncClient: %s", close_exc)
705
+ finally:
706
+ saver._client = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: agent-lab-sdk
3
- Version: 0.1.36
3
+ Version: 0.1.38
4
4
  Summary: SDK для работы с Agent Lab
5
5
  Author-email: Andrew Ohurtsov <andermirik@yandex.com>
6
6
  License: Proprietary and Confidential — All Rights Reserved
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "agent-lab-sdk"
7
- version = "0.1.36"
7
+ version = "0.1.38"
8
8
  description = "SDK для работы с Agent Lab"
9
9
  readme = "README.md"
10
10
  license = { text = "Proprietary and Confidential — All Rights Reserved" }
@@ -1,444 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import asyncio
4
- import base64
5
- import logging
6
- import os
7
- from contextlib import asynccontextmanager
8
- from random import random
9
- from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple
10
-
11
- import httpx
12
- import orjson
13
- from langchain_core.runnables import RunnableConfig
14
-
15
- from langgraph.checkpoint.base import (
16
- BaseCheckpointSaver,
17
- ChannelVersions,
18
- Checkpoint,
19
- CheckpointMetadata,
20
- CheckpointTuple,
21
- get_checkpoint_id,
22
- get_checkpoint_metadata,
23
- )
24
- from langgraph.checkpoint.serde.base import SerializerProtocol
25
- from langgraph.checkpoint.serde.encrypted import EncryptedSerializer
26
- from langgraph.checkpoint.serde.types import ChannelProtocol
27
- from .serde import Serializer
28
-
29
- __all__ = ["AsyncAGWCheckpointSaver"]
30
-
31
- logger = logging.getLogger(__name__)
32
-
33
- TYPED_KEYS = ("type", "blob")
34
-
35
-
36
- def _to_b64(b: bytes | None) -> str | None:
37
- return base64.b64encode(b).decode() if b is not None else None
38
-
39
-
40
- def _b64decode_strict(s: str) -> bytes | None:
41
- """Возвращает bytes только если строка действительно корректная base64."""
42
- try:
43
- return base64.b64decode(s, validate=True)
44
- except Exception:
45
- return None
46
-
47
-
48
- class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
49
- """Persist checkpoints in Agent-Gateway с помощью `httpx` async client."""
50
-
51
- # ---------------------------- init / ctx -------------------------
52
- def __init__(
53
- self,
54
- base_url: str = "http://localhost",
55
- *,
56
- serde: SerializerProtocol | None = None,
57
- timeout: int | float = 10,
58
- api_key: str | None = None,
59
- extra_headers: Dict[str, str] | None = None,
60
- verify: bool = True,
61
- ):
62
- if not serde:
63
- base_serde: SerializerProtocol = Serializer()
64
- # опционально оборачиваем в AES по ENV
65
- _aes_key = (
66
- os.getenv("LANGGRAPH_AES_KEY")
67
- or os.getenv("AGW_AES_KEY")
68
- or os.getenv("AES_KEY")
69
- )
70
- if _aes_key:
71
- base_serde = EncryptedSerializer.from_pycryptodome_aes(
72
- base_serde, key=_aes_key
73
- )
74
- serde = base_serde
75
- super().__init__(serde=serde)
76
- self.base_url = base_url.rstrip("/")
77
- self.timeout = timeout
78
- self.loop = asyncio.get_running_loop()
79
-
80
- self.headers: Dict[str, str] = {
81
- "Accept": "application/json",
82
- "Content-Type": "application/json",
83
- }
84
- if extra_headers:
85
- self.headers.update(extra_headers)
86
- if api_key:
87
- self.headers["Authorization"] = f"Bearer {api_key}"
88
-
89
- self._client = httpx.AsyncClient(
90
- base_url=self.base_url,
91
- headers=self.headers,
92
- timeout=self.timeout,
93
- verify=verify,
94
- trust_env=True,
95
- )
96
-
97
- async def __aenter__(self): # noqa: D401
98
- return self
99
-
100
- async def __aexit__(self, exc_type, exc, tb): # noqa: D401
101
- await self._client.aclose()
102
-
103
- # ----------------------- typed (de)serialize ---------------------
104
- def _encode_typed(self, value: Any) -> dict[str, Any]:
105
- """value -> {"type": str, "blob": base64str | null}"""
106
- t, b = self.serde.dumps_typed(value)
107
- return {"type": t, "blob": _to_b64(b)}
108
-
109
- def _decode_typed(self, obj: Any) -> Any:
110
- """{type, blob} | [type, blob] | legacy -> python."""
111
- # Новый формат: dict с ключами type/blob — только если blob валидная base64 или None
112
- if isinstance(obj, dict) and all(k in obj for k in TYPED_KEYS):
113
- t = obj.get("type")
114
- b64 = obj.get("blob")
115
- if b64 is None:
116
- return self.serde.loads_typed((t, None))
117
- if isinstance(b64, str):
118
- b = _b64decode_strict(b64)
119
- if b is not None:
120
- return self.serde.loads_typed((t, b))
121
- # если невалидно — падаем ниже на общую обработку
122
-
123
- # Допускаем tuple/list вида [type, base64] — только при валидной base64
124
- if isinstance(obj, (list, tuple)) and len(obj) == 2 and isinstance(obj[0], str):
125
- t, b64 = obj
126
- if b64 is None and t == "empty":
127
- return self.serde.loads_typed((t, None))
128
- if isinstance(b64, str):
129
- b = _b64decode_strict(b64)
130
- if b is not None:
131
- return self.serde.loads_typed((t, b))
132
- # иначе это не typed-пара
133
-
134
- # Если это строка — пробуем как base64 строго, затем как JSON-строку
135
- if isinstance(obj, str):
136
- b = _b64decode_strict(obj)
137
- if b is not None:
138
- try:
139
- return self.serde.loads(b)
140
- except Exception:
141
- pass
142
- try:
143
- return self.serde.loads(obj.encode())
144
- except Exception:
145
- return obj
146
-
147
- # dict/list -> считаем это уже JSON и грузим через serde
148
- if isinstance(obj, (dict, list)):
149
- try:
150
- return self.serde.loads(orjson.dumps(obj))
151
- except Exception:
152
- return obj
153
-
154
- # как есть пробуем через serde
155
- try:
156
- return self.serde.loads(obj)
157
- except Exception:
158
- return obj
159
-
160
- # ----------------------- config <-> api --------------------------
161
- def _to_api_config(self, cfg: RunnableConfig | None) -> Dict[str, Any]:
162
- if not cfg:
163
- return {}
164
- c = cfg.get("configurable", {})
165
- res: Dict[str, Any] = {
166
- "threadId": c.get("thread_id", ""),
167
- "checkpointNs": c.get("checkpoint_ns", ""),
168
- }
169
- if cid := c.get("checkpoint_id"):
170
- res["checkpointId"] = cid
171
- if ts := c.get("thread_ts"):
172
- res["threadTs"] = ts
173
- return res
174
-
175
- # --------------------- checkpoint (de)ser ------------------------
176
- def _encode_cp(self, cp: Checkpoint) -> Dict[str, Any]:
177
- channel_values = {
178
- k: self._encode_typed(v) for k, v in cp.get("channel_values", {}).items()
179
- }
180
- pending = []
181
- for item in cp.get("pending_sends", []) or []:
182
- try:
183
- channel, value = item
184
- pending.append({"channel": channel, **self._encode_typed(value)})
185
- except Exception:
186
- continue
187
- return {
188
- "v": cp["v"],
189
- "id": cp["id"],
190
- "ts": cp["ts"],
191
- "channelValues": channel_values,
192
- "channelVersions": cp["channel_versions"],
193
- "versionsSeen": cp["versions_seen"],
194
- "pendingSends": pending,
195
- }
196
-
197
- def _decode_cp(self, raw: Dict[str, Any]) -> Checkpoint:
198
- cv_raw = raw.get("channelValues") or {}
199
- channel_values = {k: self._decode_typed(v) for k, v in cv_raw.items()}
200
- ps_raw = raw.get("pendingSends") or []
201
- pending_sends = []
202
- for obj in ps_raw:
203
- # ожидаем {channel, type, blob}
204
- if isinstance(obj, dict) and "channel" in obj:
205
- ch = obj["channel"]
206
- typed = {k: obj[k] for k in obj.keys() if k in TYPED_KEYS}
207
- val = self._decode_typed(typed)
208
- pending_sends.append((ch, val))
209
- elif isinstance(obj, (list, tuple)) and len(obj) == 2:
210
- ch, val = obj
211
- pending_sends.append((ch, self._decode_typed(val)))
212
-
213
- return Checkpoint(
214
- v=raw["v"],
215
- id=raw["id"],
216
- ts=raw["ts"],
217
- channel_values=channel_values,
218
- channel_versions=raw["channelVersions"],
219
- versions_seen=raw["versionsSeen"],
220
- pending_sends=pending_sends,
221
- )
222
-
223
- def _decode_config(self, raw: Dict[str, Any] | None) -> Optional[RunnableConfig]:
224
- if not raw:
225
- return None
226
- return RunnableConfig(
227
- tags=raw.get("tags"),
228
- metadata=raw.get("metadata"),
229
- callbacks=raw.get("callbacks"),
230
- run_name=raw.get("run_name"),
231
- max_concurrency=raw.get("max_concurrency"),
232
- recursion_limit=raw.get("recursion_limit"),
233
- configurable=self._decode_configurable(raw.get("configurable") or {}),
234
- )
235
-
236
- def _decode_configurable(self, raw: Dict[str, Any]) -> dict[str, Any]:
237
- return {
238
- "thread_id": raw.get("threadId"),
239
- "thread_ts": raw.get("threadTs"),
240
- "checkpoint_ns": raw.get("checkpointNs"),
241
- "checkpoint_id": raw.get("checkpointId"),
242
- }
243
-
244
- # metadata (de)ser — передаём как есть (JSON-совместимый словарь)
245
- def _enc_meta(self, md: CheckpointMetadata) -> CheckpointMetadata:
246
- return md or {}
247
-
248
- def _dec_meta(self, md: Any) -> Any:
249
- return md
250
-
251
- # ------------------------ HTTP wrapper ---------------------------
252
- async def _http(self, method: str, path: str, **kw) -> httpx.Response:
253
- if "json" in kw:
254
- payload = kw.pop("json")
255
- kw["data"] = orjson.dumps(payload)
256
- logger.debug("AGW HTTP payload: %s", kw["data"].decode())
257
- return await self._client.request(method, path, **kw)
258
-
259
- # -------------------- api -> CheckpointTuple ----------------------
260
- def _to_tuple(self, node: Dict[str, Any]) -> CheckpointTuple:
261
- pending_writes = None
262
- raw_pw = node.get("pendingWrites")
263
- if raw_pw:
264
- decoded: list[tuple[str, str, Any]] = []
265
- for w in raw_pw:
266
- if isinstance(w, dict) and "first" in w and "second" in w:
267
- # ожидаем формат, который возвращает бек: first=task_id, second=channel, third=typed
268
- task_id = w["first"]
269
- channel = w["second"]
270
- tv = w.get("third")
271
- value = self._decode_typed(tv)
272
- decoded.append((task_id, channel, value))
273
- elif isinstance(w, (list, tuple)):
274
- try:
275
- first, channel, tv = w
276
- decoded.append((first, channel, self._decode_typed(tv)))
277
- except Exception: # pragma: no cover
278
- continue
279
- pending_writes = decoded
280
-
281
- return CheckpointTuple(
282
- config=self._decode_config(node.get("config")),
283
- checkpoint=self._decode_cp(node["checkpoint"]),
284
- metadata=self._dec_meta(node.get("metadata")),
285
- parent_config=self._decode_config(node.get("parentConfig")),
286
- pending_writes=pending_writes,
287
- )
288
-
289
- # =================================================================
290
- # async-методы BaseCheckpointSaver
291
- # =================================================================
292
- async def aget_tuple(self, cfg: RunnableConfig) -> CheckpointTuple | None:
293
- cid = get_checkpoint_id(cfg)
294
- api_cfg = self._to_api_config(cfg)
295
- tid = api_cfg.get("threadId")
296
-
297
- if cid:
298
- path = f"/checkpoint/{tid}/{cid}"
299
- params = {"checkpointNs": api_cfg.get("checkpointNs", "")}
300
- else:
301
- path = f"/checkpoint/{tid}"
302
- params = None
303
-
304
- resp = await self._http("GET", path, params=params)
305
- logger.debug("AGW aget_tuple response: %s", resp.text)
306
-
307
- if not resp.text or resp.status_code in (404, 406):
308
- return None
309
- resp.raise_for_status()
310
- return self._to_tuple(resp.json())
311
-
312
- async def alist(
313
- self,
314
- cfg: RunnableConfig | None,
315
- *,
316
- filter: Dict[str, Any] | None = None,
317
- before: RunnableConfig | None = None,
318
- limit: int | None = None,
319
- ) -> AsyncIterator[CheckpointTuple]:
320
- payload = {
321
- "config": self._to_api_config(cfg) if cfg else None,
322
- "filter": filter,
323
- "before": self._to_api_config(before) if before else None,
324
- "limit": limit,
325
- }
326
- resp = await self._http("POST", "/checkpoint/list", json=payload)
327
- logger.debug("AGW alist response: %s", resp.text)
328
- resp.raise_for_status()
329
- for item in resp.json():
330
- yield self._to_tuple(item)
331
-
332
- async def aput(
333
- self,
334
- cfg: RunnableConfig,
335
- cp: Checkpoint,
336
- metadata: CheckpointMetadata,
337
- new_versions: ChannelVersions,
338
- ) -> RunnableConfig:
339
- payload = {
340
- "config": self._to_api_config(cfg),
341
- "checkpoint": self._encode_cp(cp),
342
- "metadata": get_checkpoint_metadata(cfg, metadata),
343
- "newVersions": new_versions,
344
- }
345
- resp = await self._http("POST", "/checkpoint", json=payload)
346
- logger.debug("AGW aput response: %s", resp.text)
347
- resp.raise_for_status()
348
- return resp.json()["config"]
349
-
350
- async def aput_writes(
351
- self,
352
- cfg: RunnableConfig,
353
- writes: Sequence[Tuple[str, Any]],
354
- task_id: str,
355
- task_path: str = "",
356
- ) -> None:
357
- enc = [{"first": ch, "second": self._encode_typed(v)} for ch, v in writes]
358
- payload = {
359
- "config": self._to_api_config(cfg),
360
- "writes": enc,
361
- "taskId": task_id,
362
- "taskPath": task_path,
363
- }
364
- resp = await self._http("POST", "/checkpoint/writes", json=payload)
365
- logger.debug("AGW aput_writes response: %s", resp.text)
366
- resp.raise_for_status()
367
-
368
- async def adelete_thread(self, thread_id: str) -> None:
369
- resp = await self._http("DELETE", f"/checkpoint/{thread_id}")
370
- resp.raise_for_status()
371
-
372
- # =================================================================
373
- # sync-обёртки
374
- # =================================================================
375
- def _run(self, coro):
376
- return asyncio.run_coroutine_threadsafe(coro, self.loop).result()
377
-
378
- def list(
379
- self,
380
- cfg: RunnableConfig | None,
381
- *,
382
- filter: Dict[str, Any] | None = None,
383
- before: RunnableConfig | None = None,
384
- limit: int | None = None,
385
- ) -> Iterator[CheckpointTuple]:
386
- aiter_ = self.alist(cfg, filter=filter, before=before, limit=limit)
387
- while True:
388
- try:
389
- yield self._run(anext(aiter_))
390
- except StopAsyncIteration:
391
- break
392
-
393
- def get_tuple(self, cfg: RunnableConfig) -> CheckpointTuple | None:
394
- return self._run(self.aget_tuple(cfg))
395
-
396
- def put(
397
- self,
398
- cfg: RunnableConfig,
399
- cp: Checkpoint,
400
- metadata: CheckpointMetadata,
401
- new_versions: ChannelVersions,
402
- ) -> RunnableConfig:
403
- return self._run(self.aput(cfg, cp, metadata, new_versions))
404
-
405
- def put_writes(
406
- self,
407
- cfg: RunnableConfig,
408
- writes: Sequence[Tuple[str, Any]],
409
- task_id: str,
410
- task_path: str = "",
411
- ) -> None:
412
- self._run(self.aput_writes(cfg, writes, task_id, task_path))
413
-
414
- def delete_thread(self, thread_id: str) -> None:
415
- self._run(self.adelete_thread(thread_id))
416
-
417
- def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
418
- if current is None:
419
- current_v = 0
420
- elif isinstance(current, int):
421
- current_v = current
422
- else:
423
- current_v = int(current.split(".")[0])
424
- next_v = current_v + 1
425
- next_h = random()
426
- return f"{next_v:032}.{next_h:016}"
427
-
428
- # ------------------------------------------------------------------ #
429
- # Convenience factory #
430
- # ------------------------------------------------------------------ #
431
- @classmethod
432
- @asynccontextmanager
433
- async def from_base_url(
434
- cls,
435
- base_url: str,
436
- *,
437
- api_key: str | None = None,
438
- **kwargs: Any,
439
- ) -> AsyncIterator["AsyncAGWCheckpointSaver"]:
440
- saver = cls(base_url, api_key=api_key, **kwargs)
441
- try:
442
- yield saver
443
- finally:
444
- await saver._client.aclose()
File without changes
File without changes
File without changes