speedy-utils 1.1.18__py3-none-any.whl → 1.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,50 +13,51 @@ import numpy as np
13
13
  class VectorCache:
14
14
  """
15
15
  A caching layer for text embeddings with support for multiple backends.
16
-
16
+
17
17
  This cache is designed to be safe for multi-process environments where multiple
18
18
  processes may access the same cache file simultaneously. It uses SQLite WAL mode
19
19
  and retry logic with exponential backoff to handle concurrent access.
20
-
20
+
21
21
  Examples:
22
22
  # OpenAI API
23
23
  from llm_utils import VectorCache
24
24
  cache = VectorCache("https://api.openai.com/v1", api_key="your-key")
25
25
  embeddings = cache.embeds(["Hello world", "How are you?"])
26
-
26
+
27
27
  # Custom OpenAI-compatible server (auto-detects model)
28
28
  cache = VectorCache("http://localhost:8000/v1", api_key="abc")
29
-
29
+
30
30
  # Transformers (Sentence Transformers)
31
31
  cache = VectorCache("sentence-transformers/all-MiniLM-L6-v2")
32
-
32
+
33
33
  # vLLM (local model)
34
34
  cache = VectorCache("/path/to/model")
35
-
35
+
36
36
  # Explicit backend specification
37
37
  cache = VectorCache("model-name", backend="transformers")
38
-
38
+
39
39
  # Eager loading (default: False) - load model immediately for better performance
40
40
  cache = VectorCache("model-name", lazy=False)
41
-
41
+
42
42
  # Lazy loading - load model only when needed (may cause performance issues)
43
43
  cache = VectorCache("model-name", lazy=True)
44
-
44
+
45
45
  Multi-Process Safety:
46
46
  The cache uses SQLite WAL (Write-Ahead Logging) mode and implements retry logic
47
47
  with exponential backoff to handle database locks. Multiple processes can safely
48
48
  read and write to the same cache file simultaneously.
49
-
49
+
50
50
  Race Condition Protection:
51
51
  - Uses INSERT OR IGNORE to prevent overwrites when multiple processes compute the same text
52
52
  - The first process to successfully cache a text wins, subsequent attempts are ignored
53
53
  - This ensures deterministic results even with non-deterministic embedding models
54
-
54
+
55
55
  For best performance in multi-process scenarios, consider:
56
56
  - Using separate cache files per process if cache hits are low
57
57
  - Coordinating cache warm-up to avoid redundant computation
58
58
  - Monitor for excessive lock contention in high-concurrency scenarios
59
59
  """
