speedy-utils 1.1.16__py3-none-any.whl → 1.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +8 -1
- llm_utils/chat_format/display.py +109 -14
- llm_utils/lm/__init__.py +12 -11
- llm_utils/lm/async_lm/async_llm_task.py +0 -12
- llm_utils/lm/async_lm/async_lm.py +13 -4
- llm_utils/lm/async_lm/async_lm_base.py +24 -14
- llm_utils/lm/base_prompt_builder.py +288 -0
- llm_utils/lm/llm_task.py +400 -0
- llm_utils/lm/lm.py +207 -0
- llm_utils/lm/lm_base.py +285 -0
- llm_utils/vector_cache/core.py +297 -87
- speedy_utils/common/patcher.py +68 -0
- speedy_utils/common/utils_cache.py +5 -5
- speedy_utils/common/utils_io.py +232 -6
- speedy_utils/multi_worker/process.py +124 -193
- {speedy_utils-1.1.16.dist-info → speedy_utils-1.1.18.dist-info}/METADATA +3 -2
- {speedy_utils-1.1.16.dist-info → speedy_utils-1.1.18.dist-info}/RECORD +19 -14
- {speedy_utils-1.1.16.dist-info → speedy_utils-1.1.18.dist-info}/WHEEL +1 -1
- {speedy_utils-1.1.16.dist-info → speedy_utils-1.1.18.dist-info}/entry_points.txt +0 -0
llm_utils/vector_cache/core.py
CHANGED
|
@@ -14,6 +14,10 @@ class VectorCache:
|
|
|
14
14
|
"""
|
|
15
15
|
A caching layer for text embeddings with support for multiple backends.
|
|
16
16
|
|
|
17
|
+
This cache is designed to be safe for multi-process environments where multiple
|
|
18
|
+
processes may access the same cache file simultaneously. It uses SQLite WAL mode
|
|
19
|
+
and retry logic with exponential backoff to handle concurrent access.
|
|
20
|
+
|
|
17
21
|
Examples:
|
|
18
22
|
# OpenAI API
|
|
19
23
|
from llm_utils import VectorCache
|
|
@@ -32,11 +36,26 @@ class VectorCache:
|
|
|
32
36
|
# Explicit backend specification
|
|
33
37
|
cache = VectorCache("model-name", backend="transformers")
|
|
34
38
|
|
|
35
|
-
#
|
|
39
|
+
# Eager loading (default: False) - load model immediately for better performance
|
|
40
|
+
cache = VectorCache("model-name", lazy=False)
|
|
41
|
+
|
|
42
|
+
# Lazy loading - load model only when needed (may cause performance issues)
|
|
36
43
|
cache = VectorCache("model-name", lazy=True)
|
|
37
44
|
|
|
38
|
-
|
|
39
|
-
cache
|
|
45
|
+
Multi-Process Safety:
|
|
46
|
+
The cache uses SQLite WAL (Write-Ahead Logging) mode and implements retry logic
|
|
47
|
+
with exponential backoff to handle database locks. Multiple processes can safely
|
|
48
|
+
read and write to the same cache file simultaneously.
|
|
49
|
+
|
|
50
|
+
Race Condition Protection:
|
|
51
|
+
- Uses INSERT OR IGNORE to prevent overwrites when multiple processes compute the same text
|
|
52
|
+
- The first process to successfully cache a text wins, subsequent attempts are ignored
|
|
53
|
+
- This ensures deterministic results even with non-deterministic embedding models
|
|
54
|
+
|
|
55
|
+
For best performance in multi-process scenarios, consider:
|
|
56
|
+
- Using separate cache files per process if cache hits are low
|
|
57
|
+
- Coordinating cache warm-up to avoid redundant computation
|
|
58
|
+
- Monitor for excessive lock contention in high-concurrency scenarios
|
|
40
59
|
"""
|
|
41
60
|
def __init__(
|
|
42
61
|
self,
|
|
@@ -62,9 +81,11 @@ class VectorCache:
|
|
|
62
81
|
sqlite_chunk_size: int = 999,
|
|
63
82
|
sqlite_cache_size: int = 10000,
|
|
64
83
|
sqlite_mmap_size: int = 268435456,
|
|
84
|
+
# Processing parameters
|
|
85
|
+
embedding_batch_size: int = 20_000,
|
|
65
86
|
# Other parameters
|
|
66
87
|
verbose: bool = True,
|
|
67
|
-
lazy: bool =
|
|
88
|
+
lazy: bool = False,
|
|
68
89
|
) -> None:
|
|
69
90
|
self.url_or_model = url_or_model
|
|
70
91
|
self.embed_size = embed_size
|
|
@@ -95,6 +116,8 @@ class VectorCache:
|
|
|
95
116
|
"sqlite_chunk_size": sqlite_chunk_size,
|
|
96
117
|
"sqlite_cache_size": sqlite_cache_size,
|
|
97
118
|
"sqlite_mmap_size": sqlite_mmap_size,
|
|
119
|
+
# Processing
|
|
120
|
+
"embedding_batch_size": embedding_batch_size,
|
|
98
121
|
}
|
|
99
122
|
|
|
100
123
|
# Auto-detect model_name for OpenAI if using custom URL and default model
|
|
@@ -147,10 +170,14 @@ class VectorCache:
|
|
|
147
170
|
|
|
148
171
|
# Load model/client if not lazy
|
|
149
172
|
if not self.lazy:
|
|
173
|
+
if self.verbose:
|
|
174
|
+
print(f"Loading {self.backend} model/client: {self.url_or_model}")
|
|
150
175
|
if self.backend == "openai":
|
|
151
176
|
self._load_openai_client()
|
|
152
177
|
elif self.backend in ["vllm", "transformers"]:
|
|
153
178
|
self._load_model()
|
|
179
|
+
if self.verbose:
|
|
180
|
+
print(f"✓ {self.backend.upper()} model/client loaded successfully")
|
|
154
181
|
|
|
155
182
|
def _determine_backend(self, backend: Optional[Literal["vllm", "transformers", "openai"]]) -> str:
|
|
156
183
|
"""Determine the appropriate backend based on url_or_model and user preference."""
|
|
@@ -181,7 +208,7 @@ class VectorCache:
|
|
|
181
208
|
print('Infer model name:', model_name)
|
|
182
209
|
return model_name
|
|
183
210
|
def _optimize_connection(self) -> None:
|
|
184
|
-
"""Optimize SQLite connection for bulk operations."""
|
|
211
|
+
"""Optimize SQLite connection for bulk operations and multi-process safety."""
|
|
185
212
|
# Performance optimizations for bulk operations
|
|
186
213
|
self.conn.execute(
|
|
187
214
|
"PRAGMA journal_mode=WAL"
|
|
@@ -190,6 +217,10 @@ class VectorCache:
|
|
|
190
217
|
self.conn.execute(f"PRAGMA cache_size={self.config['sqlite_cache_size']}") # Configurable cache
|
|
191
218
|
self.conn.execute("PRAGMA temp_store=MEMORY") # Use memory for temp storage
|
|
192
219
|
self.conn.execute(f"PRAGMA mmap_size={self.config['sqlite_mmap_size']}") # Configurable memory mapping
|
|
220
|
+
|
|
221
|
+
# Multi-process safety improvements
|
|
222
|
+
self.conn.execute("PRAGMA busy_timeout=30000") # Wait up to 30 seconds for locks
|
|
223
|
+
self.conn.execute("PRAGMA wal_autocheckpoint=1000") # Checkpoint WAL every 1000 pages
|
|
193
224
|
|
|
194
225
|
def _ensure_schema(self) -> None:
|
|
195
226
|
self.conn.execute("""
|
|
@@ -216,7 +247,7 @@ class VectorCache:
|
|
|
216
247
|
def _load_model(self) -> None:
|
|
217
248
|
"""Load the model for vLLM or Transformers."""
|
|
218
249
|
if self.backend == "vllm":
|
|
219
|
-
from vllm import LLM
|
|
250
|
+
from vllm import LLM # type: ignore[import-not-found]
|
|
220
251
|
|
|
221
252
|
gpu_memory_utilization = cast(float, self.config["vllm_gpu_memory_utilization"])
|
|
222
253
|
tensor_parallel_size = cast(int, self.config["vllm_tensor_parallel_size"])
|
|
@@ -257,8 +288,8 @@ class VectorCache:
|
|
|
257
288
|
else:
|
|
258
289
|
raise
|
|
259
290
|
elif self.backend == "transformers":
|
|
260
|
-
from transformers import AutoTokenizer, AutoModel
|
|
261
|
-
import torch
|
|
291
|
+
from transformers import AutoTokenizer, AutoModel # type: ignore[import-not-found]
|
|
292
|
+
import torch # type: ignore[import-not-found]
|
|
262
293
|
|
|
263
294
|
device = self.config["transformers_device"]
|
|
264
295
|
# Handle "auto" device selection - default to CPU for transformers to avoid memory conflicts
|
|
@@ -276,6 +307,8 @@ class VectorCache:
|
|
|
276
307
|
|
|
277
308
|
def _get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
278
309
|
"""Get embeddings using the configured backend."""
|
|
310
|
+
assert isinstance(texts, list), "texts must be a list"
|
|
311
|
+
assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
|
|
279
312
|
if self.backend == "openai":
|
|
280
313
|
return self._get_openai_embeddings(texts)
|
|
281
314
|
elif self.backend == "vllm":
|
|
@@ -287,12 +320,18 @@ class VectorCache:
|
|
|
287
320
|
|
|
288
321
|
def _get_openai_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
289
322
|
"""Get embeddings using OpenAI API."""
|
|
323
|
+
assert isinstance(texts, list), "texts must be a list"
|
|
324
|
+
assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
|
|
290
325
|
# Assert valid model_name for OpenAI backend
|
|
291
326
|
model_name = self.config["model_name"]
|
|
292
327
|
assert model_name is not None and model_name.strip(), f"Invalid model_name for OpenAI backend: {model_name}. Model name must be provided and non-empty."
|
|
293
328
|
|
|
294
329
|
if self._client is None:
|
|
330
|
+
if self.verbose:
|
|
331
|
+
print("🔧 Loading OpenAI client...")
|
|
295
332
|
self._load_openai_client()
|
|
333
|
+
if self.verbose:
|
|
334
|
+
print("✓ OpenAI client loaded successfully")
|
|
296
335
|
|
|
297
336
|
response = self._client.embeddings.create( # type: ignore
|
|
298
337
|
model=model_name,
|
|
@@ -303,8 +342,14 @@ class VectorCache:
|
|
|
303
342
|
|
|
304
343
|
def _get_vllm_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
305
344
|
"""Get embeddings using vLLM."""
|
|
345
|
+
assert isinstance(texts, list), "texts must be a list"
|
|
346
|
+
assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
|
|
306
347
|
if self._model is None:
|
|
348
|
+
if self.verbose:
|
|
349
|
+
print("🔧 Loading vLLM model...")
|
|
307
350
|
self._load_model()
|
|
351
|
+
if self.verbose:
|
|
352
|
+
print("✓ vLLM model loaded successfully")
|
|
308
353
|
|
|
309
354
|
outputs = self._model.embed(texts) # type: ignore
|
|
310
355
|
embeddings = [o.outputs.embedding for o in outputs]
|
|
@@ -312,8 +357,14 @@ class VectorCache:
|
|
|
312
357
|
|
|
313
358
|
def _get_transformers_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
314
359
|
"""Get embeddings using transformers directly."""
|
|
360
|
+
assert isinstance(texts, list), "texts must be a list"
|
|
361
|
+
assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
|
|
315
362
|
if self._model is None:
|
|
363
|
+
if self.verbose:
|
|
364
|
+
print("🔧 Loading Transformers model...")
|
|
316
365
|
self._load_model()
|
|
366
|
+
if self.verbose:
|
|
367
|
+
print("✓ Transformers model loaded successfully")
|
|
317
368
|
|
|
318
369
|
if not isinstance(self._model, dict):
|
|
319
370
|
raise ValueError("Model not loaded properly for transformers backend")
|
|
@@ -340,7 +391,7 @@ class VectorCache:
|
|
|
340
391
|
batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
|
|
341
392
|
|
|
342
393
|
# Run model
|
|
343
|
-
import torch
|
|
394
|
+
import torch # type: ignore[import-not-found]
|
|
344
395
|
with torch.no_grad():
|
|
345
396
|
outputs = model(**batch_dict)
|
|
346
397
|
|
|
@@ -349,14 +400,14 @@ class VectorCache:
|
|
|
349
400
|
|
|
350
401
|
# Normalize if needed
|
|
351
402
|
if normalize_embeddings:
|
|
352
|
-
import torch.nn.functional as F
|
|
403
|
+
import torch.nn.functional as F # type: ignore[import-not-found]
|
|
353
404
|
embeddings = F.normalize(embeddings, p=2, dim=1)
|
|
354
405
|
|
|
355
406
|
return embeddings.cpu().numpy().tolist()
|
|
356
407
|
|
|
357
408
|
def _last_token_pool(self, last_hidden_states, attention_mask):
|
|
358
409
|
"""Apply last token pooling to get embeddings."""
|
|
359
|
-
import torch
|
|
410
|
+
import torch # type: ignore[import-not-found]
|
|
360
411
|
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
|
361
412
|
if left_padding:
|
|
362
413
|
return last_hidden_states[:, -1]
|
|
@@ -368,6 +419,40 @@ class VectorCache:
|
|
|
368
419
|
def _hash_text(self, text: str) -> str:
|
|
369
420
|
return hashlib.sha1(text.encode("utf-8")).hexdigest()
|
|
370
421
|
|
|
422
|
+
def _execute_with_retry(self, query: str, params=None) -> sqlite3.Cursor:
|
|
423
|
+
"""Execute SQLite query with retry logic for multi-process safety."""
|
|
424
|
+
max_retries = 3
|
|
425
|
+
base_delay = 0.05 # 50ms base delay for reads (faster than writes)
|
|
426
|
+
|
|
427
|
+
last_exception = None
|
|
428
|
+
|
|
429
|
+
for attempt in range(max_retries + 1):
|
|
430
|
+
try:
|
|
431
|
+
if params is None:
|
|
432
|
+
return self.conn.execute(query)
|
|
433
|
+
else:
|
|
434
|
+
return self.conn.execute(query, params)
|
|
435
|
+
|
|
436
|
+
except sqlite3.OperationalError as e:
|
|
437
|
+
last_exception = e
|
|
438
|
+
if "database is locked" in str(e).lower() and attempt < max_retries:
|
|
439
|
+
# Exponential backoff: 0.05s, 0.1s, 0.2s
|
|
440
|
+
delay = base_delay * (2 ** attempt)
|
|
441
|
+
if self.verbose:
|
|
442
|
+
print(f"⚠️ Database locked on read, retrying in {delay:.2f}s (attempt {attempt + 1}/{max_retries + 1})")
|
|
443
|
+
import time
|
|
444
|
+
time.sleep(delay)
|
|
445
|
+
continue
|
|
446
|
+
else:
|
|
447
|
+
# Re-raise if not a lock error or max retries exceeded
|
|
448
|
+
raise
|
|
449
|
+
except Exception as e:
|
|
450
|
+
# Re-raise any other exceptions
|
|
451
|
+
raise
|
|
452
|
+
|
|
453
|
+
# This should never be reached, but satisfy the type checker
|
|
454
|
+
raise last_exception or RuntimeError("Failed to execute query after retries")
|
|
455
|
+
|
|
371
456
|
def embeds(self, texts: list[str], cache: bool = True) -> np.ndarray:
|
|
372
457
|
"""
|
|
373
458
|
Return embeddings for all texts.
|
|
@@ -379,6 +464,8 @@ class VectorCache:
|
|
|
379
464
|
handle very large input lists. A tqdm progress bar is shown while
|
|
380
465
|
computing missing embeddings.
|
|
381
466
|
"""
|
|
467
|
+
assert isinstance(texts, list), "texts must be a list"
|
|
468
|
+
assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
|
|
382
469
|
if not texts:
|
|
383
470
|
return np.empty((0, 0), dtype=np.float32)
|
|
384
471
|
t = time()
|
|
@@ -392,11 +479,11 @@ class VectorCache:
|
|
|
392
479
|
hit_map: dict[str, np.ndarray] = {}
|
|
393
480
|
chunk_size = self.config["sqlite_chunk_size"]
|
|
394
481
|
|
|
395
|
-
# Use bulk lookup with optimized query
|
|
482
|
+
# Use bulk lookup with optimized query and retry logic
|
|
396
483
|
hash_chunks = _chunks(hashes, chunk_size)
|
|
397
484
|
for chunk in hash_chunks:
|
|
398
485
|
placeholders = ",".join("?" * len(chunk))
|
|
399
|
-
rows = self.
|
|
486
|
+
rows = self._execute_with_retry(
|
|
400
487
|
f"SELECT hash, embedding FROM cache WHERE hash IN ({placeholders})",
|
|
401
488
|
chunk,
|
|
402
489
|
).fetchall()
|
|
@@ -415,18 +502,8 @@ class VectorCache:
|
|
|
415
502
|
|
|
416
503
|
if missing_items:
|
|
417
504
|
if self.verbose:
|
|
418
|
-
print(f"Computing
|
|
419
|
-
|
|
420
|
-
embeds = self._get_embeddings(missing_texts)
|
|
421
|
-
|
|
422
|
-
# Prepare batch data for bulk insert
|
|
423
|
-
bulk_insert_data: list[tuple[str, str, bytes]] = []
|
|
424
|
-
for (text, h), vec in zip(missing_items, embeds):
|
|
425
|
-
arr = np.asarray(vec, dtype=np.float32)
|
|
426
|
-
bulk_insert_data.append((h, text, arr.tobytes()))
|
|
427
|
-
hit_map[h] = arr
|
|
428
|
-
|
|
429
|
-
self._bulk_insert(bulk_insert_data)
|
|
505
|
+
print(f"Computing {len(missing_items)}/{len(texts)} missing embeddings...")
|
|
506
|
+
self._process_missing_items_with_batches(missing_items, hit_map)
|
|
430
507
|
|
|
431
508
|
# Return embeddings in the original order
|
|
432
509
|
elapsed = time() - t
|
|
@@ -434,83 +511,215 @@ class VectorCache:
|
|
|
434
511
|
print(f"Retrieved {len(texts)} embeddings in {elapsed:.2f} seconds")
|
|
435
512
|
return np.vstack([hit_map[h] for h in hashes])
|
|
436
513
|
|
|
514
|
+
def _process_missing_items_with_batches(self, missing_items: list[tuple[str, str]], hit_map: dict[str, np.ndarray]) -> None:
|
|
515
|
+
"""
|
|
516
|
+
Process missing items in batches with progress bar and incremental DB insertion.
|
|
517
|
+
"""
|
|
518
|
+
t = time() # Track total processing time
|
|
519
|
+
|
|
520
|
+
# Try to import tqdm, fall back to simple progress if not available
|
|
521
|
+
tqdm = None # avoid "possibly unbound" in type checker
|
|
522
|
+
use_tqdm = False
|
|
523
|
+
try:
|
|
524
|
+
from tqdm import tqdm as _tqdm # type: ignore[import-not-found]
|
|
525
|
+
tqdm = _tqdm
|
|
526
|
+
use_tqdm = True
|
|
527
|
+
except ImportError:
|
|
528
|
+
use_tqdm = False
|
|
529
|
+
if self.verbose:
|
|
530
|
+
print("tqdm not available, using simple progress reporting")
|
|
531
|
+
|
|
532
|
+
batch_size = self.config["embedding_batch_size"]
|
|
533
|
+
total_items = len(missing_items)
|
|
534
|
+
|
|
535
|
+
if self.verbose:
|
|
536
|
+
print(f"Computing embeddings for {total_items} missing texts in batches of {batch_size}...")
|
|
537
|
+
if self.backend in ["vllm", "transformers"] and self._model is None:
|
|
538
|
+
print(f"⚠️ Model will be loaded on first batch (lazy loading enabled)")
|
|
539
|
+
elif self.backend in ["vllm", "transformers"]:
|
|
540
|
+
print(f"✓ Model already loaded, ready for efficient batch processing")
|
|
541
|
+
|
|
542
|
+
# Create progress bar
|
|
543
|
+
pbar = None
|
|
544
|
+
processed_count = 0
|
|
545
|
+
if use_tqdm and tqdm is not None:
|
|
546
|
+
pbar = tqdm(total=total_items, desc="Computing embeddings", unit="texts")
|
|
547
|
+
|
|
548
|
+
# Track total committed items
|
|
549
|
+
total_committed = 0
|
|
550
|
+
|
|
551
|
+
try:
|
|
552
|
+
# Process in batches
|
|
553
|
+
for i in range(0, total_items, batch_size):
|
|
554
|
+
batch_items = missing_items[i:i + batch_size]
|
|
555
|
+
batch_texts = [text for text, _ in batch_items]
|
|
556
|
+
|
|
557
|
+
# Get embeddings for this batch
|
|
558
|
+
batch_embeds = self._get_embeddings(batch_texts)
|
|
559
|
+
|
|
560
|
+
# Prepare batch data for immediate insert
|
|
561
|
+
batch_data: list[tuple[str, str, bytes]] = []
|
|
562
|
+
for (text, h), vec in zip(batch_items, batch_embeds):
|
|
563
|
+
arr = np.asarray(vec, dtype=np.float32)
|
|
564
|
+
batch_data.append((h, text, arr.tobytes()))
|
|
565
|
+
hit_map[h] = arr
|
|
566
|
+
|
|
567
|
+
# Immediate commit after each batch
|
|
568
|
+
self._bulk_insert(batch_data)
|
|
569
|
+
total_committed += len(batch_data)
|
|
570
|
+
|
|
571
|
+
# Update progress
|
|
572
|
+
batch_size_actual = len(batch_items)
|
|
573
|
+
if use_tqdm:
|
|
574
|
+
pbar.update(batch_size_actual)
|
|
575
|
+
else:
|
|
576
|
+
processed_count += batch_size_actual
|
|
577
|
+
if self.verbose:
|
|
578
|
+
print(f"Progress: {processed_count}/{total_items} embeddings computed, {total_committed} committed")
|
|
579
|
+
|
|
580
|
+
finally:
|
|
581
|
+
# Clean up progress bar
|
|
582
|
+
if pbar is not None:
|
|
583
|
+
pbar.close()
|
|
584
|
+
|
|
585
|
+
if self.verbose:
|
|
586
|
+
total_time = time() - t
|
|
587
|
+
rate = total_items / total_time if total_time > 0 else 0
|
|
588
|
+
print(f"✅ Completed: {total_items} embeddings computed and {total_committed} items committed to database")
|
|
589
|
+
print(f" Total time: {total_time:.2f}s | Rate: {rate:.1f} embeddings/sec")
|
|
590
|
+
|
|
437
591
|
def __call__(self, texts: list[str], cache: bool = True) -> np.ndarray:
|
|
592
|
+
assert isinstance(texts, list), "texts must be a list"
|
|
593
|
+
assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
|
|
438
594
|
return self.embeds(texts, cache)
|
|
439
595
|
|
|
440
596
|
def _bulk_insert(self, data: list[tuple[str, str, bytes]]) -> None:
|
|
441
|
-
"""Perform bulk insert of embedding data."""
|
|
442
|
-
if not data:
|
|
443
|
-
return
|
|
444
|
-
|
|
445
|
-
self.conn.executemany(
|
|
446
|
-
"INSERT OR REPLACE INTO cache (hash, text, embedding) VALUES (?, ?, ?)",
|
|
447
|
-
data,
|
|
448
|
-
)
|
|
449
|
-
self.conn.commit()
|
|
450
|
-
|
|
451
|
-
def precompute_embeddings(self, texts: list[str]) -> None:
|
|
452
597
|
"""
|
|
453
|
-
|
|
454
|
-
|
|
598
|
+
Perform bulk insert of embedding data with retry logic for multi-process safety.
|
|
599
|
+
|
|
600
|
+
Uses INSERT OR IGNORE to prevent race conditions where multiple processes
|
|
601
|
+
might try to insert the same text hash. The first process to successfully
|
|
602
|
+
insert wins, subsequent attempts are ignored. This ensures deterministic
|
|
603
|
+
caching behavior in multi-process environments.
|
|
455
604
|
"""
|
|
456
|
-
if not
|
|
457
|
-
return
|
|
458
|
-
|
|
459
|
-
# Remove duplicates while preserving order
|
|
460
|
-
unique_texts = list(dict.fromkeys(texts))
|
|
461
|
-
if self.verbose:
|
|
462
|
-
print(f"Precomputing embeddings for {len(unique_texts)} unique texts...")
|
|
463
|
-
|
|
464
|
-
# Check which ones are already cached
|
|
465
|
-
hashes = [self._hash_text(t) for t in unique_texts]
|
|
466
|
-
existing_hashes = set()
|
|
467
|
-
|
|
468
|
-
# Bulk check for existing embeddings
|
|
469
|
-
chunk_size = self.config["sqlite_chunk_size"]
|
|
470
|
-
for i in range(0, len(hashes), chunk_size):
|
|
471
|
-
chunk = hashes[i : i + chunk_size]
|
|
472
|
-
placeholders = ",".join("?" * len(chunk))
|
|
473
|
-
rows = self.conn.execute(
|
|
474
|
-
f"SELECT hash FROM cache WHERE hash IN ({placeholders})",
|
|
475
|
-
chunk,
|
|
476
|
-
).fetchall()
|
|
477
|
-
existing_hashes.update(h[0] for h in rows)
|
|
478
|
-
|
|
479
|
-
# Find missing texts
|
|
480
|
-
missing_items = [
|
|
481
|
-
(t, h) for t, h in zip(unique_texts, hashes) if h not in existing_hashes
|
|
482
|
-
]
|
|
483
|
-
|
|
484
|
-
if not missing_items:
|
|
485
|
-
if self.verbose:
|
|
486
|
-
print("All texts already cached!")
|
|
605
|
+
if not data:
|
|
487
606
|
return
|
|
488
607
|
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
608
|
+
max_retries = 3
|
|
609
|
+
base_delay = 0.1 # 100ms base delay
|
|
610
|
+
|
|
611
|
+
for attempt in range(max_retries + 1):
|
|
612
|
+
try:
|
|
613
|
+
cursor = self.conn.executemany(
|
|
614
|
+
"INSERT OR IGNORE INTO cache (hash, text, embedding) VALUES (?, ?, ?)",
|
|
615
|
+
data,
|
|
616
|
+
)
|
|
617
|
+
self.conn.commit()
|
|
618
|
+
|
|
619
|
+
# Check if some insertions were ignored due to existing entries
|
|
620
|
+
if self.verbose and cursor.rowcount < len(data):
|
|
621
|
+
ignored_count = len(data) - cursor.rowcount
|
|
622
|
+
if ignored_count > 0:
|
|
623
|
+
print(f"ℹ️ {ignored_count}/{len(data)} embeddings already existed in cache (computed by another process)")
|
|
624
|
+
|
|
625
|
+
return # Success, exit the retry loop
|
|
626
|
+
|
|
627
|
+
except sqlite3.OperationalError as e:
|
|
628
|
+
if "database is locked" in str(e).lower() and attempt < max_retries:
|
|
629
|
+
# Exponential backoff: 0.1s, 0.2s, 0.4s
|
|
630
|
+
delay = base_delay * (2 ** attempt)
|
|
631
|
+
if self.verbose:
|
|
632
|
+
print(f"⚠️ Database locked, retrying in {delay:.1f}s (attempt {attempt + 1}/{max_retries + 1})")
|
|
633
|
+
import time
|
|
634
|
+
time.sleep(delay)
|
|
635
|
+
continue
|
|
636
|
+
else:
|
|
637
|
+
# Re-raise if not a lock error or max retries exceeded
|
|
638
|
+
raise
|
|
639
|
+
except Exception as e:
|
|
640
|
+
# Re-raise any other exceptions
|
|
641
|
+
raise
|
|
642
|
+
|
|
643
|
+
# def precompute_embeddings(self, texts: list[str]) -> None:
|
|
644
|
+
# """
|
|
645
|
+
# Precompute embeddings for a large list of texts efficiently.
|
|
646
|
+
# This is optimized for bulk operations when you know all texts upfront.
|
|
647
|
+
# """
|
|
648
|
+
# assert isinstance(texts, list), "texts must be a list"
|
|
649
|
+
# assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
|
|
650
|
+
# if not texts:
|
|
651
|
+
# return
|
|
652
|
+
|
|
653
|
+
# # Remove duplicates while preserving order
|
|
654
|
+
# unique_texts = list(dict.fromkeys(texts))
|
|
655
|
+
# if self.verbose:
|
|
656
|
+
# print(f"Precomputing embeddings for {len(unique_texts)} unique texts...")
|
|
657
|
+
|
|
658
|
+
# # Check which ones are already cached
|
|
659
|
+
# hashes = [self._hash_text(t) for t in unique_texts]
|
|
660
|
+
# existing_hashes = set()
|
|
661
|
+
|
|
662
|
+
# # Bulk check for existing embeddings
|
|
663
|
+
# chunk_size = self.config["sqlite_chunk_size"]
|
|
664
|
+
# for i in range(0, len(hashes), chunk_size):
|
|
665
|
+
# chunk = hashes[i : i + chunk_size]
|
|
666
|
+
# placeholders = ",".join("?" * len(chunk))
|
|
667
|
+
# rows = self._execute_with_retry(
|
|
668
|
+
# f"SELECT hash FROM cache WHERE hash IN ({placeholders})",
|
|
669
|
+
# chunk,
|
|
670
|
+
# ).fetchall()
|
|
671
|
+
# existing_hashes.update(h[0] for h in rows)
|
|
672
|
+
|
|
673
|
+
# # Find missing texts
|
|
674
|
+
# missing_items = [
|
|
675
|
+
# (t, h) for t, h in zip(unique_texts, hashes) if h not in existing_hashes
|
|
676
|
+
# ]
|
|
677
|
+
|
|
678
|
+
# if not missing_items:
|
|
679
|
+
# if self.verbose:
|
|
680
|
+
# print("All texts already cached!")
|
|
681
|
+
# return
|
|
682
|
+
|
|
683
|
+
# if self.verbose:
|
|
684
|
+
# print(f"Computing {len(missing_items)} missing embeddings...")
|
|
685
|
+
|
|
686
|
+
# # Process missing items with batches
|
|
687
|
+
# missing_texts = [t for t, _ in missing_items]
|
|
688
|
+
# missing_items_tupled = [(t, h) for t, h in zip(missing_texts, [self._hash_text(t) for t in missing_texts])]
|
|
689
|
+
# hit_map_temp: dict[str, np.ndarray] = {}
|
|
690
|
+
# self._process_missing_items_with_batches(missing_items_tupled, hit_map_temp)
|
|
691
|
+
# if self.verbose:
|
|
692
|
+
# print(f"Successfully cached {len(missing_items)} new embeddings!")
|
|
503
693
|
|
|
504
694
|
def get_cache_stats(self) -> dict[str, int]:
|
|
505
695
|
"""Get statistics about the cache."""
|
|
506
|
-
cursor = self.
|
|
696
|
+
cursor = self._execute_with_retry("SELECT COUNT(*) FROM cache")
|
|
507
697
|
count = cursor.fetchone()[0]
|
|
508
698
|
return {"total_cached": count}
|
|
509
699
|
|
|
510
700
|
def clear_cache(self) -> None:
|
|
511
701
|
"""Clear all cached embeddings."""
|
|
512
|
-
|
|
513
|
-
|
|
702
|
+
max_retries = 3
|
|
703
|
+
base_delay = 0.1 # 100ms base delay
|
|
704
|
+
|
|
705
|
+
for attempt in range(max_retries + 1):
|
|
706
|
+
try:
|
|
707
|
+
self.conn.execute("DELETE FROM cache")
|
|
708
|
+
self.conn.commit()
|
|
709
|
+
return # Success
|
|
710
|
+
|
|
711
|
+
except sqlite3.OperationalError as e:
|
|
712
|
+
if "database is locked" in str(e).lower() and attempt < max_retries:
|
|
713
|
+
delay = base_delay * (2 ** attempt)
|
|
714
|
+
if self.verbose:
|
|
715
|
+
print(f"⚠️ Database locked during clear, retrying in {delay:.1f}s (attempt {attempt + 1}/{max_retries + 1})")
|
|
716
|
+
import time
|
|
717
|
+
time.sleep(delay)
|
|
718
|
+
continue
|
|
719
|
+
else:
|
|
720
|
+
raise
|
|
721
|
+
except Exception as e:
|
|
722
|
+
raise
|
|
514
723
|
|
|
515
724
|
def get_config(self) -> Dict[str, Any]:
|
|
516
725
|
"""Get current configuration."""
|
|
@@ -542,7 +751,8 @@ class VectorCache:
|
|
|
542
751
|
"vllm_trust_remote_code", "vllm_max_model_len"],
|
|
543
752
|
"transformers": ["transformers_device", "transformers_batch_size",
|
|
544
753
|
"transformers_normalize_embeddings", "transformers_trust_remote_code"],
|
|
545
|
-
"openai": ["api_key", "model_name"]
|
|
754
|
+
"openai": ["api_key", "model_name"],
|
|
755
|
+
"processing": ["embedding_batch_size"]
|
|
546
756
|
}
|
|
547
757
|
|
|
548
758
|
if any(param in kwargs for param in backend_params.get(self.backend, [])):
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# utils/patching.py
|
|
2
|
+
import inspect
|
|
3
|
+
import types
|
|
4
|
+
import re
|
|
5
|
+
from typing import Annotated, Union
|
|
6
|
+
|
|
7
|
+
def patch_method(
|
|
8
|
+
cls: Annotated[type, "Class containing the method"],
|
|
9
|
+
method_name: Annotated[str, "Name of the method to patch"],
|
|
10
|
+
replacements: Annotated[
|
|
11
|
+
dict[Union[str, re.Pattern], str],
|
|
12
|
+
"Mapping of {old_substring_or_regex: new_string} replacements"
|
|
13
|
+
],
|
|
14
|
+
tag: Annotated[str, "Optional logging tag"] = "",
|
|
15
|
+
) -> bool:
|
|
16
|
+
"""
|
|
17
|
+
Generic patcher for replacing substrings or regex matches in a method's source code.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
cls: Target class
|
|
21
|
+
method_name: Method name to patch
|
|
22
|
+
replacements: {pattern: replacement}. Patterns may be plain strings or regex patterns.
|
|
23
|
+
tag: Optional string shown in logs
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
bool: True if successfully patched, False otherwise
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
method = getattr(cls, method_name)
|
|
31
|
+
except AttributeError:
|
|
32
|
+
print(f"[patcher{':' + tag if tag else ''}] No method {method_name} in {cls.__name__}")
|
|
33
|
+
return False
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
src = inspect.getsource(method)
|
|
37
|
+
except (OSError, TypeError):
|
|
38
|
+
print(f"[patcher{':' + tag if tag else ''}] Could not get source for {cls.__name__}.{method_name}")
|
|
39
|
+
return False
|
|
40
|
+
|
|
41
|
+
new_src = src
|
|
42
|
+
did_patch = False
|
|
43
|
+
|
|
44
|
+
for old, new in replacements.items():
|
|
45
|
+
if isinstance(old, re.Pattern):
|
|
46
|
+
if old.search(new_src):
|
|
47
|
+
new_src = old.sub(new, new_src)
|
|
48
|
+
did_patch = True
|
|
49
|
+
elif isinstance(old, str):
|
|
50
|
+
if old in new_src:
|
|
51
|
+
new_src = new_src.replace(old, new)
|
|
52
|
+
did_patch = True
|
|
53
|
+
else:
|
|
54
|
+
raise TypeError("Replacement keys must be str or re.Pattern")
|
|
55
|
+
|
|
56
|
+
if not did_patch:
|
|
57
|
+
print(f"[patcher{':' + tag if tag else ''}] No matching patterns found in {cls.__name__}.{method_name}")
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
# Recompile the patched function
|
|
61
|
+
code_obj = compile(new_src, filename=f"<patched_{method_name}>", mode="exec")
|
|
62
|
+
ns = {}
|
|
63
|
+
exec(code_obj, cls.__dict__, ns) # type: ignore
|
|
64
|
+
|
|
65
|
+
# Attach patched method back
|
|
66
|
+
setattr(cls, method_name, types.MethodType(ns[method_name], None, cls)) # type: ignore
|
|
67
|
+
print(f"[patcher{':' + tag if tag else ''}] Patched {cls.__name__}.{method_name}")
|
|
68
|
+
return True
|
|
@@ -563,7 +563,7 @@ def _async_both_memoize(
|
|
|
563
563
|
|
|
564
564
|
@overload
|
|
565
565
|
def memoize(
|
|
566
|
-
_func: Callable[P, R],
|
|
566
|
+
_func: Callable[P, R | Awaitable[R]],
|
|
567
567
|
*,
|
|
568
568
|
keys: Optional[list[str]] = ...,
|
|
569
569
|
key: Optional[Callable[..., Any]] = ...,
|
|
@@ -572,10 +572,10 @@ def memoize(
|
|
|
572
572
|
size: int = ...,
|
|
573
573
|
ignore_self: bool = ...,
|
|
574
574
|
verbose: bool = ...,
|
|
575
|
-
) -> Callable[P, R]: ...
|
|
575
|
+
) -> Callable[P, R | Awaitable[R]]: ...
|
|
576
576
|
@overload
|
|
577
577
|
def memoize(
|
|
578
|
-
_func:
|
|
578
|
+
_func: None = ...,
|
|
579
579
|
*,
|
|
580
580
|
keys: Optional[list[str]] = ...,
|
|
581
581
|
key: Optional[Callable[..., Any]] = ...,
|
|
@@ -584,7 +584,7 @@ def memoize(
|
|
|
584
584
|
size: int = ...,
|
|
585
585
|
ignore_self: bool = ...,
|
|
586
586
|
verbose: bool = ...,
|
|
587
|
-
) -> Callable[P,
|
|
587
|
+
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
|
588
588
|
@overload
|
|
589
589
|
def memoize(
|
|
590
590
|
_func: None = ...,
|
|
@@ -596,7 +596,7 @@ def memoize(
|
|
|
596
596
|
size: int = ...,
|
|
597
597
|
ignore_self: bool = ...,
|
|
598
598
|
verbose: bool = ...,
|
|
599
|
-
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
|
599
|
+
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ...
|
|
600
600
|
|
|
601
601
|
|
|
602
602
|
def memoize(
|