agent-lab-sdk 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agent-lab-sdk might be problematic. Click here for more details.
- agent_lab_sdk/__init__.py +3 -0
- agent_lab_sdk/langgraph/checkpoint/__init__.py +3 -0
- agent_lab_sdk/langgraph/checkpoint/agw_saver.py +381 -0
- agent_lab_sdk/llm/__init__.py +14 -0
- agent_lab_sdk/llm/agw_token_manager.py +97 -0
- agent_lab_sdk/llm/gigachat_token_manager.py +156 -0
- agent_lab_sdk/llm/llm.py +28 -0
- agent_lab_sdk/llm/throttled.py +177 -0
- agent_lab_sdk/metrics/__init__.py +2 -0
- agent_lab_sdk/metrics/metrics.py +104 -0
- agent_lab_sdk/storage/__init__.py +3 -0
- agent_lab_sdk/storage/storage.py +64 -0
- {agent_lab_sdk-0.1.3.dist-info → agent_lab_sdk-0.1.4.dist-info}/METADATA +5 -1
- agent_lab_sdk-0.1.4.dist-info/RECORD +17 -0
- agent_lab_sdk-0.1.3.dist-info/RECORD +0 -6
- {agent_lab_sdk-0.1.3.dist-info → agent_lab_sdk-0.1.4.dist-info}/WHEEL +0 -0
- {agent_lab_sdk-0.1.3.dist-info → agent_lab_sdk-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {agent_lab_sdk-0.1.3.dist-info → agent_lab_sdk-0.1.4.dist-info}/top_level.txt +0 -0
agent_lab_sdk/__init__.py
CHANGED
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from random import random
|
|
5
|
+
from langgraph.checkpoint.serde.types import ChannelProtocol
|
|
6
|
+
import asyncio
|
|
7
|
+
import base64
|
|
8
|
+
from contextlib import asynccontextmanager
|
|
9
|
+
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple
|
|
10
|
+
import logging
|
|
11
|
+
|
|
12
|
+
import requests
|
|
13
|
+
from langchain_core.runnables import RunnableConfig
|
|
14
|
+
|
|
15
|
+
from langgraph.checkpoint.base import (
|
|
16
|
+
WRITES_IDX_MAP,
|
|
17
|
+
BaseCheckpointSaver,
|
|
18
|
+
ChannelVersions,
|
|
19
|
+
Checkpoint,
|
|
20
|
+
CheckpointMetadata,
|
|
21
|
+
CheckpointTuple,
|
|
22
|
+
get_checkpoint_id,
|
|
23
|
+
get_checkpoint_metadata,
|
|
24
|
+
)
|
|
25
|
+
from langgraph.checkpoint.serde.base import SerializerProtocol
|
|
26
|
+
|
|
27
|
+
__all__ = ["AsyncAGWCheckpointSaver"]
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
# ------------------------------------------------------------------ #
|
|
32
|
+
# helpers for Py < 3.10
|
|
33
|
+
# ------------------------------------------------------------------ #
|
|
34
|
+
try:
|
|
35
|
+
anext # type: ignore[name-defined]
|
|
36
|
+
except NameError: # pragma: no cover
|
|
37
|
+
|
|
38
|
+
async def anext(it):
|
|
39
|
+
return await it.__anext__()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class AsyncAGWCheckpointSaver(BaseCheckpointSaver):
|
|
43
|
+
"""Persist checkpoints in Agent-Gateway с помощью `requests` + threads."""
|
|
44
|
+
|
|
45
|
+
# ---------------------------- init / ctx -------------------------
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
base_url: str = "http://localhost",
|
|
49
|
+
*,
|
|
50
|
+
serde: SerializerProtocol | None = None,
|
|
51
|
+
timeout: int | float = 10,
|
|
52
|
+
api_key: str | None = None,
|
|
53
|
+
extra_headers: Dict[str, str] | None = None,
|
|
54
|
+
):
|
|
55
|
+
super().__init__(serde=serde)
|
|
56
|
+
self.base_url = base_url.rstrip("/")
|
|
57
|
+
self.timeout = timeout
|
|
58
|
+
self._session = requests.Session()
|
|
59
|
+
self.loop = asyncio.get_running_loop()
|
|
60
|
+
|
|
61
|
+
self.headers: Dict[str, str] = {
|
|
62
|
+
"Accept": "application/json",
|
|
63
|
+
"Content-Type": "application/json",
|
|
64
|
+
}
|
|
65
|
+
if extra_headers:
|
|
66
|
+
self.headers.update(extra_headers)
|
|
67
|
+
if api_key:
|
|
68
|
+
self.headers["Authorization"] = f"Bearer {api_key}"
|
|
69
|
+
|
|
70
|
+
async def __aenter__(self): # noqa: D401
|
|
71
|
+
return self
|
|
72
|
+
|
|
73
|
+
async def __aexit__(self, exc_type, exc, tb): # noqa: D401
|
|
74
|
+
await asyncio.to_thread(self._session.close)
|
|
75
|
+
|
|
76
|
+
# ----------------------- universal dump/load ---------------------
|
|
77
|
+
# def _safe_dump(self, obj: Any) -> Any:
|
|
78
|
+
# """self.serde.dump → гарантированная JSON-строка."""
|
|
79
|
+
# dumped = self.serde.dumps(obj)
|
|
80
|
+
# if isinstance(dumped, (bytes, bytearray)):
|
|
81
|
+
# return base64.b64encode(dumped).decode() # str
|
|
82
|
+
# return dumped # уже json-совместимо
|
|
83
|
+
|
|
84
|
+
def _safe_dump(self, obj: Any) -> Any:
|
|
85
|
+
"""bytes → python-object; fallback base64 для реально бинарных данных."""
|
|
86
|
+
dumped = self.serde.dumps(obj)
|
|
87
|
+
if isinstance(dumped, (bytes, bytearray)):
|
|
88
|
+
try:
|
|
89
|
+
# 1) bytes → str
|
|
90
|
+
s = dumped.decode()
|
|
91
|
+
# 2) str JSON → python (list/dict/scalar)
|
|
92
|
+
return json.loads(s)
|
|
93
|
+
except (UnicodeDecodeError, json.JSONDecodeError):
|
|
94
|
+
# не UTF-8 или не JSON → base64
|
|
95
|
+
return base64.b64encode(dumped).decode()
|
|
96
|
+
return dumped
|
|
97
|
+
|
|
98
|
+
def _safe_load(self, obj: Any) -> Any:
|
|
99
|
+
if isinstance(obj, (dict, list)): # уже распакованный JSON
|
|
100
|
+
return self.serde.loads(json.dumps(obj, ensure_ascii=False).encode())
|
|
101
|
+
if isinstance(obj, str):
|
|
102
|
+
# сначала plain JSON-строка
|
|
103
|
+
try:
|
|
104
|
+
return self.serde.loads(obj.encode())
|
|
105
|
+
except Exception:
|
|
106
|
+
# возможно base64
|
|
107
|
+
try:
|
|
108
|
+
return self.serde.loads(base64.b64decode(obj))
|
|
109
|
+
except Exception:
|
|
110
|
+
return obj
|
|
111
|
+
try:
|
|
112
|
+
return self.serde.loads(obj)
|
|
113
|
+
except Exception:
|
|
114
|
+
return obj
|
|
115
|
+
|
|
116
|
+
# def _safe_load(self, obj: Any) -> Any:
|
|
117
|
+
# """Обратная операция к _safe_dump."""
|
|
118
|
+
# if isinstance(obj, str):
|
|
119
|
+
# try:
|
|
120
|
+
# return self.serde.load(base64.b64decode(obj))
|
|
121
|
+
# except Exception:
|
|
122
|
+
# # не base64 — обычная строка
|
|
123
|
+
# return self.serde.load(obj)
|
|
124
|
+
# return self.serde.load(obj)
|
|
125
|
+
|
|
126
|
+
# ----------------------- config <-> api --------------------------
|
|
127
|
+
def _to_api_config(self, cfg: RunnableConfig | None) -> Dict[str, Any]:
|
|
128
|
+
if not cfg:
|
|
129
|
+
return {}
|
|
130
|
+
c = cfg.get("configurable", {})
|
|
131
|
+
res: Dict[str, Any] = {
|
|
132
|
+
"threadId": c.get("thread_id", ""),
|
|
133
|
+
"checkpointNs": c.get("checkpoint_ns", ""),
|
|
134
|
+
}
|
|
135
|
+
if cid := c.get("checkpoint_id"):
|
|
136
|
+
res["checkpointId"] = cid
|
|
137
|
+
if ts := c.get("thread_ts"):
|
|
138
|
+
res["threadTs"] = ts
|
|
139
|
+
return res
|
|
140
|
+
|
|
141
|
+
# --------------------- checkpoint (de)ser ------------------------
|
|
142
|
+
def _encode_cp(self, cp: Checkpoint) -> Dict[str, Any]:
|
|
143
|
+
return {
|
|
144
|
+
"v": cp["v"],
|
|
145
|
+
"id": cp["id"],
|
|
146
|
+
"ts": cp["ts"],
|
|
147
|
+
"channelValues": {k: self._safe_dump(v) for k, v in cp["channel_values"].items()},
|
|
148
|
+
"channelVersions": cp["channel_versions"],
|
|
149
|
+
"versionsSeen": cp["versions_seen"],
|
|
150
|
+
"pendingSends": cp.get("pending_sends", []),
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
def _decode_cp(self, raw: Dict[str, Any]) -> Checkpoint:
|
|
154
|
+
return Checkpoint(
|
|
155
|
+
v=raw["v"],
|
|
156
|
+
id=raw["id"],
|
|
157
|
+
ts=raw["ts"],
|
|
158
|
+
channel_values={k: self._safe_load(v) for k, v in raw["channelValues"].items()},
|
|
159
|
+
channel_versions=raw["channelVersions"],
|
|
160
|
+
versions_seen=raw["versionsSeen"],
|
|
161
|
+
pending_sends=raw.get("pendingSends", []),
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
def _decode_config(self, raw: Dict[str, Any]) -> Optional[RunnableConfig]:
|
|
165
|
+
if not raw:
|
|
166
|
+
return None
|
|
167
|
+
return RunnableConfig(
|
|
168
|
+
tags=raw.get("tags"),
|
|
169
|
+
metadata=raw.get("metadata"),
|
|
170
|
+
callbacks=raw.get("callbacks"),
|
|
171
|
+
run_name=raw.get("run_name"),
|
|
172
|
+
max_concurrency=raw.get("max_concurrency"),
|
|
173
|
+
recursion_limit=raw.get("recursion_limit"),
|
|
174
|
+
configurable=self._decode_configurable(raw.get("configurable"))
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def _decode_configurable(self, raw: Dict[str, Any]) -> dict[str, Any]:
|
|
178
|
+
return {
|
|
179
|
+
"thread_id": raw.get("threadId"),
|
|
180
|
+
"thread_ts": raw.get("threadTs"),
|
|
181
|
+
"checkpoint_ns": raw.get("checkpointNs"),
|
|
182
|
+
"checkpoint_id": raw.get("checkpointId")
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
# metadata (de)ser
|
|
186
|
+
def _enc_meta(self, md: CheckpointMetadata) -> CheckpointMetadata:
|
|
187
|
+
out: CheckpointMetadata = {}
|
|
188
|
+
for k, v in md.items():
|
|
189
|
+
out[k] = self._enc_meta(v) if isinstance(v, dict) else self._safe_dump(v) # type: ignore[assignment]
|
|
190
|
+
return out
|
|
191
|
+
|
|
192
|
+
def _dec_meta(self, md: Any) -> Any:
|
|
193
|
+
if isinstance(md, dict):
|
|
194
|
+
return {k: self._dec_meta(v) for k, v in md.items()}
|
|
195
|
+
return self._safe_load(md)
|
|
196
|
+
|
|
197
|
+
# ------------------------ HTTP wrapper ---------------------------
|
|
198
|
+
async def _http(self, m: str, path: str, **kw) -> requests.Response:
|
|
199
|
+
url = f"{self.base_url}{path}"
|
|
200
|
+
hdr = {**self.headers, **kw.pop("headers", {})}
|
|
201
|
+
|
|
202
|
+
if "json" in kw:
|
|
203
|
+
payload = kw.pop("json")
|
|
204
|
+
kw["data"] = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
|
205
|
+
logger.info(kw["data"])
|
|
206
|
+
|
|
207
|
+
return await asyncio.to_thread(
|
|
208
|
+
self._session.request, m, url, headers=hdr, timeout=self.timeout, **kw
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# -------------------- api -> CheckpointTuple ----------------------
|
|
212
|
+
def _to_tuple(self, node: Dict[str, Any]) -> CheckpointTuple:
|
|
213
|
+
pending = None
|
|
214
|
+
if node.get("pendingWrites"):
|
|
215
|
+
pending = [(w["first"], w["second"], self._safe_load(w["third"])) for w in node["pendingWrites"]]
|
|
216
|
+
return CheckpointTuple(
|
|
217
|
+
config=self._decode_config(node["config"]),
|
|
218
|
+
checkpoint=self._decode_cp(node["checkpoint"]),
|
|
219
|
+
metadata=self._dec_meta(node["metadata"]),
|
|
220
|
+
parent_config=self._decode_config(node.get("parentConfig")),
|
|
221
|
+
pending_writes=pending,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# =================================================================
|
|
225
|
+
# async-методы BaseCheckpointSaver
|
|
226
|
+
# =================================================================
|
|
227
|
+
async def aget_tuple(self, cfg: RunnableConfig) -> CheckpointTuple | None:
|
|
228
|
+
cid = get_checkpoint_id(cfg)
|
|
229
|
+
api_cfg = self._to_api_config(cfg)
|
|
230
|
+
tid = api_cfg["threadId"]
|
|
231
|
+
|
|
232
|
+
if cid:
|
|
233
|
+
path = f"/checkpoint/{tid}/{cid}"
|
|
234
|
+
params = {"checkpointNs": api_cfg.get("checkpointNs", "")}
|
|
235
|
+
else:
|
|
236
|
+
path = f"/checkpoint/{tid}"
|
|
237
|
+
params = None
|
|
238
|
+
|
|
239
|
+
resp = await self._http("GET", path, params=params)
|
|
240
|
+
logger.debug("AGW aget_tuple response: %s", resp.text)
|
|
241
|
+
|
|
242
|
+
if not resp.text:
|
|
243
|
+
return None
|
|
244
|
+
if resp.status_code in (404, 406):
|
|
245
|
+
return None
|
|
246
|
+
resp.raise_for_status()
|
|
247
|
+
return self._to_tuple(resp.json())
|
|
248
|
+
|
|
249
|
+
async def alist(
|
|
250
|
+
self,
|
|
251
|
+
cfg: RunnableConfig | None,
|
|
252
|
+
*,
|
|
253
|
+
filter: Dict[str, Any] | None = None,
|
|
254
|
+
before: RunnableConfig | None = None,
|
|
255
|
+
limit: int | None = None,
|
|
256
|
+
) -> AsyncIterator[CheckpointTuple]:
|
|
257
|
+
payload = {
|
|
258
|
+
"config": self._to_api_config(cfg) if cfg else None,
|
|
259
|
+
"filter": filter,
|
|
260
|
+
"before": self._to_api_config(before) if before else None,
|
|
261
|
+
"limit": limit,
|
|
262
|
+
}
|
|
263
|
+
resp = await self._http("POST", "/checkpoint/list", json=payload)
|
|
264
|
+
logger.debug("AGW alist response: %s", resp.text)
|
|
265
|
+
resp.raise_for_status()
|
|
266
|
+
for item in resp.json():
|
|
267
|
+
yield self._to_tuple(item)
|
|
268
|
+
|
|
269
|
+
async def aput(
|
|
270
|
+
self,
|
|
271
|
+
cfg: RunnableConfig,
|
|
272
|
+
cp: Checkpoint,
|
|
273
|
+
metadata: CheckpointMetadata,
|
|
274
|
+
new_versions: ChannelVersions,
|
|
275
|
+
) -> RunnableConfig:
|
|
276
|
+
payload = {
|
|
277
|
+
"config": self._to_api_config(cfg),
|
|
278
|
+
"checkpoint": self._encode_cp(cp),
|
|
279
|
+
"metadata": self._enc_meta(get_checkpoint_metadata(cfg, metadata)),
|
|
280
|
+
"newVersions": new_versions,
|
|
281
|
+
}
|
|
282
|
+
resp = await self._http("POST", "/checkpoint", json=payload)
|
|
283
|
+
logger.debug("AGW aput response: %s", resp.text)
|
|
284
|
+
resp.raise_for_status()
|
|
285
|
+
return resp.json()["config"]
|
|
286
|
+
|
|
287
|
+
async def aput_writes(
|
|
288
|
+
self,
|
|
289
|
+
cfg: RunnableConfig,
|
|
290
|
+
writes: Sequence[Tuple[str, Any]],
|
|
291
|
+
task_id: str,
|
|
292
|
+
task_path: str = "",
|
|
293
|
+
) -> None:
|
|
294
|
+
enc = [{"first": ch, "second": self._safe_dump(v)} for ch, v in writes]
|
|
295
|
+
payload = {
|
|
296
|
+
"config": self._to_api_config(cfg),
|
|
297
|
+
"writes": enc,
|
|
298
|
+
"taskId": task_id,
|
|
299
|
+
"taskPath": task_path,
|
|
300
|
+
}
|
|
301
|
+
resp = await self._http("POST", "/checkpoint/writes", json=payload)
|
|
302
|
+
logger.debug("AGW aput_writes response: %s", resp.text)
|
|
303
|
+
resp.raise_for_status()
|
|
304
|
+
|
|
305
|
+
async def adelete_thread(self, thread_id: str) -> None:
|
|
306
|
+
resp = await self._http("DELETE", f"/checkpoint/{thread_id}")
|
|
307
|
+
resp.raise_for_status()
|
|
308
|
+
|
|
309
|
+
# =================================================================
|
|
310
|
+
# sync-обёртки
|
|
311
|
+
# =================================================================
|
|
312
|
+
def _run(self, coro):
|
|
313
|
+
return asyncio.run_coroutine_threadsafe(coro, self.loop).result()
|
|
314
|
+
|
|
315
|
+
def list(
|
|
316
|
+
self,
|
|
317
|
+
cfg: RunnableConfig | None,
|
|
318
|
+
*,
|
|
319
|
+
filter: Dict[str, Any] | None = None,
|
|
320
|
+
before: RunnableConfig | None = None,
|
|
321
|
+
limit: int | None = None,
|
|
322
|
+
) -> Iterator[CheckpointTuple]:
|
|
323
|
+
aiter_ = self.alist(cfg, filter=filter, before=before, limit=limit)
|
|
324
|
+
while True:
|
|
325
|
+
try:
|
|
326
|
+
yield self._run(anext(aiter_))
|
|
327
|
+
except StopAsyncIteration:
|
|
328
|
+
break
|
|
329
|
+
|
|
330
|
+
def get_tuple(self, cfg: RunnableConfig) -> CheckpointTuple | None:
|
|
331
|
+
return self._run(self.aget_tuple(cfg))
|
|
332
|
+
|
|
333
|
+
def put(
|
|
334
|
+
self,
|
|
335
|
+
cfg: RunnableConfig,
|
|
336
|
+
cp: Checkpoint,
|
|
337
|
+
metadata: CheckpointMetadata,
|
|
338
|
+
new_versions: ChannelVersions,
|
|
339
|
+
) -> RunnableConfig:
|
|
340
|
+
return self._run(self.aput(cfg, cp, metadata, new_versions))
|
|
341
|
+
|
|
342
|
+
def put_writes(
|
|
343
|
+
self,
|
|
344
|
+
cfg: RunnableConfig,
|
|
345
|
+
writes: Sequence[Tuple[str, Any]],
|
|
346
|
+
task_id: str,
|
|
347
|
+
task_path: str = "",
|
|
348
|
+
) -> None:
|
|
349
|
+
self._run(self.aput_writes(cfg, writes, task_id, task_path))
|
|
350
|
+
|
|
351
|
+
def delete_thread(self, thread_id: str) -> None:
|
|
352
|
+
self._run(self.adelete_thread(thread_id))
|
|
353
|
+
|
|
354
|
+
def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
|
|
355
|
+
if current is None:
|
|
356
|
+
current_v = 0
|
|
357
|
+
elif isinstance(current, int):
|
|
358
|
+
current_v = current
|
|
359
|
+
else:
|
|
360
|
+
current_v = int(current.split(".")[0])
|
|
361
|
+
next_v = current_v + 1
|
|
362
|
+
next_h = random()
|
|
363
|
+
return f"{next_v:032}.{next_h:016}"
|
|
364
|
+
|
|
365
|
+
# ------------------------------------------------------------------ #
|
|
366
|
+
# Convenience factory #
|
|
367
|
+
# ------------------------------------------------------------------ #
|
|
368
|
+
@classmethod
|
|
369
|
+
@asynccontextmanager
|
|
370
|
+
async def from_base_url(
|
|
371
|
+
cls,
|
|
372
|
+
base_url: str,
|
|
373
|
+
*,
|
|
374
|
+
api_key: str | None = None,
|
|
375
|
+
**kwargs: Any,
|
|
376
|
+
) -> AsyncIterator["AsyncAGWCheckpointSaver"]:
|
|
377
|
+
saver = cls(base_url, api_key=api_key, **kwargs)
|
|
378
|
+
try:
|
|
379
|
+
yield saver
|
|
380
|
+
finally:
|
|
381
|
+
await asyncio.to_thread(saver._session.close)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# agent_lab_sdk/llm/__init__.py
|
|
2
|
+
|
|
3
|
+
from .llm import get_model
|
|
4
|
+
from .gigachat_token_manager import GigaChatTokenManager
|
|
5
|
+
from .agw_token_manager import AgwTokenManager
|
|
6
|
+
from .throttled import ThrottledGigaChat, ThrottledGigaChatEmbeddings
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"get_model",
|
|
10
|
+
"GigaChatTokenManager",
|
|
11
|
+
"AgwTokenManager",
|
|
12
|
+
"ThrottledGigaChat",
|
|
13
|
+
"ThrottledGigaChatEmbeddings",
|
|
14
|
+
]
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import requests
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import threading
|
|
7
|
+
import urllib.parse
|
|
8
|
+
import json
|
|
9
|
+
import random
|
|
10
|
+
import time
|
|
11
|
+
from datetime import datetime, timedelta, timezone
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
# Путь к файлу токена (можно переопределить через ENV)
|
|
16
|
+
TOKEN_PROVIDER_AGW_URL = os.environ.get("TOKEN_PROVIDER_AGW_URL", "https://agent-gateway.apps.advosd.sberdevices.ru")
|
|
17
|
+
# Максимальное число попыток получения токена (по умолчанию 3, можно переопределить через ENV)
|
|
18
|
+
TOKEN_PROVIDER_AGW_DEFAULT_MAX_RETRIES = int(os.environ.get("TOKEN_PROVIDER_AGW_DEFAULT_MAX_RETRIES", 3))
|
|
19
|
+
# Таймаут ожидания ответа agw в секундах
|
|
20
|
+
TOKEN_PROVIDER_AGW_TIMEOUT_SEC = int(os.environ.get("TOKEN_PROVIDER_AGW_TIMEOUT_SEC", 5))
|
|
21
|
+
# Случайная задержка между повторными попытками: от 1 до 5 секунд
|
|
22
|
+
BACKOFF_MIN = 1
|
|
23
|
+
BACKOFF_MAX = 5
|
|
24
|
+
# Случайный порог обновления токена: от 0 до 300 секунд (5 минут)
|
|
25
|
+
REFRESH_WINDOW_MAX = 300
|
|
26
|
+
|
|
27
|
+
class AgwTokenManager:
|
|
28
|
+
# Лок для синхронизации между потоками в одном процессе
|
|
29
|
+
_thread_lock = threading.Lock()
|
|
30
|
+
_tokens: dict[str, any] = {}
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def _get_new_token(provider: str, token_type: Optional[str] = None) -> dict[str, any]:
|
|
34
|
+
req_url = urllib.parse.urljoin(TOKEN_PROVIDER_AGW_URL, "/tokens")
|
|
35
|
+
params = {"provider": provider}
|
|
36
|
+
if token_type:
|
|
37
|
+
params["type"] = token_type
|
|
38
|
+
|
|
39
|
+
max_retries = TOKEN_PROVIDER_AGW_DEFAULT_MAX_RETRIES
|
|
40
|
+
for attempt in range(1, max_retries + 1):
|
|
41
|
+
try:
|
|
42
|
+
logger.info(f"Попытка получения токена из AGW ({attempt}/{max_retries})...")
|
|
43
|
+
resp = requests.post(req_url, params=params, data={}, verify=False, timeout=TOKEN_PROVIDER_AGW_TIMEOUT_SEC)
|
|
44
|
+
resp.raise_for_status()
|
|
45
|
+
result = resp.json()
|
|
46
|
+
|
|
47
|
+
token = result.get("token")
|
|
48
|
+
expires_in = result.get("expiresIn")
|
|
49
|
+
|
|
50
|
+
if not token or not expires_in:
|
|
51
|
+
raise ValueError(f"Неверный ответ API: {resp.text}")
|
|
52
|
+
|
|
53
|
+
expiry_time = datetime.fromtimestamp(expires_in, tz=timezone.utc) - timedelta(seconds=60)
|
|
54
|
+
logger.info("Успешно получили токен, действует до %s", expiry_time.isoformat())
|
|
55
|
+
return {
|
|
56
|
+
"token": token,
|
|
57
|
+
"expiry_ts": expires_in
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
except (Exception, requests.RequestException, ValueError, json.JSONDecodeError) as e:
|
|
61
|
+
logger.error("Ошибка при попытке %d: %s", attempt, e)
|
|
62
|
+
if attempt < max_retries:
|
|
63
|
+
delay = random.uniform(BACKOFF_MIN, BACKOFF_MAX)
|
|
64
|
+
logger.info("Повтор через %.2f сек...", delay)
|
|
65
|
+
time.sleep(delay)
|
|
66
|
+
else:
|
|
67
|
+
logger.error("Все %d попыток получения токена провалены.", max_retries)
|
|
68
|
+
raise
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def get_token(cls, provider: str, token_type: Optional[str] = None) -> str:
|
|
72
|
+
"""
|
|
73
|
+
Возвращает валидный токен, обновляя его при необходимости.
|
|
74
|
+
"""
|
|
75
|
+
key = "{}_{}".format(provider, token_type if token_type is not None else "default")
|
|
76
|
+
# Лок между потоками
|
|
77
|
+
with cls._thread_lock:
|
|
78
|
+
by_provider = cls._tokens.get(key)
|
|
79
|
+
if by_provider:
|
|
80
|
+
token = by_provider.get("token")
|
|
81
|
+
expiry_ts = by_provider.get("expiry_ts")
|
|
82
|
+
now = datetime.now(timezone.utc)
|
|
83
|
+
|
|
84
|
+
if token and expiry_ts:
|
|
85
|
+
expiry_time = datetime.fromtimestamp(expiry_ts, tz=timezone.utc)
|
|
86
|
+
refresh_window = timedelta(seconds=random.uniform(0, REFRESH_WINDOW_MAX))
|
|
87
|
+
time_left = expiry_time - now
|
|
88
|
+
logger.debug("Осталось времени: %s, порог обновления: %s", time_left, refresh_window)
|
|
89
|
+
if time_left > refresh_window:
|
|
90
|
+
logger.debug("Используем кэшированный токен; истекает %s", expiry_time.isoformat())
|
|
91
|
+
return token
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# Иначе — запрашиваем новый токен
|
|
95
|
+
new_token = cls._get_new_token(provider, token_type)
|
|
96
|
+
cls._tokens[key] = new_token
|
|
97
|
+
return new_token.get("token")
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import logging
|
|
5
|
+
import uuid
|
|
6
|
+
import requests
|
|
7
|
+
import fcntl
|
|
8
|
+
import random
|
|
9
|
+
import time
|
|
10
|
+
import threading
|
|
11
|
+
from datetime import datetime, timedelta, timezone
|
|
12
|
+
from .agw_token_manager import AgwTokenManager
|
|
13
|
+
import urllib3
|
|
14
|
+
|
|
15
|
+
logging.basicConfig(level=logging.INFO)
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
19
|
+
|
|
20
|
+
# Путь к файлу токена (можно переопределить через ENV)
|
|
21
|
+
TOKEN_FILE_PATH = os.environ.get("GIGACHAT_TOKEN_PATH", "/tmp/gigachat_token.json")
|
|
22
|
+
USE_TOKEN_PROVIDER_AGW = os.environ.get("USE_TOKEN_PROVIDER_AGW", 'False').lower() in ('true', '1', 't')
|
|
23
|
+
|
|
24
|
+
# Максимальное число попыток получения токена (по умолчанию 3, можно переопределить через ENV)
|
|
25
|
+
DEFAULT_MAX_RETRIES = int(os.environ.get("GIGACHAT_TOKEN_FETCH_RETRIES", 3))
|
|
26
|
+
# Случайная задержка между повторными попытками: от 1 до 5 секунд
|
|
27
|
+
BACKOFF_MIN = 1
|
|
28
|
+
BACKOFF_MAX = 5
|
|
29
|
+
# Случайный порог обновления токена: от 0 до 300 секунд (5 минут)
|
|
30
|
+
REFRESH_WINDOW_MAX = 300
|
|
31
|
+
|
|
32
|
+
class GigaChatTokenManager:
|
|
33
|
+
# Лок для синхронизации между потоками в одном процессе
|
|
34
|
+
_thread_lock = threading.Lock()
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def _get_new_token():
|
|
38
|
+
"""Запрашивает новый токен у GigaChat API с retry и случайной задержкой."""
|
|
39
|
+
if os.getenv("USE_GIGACHAT_ADVANCED", 'False').lower() in ('true', '1', 't'):
|
|
40
|
+
url = os.environ.get("GIGACHAT_BASE_URL") + "token"
|
|
41
|
+
data = None
|
|
42
|
+
else:
|
|
43
|
+
gigachat_scope = os.environ.get("GIGACHAT_SCOPE", "GIGACHAT_API_PERS")
|
|
44
|
+
|
|
45
|
+
url = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth"
|
|
46
|
+
data = {"scope": gigachat_scope}
|
|
47
|
+
|
|
48
|
+
print(url)
|
|
49
|
+
gigachat_credentials = os.environ.get("GIGACHAT_CREDENTIALS")
|
|
50
|
+
if not gigachat_credentials:
|
|
51
|
+
raise ValueError("Переменная окружения GIGACHAT_CREDENTIALS не установлена.")
|
|
52
|
+
headers = {
|
|
53
|
+
"Content-Type": "application/x-www-form-urlencoded",
|
|
54
|
+
"Accept": "application/json",
|
|
55
|
+
"User-Agent": "agent-toolkit",
|
|
56
|
+
"RqUID": str(uuid.uuid4()), # Уникальный идентификатор запроса
|
|
57
|
+
"Authorization": f"Basic {gigachat_credentials}"
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
max_retries = DEFAULT_MAX_RETRIES
|
|
61
|
+
for attempt in range(1, max_retries + 1):
|
|
62
|
+
try:
|
|
63
|
+
logger.info(f"Попытка получения токена GigaChat ({attempt}/{max_retries})...")
|
|
64
|
+
resp = requests.post(url, headers=headers, data=data, verify=False)
|
|
65
|
+
resp.raise_for_status()
|
|
66
|
+
result = resp.json()
|
|
67
|
+
|
|
68
|
+
token = result.get("access_token")
|
|
69
|
+
if not token:
|
|
70
|
+
token = result.get("tok")
|
|
71
|
+
expires_at = result.get("expires_at")
|
|
72
|
+
if not expires_at:
|
|
73
|
+
expires_at = result.get("exp") * 1000
|
|
74
|
+
if not token or not expires_at:
|
|
75
|
+
raise ValueError(f"Неверный ответ API: {resp.text}")
|
|
76
|
+
|
|
77
|
+
# expires_at в миллисекундах
|
|
78
|
+
expiry_time = datetime.fromtimestamp(expires_at / 1000, tz=timezone.utc) - timedelta(seconds=60)
|
|
79
|
+
logger.info("Успешно получили токен, действует до %s", expiry_time.isoformat())
|
|
80
|
+
return token, expiry_time
|
|
81
|
+
|
|
82
|
+
except (Exception, requests.RequestException, ValueError, json.JSONDecodeError) as e:
|
|
83
|
+
logger.error("Ошибка при попытке %d: %s", attempt, e)
|
|
84
|
+
if attempt < max_retries:
|
|
85
|
+
delay = random.uniform(BACKOFF_MIN, BACKOFF_MAX)
|
|
86
|
+
logger.info("Повтор через %.2f сек...", delay)
|
|
87
|
+
time.sleep(delay)
|
|
88
|
+
else:
|
|
89
|
+
logger.error("Все %d попыток получения токена провалены.", max_retries)
|
|
90
|
+
raise
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def get_token(cls) -> str:
|
|
94
|
+
"""
|
|
95
|
+
Возвращает валидный токен, обновляя его при необходимости.
|
|
96
|
+
Использует межпроцессную блокировку (fcntl.flock) и лок для потоков,
|
|
97
|
+
чтобы избежать дублирующих запросов.
|
|
98
|
+
"""
|
|
99
|
+
if USE_TOKEN_PROVIDER_AGW:
|
|
100
|
+
provider = "GIGACHAT-ADVANCED" if os.getenv("USE_GIGACHAT_ADVANCED", 'False').lower() in ('true', '1', 't') else "GIGACHAT"
|
|
101
|
+
return AgwTokenManager.get_token(provider=provider)
|
|
102
|
+
|
|
103
|
+
# Лок между потоками
|
|
104
|
+
with cls._thread_lock:
|
|
105
|
+
# Гарантируем существование директории под файл токена
|
|
106
|
+
dirpath = os.path.dirname(TOKEN_FILE_PATH)
|
|
107
|
+
if dirpath and not os.path.exists(dirpath):
|
|
108
|
+
os.makedirs(dirpath, exist_ok=True)
|
|
109
|
+
|
|
110
|
+
# Открываем (или создаем) файл токена
|
|
111
|
+
with open(TOKEN_FILE_PATH, "a+", encoding="utf-8") as f:
|
|
112
|
+
# Межпроцессная блокировка файла
|
|
113
|
+
fcntl.flock(f, fcntl.LOCK_EX)
|
|
114
|
+
try:
|
|
115
|
+
# Читаем существующие данные
|
|
116
|
+
f.seek(0)
|
|
117
|
+
try:
|
|
118
|
+
data = json.load(f)
|
|
119
|
+
except (json.JSONDecodeError, IOError):
|
|
120
|
+
data = {}
|
|
121
|
+
|
|
122
|
+
token = data.get("token")
|
|
123
|
+
expiry_ts = data.get("expiry_timestamp")
|
|
124
|
+
now = datetime.now(timezone.utc)
|
|
125
|
+
|
|
126
|
+
# Если токен есть и не истекает раньше случайного порога — возвращаем его
|
|
127
|
+
if token and expiry_ts:
|
|
128
|
+
expiry_time = datetime.fromtimestamp(expiry_ts, tz=timezone.utc)
|
|
129
|
+
# генерируем случайный порог в секундах до истечения (0–300)
|
|
130
|
+
refresh_window = timedelta(seconds=random.uniform(0, REFRESH_WINDOW_MAX))
|
|
131
|
+
time_left = expiry_time - now
|
|
132
|
+
logger.debug("Осталось времени: %s, порог обновления: %s", time_left, refresh_window)
|
|
133
|
+
if time_left > refresh_window:
|
|
134
|
+
logger.debug("Используем кэшированный токен; истекает %s", expiry_time.isoformat())
|
|
135
|
+
return token
|
|
136
|
+
|
|
137
|
+
# Иначе — запрашиваем новый токен
|
|
138
|
+
new_token, new_expiry = cls._get_new_token()
|
|
139
|
+
|
|
140
|
+
# Сохраняем новый токен
|
|
141
|
+
f.seek(0)
|
|
142
|
+
f.truncate()
|
|
143
|
+
json.dump({
|
|
144
|
+
"token": new_token,
|
|
145
|
+
"expiry_timestamp": new_expiry.timestamp()
|
|
146
|
+
}, f)
|
|
147
|
+
f.flush()
|
|
148
|
+
logger.debug("Новый токен сохранен в %s", TOKEN_FILE_PATH)
|
|
149
|
+
return new_token
|
|
150
|
+
|
|
151
|
+
finally:
|
|
152
|
+
# Снимаем межпроцессную блокировку
|
|
153
|
+
fcntl.flock(f, fcntl.LOCK_UN)
|
|
154
|
+
|
|
155
|
+
# Примечание: fcntl.flock обеспечивает межпроцессную блокировку,
|
|
156
|
+
# а threading.Lock — синхронизацию между потоками в одном процессе.
|
agent_lab_sdk/llm/llm.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from langchain_gigachat.chat_models import GigaChat
|
|
2
|
+
from agent_lab_sdk.llm.gigachat_token_manager import GigaChatTokenManager
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
def get_model(**kwargs) -> GigaChat:
|
|
6
|
+
access_token = kwargs.pop("access_token", None)
|
|
7
|
+
if not access_token:
|
|
8
|
+
access_token = GigaChatTokenManager.get_token()
|
|
9
|
+
timeout = kwargs.pop("timeout", None)
|
|
10
|
+
if not timeout:
|
|
11
|
+
timeout=int(os.getenv("GLOBAL_GIGACHAT_TIMEOUT", "120"))
|
|
12
|
+
|
|
13
|
+
scope = kwargs.pop("scope", None)
|
|
14
|
+
if not scope:
|
|
15
|
+
scope = os.getenv("GIGACHAT_SCOPE")
|
|
16
|
+
|
|
17
|
+
verify_ssl_certs = kwargs.pop("verify_ssl_certs", False)
|
|
18
|
+
|
|
19
|
+
if not scope:
|
|
20
|
+
raise ValueError("GIGACHAT_SCOPE environment variable is not set.")
|
|
21
|
+
|
|
22
|
+
return GigaChat(
|
|
23
|
+
access_token=access_token,
|
|
24
|
+
verify_ssl_certs=verify_ssl_certs,
|
|
25
|
+
scope=scope,
|
|
26
|
+
timeout=timeout,
|
|
27
|
+
**kwargs
|
|
28
|
+
)
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import asyncio
|
|
3
|
+
import threading
|
|
4
|
+
import time
|
|
5
|
+
from langchain_gigachat.chat_models import GigaChat
|
|
6
|
+
from langchain_gigachat import GigaChatEmbeddings
|
|
7
|
+
import langchain_gigachat.embeddings.gigachat
|
|
8
|
+
|
|
9
|
+
langchain_gigachat.embeddings.gigachat.MAX_BATCH_SIZE_PARTS=int(os.getenv("EMBEDDINGS_MAX_BATCH_SIZE_PARTS", "90"))
|
|
10
|
+
|
|
11
|
+
MAX_CHAT_CONCURRENCY = int(os.getenv("MAX_CHAT_CONCURRENCY", "100000"))
|
|
12
|
+
MAX_EMBED_CONCURRENCY = int(os.getenv("MAX_EMBED_CONCURRENCY", "100000"))
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from agent_lab_sdk.metrics import get_metric
|
|
16
|
+
|
|
17
|
+
def create_metrics(prefix: str):
|
|
18
|
+
in_use = get_metric(
|
|
19
|
+
metric_type = "gauge", name = f"{prefix}_slots_in_use",
|
|
20
|
+
documentation = f"Number of {prefix} slots currently in use"
|
|
21
|
+
)
|
|
22
|
+
waiting = get_metric(
|
|
23
|
+
metric_type = "gauge", name = f"{prefix}_waiting_tasks",
|
|
24
|
+
documentation = f"Number of tasks waiting for {prefix}"
|
|
25
|
+
)
|
|
26
|
+
wait_time = get_metric(
|
|
27
|
+
metric_type = "histogram", name = f"{prefix}_wait_time_seconds",
|
|
28
|
+
documentation = f"Time tasks wait for {prefix}",
|
|
29
|
+
buckets = [3, 5, 10, 15, 30, 60, 120, 240, 480, 960, 1920, float("inf")]
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
return in_use, waiting, wait_time
|
|
33
|
+
|
|
34
|
+
chat_in_use, chat_waiting, chat_wait_hist = create_metrics("chat")
|
|
35
|
+
embed_in_use, embed_waiting, embed_wait_hist = create_metrics("embed")
|
|
36
|
+
|
|
37
|
+
class UnifiedSemaphore:
|
|
38
|
+
"""Threading-based семафор + sync/async API + metrics + контекстники."""
|
|
39
|
+
def __init__(self, limit, in_use, waiting, wait_hist):
|
|
40
|
+
self._sem = threading.Semaphore(limit)
|
|
41
|
+
self._limit = limit
|
|
42
|
+
self._in_use = in_use
|
|
43
|
+
self._waiting = waiting
|
|
44
|
+
self._wait_hist = wait_hist
|
|
45
|
+
self._current = 0
|
|
46
|
+
|
|
47
|
+
self._in_use.set(0)
|
|
48
|
+
self._waiting.set(0)
|
|
49
|
+
|
|
50
|
+
# ——— синхронный API ———
|
|
51
|
+
def acquire(self):
|
|
52
|
+
self._waiting.inc()
|
|
53
|
+
start = time.time()
|
|
54
|
+
|
|
55
|
+
self._sem.acquire()
|
|
56
|
+
elapsed = time.time() - start
|
|
57
|
+
self._wait_hist.observe(elapsed)
|
|
58
|
+
self._waiting.dec()
|
|
59
|
+
|
|
60
|
+
self._current += 1
|
|
61
|
+
self._in_use.set(self._current)
|
|
62
|
+
|
|
63
|
+
def release(self):
|
|
64
|
+
self._sem.release()
|
|
65
|
+
self._current -= 1
|
|
66
|
+
self._in_use.set(self._current)
|
|
67
|
+
|
|
68
|
+
# контекстник для sync
|
|
69
|
+
def __enter__(self):
|
|
70
|
+
self.acquire()
|
|
71
|
+
return self
|
|
72
|
+
|
|
73
|
+
def __exit__(self, exc_type, exc, tb):
|
|
74
|
+
self.release()
|
|
75
|
+
|
|
76
|
+
# ——— асинхронный API ———
|
|
77
|
+
async def acquire_async(self):
|
|
78
|
+
self._waiting.inc()
|
|
79
|
+
start = time.time()
|
|
80
|
+
loop = asyncio.get_running_loop()
|
|
81
|
+
await loop.run_in_executor(None, self._sem.acquire)
|
|
82
|
+
elapsed = time.time() - start
|
|
83
|
+
self._wait_hist.observe(elapsed)
|
|
84
|
+
self._waiting.dec()
|
|
85
|
+
|
|
86
|
+
self._current += 1
|
|
87
|
+
self._in_use.set(self._current)
|
|
88
|
+
|
|
89
|
+
async def release_async(self):
|
|
90
|
+
# release очень быстрый
|
|
91
|
+
self._sem.release()
|
|
92
|
+
self._current -= 1
|
|
93
|
+
self._in_use.set(self._current)
|
|
94
|
+
|
|
95
|
+
# контекстник для async
|
|
96
|
+
async def __aenter__(self):
|
|
97
|
+
await self.acquire_async()
|
|
98
|
+
return self
|
|
99
|
+
|
|
100
|
+
async def __aexit__(self, exc_type, exc, tb):
|
|
101
|
+
await self.release_async()
|
|
102
|
+
|
|
103
|
+
# Semaphores for chat and embeddings
|
|
104
|
+
_semaphores = {
|
|
105
|
+
"chat": UnifiedSemaphore(MAX_CHAT_CONCURRENCY, chat_in_use, chat_waiting, chat_wait_hist),
|
|
106
|
+
"embed": UnifiedSemaphore(MAX_EMBED_CONCURRENCY, embed_in_use, embed_waiting, embed_wait_hist),
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
class ThrottledGigaChatEmbeddings(GigaChatEmbeddings):
|
|
110
|
+
def embed_documents(self, *args, **kwargs):
|
|
111
|
+
with _semaphores["embed"]:
|
|
112
|
+
return super().embed_documents(*args, **kwargs)
|
|
113
|
+
|
|
114
|
+
def embed_query(self, *args, **kwargs):
|
|
115
|
+
# здесь семафор не нужен, под капотом вызвается embed_documents, семафор уже там
|
|
116
|
+
return super().embed_query(*args, **kwargs)
|
|
117
|
+
|
|
118
|
+
async def aembed_documents(self, *args, **kwargs):
|
|
119
|
+
async with _semaphores["embed"]:
|
|
120
|
+
return await super().aembed_documents(*args, **kwargs)
|
|
121
|
+
|
|
122
|
+
async def aembed_query(self, *args, **kwargs):
|
|
123
|
+
# здесь семафор не нужен, под капотом вызвается aembed_documents, семафор уже там
|
|
124
|
+
return await super().aembed_query(*args, **kwargs)
|
|
125
|
+
|
|
126
|
+
# по хорошему бы переопределять клиент гигачата или манкипатчить его, но это не так просто
|
|
127
|
+
class ThrottledGigaChat(GigaChat):
|
|
128
|
+
def invoke(self, *args, **kwargs):
|
|
129
|
+
with _semaphores["chat"]:
|
|
130
|
+
return super().invoke(*args, **kwargs)
|
|
131
|
+
|
|
132
|
+
async def ainvoke(self, *args, **kwargs):
|
|
133
|
+
async with _semaphores["chat"]:
|
|
134
|
+
return await super().ainvoke(*args, **kwargs)
|
|
135
|
+
|
|
136
|
+
def stream(self, *args, **kwargs):
|
|
137
|
+
if super()._should_stream(async_api=False, **{**kwargs, "stream": True}):
|
|
138
|
+
with _semaphores["chat"]:
|
|
139
|
+
for chunk in super().stream(*args, **kwargs):
|
|
140
|
+
yield chunk
|
|
141
|
+
else:
|
|
142
|
+
# здесь есть проблема когда внутри stream вызывается invoke, поэтому без семафора
|
|
143
|
+
for chunk in super().stream(*args, **kwargs):
|
|
144
|
+
yield chunk
|
|
145
|
+
|
|
146
|
+
async def astream(self, *args, **kwargs):
|
|
147
|
+
if super()._should_stream(async_api=True, **{**kwargs, "stream": True}):
|
|
148
|
+
async with _semaphores["chat"]:
|
|
149
|
+
async for chunk in super().astream(*args, **kwargs):
|
|
150
|
+
yield chunk
|
|
151
|
+
else:
|
|
152
|
+
# здесь есть проблема когда внутри stream вызывается ainvoke, поэтому без семафора
|
|
153
|
+
async for chunk in super().astream(*args, **kwargs):
|
|
154
|
+
yield chunk
|
|
155
|
+
|
|
156
|
+
async def astream_events(self, *args, **kwargs):
|
|
157
|
+
async with _semaphores["chat"]:
|
|
158
|
+
async for ev in super().astream_events(*args, **kwargs):
|
|
159
|
+
yield ev
|
|
160
|
+
|
|
161
|
+
def batch(self, *args, **kwargs):
|
|
162
|
+
# здесь семафор не нужен, под капотом вызывается invoke, семафор уже там
|
|
163
|
+
return super().batch(*args, **kwargs)
|
|
164
|
+
|
|
165
|
+
async def abatch(self, *args, **kwargs):
|
|
166
|
+
# здесь семафор не нужен, под капотом вызывается ainvoke, семафор уже там
|
|
167
|
+
return await super().abatch(*args, **kwargs)
|
|
168
|
+
|
|
169
|
+
def batch_as_completed(self, *args, **kwargs):
|
|
170
|
+
# здесь семафор не нужен, под капотом вызывается invoke, семафор уже там
|
|
171
|
+
for item in super().batch_as_completed(*args, **kwargs):
|
|
172
|
+
yield item
|
|
173
|
+
|
|
174
|
+
async def abatch_as_completed(self, *args, **kwargs):
|
|
175
|
+
# здесь семафор не нужен, под капотом вызывается ainvoke, семафор уже там
|
|
176
|
+
async for item in super().abatch_as_completed(*args, **kwargs):
|
|
177
|
+
yield item
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
3
|
+
|
|
4
|
+
from prometheus_client import Counter, Gauge, Summary, Histogram, Info, Enum
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
_metrics: Dict[str, Any] = {}
|
|
9
|
+
|
|
10
|
+
MetricType = Union[
|
|
11
|
+
Type[Counter],
|
|
12
|
+
Type[Gauge],
|
|
13
|
+
Type[Summary],
|
|
14
|
+
Type[Histogram],
|
|
15
|
+
Type[Info],
|
|
16
|
+
Type[Enum],
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
_CONSTRUCTORS: Dict[str, MetricType] = {
|
|
20
|
+
"counter": Counter,
|
|
21
|
+
"gauge": Gauge,
|
|
22
|
+
"summary": Summary,
|
|
23
|
+
"histogram": Histogram,
|
|
24
|
+
"info": Info,
|
|
25
|
+
"enum": Enum,
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
def get_metric(
|
|
29
|
+
metric_type: str,
|
|
30
|
+
name: str,
|
|
31
|
+
documentation: str,
|
|
32
|
+
*,
|
|
33
|
+
labelnames: Optional[List[str]] = None,
|
|
34
|
+
buckets: Optional[List[float]] = None,
|
|
35
|
+
states: Optional[List[str]] = None,
|
|
36
|
+
**kwargs
|
|
37
|
+
) -> Any:
|
|
38
|
+
"""
|
|
39
|
+
Возвращает метрику по name, создаёт её при первом запросе.
|
|
40
|
+
|
|
41
|
+
metric_type — один из ключей _CONSTRUCTORS ('counter', 'gauge', 'summary', 'histogram', 'info', 'enum').
|
|
42
|
+
labelnames — список имён лейблов (для Counter, Gauge, Summary, Histogram).
|
|
43
|
+
buckets — для histogram.
|
|
44
|
+
states — для enum.
|
|
45
|
+
kwargs — остальные параметры конструктору.
|
|
46
|
+
"""
|
|
47
|
+
if name in _metrics:
|
|
48
|
+
return _metrics[name]
|
|
49
|
+
|
|
50
|
+
ctor = _CONSTRUCTORS.get(metric_type)
|
|
51
|
+
if ctor is None:
|
|
52
|
+
raise ValueError(f"Unknown metric type: {metric_type!r}")
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
# Собираем позиционные и именованные аргументы в зависимости от типа
|
|
56
|
+
args: Tuple = ()
|
|
57
|
+
if metric_type == "histogram" and buckets is not None:
|
|
58
|
+
# Histogram(name, doc, labelnames?, buckets=buckets, **kwargs)
|
|
59
|
+
args = (name, documentation)
|
|
60
|
+
if labelnames:
|
|
61
|
+
args += (labelnames,)
|
|
62
|
+
_metrics[name] = ctor(*args, buckets=buckets, **kwargs)
|
|
63
|
+
elif metric_type == "enum" and states is not None:
|
|
64
|
+
# Enum(name, doc, labelnames?, states=states, **kwargs)
|
|
65
|
+
args = (name, documentation)
|
|
66
|
+
if labelnames:
|
|
67
|
+
args += (labelnames,)
|
|
68
|
+
_metrics[name] = ctor(*args, states=states, **kwargs)
|
|
69
|
+
else:
|
|
70
|
+
# Counter, Gauge, Summary, Info
|
|
71
|
+
args = (name, documentation)
|
|
72
|
+
if labelnames:
|
|
73
|
+
args += (labelnames,)
|
|
74
|
+
_metrics[name] = ctor(*args, **kwargs)
|
|
75
|
+
|
|
76
|
+
except Exception as e:
|
|
77
|
+
logger.error("Failed to create %s metric '%s': %s", metric_type, name, e)
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
return _metrics[name]
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
if __name__ == "__main__":
|
|
84
|
+
|
|
85
|
+
# получение счётчика
|
|
86
|
+
reqs = get_metric("http_reqs", "counter", "http_requests_total",
|
|
87
|
+
"Всего HTTP-запросов", labelnames=["method", "endpoint"])
|
|
88
|
+
|
|
89
|
+
# увеличение
|
|
90
|
+
if reqs:
|
|
91
|
+
reqs.labels("GET", "/api").inc()
|
|
92
|
+
|
|
93
|
+
# получение гистограммы
|
|
94
|
+
lat = get_metric("http_latency", "histogram", "http_request_latency_seconds",
|
|
95
|
+
"Длительность HTTP-запроса", buckets=[0.1, 0.5, 1.0, 5.0])
|
|
96
|
+
|
|
97
|
+
# замер времени через контекст
|
|
98
|
+
if lat:
|
|
99
|
+
with lat.time():
|
|
100
|
+
import time
|
|
101
|
+
time.sleep(0.5) # имитируем длительный запрос
|
|
102
|
+
|
|
103
|
+
print(reqs.collect())
|
|
104
|
+
print(lat.collect())
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import os
|
|
3
|
+
import mimetypes
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from io import BytesIO
|
|
6
|
+
import requests
|
|
7
|
+
|
|
8
|
+
def store_file_in_sd_asset(filename: str, file_base64: str, folder: str = "giga-agents") -> Optional[str]:
|
|
9
|
+
"""
|
|
10
|
+
Загружает файл в формате base64 в SD Asset API и возвращает URL загруженного файла.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
file_base64: файл в формате base64
|
|
14
|
+
folder: Название папки для загрузки
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
URL загруженного файла или None в случае ошибки
|
|
18
|
+
"""
|
|
19
|
+
# Декодируем base64 в бинарные данные
|
|
20
|
+
try:
|
|
21
|
+
file_data = base64.b64decode(file_base64.split(",")[-1])
|
|
22
|
+
except Exception as e:
|
|
23
|
+
print(f"Ошибка декодирования base64: {e}")
|
|
24
|
+
return None
|
|
25
|
+
|
|
26
|
+
# URL API
|
|
27
|
+
url = "https://asset.tools.sberdevices.ru/api/file/upload"
|
|
28
|
+
|
|
29
|
+
api_key = os.getenv("SD_ASSET_API_KEY")
|
|
30
|
+
if not api_key:
|
|
31
|
+
raise ValueError("SD_ASSET_API_KEY is missing")
|
|
32
|
+
|
|
33
|
+
# Заголовки запроса
|
|
34
|
+
headers = {
|
|
35
|
+
"X-Api-Key": api_key
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
mimetype, _ = mimetypes.guess_type(filename)
|
|
39
|
+
|
|
40
|
+
# Параметры формы
|
|
41
|
+
files = {
|
|
42
|
+
"files": (f"{filename}", BytesIO(file_data), mimetype)
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
# Дополнительные параметры
|
|
46
|
+
data = {
|
|
47
|
+
"folder": folder,
|
|
48
|
+
"uniqueNames": "false",
|
|
49
|
+
"removable": "true"
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
# Отправляем POST-запрос
|
|
54
|
+
response = requests.post(url, headers=headers, data=data, files=files)
|
|
55
|
+
response.raise_for_status()
|
|
56
|
+
|
|
57
|
+
# Получаем URL загруженного файла
|
|
58
|
+
if response.status_code == 200 and response.json():
|
|
59
|
+
return response.json()[0].replace('cdn-app.sberdevices.ru', 'cdn-app.giga.chat')
|
|
60
|
+
|
|
61
|
+
except requests.exceptions.RequestException as e:
|
|
62
|
+
print(f"Ошибка при загрузке файла: {e}")
|
|
63
|
+
|
|
64
|
+
return None
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: agent-lab-sdk
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.4
|
|
4
4
|
Summary: SDK для работы с Agent Lab
|
|
5
5
|
Author-email: Andrew Ohurtsov <andermirik@yandex.com>
|
|
6
6
|
License: Proprietary and Confidential — All Rights Reserved
|
|
@@ -16,4 +16,8 @@ Classifier: Operating System :: OS Independent
|
|
|
16
16
|
Requires-Python: >=3.11
|
|
17
17
|
Description-Content-Type: text/markdown
|
|
18
18
|
License-File: LICENSE
|
|
19
|
+
Requires-Dist: requests
|
|
20
|
+
Requires-Dist: langgraph~=0.4.1
|
|
21
|
+
Requires-Dist: langchain_gigachat
|
|
22
|
+
Requires-Dist: prometheus-client
|
|
19
23
|
Dynamic: license-file
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
agent_lab_sdk/__init__.py,sha256=rggZCMNBJ3Vd5rZkC6kepxILIGCCAmFTPHZ-ORv97nY,90
|
|
2
|
+
agent_lab_sdk/langgraph/checkpoint/__init__.py,sha256=DnKwR1LwbaQ3qhb124lE-tnojrUIVcCdNzHEHwgpL5M,86
|
|
3
|
+
agent_lab_sdk/langgraph/checkpoint/agw_saver.py,sha256=ALhZW-7XUIVsbZVjCRVhZcMmcYj_f9k3UU-owaq5vxM,13797
|
|
4
|
+
agent_lab_sdk/llm/__init__.py,sha256=Yo9MbYdHS1iX05A9XiJGwWN1Hm4IARGav9mNFPrtDeA,376
|
|
5
|
+
agent_lab_sdk/llm/agw_token_manager.py,sha256=_bPPI8muaEa6H01P8hHQOJHiiivaLd8N_d3OT9UT_80,4787
|
|
6
|
+
agent_lab_sdk/llm/gigachat_token_manager.py,sha256=TPA8cb0ypdWtRTI5C7GItL9jbLt93vR-Ijf2yMrOytQ,7921
|
|
7
|
+
agent_lab_sdk/llm/llm.py,sha256=NOEH9TOH66EIJXGevxPm6w6px7Z0cZl9DJ-9A7jOnd0,873
|
|
8
|
+
agent_lab_sdk/llm/throttled.py,sha256=9_nm1i3Uuep0VEWsY1KNCllZA-vM202XVdlgXhgC8BA,7005
|
|
9
|
+
agent_lab_sdk/metrics/__init__.py,sha256=G4VSlzKwupPMM4c6vZaF1rnd0KusKarezDMjli9pVFw,57
|
|
10
|
+
agent_lab_sdk/metrics/metrics.py,sha256=2e0c7BanThUNtCxpS6BUlAIDoLSidQsuaaBP5EB48Yo,3432
|
|
11
|
+
agent_lab_sdk/storage/__init__.py,sha256=ik1_v1DMTwehvcAEXIYxuvLuCjJCa3y5qAuJqoQpuSA,81
|
|
12
|
+
agent_lab_sdk/storage/storage.py,sha256=ELpt7GRwFD-aWa6ctinfA_QwcvzWLvKS0Wz8FlxVqAs,2075
|
|
13
|
+
agent_lab_sdk-0.1.4.dist-info/licenses/LICENSE,sha256=_TRXHkF3S9ilWBPdZcHLI_S-PRjK0L_SeOb2pcPAdV4,417
|
|
14
|
+
agent_lab_sdk-0.1.4.dist-info/METADATA,sha256=cfGy-8cVIcRC7IhJKxUCb27nPC5vnOBuSwTpTKmetcA,884
|
|
15
|
+
agent_lab_sdk-0.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
16
|
+
agent_lab_sdk-0.1.4.dist-info/top_level.txt,sha256=E1efqkJ89KNmPBWdLzdMHeVtH0dYyCo4fhnSb81_15I,14
|
|
17
|
+
agent_lab_sdk-0.1.4.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
agent_lab_sdk/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
agent_lab_sdk-0.1.3.dist-info/licenses/LICENSE,sha256=_TRXHkF3S9ilWBPdZcHLI_S-PRjK0L_SeOb2pcPAdV4,417
|
|
3
|
-
agent_lab_sdk-0.1.3.dist-info/METADATA,sha256=yvcBFQxv8NoD7Gqe3mthMUkwgfJrd_pnTfafa8bZT5M,761
|
|
4
|
-
agent_lab_sdk-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
5
|
-
agent_lab_sdk-0.1.3.dist-info/top_level.txt,sha256=E1efqkJ89KNmPBWdLzdMHeVtH0dYyCo4fhnSb81_15I,14
|
|
6
|
-
agent_lab_sdk-0.1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|