speedy-utils 1.1.18__py3-none-any.whl → 1.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +3 -2
- llm_utils/lm/async_lm/async_llm_task.py +1 -0
- llm_utils/lm/llm_task.py +303 -10
- llm_utils/lm/openai_memoize.py +10 -2
- llm_utils/vector_cache/core.py +250 -234
- speedy_utils/__init__.py +2 -1
- speedy_utils/common/utils_cache.py +38 -19
- speedy_utils/common/utils_io.py +9 -5
- speedy_utils/multi_worker/process.py +91 -10
- speedy_utils/multi_worker/thread.py +94 -2
- {speedy_utils-1.1.18.dist-info → speedy_utils-1.1.20.dist-info}/METADATA +34 -13
- {speedy_utils-1.1.18.dist-info → speedy_utils-1.1.20.dist-info}/RECORD +19 -19
- {speedy_utils-1.1.18.dist-info → speedy_utils-1.1.20.dist-info}/WHEEL +1 -1
- speedy_utils-1.1.20.dist-info/entry_points.txt +5 -0
- speedy_utils-1.1.18.dist-info/entry_points.txt +0 -6
llm_utils/vector_cache/core.py
CHANGED
|
@@ -13,50 +13,51 @@ import numpy as np
|
|
|
13
13
|
class VectorCache:
|
|
14
14
|
"""
|
|
15
15
|
A caching layer for text embeddings with support for multiple backends.
|
|
16
|
-
|
|
16
|
+
|
|
17
17
|
This cache is designed to be safe for multi-process environments where multiple
|
|
18
18
|
processes may access the same cache file simultaneously. It uses SQLite WAL mode
|
|
19
19
|
and retry logic with exponential backoff to handle concurrent access.
|
|
20
|
-
|
|
20
|
+
|
|
21
21
|
Examples:
|
|
22
22
|
# OpenAI API
|
|
23
23
|
from llm_utils import VectorCache
|
|
24
24
|
cache = VectorCache("https://api.openai.com/v1", api_key="your-key")
|
|
25
25
|
embeddings = cache.embeds(["Hello world", "How are you?"])
|
|
26
|
-
|
|
26
|
+
|
|
27
27
|
# Custom OpenAI-compatible server (auto-detects model)
|
|
28
28
|
cache = VectorCache("http://localhost:8000/v1", api_key="abc")
|
|
29
|
-
|
|
29
|
+
|
|
30
30
|
# Transformers (Sentence Transformers)
|
|
31
31
|
cache = VectorCache("sentence-transformers/all-MiniLM-L6-v2")
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
# vLLM (local model)
|
|
34
34
|
cache = VectorCache("/path/to/model")
|
|
35
|
-
|
|
35
|
+
|
|
36
36
|
# Explicit backend specification
|
|
37
37
|
cache = VectorCache("model-name", backend="transformers")
|
|
38
|
-
|
|
38
|
+
|
|
39
39
|
# Eager loading (default: False) - load model immediately for better performance
|
|
40
40
|
cache = VectorCache("model-name", lazy=False)
|
|
41
|
-
|
|
41
|
+
|
|
42
42
|
# Lazy loading - load model only when needed (may cause performance issues)
|
|
43
43
|
cache = VectorCache("model-name", lazy=True)
|
|
44
|
-
|
|
44
|
+
|
|
45
45
|
Multi-Process Safety:
|
|
46
46
|
The cache uses SQLite WAL (Write-Ahead Logging) mode and implements retry logic
|
|
47
47
|
with exponential backoff to handle database locks. Multiple processes can safely
|
|
48
48
|
read and write to the same cache file simultaneously.
|
|
49
|
-
|
|
49
|
+
|
|
50
50
|
Race Condition Protection:
|
|
51
51
|
- Uses INSERT OR IGNORE to prevent overwrites when multiple processes compute the same text
|
|
52
52
|
- The first process to successfully cache a text wins, subsequent attempts are ignored
|
|
53
53
|
- This ensures deterministic results even with non-deterministic embedding models
|
|
54
|
-
|
|
54
|
+
|
|
55
55
|
For best performance in multi-process scenarios, consider:
|
|
56
56
|
- Using separate cache files per process if cache hits are low
|
|
57
57
|
- Coordinating cache warm-up to avoid redundant computation
|
|
58
58
|
- Monitor for excessive lock contention in high-concurrency scenarios
|
|
59
59
|
"""
|
|
60
|
+
|
|
60
61
|
def __init__(
|
|
61
62
|
self,
|
|
62
63
|
url_or_model: str,
|
|
@@ -80,7 +81,7 @@ class VectorCache:
|
|
|
80
81
|
# SQLite parameters
|
|
81
82
|
sqlite_chunk_size: int = 999,
|
|
82
83
|
sqlite_cache_size: int = 10000,
|
|
83
|
-
sqlite_mmap_size: int = 268435456,
|
|
84
|
+
sqlite_mmap_size: int = 268435456, # 256MB
|
|
84
85
|
# Processing parameters
|
|
85
86
|
embedding_batch_size: int = 20_000,
|
|
86
87
|
# Other parameters
|
|
@@ -91,11 +92,11 @@ class VectorCache:
|
|
|
91
92
|
self.embed_size = embed_size
|
|
92
93
|
self.verbose = verbose
|
|
93
94
|
self.lazy = lazy
|
|
94
|
-
|
|
95
|
+
|
|
95
96
|
self.backend = self._determine_backend(backend)
|
|
96
97
|
if self.verbose and backend is None:
|
|
97
98
|
print(f"Auto-detected backend: {self.backend}")
|
|
98
|
-
|
|
99
|
+
|
|
99
100
|
# Store all configuration parameters
|
|
100
101
|
self.config = {
|
|
101
102
|
# OpenAI
|
|
@@ -119,18 +120,20 @@ class VectorCache:
|
|
|
119
120
|
# Processing
|
|
120
121
|
"embedding_batch_size": embedding_batch_size,
|
|
121
122
|
}
|
|
122
|
-
|
|
123
|
+
|
|
123
124
|
# Auto-detect model_name for OpenAI if using custom URL and default model
|
|
124
|
-
if (
|
|
125
|
-
|
|
126
|
-
|
|
125
|
+
if (
|
|
126
|
+
self.backend == "openai"
|
|
127
|
+
and model_name == "text-embedding-3-small"
|
|
128
|
+
and self.url_or_model != "https://api.openai.com/v1"
|
|
129
|
+
):
|
|
127
130
|
if self.verbose:
|
|
128
131
|
print(f"Attempting to auto-detect model from {self.url_or_model}...")
|
|
129
132
|
try:
|
|
130
133
|
import openai
|
|
134
|
+
|
|
131
135
|
client = openai.OpenAI(
|
|
132
|
-
base_url=self.url_or_model,
|
|
133
|
-
api_key=self.config["api_key"]
|
|
136
|
+
base_url=self.url_or_model, api_key=self.config["api_key"]
|
|
134
137
|
)
|
|
135
138
|
models = client.models.list()
|
|
136
139
|
if models.data:
|
|
@@ -147,7 +150,7 @@ class VectorCache:
|
|
|
147
150
|
print(f"Model auto-detection failed: {e}, using default model")
|
|
148
151
|
# Fallback to default if auto-detection fails
|
|
149
152
|
pass
|
|
150
|
-
|
|
153
|
+
|
|
151
154
|
# Set default db_path if not provided
|
|
152
155
|
if db_path is None:
|
|
153
156
|
if self.backend == "openai":
|
|
@@ -155,19 +158,21 @@ class VectorCache:
|
|
|
155
158
|
else:
|
|
156
159
|
model_id = self.url_or_model
|
|
157
160
|
safe_name = hashlib.sha1(model_id.encode("utf-8")).hexdigest()[:16]
|
|
158
|
-
self.db_path =
|
|
161
|
+
self.db_path = (
|
|
162
|
+
Path.home() / ".cache" / "embed" / f"{self.backend}_{safe_name}.sqlite"
|
|
163
|
+
)
|
|
159
164
|
else:
|
|
160
165
|
self.db_path = Path(db_path)
|
|
161
|
-
|
|
166
|
+
|
|
162
167
|
# Ensure the directory exists
|
|
163
168
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
164
|
-
|
|
169
|
+
|
|
165
170
|
self.conn = sqlite3.connect(self.db_path)
|
|
166
171
|
self._optimize_connection()
|
|
167
172
|
self._ensure_schema()
|
|
168
173
|
self._model = None # Lazy loading
|
|
169
174
|
self._client = None # For OpenAI client
|
|
170
|
-
|
|
175
|
+
|
|
171
176
|
# Load model/client if not lazy
|
|
172
177
|
if not self.lazy:
|
|
173
178
|
if self.verbose:
|
|
@@ -179,34 +184,41 @@ class VectorCache:
|
|
|
179
184
|
if self.verbose:
|
|
180
185
|
print(f"✓ {self.backend.upper()} model/client loaded successfully")
|
|
181
186
|
|
|
182
|
-
def _determine_backend(
|
|
187
|
+
def _determine_backend(
|
|
188
|
+
self, backend: Optional[Literal["vllm", "transformers", "openai"]]
|
|
189
|
+
) -> str:
|
|
183
190
|
"""Determine the appropriate backend based on url_or_model and user preference."""
|
|
184
191
|
if backend is not None:
|
|
185
192
|
valid_backends = ["vllm", "transformers", "openai"]
|
|
186
193
|
if backend not in valid_backends:
|
|
187
|
-
raise ValueError(
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Invalid backend '{backend}'. Must be one of: {valid_backends}"
|
|
196
|
+
)
|
|
188
197
|
return backend
|
|
189
|
-
|
|
198
|
+
|
|
190
199
|
if self.url_or_model.startswith("http"):
|
|
191
200
|
return "openai"
|
|
192
|
-
|
|
201
|
+
|
|
193
202
|
# Default to vllm for local models
|
|
194
203
|
return "vllm"
|
|
204
|
+
|
|
195
205
|
def _try_infer_model_name(self, model_name: Optional[str]) -> Optional[str]:
|
|
196
206
|
"""Infer model name for OpenAI backend if not explicitly provided."""
|
|
197
207
|
if model_name:
|
|
198
208
|
return model_name
|
|
199
|
-
if
|
|
200
|
-
model_name =
|
|
201
|
-
|
|
202
|
-
if
|
|
209
|
+
if "https://" in self.url_or_model:
|
|
210
|
+
model_name = "text-embedding-3-small"
|
|
211
|
+
|
|
212
|
+
if "http://localhost" in self.url_or_model:
|
|
203
213
|
from openai import OpenAI
|
|
204
|
-
|
|
205
|
-
|
|
214
|
+
|
|
215
|
+
client = OpenAI(base_url=self.url_or_model, api_key="abc")
|
|
216
|
+
model_name = client.models.list().data[0].id
|
|
206
217
|
|
|
207
218
|
# Default model name
|
|
208
|
-
print(
|
|
219
|
+
print("Infer model name:", model_name)
|
|
209
220
|
return model_name
|
|
221
|
+
|
|
210
222
|
def _optimize_connection(self) -> None:
|
|
211
223
|
"""Optimize SQLite connection for bulk operations and multi-process safety."""
|
|
212
224
|
# Performance optimizations for bulk operations
|
|
@@ -214,13 +226,21 @@ class VectorCache:
|
|
|
214
226
|
"PRAGMA journal_mode=WAL"
|
|
215
227
|
) # Write-Ahead Logging for better concurrency
|
|
216
228
|
self.conn.execute("PRAGMA synchronous=NORMAL") # Faster writes, still safe
|
|
217
|
-
self.conn.execute(
|
|
229
|
+
self.conn.execute(
|
|
230
|
+
f"PRAGMA cache_size={self.config['sqlite_cache_size']}"
|
|
231
|
+
) # Configurable cache
|
|
218
232
|
self.conn.execute("PRAGMA temp_store=MEMORY") # Use memory for temp storage
|
|
219
|
-
self.conn.execute(
|
|
220
|
-
|
|
233
|
+
self.conn.execute(
|
|
234
|
+
f"PRAGMA mmap_size={self.config['sqlite_mmap_size']}"
|
|
235
|
+
) # Configurable memory mapping
|
|
236
|
+
|
|
221
237
|
# Multi-process safety improvements
|
|
222
|
-
self.conn.execute(
|
|
223
|
-
|
|
238
|
+
self.conn.execute(
|
|
239
|
+
"PRAGMA busy_timeout=30000"
|
|
240
|
+
) # Wait up to 30 seconds for locks
|
|
241
|
+
self.conn.execute(
|
|
242
|
+
"PRAGMA wal_autocheckpoint=1000"
|
|
243
|
+
) # Checkpoint WAL every 1000 pages
|
|
224
244
|
|
|
225
245
|
def _ensure_schema(self) -> None:
|
|
226
246
|
self.conn.execute("""
|
|
@@ -239,22 +259,24 @@ class VectorCache:
|
|
|
239
259
|
def _load_openai_client(self) -> None:
|
|
240
260
|
"""Load OpenAI client."""
|
|
241
261
|
import openai
|
|
262
|
+
|
|
242
263
|
self._client = openai.OpenAI(
|
|
243
|
-
base_url=self.url_or_model,
|
|
244
|
-
api_key=self.config["api_key"]
|
|
264
|
+
base_url=self.url_or_model, api_key=self.config["api_key"]
|
|
245
265
|
)
|
|
246
266
|
|
|
247
267
|
def _load_model(self) -> None:
|
|
248
268
|
"""Load the model for vLLM or Transformers."""
|
|
249
269
|
if self.backend == "vllm":
|
|
250
270
|
from vllm import LLM # type: ignore[import-not-found]
|
|
251
|
-
|
|
252
|
-
gpu_memory_utilization = cast(
|
|
271
|
+
|
|
272
|
+
gpu_memory_utilization = cast(
|
|
273
|
+
float, self.config["vllm_gpu_memory_utilization"]
|
|
274
|
+
)
|
|
253
275
|
tensor_parallel_size = cast(int, self.config["vllm_tensor_parallel_size"])
|
|
254
276
|
dtype = cast(str, self.config["vllm_dtype"])
|
|
255
277
|
trust_remote_code = cast(bool, self.config["vllm_trust_remote_code"])
|
|
256
278
|
max_model_len = cast(Optional[int], self.config["vllm_max_model_len"])
|
|
257
|
-
|
|
279
|
+
|
|
258
280
|
vllm_kwargs = {
|
|
259
281
|
"model": self.url_or_model,
|
|
260
282
|
"task": "embed",
|
|
@@ -263,18 +285,23 @@ class VectorCache:
|
|
|
263
285
|
"dtype": dtype,
|
|
264
286
|
"trust_remote_code": trust_remote_code,
|
|
265
287
|
}
|
|
266
|
-
|
|
288
|
+
|
|
267
289
|
if max_model_len is not None:
|
|
268
290
|
vllm_kwargs["max_model_len"] = max_model_len
|
|
269
|
-
|
|
291
|
+
|
|
270
292
|
try:
|
|
271
293
|
self._model = LLM(**vllm_kwargs)
|
|
272
294
|
except (ValueError, AssertionError, RuntimeError) as e:
|
|
273
295
|
error_msg = str(e).lower()
|
|
274
|
-
if (
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
296
|
+
if (
|
|
297
|
+
("kv cache" in error_msg and "gpu_memory_utilization" in error_msg)
|
|
298
|
+
or (
|
|
299
|
+
"memory" in error_msg
|
|
300
|
+
and ("gpu" in error_msg or "insufficient" in error_msg)
|
|
301
|
+
)
|
|
302
|
+
or ("free memory" in error_msg and "initial" in error_msg)
|
|
303
|
+
or ("engine core initialization failed" in error_msg)
|
|
304
|
+
):
|
|
278
305
|
raise ValueError(
|
|
279
306
|
f"Insufficient GPU memory for vLLM model initialization. "
|
|
280
307
|
f"Current vllm_gpu_memory_utilization ({gpu_memory_utilization}) may be too low. "
|
|
@@ -288,27 +315,39 @@ class VectorCache:
|
|
|
288
315
|
else:
|
|
289
316
|
raise
|
|
290
317
|
elif self.backend == "transformers":
|
|
291
|
-
|
|
292
|
-
import
|
|
293
|
-
|
|
318
|
+
import torch # type: ignore[import-not-found] # noqa: F401
|
|
319
|
+
from transformers import ( # type: ignore[import-not-found]
|
|
320
|
+
AutoModel,
|
|
321
|
+
AutoTokenizer,
|
|
322
|
+
)
|
|
323
|
+
|
|
294
324
|
device = self.config["transformers_device"]
|
|
295
325
|
# Handle "auto" device selection - default to CPU for transformers to avoid memory conflicts
|
|
296
326
|
if device == "auto":
|
|
297
327
|
device = "cpu" # Default to CPU to avoid GPU memory conflicts with vLLM
|
|
298
|
-
|
|
299
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
300
|
-
|
|
301
|
-
|
|
328
|
+
|
|
329
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
330
|
+
self.url_or_model,
|
|
331
|
+
padding_side="left",
|
|
332
|
+
trust_remote_code=self.config["transformers_trust_remote_code"],
|
|
333
|
+
)
|
|
334
|
+
model = AutoModel.from_pretrained(
|
|
335
|
+
self.url_or_model,
|
|
336
|
+
trust_remote_code=self.config["transformers_trust_remote_code"],
|
|
337
|
+
)
|
|
338
|
+
|
|
302
339
|
# Move model to device
|
|
303
340
|
model.to(device)
|
|
304
341
|
model.eval()
|
|
305
|
-
|
|
342
|
+
|
|
306
343
|
self._model = {"tokenizer": tokenizer, "model": model, "device": device}
|
|
307
344
|
|
|
308
345
|
def _get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
309
346
|
"""Get embeddings using the configured backend."""
|
|
310
347
|
assert isinstance(texts, list), "texts must be a list"
|
|
311
|
-
assert all(isinstance(t, str) for t in texts),
|
|
348
|
+
assert all(isinstance(t, str) for t in texts), (
|
|
349
|
+
"all elements in texts must be strings"
|
|
350
|
+
)
|
|
312
351
|
if self.backend == "openai":
|
|
313
352
|
return self._get_openai_embeddings(texts)
|
|
314
353
|
elif self.backend == "vllm":
|
|
@@ -321,10 +360,14 @@ class VectorCache:
|
|
|
321
360
|
def _get_openai_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
322
361
|
"""Get embeddings using OpenAI API."""
|
|
323
362
|
assert isinstance(texts, list), "texts must be a list"
|
|
324
|
-
assert all(isinstance(t, str) for t in texts),
|
|
363
|
+
assert all(isinstance(t, str) for t in texts), (
|
|
364
|
+
"all elements in texts must be strings"
|
|
365
|
+
)
|
|
325
366
|
# Assert valid model_name for OpenAI backend
|
|
326
367
|
model_name = self.config["model_name"]
|
|
327
|
-
assert model_name is not None and model_name.strip(),
|
|
368
|
+
assert model_name is not None and model_name.strip(), (
|
|
369
|
+
f"Invalid model_name for OpenAI backend: {model_name}. Model name must be provided and non-empty."
|
|
370
|
+
)
|
|
328
371
|
|
|
329
372
|
if self._client is None:
|
|
330
373
|
if self.verbose:
|
|
@@ -332,10 +375,9 @@ class VectorCache:
|
|
|
332
375
|
self._load_openai_client()
|
|
333
376
|
if self.verbose:
|
|
334
377
|
print("✓ OpenAI client loaded successfully")
|
|
335
|
-
|
|
378
|
+
|
|
336
379
|
response = self._client.embeddings.create( # type: ignore
|
|
337
|
-
model=model_name,
|
|
338
|
-
input=texts
|
|
380
|
+
model=model_name, input=texts
|
|
339
381
|
)
|
|
340
382
|
embeddings = [item.embedding for item in response.data]
|
|
341
383
|
return embeddings
|
|
@@ -343,14 +385,16 @@ class VectorCache:
|
|
|
343
385
|
def _get_vllm_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
344
386
|
"""Get embeddings using vLLM."""
|
|
345
387
|
assert isinstance(texts, list), "texts must be a list"
|
|
346
|
-
assert all(isinstance(t, str) for t in texts),
|
|
388
|
+
assert all(isinstance(t, str) for t in texts), (
|
|
389
|
+
"all elements in texts must be strings"
|
|
390
|
+
)
|
|
347
391
|
if self._model is None:
|
|
348
392
|
if self.verbose:
|
|
349
393
|
print("🔧 Loading vLLM model...")
|
|
350
394
|
self._load_model()
|
|
351
395
|
if self.verbose:
|
|
352
396
|
print("✓ vLLM model loaded successfully")
|
|
353
|
-
|
|
397
|
+
|
|
354
398
|
outputs = self._model.embed(texts) # type: ignore
|
|
355
399
|
embeddings = [o.outputs.embedding for o in outputs]
|
|
356
400
|
return embeddings
|
|
@@ -358,26 +402,30 @@ class VectorCache:
|
|
|
358
402
|
def _get_transformers_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
359
403
|
"""Get embeddings using transformers directly."""
|
|
360
404
|
assert isinstance(texts, list), "texts must be a list"
|
|
361
|
-
assert all(isinstance(t, str) for t in texts),
|
|
405
|
+
assert all(isinstance(t, str) for t in texts), (
|
|
406
|
+
"all elements in texts must be strings"
|
|
407
|
+
)
|
|
362
408
|
if self._model is None:
|
|
363
409
|
if self.verbose:
|
|
364
410
|
print("🔧 Loading Transformers model...")
|
|
365
411
|
self._load_model()
|
|
366
412
|
if self.verbose:
|
|
367
413
|
print("✓ Transformers model loaded successfully")
|
|
368
|
-
|
|
414
|
+
|
|
369
415
|
if not isinstance(self._model, dict):
|
|
370
416
|
raise ValueError("Model not loaded properly for transformers backend")
|
|
371
|
-
|
|
417
|
+
|
|
372
418
|
tokenizer = self._model["tokenizer"]
|
|
373
419
|
model = self._model["model"]
|
|
374
420
|
device = self._model["device"]
|
|
375
|
-
|
|
376
|
-
normalize_embeddings = cast(
|
|
377
|
-
|
|
421
|
+
|
|
422
|
+
normalize_embeddings = cast(
|
|
423
|
+
bool, self.config["transformers_normalize_embeddings"]
|
|
424
|
+
)
|
|
425
|
+
|
|
378
426
|
# For now, use a default max_length
|
|
379
427
|
max_length = 8192
|
|
380
|
-
|
|
428
|
+
|
|
381
429
|
# Tokenize
|
|
382
430
|
batch_dict = tokenizer(
|
|
383
431
|
texts,
|
|
@@ -386,35 +434,43 @@ class VectorCache:
|
|
|
386
434
|
max_length=max_length,
|
|
387
435
|
return_tensors="pt",
|
|
388
436
|
)
|
|
389
|
-
|
|
437
|
+
|
|
390
438
|
# Move to device
|
|
391
439
|
batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
|
|
392
|
-
|
|
440
|
+
|
|
393
441
|
# Run model
|
|
394
442
|
import torch # type: ignore[import-not-found]
|
|
443
|
+
|
|
395
444
|
with torch.no_grad():
|
|
396
445
|
outputs = model(**batch_dict)
|
|
397
|
-
|
|
446
|
+
|
|
398
447
|
# Apply last token pooling
|
|
399
|
-
embeddings = self._last_token_pool(
|
|
400
|
-
|
|
448
|
+
embeddings = self._last_token_pool(
|
|
449
|
+
outputs.last_hidden_state, batch_dict["attention_mask"]
|
|
450
|
+
)
|
|
451
|
+
|
|
401
452
|
# Normalize if needed
|
|
402
453
|
if normalize_embeddings:
|
|
403
454
|
import torch.nn.functional as F # type: ignore[import-not-found]
|
|
455
|
+
|
|
404
456
|
embeddings = F.normalize(embeddings, p=2, dim=1)
|
|
405
|
-
|
|
457
|
+
|
|
406
458
|
return embeddings.cpu().numpy().tolist()
|
|
407
459
|
|
|
408
460
|
def _last_token_pool(self, last_hidden_states, attention_mask):
|
|
409
461
|
"""Apply last token pooling to get embeddings."""
|
|
410
462
|
import torch # type: ignore[import-not-found]
|
|
411
|
-
|
|
463
|
+
|
|
464
|
+
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
|
|
412
465
|
if left_padding:
|
|
413
466
|
return last_hidden_states[:, -1]
|
|
414
467
|
else:
|
|
415
468
|
sequence_lengths = attention_mask.sum(dim=1) - 1
|
|
416
469
|
batch_size = last_hidden_states.shape[0]
|
|
417
|
-
return last_hidden_states[
|
|
470
|
+
return last_hidden_states[
|
|
471
|
+
torch.arange(batch_size, device=last_hidden_states.device),
|
|
472
|
+
sequence_lengths,
|
|
473
|
+
]
|
|
418
474
|
|
|
419
475
|
def _hash_text(self, text: str) -> str:
|
|
420
476
|
return hashlib.sha1(text.encode("utf-8")).hexdigest()
|
|
@@ -423,33 +479,36 @@ class VectorCache:
|
|
|
423
479
|
"""Execute SQLite query with retry logic for multi-process safety."""
|
|
424
480
|
max_retries = 3
|
|
425
481
|
base_delay = 0.05 # 50ms base delay for reads (faster than writes)
|
|
426
|
-
|
|
482
|
+
|
|
427
483
|
last_exception = None
|
|
428
|
-
|
|
484
|
+
|
|
429
485
|
for attempt in range(max_retries + 1):
|
|
430
486
|
try:
|
|
431
487
|
if params is None:
|
|
432
488
|
return self.conn.execute(query)
|
|
433
489
|
else:
|
|
434
490
|
return self.conn.execute(query, params)
|
|
435
|
-
|
|
491
|
+
|
|
436
492
|
except sqlite3.OperationalError as e:
|
|
437
493
|
last_exception = e
|
|
438
494
|
if "database is locked" in str(e).lower() and attempt < max_retries:
|
|
439
495
|
# Exponential backoff: 0.05s, 0.1s, 0.2s
|
|
440
|
-
delay = base_delay * (2
|
|
496
|
+
delay = base_delay * (2**attempt)
|
|
441
497
|
if self.verbose:
|
|
442
|
-
print(
|
|
498
|
+
print(
|
|
499
|
+
f"⚠️ Database locked on read, retrying in {delay:.2f}s (attempt {attempt + 1}/{max_retries + 1})"
|
|
500
|
+
)
|
|
443
501
|
import time
|
|
502
|
+
|
|
444
503
|
time.sleep(delay)
|
|
445
504
|
continue
|
|
446
505
|
else:
|
|
447
506
|
# Re-raise if not a lock error or max retries exceeded
|
|
448
507
|
raise
|
|
449
|
-
except Exception
|
|
508
|
+
except Exception:
|
|
450
509
|
# Re-raise any other exceptions
|
|
451
510
|
raise
|
|
452
|
-
|
|
511
|
+
|
|
453
512
|
# This should never be reached, but satisfy the type checker
|
|
454
513
|
raise last_exception or RuntimeError("Failed to execute query after retries")
|
|
455
514
|
|
|
@@ -465,7 +524,9 @@ class VectorCache:
|
|
|
465
524
|
computing missing embeddings.
|
|
466
525
|
"""
|
|
467
526
|
assert isinstance(texts, list), "texts must be a list"
|
|
468
|
-
assert all(isinstance(t, str) for t in texts),
|
|
527
|
+
assert all(isinstance(t, str) for t in texts), (
|
|
528
|
+
"all elements in texts must be strings"
|
|
529
|
+
)
|
|
469
530
|
if not texts:
|
|
470
531
|
return np.empty((0, 0), dtype=np.float32)
|
|
471
532
|
t = time()
|
|
@@ -502,7 +563,9 @@ class VectorCache:
|
|
|
502
563
|
|
|
503
564
|
if missing_items:
|
|
504
565
|
if self.verbose:
|
|
505
|
-
print(
|
|
566
|
+
print(
|
|
567
|
+
f"Computing {len(missing_items)}/{len(texts)} missing embeddings..."
|
|
568
|
+
)
|
|
506
569
|
self._process_missing_items_with_batches(missing_items, hit_map)
|
|
507
570
|
|
|
508
571
|
# Return embeddings in the original order
|
|
@@ -511,92 +574,81 @@ class VectorCache:
|
|
|
511
574
|
print(f"Retrieved {len(texts)} embeddings in {elapsed:.2f} seconds")
|
|
512
575
|
return np.vstack([hit_map[h] for h in hashes])
|
|
513
576
|
|
|
514
|
-
def _process_missing_items_with_batches(
|
|
577
|
+
def _process_missing_items_with_batches(
|
|
578
|
+
self, missing_items: list[tuple[str, str]], hit_map: dict[str, np.ndarray]
|
|
579
|
+
) -> None:
|
|
515
580
|
"""
|
|
516
|
-
Process missing items in batches with progress
|
|
581
|
+
Process missing items in batches with simple progress tracking.
|
|
517
582
|
"""
|
|
518
583
|
t = time() # Track total processing time
|
|
519
|
-
|
|
520
|
-
# Try to import tqdm, fall back to simple progress if not available
|
|
521
|
-
tqdm = None # avoid "possibly unbound" in type checker
|
|
522
|
-
use_tqdm = False
|
|
523
|
-
try:
|
|
524
|
-
from tqdm import tqdm as _tqdm # type: ignore[import-not-found]
|
|
525
|
-
tqdm = _tqdm
|
|
526
|
-
use_tqdm = True
|
|
527
|
-
except ImportError:
|
|
528
|
-
use_tqdm = False
|
|
529
|
-
if self.verbose:
|
|
530
|
-
print("tqdm not available, using simple progress reporting")
|
|
531
|
-
|
|
584
|
+
|
|
532
585
|
batch_size = self.config["embedding_batch_size"]
|
|
533
586
|
total_items = len(missing_items)
|
|
534
|
-
|
|
587
|
+
|
|
535
588
|
if self.verbose:
|
|
536
|
-
print(
|
|
589
|
+
print(
|
|
590
|
+
f"Computing embeddings for {total_items} missing texts in batches of {batch_size}..."
|
|
591
|
+
)
|
|
537
592
|
if self.backend in ["vllm", "transformers"] and self._model is None:
|
|
538
|
-
print(
|
|
593
|
+
print("⚠️ Model will be loaded on first batch (lazy loading enabled)")
|
|
539
594
|
elif self.backend in ["vllm", "transformers"]:
|
|
540
|
-
print(
|
|
541
|
-
|
|
542
|
-
# Create progress bar
|
|
543
|
-
pbar = None
|
|
544
|
-
processed_count = 0
|
|
545
|
-
if use_tqdm and tqdm is not None:
|
|
546
|
-
pbar = tqdm(total=total_items, desc="Computing embeddings", unit="texts")
|
|
547
|
-
|
|
595
|
+
print("✓ Model already loaded, ready for efficient batch processing")
|
|
596
|
+
|
|
548
597
|
# Track total committed items
|
|
549
598
|
total_committed = 0
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
pbar.update(batch_size_actual)
|
|
575
|
-
else:
|
|
576
|
-
processed_count += batch_size_actual
|
|
577
|
-
if self.verbose:
|
|
578
|
-
print(f"Progress: {processed_count}/{total_items} embeddings computed, {total_committed} committed")
|
|
579
|
-
|
|
580
|
-
finally:
|
|
581
|
-
# Clean up progress bar
|
|
582
|
-
if pbar is not None:
|
|
583
|
-
pbar.close()
|
|
584
|
-
|
|
599
|
+
processed_count = 0
|
|
600
|
+
|
|
601
|
+
# Process in batches
|
|
602
|
+
for i in range(0, total_items, batch_size):
|
|
603
|
+
batch_items = missing_items[i : i + batch_size]
|
|
604
|
+
batch_texts = [text for text, _ in batch_items]
|
|
605
|
+
|
|
606
|
+
# Get embeddings for this batch
|
|
607
|
+
batch_embeds = self._get_embeddings(batch_texts)
|
|
608
|
+
|
|
609
|
+
# Prepare batch data for immediate insert
|
|
610
|
+
batch_data: list[tuple[str, str, bytes]] = []
|
|
611
|
+
for (text, h), vec in zip(batch_items, batch_embeds):
|
|
612
|
+
arr = np.asarray(vec, dtype=np.float32)
|
|
613
|
+
batch_data.append((h, text, arr.tobytes()))
|
|
614
|
+
hit_map[h] = arr
|
|
615
|
+
|
|
616
|
+
# Immediate commit after each batch
|
|
617
|
+
self._bulk_insert(batch_data)
|
|
618
|
+
total_committed += len(batch_data)
|
|
619
|
+
|
|
620
|
+
# Update progress - simple single line
|
|
621
|
+
batch_size_actual = len(batch_items)
|
|
622
|
+
processed_count += batch_size_actual
|
|
585
623
|
if self.verbose:
|
|
586
|
-
|
|
587
|
-
rate =
|
|
588
|
-
|
|
589
|
-
print(
|
|
624
|
+
elapsed = time() - t
|
|
625
|
+
rate = processed_count / elapsed if elapsed > 0 else 0
|
|
626
|
+
progress_pct = (processed_count / total_items) * 100
|
|
627
|
+
print(
|
|
628
|
+
f"\rProgress: {processed_count}/{total_items} ({progress_pct:.1f}%) | {rate:.0f} texts/sec",
|
|
629
|
+
end="",
|
|
630
|
+
flush=True,
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
if self.verbose:
|
|
634
|
+
total_time = time() - t
|
|
635
|
+
rate = total_items / total_time if total_time > 0 else 0
|
|
636
|
+
print(
|
|
637
|
+
f"\n✅ Completed: {total_items} embeddings computed and {total_committed} items committed to database"
|
|
638
|
+
)
|
|
639
|
+
print(f" Total time: {total_time:.2f}s | Rate: {rate:.1f} embeddings/sec")
|
|
590
640
|
|
|
591
641
|
def __call__(self, texts: list[str], cache: bool = True) -> np.ndarray:
|
|
592
642
|
assert isinstance(texts, list), "texts must be a list"
|
|
593
|
-
assert all(isinstance(t, str) for t in texts),
|
|
643
|
+
assert all(isinstance(t, str) for t in texts), (
|
|
644
|
+
"all elements in texts must be strings"
|
|
645
|
+
)
|
|
594
646
|
return self.embeds(texts, cache)
|
|
595
647
|
|
|
596
648
|
def _bulk_insert(self, data: list[tuple[str, str, bytes]]) -> None:
|
|
597
649
|
"""
|
|
598
650
|
Perform bulk insert of embedding data with retry logic for multi-process safety.
|
|
599
|
-
|
|
651
|
+
|
|
600
652
|
Uses INSERT OR IGNORE to prevent race conditions where multiple processes
|
|
601
653
|
might try to insert the same text hash. The first process to successfully
|
|
602
654
|
insert wins, subsequent attempts are ignored. This ensures deterministic
|
|
@@ -607,7 +659,7 @@ class VectorCache:
|
|
|
607
659
|
|
|
608
660
|
max_retries = 3
|
|
609
661
|
base_delay = 0.1 # 100ms base delay
|
|
610
|
-
|
|
662
|
+
|
|
611
663
|
for attempt in range(max_retries + 1):
|
|
612
664
|
try:
|
|
613
665
|
cursor = self.conn.executemany(
|
|
@@ -615,82 +667,34 @@ class VectorCache:
|
|
|
615
667
|
data,
|
|
616
668
|
)
|
|
617
669
|
self.conn.commit()
|
|
618
|
-
|
|
670
|
+
|
|
619
671
|
# Check if some insertions were ignored due to existing entries
|
|
620
|
-
if self.verbose and cursor.rowcount < len(data):
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
672
|
+
# if self.verbose and cursor.rowcount < len(data):
|
|
673
|
+
# ignored_count = len(data) - cursor.rowcount
|
|
674
|
+
# if ignored_count > 0:
|
|
675
|
+
# print(f"ℹ️ {ignored_count}/{len(data)} embeddings already existed in cache (computed by another process)")
|
|
676
|
+
|
|
625
677
|
return # Success, exit the retry loop
|
|
626
|
-
|
|
678
|
+
|
|
627
679
|
except sqlite3.OperationalError as e:
|
|
628
680
|
if "database is locked" in str(e).lower() and attempt < max_retries:
|
|
629
681
|
# Exponential backoff: 0.1s, 0.2s, 0.4s
|
|
630
|
-
delay = base_delay * (2
|
|
682
|
+
delay = base_delay * (2**attempt)
|
|
631
683
|
if self.verbose:
|
|
632
|
-
print(
|
|
684
|
+
print(
|
|
685
|
+
f"⚠️ Database locked, retrying in {delay:.1f}s (attempt {attempt + 1}/{max_retries + 1})"
|
|
686
|
+
)
|
|
633
687
|
import time
|
|
688
|
+
|
|
634
689
|
time.sleep(delay)
|
|
635
690
|
continue
|
|
636
691
|
else:
|
|
637
692
|
# Re-raise if not a lock error or max retries exceeded
|
|
638
693
|
raise
|
|
639
|
-
except Exception
|
|
694
|
+
except Exception:
|
|
640
695
|
# Re-raise any other exceptions
|
|
641
696
|
raise
|
|
642
697
|
|
|
643
|
-
# def precompute_embeddings(self, texts: list[str]) -> None:
|
|
644
|
-
# """
|
|
645
|
-
# Precompute embeddings for a large list of texts efficiently.
|
|
646
|
-
# This is optimized for bulk operations when you know all texts upfront.
|
|
647
|
-
# """
|
|
648
|
-
# assert isinstance(texts, list), "texts must be a list"
|
|
649
|
-
# assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
|
|
650
|
-
# if not texts:
|
|
651
|
-
# return
|
|
652
|
-
|
|
653
|
-
# # Remove duplicates while preserving order
|
|
654
|
-
# unique_texts = list(dict.fromkeys(texts))
|
|
655
|
-
# if self.verbose:
|
|
656
|
-
# print(f"Precomputing embeddings for {len(unique_texts)} unique texts...")
|
|
657
|
-
|
|
658
|
-
# # Check which ones are already cached
|
|
659
|
-
# hashes = [self._hash_text(t) for t in unique_texts]
|
|
660
|
-
# existing_hashes = set()
|
|
661
|
-
|
|
662
|
-
# # Bulk check for existing embeddings
|
|
663
|
-
# chunk_size = self.config["sqlite_chunk_size"]
|
|
664
|
-
# for i in range(0, len(hashes), chunk_size):
|
|
665
|
-
# chunk = hashes[i : i + chunk_size]
|
|
666
|
-
# placeholders = ",".join("?" * len(chunk))
|
|
667
|
-
# rows = self._execute_with_retry(
|
|
668
|
-
# f"SELECT hash FROM cache WHERE hash IN ({placeholders})",
|
|
669
|
-
# chunk,
|
|
670
|
-
# ).fetchall()
|
|
671
|
-
# existing_hashes.update(h[0] for h in rows)
|
|
672
|
-
|
|
673
|
-
# # Find missing texts
|
|
674
|
-
# missing_items = [
|
|
675
|
-
# (t, h) for t, h in zip(unique_texts, hashes) if h not in existing_hashes
|
|
676
|
-
# ]
|
|
677
|
-
|
|
678
|
-
# if not missing_items:
|
|
679
|
-
# if self.verbose:
|
|
680
|
-
# print("All texts already cached!")
|
|
681
|
-
# return
|
|
682
|
-
|
|
683
|
-
# if self.verbose:
|
|
684
|
-
# print(f"Computing {len(missing_items)} missing embeddings...")
|
|
685
|
-
|
|
686
|
-
# # Process missing items with batches
|
|
687
|
-
# missing_texts = [t for t, _ in missing_items]
|
|
688
|
-
# missing_items_tupled = [(t, h) for t, h in zip(missing_texts, [self._hash_text(t) for t in missing_texts])]
|
|
689
|
-
# hit_map_temp: dict[str, np.ndarray] = {}
|
|
690
|
-
# self._process_missing_items_with_batches(missing_items_tupled, hit_map_temp)
|
|
691
|
-
# if self.verbose:
|
|
692
|
-
# print(f"Successfully cached {len(missing_items)} new embeddings!")
|
|
693
|
-
|
|
694
698
|
def get_cache_stats(self) -> dict[str, int]:
|
|
695
699
|
"""Get statistics about the cache."""
|
|
696
700
|
cursor = self._execute_with_retry("SELECT COUNT(*) FROM cache")
|
|
@@ -701,24 +705,27 @@ class VectorCache:
|
|
|
701
705
|
"""Clear all cached embeddings."""
|
|
702
706
|
max_retries = 3
|
|
703
707
|
base_delay = 0.1 # 100ms base delay
|
|
704
|
-
|
|
708
|
+
|
|
705
709
|
for attempt in range(max_retries + 1):
|
|
706
710
|
try:
|
|
707
711
|
self.conn.execute("DELETE FROM cache")
|
|
708
712
|
self.conn.commit()
|
|
709
713
|
return # Success
|
|
710
|
-
|
|
714
|
+
|
|
711
715
|
except sqlite3.OperationalError as e:
|
|
712
716
|
if "database is locked" in str(e).lower() and attempt < max_retries:
|
|
713
|
-
delay = base_delay * (2
|
|
717
|
+
delay = base_delay * (2**attempt)
|
|
714
718
|
if self.verbose:
|
|
715
|
-
print(
|
|
719
|
+
print(
|
|
720
|
+
f"⚠️ Database locked during clear, retrying in {delay:.1f}s (attempt {attempt + 1}/{max_retries + 1})"
|
|
721
|
+
)
|
|
716
722
|
import time
|
|
723
|
+
|
|
717
724
|
time.sleep(delay)
|
|
718
725
|
continue
|
|
719
726
|
else:
|
|
720
727
|
raise
|
|
721
|
-
except Exception
|
|
728
|
+
except Exception:
|
|
722
729
|
raise
|
|
723
730
|
|
|
724
731
|
def get_config(self) -> Dict[str, Any]:
|
|
@@ -730,7 +737,7 @@ class VectorCache:
|
|
|
730
737
|
"db_path": str(self.db_path),
|
|
731
738
|
"verbose": self.verbose,
|
|
732
739
|
"lazy": self.lazy,
|
|
733
|
-
**self.config
|
|
740
|
+
**self.config,
|
|
734
741
|
}
|
|
735
742
|
|
|
736
743
|
def update_config(self, **kwargs) -> None:
|
|
@@ -744,17 +751,26 @@ class VectorCache:
|
|
|
744
751
|
self.lazy = value
|
|
745
752
|
else:
|
|
746
753
|
raise ValueError(f"Unknown configuration parameter: {key}")
|
|
747
|
-
|
|
754
|
+
|
|
748
755
|
# Reset model if backend-specific parameters changed
|
|
749
756
|
backend_params = {
|
|
750
|
-
"vllm": [
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
757
|
+
"vllm": [
|
|
758
|
+
"vllm_gpu_memory_utilization",
|
|
759
|
+
"vllm_tensor_parallel_size",
|
|
760
|
+
"vllm_dtype",
|
|
761
|
+
"vllm_trust_remote_code",
|
|
762
|
+
"vllm_max_model_len",
|
|
763
|
+
],
|
|
764
|
+
"transformers": [
|
|
765
|
+
"transformers_device",
|
|
766
|
+
"transformers_batch_size",
|
|
767
|
+
"transformers_normalize_embeddings",
|
|
768
|
+
"transformers_trust_remote_code",
|
|
769
|
+
],
|
|
754
770
|
"openai": ["api_key", "model_name"],
|
|
755
|
-
"processing": ["embedding_batch_size"]
|
|
771
|
+
"processing": ["embedding_batch_size"],
|
|
756
772
|
}
|
|
757
|
-
|
|
773
|
+
|
|
758
774
|
if any(param in kwargs for param in backend_params.get(self.backend, [])):
|
|
759
775
|
self._model = None # Force reload on next use
|
|
760
776
|
if self.backend == "openai":
|