speedy-utils 1.1.17__py3-none-any.whl → 1.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,6 +14,10 @@ class VectorCache:
14
14
  """
15
15
  A caching layer for text embeddings with support for multiple backends.
16
16
 
17
+ This cache is designed to be safe for multi-process environments where multiple
18
+ processes may access the same cache file simultaneously. It uses SQLite WAL mode
19
+ and retry logic with exponential backoff to handle concurrent access.
20
+
17
21
  Examples:
18
22
  # OpenAI API
19
23
  from llm_utils import VectorCache
@@ -32,11 +36,26 @@ class VectorCache:
32
36
  # Explicit backend specification
33
37
  cache = VectorCache("model-name", backend="transformers")
34
38
 
35
- # Lazy loading (default: True) - load model only when needed
39
+ # Eager loading (default: False) - load model immediately for better performance
40
+ cache = VectorCache("model-name", lazy=False)
41
+
42
+ # Lazy loading - load model only when needed (may cause performance issues)
36
43
  cache = VectorCache("model-name", lazy=True)
37
44
 
38
- # Eager loading - load model immediately
39
- cache = VectorCache("model-name", lazy=False)
45
+ Multi-Process Safety:
46
+ The cache uses SQLite WAL (Write-Ahead Logging) mode and implements retry logic
47
+ with exponential backoff to handle database locks. Multiple processes can safely
48
+ read and write to the same cache file simultaneously.
49
+
50
+ Race Condition Protection:
51
+ - Uses INSERT OR IGNORE to prevent overwrites when multiple processes compute the same text
52
+ - The first process to successfully cache a text wins, subsequent attempts are ignored
53
+ - This ensures deterministic results even with non-deterministic embedding models
54
+
55
+ For best performance in multi-process scenarios, consider:
56
+ - Using separate cache files per process if cache hits are low
57
+ - Coordinating cache warm-up to avoid redundant computation
58
+ - Monitor for excessive lock contention in high-concurrency scenarios
40
59
  """
41
60
  def __init__(
42
61
  self,
@@ -62,9 +81,11 @@ class VectorCache:
62
81
  sqlite_chunk_size: int = 999,
63
82
  sqlite_cache_size: int = 10000,
64
83
  sqlite_mmap_size: int = 268435456,
84
+ # Processing parameters
85
+ embedding_batch_size: int = 20_000,
65
86
  # Other parameters
66
87
  verbose: bool = True,
67
- lazy: bool = True,
88
+ lazy: bool = False,
68
89
  ) -> None:
69
90
  self.url_or_model = url_or_model
70
91
  self.embed_size = embed_size
@@ -95,6 +116,8 @@ class VectorCache:
95
116
  "sqlite_chunk_size": sqlite_chunk_size,
96
117
  "sqlite_cache_size": sqlite_cache_size,
97
118
  "sqlite_mmap_size": sqlite_mmap_size,
119
+ # Processing
120
+ "embedding_batch_size": embedding_batch_size,
98
121
  }
99
122
 
100
123
  # Auto-detect model_name for OpenAI if using custom URL and default model
@@ -147,10 +170,14 @@ class VectorCache:
147
170
 
148
171
  # Load model/client if not lazy
149
172
  if not self.lazy:
173
+ if self.verbose:
174
+ print(f"Loading {self.backend} model/client: {self.url_or_model}")
150
175
  if self.backend == "openai":
151
176
  self._load_openai_client()
152
177
  elif self.backend in ["vllm", "transformers"]:
153
178
  self._load_model()
179
+ if self.verbose:
180
+ print(f"✓ {self.backend.upper()} model/client loaded successfully")
154
181
 
155
182
  def _determine_backend(self, backend: Optional[Literal["vllm", "transformers", "openai"]]) -> str:
156
183
  """Determine the appropriate backend based on url_or_model and user preference."""
@@ -181,7 +208,7 @@ class VectorCache:
181
208
  print('Infer model name:', model_name)
182
209
  return model_name
183
210
  def _optimize_connection(self) -> None:
184
- """Optimize SQLite connection for bulk operations."""
211
+ """Optimize SQLite connection for bulk operations and multi-process safety."""
185
212
  # Performance optimizations for bulk operations