60
+
60
61
  def __init__(
61
62
  self,
62
63
  url_or_model: str,
@@ -80,7 +81,7 @@ class VectorCache:
80
81
  # SQLite parameters
81
82
  sqlite_chunk_size: int = 999,
82
83
  sqlite_cache_size: int = 10000,
83
- sqlite_mmap_size: int = 268435456,
84
+ sqlite_mmap_size: int = 268435456, # 256MB
84
85
  # Processing parameters
85
86
  embedding_batch_size: int = 20_000,
86
87
  # Other parameters
@@ -91,11 +92,11 @@ class VectorCache:
91
92
  self.embed_size = embed_size
92
93
  self.verbose = verbose
93
94
  self.lazy = lazy
94
-
95
+
95
96
  self.backend = self._determine_backend(backend)
96
97
  if self.verbose and backend is None:
97
98
  print(f"Auto-detected backend: {self.backend}")
98
-
99
+
99
100
  # Store all configuration parameters
100
101
  self.config = {
101
102
  # OpenAI
@@ -119,18 +120,20 @@ class VectorCache:
119
120
  # Processing
120
121
  "embedding_batch_size": embedding_batch_size,
121
122
  }
122
-
123
+
123
124
  # Auto-detect model_name for OpenAI if using custom URL and default model
124
- if (self.backend == "openai" and
125
- model_name == "text-embedding-3-small" and
126
- self.url_or_model != "https://api.openai.com/v1"):
125
+ if (
126
+ self.backend == "openai"
127
+ and model_name == "text-embedding-3-small"
128
+ and self.url_or_model != "https://api.openai.com/v1"
129
+ ):
127
130
  if self.verbose:
128
131
  print(f"Attempting to auto-detect model from {self.url_or_model}...")
129
132
  try:
130
133
  import openai
134
+
131
135
  client = openai.OpenAI(
132
- base_url=self.url_or_model,
133
- api_key=self.config["api_key"]
136
+ base_url=self.url_or_model, api_key=self.config["api_key"]
134
137
  )
135
138
  models = client.models.list()
136
139
  if models.data:
@@ -147,7 +150,7 @@ class VectorCache:
147
150
  print(f"Model auto-detection failed: {e}, using default model")
148
151
  # Fallback to default if auto-detection fails
149
152
  pass
150
-
153
+
151
154
  # Set default db_path if not provided
152
155
  if db_path is None:
153
156
  if self.backend == "openai":
@@ -155,19 +158,21 @@ class VectorCache:
155
158
  else:
156
159
  model_id = self.url_or_model
157
160
  safe_name = hashlib.sha1(model_id.encode("utf-8")).hexdigest()[:16]
158
- self.db_path = Path.home() / ".cache" / "embed" / f"{self.backend}_{safe_name}.sqlite"
161
+ self.db_path = (
162
+ Path.home() / ".cache" / "embed" / f"{self.backend}_{safe_name}.sqlite"
163
+ )
159
164
  else:
160
165
  self.db_path = Path(db_path)
161
-
166
+
162
167
  # Ensure the directory exists
163
168
  self.db_path.parent.mkdir(parents=True, exist_ok=True)
164
-
169
+
165
170
  self.conn = sqlite3.connect(self.db_path)
166
171
  self._optimize_connection()
167
172
  self._ensure_schema()
168
173
  self._model = None # Lazy loading
169
174
  self._client = None # For OpenAI client
170
-
175
+
171
176
  # Load model/client if not lazy
172
177
  if not self.lazy:
173
178
  if self.verbose:
@@ -179,34 +184,41 @@ class VectorCache:
179
184
  if self.verbose:
180
185
  print(f"✓ {self.backend.upper()} model/client loaded successfully")
181
186
 
182
- def _determine_backend(self, backend: Optional[Literal["vllm", "transformers", "openai"]]) -> str:
187
+ def _determine_backend(
188
+ self, backend: Optional[Literal["vllm", "transformers", "openai"]]
189
+ ) -> str:
183
190
  """Determine the appropriate backend based on url_or_model and user preference."""
184
191
  if backend is not None:
185
192
  valid_backends = ["vllm", "transformers", "openai"]
186
193
  if backend not in valid_backends:
187
- raise ValueError(f"Invalid backend '{backend}'. Must be one of: {valid_backends}")
194
+ raise ValueError(
195
+ f"Invalid backend '{backend}'. Must be one of: {valid_backends}"
196
+ )
188
197
  return backend
189
-
198
+
190
199
  if self.url_or_model.startswith("http"):
191
200
  return "openai"
192
-
201
+
193
202
  # Default to vllm for local models
194
203
  return "vllm"
204
+
195
205
  def _try_infer_model_name(self, model_name: Optional[str]) -> Optional[str]:
196
206
  """Infer model name for OpenAI backend if not explicitly provided."""
197
207
  if model_name:
198
208
  return model_name
199
- if 'https://' in self.url_or_model:
200
- model_name = "text-embedding-3-small"
201
-
202
- if 'http://localhost' in self.url_or_model:
209
+ if "https://" in self.url_or_model:
210
+ model_name = "text-embedding-3-small"
211
+
212
+ if "http://localhost" in self.url_or_model:
203
213
  from openai import OpenAI
204
- client = OpenAI(base_url=self.url_or_model, api_key='abc')
205
- model_name = client.models.list().data[0].id
214
+
215
+ client = OpenAI(base_url=self.url_or_model, api_key="abc")
216
+ model_name = client.models.list().data[0].id
206
217
 
207
218
  # Default model name
208
- print('Infer model name:', model_name)
219
+ print("Infer model name:", model_name)
209
220
  return model_name
221
+
210
222
  def _optimize_connection(self) -> None:
211
223
  """Optimize SQLite connection for bulk operations and multi-process safety."""
212
224
  # Performance optimizations for bulk operations
@@ -214,13 +226,21 @@ class VectorCache:
214
226
  "PRAGMA journal_mode=WAL"
215
227
  ) # Write-Ahead Logging for better concurrency
216
228
  self.conn.execute("PRAGMA synchronous=NORMAL") # Faster writes, still safe
217
- self.conn.execute(f"PRAGMA cache_size={self.config['sqlite_cache_size']}") # Configurable cache
229
+ self.conn.execute(
230
+ f"PRAGMA cache_size={self.config['sqlite_cache_size']}"
231
+ ) # Configurable cache
218
232
  self.conn.execute("PRAGMA temp_store=MEMORY") # Use memory for temp storage
219
- self.conn.execute(f"PRAGMA mmap_size={self.config['sqlite_mmap_size']}") # Configurable memory mapping
220
-
233
+ self.conn.execute(
234
+ f"PRAGMA mmap_size={self.config['sqlite_mmap_size']}"
235
+ ) # Configurable memory mapping
236
+
221
237
  # Multi-process safety improvements
