ai-lib-python 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_lib_python/__init__.py +43 -0
- ai_lib_python/batch/__init__.py +15 -0
- ai_lib_python/batch/collector.py +244 -0
- ai_lib_python/batch/executor.py +224 -0
- ai_lib_python/cache/__init__.py +26 -0
- ai_lib_python/cache/backends.py +380 -0
- ai_lib_python/cache/key.py +237 -0
- ai_lib_python/cache/manager.py +332 -0
- ai_lib_python/client/__init__.py +37 -0
- ai_lib_python/client/builder.py +528 -0
- ai_lib_python/client/cancel.py +368 -0
- ai_lib_python/client/core.py +433 -0
- ai_lib_python/client/response.py +134 -0
- ai_lib_python/embeddings/__init__.py +36 -0
- ai_lib_python/embeddings/client.py +339 -0
- ai_lib_python/embeddings/types.py +234 -0
- ai_lib_python/embeddings/vectors.py +246 -0
- ai_lib_python/errors/__init__.py +41 -0
- ai_lib_python/errors/base.py +316 -0
- ai_lib_python/errors/classification.py +210 -0
- ai_lib_python/guardrails/__init__.py +35 -0
- ai_lib_python/guardrails/base.py +336 -0
- ai_lib_python/guardrails/filters.py +583 -0
- ai_lib_python/guardrails/validators.py +475 -0
- ai_lib_python/pipeline/__init__.py +55 -0
- ai_lib_python/pipeline/accumulate.py +248 -0
- ai_lib_python/pipeline/base.py +240 -0
- ai_lib_python/pipeline/decode.py +281 -0
- ai_lib_python/pipeline/event_map.py +506 -0
- ai_lib_python/pipeline/fan_out.py +284 -0
- ai_lib_python/pipeline/select.py +297 -0
- ai_lib_python/plugins/__init__.py +32 -0
- ai_lib_python/plugins/base.py +294 -0
- ai_lib_python/plugins/hooks.py +296 -0
- ai_lib_python/plugins/middleware.py +285 -0
- ai_lib_python/plugins/registry.py +294 -0
- ai_lib_python/protocol/__init__.py +71 -0
- ai_lib_python/protocol/loader.py +317 -0
- ai_lib_python/protocol/manifest.py +385 -0
- ai_lib_python/protocol/validator.py +460 -0
- ai_lib_python/py.typed +1 -0
- ai_lib_python/resilience/__init__.py +102 -0
- ai_lib_python/resilience/backpressure.py +225 -0
- ai_lib_python/resilience/circuit_breaker.py +318 -0
- ai_lib_python/resilience/executor.py +343 -0
- ai_lib_python/resilience/fallback.py +341 -0
- ai_lib_python/resilience/preflight.py +413 -0
- ai_lib_python/resilience/rate_limiter.py +291 -0
- ai_lib_python/resilience/retry.py +299 -0
- ai_lib_python/resilience/signals.py +283 -0
- ai_lib_python/routing/__init__.py +118 -0
- ai_lib_python/routing/manager.py +593 -0
- ai_lib_python/routing/strategy.py +345 -0
- ai_lib_python/routing/types.py +397 -0
- ai_lib_python/structured/__init__.py +33 -0
- ai_lib_python/structured/json_mode.py +281 -0
- ai_lib_python/structured/schema.py +316 -0
- ai_lib_python/structured/validator.py +334 -0
- ai_lib_python/telemetry/__init__.py +127 -0
- ai_lib_python/telemetry/exporters/__init__.py +9 -0
- ai_lib_python/telemetry/exporters/prometheus.py +111 -0
- ai_lib_python/telemetry/feedback.py +446 -0
- ai_lib_python/telemetry/health.py +409 -0
- ai_lib_python/telemetry/logger.py +389 -0
- ai_lib_python/telemetry/metrics.py +496 -0
- ai_lib_python/telemetry/tracer.py +473 -0
- ai_lib_python/tokens/__init__.py +25 -0
- ai_lib_python/tokens/counter.py +282 -0
- ai_lib_python/tokens/estimator.py +286 -0
- ai_lib_python/transport/__init__.py +34 -0
- ai_lib_python/transport/auth.py +141 -0
- ai_lib_python/transport/http.py +364 -0
- ai_lib_python/transport/pool.py +425 -0
- ai_lib_python/types/__init__.py +41 -0
- ai_lib_python/types/events.py +343 -0
- ai_lib_python/types/message.py +332 -0
- ai_lib_python/types/tool.py +191 -0
- ai_lib_python/utils/__init__.py +21 -0
- ai_lib_python/utils/tool_call_assembler.py +317 -0
- ai_lib_python-0.5.0.dist-info/METADATA +837 -0
- ai_lib_python-0.5.0.dist-info/RECORD +84 -0
- ai_lib_python-0.5.0.dist-info/WHEEL +4 -0
- ai_lib_python-0.5.0.dist-info/licenses/LICENSE-APACHE +201 -0
- ai_lib_python-0.5.0.dist-info/licenses/LICENSE-MIT +21 -0
|
@@ -0,0 +1,380 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Cache backend implementations.
|
|
3
|
+
|
|
4
|
+
Provides memory, disk, and null cache backends.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import hashlib
|
|
11
|
+
import json
|
|
12
|
+
import time
|
|
13
|
+
from abc import ABC, abstractmethod
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class CacheEntry:
|
|
21
|
+
"""A cache entry with metadata.
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
value: Cached value
|
|
25
|
+
created_at: Creation timestamp
|
|
26
|
+
ttl: Time-to-live in seconds
|
|
27
|
+
hits: Number of cache hits
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
value: Any
|
|
31
|
+
created_at: float
|
|
32
|
+
ttl: float | None = None
|
|
33
|
+
hits: int = 0
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def is_expired(self) -> bool:
|
|
37
|
+
"""Check if entry is expired."""
|
|
38
|
+
if self.ttl is None:
|
|
39
|
+
return False
|
|
40
|
+
return time.time() > self.created_at + self.ttl
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def age_seconds(self) -> float:
|
|
44
|
+
"""Get age in seconds."""
|
|
45
|
+
return time.time() - self.created_at
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class CacheBackend(ABC):
|
|
49
|
+
"""Abstract base class for cache backends."""
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
async def get(self, key: str) -> Any | None:
|
|
53
|
+
"""Get a value from the cache.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
key: Cache key
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Cached value or None
|
|
60
|
+
"""
|
|
61
|
+
raise NotImplementedError
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
async def set(
|
|
65
|
+
self, key: str, value: Any, ttl: float | None = None
|
|
66
|
+
) -> None:
|
|
67
|
+
"""Set a value in the cache.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
key: Cache key
|
|
71
|
+
value: Value to cache
|
|
72
|
+
ttl: Time-to-live in seconds
|
|
73
|
+
"""
|
|
74
|
+
raise NotImplementedError
|
|
75
|
+
|
|
76
|
+
@abstractmethod
|
|
77
|
+
async def delete(self, key: str) -> bool:
|
|
78
|
+
"""Delete a value from the cache.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
key: Cache key
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
True if deleted, False if not found
|
|
85
|
+
"""
|
|
86
|
+
raise NotImplementedError
|
|
87
|
+
|
|
88
|
+
@abstractmethod
|
|
89
|
+
async def clear(self) -> None:
|
|
90
|
+
"""Clear all cache entries."""
|
|
91
|
+
raise NotImplementedError
|
|
92
|
+
|
|
93
|
+
@abstractmethod
|
|
94
|
+
async def exists(self, key: str) -> bool:
|
|
95
|
+
"""Check if a key exists in the cache.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
key: Cache key
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
True if exists, False otherwise
|
|
102
|
+
"""
|
|
103
|
+
raise NotImplementedError
|
|
104
|
+
|
|
105
|
+
async def close(self) -> None:
|
|
106
|
+
"""Close the backend (cleanup)."""
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class MemoryCache(CacheBackend):
|
|
111
|
+
"""In-memory cache backend with TTL support.
|
|
112
|
+
|
|
113
|
+
Example:
|
|
114
|
+
>>> cache = MemoryCache(max_size=1000, default_ttl=3600)
|
|
115
|
+
>>> await cache.set("key", {"data": "value"})
|
|
116
|
+
>>> value = await cache.get("key")
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def __init__(
|
|
120
|
+
self,
|
|
121
|
+
max_size: int = 1000,
|
|
122
|
+
default_ttl: float | None = None,
|
|
123
|
+
) -> None:
|
|
124
|
+
"""Initialize memory cache.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
max_size: Maximum number of entries
|
|
128
|
+
default_ttl: Default TTL in seconds
|
|
129
|
+
"""
|
|
130
|
+
self._cache: dict[str, CacheEntry] = {}
|
|
131
|
+
self._max_size = max_size
|
|
132
|
+
self._default_ttl = default_ttl
|
|
133
|
+
self._lock = asyncio.Lock()
|
|
134
|
+
|
|
135
|
+
async def get(self, key: str) -> Any | None:
|
|
136
|
+
"""Get a value from the cache."""
|
|
137
|
+
async with self._lock:
|
|
138
|
+
entry = self._cache.get(key)
|
|
139
|
+
if entry is None:
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
if entry.is_expired:
|
|
143
|
+
del self._cache[key]
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
entry.hits += 1
|
|
147
|
+
return entry.value
|
|
148
|
+
|
|
149
|
+
async def set(
|
|
150
|
+
self, key: str, value: Any, ttl: float | None = None
|
|
151
|
+
) -> None:
|
|
152
|
+
"""Set a value in the cache."""
|
|
153
|
+
async with self._lock:
|
|
154
|
+
# Evict if at capacity
|
|
155
|
+
if len(self._cache) >= self._max_size and key not in self._cache:
|
|
156
|
+
self._evict_one()
|
|
157
|
+
|
|
158
|
+
self._cache[key] = CacheEntry(
|
|
159
|
+
value=value,
|
|
160
|
+
created_at=time.time(),
|
|
161
|
+
ttl=ttl if ttl is not None else self._default_ttl,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
async def delete(self, key: str) -> bool:
|
|
165
|
+
"""Delete a value from the cache."""
|
|
166
|
+
async with self._lock:
|
|
167
|
+
if key in self._cache:
|
|
168
|
+
del self._cache[key]
|
|
169
|
+
return True
|
|
170
|
+
return False
|
|
171
|
+
|
|
172
|
+
async def clear(self) -> None:
|
|
173
|
+
"""Clear all cache entries."""
|
|
174
|
+
async with self._lock:
|
|
175
|
+
self._cache.clear()
|
|
176
|
+
|
|
177
|
+
async def exists(self, key: str) -> bool:
|
|
178
|
+
"""Check if a key exists."""
|
|
179
|
+
async with self._lock:
|
|
180
|
+
entry = self._cache.get(key)
|
|
181
|
+
if entry is None:
|
|
182
|
+
return False
|
|
183
|
+
if entry.is_expired:
|
|
184
|
+
del self._cache[key]
|
|
185
|
+
return False
|
|
186
|
+
return True
|
|
187
|
+
|
|
188
|
+
def _evict_one(self) -> None:
|
|
189
|
+
"""Evict one entry (LRU-like based on hits)."""
|
|
190
|
+
if not self._cache:
|
|
191
|
+
return
|
|
192
|
+
|
|
193
|
+
# First, remove expired entries
|
|
194
|
+
expired = [k for k, v in self._cache.items() if v.is_expired]
|
|
195
|
+
if expired:
|
|
196
|
+
del self._cache[expired[0]]
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
# Otherwise, remove entry with lowest hits
|
|
200
|
+
min_hits_key = min(self._cache.keys(), key=lambda k: self._cache[k].hits)
|
|
201
|
+
del self._cache[min_hits_key]
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def size(self) -> int:
|
|
205
|
+
"""Get current cache size."""
|
|
206
|
+
return len(self._cache)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class DiskCache(CacheBackend):
|
|
210
|
+
"""Disk-based cache backend with TTL support.
|
|
211
|
+
|
|
212
|
+
Stores cached values as JSON files on disk.
|
|
213
|
+
|
|
214
|
+
Example:
|
|
215
|
+
>>> cache = DiskCache(path="/tmp/ai_cache", default_ttl=86400)
|
|
216
|
+
>>> await cache.set("key", {"data": "value"})
|
|
217
|
+
>>> value = await cache.get("key")
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
def __init__(
|
|
221
|
+
self,
|
|
222
|
+
path: str | Path,
|
|
223
|
+
default_ttl: float | None = None,
|
|
224
|
+
max_size_bytes: int = 100 * 1024 * 1024, # 100MB default
|
|
225
|
+
) -> None:
|
|
226
|
+
"""Initialize disk cache.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
path: Cache directory path
|
|
230
|
+
default_ttl: Default TTL in seconds
|
|
231
|
+
max_size_bytes: Maximum cache size in bytes
|
|
232
|
+
"""
|
|
233
|
+
self._path = Path(path)
|
|
234
|
+
self._default_ttl = default_ttl
|
|
235
|
+
self._max_size_bytes = max_size_bytes
|
|
236
|
+
self._lock = asyncio.Lock()
|
|
237
|
+
|
|
238
|
+
# Create cache directory
|
|
239
|
+
self._path.mkdir(parents=True, exist_ok=True)
|
|
240
|
+
|
|
241
|
+
def _key_to_path(self, key: str) -> Path:
|
|
242
|
+
"""Convert cache key to file path.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
key: Cache key
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
File path
|
|
249
|
+
"""
|
|
250
|
+
# Hash key to create safe filename
|
|
251
|
+
key_hash = hashlib.sha256(key.encode()).hexdigest()
|
|
252
|
+
return self._path / f"{key_hash}.json"
|
|
253
|
+
|
|
254
|
+
async def get(self, key: str) -> Any | None:
|
|
255
|
+
"""Get a value from the cache."""
|
|
256
|
+
path = self._key_to_path(key)
|
|
257
|
+
|
|
258
|
+
async with self._lock:
|
|
259
|
+
if not path.exists():
|
|
260
|
+
return None
|
|
261
|
+
|
|
262
|
+
try:
|
|
263
|
+
content = path.read_text()
|
|
264
|
+
data = json.loads(content)
|
|
265
|
+
|
|
266
|
+
# Check expiry
|
|
267
|
+
if data.get("ttl") is not None:
|
|
268
|
+
expires_at = data["created_at"] + data["ttl"]
|
|
269
|
+
if time.time() > expires_at:
|
|
270
|
+
path.unlink(missing_ok=True)
|
|
271
|
+
return None
|
|
272
|
+
|
|
273
|
+
return data["value"]
|
|
274
|
+
|
|
275
|
+
except (json.JSONDecodeError, KeyError, OSError):
|
|
276
|
+
return None
|
|
277
|
+
|
|
278
|
+
async def set(
|
|
279
|
+
self, key: str, value: Any, ttl: float | None = None
|
|
280
|
+
) -> None:
|
|
281
|
+
"""Set a value in the cache."""
|
|
282
|
+
path = self._key_to_path(key)
|
|
283
|
+
|
|
284
|
+
async with self._lock:
|
|
285
|
+
data = {
|
|
286
|
+
"key": key,
|
|
287
|
+
"value": value,
|
|
288
|
+
"created_at": time.time(),
|
|
289
|
+
"ttl": ttl if ttl is not None else self._default_ttl,
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
try:
|
|
293
|
+
path.write_text(json.dumps(data, default=str))
|
|
294
|
+
except OSError:
|
|
295
|
+
# Ignore write errors
|
|
296
|
+
pass
|
|
297
|
+
|
|
298
|
+
async def delete(self, key: str) -> bool:
|
|
299
|
+
"""Delete a value from the cache."""
|
|
300
|
+
path = self._key_to_path(key)
|
|
301
|
+
|
|
302
|
+
async with self._lock:
|
|
303
|
+
if path.exists():
|
|
304
|
+
path.unlink(missing_ok=True)
|
|
305
|
+
return True
|
|
306
|
+
return False
|
|
307
|
+
|
|
308
|
+
async def clear(self) -> None:
|
|
309
|
+
"""Clear all cache entries."""
|
|
310
|
+
async with self._lock:
|
|
311
|
+
for path in self._path.glob("*.json"):
|
|
312
|
+
path.unlink(missing_ok=True)
|
|
313
|
+
|
|
314
|
+
async def exists(self, key: str) -> bool:
|
|
315
|
+
"""Check if a key exists."""
|
|
316
|
+
path = self._key_to_path(key)
|
|
317
|
+
return path.exists()
|
|
318
|
+
|
|
319
|
+
async def cleanup_expired(self) -> int:
|
|
320
|
+
"""Remove expired entries.
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
Number of entries removed
|
|
324
|
+
"""
|
|
325
|
+
removed = 0
|
|
326
|
+
now = time.time()
|
|
327
|
+
|
|
328
|
+
async with self._lock:
|
|
329
|
+
for path in self._path.glob("*.json"):
|
|
330
|
+
try:
|
|
331
|
+
content = path.read_text()
|
|
332
|
+
data = json.loads(content)
|
|
333
|
+
|
|
334
|
+
if data.get("ttl") is not None:
|
|
335
|
+
expires_at = data["created_at"] + data["ttl"]
|
|
336
|
+
if now > expires_at:
|
|
337
|
+
path.unlink(missing_ok=True)
|
|
338
|
+
removed += 1
|
|
339
|
+
|
|
340
|
+
except (json.JSONDecodeError, KeyError, OSError):
|
|
341
|
+
pass
|
|
342
|
+
|
|
343
|
+
return removed
|
|
344
|
+
|
|
345
|
+
@property
|
|
346
|
+
def cache_size_bytes(self) -> int:
|
|
347
|
+
"""Get total cache size in bytes."""
|
|
348
|
+
total = 0
|
|
349
|
+
for path in self._path.glob("*.json"):
|
|
350
|
+
total += path.stat().st_size
|
|
351
|
+
return total
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
class NullCache(CacheBackend):
|
|
355
|
+
"""Null cache backend that doesn't cache anything.
|
|
356
|
+
|
|
357
|
+
Useful for testing or disabling caching.
|
|
358
|
+
"""
|
|
359
|
+
|
|
360
|
+
async def get(self, key: str) -> Any | None: # noqa: ARG002
|
|
361
|
+
"""Always returns None."""
|
|
362
|
+
return None
|
|
363
|
+
|
|
364
|
+
async def set(
|
|
365
|
+
self, key: str, value: Any, ttl: float | None = None
|
|
366
|
+
) -> None:
|
|
367
|
+
"""Does nothing."""
|
|
368
|
+
pass
|
|
369
|
+
|
|
370
|
+
async def delete(self, key: str) -> bool: # noqa: ARG002
|
|
371
|
+
"""Always returns False."""
|
|
372
|
+
return False
|
|
373
|
+
|
|
374
|
+
async def clear(self) -> None:
|
|
375
|
+
"""Does nothing."""
|
|
376
|
+
pass
|
|
377
|
+
|
|
378
|
+
async def exists(self, key: str) -> bool: # noqa: ARG002
|
|
379
|
+
"""Always returns False."""
|
|
380
|
+
return False
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Cache key generation utilities.
|
|
3
|
+
|
|
4
|
+
Provides deterministic cache key generation for requests.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import hashlib
|
|
10
|
+
import json
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class CacheKey:
|
|
17
|
+
"""A cache key with metadata.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
key: The cache key string
|
|
21
|
+
model: Model used
|
|
22
|
+
messages_hash: Hash of messages
|
|
23
|
+
params_hash: Hash of parameters
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
key: str
|
|
27
|
+
model: str = ""
|
|
28
|
+
messages_hash: str = ""
|
|
29
|
+
params_hash: str = ""
|
|
30
|
+
|
|
31
|
+
def __str__(self) -> str:
|
|
32
|
+
"""Return the key string."""
|
|
33
|
+
return self.key
|
|
34
|
+
|
|
35
|
+
def __hash__(self) -> int:
|
|
36
|
+
"""Return hash of the key."""
|
|
37
|
+
return hash(self.key)
|
|
38
|
+
|
|
39
|
+
def __eq__(self, other: object) -> bool:
|
|
40
|
+
"""Check equality."""
|
|
41
|
+
if isinstance(other, CacheKey):
|
|
42
|
+
return self.key == other.key
|
|
43
|
+
if isinstance(other, str):
|
|
44
|
+
return self.key == other
|
|
45
|
+
return False
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class CacheKeyGenerator:
|
|
49
|
+
"""Generates deterministic cache keys for requests.
|
|
50
|
+
|
|
51
|
+
Example:
|
|
52
|
+
>>> generator = CacheKeyGenerator()
|
|
53
|
+
>>> key = generator.generate(
|
|
54
|
+
... model="gpt-4o",
|
|
55
|
+
... messages=[{"role": "user", "content": "Hello"}],
|
|
56
|
+
... temperature=0.7,
|
|
57
|
+
... )
|
|
58
|
+
>>> print(key.key) # "ai:gpt-4o:a1b2c3..."
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
prefix: str = "ai",
|
|
64
|
+
include_model: bool = True,
|
|
65
|
+
include_params: bool = True,
|
|
66
|
+
excluded_params: list[str] | None = None,
|
|
67
|
+
) -> None:
|
|
68
|
+
"""Initialize key generator.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
prefix: Key prefix
|
|
72
|
+
include_model: Whether to include model in key
|
|
73
|
+
include_params: Whether to include params in key
|
|
74
|
+
excluded_params: Parameters to exclude from key
|
|
75
|
+
"""
|
|
76
|
+
self._prefix = prefix
|
|
77
|
+
self._include_model = include_model
|
|
78
|
+
self._include_params = include_params
|
|
79
|
+
self._excluded_params = set(excluded_params or ["user", "stream"])
|
|
80
|
+
|
|
81
|
+
def generate(
|
|
82
|
+
self,
|
|
83
|
+
model: str,
|
|
84
|
+
messages: list[dict[str, Any]],
|
|
85
|
+
**params: Any,
|
|
86
|
+
) -> CacheKey:
|
|
87
|
+
"""Generate a cache key.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
model: Model name
|
|
91
|
+
messages: Chat messages
|
|
92
|
+
**params: Additional parameters
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
CacheKey instance
|
|
96
|
+
"""
|
|
97
|
+
# Hash messages
|
|
98
|
+
messages_hash = self._hash_messages(messages)
|
|
99
|
+
|
|
100
|
+
# Hash parameters
|
|
101
|
+
filtered_params = {
|
|
102
|
+
k: v for k, v in params.items() if k not in self._excluded_params
|
|
103
|
+
}
|
|
104
|
+
params_hash = self._hash_params(filtered_params)
|
|
105
|
+
|
|
106
|
+
# Build key
|
|
107
|
+
parts = [self._prefix]
|
|
108
|
+
|
|
109
|
+
if self._include_model:
|
|
110
|
+
parts.append(model)
|
|
111
|
+
|
|
112
|
+
parts.append(messages_hash[:16])
|
|
113
|
+
|
|
114
|
+
if self._include_params and params_hash:
|
|
115
|
+
parts.append(params_hash[:8])
|
|
116
|
+
|
|
117
|
+
key = ":".join(parts)
|
|
118
|
+
|
|
119
|
+
return CacheKey(
|
|
120
|
+
key=key,
|
|
121
|
+
model=model,
|
|
122
|
+
messages_hash=messages_hash,
|
|
123
|
+
params_hash=params_hash,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def generate_for_embedding(
|
|
127
|
+
self,
|
|
128
|
+
model: str,
|
|
129
|
+
input_text: str | list[str],
|
|
130
|
+
dimensions: int | None = None,
|
|
131
|
+
) -> CacheKey:
|
|
132
|
+
"""Generate cache key for embedding request.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
model: Model name
|
|
136
|
+
input_text: Input text or list of texts
|
|
137
|
+
dimensions: Output dimensions
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
CacheKey instance
|
|
141
|
+
"""
|
|
142
|
+
# Normalize input
|
|
143
|
+
if isinstance(input_text, str):
|
|
144
|
+
input_hash = self._hash_string(input_text)
|
|
145
|
+
else:
|
|
146
|
+
input_hash = self._hash_string(json.dumps(input_text, sort_keys=True))
|
|
147
|
+
|
|
148
|
+
# Build key
|
|
149
|
+
parts = [self._prefix, "emb", model, input_hash[:16]]
|
|
150
|
+
|
|
151
|
+
if dimensions:
|
|
152
|
+
parts.append(str(dimensions))
|
|
153
|
+
|
|
154
|
+
key = ":".join(parts)
|
|
155
|
+
|
|
156
|
+
return CacheKey(
|
|
157
|
+
key=key,
|
|
158
|
+
model=model,
|
|
159
|
+
messages_hash=input_hash,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def _hash_messages(self, messages: list[dict[str, Any]]) -> str:
|
|
163
|
+
"""Hash a list of messages.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
messages: Messages to hash
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
Hash string
|
|
170
|
+
"""
|
|
171
|
+
# Normalize messages for hashing
|
|
172
|
+
normalized = []
|
|
173
|
+
for msg in messages:
|
|
174
|
+
normalized.append({
|
|
175
|
+
"role": msg.get("role", ""),
|
|
176
|
+
"content": self._normalize_content(msg.get("content", "")),
|
|
177
|
+
})
|
|
178
|
+
|
|
179
|
+
content = json.dumps(normalized, sort_keys=True, ensure_ascii=True)
|
|
180
|
+
return self._hash_string(content)
|
|
181
|
+
|
|
182
|
+
def _normalize_content(self, content: Any) -> Any:
|
|
183
|
+
"""Normalize message content for hashing.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
content: Content to normalize
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Normalized content
|
|
190
|
+
"""
|
|
191
|
+
if isinstance(content, str):
|
|
192
|
+
return content
|
|
193
|
+
if isinstance(content, list):
|
|
194
|
+
# Handle content blocks
|
|
195
|
+
normalized = []
|
|
196
|
+
for block in content:
|
|
197
|
+
if isinstance(block, dict):
|
|
198
|
+
if block.get("type") == "text":
|
|
199
|
+
normalized.append({"type": "text", "text": block.get("text", "")})
|
|
200
|
+
elif block.get("type") == "image_url":
|
|
201
|
+
# Include image URL in hash
|
|
202
|
+
normalized.append({
|
|
203
|
+
"type": "image_url",
|
|
204
|
+
"url": block.get("image_url", {}).get("url", ""),
|
|
205
|
+
})
|
|
206
|
+
else:
|
|
207
|
+
normalized.append(block)
|
|
208
|
+
else:
|
|
209
|
+
normalized.append(block)
|
|
210
|
+
return normalized
|
|
211
|
+
return content
|
|
212
|
+
|
|
213
|
+
def _hash_params(self, params: dict[str, Any]) -> str:
|
|
214
|
+
"""Hash parameters.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
params: Parameters to hash
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Hash string
|
|
221
|
+
"""
|
|
222
|
+
if not params:
|
|
223
|
+
return ""
|
|
224
|
+
|
|
225
|
+
content = json.dumps(params, sort_keys=True, ensure_ascii=True)
|
|
226
|
+
return self._hash_string(content)
|
|
227
|
+
|
|
228
|
+
def _hash_string(self, content: str) -> str:
|
|
229
|
+
"""Hash a string using SHA-256.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
content: String to hash
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
Hex digest
|
|
236
|
+
"""
|
|
237
|
+
return hashlib.sha256(content.encode()).hexdigest()
|