186
213
  self.conn.execute(
187
214
  "PRAGMA journal_mode=WAL"
@@ -190,6 +217,10 @@ class VectorCache:
190
217
  self.conn.execute(f"PRAGMA cache_size={self.config['sqlite_cache_size']}") # Configurable cache
191
218
  self.conn.execute("PRAGMA temp_store=MEMORY") # Use memory for temp storage
192
219
  self.conn.execute(f"PRAGMA mmap_size={self.config['sqlite_mmap_size']}") # Configurable memory mapping
220
+
221
+ # Multi-process safety improvements
222
+ self.conn.execute("PRAGMA busy_timeout=30000") # Wait up to 30 seconds for locks
223
+ self.conn.execute("PRAGMA wal_autocheckpoint=1000") # Checkpoint WAL every 1000 pages
193
224
 
194
225
  def _ensure_schema(self) -> None:
195
226
  self.conn.execute("""
@@ -216,7 +247,7 @@ class VectorCache:
216
247
  def _load_model(self) -> None:
217
248
  """Load the model for vLLM or Transformers."""
218
249
  if self.backend == "vllm":
219
- from vllm import LLM
250
+ from vllm import LLM # type: ignore[import-not-found]
220
251
 
221
252
  gpu_memory_utilization = cast(float, self.config["vllm_gpu_memory_utilization"])
222
253
  tensor_parallel_size = cast(int, self.config["vllm_tensor_parallel_size"])
@@ -257,8 +288,8 @@ class VectorCache:
257
288
  else:
258
289
  raise
259
290
  elif self.backend == "transformers":
260
- from transformers import AutoTokenizer, AutoModel
261
- import torch
291
+ from transformers import AutoTokenizer, AutoModel # type: ignore[import-not-found]
292
+ import torch # type: ignore[import-not-found]
262
293
 
263
294
  device = self.config["transformers_device"]
264
295
  # Handle "auto" device selection - default to CPU for transformers to avoid memory conflicts
@@ -296,7 +327,11 @@ class VectorCache:
296
327
  assert model_name is not None and model_name.strip(), f"Invalid model_name for OpenAI backend: {model_name}. Model name must be provided and non-empty."
297
328
 
298
329
  if self._client is None:
330
+ if self.verbose:
331
+ print("🔧 Loading OpenAI client...")
299
332
  self._load_openai_client()
333
+ if self.verbose:
334
+ print("✓ OpenAI client loaded successfully")
300
335
 
301
336
  response = self._client.embeddings.create( # type: ignore
302
337
  model=model_name,
@@ -310,7 +345,11 @@ class VectorCache:
310
345
  assert isinstance(texts, list), "texts must be a list"
311
346
  assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
312
347
  if self._model is None:
348
+ if self.verbose:
349
+ print("🔧 Loading vLLM model...")
313
350
  self._load_model()
351
+ if self.verbose:
352
+ print("✓ vLLM model loaded successfully")
314
353
 
315
354
  outputs = self._model.embed(texts) # type: ignore
316
355
  embeddings = [o.outputs.embedding for o in outputs]
@@ -321,7 +360,11 @@ class VectorCache:
321
360
  assert isinstance(texts, list), "texts must be a list"
322
361
  assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
323
362
  if self._model is None:
363
+ if self.verbose:
364
+ print("🔧 Loading Transformers model...")
324
365
  self._load_model()
366
+ if self.verbose:
367
+ print("✓ Transformers model loaded successfully")
325
368
 
326
369
  if not isinstance(self._model, dict):
327
370
  raise ValueError("Model not loaded properly for transformers backend")
@@ -348,7 +391,7 @@ class VectorCache:
348
391
  batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
349
392
 
350
393
  # Run model
351
- import torch
394
+ import torch # type: ignore[import-not-found]
352
395
  with torch.no_grad():
353
396
  outputs = model(**batch_dict)
354
397
 
@@ -357,14 +400,14 @@ class VectorCache:
357
400
 
358
401
  # Normalize if needed
359
402
  if normalize_embeddings:
360
- import torch.nn.functional as F
403
+ import torch.nn.functional as F # type: ignore[import-not-found]
361
404
  embeddings = F.normalize(embeddings, p=2, dim=1)
362
405
 
363
406
  return embeddings.cpu().numpy().tolist()
364
407
 
365
408
  def _last_token_pool(self, last_hidden_states, attention_mask):
366
409
  """Apply last token pooling to get embeddings."""
367
- import torch
410
+ import torch # type: ignore[import-not-found]
368
411
  left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
369
412
  if left_padding:
370
413
  return last_hidden_states[:, -1]
@@ -376,6 +419,40 @@ class VectorCache:
376
419
  def _hash_text(self, text: str) -> str:
377
420
  return hashlib.sha1(text.encode("utf-8")).hexdigest()
378
421
 
422
+ def _execute_with_retry(self, query: str, params=None) -> sqlite3.Cursor:
423
+ """Execute SQLite query with retry logic for multi-process safety."""
424
+ max_retries = 3
425
+ base_delay = 0.05 # 50ms base delay for reads (faster than writes)
426
+
427
+ last_exception = None
428
+
429
+ for attempt in range(max_retries + 1):
430
+ try:
431
+ if params is None:
432
+ return self.conn.execute(query)
433
+ else:
434
+ return self.conn.execute(query, params)
435
+
436
+ except sqlite3.OperationalError as e:
437
+ last_exception = e
438
+ if "database is locked" in str(e).lower() and attempt < max_retries:
439
+ # Exponential backoff: 0.05s, 0.1s, 0.2s
440
+ delay = base_delay * (2 ** attempt)
441
+ if self.verbose:
442
+ print(f"⚠️ Database locked on read, retrying in {delay:.2f}s (attempt {attempt + 1}/{max_retries + 1})")
443
+ import time
444
+ time.sleep(delay)
445
+ continue
446
+ else:
447
+ # Re-raise if not a lock error or max retries exceeded
448
+ raise
449
+ except Exception as e:
450
+ # Re-raise any other exceptions
451
+ raise
452
+
453
+ # This should never be reached, but satisfy the type checker
454
+ raise last_exception or RuntimeError("Failed to execute query after retries")
455
+
379
456
  def embeds(self, texts: list[str], cache: bool = True) -> np.ndarray:
380
457
  """
381
458
  Return embeddings for all texts.
@@ -402,11 +479,11 @@ class VectorCache:
402
479
  hit_map: dict[str, np.ndarray] = {}
403
480
  chunk_size = self.config["sqlite_chunk_size"]
404
481
 
405
- # Use bulk lookup with optimized query
482
+ # Use bulk lookup with optimized query and retry logic
406
483
  hash_chunks = _chunks(hashes, chunk_size)
407
484
  for chunk in hash_chunks:
408
485
  placeholders = ",".join("?" * len(chunk))
409
- rows = self.conn.execute(
486
+ rows = self._execute_with_retry(
410
487
  f"SELECT hash, embedding FROM cache WHERE hash IN ({placeholders})",
411
488
  chunk,
412
489
  ).fetchall()
@@ -425,18 +502,8 @@ class VectorCache:
425
502
 
426
503
  if missing_items:
427
504
  if self.verbose:
428
- print(f"Computing embeddings for {len(missing_items)} missing texts...")
429
- missing_texts = [t for t, _ in missing_items]
430
- embeds = self._get_embeddings(missing_texts)
431
-
432
- # Prepare batch data for bulk insert
433
- bulk_insert_data: list[tuple[str, str, bytes]] = []
434
- for (text, h), vec in zip(missing_items, embeds):
435
- arr = np.asarray(vec, dtype=np.float32)
436
- bulk_insert_data.append((h, text, arr.tobytes()))
437
- hit_map[h] = arr
438
-
439
- self._bulk_insert(bulk_insert_data)
505
+ print(f"Computing {len(missing_items)}/{len(texts)} missing embeddings...")
506
+ self._process_missing_items_with_batches(missing_items, hit_map)
440
507
 
441
508
  # Return embeddings in the original order
442
509
  elapsed = time() - t
@@ -444,87 +511,215 @@ class VectorCache:
444
511
  print(f"Retrieved {len(texts)} embeddings in {elapsed:.2f} seconds")
445
512
  return np.vstack([hit_map[h] for h in hashes])
446
513
 
514
+ def _process_missing_items_with_batches(self, missing_items: list[tuple[str, str]], hit_map: dict[str, np.ndarray]) -> None:
515
+ """
516
+ Process missing items in batches with progress bar and incremental DB insertion.
517
+ """
518
+ t = time() # Track total processing time
519
+
520
+ # Try to import tqdm, fall back to simple progress if not available
521
+ tqdm = None # avoid "possibly unbound" in type checker
522
+ use_tqdm = False
523
+ try:
524
+ from tqdm import tqdm as _tqdm # type: ignore[import-not-found]
525
+ tqdm = _tqdm
526
+ use_tqdm = True
527
+ except ImportError:
528
+ use_tqdm = False
529
+ if self.verbose:
530
+ print("tqdm not available, using simple progress reporting")
531
+
532
+ batch_size = self.config["embedding_batch_size"]
533
+ total_items = len(missing_items)
534
+
535
+ if self.verbose:
536
+ print(f"Computing embeddings for {total_items} missing texts in batches of {batch_size}...")
537
+ if self.backend in ["vllm", "transformers"] and self._model is None:
538
+ print(f"⚠️ Model will be loaded on first batch (lazy loading enabled)")
539
+ elif self.backend in ["vllm", "transformers"]:
540
+ print(f"✓ Model already loaded, ready for efficient batch processing")
541
+
542
+ # Create progress bar
543
+ pbar = None
544
+ processed_count = 0
545
+ if use_tqdm and tqdm is not None:
546
+ pbar = tqdm(total=total_items, desc="Computing embeddings", unit="texts")
547
+
548
+ # Track total committed items
549
+ total_committed = 0
550
+
551
+ try:
552
+ # Process in batches
553
+ for i in range(0, total_items, batch_size):
554
+ batch_items = missing_items[i:i + batch_size]
555
+ batch_texts = [text for text, _ in batch_items]
556
+
557
+ # Get embeddings for this batch
558
+ batch_embeds = self._get_embeddings(batch_texts)
559
+
560
+ # Prepare batch data for immediate insert
561
+ batch_data: list[tuple[str, str, bytes]] = []
562
+ for (text, h), vec in zip(batch_items, batch_embeds):
563
+ arr = np.asarray(vec, dtype=np.float32)
564
+ batch_data.append((h, text, arr.tobytes()))
565
+ hit_map[h] = arr
566
+
567
+ # Immediate commit after each batch
568
+ self._bulk_insert(batch_data)
569
+ total_committed += len(batch_data)
570
+
571
+ # Update progress
572
+ batch_size_actual = len(batch_items)
573
+ if use_tqdm:
574
+ pbar.update(batch_size_actual)
575
+ else:
576
+ processed_count += batch_size_actual
577
+ if self.verbose:
578
+ print(f"Progress: {processed_count}/{total_items} embeddings computed, {total_committed} committed")
579
+
580
+ finally:
581
+ # Clean up progress bar
582
+ if pbar is not None:
583
+ pbar.close()
584
+
585
+ if self.verbose:
586
+ total_time = time() - t
587
+ rate = total_items / total_time if total_time > 0 else 0
588
+ print(f"✅ Completed: {total_items} embeddings computed and {total_committed} items committed to database")
589
+ print(f" Total time: {total_time:.2f}s | Rate: {rate:.1f} embeddings/sec")
590
+
447
591
  def __call__(self, texts: list[str], cache: bool = True) -> np.ndarray:
448
592
  assert isinstance(texts, list), "texts must be a list"
449
593
  assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
450
594
  return self.embeds(texts, cache)
451
595
 
452
596
  def _bulk_insert(self, data: list[tuple[str, str, bytes]]) -> None:
453
- """Perform bulk insert of embedding data."""
454
- if not data:
455
- return
456
-
457
- self.conn.executemany(
458
- "INSERT OR REPLACE INTO cache (hash, text, embedding) VALUES (?, ?, ?)",
459
- data,
460
- )
461
- self.conn.commit()
462
-
463
- def precompute_embeddings(self, texts: list[str]) -> None:
464
597
  """
465
- Precompute embeddings for a large list of texts efficiently.
466
- This is optimized for bulk operations when you know all texts upfront.
598
+ Perform bulk insert of embedding data with retry logic for multi-process safety.
599
+
600
+ Uses INSERT OR IGNORE to prevent race conditions where multiple processes
601
+ might try to insert the same text hash. The first process to successfully
602
+ insert wins, subsequent attempts are ignored. This ensures deterministic
603
+ caching behavior in multi-process environments.
467
604
  """
468
- assert isinstance(texts, list), "texts must be a list"
469
- assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
470
- if not texts:
471
- return
472
-
473
- # Remove duplicates while preserving order
474
- unique_texts = list(dict.fromkeys(texts))
475
- if self.verbose:
476
- print(f"Precomputing embeddings for {len(unique_texts)} unique texts...")
477
-
478
- # Check which ones are already cached
479
- hashes = [self._hash_text(t) for t in unique_texts]
480
- existing_hashes = set()
481
-
482
- # Bulk check for existing embeddings
483
- chunk_size = self.config["sqlite_chunk_size"]
484
- for i in range(0, len(hashes), chunk_size):
485
- chunk = hashes[i : i + chunk_size]
486
- placeholders = ",".join("?" * len(chunk))
487
- rows = self.conn.execute(
488
- f"SELECT hash FROM cache WHERE hash IN ({placeholders})",
489
- chunk,
490
- ).fetchall()
491
- existing_hashes.update(h[0] for h in rows)
492
-
493
- # Find missing texts
494
- missing_items = [
495
- (t, h) for t, h in zip(unique_texts, hashes) if h not in existing_hashes
496
- ]
497
-
498
- if not missing_items:
499
- if self.verbose:
500
- print("All texts already cached!")
605
+ if not data:
501
606
  return
502
607
 
503
- if self.verbose:
504
- print(f"Computing {len(missing_items)} missing embeddings...")
505
- missing_texts = [t for t, _ in missing_items]
506
- embeds = self._get_embeddings(missing_texts)
507
-
508
- # Prepare batch data for bulk insert
509
- bulk_insert_data: list[tuple[str, str, bytes]] = []
510
- for (text, h), vec in zip(missing_items, embeds):
511
- arr = np.asarray(vec, dtype=np.float32)
512
- bulk_insert_data.append((h, text, arr.tobytes()))
513
-
514
- self._bulk_insert(bulk_insert_data)
515
- if self.verbose:
516
- print(f"Successfully cached {len(missing_items)} new embeddings!")
608
+ max_retries = 3
609
+ base_delay = 0.1 # 100ms base delay
610
+
611
+ for attempt in range(max_retries + 1):
612
+ try:
613
+ cursor = self.conn.executemany(
614
+ "INSERT OR IGNORE INTO cache (hash, text, embedding) VALUES (?, ?, ?)",
615
+ data,
616
+ )
617
+ self.conn.commit()
618
+
619
+ # Check if some insertions were ignored due to existing entries
620
+ if self.verbose and cursor.rowcount < len(data):
621
+ ignored_count = len(data) - cursor.rowcount
622
+ if ignored_count > 0:
623
+ print(f"ℹ️ {ignored_count}/{len(data)} embeddings already existed in cache (computed by another process)")
624
+
625
+ return # Success, exit the retry loop
626
+
627
+ except sqlite3.OperationalError as e:
628
+ if "database is locked" in str(e).lower() and attempt < max_retries:
629
+ # Exponential backoff: 0.1s, 0.2s, 0.4s
630
+ delay = base_delay * (2 ** attempt)
631
+ if self.verbose:
632
+ print(f"⚠️ Database locked, retrying in {delay:.1f}s (attempt {attempt + 1}/{max_retries + 1})")
633
+ import time
634
+ time.sleep(delay)
635
+ continue
636
+ else:
637
+ # Re-raise if not a lock error or max retries exceeded
638
+ raise
639
+ except Exception as e:
640
+ # Re-raise any other exceptions
641
+ raise
642
+
643
+ # def precompute_embeddings(self, texts: list[str]) -> None:
644
+ # """
645
+ # Precompute embeddings for a large list of texts efficiently.
646
+ # This is optimized for bulk operations when you know all texts upfront.
647
+ # """
648
+ # assert isinstance(texts, list), "texts must be a list"
649
+ # assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
650
+ # if not texts:
651
+ # return
652
+
653
+ # # Remove duplicates while preserving order
654
+ # unique_texts = list(dict.fromkeys(texts))
655
+ # if self.verbose:
656
+ # print(f"Precomputing embeddings for {len(unique_texts)} unique texts...")
657
+
658
+ # # Check which ones are already cached
659
+ # hashes = [self._hash_text(t) for t in unique_texts]
660
+ # existing_hashes = set()
661
+
662
+ # # Bulk check for existing embeddings
663
+ # chunk_size = self.config["sqlite_chunk_size"]
664
+ # for i in range(0, len(hashes), chunk_size):
665
+ # chunk = hashes[i : i + chunk_size]
666
+ # placeholders = ",".join("?" * len(chunk))
667
+ # rows = self._execute_with_retry(
668
+ # f"SELECT hash FROM cache WHERE hash IN ({placeholders})",
669
+ # chunk,
670
+ # ).fetchall()
671
+ # existing_hashes.update(h[0] for h in rows)
672
+
673
+ # # Find missing texts
674
+ # missing_items = [
675
+ # (t, h) for t, h in zip(unique_texts, hashes) if h not in existing_hashes
676
+ # ]
677
+
678
+ # if not missing_items:
679
+ # if self.verbose:
680
+ # print("All texts already cached!")
681
+ # return
682
+
683
+ # if self.verbose:
684
+ # print(f"Computing {len(missing_items)} missing embeddings...")
685
+
686
+ # # Process missing items with batches
687
+ # missing_texts = [t for t, _ in missing_items]
688
+ # missing_items_tupled = [(t, h) for t, h in zip(missing_texts, [self._hash_text(t) for t in missing_texts])]
689
+ # hit_map_temp: dict[str, np.ndarray] = {}
690
+ # self._process_missing_items_with_batches(missing_items_tupled, hit_map_temp)
691
+ # if self.verbose:
692
+ # print(f"Successfully cached {len(missing_items)} new embeddings!")
517
693
 
518
694
  def get_cache_stats(self) -> dict[str, int]:
519
695
  """Get statistics about the cache."""
520
- cursor = self.conn.execute("SELECT COUNT(*) FROM cache")
696
+ cursor = self._execute_with_retry("SELECT COUNT(*) FROM cache")
521
697
  count = cursor.fetchone()[0]
522
698
  return {"total_cached": count}
523
699
 
524
700
  def clear_cache(self) -> None:
525
701
  """Clear all cached embeddings."""
526
- self.conn.execute("DELETE FROM cache")
527
- self.conn.commit()
702
+ max_retries = 3
703
+ base_delay = 0.1 # 100ms base delay
704
+
705
+ for attempt in range(max_retries + 1):
706
+ try:
707
+ self.conn.execute("DELETE FROM cache")
708
+ self.conn.commit()
709
+ return # Success
710
+
711
+ except sqlite3.OperationalError as e:
712
+ if "database is locked" in str(e).lower() and attempt < max_retries:
713
+ delay = base_delay * (2 ** attempt)
714
+ if self.verbose:
715
+ print(f"⚠️ Database locked during clear, retrying in {delay:.1f}s (attempt {attempt + 1}/{max_retries + 1})")
716
+ import time
717
+ time.sleep(delay)
718
+ continue
719
+ else:
720
+ raise
721
+ except Exception as e:
722
+ raise
528
723
 
529
724
  def get_config(self) -> Dict[str, Any]:
530
725
  """Get current configuration."""
@@ -556,7 +751,8 @@ class VectorCache:
556
751
  "vllm_trust_remote_code", "vllm_max_model_len"],
557
752
  "transformers": ["transformers_device", "transformers_batch_size",
558
753
  "transformers_normalize_embeddings", "transformers_trust_remote_code"],
559
- "openai": ["api_key", "model_name"]
754
+ "openai": ["api_key", "model_name"],
755
+ "processing": ["embedding_batch_size"]
560
756
  }
561
757
 
562
758
  if any(param in kwargs for param in backend_params.get(self.backend, [])):
@@ -0,0 +1,68 @@
1
+ # utils/patching.py
2
+ import inspect
3
+ import types
4
+ import re
5
+ from typing import Annotated, Union
6
+
7
+ def patch_method(
8
+ cls: Annotated[type, "Class containing the method"],
9
+ method_name: Annotated[str, "Name of the method to patch"],
10
+ replacements: Annotated[
11
+ dict[Union[str, re.Pattern], str],
12
+ "Mapping of {old_substring_or_regex: new_string} replacements"
13
+ ],
14
+ tag: Annotated[str, "Optional logging tag"] = "",
15
+ ) -> bool:
16
+ """
17
+ Generic patcher for replacing substrings or regex matches in a method's source code.
18
+
19
+ Args:
20
+ cls: Target class
21
+ method_name: Method name to patch
22
+ replacements: {pattern: replacement}. Patterns may be plain strings or regex patterns.
23
+ tag: Optional string shown in logs
24
+
25
+ Returns:
26
+ bool: True if successfully patched, False otherwise
27
+ """
28
+
29
+ try:
30
+ method = getattr(cls, method_name)
31
+ except AttributeError:
32
+ print(f"[patcher{':' + tag if tag else ''}] No method {method_name} in {cls.__name__}")
33
+ return False
34
+
35
+ try:
36
+ src = inspect.getsource(method)
37
+ except (OSError, TypeError):
38
+ print(f"[patcher{':' + tag if tag else ''}] Could not get source for {cls.__name__}.{method_name}")
39
+ return False
40
+
41
+ new_src = src
42
+ did_patch = False
43
+
44
+ for old, new in replacements.items():
45
+ if isinstance(old, re.Pattern):
46
+ if old.search(new_src):
47
+ new_src = old.sub(new, new_src)
48
+ did_patch = True
49
+ elif isinstance(old, str):
50
+ if old in new_src:
51
+ new_src = new_src.replace(old, new)
52
+ did_patch = True
53
+ else:
54
+ raise TypeError("Replacement keys must be str or re.Pattern")
55
+
56
+ if not did_patch:
57
+ print(f"[patcher{':' + tag if tag else ''}] No matching patterns found in {cls.__name__}.{method_name}")
58
+ return False
59
+
60
+ # Recompile the patched function
61
+ code_obj = compile(new_src, filename=f"<patched_{method_name}>", mode="exec")
62
+ ns = {}
63
+ exec(code_obj, cls.__dict__, ns) # type: ignore
64
+
65
+ # Attach patched method back
66
+ setattr(cls, method_name, types.MethodType(ns[method_name], None, cls)) # type: ignore
67
+ print(f"[patcher{':' + tag if tag else ''}] Patched {cls.__name__}.{method_name}")
68
+ return True
@@ -563,7 +563,7 @@ def _async_both_memoize(
563
563
 
564
564
  @overload
565
565
  def memoize(
566
- _func: Callable[P, R],
566
+ _func: Callable[P, R | Awaitable[R]],
567
567
  *,
568
568
  keys: Optional[list[str]] = ...,
569
569
  key: Optional[Callable[..., Any]] = ...,
@@ -572,10 +572,10 @@ def memoize(
572
572
  size: int = ...,
573
573
  ignore_self: bool = ...,
574
574
  verbose: bool = ...,
575
- ) -> Callable[P, R]: ...
575
+ ) -> Callable[P, R | Awaitable[R]]: ...
576
576
  @overload
577
577
  def memoize(
578
- _func: Callable[P, Awaitable[R]],
578
+ _func: None = ...,
579
579
  *,
580
580
  keys: Optional[list[str]] = ...,
581
581
  key: Optional[Callable[..., Any]] = ...,
@@ -584,7 +584,7 @@ def memoize(
584
584
  size: int = ...,
585
585
  ignore_self: bool = ...,
586
586
  verbose: bool = ...,
587
- ) -> Callable[P, Awaitable[R]]: ...
587
+ ) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
588
588
  @overload
589
589
  def memoize(
590
590
  _func: None = ...,
@@ -596,7 +596,7 @@ def memoize(
596
596
  size: int = ...,
597
597
  ignore_self: bool = ...,
598
598
  verbose: bool = ...,
599
- ) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
599
+ ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ...
600
600
 
601
601
 
602
602
  def memoize(