agent-lab-sdk 0.1.35__py3-none-any.whl → 0.1.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agent-lab-sdk might be problematic. Click here for more details.
- agent_lab_sdk/langgraph/checkpoint/agw_saver.py +145 -85
- agent_lab_sdk/langgraph/checkpoint/serde.py +172 -0
- {agent_lab_sdk-0.1.35.dist-info → agent_lab_sdk-0.1.36.dist-info}/METADATA +2 -1
- {agent_lab_sdk-0.1.35.dist-info → agent_lab_sdk-0.1.36.dist-info}/RECORD +7 -6
- {agent_lab_sdk-0.1.35.dist-info → agent_lab_sdk-0.1.36.dist-info}/WHEEL +0 -0
- {agent_lab_sdk-0.1.35.dist-info → agent_lab_sdk-0.1.36.dist-info}/licenses/LICENSE +0 -0
- {agent_lab_sdk-0.1.35.dist-info → agent_lab_sdk-0.1.36.dist-info}/top_level.txt +0 -0
|
@@ -1,19 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import orjson
|
|
4
|
-
from random import random
|
|
5
|
-
from langgraph.checkpoint.serde.types import ChannelProtocol
|
|
6
3
|
import asyncio
|
|
7
4
|
import base64
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
8
7
|
from contextlib import asynccontextmanager
|
|
8
|
+
from random import random
|
|
9
9
|
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple
|
|
10
|
-
import logging
|
|
11
10
|
|
|
12
11
|
import httpx
|
|
12
|
+
import orjson
|
|
13
13
|
from langchain_core.runnables import RunnableConfig
|
|
14
14
|
|
|
15
15
|
from langgraph.checkpoint.base import (
|
|
16
|
-
WRITES_IDX_MAP,
|
|
17
16
|
BaseCheckpointSaver,
|
|
18
17
|
ChannelVersions,
|
|
19
18
|
Checkpoint,
|
|
@@ -23,20 +22,27 @@ from langgraph.checkpoint.base import (
|
|
|
23
22
|
get_checkpoint_metadata,
|
|
24
23
|
)
|
|
25
24
|
from langgraph.checkpoint.serde.base import SerializerProtocol
|
|
25
|
+
from langgraph.checkpoint.serde.encrypted import EncryptedSerializer
|
|
26
|
+
from langgraph.checkpoint.serde.types import ChannelProtocol
|
|
27
|
+
from .serde import Serializer
|
|
26
28
|
|
|
27
29
|
__all__ = ["AsyncAGWCheckpointSaver"]
|
|
28
30
|
|
|
29
31
|
logger = logging.getLogger(__name__)
|
|
30
32
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
33
|
+
TYPED_KEYS = ("type", "blob")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _to_b64(b: bytes | None) -> str | None:
|
|
37
|
+
return base64.b64encode(b).decode() if b is not None else None
|
|
38
|
+
|
|
37
39
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
+
def _b64decode_strict(s: str) -> bytes | None:
|
|
41
|
+
"""Возвращает bytes только если строка действительно корректная base64."""
|
|
42
|
+
try:
|
|
43
|
+
return base64.b64decode(s, validate=True)
|
|
44
|
+
except Exception:
|
|
45
|
+
return None
|
|
40
46
|
|
|
41
47
|
|
|
42
48
|
class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
@@ -53,6 +59,19 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
53
59
|
extra_headers: Dict[str, str] | None = None,
|
|
54
60
|
verify: bool = True,
|
|
55
61
|
):
|
|
62
|
+
if not serde:
|
|
63
|
+
base_serde: SerializerProtocol = Serializer()
|
|
64
|
+
# опционально оборачиваем в AES по ENV
|
|
65
|
+
_aes_key = (
|
|
66
|
+
os.getenv("LANGGRAPH_AES_KEY")
|
|
67
|
+
or os.getenv("AGW_AES_KEY")
|
|
68
|
+
or os.getenv("AES_KEY")
|
|
69
|
+
)
|
|
70
|
+
if _aes_key:
|
|
71
|
+
base_serde = EncryptedSerializer.from_pycryptodome_aes(
|
|
72
|
+
base_serde, key=_aes_key
|
|
73
|
+
)
|
|
74
|
+
serde = base_serde
|
|
56
75
|
super().__init__(serde=serde)
|
|
57
76
|
self.base_url = base_url.rstrip("/")
|
|
58
77
|
self.timeout = timeout
|
|
@@ -66,13 +85,13 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
66
85
|
self.headers.update(extra_headers)
|
|
67
86
|
if api_key:
|
|
68
87
|
self.headers["Authorization"] = f"Bearer {api_key}"
|
|
69
|
-
|
|
88
|
+
|
|
70
89
|
self._client = httpx.AsyncClient(
|
|
71
90
|
base_url=self.base_url,
|
|
72
91
|
headers=self.headers,
|
|
73
92
|
timeout=self.timeout,
|
|
74
93
|
verify=verify,
|
|
75
|
-
trust_env=True
|
|
94
|
+
trust_env=True,
|
|
76
95
|
)
|
|
77
96
|
|
|
78
97
|
async def __aenter__(self): # noqa: D401
|
|
@@ -81,56 +100,63 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
81
100
|
async def __aexit__(self, exc_type, exc, tb): # noqa: D401
|
|
82
101
|
await self._client.aclose()
|
|
83
102
|
|
|
84
|
-
# -----------------------
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
103
|
+
# ----------------------- typed (de)serialize ---------------------
|
|
104
|
+
def _encode_typed(self, value: Any) -> dict[str, Any]:
|
|
105
|
+
"""value -> {"type": str, "blob": base64str | null}"""
|
|
106
|
+
t, b = self.serde.dumps_typed(value)
|
|
107
|
+
return {"type": t, "blob": _to_b64(b)}
|
|
108
|
+
|
|
109
|
+
def _decode_typed(self, obj: Any) -> Any:
|
|
110
|
+
"""{type, blob} | [type, blob] | legacy -> python."""
|
|
111
|
+
# Новый формат: dict с ключами type/blob — только если blob валидная base64 или None
|
|
112
|
+
if isinstance(obj, dict) and all(k in obj for k in TYPED_KEYS):
|
|
113
|
+
t = obj.get("type")
|
|
114
|
+
b64 = obj.get("blob")
|
|
115
|
+
if b64 is None:
|
|
116
|
+
return self.serde.loads_typed((t, None))
|
|
117
|
+
if isinstance(b64, str):
|
|
118
|
+
b = _b64decode_strict(b64)
|
|
119
|
+
if b is not None:
|
|
120
|
+
return self.serde.loads_typed((t, b))
|
|
121
|
+
# если невалидно — падаем ниже на общую обработку
|
|
122
|
+
|
|
123
|
+
# Допускаем tuple/list вида [type, base64] — только при валидной base64
|
|
124
|
+
if isinstance(obj, (list, tuple)) and len(obj) == 2 and isinstance(obj[0], str):
|
|
125
|
+
t, b64 = obj
|
|
126
|
+
if b64 is None and t == "empty":
|
|
127
|
+
return self.serde.loads_typed((t, None))
|
|
128
|
+
if isinstance(b64, str):
|
|
129
|
+
b = _b64decode_strict(b64)
|
|
130
|
+
if b is not None:
|
|
131
|
+
return self.serde.loads_typed((t, b))
|
|
132
|
+
# иначе это не typed-пара
|
|
133
|
+
|
|
134
|
+
# Если это строка — пробуем как base64 строго, затем как JSON-строку
|
|
109
135
|
if isinstance(obj, str):
|
|
110
|
-
|
|
136
|
+
b = _b64decode_strict(obj)
|
|
137
|
+
if b is not None:
|
|
138
|
+
try:
|
|
139
|
+
return self.serde.loads(b)
|
|
140
|
+
except Exception:
|
|
141
|
+
pass
|
|
111
142
|
try:
|
|
112
143
|
return self.serde.loads(obj.encode())
|
|
113
144
|
except Exception:
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
145
|
+
return obj
|
|
146
|
+
|
|
147
|
+
# dict/list -> считаем это уже JSON и грузим через serde
|
|
148
|
+
if isinstance(obj, (dict, list)):
|
|
149
|
+
try:
|
|
150
|
+
return self.serde.loads(orjson.dumps(obj))
|
|
151
|
+
except Exception:
|
|
152
|
+
return obj
|
|
153
|
+
|
|
154
|
+
# как есть пробуем через serde
|
|
119
155
|
try:
|
|
120
156
|
return self.serde.loads(obj)
|
|
121
157
|
except Exception:
|
|
122
158
|
return obj
|
|
123
159
|
|
|
124
|
-
# def _safe_load(self, obj: Any) -> Any:
|
|
125
|
-
# """Обратная операция к _safe_dump."""
|
|
126
|
-
# if isinstance(obj, str):
|
|
127
|
-
# try:
|
|
128
|
-
# return self.serde.load(base64.b64decode(obj))
|
|
129
|
-
# except Exception:
|
|
130
|
-
# # не base64 — обычная строка
|
|
131
|
-
# return self.serde.load(obj)
|
|
132
|
-
# return self.serde.load(obj)
|
|
133
|
-
|
|
134
160
|
# ----------------------- config <-> api --------------------------
|
|
135
161
|
def _to_api_config(self, cfg: RunnableConfig | None) -> Dict[str, Any]:
|
|
136
162
|
if not cfg:
|
|
@@ -148,28 +174,53 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
148
174
|
|
|
149
175
|
# --------------------- checkpoint (de)ser ------------------------
|
|
150
176
|
def _encode_cp(self, cp: Checkpoint) -> Dict[str, Any]:
|
|
177
|
+
channel_values = {
|
|
178
|
+
k: self._encode_typed(v) for k, v in cp.get("channel_values", {}).items()
|
|
179
|
+
}
|
|
180
|
+
pending = []
|
|
181
|
+
for item in cp.get("pending_sends", []) or []:
|
|
182
|
+
try:
|
|
183
|
+
channel, value = item
|
|
184
|
+
pending.append({"channel": channel, **self._encode_typed(value)})
|
|
185
|
+
except Exception:
|
|
186
|
+
continue
|
|
151
187
|
return {
|
|
152
188
|
"v": cp["v"],
|
|
153
189
|
"id": cp["id"],
|
|
154
190
|
"ts": cp["ts"],
|
|
155
|
-
"channelValues":
|
|
191
|
+
"channelValues": channel_values,
|
|
156
192
|
"channelVersions": cp["channel_versions"],
|
|
157
193
|
"versionsSeen": cp["versions_seen"],
|
|
158
|
-
"pendingSends":
|
|
194
|
+
"pendingSends": pending,
|
|
159
195
|
}
|
|
160
196
|
|
|
161
197
|
def _decode_cp(self, raw: Dict[str, Any]) -> Checkpoint:
|
|
198
|
+
cv_raw = raw.get("channelValues") or {}
|
|
199
|
+
channel_values = {k: self._decode_typed(v) for k, v in cv_raw.items()}
|
|
200
|
+
ps_raw = raw.get("pendingSends") or []
|
|
201
|
+
pending_sends = []
|
|
202
|
+
for obj in ps_raw:
|
|
203
|
+
# ожидаем {channel, type, blob}
|
|
204
|
+
if isinstance(obj, dict) and "channel" in obj:
|
|
205
|
+
ch = obj["channel"]
|
|
206
|
+
typed = {k: obj[k] for k in obj.keys() if k in TYPED_KEYS}
|
|
207
|
+
val = self._decode_typed(typed)
|
|
208
|
+
pending_sends.append((ch, val))
|
|
209
|
+
elif isinstance(obj, (list, tuple)) and len(obj) == 2:
|
|
210
|
+
ch, val = obj
|
|
211
|
+
pending_sends.append((ch, self._decode_typed(val)))
|
|
212
|
+
|
|
162
213
|
return Checkpoint(
|
|
163
214
|
v=raw["v"],
|
|
164
215
|
id=raw["id"],
|
|
165
216
|
ts=raw["ts"],
|
|
166
|
-
channel_values=
|
|
217
|
+
channel_values=channel_values,
|
|
167
218
|
channel_versions=raw["channelVersions"],
|
|
168
219
|
versions_seen=raw["versionsSeen"],
|
|
169
|
-
pending_sends=
|
|
220
|
+
pending_sends=pending_sends,
|
|
170
221
|
)
|
|
171
222
|
|
|
172
|
-
def _decode_config(self, raw: Dict[str, Any]) -> Optional[RunnableConfig]:
|
|
223
|
+
def _decode_config(self, raw: Dict[str, Any] | None) -> Optional[RunnableConfig]:
|
|
173
224
|
if not raw:
|
|
174
225
|
return None
|
|
175
226
|
return RunnableConfig(
|
|
@@ -179,7 +230,7 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
179
230
|
run_name=raw.get("run_name"),
|
|
180
231
|
max_concurrency=raw.get("max_concurrency"),
|
|
181
232
|
recursion_limit=raw.get("recursion_limit"),
|
|
182
|
-
configurable=self._decode_configurable(raw.get("configurable"))
|
|
233
|
+
configurable=self._decode_configurable(raw.get("configurable") or {}),
|
|
183
234
|
)
|
|
184
235
|
|
|
185
236
|
def _decode_configurable(self, raw: Dict[str, Any]) -> dict[str, Any]:
|
|
@@ -187,41 +238,52 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
187
238
|
"thread_id": raw.get("threadId"),
|
|
188
239
|
"thread_ts": raw.get("threadTs"),
|
|
189
240
|
"checkpoint_ns": raw.get("checkpointNs"),
|
|
190
|
-
"checkpoint_id": raw.get("checkpointId")
|
|
241
|
+
"checkpoint_id": raw.get("checkpointId"),
|
|
191
242
|
}
|
|
192
243
|
|
|
193
|
-
# metadata (de)ser
|
|
244
|
+
# metadata (de)ser — передаём как есть (JSON-совместимый словарь)
|
|
194
245
|
def _enc_meta(self, md: CheckpointMetadata) -> CheckpointMetadata:
|
|
195
|
-
|
|
196
|
-
for k, v in md.items():
|
|
197
|
-
out[k] = self._enc_meta(v) if isinstance(v, dict) else self._safe_dump(v) # type: ignore[assignment]
|
|
198
|
-
return out
|
|
246
|
+
return md or {}
|
|
199
247
|
|
|
200
248
|
def _dec_meta(self, md: Any) -> Any:
|
|
201
|
-
|
|
202
|
-
return {k: self._dec_meta(v) for k, v in md.items()}
|
|
203
|
-
return self._safe_load(md)
|
|
249
|
+
return md
|
|
204
250
|
|
|
205
251
|
# ------------------------ HTTP wrapper ---------------------------
|
|
206
252
|
async def _http(self, method: str, path: str, **kw) -> httpx.Response:
|
|
207
253
|
if "json" in kw:
|
|
208
254
|
payload = kw.pop("json")
|
|
209
255
|
kw["data"] = orjson.dumps(payload)
|
|
210
|
-
logger.
|
|
211
|
-
|
|
256
|
+
logger.debug("AGW HTTP payload: %s", kw["data"].decode())
|
|
212
257
|
return await self._client.request(method, path, **kw)
|
|
213
258
|
|
|
214
259
|
# -------------------- api -> CheckpointTuple ----------------------
|
|
215
260
|
def _to_tuple(self, node: Dict[str, Any]) -> CheckpointTuple:
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
261
|
+
pending_writes = None
|
|
262
|
+
raw_pw = node.get("pendingWrites")
|
|
263
|
+
if raw_pw:
|
|
264
|
+
decoded: list[tuple[str, str, Any]] = []
|
|
265
|
+
for w in raw_pw:
|
|
266
|
+
if isinstance(w, dict) and "first" in w and "second" in w:
|
|
267
|
+
# ожидаем формат, который возвращает бек: first=task_id, second=channel, third=typed
|
|
268
|
+
task_id = w["first"]
|
|
269
|
+
channel = w["second"]
|
|
270
|
+
tv = w.get("third")
|
|
271
|
+
value = self._decode_typed(tv)
|
|
272
|
+
decoded.append((task_id, channel, value))
|
|
273
|
+
elif isinstance(w, (list, tuple)):
|
|
274
|
+
try:
|
|
275
|
+
first, channel, tv = w
|
|
276
|
+
decoded.append((first, channel, self._decode_typed(tv)))
|
|
277
|
+
except Exception: # pragma: no cover
|
|
278
|
+
continue
|
|
279
|
+
pending_writes = decoded
|
|
280
|
+
|
|
219
281
|
return CheckpointTuple(
|
|
220
|
-
config=self._decode_config(node
|
|
282
|
+
config=self._decode_config(node.get("config")),
|
|
221
283
|
checkpoint=self._decode_cp(node["checkpoint"]),
|
|
222
|
-
metadata=self._dec_meta(node
|
|
284
|
+
metadata=self._dec_meta(node.get("metadata")),
|
|
223
285
|
parent_config=self._decode_config(node.get("parentConfig")),
|
|
224
|
-
pending_writes=
|
|
286
|
+
pending_writes=pending_writes,
|
|
225
287
|
)
|
|
226
288
|
|
|
227
289
|
# =================================================================
|
|
@@ -230,7 +292,7 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
230
292
|
async def aget_tuple(self, cfg: RunnableConfig) -> CheckpointTuple | None:
|
|
231
293
|
cid = get_checkpoint_id(cfg)
|
|
232
294
|
api_cfg = self._to_api_config(cfg)
|
|
233
|
-
tid = api_cfg
|
|
295
|
+
tid = api_cfg.get("threadId")
|
|
234
296
|
|
|
235
297
|
if cid:
|
|
236
298
|
path = f"/checkpoint/{tid}/{cid}"
|
|
@@ -242,9 +304,7 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
242
304
|
resp = await self._http("GET", path, params=params)
|
|
243
305
|
logger.debug("AGW aget_tuple response: %s", resp.text)
|
|
244
306
|
|
|
245
|
-
if not resp.text:
|
|
246
|
-
return None
|
|
247
|
-
if resp.status_code in (404, 406):
|
|
307
|
+
if not resp.text or resp.status_code in (404, 406):
|
|
248
308
|
return None
|
|
249
309
|
resp.raise_for_status()
|
|
250
310
|
return self._to_tuple(resp.json())
|
|
@@ -279,7 +339,7 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
279
339
|
payload = {
|
|
280
340
|
"config": self._to_api_config(cfg),
|
|
281
341
|
"checkpoint": self._encode_cp(cp),
|
|
282
|
-
"metadata":
|
|
342
|
+
"metadata": get_checkpoint_metadata(cfg, metadata),
|
|
283
343
|
"newVersions": new_versions,
|
|
284
344
|
}
|
|
285
345
|
resp = await self._http("POST", "/checkpoint", json=payload)
|
|
@@ -294,7 +354,7 @@ class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
|
294
354
|
task_id: str,
|
|
295
355
|
task_path: str = "",
|
|
296
356
|
) -> None:
|
|
297
|
-
enc = [{"first": ch, "second": self.
|
|
357
|
+
enc = [{"first": ch, "second": self._encode_typed(v)} for ch, v in writes]
|
|
298
358
|
payload = {
|
|
299
359
|
"config": self._to_api_config(cfg),
|
|
300
360
|
"writes": enc,
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import re
|
|
3
|
+
import uuid
|
|
4
|
+
from base64 import b64encode
|
|
5
|
+
from collections import deque
|
|
6
|
+
from collections.abc import Mapping
|
|
7
|
+
from datetime import timedelta, timezone
|
|
8
|
+
from decimal import Decimal
|
|
9
|
+
from ipaddress import (
|
|
10
|
+
IPv4Address,
|
|
11
|
+
IPv4Interface,
|
|
12
|
+
IPv4Network,
|
|
13
|
+
IPv6Address,
|
|
14
|
+
IPv6Interface,
|
|
15
|
+
IPv6Network,
|
|
16
|
+
)
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from re import Pattern
|
|
19
|
+
from typing import Any, NamedTuple
|
|
20
|
+
from zoneinfo import ZoneInfo
|
|
21
|
+
|
|
22
|
+
import cloudpickle
|
|
23
|
+
import orjson
|
|
24
|
+
import logging
|
|
25
|
+
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Fragment(NamedTuple):
|
|
31
|
+
buf: bytes
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def decimal_encoder(dec_value: Decimal) -> int | float:
|
|
35
|
+
"""
|
|
36
|
+
Encodes a Decimal as int of there's no exponent, otherwise float
|
|
37
|
+
|
|
38
|
+
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
|
|
39
|
+
where a integer (but not int typed) is used. Encoding this as a float
|
|
40
|
+
results in failed round-tripping between encode and parse.
|
|
41
|
+
Our Id type is a prime example of this.
|
|
42
|
+
|
|
43
|
+
>>> decimal_encoder(Decimal("1.0"))
|
|
44
|
+
1.0
|
|
45
|
+
|
|
46
|
+
>>> decimal_encoder(Decimal("1"))
|
|
47
|
+
1
|
|
48
|
+
"""
|
|
49
|
+
if dec_value.as_tuple().exponent >= 0:
|
|
50
|
+
return int(dec_value)
|
|
51
|
+
else:
|
|
52
|
+
return float(dec_value)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def default(obj):
|
|
56
|
+
# Only need to handle types that orjson doesn't serialize by default
|
|
57
|
+
# https://github.com/ijl/orjson#serialize
|
|
58
|
+
if isinstance(obj, Fragment):
|
|
59
|
+
return orjson.Fragment(obj.buf)
|
|
60
|
+
if (
|
|
61
|
+
hasattr(obj, "model_dump")
|
|
62
|
+
and callable(obj.model_dump)
|
|
63
|
+
and not isinstance(obj, type)
|
|
64
|
+
):
|
|
65
|
+
return obj.model_dump()
|
|
66
|
+
elif hasattr(obj, "dict") and callable(obj.dict) and not isinstance(obj, type):
|
|
67
|
+
return obj.dict()
|
|
68
|
+
elif (
|
|
69
|
+
hasattr(obj, "_asdict") and callable(obj._asdict) and not isinstance(obj, type)
|
|
70
|
+
):
|
|
71
|
+
return obj._asdict()
|
|
72
|
+
elif isinstance(obj, BaseException):
|
|
73
|
+
return {"error": type(obj).__name__, "message": str(obj)}
|
|
74
|
+
elif isinstance(obj, (set, frozenset, deque)): # noqa: UP038
|
|
75
|
+
return list(obj)
|
|
76
|
+
elif isinstance(obj, (timezone, ZoneInfo)): # noqa: UP038
|
|
77
|
+
return obj.tzname(None)
|
|
78
|
+
elif isinstance(obj, timedelta):
|
|
79
|
+
return obj.total_seconds()
|
|
80
|
+
elif isinstance(obj, Decimal):
|
|
81
|
+
return decimal_encoder(obj)
|
|
82
|
+
elif isinstance(obj, uuid.UUID):
|
|
83
|
+
return str(obj)
|
|
84
|
+
elif isinstance( # noqa: UP038
|
|
85
|
+
obj,
|
|
86
|
+
(
|
|
87
|
+
IPv4Address,
|
|
88
|
+
IPv4Interface,
|
|
89
|
+
IPv4Network,
|
|
90
|
+
IPv6Address,
|
|
91
|
+
IPv6Interface,
|
|
92
|
+
IPv6Network,
|
|
93
|
+
Path,
|
|
94
|
+
),
|
|
95
|
+
):
|
|
96
|
+
return str(obj)
|
|
97
|
+
elif isinstance(obj, Pattern):
|
|
98
|
+
return obj.pattern
|
|
99
|
+
elif isinstance(obj, bytes | bytearray):
|
|
100
|
+
return b64encode(obj).decode()
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
_option = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NON_STR_KEYS
|
|
105
|
+
|
|
106
|
+
_SURROGATE_RE = re.compile(r"[\ud800-\udfff]")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _strip_surr(s: str) -> str:
|
|
110
|
+
return s if _SURROGATE_RE.search(s) is None else _SURROGATE_RE.sub("", s)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _sanitise(o: Any) -> Any:
|
|
114
|
+
if isinstance(o, str):
|
|
115
|
+
return _strip_surr(o)
|
|
116
|
+
if isinstance(o, Mapping):
|
|
117
|
+
return {_sanitise(k): _sanitise(v) for k, v in o.items()}
|
|
118
|
+
if isinstance(o, list | tuple | set):
|
|
119
|
+
ctor = list if isinstance(o, list) else type(o)
|
|
120
|
+
return ctor(_sanitise(x) for x in o)
|
|
121
|
+
return o
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def json_dumpb(obj) -> bytes:
|
|
125
|
+
try:
|
|
126
|
+
dumped = orjson.dumps(obj, default=default, option=_option)
|
|
127
|
+
except TypeError as e:
|
|
128
|
+
if "surrogates not allowed" not in str(e):
|
|
129
|
+
raise
|
|
130
|
+
dumped = orjson.dumps(_sanitise(obj), default=default, option=_option)
|
|
131
|
+
return (
|
|
132
|
+
# Unfortunately simply doing ``.replace(rb"\\u0000", b"")`` on
|
|
133
|
+
# the dumped bytes can leave an **orphaned back-slash** (e.g. ``\\q``)
|
|
134
|
+
# which makes the resulting JSON invalid. The fix is to delete the *double*
|
|
135
|
+
# back-slash form **first**, then (optionally) the single-escapes.
|
|
136
|
+
dumped.replace(rb"\\u0000", b"").replace(rb"\u0000", b"")
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def json_loads(content: bytes | Fragment | dict) -> Any:
|
|
141
|
+
if isinstance(content, Fragment):
|
|
142
|
+
content = content.buf
|
|
143
|
+
if isinstance(content, dict):
|
|
144
|
+
return content
|
|
145
|
+
return orjson.loads(content)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
async def ajson_loads(content: bytes | Fragment) -> Any:
|
|
149
|
+
return await asyncio.to_thread(json_loads, content)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class Serializer(JsonPlusSerializer):
|
|
153
|
+
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
|
|
154
|
+
try:
|
|
155
|
+
return super().dumps_typed(obj)
|
|
156
|
+
except TypeError:
|
|
157
|
+
return "pickle", cloudpickle.dumps(obj)
|
|
158
|
+
|
|
159
|
+
def dumps(self, obj: Any) -> bytes:
|
|
160
|
+
# See comment above (in json_dumpb)
|
|
161
|
+
return super().dumps(obj).replace(rb"\\u0000", b"").replace(rb"\u0000", b"")
|
|
162
|
+
|
|
163
|
+
def loads_typed(self, data: tuple[str, bytes]) -> Any:
|
|
164
|
+
if data[0] == "pickle":
|
|
165
|
+
try:
|
|
166
|
+
return cloudpickle.loads(data[1])
|
|
167
|
+
except Exception as e:
|
|
168
|
+
logger.warning(
|
|
169
|
+
"Failed to unpickle object, replacing w None", exc_info=e
|
|
170
|
+
)
|
|
171
|
+
return None
|
|
172
|
+
return super().loads_typed(data)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: agent-lab-sdk
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.36
|
|
4
4
|
Summary: SDK для работы с Agent Lab
|
|
5
5
|
Author-email: Andrew Ohurtsov <andermirik@yandex.com>
|
|
6
6
|
License: Proprietary and Confidential — All Rights Reserved
|
|
@@ -25,6 +25,7 @@ Requires-Dist: prometheus-client
|
|
|
25
25
|
Requires-Dist: langchain
|
|
26
26
|
Requires-Dist: httpx
|
|
27
27
|
Requires-Dist: orjson
|
|
28
|
+
Requires-Dist: cloudpickle
|
|
28
29
|
Dynamic: license-file
|
|
29
30
|
|
|
30
31
|
# Agent Lab SDK
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
agent_lab_sdk/__init__.py,sha256=1Dlmv-wuz1QuciymKtYtX7jXzr_fkeGTe7aENfEDl3E,108
|
|
2
2
|
agent_lab_sdk/langgraph/checkpoint/__init__.py,sha256=DnKwR1LwbaQ3qhb124lE-tnojrUIVcCdNzHEHwgpL5M,86
|
|
3
|
-
agent_lab_sdk/langgraph/checkpoint/agw_saver.py,sha256=
|
|
3
|
+
agent_lab_sdk/langgraph/checkpoint/agw_saver.py,sha256=QeYUAGEldw9SNXup3FRI7gNcGXYQDecKhoF1NyXa7yQ,16365
|
|
4
|
+
agent_lab_sdk/langgraph/checkpoint/serde.py,sha256=UTSYbTbhBeL1CAr-XMbaH3SSIx9TeiC7ak22duXvqkw,5175
|
|
4
5
|
agent_lab_sdk/llm/__init__.py,sha256=Yo9MbYdHS1iX05A9XiJGwWN1Hm4IARGav9mNFPrtDeA,376
|
|
5
6
|
agent_lab_sdk/llm/agw_token_manager.py,sha256=_bPPI8muaEa6H01P8hHQOJHiiivaLd8N_d3OT9UT_80,4787
|
|
6
7
|
agent_lab_sdk/llm/gigachat_token_manager.py,sha256=nlOxHcwJovsmM4cpI4fwMrYjoSeMjelDaHTipXsrUuA,8282
|
|
@@ -13,8 +14,8 @@ agent_lab_sdk/schema/input_types.py,sha256=e75nRW7Dz_RHk5Yia8DkFfbqMafsLQsQrJPfz
|
|
|
13
14
|
agent_lab_sdk/schema/log_message.py,sha256=nadi6lZGRuDSPmfbYs9QPpRJUT9Pfy8Y7pGCvyFF5Mw,638
|
|
14
15
|
agent_lab_sdk/storage/__init__.py,sha256=ik1_v1DMTwehvcAEXIYxuvLuCjJCa3y5qAuJqoQpuSA,81
|
|
15
16
|
agent_lab_sdk/storage/storage.py,sha256=ELpt7GRwFD-aWa6ctinfA_QwcvzWLvKS0Wz8FlxVqAs,2075
|
|
16
|
-
agent_lab_sdk-0.1.
|
|
17
|
-
agent_lab_sdk-0.1.
|
|
18
|
-
agent_lab_sdk-0.1.
|
|
19
|
-
agent_lab_sdk-0.1.
|
|
20
|
-
agent_lab_sdk-0.1.
|
|
17
|
+
agent_lab_sdk-0.1.36.dist-info/licenses/LICENSE,sha256=_TRXHkF3S9ilWBPdZcHLI_S-PRjK0L_SeOb2pcPAdV4,417
|
|
18
|
+
agent_lab_sdk-0.1.36.dist-info/METADATA,sha256=p5wE_ZTl7073DHAqLOI7DPb-4EGI_E44w8TU11K2960,17911
|
|
19
|
+
agent_lab_sdk-0.1.36.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
20
|
+
agent_lab_sdk-0.1.36.dist-info/top_level.txt,sha256=E1efqkJ89KNmPBWdLzdMHeVtH0dYyCo4fhnSb81_15I,14
|
|
21
|
+
agent_lab_sdk-0.1.36.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|