222
- self.conn.execute("PRAGMA busy_timeout=30000") # Wait up to 30 seconds for locks
223
- self.conn.execute("PRAGMA wal_autocheckpoint=1000") # Checkpoint WAL every 1000 pages
238
+ self.conn.execute(
239
+ "PRAGMA busy_timeout=30000"
240
+ ) # Wait up to 30 seconds for locks
241
+ self.conn.execute(
242
+ "PRAGMA wal_autocheckpoint=1000"
243
+ ) # Checkpoint WAL every 1000 pages
224
244
 
225
245
  def _ensure_schema(self) -> None:
226
246
  self.conn.execute("""
@@ -239,22 +259,24 @@ class VectorCache:
239
259
  def _load_openai_client(self) -> None:
240
260
  """Load OpenAI client."""
241
261
  import openai
262
+
242
263
  self._client = openai.OpenAI(
243
- base_url=self.url_or_model,
244
- api_key=self.config["api_key"]
264
+ base_url=self.url_or_model, api_key=self.config["api_key"]
245
265
  )
246
266
 
247
267
  def _load_model(self) -> None:
248
268
  """Load the model for vLLM or Transformers."""
249
269
  if self.backend == "vllm":
250
270
  from vllm import LLM # type: ignore[import-not-found]
251
-
252
- gpu_memory_utilization = cast(float, self.config["vllm_gpu_memory_utilization"])
271
+
272
+ gpu_memory_utilization = cast(
273
+ float, self.config["vllm_gpu_memory_utilization"]
274
+ )
253
275
  tensor_parallel_size = cast(int, self.config["vllm_tensor_parallel_size"])
254
276
  dtype = cast(str, self.config["vllm_dtype"])
255
277
  trust_remote_code = cast(bool, self.config["vllm_trust_remote_code"])
256
278
  max_model_len = cast(Optional[int], self.config["vllm_max_model_len"])
257
-
279
+
258
280
  vllm_kwargs = {
259
281
  "model": self.url_or_model,
260
282
  "task": "embed",
@@ -263,18 +285,23 @@ class VectorCache:
263
285
  "dtype": dtype,
264
286
  "trust_remote_code": trust_remote_code,
265
287
  }
266
-
288
+
267
289
  if max_model_len is not None:
268
290
  vllm_kwargs["max_model_len"] = max_model_len
269
-
291
+
270
292
  try:
271
293
  self._model = LLM(**vllm_kwargs)
272
294
  except (ValueError, AssertionError, RuntimeError) as e:
273
295
  error_msg = str(e).lower()
274
- if ("kv cache" in error_msg and "gpu_memory_utilization" in error_msg) or \
275
- ("memory" in error_msg and ("gpu" in error_msg or "insufficient" in error_msg)) or \
276
- ("free memory" in error_msg and "initial" in error_msg) or \
277
- ("engine core initialization failed" in error_msg):
296
+ if (
297
+ ("kv cache" in error_msg and "gpu_memory_utilization" in error_msg)
298
+ or (
299
+ "memory" in error_msg
300
+ and ("gpu" in error_msg or "insufficient" in error_msg)
301
+ )
302
+ or ("free memory" in error_msg and "initial" in error_msg)
303
+ or ("engine core initialization failed" in error_msg)
304
+ ):
278
305
  raise ValueError(
279
306
  f"Insufficient GPU memory for vLLM model initialization. "
280
307
  f"Current vllm_gpu_memory_utilization ({gpu_memory_utilization}) may be too low. "
@@ -288,27 +315,39 @@ class VectorCache:
288
315
  else:
289
316
  raise
290
317
  elif self.backend == "transformers":
291
- from transformers import AutoTokenizer, AutoModel # type: ignore[import-not-found]
292
- import torch # type: ignore[import-not-found]
293
-
318
+ import torch # type: ignore[import-not-found] # noqa: F401
319
+ from transformers import ( # type: ignore[import-not-found]
320
+ AutoModel,
321
+ AutoTokenizer,
322
+ )
323
+
294
324
  device = self.config["transformers_device"]
295
325
  # Handle "auto" device selection - default to CPU for transformers to avoid memory conflicts
296
326
  if device == "auto":
297
327
  device = "cpu" # Default to CPU to avoid GPU memory conflicts with vLLM
298
-
299
- tokenizer = AutoTokenizer.from_pretrained(self.url_or_model, padding_side='left', trust_remote_code=self.config["transformers_trust_remote_code"])
300
- model = AutoModel.from_pretrained(self.url_or_model, trust_remote_code=self.config["transformers_trust_remote_code"])
301
-
328
+
329
+ tokenizer = AutoTokenizer.from_pretrained(
330
+ self.url_or_model,
331
+ padding_side="left",
332
+ trust_remote_code=self.config["transformers_trust_remote_code"],
333
+ )
334
+ model = AutoModel.from_pretrained(
335
+ self.url_or_model,
336
+ trust_remote_code=self.config["transformers_trust_remote_code"],
337
+ )
338
+
302
339
  # Move model to device
303
340
  model.to(device)
304
341
  model.eval()
305
-
342
+
306
343
  self._model = {"tokenizer": tokenizer, "model": model, "device": device}
307
344
 
308
345
  def _get_embeddings(self, texts: list[str]) -> list[list[float]]:
309
346
  """Get embeddings using the configured backend."""
310
347
  assert isinstance(texts, list), "texts must be a list"
311
- assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
348
+ assert all(isinstance(t, str) for t in texts), (
349
+ "all elements in texts must be strings"
350
+ )
312
351
  if self.backend == "openai":
313
352
  return self._get_openai_embeddings(texts)
314
353
  elif self.backend == "vllm":
@@ -321,10 +360,14 @@ class VectorCache:
321
360
  def _get_openai_embeddings(self, texts: list[str]) -> list[list[float]]:
322
361
  """Get embeddings using OpenAI API."""
323
362
  assert isinstance(texts, list), "texts must be a list"
324
- assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
363
+ assert all(isinstance(t, str) for t in texts), (
364
+ "all elements in texts must be strings"
365
+ )
325
366
  # Assert valid model_name for OpenAI backend
326
367
  model_name = self.config["model_name"]
327
- assert model_name is not None and model_name.strip(), f"Invalid model_name for OpenAI backend: {model_name}. Model name must be provided and non-empty."
368
+ assert model_name is not None and model_name.strip(), (
369
+ f"Invalid model_name for OpenAI backend: {model_name}. Model name must be provided and non-empty."
370
+ )
328
371
 
329
372
  if self._client is None:
330
373
  if self.verbose:
@@ -332,10 +375,9 @@ class VectorCache:
332
375
  self._load_openai_client()
333
376
  if self.verbose:
334
377
  print("✓ OpenAI client loaded successfully")
335
-
378
+
336
379
  response = self._client.embeddings.create( # type: ignore
337
- model=model_name,
338
- input=texts
380
+ model=model_name, input=texts
339
381
  )
340
382
  embeddings = [item.embedding for item in response.data]
341
383
  return embeddings
@@ -343,14 +385,16 @@ class VectorCache:
343
385
  def _get_vllm_embeddings(self, texts: list[str]) -> list[list[float]]:
344
386
  """Get embeddings using vLLM."""
345
387
  assert isinstance(texts, list), "texts must be a list"
346
- assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
388
+ assert all(isinstance(t, str) for t in texts), (
389
+ "all elements in texts must be strings"
390
+ )
347
391
  if self._model is None:
348
392
  if self.verbose:
349
393
  print("🔧 Loading vLLM model...")
350
394
  self._load_model()
351
395
  if self.verbose:
352
396
  print("✓ vLLM model loaded successfully")
353
-
397
+
354
398
  outputs = self._model.embed(texts) # type: ignore
355
399
  embeddings = [o.outputs.embedding for o in outputs]
356
400
  return embeddings
@@ -358,26 +402,30 @@ class VectorCache:
358
402
  def _get_transformers_embeddings(self, texts: list[str]) -> list[list[float]]:
359
403
  """Get embeddings using transformers directly."""
360
404
  assert isinstance(texts, list), "texts must be a list"
361
- assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
405
+ assert all(isinstance(t, str) for t in texts), (
406
+ "all elements in texts must be strings"
407
+ )
362
408
  if self._model is None:
363
409
  if self.verbose:
364
410
  print("🔧 Loading Transformers model...")
365
411
  self._load_model()
366
412
  if self.verbose:
367
413
  print("✓ Transformers model loaded successfully")
368
-
414
+
369
415
  if not isinstance(self._model, dict):
370
416
  raise ValueError("Model not loaded properly for transformers backend")
371
-
417
+
372
418
  tokenizer = self._model["tokenizer"]
373
419
  model = self._model["model"]
374
420
  device = self._model["device"]
375
-
376
- normalize_embeddings = cast(bool, self.config["transformers_normalize_embeddings"])
377
-
421
+
422
+ normalize_embeddings = cast(
423
+ bool, self.config["transformers_normalize_embeddings"]
424
+ )
425
+
378
426
  # For now, use a default max_length
379
427
  max_length = 8192
380
-
428
+
381
429
  # Tokenize
382
430
  batch_dict = tokenizer(
383
431
  texts,
@@ -386,35 +434,43 @@ class VectorCache:
386
434
  max_length=max_length,
387
435
  return_tensors="pt",
388
436
  )
389
-
437
+
390
438
  # Move to device
391
439
  batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
392
-
440
+
393
441
  # Run model
394
442
  import torch # type: ignore[import-not-found]
443
+
395
444
  with torch.no_grad():
396
445
  outputs = model(**batch_dict)
397
-
446
+
398
447
  # Apply last token pooling
399
- embeddings = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
400
-
448
+ embeddings = self._last_token_pool(
449
+ outputs.last_hidden_state, batch_dict["attention_mask"]
450
+ )
451
+
401
452
  # Normalize if needed
402
453
  if normalize_embeddings:
403
454
  import torch.nn.functional as F # type: ignore[import-not-found]
455
+
404
456
  embeddings = F.normalize(embeddings, p=2, dim=1)
405
-
457
+
406
458
  return embeddings.cpu().numpy().tolist()
407
459
 
408
460
  def _last_token_pool(self, last_hidden_states, attention_mask):
409
461
  """Apply last token pooling to get embeddings."""
410
462
  import torch # type: ignore[import-not-found]
411
- left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
463
+
464
+ left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
412
465
  if left_padding:
413
466
  return last_hidden_states[:, -1]
414
467
  else:
415
468
  sequence_lengths = attention_mask.sum(dim=1) - 1
416
469
  batch_size = last_hidden_states.shape[0]
417
- return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
470
+ return last_hidden_states[
471
+ torch.arange(batch_size, device=last_hidden_states.device),
472
+ sequence_lengths,
473
+ ]
418
474
 
419
475
  def _hash_text(self, text: str) -> str:
420
476
  return hashlib.sha1(text.encode("utf-8")).hexdigest()
@@ -423,33 +479,36 @@ class VectorCache:
423
479
  """Execute SQLite query with retry logic for multi-process safety."""
424
480
  max_retries = 3
425
481
  base_delay = 0.05 # 50ms base delay for reads (faster than writes)
426
-
482
+
427
483
  last_exception = None
428
-
484
+
429
485
  for attempt in range(max_retries + 1):
430
486
  try:
431
487
  if params is None:
432
488
  return self.conn.execute(query)
433
489
  else:
434
490
  return self.conn.execute(query, params)
435
-
491
+
436
492
  except sqlite3.OperationalError as e:
437
493
  last_exception = e
438
494
  if "database is locked" in str(e).lower() and attempt < max_retries:
439
495
  # Exponential backoff: 0.05s, 0.1s, 0.2s
440
- delay = base_delay * (2 ** attempt)
496
+ delay = base_delay * (2**attempt)
441
497
  if self.verbose:
442
- print(f"⚠️ Database locked on read, retrying in {delay:.2f}s (attempt {attempt + 1}/{max_retries + 1})")
498
+ print(
499
+ f"⚠️ Database locked on read, retrying in {delay:.2f}s (attempt {attempt + 1}/{max_retries + 1})"
500
+ )
443
501
  import time
502
+
444
503
  time.sleep(delay)
445
504
  continue
446
505
  else:
447
506
  # Re-raise if not a lock error or max retries exceeded
448
507
  raise
449
- except Exception as e:
508
+ except Exception:
450
509
  # Re-raise any other exceptions
451
510
  raise
452
-
511
+
453
512
  # This should never be reached, but satisfy the type checker
454
513
  raise last_exception or RuntimeError("Failed to execute query after retries")
455
514
 
@@ -465,7 +524,9 @@ class VectorCache:
465
524
  computing missing embeddings.
466
525
  """
467
526
  assert isinstance(texts, list), "texts must be a list"
468
- assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
527
+ assert all(isinstance(t, str) for t in texts), (
528
+ "all elements in texts must be strings"
529
+ )
469
530
  if not texts:
470
531
  return np.empty((0, 0), dtype=np.float32)
471
532
  t = time()
@@ -502,7 +563,9 @@ class VectorCache:
502
563
 
503
564
  if missing_items:
504
565
  if self.verbose:
505
- print(f"Computing {len(missing_items)}/{len(texts)} missing embeddings...")
566
+ print(
567
+ f"Computing {len(missing_items)}/{len(texts)} missing embeddings..."
568
+ )
506
569
  self._process_missing_items_with_batches(missing_items, hit_map)
507
570
 
508
571
  # Return embeddings in the original order
@@ -511,92 +574,81 @@ class VectorCache:
511
574
  print(f"Retrieved {len(texts)} embeddings in {elapsed:.2f} seconds")
512
575
  return np.vstack([hit_map[h] for h in hashes])
513
576
 
514
- def _process_missing_items_with_batches(self, missing_items: list[tuple[str, str]], hit_map: dict[str, np.ndarray]) -> None:
577
+ def _process_missing_items_with_batches(
578
+ self, missing_items: list[tuple[str, str]], hit_map: dict[str, np.ndarray]
579
+ ) -> None:
515
580
  """
516
- Process missing items in batches with progress bar and incremental DB insertion.
581
+ Process missing items in batches with simple progress tracking.
517
582
  """
518
583
  t = time() # Track total processing time
519
-
520
- # Try to import tqdm, fall back to simple progress if not available
521
- tqdm = None # avoid "possibly unbound" in type checker
522
- use_tqdm = False
523
- try:
524
- from tqdm import tqdm as _tqdm # type: ignore[import-not-found]
525
- tqdm = _tqdm
526
- use_tqdm = True
527
- except ImportError:
528
- use_tqdm = False
529
- if self.verbose:
530
- print("tqdm not available, using simple progress reporting")
531
-
584
+
532
585
  batch_size = self.config["embedding_batch_size"]
533
586
  total_items = len(missing_items)
534
-
587
+
535
588
  if self.verbose:
536
- print(f"Computing embeddings for {total_items} missing texts in batches of {batch_size}...")
589
+ print(
590
+ f"Computing embeddings for {total_items} missing texts in batches of {batch_size}..."
591
+ )
537
592
  if self.backend in ["vllm", "transformers"] and self._model is None:
538
- print(f"⚠️ Model will be loaded on first batch (lazy loading enabled)")
593
+ print("⚠️ Model will be loaded on first batch (lazy loading enabled)")
539
594
  elif self.backend in ["vllm", "transformers"]:
540
- print(f"✓ Model already loaded, ready for efficient batch processing")
541
-
542
- # Create progress bar
543
- pbar = None
544
- processed_count = 0
545
- if use_tqdm and tqdm is not None:
546
- pbar = tqdm(total=total_items, desc="Computing embeddings", unit="texts")
547
-
595
+ print("✓ Model already loaded, ready for efficient batch processing")
596
+
548
597
  # Track total committed items
549
598
  total_committed = 0
550
-
551
- try:
552
- # Process in batches
553
- for i in range(0, total_items, batch_size):
554
- batch_items = missing_items[i:i + batch_size]
555
- batch_texts = [text for text, _ in batch_items]
556
-
557
- # Get embeddings for this batch
558
- batch_embeds = self._get_embeddings(batch_texts)
559
-
560
- # Prepare batch data for immediate insert
561
- batch_data: list[tuple[str, str, bytes]] = []
562
- for (text, h), vec in zip(batch_items, batch_embeds):
563
- arr = np.asarray(vec, dtype=np.float32)
564
- batch_data.append((h, text, arr.tobytes()))
565
- hit_map[h] = arr
566
-
567
- # Immediate commit after each batch
568
- self._bulk_insert(batch_data)
569
- total_committed += len(batch_data)
570
-
571
- # Update progress
572
- batch_size_actual = len(batch_items)
573
- if use_tqdm:
574
- pbar.update(batch_size_actual)
575
- else:
576
- processed_count += batch_size_actual
577
- if self.verbose:
578
- print(f"Progress: {processed_count}/{total_items} embeddings computed, {total_committed} committed")
579
-
580
- finally:
581
- # Clean up progress bar
582
- if pbar is not None:
583
- pbar.close()
584
-
599
+ processed_count = 0
600
+
601
+ # Process in batches
602
+ for i in range(0, total_items, batch_size):
603
+ batch_items = missing_items[i : i + batch_size]
604
+ batch_texts = [text for text, _ in batch_items]
605
+
606
+ # Get embeddings for this batch
607
+ batch_embeds = self._get_embeddings(batch_texts)
608
+
609
+ # Prepare batch data for immediate insert
610
+ batch_data: list[tuple[str, str, bytes]] = []
611
+ for (text, h), vec in zip(batch_items, batch_embeds):
612
+ arr = np.asarray(vec, dtype=np.float32)
613
+ batch_data.append((h, text, arr.tobytes()))
614
+ hit_map[h] = arr
615
+
616
+ # Immediate commit after each batch
617
+ self._bulk_insert(batch_data)
618
+ total_committed += len(batch_data)
619
+
620
+ # Update progress - simple single line
621
+ batch_size_actual = len(batch_items)
622
+ processed_count += batch_size_actual
585
623
  if self.verbose:
586
- total_time = time() - t
587
- rate = total_items / total_time if total_time > 0 else 0
588
- print(f"✅ Completed: {total_items} embeddings computed and {total_committed} items committed to database")
589
- print(f" Total time: {total_time:.2f}s | Rate: {rate:.1f} embeddings/sec")
624
+ elapsed = time() - t
625
+ rate = processed_count / elapsed if elapsed > 0 else 0
626
+ progress_pct = (processed_count / total_items) * 100
627
+ print(
628
+ f"\rProgress: {processed_count}/{total_items} ({progress_pct:.1f}%) | {rate:.0f} texts/sec",
629
+ end="",
630
+ flush=True,
631
+ )
632
+
633
+ if self.verbose:
634
+ total_time = time() - t
635
+ rate = total_items / total_time if total_time > 0 else 0
636
+ print(
637
+ f"\n✅ Completed: {total_items} embeddings computed and {total_committed} items committed to database"
638
+ )
639
+ print(f" Total time: {total_time:.2f}s | Rate: {rate:.1f} embeddings/sec")
590
640
 
591
641
  def __call__(self, texts: list[str], cache: bool = True) -> np.ndarray:
592
642
  assert isinstance(texts, list), "texts must be a list"
593
- assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
643
+ assert all(isinstance(t, str) for t in texts), (
644
+ "all elements in texts must be strings"
645
+ )
594
646
  return self.embeds(texts, cache)
595
647
 
596
648
  def _bulk_insert(self, data: list[tuple[str, str, bytes]]) -> None:
597
649
  """
598
650
  Perform bulk insert of embedding data with retry logic for multi-process safety.
599
-
651
+
600
652
  Uses INSERT OR IGNORE to prevent race conditions where multiple processes
601
653
  might try to insert the same text hash. The first process to successfully
602
654
  insert wins, subsequent attempts are ignored. This ensures deterministic
@@ -607,7 +659,7 @@ class VectorCache:
607
659
 
608
660
  max_retries = 3
609
661
  base_delay = 0.1 # 100ms base delay
610
-
662
+
611
663
  for attempt in range(max_retries + 1):
612
664
  try:
613
665
  cursor = self.conn.executemany(
@@ -615,82 +667,34 @@ class VectorCache:
615
667
  data,
616
668
  )
617
669
  self.conn.commit()
618
-
670
+
619
671
  # Check if some insertions were ignored due to existing entries
620
- if self.verbose and cursor.rowcount < len(data):
621
- ignored_count = len(data) - cursor.rowcount
622
- if ignored_count > 0:
623
- print(f"ℹ️ {ignored_count}/{len(data)} embeddings already existed in cache (computed by another process)")
624
-
672
+ # if self.verbose and cursor.rowcount < len(data):
673
+ # ignored_count = len(data) - cursor.rowcount
674
+ # if ignored_count > 0:
675
+ # print(f"ℹ️ {ignored_count}/{len(data)} embeddings already existed in cache (computed by another process)")
676
+
625
677
  return # Success, exit the retry loop
626
-
678
+
627
679
  except sqlite3.OperationalError as e:
628
680
  if "database is locked" in str(e).lower() and attempt < max_retries:
629
681
  # Exponential backoff: 0.1s, 0.2s, 0.4s
630
- delay = base_delay * (2 ** attempt)
682
+ delay = base_delay * (2**attempt)
631
683
  if self.verbose:
632
- print(f"⚠️ Database locked, retrying in {delay:.1f}s (attempt {attempt + 1}/{max_retries + 1})")
684
+ print(
685
+ f"⚠️ Database locked, retrying in {delay:.1f}s (attempt {attempt + 1}/{max_retries + 1})"
686
+ )
633
687
  import time
688
+
634
689
  time.sleep(delay)
635
690
  continue
636
691
  else:
637
692
  # Re-raise if not a lock error or max retries exceeded
638
693
  raise
639
- except Exception as e:
694
+ except Exception:
640
695
  # Re-raise any other exceptions
641
696
  raise
642
697
 
643
- # def precompute_embeddings(self, texts: list[str]) -> None:
644
- # """
645
- # Precompute embeddings for a large list of texts efficiently.
646
- # This is optimized for bulk operations when you know all texts upfront.
647
- # """
648
- # assert isinstance(texts, list), "texts must be a list"
649
- # assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
650
- # if not texts:
651
- # return
652
-
653
- # # Remove duplicates while preserving order
654
- # unique_texts = list(dict.fromkeys(texts))
655
- # if self.verbose:
656
- # print(f"Precomputing embeddings for {len(unique_texts)} unique texts...")
657
-
658
- # # Check which ones are already cached
659
- # hashes = [self._hash_text(t) for t in unique_texts]
660
- # existing_hashes = set()
661
-
662
- # # Bulk check for existing embeddings
663
- # chunk_size = self.config["sqlite_chunk_size"]
664
- # for i in range(0, len(hashes), chunk_size):
665
- # chunk = hashes[i : i + chunk_size]
666
- # placeholders = ",".join("?" * len(chunk))
667
- # rows = self._execute_with_retry(
668
- # f"SELECT hash FROM cache WHERE hash IN ({placeholders})",
669
- # chunk,
670
- # ).fetchall()
671
- # existing_hashes.update(h[0] for h in rows)
672
-
673
- # # Find missing texts
674
- # missing_items = [
675
- # (t, h) for t, h in zip(unique_texts, hashes) if h not in existing_hashes
676
- # ]
677
-
678
- # if not missing_items:
679
- # if self.verbose:
680
- # print("All texts already cached!")
681
- # return
682
-
683
- # if self.verbose:
684
- # print(f"Computing {len(missing_items)} missing embeddings...")
685
-
686
- # # Process missing items with batches
687
- # missing_texts = [t for t, _ in missing_items]
688
- # missing_items_tupled = [(t, h) for t, h in zip(missing_texts, [self._hash_text(t) for t in missing_texts])]
689
- # hit_map_temp: dict[str, np.ndarray] = {}
690
- # self._process_missing_items_with_batches(missing_items_tupled, hit_map_temp)
691
- # if self.verbose:
692
- # print(f"Successfully cached {len(missing_items)} new embeddings!")
693
-
694
698
  def get_cache_stats(self) -> dict[str, int]:
695
699
  """Get statistics about the cache."""
696
700
  cursor = self._execute_with_retry("SELECT COUNT(*) FROM cache")
@@ -701,24 +705,27 @@ class VectorCache:
701
705
  """Clear all cached embeddings."""
702
706
  max_retries = 3
703
707
  base_delay = 0.1 # 100ms base delay
704
-
708
+
705
709
  for attempt in range(max_retries + 1):
706
710
  try:
707
711
  self.conn.execute("DELETE FROM cache")
708
712
  self.conn.commit()
709
713
  return # Success
710
-
714
+
711
715
  except sqlite3.OperationalError as e:
712
716
  if "database is locked" in str(e).lower() and attempt < max_retries:
713
- delay = base_delay * (2 ** attempt)
717
+ delay = base_delay * (2**attempt)
714
718
  if self.verbose:
715
- print(f"⚠️ Database locked during clear, retrying in {delay:.1f}s (attempt {attempt + 1}/{max_retries + 1})")
719
+ print(
720
+ f"⚠️ Database locked during clear, retrying in {delay:.1f}s (attempt {attempt + 1}/{max_retries + 1})"
721
+ )
716
722
  import time
723
+
717
724
  time.sleep(delay)
718
725
  continue
719
726
  else:
720
727
  raise
721
- except Exception as e:
728
+ except Exception:
722
729
  raise
723
730
 
724
731
  def get_config(self) -> Dict[str, Any]:
@@ -730,7 +737,7 @@ class VectorCache:
730
737
  "db_path": str(self.db_path),
731
738
  "verbose": self.verbose,
732
739
  "lazy": self.lazy,
733
- **self.config
740
+ **self.config,
734
741
  }
