agent-lab-sdk 0.1.36__py3-none-any.whl → 0.1.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agent-lab-sdk might be problematic. Click here for more details.
- agent_lab_sdk/langgraph/checkpoint/agw_saver.py +375 -113
- {agent_lab_sdk-0.1.36.dist-info → agent_lab_sdk-0.1.38.dist-info}/METADATA +1 -1
- {agent_lab_sdk-0.1.36.dist-info → agent_lab_sdk-0.1.38.dist-info}/RECORD +6 -6
- {agent_lab_sdk-0.1.36.dist-info → agent_lab_sdk-0.1.38.dist-info}/WHEEL +0 -0
- {agent_lab_sdk-0.1.36.dist-info → agent_lab_sdk-0.1.38.dist-info}/licenses/LICENSE +0 -0
- {agent_lab_sdk-0.1.36.dist-info → agent_lab_sdk-0.1.38.dist-info}/top_level.txt +0 -0
|
@@ -6,10 +6,12 @@ import logging
|
|
|
6
6
|
import os
|
|
7
7
|
from contextlib import asynccontextmanager
|
|
8
8
|
from random import random
|
|
9
|
-
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple
|
|
9
|
+
from typing import Any, AsyncIterator, Dict, Iterable, Iterator, Optional, Sequence, Tuple
|
|
10
10
|
|
|
11
|
-
import httpx
|
|
12
11
|
import orjson
|
|
12
|
+
from langgraph.checkpoint.serde.types import ChannelProtocol
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
13
15
|
from langchain_core.runnables import RunnableConfig
|
|
14
16
|
|
|
15
17
|
from langgraph.checkpoint.base import (
|
|
@@ -23,27 +25,53 @@ from langgraph.checkpoint.base import (
|
|
|
23
25
|
)
|
|
24
26
|
from langgraph.checkpoint.serde.base import SerializerProtocol
|
|
25
27
|
from langgraph.checkpoint.serde.encrypted import EncryptedSerializer
|
|
26
|
-
|
|
28
|
+
|
|
27
29
|
from .serde import Serializer
|
|
30
|
+
from agent_lab_sdk.metrics import get_metric
|
|
28
31
|
|
|
29
32
|
__all__ = ["AsyncAGWCheckpointSaver"]
|
|
30
33
|
|
|
31
34
|
logger = logging.getLogger(__name__)
|
|
32
35
|
|
|
33
|
-
|
|
34
|
-
|
|
36
|
+
AGW_METRIC_LABELS = ["method", "endpoint"]
|
|
37
|
+
AGW_HTTP_SUCCESS = get_metric(
|
|
38
|
+
"counter",
|
|
39
|
+
"agw_http_success_total",
|
|
40
|
+
"Number of successful AGW HTTP requests",
|
|
41
|
+
labelnames=AGW_METRIC_LABELS,
|
|
42
|
+
)
|
|
43
|
+
AGW_HTTP_ERROR = get_metric(
|
|
44
|
+
"counter",
|
|
45
|
+
"agw_http_error_total",
|
|
46
|
+
"Number of failed AGW HTTP request attempts",
|
|
47
|
+
labelnames=AGW_METRIC_LABELS,
|
|
48
|
+
)
|
|
49
|
+
AGW_HTTP_FINAL_ERROR = get_metric(
|
|
50
|
+
"counter",
|
|
51
|
+
"agw_http_final_error_total",
|
|
52
|
+
"Number of AGW HTTP requests that failed after retries",
|
|
53
|
+
labelnames=AGW_METRIC_LABELS,
|
|
54
|
+
)
|
|
35
55
|
|
|
36
|
-
|
|
37
|
-
return base64.b64encode(b).decode() if b is not None else None
|
|
56
|
+
TYPED_KEYS = ("type", "blob")
|
|
38
57
|
|
|
39
58
|
|
|
40
|
-
def _b64decode_strict(
|
|
41
|
-
"""Возвращает bytes только если строка действительно корректная base64."""
|
|
59
|
+
def _b64decode_strict(value: str) -> bytes | None:
|
|
42
60
|
try:
|
|
43
|
-
return base64.b64decode(
|
|
61
|
+
return base64.b64decode(value, validate=True)
|
|
44
62
|
except Exception:
|
|
45
63
|
return None
|
|
46
64
|
|
|
65
|
+
# ------------------------------------------------------------------ #
|
|
66
|
+
# helpers for Py < 3.10
|
|
67
|
+
# ------------------------------------------------------------------ #
|
|
68
|
+
try:
|
|
69
|
+
anext # type: ignore[name-defined]
|
|
70
|
+
except NameError: # pragma: no cover
|
|
71
|
+
|
|
72
|
+
async def anext(it):
|
|
73
|
+
return await it.__anext__()
|
|
74
|
+
|
|
47
75
|
|
|
48
76
|
class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
49
77
|
"""Persist checkpoints in Agent-Gateway с помощью `httpx` async client."""
|
|
@@ -61,15 +89,14 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
61
89
|
):
|
|
62
90
|
if not serde:
|
|
63
91
|
base_serde: SerializerProtocol = Serializer()
|
|
64
|
-
|
|
65
|
-
_aes_key = (
|
|
92
|
+
aes_key = (
|
|
66
93
|
os.getenv("LANGGRAPH_AES_KEY")
|
|
67
94
|
or os.getenv("AGW_AES_KEY")
|
|
68
95
|
or os.getenv("AES_KEY")
|
|
69
96
|
)
|
|
70
|
-
if
|
|
97
|
+
if aes_key:
|
|
71
98
|
base_serde = EncryptedSerializer.from_pycryptodome_aes(
|
|
72
|
-
base_serde, key=
|
|
99
|
+
base_serde, key=aes_key
|
|
73
100
|
)
|
|
74
101
|
serde = base_serde
|
|
75
102
|
super().__init__(serde=serde)
|
|
@@ -77,6 +104,66 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
77
104
|
self.timeout = timeout
|
|
78
105
|
self.loop = asyncio.get_running_loop()
|
|
79
106
|
|
|
107
|
+
raw_attempts = os.getenv("AGW_HTTP_MAX_RETRIES")
|
|
108
|
+
if raw_attempts is None:
|
|
109
|
+
self.retry_max_attempts = 3
|
|
110
|
+
else:
|
|
111
|
+
try:
|
|
112
|
+
self.retry_max_attempts = max(int(raw_attempts), 1)
|
|
113
|
+
except ValueError:
|
|
114
|
+
logger.warning(
|
|
115
|
+
"Env %s expected int, got %r; using default %s",
|
|
116
|
+
"AGW_HTTP_MAX_RETRIES",
|
|
117
|
+
raw_attempts,
|
|
118
|
+
3,
|
|
119
|
+
)
|
|
120
|
+
self.retry_max_attempts = 3
|
|
121
|
+
|
|
122
|
+
raw_backoff_base = os.getenv("AGW_HTTP_RETRY_BACKOFF_BASE")
|
|
123
|
+
if raw_backoff_base is None:
|
|
124
|
+
self.retry_backoff_base = 0.5
|
|
125
|
+
else:
|
|
126
|
+
try:
|
|
127
|
+
self.retry_backoff_base = max(float(raw_backoff_base), 0.0)
|
|
128
|
+
except ValueError:
|
|
129
|
+
logger.warning(
|
|
130
|
+
"Env %s expected float, got %r; using default %.3f",
|
|
131
|
+
"AGW_HTTP_RETRY_BACKOFF_BASE",
|
|
132
|
+
raw_backoff_base,
|
|
133
|
+
0.5,
|
|
134
|
+
)
|
|
135
|
+
self.retry_backoff_base = 0.5
|
|
136
|
+
|
|
137
|
+
raw_backoff_max = os.getenv("AGW_HTTP_RETRY_BACKOFF_MAX")
|
|
138
|
+
if raw_backoff_max is None:
|
|
139
|
+
self.retry_backoff_max = 5.0
|
|
140
|
+
else:
|
|
141
|
+
try:
|
|
142
|
+
self.retry_backoff_max = max(float(raw_backoff_max), 0.0)
|
|
143
|
+
except ValueError:
|
|
144
|
+
logger.warning(
|
|
145
|
+
"Env %s expected float, got %r; using default %.3f",
|
|
146
|
+
"AGW_HTTP_RETRY_BACKOFF_MAX",
|
|
147
|
+
raw_backoff_max,
|
|
148
|
+
5.0,
|
|
149
|
+
)
|
|
150
|
+
self.retry_backoff_max = 5.0
|
|
151
|
+
|
|
152
|
+
raw_jitter = os.getenv("AGW_HTTP_RETRY_JITTER")
|
|
153
|
+
if raw_jitter is None:
|
|
154
|
+
self.retry_jitter = 0.25
|
|
155
|
+
else:
|
|
156
|
+
try:
|
|
157
|
+
self.retry_jitter = max(float(raw_jitter), 0.0)
|
|
158
|
+
except ValueError:
|
|
159
|
+
logger.warning(
|
|
160
|
+
"Env %s expected float, got %r; using default %.3f",
|
|
161
|
+
"AGW_HTTP_RETRY_JITTER",
|
|
162
|
+
raw_jitter,
|
|
163
|
+
0.25,
|
|
164
|
+
)
|
|
165
|
+
self.retry_jitter = 0.25
|
|
166
|
+
|
|
80
167
|
self.headers: Dict[str, str] = {
|
|
81
168
|
"Accept": "application/json",
|
|
82
169
|
"Content-Type": "application/json",
|
|
@@ -86,77 +173,151 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
86
173
|
if api_key:
|
|
87
174
|
self.headers["Authorization"] = f"Bearer {api_key}"
|
|
88
175
|
|
|
89
|
-
self.
|
|
176
|
+
self._verify = verify
|
|
177
|
+
self._client: httpx.AsyncClient | None = None
|
|
178
|
+
|
|
179
|
+
def _create_client(self) -> httpx.AsyncClient:
|
|
180
|
+
return httpx.AsyncClient(
|
|
90
181
|
base_url=self.base_url,
|
|
91
182
|
headers=self.headers,
|
|
92
183
|
timeout=self.timeout,
|
|
93
|
-
verify=
|
|
184
|
+
verify=self._verify,
|
|
94
185
|
trust_env=True,
|
|
95
186
|
)
|
|
96
187
|
|
|
188
|
+
def _ensure_client(self) -> httpx.AsyncClient:
|
|
189
|
+
client = self._client
|
|
190
|
+
if client is None or client.is_closed:
|
|
191
|
+
if client is not None and client.is_closed:
|
|
192
|
+
logger.debug("Recreating closed httpx.AsyncClient for AGW")
|
|
193
|
+
client = self._create_client()
|
|
194
|
+
self._client = client
|
|
195
|
+
return client
|
|
196
|
+
|
|
197
|
+
def _compute_retry_delay(self, attempt: int) -> float:
|
|
198
|
+
if attempt <= 0:
|
|
199
|
+
attempt = 1
|
|
200
|
+
if self.retry_backoff_base <= 0:
|
|
201
|
+
delay = 0.0
|
|
202
|
+
else:
|
|
203
|
+
delay = self.retry_backoff_base * (2 ** (attempt - 1))
|
|
204
|
+
if self.retry_backoff_max > 0:
|
|
205
|
+
delay = min(delay, self.retry_backoff_max)
|
|
206
|
+
if self.retry_jitter > 0:
|
|
207
|
+
delay += self.retry_jitter * random()
|
|
208
|
+
return delay
|
|
209
|
+
|
|
97
210
|
async def __aenter__(self): # noqa: D401
|
|
98
211
|
return self
|
|
99
212
|
|
|
100
213
|
async def __aexit__(self, exc_type, exc, tb): # noqa: D401
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
# ----------------------- typed (de)serialize ---------------------
|
|
104
|
-
def _encode_typed(self, value: Any) -> dict[str, Any]:
|
|
105
|
-
"""value -> {"type": str, "blob": base64str | null}"""
|
|
106
|
-
t, b = self.serde.dumps_typed(value)
|
|
107
|
-
return {"type": t, "blob": _to_b64(b)}
|
|
108
|
-
|
|
109
|
-
def _decode_typed(self, obj: Any) -> Any:
|
|
110
|
-
"""{type, blob} | [type, blob] | legacy -> python."""
|
|
111
|
-
# Новый формат: dict с ключами type/blob — только если blob валидная base64 или None
|
|
112
|
-
if isinstance(obj, dict) and all(k in obj for k in TYPED_KEYS):
|
|
113
|
-
t = obj.get("type")
|
|
114
|
-
b64 = obj.get("blob")
|
|
115
|
-
if b64 is None:
|
|
116
|
-
return self.serde.loads_typed((t, None))
|
|
117
|
-
if isinstance(b64, str):
|
|
118
|
-
b = _b64decode_strict(b64)
|
|
119
|
-
if b is not None:
|
|
120
|
-
return self.serde.loads_typed((t, b))
|
|
121
|
-
# если невалидно — падаем ниже на общую обработку
|
|
122
|
-
|
|
123
|
-
# Допускаем tuple/list вида [type, base64] — только при валидной base64
|
|
124
|
-
if isinstance(obj, (list, tuple)) and len(obj) == 2 and isinstance(obj[0], str):
|
|
125
|
-
t, b64 = obj
|
|
126
|
-
if b64 is None and t == "empty":
|
|
127
|
-
return self.serde.loads_typed((t, None))
|
|
128
|
-
if isinstance(b64, str):
|
|
129
|
-
b = _b64decode_strict(b64)
|
|
130
|
-
if b is not None:
|
|
131
|
-
return self.serde.loads_typed((t, b))
|
|
132
|
-
# иначе это не typed-пара
|
|
133
|
-
|
|
134
|
-
# Если это строка — пробуем как base64 строго, затем как JSON-строку
|
|
135
|
-
if isinstance(obj, str):
|
|
136
|
-
b = _b64decode_strict(obj)
|
|
137
|
-
if b is not None:
|
|
138
|
-
try:
|
|
139
|
-
return self.serde.loads(b)
|
|
140
|
-
except Exception:
|
|
141
|
-
pass
|
|
214
|
+
if self._client is not None:
|
|
142
215
|
try:
|
|
143
|
-
|
|
216
|
+
await self._client.aclose()
|
|
217
|
+
except Exception as close_exc: # pragma: no cover - best effort
|
|
218
|
+
logger.debug("Failed to close AGW httpx.AsyncClient: %s", close_exc)
|
|
219
|
+
finally:
|
|
220
|
+
self._client = None
|
|
221
|
+
|
|
222
|
+
# ----------------------- universal dump/load ---------------------
|
|
223
|
+
# def _safe_dump(self, obj: Any) -> Any:
|
|
224
|
+
# """self.serde.dump → гарантированная JSON-строка."""
|
|
225
|
+
# dumped = self.serde.dumps(obj)
|
|
226
|
+
# if isinstance(dumped, (bytes, bytearray)):
|
|
227
|
+
# return base64.b64encode(dumped).decode() # str
|
|
228
|
+
# return dumped # уже json-совместимо
|
|
229
|
+
|
|
230
|
+
def _safe_dump(self, obj: Any) -> Any:
|
|
231
|
+
"""bytes → python-object; fallback base64 для реально бинарных данных."""
|
|
232
|
+
dumped = self.serde.dumps(obj)
|
|
233
|
+
if isinstance(dumped, (bytes, bytearray)):
|
|
234
|
+
try:
|
|
235
|
+
# 1) bytes → str
|
|
236
|
+
s = dumped.decode()
|
|
237
|
+
# 2) str JSON → python (list/dict/scalar)
|
|
238
|
+
return orjson.loads(s)
|
|
239
|
+
except (UnicodeDecodeError, orjson.JSONDecodeError):
|
|
240
|
+
# не UTF-8 или не JSON → base64
|
|
241
|
+
return base64.b64encode(dumped).decode()
|
|
242
|
+
return dumped
|
|
243
|
+
|
|
244
|
+
def _safe_load(self, obj: Any) -> Any:
|
|
245
|
+
if obj is None:
|
|
246
|
+
return None
|
|
247
|
+
|
|
248
|
+
if isinstance(obj, dict):
|
|
249
|
+
if all(k in obj for k in TYPED_KEYS):
|
|
250
|
+
t = obj.get("type")
|
|
251
|
+
blob = obj.get("blob")
|
|
252
|
+
if blob is None:
|
|
253
|
+
try:
|
|
254
|
+
return self.serde.loads_typed((t, None))
|
|
255
|
+
except Exception:
|
|
256
|
+
return obj
|
|
257
|
+
if isinstance(blob, str):
|
|
258
|
+
payload = _b64decode_strict(blob)
|
|
259
|
+
if payload is not None:
|
|
260
|
+
try:
|
|
261
|
+
return self.serde.loads_typed((t, payload))
|
|
262
|
+
except Exception:
|
|
263
|
+
# fall back to generic handling below
|
|
264
|
+
pass
|
|
265
|
+
try:
|
|
266
|
+
return self.serde.loads(orjson.dumps(obj))
|
|
144
267
|
except Exception:
|
|
145
268
|
return obj
|
|
146
269
|
|
|
147
|
-
|
|
148
|
-
|
|
270
|
+
if isinstance(obj, (list, tuple)):
|
|
271
|
+
if (
|
|
272
|
+
len(obj) == 2
|
|
273
|
+
and isinstance(obj[0], str)
|
|
274
|
+
and (obj[1] is None or isinstance(obj[1], str))
|
|
275
|
+
):
|
|
276
|
+
blob = obj[1]
|
|
277
|
+
if blob is None:
|
|
278
|
+
try:
|
|
279
|
+
return self.serde.loads_typed((obj[0], None))
|
|
280
|
+
except Exception:
|
|
281
|
+
pass
|
|
282
|
+
elif isinstance(blob, str):
|
|
283
|
+
payload = _b64decode_strict(blob)
|
|
284
|
+
if payload is not None:
|
|
285
|
+
try:
|
|
286
|
+
return self.serde.loads_typed((obj[0], payload))
|
|
287
|
+
except Exception:
|
|
288
|
+
pass
|
|
149
289
|
try:
|
|
150
|
-
return self.serde.loads(orjson.dumps(obj))
|
|
290
|
+
return self.serde.loads(orjson.dumps(list(obj)))
|
|
151
291
|
except Exception:
|
|
152
292
|
return obj
|
|
153
293
|
|
|
154
|
-
|
|
294
|
+
if isinstance(obj, str):
|
|
295
|
+
try:
|
|
296
|
+
return self.serde.loads(obj.encode())
|
|
297
|
+
except Exception:
|
|
298
|
+
payload = _b64decode_strict(obj)
|
|
299
|
+
if payload is not None:
|
|
300
|
+
try:
|
|
301
|
+
return self.serde.loads(payload)
|
|
302
|
+
except Exception:
|
|
303
|
+
pass
|
|
304
|
+
return obj
|
|
305
|
+
|
|
155
306
|
try:
|
|
156
307
|
return self.serde.loads(obj)
|
|
157
308
|
except Exception:
|
|
158
309
|
return obj
|
|
159
310
|
|
|
311
|
+
# def _safe_load(self, obj: Any) -> Any:
|
|
312
|
+
# """Обратная операция к _safe_dump."""
|
|
313
|
+
# if isinstance(obj, str):
|
|
314
|
+
# try:
|
|
315
|
+
# return self.serde.load(base64.b64decode(obj))
|
|
316
|
+
# except Exception:
|
|
317
|
+
# # не base64 — обычная строка
|
|
318
|
+
# return self.serde.load(obj)
|
|
319
|
+
# return self.serde.load(obj)
|
|
320
|
+
|
|
160
321
|
# ----------------------- config <-> api --------------------------
|
|
161
322
|
def _to_api_config(self, cfg: RunnableConfig | None) -> Dict[str, Any]:
|
|
162
323
|
if not cfg:
|
|
@@ -174,47 +335,44 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
174
335
|
|
|
175
336
|
# --------------------- checkpoint (de)ser ------------------------
|
|
176
337
|
def _encode_cp(self, cp: Checkpoint) -> Dict[str, Any]:
|
|
177
|
-
|
|
178
|
-
k: self._encode_typed(v) for k, v in cp.get("channel_values", {}).items()
|
|
179
|
-
}
|
|
180
|
-
pending = []
|
|
338
|
+
pending: list[Any] = []
|
|
181
339
|
for item in cp.get("pending_sends", []) or []:
|
|
182
340
|
try:
|
|
183
341
|
channel, value = item
|
|
184
|
-
pending.append({"channel": channel, **self._encode_typed(value)})
|
|
185
342
|
except Exception:
|
|
343
|
+
pending.append(item)
|
|
186
344
|
continue
|
|
345
|
+
pending.append([channel, self._safe_dump(value)])
|
|
187
346
|
return {
|
|
188
347
|
"v": cp["v"],
|
|
189
348
|
"id": cp["id"],
|
|
190
349
|
"ts": cp["ts"],
|
|
191
|
-
"channelValues": channel_values,
|
|
350
|
+
"channelValues": {k: self._safe_dump(v) for k, v in cp["channel_values"].items()},
|
|
192
351
|
"channelVersions": cp["channel_versions"],
|
|
193
352
|
"versionsSeen": cp["versions_seen"],
|
|
194
353
|
"pendingSends": pending,
|
|
195
354
|
}
|
|
196
355
|
|
|
197
356
|
def _decode_cp(self, raw: Dict[str, Any]) -> Checkpoint:
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
ps_raw = raw.get("pendingSends") or []
|
|
201
|
-
pending_sends = []
|
|
202
|
-
for obj in ps_raw:
|
|
203
|
-
# ожидаем {channel, type, blob}
|
|
357
|
+
pending_sends: list[Tuple[str, Any]] = []
|
|
358
|
+
for obj in raw.get("pendingSends", []) or []:
|
|
204
359
|
if isinstance(obj, dict) and "channel" in obj:
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
360
|
+
channel = obj["channel"]
|
|
361
|
+
value_payload: Any = obj.get("value")
|
|
362
|
+
if value_payload is None and all(k in obj for k in TYPED_KEYS):
|
|
363
|
+
value_payload = {k: obj[k] for k in TYPED_KEYS}
|
|
364
|
+
pending_sends.append((channel, self._safe_load(value_payload)))
|
|
365
|
+
elif isinstance(obj, (list, tuple)) and len(obj) >= 2:
|
|
366
|
+
channel = obj[0]
|
|
367
|
+
value_payload = obj[1]
|
|
368
|
+
pending_sends.append((channel, self._safe_load(value_payload)))
|
|
369
|
+
else:
|
|
370
|
+
pending_sends.append(obj) # сохраняем как есть, если формат неизвестен
|
|
213
371
|
return Checkpoint(
|
|
214
372
|
v=raw["v"],
|
|
215
373
|
id=raw["id"],
|
|
216
374
|
ts=raw["ts"],
|
|
217
|
-
channel_values=
|
|
375
|
+
channel_values={k: self._safe_load(v) for k, v in raw["channelValues"].items()},
|
|
218
376
|
channel_versions=raw["channelVersions"],
|
|
219
377
|
versions_seen=raw["versionsSeen"],
|
|
220
378
|
pending_sends=pending_sends,
|
|
@@ -238,52 +396,148 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
238
396
|
"thread_id": raw.get("threadId"),
|
|
239
397
|
"thread_ts": raw.get("threadTs"),
|
|
240
398
|
"checkpoint_ns": raw.get("checkpointNs"),
|
|
241
|
-
"checkpoint_id": raw.get("checkpointId")
|
|
399
|
+
"checkpoint_id": raw.get("checkpointId")
|
|
242
400
|
}
|
|
243
401
|
|
|
244
|
-
# metadata (de)ser
|
|
402
|
+
# metadata (de)ser
|
|
245
403
|
def _enc_meta(self, md: CheckpointMetadata) -> CheckpointMetadata:
|
|
246
|
-
|
|
404
|
+
if not md:
|
|
405
|
+
return {}
|
|
406
|
+
out: CheckpointMetadata = {}
|
|
407
|
+
for k, v in md.items():
|
|
408
|
+
out[k] = self._enc_meta(v) if isinstance(v, dict) else self._safe_dump(v) # type: ignore[assignment]
|
|
409
|
+
return out
|
|
247
410
|
|
|
248
411
|
def _dec_meta(self, md: Any) -> Any:
|
|
249
|
-
|
|
412
|
+
if isinstance(md, dict):
|
|
413
|
+
return {k: self._dec_meta(v) for k, v in md.items()}
|
|
414
|
+
return self._safe_load(md)
|
|
250
415
|
|
|
251
416
|
# ------------------------ HTTP wrapper ---------------------------
|
|
252
|
-
async def _http(
|
|
417
|
+
async def _http(
|
|
418
|
+
self,
|
|
419
|
+
method: str,
|
|
420
|
+
path: str,
|
|
421
|
+
*,
|
|
422
|
+
ok_statuses: Iterable[int] | None = None,
|
|
423
|
+
**kw,
|
|
424
|
+
) -> httpx.Response:
|
|
253
425
|
if "json" in kw:
|
|
254
426
|
payload = kw.pop("json")
|
|
255
427
|
kw["data"] = orjson.dumps(payload)
|
|
256
428
|
logger.debug("AGW HTTP payload: %s", kw["data"].decode())
|
|
257
|
-
|
|
429
|
+
|
|
430
|
+
ok_set = set(ok_statuses) if ok_statuses is not None else set()
|
|
431
|
+
|
|
432
|
+
attempt = 1
|
|
433
|
+
while True:
|
|
434
|
+
client = self._ensure_client()
|
|
435
|
+
try:
|
|
436
|
+
resp = await client.request(method, path, **kw)
|
|
437
|
+
except httpx.RequestError as exc:
|
|
438
|
+
AGW_HTTP_ERROR.labels(method, path).inc()
|
|
439
|
+
logger.warning(
|
|
440
|
+
"AGW request %s %s failed on attempt %d/%d: %s",
|
|
441
|
+
method,
|
|
442
|
+
path,
|
|
443
|
+
attempt,
|
|
444
|
+
self.retry_max_attempts,
|
|
445
|
+
exc,
|
|
446
|
+
)
|
|
447
|
+
if attempt >= self.retry_max_attempts:
|
|
448
|
+
AGW_HTTP_FINAL_ERROR.labels(method, path).inc()
|
|
449
|
+
if self._client is not None:
|
|
450
|
+
try:
|
|
451
|
+
await self._client.aclose()
|
|
452
|
+
except Exception as close_exc: # pragma: no cover
|
|
453
|
+
logger.debug(
|
|
454
|
+
"Failed to close AGW httpx.AsyncClient: %s",
|
|
455
|
+
close_exc,
|
|
456
|
+
)
|
|
457
|
+
finally:
|
|
458
|
+
self._client = None
|
|
459
|
+
raise
|
|
460
|
+
|
|
461
|
+
if self._client is not None:
|
|
462
|
+
try:
|
|
463
|
+
await self._client.aclose()
|
|
464
|
+
except Exception as close_exc: # pragma: no cover
|
|
465
|
+
logger.debug(
|
|
466
|
+
"Failed to close AGW httpx.AsyncClient: %s",
|
|
467
|
+
close_exc,
|
|
468
|
+
)
|
|
469
|
+
finally:
|
|
470
|
+
self._client = None
|
|
471
|
+
delay = self._compute_retry_delay(attempt)
|
|
472
|
+
if delay > 0:
|
|
473
|
+
await asyncio.sleep(delay)
|
|
474
|
+
attempt += 1
|
|
475
|
+
continue
|
|
476
|
+
|
|
477
|
+
status = resp.status_code
|
|
478
|
+
if status < 400 or status in ok_set:
|
|
479
|
+
AGW_HTTP_SUCCESS.labels(method, path).inc()
|
|
480
|
+
return resp
|
|
481
|
+
|
|
482
|
+
AGW_HTTP_ERROR.labels(method, path).inc()
|
|
483
|
+
if status in (404, 406):
|
|
484
|
+
AGW_HTTP_FINAL_ERROR.labels(method, path).inc()
|
|
485
|
+
return resp
|
|
486
|
+
|
|
487
|
+
if attempt >= self.retry_max_attempts:
|
|
488
|
+
AGW_HTTP_FINAL_ERROR.labels(method, path).inc()
|
|
489
|
+
return resp
|
|
490
|
+
|
|
491
|
+
try:
|
|
492
|
+
await resp.aclose()
|
|
493
|
+
except Exception as exc: # pragma: no cover - best effort
|
|
494
|
+
logger.debug("Failed to close AGW httpx.Response before retry: %s", exc)
|
|
495
|
+
|
|
496
|
+
if self._client is not None:
|
|
497
|
+
try:
|
|
498
|
+
await self._client.aclose()
|
|
499
|
+
except Exception as close_exc: # pragma: no cover
|
|
500
|
+
logger.debug(
|
|
501
|
+
"Failed to close AGW httpx.AsyncClient: %s",
|
|
502
|
+
close_exc,
|
|
503
|
+
)
|
|
504
|
+
finally:
|
|
505
|
+
self._client = None
|
|
506
|
+
delay = self._compute_retry_delay(attempt)
|
|
507
|
+
if delay > 0:
|
|
508
|
+
await asyncio.sleep(delay)
|
|
509
|
+
attempt += 1
|
|
258
510
|
|
|
259
511
|
# -------------------- api -> CheckpointTuple ----------------------
|
|
260
512
|
def _to_tuple(self, node: Dict[str, Any]) -> CheckpointTuple:
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
513
|
+
pending = None
|
|
514
|
+
if node.get("pendingWrites"):
|
|
515
|
+
pending = []
|
|
516
|
+
for w in node["pendingWrites"]:
|
|
517
|
+
if isinstance(w, dict):
|
|
518
|
+
first = w.get("first")
|
|
519
|
+
second = w.get("second")
|
|
520
|
+
third = w.get("third")
|
|
521
|
+
if third is None and isinstance(second, dict) and all(
|
|
522
|
+
k in second for k in TYPED_KEYS
|
|
523
|
+
):
|
|
524
|
+
third = second
|
|
525
|
+
pending.append((first, second, self._safe_load(third)))
|
|
273
526
|
elif isinstance(w, (list, tuple)):
|
|
274
|
-
|
|
275
|
-
first,
|
|
276
|
-
|
|
277
|
-
|
|
527
|
+
if len(w) == 3:
|
|
528
|
+
first, second, third = w
|
|
529
|
+
elif len(w) == 2:
|
|
530
|
+
first, second = w
|
|
531
|
+
third = None
|
|
532
|
+
else:
|
|
278
533
|
continue
|
|
279
|
-
|
|
280
|
-
|
|
534
|
+
pending.append((first, second, self._safe_load(third)))
|
|
281
535
|
return CheckpointTuple(
|
|
282
536
|
config=self._decode_config(node.get("config")),
|
|
283
537
|
checkpoint=self._decode_cp(node["checkpoint"]),
|
|
284
538
|
metadata=self._dec_meta(node.get("metadata")),
|
|
285
539
|
parent_config=self._decode_config(node.get("parentConfig")),
|
|
286
|
-
pending_writes=
|
|
540
|
+
pending_writes=pending,
|
|
287
541
|
)
|
|
288
542
|
|
|
289
543
|
# =================================================================
|
|
@@ -292,7 +546,7 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
292
546
|
async def aget_tuple(self, cfg: RunnableConfig) -> CheckpointTuple | None:
|
|
293
547
|
cid = get_checkpoint_id(cfg)
|
|
294
548
|
api_cfg = self._to_api_config(cfg)
|
|
295
|
-
tid = api_cfg
|
|
549
|
+
tid = api_cfg["threadId"]
|
|
296
550
|
|
|
297
551
|
if cid:
|
|
298
552
|
path = f"/checkpoint/{tid}/{cid}"
|
|
@@ -304,7 +558,9 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
304
558
|
resp = await self._http("GET", path, params=params)
|
|
305
559
|
logger.debug("AGW aget_tuple response: %s", resp.text)
|
|
306
560
|
|
|
307
|
-
if not resp.text
|
|
561
|
+
if not resp.text:
|
|
562
|
+
return None
|
|
563
|
+
if resp.status_code in (404, 406):
|
|
308
564
|
return None
|
|
309
565
|
resp.raise_for_status()
|
|
310
566
|
return self._to_tuple(resp.json())
|
|
@@ -339,7 +595,7 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
339
595
|
payload = {
|
|
340
596
|
"config": self._to_api_config(cfg),
|
|
341
597
|
"checkpoint": self._encode_cp(cp),
|
|
342
|
-
"metadata": get_checkpoint_metadata(cfg, metadata),
|
|
598
|
+
"metadata": self._enc_meta(get_checkpoint_metadata(cfg, metadata)),
|
|
343
599
|
"newVersions": new_versions,
|
|
344
600
|
}
|
|
345
601
|
resp = await self._http("POST", "/checkpoint", json=payload)
|
|
@@ -354,7 +610,7 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
354
610
|
task_id: str,
|
|
355
611
|
task_path: str = "",
|
|
356
612
|
) -> None:
|
|
357
|
-
enc = [{"first": ch, "second": self.
|
|
613
|
+
enc = [{"first": ch, "second": self._safe_dump(v)} for ch, v in writes]
|
|
358
614
|
payload = {
|
|
359
615
|
"config": self._to_api_config(cfg),
|
|
360
616
|
"writes": enc,
|
|
@@ -441,4 +697,10 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
441
697
|
try:
|
|
442
698
|
yield saver
|
|
443
699
|
finally:
|
|
444
|
-
|
|
700
|
+
if saver._client is not None:
|
|
701
|
+
try:
|
|
702
|
+
await saver._client.aclose()
|
|
703
|
+
except Exception as close_exc: # pragma: no cover - best effort
|
|
704
|
+
logger.debug("Failed to close AGW httpx.AsyncClient: %s", close_exc)
|
|
705
|
+
finally:
|
|
706
|
+
saver._client = None
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
agent_lab_sdk/__init__.py,sha256=1Dlmv-wuz1QuciymKtYtX7jXzr_fkeGTe7aENfEDl3E,108
|
|
2
2
|
agent_lab_sdk/langgraph/checkpoint/__init__.py,sha256=DnKwR1LwbaQ3qhb124lE-tnojrUIVcCdNzHEHwgpL5M,86
|
|
3
|
-
agent_lab_sdk/langgraph/checkpoint/agw_saver.py,sha256=
|
|
3
|
+
agent_lab_sdk/langgraph/checkpoint/agw_saver.py,sha256=Srf8RYcW34_u2s54ABl0Jqm-_Z1gBH97gKqVY7QrKOQ,25631
|
|
4
4
|
agent_lab_sdk/langgraph/checkpoint/serde.py,sha256=UTSYbTbhBeL1CAr-XMbaH3SSIx9TeiC7ak22duXvqkw,5175
|
|
5
5
|
agent_lab_sdk/llm/__init__.py,sha256=Yo9MbYdHS1iX05A9XiJGwWN1Hm4IARGav9mNFPrtDeA,376
|
|
6
6
|
agent_lab_sdk/llm/agw_token_manager.py,sha256=_bPPI8muaEa6H01P8hHQOJHiiivaLd8N_d3OT9UT_80,4787
|
|
@@ -14,8 +14,8 @@ agent_lab_sdk/schema/input_types.py,sha256=e75nRW7Dz_RHk5Yia8DkFfbqMafsLQsQrJPfz
|
|
|
14
14
|
agent_lab_sdk/schema/log_message.py,sha256=nadi6lZGRuDSPmfbYs9QPpRJUT9Pfy8Y7pGCvyFF5Mw,638
|
|
15
15
|
agent_lab_sdk/storage/__init__.py,sha256=ik1_v1DMTwehvcAEXIYxuvLuCjJCa3y5qAuJqoQpuSA,81
|
|
16
16
|
agent_lab_sdk/storage/storage.py,sha256=ELpt7GRwFD-aWa6ctinfA_QwcvzWLvKS0Wz8FlxVqAs,2075
|
|
17
|
-
agent_lab_sdk-0.1.
|
|
18
|
-
agent_lab_sdk-0.1.
|
|
19
|
-
agent_lab_sdk-0.1.
|
|
20
|
-
agent_lab_sdk-0.1.
|
|
21
|
-
agent_lab_sdk-0.1.
|
|
17
|
+
agent_lab_sdk-0.1.38.dist-info/licenses/LICENSE,sha256=_TRXHkF3S9ilWBPdZcHLI_S-PRjK0L_SeOb2pcPAdV4,417
|
|
18
|
+
agent_lab_sdk-0.1.38.dist-info/METADATA,sha256=lVIN4EWC5qK-5j77XORxYLX7rTfYQV3Fkr0DB6czd0U,17911
|
|
19
|
+
agent_lab_sdk-0.1.38.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
20
|
+
agent_lab_sdk-0.1.38.dist-info/top_level.txt,sha256=E1efqkJ89KNmPBWdLzdMHeVtH0dYyCo4fhnSb81_15I,14
|
|
21
|
+
agent_lab_sdk-0.1.38.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|