speedy-utils 1.1.17__py3-none-any.whl → 1.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +9 -1
- llm_utils/chat_format/display.py +109 -14
- llm_utils/lm/__init__.py +12 -11
- llm_utils/lm/async_lm/async_llm_task.py +1 -10
- llm_utils/lm/async_lm/async_lm.py +13 -4
- llm_utils/lm/async_lm/async_lm_base.py +24 -14
- llm_utils/lm/base_prompt_builder.py +288 -0
- llm_utils/lm/llm_task.py +693 -0
- llm_utils/lm/lm.py +207 -0
- llm_utils/lm/lm_base.py +285 -0
- llm_utils/lm/openai_memoize.py +2 -2
- llm_utils/vector_cache/core.py +285 -89
- speedy_utils/__init__.py +2 -1
- speedy_utils/common/patcher.py +68 -0
- speedy_utils/common/utils_cache.py +6 -6
- speedy_utils/common/utils_io.py +238 -8
- speedy_utils/multi_worker/process.py +180 -192
- speedy_utils/multi_worker/thread.py +94 -2
- {speedy_utils-1.1.17.dist-info → speedy_utils-1.1.19.dist-info}/METADATA +36 -14
- {speedy_utils-1.1.17.dist-info → speedy_utils-1.1.19.dist-info}/RECORD +24 -19
- {speedy_utils-1.1.17.dist-info → speedy_utils-1.1.19.dist-info}/WHEEL +1 -1
- speedy_utils-1.1.19.dist-info/entry_points.txt +5 -0
- speedy_utils-1.1.17.dist-info/entry_points.txt +0 -6
llm_utils/vector_cache/core.py
CHANGED
|
@@ -14,6 +14,10 @@ class VectorCache:
|
|
|
14
14
|
"""
|
|
15
15
|
A caching layer for text embeddings with support for multiple backends.
|
|
16
16
|
|
|
17
|
+
This cache is designed to be safe for multi-process environments where multiple
|
|
18
|
+
processes may access the same cache file simultaneously. It uses SQLite WAL mode
|
|
19
|
+
and retry logic with exponential backoff to handle concurrent access.
|
|
20
|
+
|
|
17
21
|
Examples:
|
|
18
22
|
# OpenAI API
|
|
19
23
|
from llm_utils import VectorCache
|
|
@@ -32,11 +36,26 @@ class VectorCache:
|
|
|
32
36
|
# Explicit backend specification
|
|
33
37
|
cache = VectorCache("model-name", backend="transformers")
|
|
34
38
|
|
|
35
|
-
#
|
|
39
|
+
# Eager loading (default: False) - load model immediately for better performance
|
|
40
|
+
cache = VectorCache("model-name", lazy=False)
|
|
41
|
+
|
|
42
|
+
# Lazy loading - load model only when needed (may cause performance issues)
|
|
36
43
|
cache = VectorCache("model-name", lazy=True)
|
|
37
44
|
|
|
38
|
-
|
|
39
|
-
cache
|
|
45
|
+
Multi-Process Safety:
|
|
46
|
+
The cache uses SQLite WAL (Write-Ahead Logging) mode and implements retry logic
|
|
47
|
+
with exponential backoff to handle database locks. Multiple processes can safely
|
|
48
|
+
read and write to the same cache file simultaneously.
|
|
49
|
+
|
|
50
|
+
Race Condition Protection:
|
|
51
|
+
- Uses INSERT OR IGNORE to prevent overwrites when multiple processes compute the same text
|
|
52
|
+
- The first process to successfully cache a text wins, subsequent attempts are ignored
|
|
53
|
+
- This ensures deterministic results even with non-deterministic embedding models
|
|
54
|
+
|
|
55
|
+
For best performance in multi-process scenarios, consider:
|
|
56
|
+
- Using separate cache files per process if cache hits are low
|
|
57
|
+
- Coordinating cache warm-up to avoid redundant computation
|
|
58
|
+
- Monitor for excessive lock contention in high-concurrency scenarios
|
|
40
59
|
"""
|
|
41
60
|
def __init__(
|
|
42
61
|
self,
|
|
@@ -62,9 +81,11 @@ class VectorCache:
|
|
|
62
81
|
sqlite_chunk_size: int = 999,
|
|
63
82
|
sqlite_cache_size: int = 10000,
|
|
64
83
|
sqlite_mmap_size: int = 268435456,
|
|
84
|
+
# Processing parameters
|
|
85
|
+
embedding_batch_size: int = 20_000,
|
|
65
86
|
# Other parameters
|
|
66
87
|
verbose: bool = True,
|
|
67
|
-
lazy: bool =
|
|
88
|
+
lazy: bool = False,
|
|
68
89
|
) -> None:
|
|
69
90
|
self.url_or_model = url_or_model
|
|
70
91
|
self.embed_size = embed_size
|
|
@@ -95,6 +116,8 @@ class VectorCache:
|
|
|
95
116
|
"sqlite_chunk_size": sqlite_chunk_size,
|
|
96
117
|
"sqlite_cache_size": sqlite_cache_size,
|
|
97
118
|
"sqlite_mmap_size": sqlite_mmap_size,
|
|
119
|
+
# Processing
|
|
120
|
+
"embedding_batch_size": embedding_batch_size,
|
|
98
121
|
}
|
|
99
122
|
|
|
100
123
|
# Auto-detect model_name for OpenAI if using custom URL and default model
|
|
@@ -147,10 +170,14 @@ class VectorCache:
|
|
|
147
170
|
|
|
148
171
|
# Load model/client if not lazy
|
|
149
172
|
if not self.lazy:
|
|
173
|
+
if self.verbose:
|
|
174
|
+
print(f"Loading {self.backend} model/client: {self.url_or_model}")
|
|
150
175
|
if self.backend == "openai":
|
|
151
176
|
self._load_openai_client()
|
|
152
177
|
elif self.backend in ["vllm", "transformers"]:
|
|
153
178
|
self._load_model()
|
|
179
|
+
if self.verbose:
|
|
180
|
+
print(f"✓ {self.backend.upper()} model/client loaded successfully")
|
|
154
181
|
|
|
155
182
|
def _determine_backend(self, backend: Optional[Literal["vllm", "transformers", "openai"]]) -> str:
|
|
156
183
|
"""Determine the appropriate backend based on url_or_model and user preference."""
|
|
@@ -181,7 +208,7 @@ class VectorCache:
|
|
|
181
208
|
print('Infer model name:', model_name)
|
|
182
209
|
return model_name
|
|
183
210
|
def _optimize_connection(self) -> None:
|
|
184
|
-
"""Optimize SQLite connection for bulk operations."""
|
|
211
|
+
"""Optimize SQLite connection for bulk operations and multi-process safety."""
|
|
185
212
|
# Performance optimizations for bulk operations
|
|
186
213
|
self.conn.execute(
|
|
187
214
|
"PRAGMA journal_mode=WAL"
|
|
@@ -190,6 +217,10 @@ class VectorCache:
|
|
|
190
217
|
self.conn.execute(f"PRAGMA cache_size={self.config['sqlite_cache_size']}") # Configurable cache
|
|
191
218
|
self.conn.execute("PRAGMA temp_store=MEMORY") # Use memory for temp storage
|
|
192
219
|
self.conn.execute(f"PRAGMA mmap_size={self.config['sqlite_mmap_size']}") # Configurable memory mapping
|
|
220
|
+
|
|
221
|
+
# Multi-process safety improvements
|
|
222
|
+
self.conn.execute("PRAGMA busy_timeout=30000") # Wait up to 30 seconds for locks
|
|
223
|
+
self.conn.execute("PRAGMA wal_autocheckpoint=1000") # Checkpoint WAL every 1000 pages
|
|
193
224
|
|
|
194
225
|
def _ensure_schema(self) -> None:
|
|
195
226
|
self.conn.execute("""
|
|
@@ -216,7 +247,7 @@ class VectorCache:
|
|
|
216
247
|
def _load_model(self) -> None:
|
|
217
248
|
"""Load the model for vLLM or Transformers."""
|
|
218
249
|
if self.backend == "vllm":
|
|
219
|
-
from vllm import LLM
|
|
250
|
+
from vllm import LLM # type: ignore[import-not-found]
|
|
220
251
|
|
|
221
252
|
gpu_memory_utilization = cast(float, self.config["vllm_gpu_memory_utilization"])
|
|
222
253
|
tensor_parallel_size = cast(int, self.config["vllm_tensor_parallel_size"])
|
|
@@ -257,8 +288,8 @@ class VectorCache:
|
|
|
257
288
|
else:
|
|
258
289
|
raise
|
|
259
290
|
elif self.backend == "transformers":
|
|
260
|
-
from transformers import AutoTokenizer, AutoModel
|
|
261
|
-
import torch
|
|
291
|
+
from transformers import AutoTokenizer, AutoModel # type: ignore[import-not-found]
|
|
292
|
+
import torch # type: ignore[import-not-found]
|
|
262
293
|
|
|
263
294
|
device = self.config["transformers_device"]
|
|
264
295
|
# Handle "auto" device selection - default to CPU for transformers to avoid memory conflicts
|
|
@@ -296,7 +327,11 @@ class VectorCache:
|
|
|
296
327
|
assert model_name is not None and model_name.strip(), f"Invalid model_name for OpenAI backend: {model_name}. Model name must be provided and non-empty."
|
|
297
328
|
|
|
298
329
|
if self._client is None:
|
|
330
|
+
if self.verbose:
|
|
331
|
+
print("🔧 Loading OpenAI client...")
|
|
299
332
|
self._load_openai_client()
|
|
333
|
+
if self.verbose:
|
|
334
|
+
print("✓ OpenAI client loaded successfully")
|
|
300
335
|
|
|
301
336
|
response = self._client.embeddings.create( # type: ignore
|
|
302
337
|
model=model_name,
|
|
@@ -310,7 +345,11 @@ class VectorCache:
|
|
|
310
345
|
assert isinstance(texts, list), "texts must be a list"
|
|
311
346
|
assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
|
|
312
347
|
if self._model is None:
|
|
348
|
+
if self.verbose:
|
|
349
|
+
print("🔧 Loading vLLM model...")
|
|
313
350
|
self._load_model()
|
|
351
|
+
if self.verbose:
|
|
352
|
+
print("✓ vLLM model loaded successfully")
|
|
314
353
|
|
|
315
354
|
outputs = self._model.embed(texts) # type: ignore
|
|
316
355
|
embeddings = [o.outputs.embedding for o in outputs]
|
|
@@ -321,7 +360,11 @@ class VectorCache:
|
|
|
321
360
|
assert isinstance(texts, list), "texts must be a list"
|
|
322
361
|
assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
|
|
323
362
|
if self._model is None:
|
|
363
|
+
if self.verbose:
|
|
364
|
+
print("🔧 Loading Transformers model...")
|
|
324
365
|
self._load_model()
|
|
366
|
+
if self.verbose:
|
|
367
|
+
print("✓ Transformers model loaded successfully")
|
|
325
368
|
|
|
326
369
|
if not isinstance(self._model, dict):
|
|
327
370
|
raise ValueError("Model not loaded properly for transformers backend")
|
|
@@ -348,7 +391,7 @@ class VectorCache:
|
|
|
348
391
|
batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
|
|
349
392
|
|
|
350
393
|
# Run model
|
|
351
|
-
import torch
|
|
394
|
+
import torch # type: ignore[import-not-found]
|
|
352
395
|
with torch.no_grad():
|
|
353
396
|
outputs = model(**batch_dict)
|
|
354
397
|
|
|
@@ -357,14 +400,14 @@ class VectorCache:
|
|
|
357
400
|
|
|
358
401
|
# Normalize if needed
|
|
359
402
|
if normalize_embeddings:
|
|
360
|
-
import torch.nn.functional as F
|
|
403
|
+
import torch.nn.functional as F # type: ignore[import-not-found]
|
|
361
404
|
embeddings = F.normalize(embeddings, p=2, dim=1)
|
|
362
405
|
|
|
363
406
|
return embeddings.cpu().numpy().tolist()
|
|
364
407
|
|
|
365
408
|
def _last_token_pool(self, last_hidden_states, attention_mask):
|
|
366
409
|
"""Apply last token pooling to get embeddings."""
|
|
367
|
-
import torch
|
|
410
|
+
import torch # type: ignore[import-not-found]
|
|
368
411
|
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
|
369
412
|
if left_padding:
|
|
370
413
|
return last_hidden_states[:, -1]
|
|
@@ -376,6 +419,40 @@ class VectorCache:
|
|
|
376
419
|
def _hash_text(self, text: str) -> str:
|
|
377
420
|
return hashlib.sha1(text.encode("utf-8")).hexdigest()
|
|
378
421
|
|
|
422
|
+
def _execute_with_retry(self, query: str, params=None) -> sqlite3.Cursor:
|
|
423
|
+
"""Execute SQLite query with retry logic for multi-process safety."""
|
|
424
|
+
max_retries = 3
|
|
425
|
+
base_delay = 0.05 # 50ms base delay for reads (faster than writes)
|
|
426
|
+
|
|
427
|
+
last_exception = None
|
|
428
|
+
|
|
429
|
+
for attempt in range(max_retries + 1):
|
|
430
|
+
try:
|
|
431
|
+
if params is None:
|
|
432
|
+
return self.conn.execute(query)
|
|
433
|
+
else:
|
|
434
|
+
return self.conn.execute(query, params)
|
|
435
|
+
|
|
436
|
+
except sqlite3.OperationalError as e:
|
|
437
|
+
last_exception = e
|
|
438
|
+
if "database is locked" in str(e).lower() and attempt < max_retries:
|
|
439
|
+
# Exponential backoff: 0.05s, 0.1s, 0.2s
|
|
440
|
+
delay = base_delay * (2 ** attempt)
|
|
441
|
+
if self.verbose:
|
|
442
|
+
print(f"⚠️ Database locked on read, retrying in {delay:.2f}s (attempt {attempt + 1}/{max_retries + 1})")
|
|
443
|
+
import time
|
|
444
|
+
time.sleep(delay)
|
|
445
|
+
continue
|
|
446
|
+
else:
|
|
447
|
+
# Re-raise if not a lock error or max retries exceeded
|
|
448
|
+
raise
|
|
449
|
+
except Exception as e:
|
|
450
|
+
# Re-raise any other exceptions
|
|
451
|
+
raise
|
|
452
|
+
|
|
453
|
+
# This should never be reached, but satisfy the type checker
|
|
454
|
+
raise last_exception or RuntimeError("Failed to execute query after retries")
|
|
455
|
+
|
|
379
456
|
def embeds(self, texts: list[str], cache: bool = True) -> np.ndarray:
|
|
380
457
|
"""
|
|
381
458
|
Return embeddings for all texts.
|
|
@@ -402,11 +479,11 @@ class VectorCache:
|
|
|
402
479
|
hit_map: dict[str, np.ndarray] = {}
|
|
403
480
|
chunk_size = self.config["sqlite_chunk_size"]
|
|
404
481
|
|
|
405
|
-
# Use bulk lookup with optimized query
|
|
482
|
+
# Use bulk lookup with optimized query and retry logic
|
|
406
483
|
hash_chunks = _chunks(hashes, chunk_size)
|
|
407
484
|
for chunk in hash_chunks:
|
|
408
485
|
placeholders = ",".join("?" * len(chunk))
|
|
409
|
-
rows = self.
|
|
486
|
+
rows = self._execute_with_retry(
|
|
410
487
|
f"SELECT hash, embedding FROM cache WHERE hash IN ({placeholders})",
|
|
411
488
|
chunk,
|
|
412
489
|
).fetchall()
|
|
@@ -425,18 +502,8 @@ class VectorCache:
|
|
|
425
502
|
|
|
426
503
|
if missing_items:
|
|
427
504
|
if self.verbose:
|
|
428
|
-
print(f"Computing
|
|
429
|
-
|
|
430
|
-
embeds = self._get_embeddings(missing_texts)
|
|
431
|
-
|
|
432
|
-
# Prepare batch data for bulk insert
|
|
433
|
-
bulk_insert_data: list[tuple[str, str, bytes]] = []
|
|
434
|
-
for (text, h), vec in zip(missing_items, embeds):
|
|
435
|
-
arr = np.asarray(vec, dtype=np.float32)
|
|
436
|
-
bulk_insert_data.append((h, text, arr.tobytes()))
|
|
437
|
-
hit_map[h] = arr
|
|
438
|
-
|
|
439
|
-
self._bulk_insert(bulk_insert_data)
|
|
505
|
+
print(f"Computing {len(missing_items)}/{len(texts)} missing embeddings...")
|
|
506
|
+
self._process_missing_items_with_batches(missing_items, hit_map)
|
|
440
507
|
|
|
441
508
|
# Return embeddings in the original order
|
|
442
509
|
elapsed = time() - t
|
|
@@ -444,87 +511,215 @@ class VectorCache:
|
|
|
444
511
|
print(f"Retrieved {len(texts)} embeddings in {elapsed:.2f} seconds")
|
|
445
512
|
return np.vstack([hit_map[h] for h in hashes])
|
|
446
513
|
|
|
514
|
+
def _process_missing_items_with_batches(self, missing_items: list[tuple[str, str]], hit_map: dict[str, np.ndarray]) -> None:
|
|
515
|
+
"""
|
|
516
|
+
Process missing items in batches with progress bar and incremental DB insertion.
|
|
517
|
+
"""
|
|
518
|
+
t = time() # Track total processing time
|
|
519
|
+
|
|
520
|
+
# Try to import tqdm, fall back to simple progress if not available
|
|
521
|
+
tqdm = None # avoid "possibly unbound" in type checker
|
|
522
|
+
use_tqdm = False
|
|
523
|
+
try:
|
|
524
|
+
from tqdm import tqdm as _tqdm # type: ignore[import-not-found]
|
|
525
|
+
tqdm = _tqdm
|
|
526
|
+
use_tqdm = True
|
|
527
|
+
except ImportError:
|
|
528
|
+
use_tqdm = False
|
|
529
|
+
if self.verbose:
|
|
530
|
+
print("tqdm not available, using simple progress reporting")
|
|
531
|
+
|
|
532
|
+
batch_size = self.config["embedding_batch_size"]
|
|
533
|
+
total_items = len(missing_items)
|
|
534
|
+
|
|
535
|
+
if self.verbose:
|
|
536
|
+
print(f"Computing embeddings for {total_items} missing texts in batches of {batch_size}...")
|
|
537
|
+
if self.backend in ["vllm", "transformers"] and self._model is None:
|
|
538
|
+
print("⚠️ Model will be loaded on first batch (lazy loading enabled)")
|
|
539
|
+
elif self.backend in ["vllm", "transformers"]:
|
|
540
|
+
print("✓ Model already loaded, ready for efficient batch processing")
|
|
541
|
+
|
|
542
|
+
# Create progress bar
|
|
543
|
+
pbar = None
|
|
544
|
+
processed_count = 0
|
|
545
|
+
if use_tqdm and tqdm is not None:
|
|
546
|
+
pbar = tqdm(total=total_items, desc="Computing embeddings", unit="texts")
|
|
547
|
+
|
|
548
|
+
# Track total committed items
|
|
549
|
+
total_committed = 0
|
|
550
|
+
|
|
551
|
+
try:
|
|
552
|
+
# Process in batches
|
|
553
|
+
for i in range(0, total_items, batch_size):
|
|
554
|
+
batch_items = missing_items[i:i + batch_size]
|
|
555
|
+
batch_texts = [text for text, _ in batch_items]
|
|
556
|
+
|
|
557
|
+
# Get embeddings for this batch
|
|
558
|
+
batch_embeds = self._get_embeddings(batch_texts)
|
|
559
|
+
|
|
560
|
+
# Prepare batch data for immediate insert
|
|
561
|
+
batch_data: list[tuple[str, str, bytes]] = []
|
|
562
|
+
for (text, h), vec in zip(batch_items, batch_embeds):
|
|
563
|
+
arr = np.asarray(vec, dtype=np.float32)
|
|
564
|
+
batch_data.append((h, text, arr.tobytes()))
|
|
565
|
+
hit_map[h] = arr
|
|
566
|
+
|
|
567
|
+
# Immediate commit after each batch
|
|
568
|
+
self._bulk_insert(batch_data)
|
|
569
|
+
total_committed += len(batch_data)
|
|
570
|
+
|
|
571
|
+
# Update progress
|
|
572
|
+
batch_size_actual = len(batch_items)
|
|
573
|
+
if use_tqdm:
|
|
574
|
+
pbar.update(batch_size_actual) # type: ignore
|
|
575
|
+
else:
|
|
576
|
+
processed_count += batch_size_actual
|
|
577
|
+
if self.verbose:
|
|
578
|
+
print(f"Progress: {processed_count}/{total_items} embeddings computed, {total_committed} committed")
|
|
579
|
+
|
|
580
|
+
finally:
|
|
581
|
+
# Clean up progress bar
|
|
582
|
+
if pbar is not None:
|
|
583
|
+
pbar.close()
|
|
584
|
+
|
|
585
|
+
if self.verbose:
|
|
586
|
+
total_time = time() - t
|
|
587
|
+
rate = total_items / total_time if total_time > 0 else 0
|
|
588
|
+
print(f"✅ Completed: {total_items} embeddings computed and {total_committed} items committed to database")
|
|
589
|
+
print(f" Total time: {total_time:.2f}s | Rate: {rate:.1f} embeddings/sec")
|
|
590
|
+
|
|
447
591
|
def __call__(self, texts: list[str], cache: bool = True) -> np.ndarray:
|
|
448
592
|
assert isinstance(texts, list), "texts must be a list"
|
|
449
593
|
assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
|
|
450
594
|
return self.embeds(texts, cache)
|
|
451
595
|
|
|
452
596
|
def _bulk_insert(self, data: list[tuple[str, str, bytes]]) -> None:
|
|
453
|
-
"""Perform bulk insert of embedding data."""
|
|
454
|
-
if not data:
|
|
455
|
-
return
|
|
456
|
-
|
|
457
|
-
self.conn.executemany(
|
|
458
|
-
"INSERT OR REPLACE INTO cache (hash, text, embedding) VALUES (?, ?, ?)",
|
|
459
|
-
data,
|
|
460
|
-
)
|
|
461
|
-
self.conn.commit()
|
|
462
|
-
|
|
463
|
-
def precompute_embeddings(self, texts: list[str]) -> None:
|
|
464
597
|
"""
|
|
465
|
-
|
|
466
|
-
|
|
598
|
+
Perform bulk insert of embedding data with retry logic for multi-process safety.
|
|
599
|
+
|
|
600
|
+
Uses INSERT OR IGNORE to prevent race conditions where multiple processes
|
|
601
|
+
might try to insert the same text hash. The first process to successfully
|
|
602
|
+
insert wins, subsequent attempts are ignored. This ensures deterministic
|
|
603
|
+
caching behavior in multi-process environments.
|
|
467
604
|
"""
|
|
468
|
-
|
|
469
|
-
assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
|
|
470
|
-
if not texts:
|
|
471
|
-
return
|
|
472
|
-
|
|
473
|
-
# Remove duplicates while preserving order
|
|
474
|
-
unique_texts = list(dict.fromkeys(texts))
|
|
475
|
-
if self.verbose:
|
|
476
|
-
print(f"Precomputing embeddings for {len(unique_texts)} unique texts...")
|
|
477
|
-
|
|
478
|
-
# Check which ones are already cached
|
|
479
|
-
hashes = [self._hash_text(t) for t in unique_texts]
|
|
480
|
-
existing_hashes = set()
|
|
481
|
-
|
|
482
|
-
# Bulk check for existing embeddings
|
|
483
|
-
chunk_size = self.config["sqlite_chunk_size"]
|
|
484
|
-
for i in range(0, len(hashes), chunk_size):
|
|
485
|
-
chunk = hashes[i : i + chunk_size]
|
|
486
|
-
placeholders = ",".join("?" * len(chunk))
|
|
487
|
-
rows = self.conn.execute(
|
|
488
|
-
f"SELECT hash FROM cache WHERE hash IN ({placeholders})",
|
|
489
|
-
chunk,
|
|
490
|
-
).fetchall()
|
|
491
|
-
existing_hashes.update(h[0] for h in rows)
|
|
492
|
-
|
|
493
|
-
# Find missing texts
|
|
494
|
-
missing_items = [
|
|
495
|
-
(t, h) for t, h in zip(unique_texts, hashes) if h not in existing_hashes
|
|
496
|
-
]
|
|
497
|
-
|
|
498
|
-
if not missing_items:
|
|
499
|
-
if self.verbose:
|
|
500
|
-
print("All texts already cached!")
|
|
605
|
+
if not data:
|
|
501
606
|
return
|
|
502
607
|
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
608
|
+
max_retries = 3
|
|
609
|
+
base_delay = 0.1 # 100ms base delay
|
|
610
|
+
|
|
611
|
+
for attempt in range(max_retries + 1):
|
|
612
|
+
try:
|
|
613
|
+
cursor = self.conn.executemany(
|
|
614
|
+
"INSERT OR IGNORE INTO cache (hash, text, embedding) VALUES (?, ?, ?)",
|
|
615
|
+
data,
|
|
616
|
+
)
|
|
617
|
+
self.conn.commit()
|
|
618
|
+
|
|
619
|
+
# Check if some insertions were ignored due to existing entries
|
|
620
|
+
if self.verbose and cursor.rowcount < len(data):
|
|
621
|
+
ignored_count = len(data) - cursor.rowcount
|
|
622
|
+
if ignored_count > 0:
|
|
623
|
+
print(f"ℹ️ {ignored_count}/{len(data)} embeddings already existed in cache (computed by another process)")
|
|
624
|
+
|
|
625
|
+
return # Success, exit the retry loop
|
|
626
|
+
|
|
627
|
+
except sqlite3.OperationalError as e:
|
|
628
|
+
if "database is locked" in str(e).lower() and attempt < max_retries:
|
|
629
|
+
# Exponential backoff: 0.1s, 0.2s, 0.4s
|
|
630
|
+
delay = base_delay * (2 ** attempt)
|
|
631
|
+
if self.verbose:
|
|
632
|
+
print(f"⚠️ Database locked, retrying in {delay:.1f}s (attempt {attempt + 1}/{max_retries + 1})")
|
|
633
|
+
import time
|
|
634
|
+
time.sleep(delay)
|
|
635
|
+
continue
|
|
636
|
+
else:
|
|
637
|
+
# Re-raise if not a lock error or max retries exceeded
|
|
638
|
+
raise
|
|
639
|
+
except Exception as e:
|
|
640
|
+
# Re-raise any other exceptions
|
|
641
|
+
raise
|
|
642
|
+
|
|
643
|
+
# def precompute_embeddings(self, texts: list[str]) -> None:
|
|
644
|
+
# """
|
|
645
|
+
# Precompute embeddings for a large list of texts efficiently.
|
|
646
|
+
# This is optimized for bulk operations when you know all texts upfront.
|
|
647
|
+
# """
|
|
648
|
+
# assert isinstance(texts, list), "texts must be a list"
|
|
649
|
+
# assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
|
|
650
|
+
# if not texts:
|
|
651
|
+
# return
|
|
652
|
+
|
|
653
|
+
# # Remove duplicates while preserving order
|
|
654
|
+
# unique_texts = list(dict.fromkeys(texts))
|
|
655
|
+
# if self.verbose:
|
|
656
|
+
# print(f"Precomputing embeddings for {len(unique_texts)} unique texts...")
|
|
657
|
+
|
|
658
|
+
# # Check which ones are already cached
|
|
659
|
+
# hashes = [self._hash_text(t) for t in unique_texts]
|
|
660
|
+
# existing_hashes = set()
|
|
661
|
+
|
|
662
|
+
# # Bulk check for existing embeddings
|
|
663
|
+
# chunk_size = self.config["sqlite_chunk_size"]
|
|
664
|
+
# for i in range(0, len(hashes), chunk_size):
|
|
665
|
+
# chunk = hashes[i : i + chunk_size]
|
|
666
|
+
# placeholders = ",".join("?" * len(chunk))
|
|
667
|
+
# rows = self._execute_with_retry(
|
|
668
|
+
# f"SELECT hash FROM cache WHERE hash IN ({placeholders})",
|
|
669
|
+
# chunk,
|
|
670
|
+
# ).fetchall()
|
|
671
|
+
# existing_hashes.update(h[0] for h in rows)
|
|
672
|
+
|
|
673
|
+
# # Find missing texts
|
|
674
|
+
# missing_items = [
|
|
675
|
+
# (t, h) for t, h in zip(unique_texts, hashes) if h not in existing_hashes
|
|
676
|
+
# ]
|
|
677
|
+
|
|
678
|
+
# if not missing_items:
|
|
679
|
+
# if self.verbose:
|
|
680
|
+
# print("All texts already cached!")
|
|
681
|
+
# return
|
|
682
|
+
|
|
683
|
+
# if self.verbose:
|
|
684
|
+
# print(f"Computing {len(missing_items)} missing embeddings...")
|
|
685
|
+
|
|
686
|
+
# # Process missing items with batches
|
|
687
|
+
# missing_texts = [t for t, _ in missing_items]
|
|
688
|
+
# missing_items_tupled = [(t, h) for t, h in zip(missing_texts, [self._hash_text(t) for t in missing_texts])]
|
|
689
|
+
# hit_map_temp: dict[str, np.ndarray] = {}
|
|
690
|
+
# self._process_missing_items_with_batches(missing_items_tupled, hit_map_temp)
|
|
691
|
+
# if self.verbose:
|
|
692
|
+
# print(f"Successfully cached {len(missing_items)} new embeddings!")
|
|
517
693
|
|
|
518
694
|
def get_cache_stats(self) -> dict[str, int]:
|
|
519
695
|
"""Get statistics about the cache."""
|
|
520
|
-
cursor = self.
|
|
696
|
+
cursor = self._execute_with_retry("SELECT COUNT(*) FROM cache")
|
|
521
697
|
count = cursor.fetchone()[0]
|
|
522
698
|
return {"total_cached": count}
|
|
523
699
|
|
|
524
700
|
def clear_cache(self) -> None:
|
|
525
701
|
"""Clear all cached embeddings."""
|
|
526
|
-
|
|
527
|
-
|
|
702
|
+
max_retries = 3
|
|
703
|
+
base_delay = 0.1 # 100ms base delay
|
|
704
|
+
|
|
705
|
+
for attempt in range(max_retries + 1):
|
|
706
|
+
try:
|
|
707
|
+
self.conn.execute("DELETE FROM cache")
|
|
708
|
+
self.conn.commit()
|
|
709
|
+
return # Success
|
|
710
|
+
|
|
711
|
+
except sqlite3.OperationalError as e:
|
|
712
|
+
if "database is locked" in str(e).lower() and attempt < max_retries:
|
|
713
|
+
delay = base_delay * (2 ** attempt)
|
|
714
|
+
if self.verbose:
|
|
715
|
+
print(f"⚠️ Database locked during clear, retrying in {delay:.1f}s (attempt {attempt + 1}/{max_retries + 1})")
|
|
716
|
+
import time
|
|
717
|
+
time.sleep(delay)
|
|
718
|
+
continue
|
|
719
|
+
else:
|
|
720
|
+
raise
|
|
721
|
+
except Exception as e:
|
|
722
|
+
raise
|
|
528
723
|
|
|
529
724
|
def get_config(self) -> Dict[str, Any]:
|
|
530
725
|
"""Get current configuration."""
|
|
@@ -556,7 +751,8 @@ class VectorCache:
|
|
|
556
751
|
"vllm_trust_remote_code", "vllm_max_model_len"],
|
|
557
752
|
"transformers": ["transformers_device", "transformers_batch_size",
|
|
558
753
|
"transformers_normalize_embeddings", "transformers_trust_remote_code"],
|
|
559
|
-
"openai": ["api_key", "model_name"]
|
|
754
|
+
"openai": ["api_key", "model_name"],
|
|
755
|
+
"processing": ["embedding_batch_size"]
|
|
560
756
|
}
|
|
561
757
|
|
|
562
758
|
if any(param in kwargs for param in backend_params.get(self.backend, [])):
|
speedy_utils/__init__.py
CHANGED
|
@@ -138,7 +138,7 @@ from .common.utils_print import (
|
|
|
138
138
|
|
|
139
139
|
# Multi-worker processing
|
|
140
140
|
from .multi_worker.process import multi_process
|
|
141
|
-
from .multi_worker.thread import multi_thread
|
|
141
|
+
from .multi_worker.thread import kill_all_thread, multi_thread
|
|
142
142
|
|
|
143
143
|
# Define __all__ explicitly
|
|
144
144
|
__all__ = [
|
|
@@ -224,6 +224,7 @@ __all__ = [
|
|
|
224
224
|
# Multi-worker processing
|
|
225
225
|
"multi_process",
|
|
226
226
|
"multi_thread",
|
|
227
|
+
"kill_all_thread",
|
|
227
228
|
# Notebook utilities
|
|
228
229
|
"change_dir",
|
|
229
230
|
]
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# utils/patching.py
|
|
2
|
+
import inspect
|
|
3
|
+
import types
|
|
4
|
+
import re
|
|
5
|
+
from typing import Annotated, Union
|
|
6
|
+
|
|
7
|
+
def patch_method(
|
|
8
|
+
cls: Annotated[type, "Class containing the method"],
|
|
9
|
+
method_name: Annotated[str, "Name of the method to patch"],
|
|
10
|
+
replacements: Annotated[
|
|
11
|
+
dict[Union[str, re.Pattern], str],
|
|
12
|
+
"Mapping of {old_substring_or_regex: new_string} replacements"
|
|
13
|
+
],
|
|
14
|
+
tag: Annotated[str, "Optional logging tag"] = "",
|
|
15
|
+
) -> bool:
|
|
16
|
+
"""
|
|
17
|
+
Generic patcher for replacing substrings or regex matches in a method's source code.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
cls: Target class
|
|
21
|
+
method_name: Method name to patch
|
|
22
|
+
replacements: {pattern: replacement}. Patterns may be plain strings or regex patterns.
|
|
23
|
+
tag: Optional string shown in logs
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
bool: True if successfully patched, False otherwise
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
method = getattr(cls, method_name)
|
|
31
|
+
except AttributeError:
|
|
32
|
+
print(f"[patcher{':' + tag if tag else ''}] No method {method_name} in {cls.__name__}")
|
|
33
|
+
return False
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
src = inspect.getsource(method)
|
|
37
|
+
except (OSError, TypeError):
|
|
38
|
+
print(f"[patcher{':' + tag if tag else ''}] Could not get source for {cls.__name__}.{method_name}")
|
|
39
|
+
return False
|
|
40
|
+
|
|
41
|
+
new_src = src
|
|
42
|
+
did_patch = False
|
|
43
|
+
|
|
44
|
+
for old, new in replacements.items():
|
|
45
|
+
if isinstance(old, re.Pattern):
|
|
46
|
+
if old.search(new_src):
|
|
47
|
+
new_src = old.sub(new, new_src)
|
|
48
|
+
did_patch = True
|
|
49
|
+
elif isinstance(old, str):
|
|
50
|
+
if old in new_src:
|
|
51
|
+
new_src = new_src.replace(old, new)
|
|
52
|
+
did_patch = True
|
|
53
|
+
else:
|
|
54
|
+
raise TypeError("Replacement keys must be str or re.Pattern")
|
|
55
|
+
|
|
56
|
+
if not did_patch:
|
|
57
|
+
print(f"[patcher{':' + tag if tag else ''}] No matching patterns found in {cls.__name__}.{method_name}")
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
# Recompile the patched function
|
|
61
|
+
code_obj = compile(new_src, filename=f"<patched_{method_name}>", mode="exec")
|
|
62
|
+
ns = {}
|
|
63
|
+
exec(code_obj, cls.__dict__, ns) # type: ignore
|
|
64
|
+
|
|
65
|
+
# Attach patched method back
|
|
66
|
+
setattr(cls, method_name, types.MethodType(ns[method_name], None, cls)) # type: ignore
|
|
67
|
+
print(f"[patcher{':' + tag if tag else ''}] Patched {cls.__name__}.{method_name}")
|
|
68
|
+
return True
|
|
@@ -563,7 +563,7 @@ def _async_both_memoize(
|
|
|
563
563
|
|
|
564
564
|
@overload
|
|
565
565
|
def memoize(
|
|
566
|
-
_func: Callable[P, R],
|
|
566
|
+
_func: Callable[P, R | Awaitable[R]],
|
|
567
567
|
*,
|
|
568
568
|
keys: Optional[list[str]] = ...,
|
|
569
569
|
key: Optional[Callable[..., Any]] = ...,
|
|
@@ -572,10 +572,10 @@ def memoize(
|
|
|
572
572
|
size: int = ...,
|
|
573
573
|
ignore_self: bool = ...,
|
|
574
574
|
verbose: bool = ...,
|
|
575
|
-
) -> Callable[P, R]: ...
|
|
575
|
+
) -> Callable[P, R | Awaitable[R]]: ...
|
|
576
576
|
@overload
|
|
577
577
|
def memoize(
|
|
578
|
-
_func:
|
|
578
|
+
_func: None = ...,
|
|
579
579
|
*,
|
|
580
580
|
keys: Optional[list[str]] = ...,
|
|
581
581
|
key: Optional[Callable[..., Any]] = ...,
|
|
@@ -584,9 +584,9 @@ def memoize(
|
|
|
584
584
|
size: int = ...,
|
|
585
585
|
ignore_self: bool = ...,
|
|
586
586
|
verbose: bool = ...,
|
|
587
|
-
) -> Callable[P,
|
|
587
|
+
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
|
588
588
|
@overload
|
|
589
|
-
def memoize(
|
|
589
|
+
def memoize( # type: ignore
|
|
590
590
|
_func: None = ...,
|
|
591
591
|
*,
|
|
592
592
|
keys: Optional[list[str]] = ...,
|
|
@@ -596,7 +596,7 @@ def memoize(
|
|
|
596
596
|
size: int = ...,
|
|
597
597
|
ignore_self: bool = ...,
|
|
598
598
|
verbose: bool = ...,
|
|
599
|
-
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
|
599
|
+
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ...
|
|
600
600
|
|
|
601
601
|
|
|
602
602
|
def memoize(
|