735
742
 
736
743
  def update_config(self, **kwargs) -> None:
@@ -744,17 +751,26 @@ class VectorCache:
744
751
  self.lazy = value
745
752
  else:
746
753
  raise ValueError(f"Unknown configuration parameter: {key}")
747
-
754
+
748
755
  # Reset model if backend-specific parameters changed
749
756
  backend_params = {
750
- "vllm": ["vllm_gpu_memory_utilization", "vllm_tensor_parallel_size", "vllm_dtype",
751
- "vllm_trust_remote_code", "vllm_max_model_len"],
752
- "transformers": ["transformers_device", "transformers_batch_size",
753
- "transformers_normalize_embeddings", "transformers_trust_remote_code"],
757
+ "vllm": [
758
+ "vllm_gpu_memory_utilization",
759
+ "vllm_tensor_parallel_size",
760
+ "vllm_dtype",
761
+ "vllm_trust_remote_code",
762
+ "vllm_max_model_len",
763
+ ],
764
+ "transformers": [
765
+ "transformers_device",
766
+ "transformers_batch_size",
767
+ "transformers_normalize_embeddings",
768
+ "transformers_trust_remote_code",
769
+ ],
754
770
  "openai": ["api_key", "model_name"],
755
- "processing": ["embedding_batch_size"]
771
+ "processing": ["embedding_batch_size"],
756
772
  }
757
-
773
+
758
774
  if any(param in kwargs for param in backend_params.get(self.backend, [])):
759
775
  self._model = None # Force reload on next use
760
776
  if self.backend == "openai":