speedy-utils 1.1.19__py3-none-any.whl → 1.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/lm/openai_memoize.py +9 -1
- llm_utils/vector_cache/core.py +248 -232
- speedy_utils/common/utils_cache.py +37 -18
- speedy_utils/multi_worker/process.py +34 -9
- {speedy_utils-1.1.19.dist-info → speedy_utils-1.1.21.dist-info}/METADATA +8 -8
- {speedy_utils-1.1.19.dist-info → speedy_utils-1.1.21.dist-info}/RECORD +8 -8
- {speedy_utils-1.1.19.dist-info → speedy_utils-1.1.21.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.19.dist-info → speedy_utils-1.1.21.dist-info}/entry_points.txt +0 -0
llm_utils/lm/openai_memoize.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from openai import OpenAI, AsyncOpenAI
|
|
2
|
+
from typing import Any, Callable
|
|
2
3
|
|
|
3
4
|
from speedy_utils.common.utils_cache import memoize
|
|
4
5
|
|
|
@@ -30,6 +31,8 @@ class MOpenAI(OpenAI):
|
|
|
30
31
|
- If you need a shared cache across instances, or more advanced cache controls,
|
|
31
32
|
modify `memoize` or wrap at a class/static level instead of assigning to the
|
|
32
33
|
bound method.
|
|
34
|
+
- Type information is now fully preserved by the memoize decorator, eliminating
|
|
35
|
+
the need for type casting.
|
|
33
36
|
|
|
34
37
|
Example
|
|
35
38
|
m = MOpenAI(api_key="...", model="gpt-4")
|
|
@@ -40,7 +43,12 @@ class MOpenAI(OpenAI):
|
|
|
40
43
|
def __init__(self, *args, cache=True, **kwargs):
|
|
41
44
|
super().__init__(*args, **kwargs)
|
|
42
45
|
if cache:
|
|
43
|
-
|
|
46
|
+
# Create a memoized wrapper for the instance's post method.
|
|
47
|
+
# The memoize decorator now preserves exact type information,
|
|
48
|
+
# so no casting is needed.
|
|
49
|
+
orig_post = self.post
|
|
50
|
+
memoized = memoize(orig_post)
|
|
51
|
+
self.post = memoized
|
|
44
52
|
|
|
45
53
|
|
|
46
54
|
class MAsyncOpenAI(AsyncOpenAI):
|
llm_utils/vector_cache/core.py
CHANGED
|
@@ -13,50 +13,51 @@ import numpy as np
|
|
|
13
13
|
class VectorCache:
|
|
14
14
|
"""
|
|
15
15
|
A caching layer for text embeddings with support for multiple backends.
|
|
16
|
-
|
|
16
|
+
|
|
17
17
|
This cache is designed to be safe for multi-process environments where multiple
|
|
18
18
|
processes may access the same cache file simultaneously. It uses SQLite WAL mode
|
|
19
19
|
and retry logic with exponential backoff to handle concurrent access.
|
|
20
|
-
|
|
20
|
+
|
|
21
21
|
Examples:
|
|
22
22
|
# OpenAI API
|
|
23
23
|
from llm_utils import VectorCache
|
|
24
24
|
cache = VectorCache("https://api.openai.com/v1", api_key="your-key")
|
|
25
25
|
embeddings = cache.embeds(["Hello world", "How are you?"])
|
|
26
|
-
|
|
26
|
+
|
|
27
27
|
# Custom OpenAI-compatible server (auto-detects model)
|
|
28
28
|
cache = VectorCache("http://localhost:8000/v1", api_key="abc")
|
|
29
|
-
|
|
29
|
+
|
|
30
30
|
# Transformers (Sentence Transformers)
|
|
31
31
|
cache = VectorCache("sentence-transformers/all-MiniLM-L6-v2")
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
# vLLM (local model)
|
|
34
34
|
cache = VectorCache("/path/to/model")
|
|
35
|
-
|
|
35
|
+
|
|
36
36
|
# Explicit backend specification
|
|
37
37
|
cache = VectorCache("model-name", backend="transformers")
|
|
38
|
-
|
|
38
|
+
|
|
39
39
|
# Eager loading (default: False) - load model immediately for better performance
|
|
40
40
|
cache = VectorCache("model-name", lazy=False)
|
|
41
|
-
|
|
41
|
+
|
|
42
42
|
# Lazy loading - load model only when needed (may cause performance issues)
|
|
43
43
|
cache = VectorCache("model-name", lazy=True)
|
|
44
|
-
|
|
44
|
+
|
|
45
45
|
Multi-Process Safety:
|
|
46
46
|
The cache uses SQLite WAL (Write-Ahead Logging) mode and implements retry logic
|
|
47
47
|
with exponential backoff to handle database locks. Multiple processes can safely
|
|
48
48
|
read and write to the same cache file simultaneously.
|
|
49
|
-
|
|
49
|
+
|
|
50
50
|
Race Condition Protection:
|
|
51
51
|
- Uses INSERT OR IGNORE to prevent overwrites when multiple processes compute the same text
|
|
52
52
|
- The first process to successfully cache a text wins, subsequent attempts are ignored
|
|
53
53
|
- This ensures deterministic results even with non-deterministic embedding models
|
|
54
|
-
|
|
54
|
+
|
|
55
55
|
For best performance in multi-process scenarios, consider:
|
|
56
56
|
- Using separate cache files per process if cache hits are low
|
|
57
57
|
- Coordinating cache warm-up to avoid redundant computation
|
|
58
58
|
- Monitor for excessive lock contention in high-concurrency scenarios
|
|
59
59
|
"""
|
|
60
|
+
|
|
60
61
|
def __init__(
|
|
61
62
|
self,
|
|
62
63
|
url_or_model: str,
|
|
@@ -80,7 +81,7 @@ class VectorCache:
|
|
|
80
81
|
# SQLite parameters
|
|
81
82
|
sqlite_chunk_size: int = 999,
|
|
82
83
|
sqlite_cache_size: int = 10000,
|
|
83
|
-
sqlite_mmap_size: int = 268435456,
|
|
84
|
+
sqlite_mmap_size: int = 268435456, # 256MB
|
|
84
85
|
# Processing parameters
|
|
85
86
|
embedding_batch_size: int = 20_000,
|
|
86
87
|
# Other parameters
|
|
@@ -91,11 +92,11 @@ class VectorCache:
|
|
|
91
92
|
self.embed_size = embed_size
|
|
92
93
|
self.verbose = verbose
|
|
93
94
|
self.lazy = lazy
|
|
94
|
-
|
|
95
|
+
|
|
95
96
|
self.backend = self._determine_backend(backend)
|
|
96
97
|
if self.verbose and backend is None:
|
|
97
98
|
print(f"Auto-detected backend: {self.backend}")
|
|
98
|
-
|
|
99
|
+
|
|
99
100
|
# Store all configuration parameters
|
|
100
101
|
self.config = {
|
|
101
102
|
# OpenAI
|
|
@@ -119,18 +120,20 @@ class VectorCache:
|
|
|
119
120
|
# Processing
|
|
120
121
|
"embedding_batch_size": embedding_batch_size,
|
|
121
122
|
}
|
|
122
|
-
|
|
123
|
+
|
|
123
124
|
# Auto-detect model_name for OpenAI if using custom URL and default model
|
|
124
|
-
if (
|
|
125
|
-
|
|
126
|
-
|
|
125
|
+
if (
|
|
126
|
+
self.backend == "openai"
|
|
127
|
+
and model_name == "text-embedding-3-small"
|
|
128
|
+
and self.url_or_model != "https://api.openai.com/v1"
|
|
129
|
+
):
|
|
127
130
|
if self.verbose:
|
|
128
131
|
print(f"Attempting to auto-detect model from {self.url_or_model}...")
|
|
129
132
|
try:
|
|
130
133
|
import openai
|
|
134
|
+
|
|
131
135
|
client = openai.OpenAI(
|
|
132
|
-
base_url=self.url_or_model,
|
|
133
|
-
api_key=self.config["api_key"]
|
|
136
|
+
base_url=self.url_or_model, api_key=self.config["api_key"]
|
|
134
137
|
)
|
|
135
138
|
models = client.models.list()
|
|
136
139
|
if models.data:
|
|
@@ -147,7 +150,7 @@ class VectorCache:
|
|
|
147
150
|
print(f"Model auto-detection failed: {e}, using default model")
|
|
148
151
|
# Fallback to default if auto-detection fails
|
|
149
152
|
pass
|
|
150
|
-
|
|
153
|
+
|
|
151
154
|
# Set default db_path if not provided
|
|
152
155
|
if db_path is None:
|
|
153
156
|
if self.backend == "openai":
|
|
@@ -155,19 +158,21 @@ class VectorCache:
|
|
|
155
158
|
else:
|
|
156
159
|
model_id = self.url_or_model
|
|
157
160
|
safe_name = hashlib.sha1(model_id.encode("utf-8")).hexdigest()[:16]
|
|
158
|
-
self.db_path =
|
|
161
|
+
self.db_path = (
|
|
162
|
+
Path.home() / ".cache" / "embed" / f"{self.backend}_{safe_name}.sqlite"
|
|
163
|
+
)
|
|
159
164
|
else:
|
|
160
165
|
self.db_path = Path(db_path)
|
|
161
|
-
|
|
166
|
+
|
|
162
167
|
# Ensure the directory exists
|
|
163
168
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
164
|
-
|
|
169
|
+
|
|
165
170
|
self.conn = sqlite3.connect(self.db_path)
|
|
166
171
|
self._optimize_connection()
|
|
167
172
|
self._ensure_schema()
|
|
168
173
|
self._model = None # Lazy loading
|
|
169
174
|
self._client = None # For OpenAI client
|
|
170
|
-
|
|
175
|
+
|
|
171
176
|
# Load model/client if not lazy
|
|
172
177
|
if not self.lazy:
|
|
173
178
|
if self.verbose:
|
|
@@ -179,34 +184,41 @@ class VectorCache:
|
|
|
179
184
|
if self.verbose:
|
|
180
185
|
print(f"✓ {self.backend.upper()} model/client loaded successfully")
|
|
181
186
|
|
|
182
|
-
def _determine_backend(
|
|
187
|
+
def _determine_backend(
|
|
188
|
+
self, backend: Optional[Literal["vllm", "transformers", "openai"]]
|
|
189
|
+
) -> str:
|
|
183
190
|
"""Determine the appropriate backend based on url_or_model and user preference."""
|
|
184
191
|
if backend is not None:
|
|
185
192
|
valid_backends = ["vllm", "transformers", "openai"]
|
|
186
193
|
if backend not in valid_backends:
|
|
187
|
-
raise ValueError(
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Invalid backend '{backend}'. Must be one of: {valid_backends}"
|
|
196
|
+
)
|
|
188
197
|
return backend
|
|
189
|
-
|
|
198
|
+
|
|
190
199
|
if self.url_or_model.startswith("http"):
|
|
191
200
|
return "openai"
|
|
192
|
-
|
|
201
|
+
|
|
193
202
|
# Default to vllm for local models
|
|
194
203
|
return "vllm"
|
|
204
|
+
|
|
195
205
|
def _try_infer_model_name(self, model_name: Optional[str]) -> Optional[str]:
|
|
196
206
|
"""Infer model name for OpenAI backend if not explicitly provided."""
|
|
197
207
|
if model_name:
|
|
198
208
|
return model_name
|
|
199
|
-
if
|
|
200
|
-
model_name =
|
|
201
|
-
|
|
202
|
-
if
|
|
209
|
+
if "https://" in self.url_or_model:
|
|
210
|
+
model_name = "text-embedding-3-small"
|
|
211
|
+
|
|
212
|
+
if "http://localhost" in self.url_or_model:
|
|
203
213
|
from openai import OpenAI
|
|
204
|
-
|
|
205
|
-
|
|
214
|
+
|
|
215
|
+
client = OpenAI(base_url=self.url_or_model, api_key="abc")
|
|
216
|
+
model_name = client.models.list().data[0].id
|
|
206
217
|
|
|
207
218
|
# Default model name
|
|
208
|
-
print(
|
|
219
|
+
print("Infer model name:", model_name)
|
|
209
220
|
return model_name
|
|
221
|
+
|
|
210
222
|
def _optimize_connection(self) -> None:
|
|
211
223
|
"""Optimize SQLite connection for bulk operations and multi-process safety."""
|
|
212
224
|
# Performance optimizations for bulk operations
|
|
@@ -214,13 +226,21 @@ class VectorCache:
|
|
|
214
226
|
"PRAGMA journal_mode=WAL"
|
|
215
227
|
) # Write-Ahead Logging for better concurrency
|
|
216
228
|
self.conn.execute("PRAGMA synchronous=NORMAL") # Faster writes, still safe
|
|
217
|
-
self.conn.execute(
|
|
229
|
+
self.conn.execute(
|
|
230
|
+
f"PRAGMA cache_size={self.config['sqlite_cache_size']}"
|
|
231
|
+
) # Configurable cache
|
|
218
232
|
self.conn.execute("PRAGMA temp_store=MEMORY") # Use memory for temp storage
|
|
219
|
-
self.conn.execute(
|
|
220
|
-
|
|
233
|
+
self.conn.execute(
|
|
234
|
+
f"PRAGMA mmap_size={self.config['sqlite_mmap_size']}"
|
|
235
|
+
) # Configurable memory mapping
|
|
236
|
+
|
|
221
237
|
# Multi-process safety improvements
|
|
222
|
-
self.conn.execute(
|
|
223
|
-
|
|
238
|
+
self.conn.execute(
|
|
239
|
+
"PRAGMA busy_timeout=30000"
|
|
240
|
+
) # Wait up to 30 seconds for locks
|
|
241
|
+
self.conn.execute(
|
|
242
|
+
"PRAGMA wal_autocheckpoint=1000"
|
|
243
|
+
) # Checkpoint WAL every 1000 pages
|
|
224
244
|
|
|
225
245
|
def _ensure_schema(self) -> None:
|
|
226
246
|
self.conn.execute("""
|
|
@@ -239,22 +259,24 @@ class VectorCache:
|
|
|
239
259
|
def _load_openai_client(self) -> None:
|
|
240
260
|
"""Load OpenAI client."""
|
|
241
261
|
import openai
|
|
262
|
+
|
|
242
263
|
self._client = openai.OpenAI(
|
|
243
|
-
base_url=self.url_or_model,
|
|
244
|
-
api_key=self.config["api_key"]
|
|
264
|
+
base_url=self.url_or_model, api_key=self.config["api_key"]
|
|
245
265
|
)
|
|
246
266
|
|
|
247
267
|
def _load_model(self) -> None:
|
|
248
268
|
"""Load the model for vLLM or Transformers."""
|
|
249
269
|
if self.backend == "vllm":
|
|
250
270
|
from vllm import LLM # type: ignore[import-not-found]
|
|
251
|
-
|
|
252
|
-
gpu_memory_utilization = cast(
|
|
271
|
+
|
|
272
|
+
gpu_memory_utilization = cast(
|
|
273
|
+
float, self.config["vllm_gpu_memory_utilization"]
|
|
274
|
+
)
|
|
253
275
|
tensor_parallel_size = cast(int, self.config["vllm_tensor_parallel_size"])
|
|
254
276
|
dtype = cast(str, self.config["vllm_dtype"])
|
|
255
277
|
trust_remote_code = cast(bool, self.config["vllm_trust_remote_code"])
|
|
256
278
|
max_model_len = cast(Optional[int], self.config["vllm_max_model_len"])
|
|
257
|
-
|
|
279
|
+
|
|
258
280
|
vllm_kwargs = {
|
|
259
281
|
"model": self.url_or_model,
|
|
260
282
|
"task": "embed",
|
|
@@ -263,18 +285,23 @@ class VectorCache:
|
|
|
263
285
|
"dtype": dtype,
|
|
264
286
|
"trust_remote_code": trust_remote_code,
|
|
265
287
|
}
|
|
266
|
-
|
|
288
|
+
|
|
267
289
|
if max_model_len is not None:
|
|
268
290
|
vllm_kwargs["max_model_len"] = max_model_len
|
|
269
|
-
|
|
291
|
+
|
|
270
292
|
try:
|
|
271
293
|
self._model = LLM(**vllm_kwargs)
|
|
272
294
|
except (ValueError, AssertionError, RuntimeError) as e:
|
|
273
295
|
error_msg = str(e).lower()
|
|
274
|
-
if (
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
296
|
+
if (
|
|
297
|
+
("kv cache" in error_msg and "gpu_memory_utilization" in error_msg)
|
|
298
|
+
or (
|
|
299
|
+
"memory" in error_msg
|
|
300
|
+
and ("gpu" in error_msg or "insufficient" in error_msg)
|
|
301
|
+
)
|
|
302
|
+
or ("free memory" in error_msg and "initial" in error_msg)
|
|
303
|
+
or ("engine core initialization failed" in error_msg)
|
|
304
|
+
):
|
|
278
305
|
raise ValueError(
|
|
279
306
|
f"Insufficient GPU memory for vLLM model initialization. "
|
|
280
307
|
f"Current vllm_gpu_memory_utilization ({gpu_memory_utilization}) may be too low. "
|
|
@@ -288,27 +315,39 @@ class VectorCache:
|
|
|
288
315
|
else:
|
|
289
316
|
raise
|
|
290
317
|
elif self.backend == "transformers":
|
|
291
|
-
|
|
292
|
-
import
|
|
293
|
-
|
|
318
|
+
import torch # type: ignore[import-not-found] # noqa: F401
|
|
319
|
+
from transformers import ( # type: ignore[import-not-found]
|
|
320
|
+
AutoModel,
|
|
321
|
+
AutoTokenizer,
|
|
322
|
+
)
|
|
323
|
+
|
|
294
324
|
device = self.config["transformers_device"]
|
|
295
325
|
# Handle "auto" device selection - default to CPU for transformers to avoid memory conflicts
|
|
296
326
|
if device == "auto":
|
|
297
327
|
device = "cpu" # Default to CPU to avoid GPU memory conflicts with vLLM
|
|
298
|
-
|
|
299
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
300
|
-
|
|
301
|
-
|
|
328
|
+
|
|
329
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
330
|
+
self.url_or_model,
|
|
331
|
+
padding_side="left",
|
|
332
|
+
trust_remote_code=self.config["transformers_trust_remote_code"],
|
|
333
|
+
)
|
|
334
|
+
model = AutoModel.from_pretrained(
|
|
335
|
+
self.url_or_model,
|
|
336
|
+
trust_remote_code=self.config["transformers_trust_remote_code"],
|
|
337
|
+
)
|
|
338
|
+
|
|
302
339
|
# Move model to device
|
|
303
340
|
model.to(device)
|
|
304
341
|
model.eval()
|
|
305
|
-
|
|
342
|
+
|
|
306
343
|
self._model = {"tokenizer": tokenizer, "model": model, "device": device}
|
|
307
344
|
|
|
308
345
|
def _get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
309
346
|
"""Get embeddings using the configured backend."""
|
|
310
347
|
assert isinstance(texts, list), "texts must be a list"
|
|
311
|
-
assert all(isinstance(t, str) for t in texts),
|
|
348
|
+
assert all(isinstance(t, str) for t in texts), (
|
|
349
|
+
"all elements in texts must be strings"
|
|
350
|
+
)
|
|
312
351
|
if self.backend == "openai":
|
|
313
352
|
return self._get_openai_embeddings(texts)
|
|
314
353
|
elif self.backend == "vllm":
|
|
@@ -321,10 +360,14 @@ class VectorCache:
|
|
|
321
360
|
def _get_openai_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
322
361
|
"""Get embeddings using OpenAI API."""
|
|
323
362
|
assert isinstance(texts, list), "texts must be a list"
|
|
324
|
-
assert all(isinstance(t, str) for t in texts),
|
|
363
|
+
assert all(isinstance(t, str) for t in texts), (
|
|
364
|
+
"all elements in texts must be strings"
|
|
365
|
+
)
|
|
325
366
|
# Assert valid model_name for OpenAI backend
|
|
326
367
|
model_name = self.config["model_name"]
|
|
327
|
-
assert model_name is not None and model_name.strip(),
|
|
368
|
+
assert model_name is not None and model_name.strip(), (
|
|
369
|
+
f"Invalid model_name for OpenAI backend: {model_name}. Model name must be provided and non-empty."
|
|
370
|
+
)
|
|
328
371
|
|
|
329
372
|
if self._client is None:
|
|
330
373
|
if self.verbose:
|
|
@@ -332,10 +375,9 @@ class VectorCache:
|
|
|
332
375
|
self._load_openai_client()
|
|
333
376
|
if self.verbose:
|
|
334
377
|
print("✓ OpenAI client loaded successfully")
|
|
335
|
-
|
|
378
|
+
|
|
336
379
|
response = self._client.embeddings.create( # type: ignore
|
|
337
|
-
model=model_name,
|
|
338
|
-
input=texts
|
|
380
|
+
model=model_name, input=texts
|
|
339
381
|
)
|
|
340
382
|
embeddings = [item.embedding for item in response.data]
|
|
341
383
|
return embeddings
|
|
@@ -343,14 +385,16 @@ class VectorCache:
|
|
|
343
385
|
def _get_vllm_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
344
386
|
"""Get embeddings using vLLM."""
|
|
345
387
|
assert isinstance(texts, list), "texts must be a list"
|
|
346
|
-
assert all(isinstance(t, str) for t in texts),
|
|
388
|
+
assert all(isinstance(t, str) for t in texts), (
|
|
389
|
+
"all elements in texts must be strings"
|
|
390
|
+
)
|
|
347
391
|
if self._model is None:
|
|
348
392
|
if self.verbose:
|
|
349
393
|
print("🔧 Loading vLLM model...")
|
|
350
394
|
self._load_model()
|
|
351
395
|
if self.verbose:
|
|
352
396
|
print("✓ vLLM model loaded successfully")
|
|
353
|
-
|
|
397
|
+
|
|
354
398
|
outputs = self._model.embed(texts) # type: ignore
|
|
355
399
|
embeddings = [o.outputs.embedding for o in outputs]
|
|
356
400
|
return embeddings
|
|
@@ -358,26 +402,30 @@ class VectorCache:
|
|
|
358
402
|
def _get_transformers_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
359
403
|
"""Get embeddings using transformers directly."""
|
|
360
404
|
assert isinstance(texts, list), "texts must be a list"
|
|
361
|
-
assert all(isinstance(t, str) for t in texts),
|
|
405
|
+
assert all(isinstance(t, str) for t in texts), (
|
|
406
|
+
"all elements in texts must be strings"
|
|
407
|
+
)
|
|
362
408
|
if self._model is None:
|
|
363
409
|
if self.verbose:
|
|
364
410
|
print("🔧 Loading Transformers model...")
|
|
365
411
|
self._load_model()
|
|
366
412
|
if self.verbose:
|
|
367
413
|
print("✓ Transformers model loaded successfully")
|
|
368
|
-
|
|
414
|
+
|
|
369
415
|
if not isinstance(self._model, dict):
|
|
370
416
|
raise ValueError("Model not loaded properly for transformers backend")
|
|
371
|
-
|
|
417
|
+
|
|
372
418
|
tokenizer = self._model["tokenizer"]
|
|
373
419
|
model = self._model["model"]
|
|
374
420
|
device = self._model["device"]
|
|
375
|
-
|
|
376
|
-
normalize_embeddings = cast(
|
|
377
|
-
|
|
421
|
+
|
|
422
|
+
normalize_embeddings = cast(
|
|
423
|
+
bool, self.config["transformers_normalize_embeddings"]
|
|
424
|
+
)
|
|
425
|
+
|
|
378
426
|
# For now, use a default max_length
|
|
379
427
|
max_length = 8192
|
|
380
|
-
|
|
428
|
+
|
|
381
429
|
# Tokenize
|
|
382
430
|
batch_dict = tokenizer(
|
|
383
431
|
texts,
|
|
@@ -386,35 +434,43 @@ class VectorCache:
|
|
|
386
434
|
max_length=max_length,
|
|
387
435
|
return_tensors="pt",
|
|
388
436
|
)
|
|
389
|
-
|
|
437
|
+
|
|
390
438
|
# Move to device
|
|
391
439
|
batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
|
|
392
|
-
|
|
440
|
+
|
|
393
441
|
# Run model
|
|
394
442
|
import torch # type: ignore[import-not-found]
|
|
443
|
+
|
|
395
444
|
with torch.no_grad():
|
|
396
445
|
outputs = model(**batch_dict)
|
|
397
|
-
|
|
446
|
+
|
|
398
447
|
# Apply last token pooling
|
|
399
|
-
embeddings = self._last_token_pool(
|
|
400
|
-
|
|
448
|
+
embeddings = self._last_token_pool(
|
|
449
|
+
outputs.last_hidden_state, batch_dict["attention_mask"]
|
|
450
|
+
)
|
|
451
|
+
|
|
401
452
|
# Normalize if needed
|
|
402
453
|
if normalize_embeddings:
|
|
403
454
|
import torch.nn.functional as F # type: ignore[import-not-found]
|
|
455
|
+
|
|
404
456
|
embeddings = F.normalize(embeddings, p=2, dim=1)
|
|
405
|
-
|
|
457
|
+
|
|
406
458
|
return embeddings.cpu().numpy().tolist()
|
|
407
459
|
|
|
408
460
|
def _last_token_pool(self, last_hidden_states, attention_mask):
|
|
409
461
|
"""Apply last token pooling to get embeddings."""
|
|
410
462
|
import torch # type: ignore[import-not-found]
|
|
411
|
-
|
|
463
|
+
|
|
464
|
+
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
|
|
412
465
|
if left_padding:
|
|
413
466
|
return last_hidden_states[:, -1]
|
|
414
467
|
else:
|
|
415
468
|
sequence_lengths = attention_mask.sum(dim=1) - 1
|
|
416
469
|
batch_size = last_hidden_states.shape[0]
|
|
417
|
-
return last_hidden_states[
|
|
470
|
+
return last_hidden_states[
|
|
471
|
+
torch.arange(batch_size, device=last_hidden_states.device),
|
|
472
|
+
sequence_lengths,
|
|
473
|
+
]
|
|
418
474
|
|
|
419
475
|
def _hash_text(self, text: str) -> str:
|
|
420
476
|
return hashlib.sha1(text.encode("utf-8")).hexdigest()
|
|
@@ -423,33 +479,36 @@ class VectorCache:
|
|
|
423
479
|
"""Execute SQLite query with retry logic for multi-process safety."""
|
|
424
480
|
max_retries = 3
|
|
425
481
|
base_delay = 0.05 # 50ms base delay for reads (faster than writes)
|
|
426
|
-
|
|
482
|
+
|
|
427
483
|
last_exception = None
|
|
428
|
-
|
|
484
|
+
|
|
429
485
|
for attempt in range(max_retries + 1):
|
|
430
486
|
try:
|
|
431
487
|
if params is None:
|
|
432
488
|
return self.conn.execute(query)
|
|
433
489
|
else:
|
|
434
490
|
return self.conn.execute(query, params)
|
|
435
|
-
|
|
491
|
+
|
|
436
492
|
except sqlite3.OperationalError as e:
|
|
437
493
|
last_exception = e
|
|
438
494
|
if "database is locked" in str(e).lower() and attempt < max_retries:
|
|
439
495
|
# Exponential backoff: 0.05s, 0.1s, 0.2s
|
|
440
|
-
delay = base_delay * (2
|
|
496
|
+
delay = base_delay * (2**attempt)
|
|
441
497
|
if self.verbose:
|
|
442
|
-
print(
|
|
498
|
+
print(
|
|
499
|
+
f"⚠️ Database locked on read, retrying in {delay:.2f}s (attempt {attempt + 1}/{max_retries + 1})"
|
|
500
|
+
)
|
|
443
501
|
import time
|
|
502
|
+
|
|
444
503
|
time.sleep(delay)
|
|
445
504
|
continue
|
|
446
505
|
else:
|
|
447
506
|
# Re-raise if not a lock error or max retries exceeded
|
|
448
507
|
raise
|
|
449
|
-
except Exception
|
|
508
|
+
except Exception:
|
|
450
509
|
# Re-raise any other exceptions
|
|
451
510
|
raise
|
|
452
|
-
|
|
511
|
+
|
|
453
512
|
# This should never be reached, but satisfy the type checker
|
|
454
513
|
raise last_exception or RuntimeError("Failed to execute query after retries")
|
|
455
514
|
|
|
@@ -465,7 +524,9 @@ class VectorCache:
|
|
|
465
524
|
computing missing embeddings.
|
|
466
525
|
"""
|
|
467
526
|
assert isinstance(texts, list), "texts must be a list"
|
|
468
|
-
assert all(isinstance(t, str) for t in texts),
|
|
527
|
+
assert all(isinstance(t, str) for t in texts), (
|
|
528
|
+
"all elements in texts must be strings"
|
|
529
|
+
)
|
|
469
530
|
if not texts:
|
|
470
531
|
return np.empty((0, 0), dtype=np.float32)
|
|
471
532
|
t = time()
|
|
@@ -502,7 +563,9 @@ class VectorCache:
|
|
|
502
563
|
|
|
503
564
|
if missing_items:
|
|
504
565
|
if self.verbose:
|
|
505
|
-
print(
|
|
566
|
+
print(
|
|
567
|
+
f"Computing {len(missing_items)}/{len(texts)} missing embeddings..."
|
|
568
|
+
)
|
|
506
569
|
self._process_missing_items_with_batches(missing_items, hit_map)
|
|
507
570
|
|
|
508
571
|
# Return embeddings in the original order
|
|
@@ -511,92 +574,81 @@ class VectorCache:
|
|
|
511
574
|
print(f"Retrieved {len(texts)} embeddings in {elapsed:.2f} seconds")
|
|
512
575
|
return np.vstack([hit_map[h] for h in hashes])
|
|
513
576
|
|
|
514
|
-
def _process_missing_items_with_batches(
|
|
577
|
+
def _process_missing_items_with_batches(
|
|
578
|
+
self, missing_items: list[tuple[str, str]], hit_map: dict[str, np.ndarray]
|
|
579
|
+
) -> None:
|
|
515
580
|
"""
|
|
516
|
-
Process missing items in batches with progress
|
|
581
|
+
Process missing items in batches with simple progress tracking.
|
|
517
582
|
"""
|
|
518
583
|
t = time() # Track total processing time
|
|
519
|
-
|
|
520
|
-
# Try to import tqdm, fall back to simple progress if not available
|
|
521
|
-
tqdm = None # avoid "possibly unbound" in type checker
|
|
522
|
-
use_tqdm = False
|
|
523
|
-
try:
|
|
524
|
-
from tqdm import tqdm as _tqdm # type: ignore[import-not-found]
|
|
525
|
-
tqdm = _tqdm
|
|
526
|
-
use_tqdm = True
|
|
527
|
-
except ImportError:
|
|
528
|
-
use_tqdm = False
|
|
529
|
-
if self.verbose:
|
|
530
|
-
print("tqdm not available, using simple progress reporting")
|
|
531
|
-
|
|
584
|
+
|
|
532
585
|
batch_size = self.config["embedding_batch_size"]
|
|
533
586
|
total_items = len(missing_items)
|
|
534
|
-
|
|
587
|
+
|
|
535
588
|
if self.verbose:
|
|
536
|
-
print(
|
|
589
|
+
print(
|
|
590
|
+
f"Computing embeddings for {total_items} missing texts in batches of {batch_size}..."
|
|
591
|
+
)
|
|
537
592
|
if self.backend in ["vllm", "transformers"] and self._model is None:
|
|
538
593
|
print("⚠️ Model will be loaded on first batch (lazy loading enabled)")
|
|
539
594
|
elif self.backend in ["vllm", "transformers"]:
|
|
540
595
|
print("✓ Model already loaded, ready for efficient batch processing")
|
|
541
|
-
|
|
542
|
-
# Create progress bar
|
|
543
|
-
pbar = None
|
|
544
|
-
processed_count = 0
|
|
545
|
-
if use_tqdm and tqdm is not None:
|
|
546
|
-
pbar = tqdm(total=total_items, desc="Computing embeddings", unit="texts")
|
|
547
|
-
|
|
596
|
+
|
|
548
597
|
# Track total committed items
|
|
549
598
|
total_committed = 0
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
pbar.update(batch_size_actual) # type: ignore
|
|
575
|
-
else:
|
|
576
|
-
processed_count += batch_size_actual
|
|
577
|
-
if self.verbose:
|
|
578
|
-
print(f"Progress: {processed_count}/{total_items} embeddings computed, {total_committed} committed")
|
|
579
|
-
|
|
580
|
-
finally:
|
|
581
|
-
# Clean up progress bar
|
|
582
|
-
if pbar is not None:
|
|
583
|
-
pbar.close()
|
|
584
|
-
|
|
599
|
+
processed_count = 0
|
|
600
|
+
|
|
601
|
+
# Process in batches
|
|
602
|
+
for i in range(0, total_items, batch_size):
|
|
603
|
+
batch_items = missing_items[i : i + batch_size]
|
|
604
|
+
batch_texts = [text for text, _ in batch_items]
|
|
605
|
+
|
|
606
|
+
# Get embeddings for this batch
|
|
607
|
+
batch_embeds = self._get_embeddings(batch_texts)
|
|
608
|
+
|
|
609
|
+
# Prepare batch data for immediate insert
|
|
610
|
+
batch_data: list[tuple[str, str, bytes]] = []
|
|
611
|
+
for (text, h), vec in zip(batch_items, batch_embeds):
|
|
612
|
+
arr = np.asarray(vec, dtype=np.float32)
|
|
613
|
+
batch_data.append((h, text, arr.tobytes()))
|
|
614
|
+
hit_map[h] = arr
|
|
615
|
+
|
|
616
|
+
# Immediate commit after each batch
|
|
617
|
+
self._bulk_insert(batch_data)
|
|
618
|
+
total_committed += len(batch_data)
|
|
619
|
+
|
|
620
|
+
# Update progress - simple single line
|
|
621
|
+
batch_size_actual = len(batch_items)
|
|
622
|
+
processed_count += batch_size_actual
|
|
585
623
|
if self.verbose:
|
|
586
|
-
|
|
587
|
-
rate =
|
|
588
|
-
|
|
589
|
-
print(
|
|
624
|
+
elapsed = time() - t
|
|
625
|
+
rate = processed_count / elapsed if elapsed > 0 else 0
|
|
626
|
+
progress_pct = (processed_count / total_items) * 100
|
|
627
|
+
print(
|
|
628
|
+
f"\rProgress: {processed_count}/{total_items} ({progress_pct:.1f}%) | {rate:.0f} texts/sec",
|
|
629
|
+
end="",
|
|
630
|
+
flush=True,
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
if self.verbose:
|
|
634
|
+
total_time = time() - t
|
|
635
|
+
rate = total_items / total_time if total_time > 0 else 0
|
|
636
|
+
print(
|
|
637
|
+
f"\n✅ Completed: {total_items} embeddings computed and {total_committed} items committed to database"
|
|
638
|
+
)
|
|
639
|
+
print(f" Total time: {total_time:.2f}s | Rate: {rate:.1f} embeddings/sec")
|
|
590
640
|
|
|
591
641
|
def __call__(self, texts: list[str], cache: bool = True) -> np.ndarray:
|
|
592
642
|
assert isinstance(texts, list), "texts must be a list"
|
|
593
|
-
assert all(isinstance(t, str) for t in texts),
|
|
643
|
+
assert all(isinstance(t, str) for t in texts), (
|
|
644
|
+
"all elements in texts must be strings"
|
|
645
|
+
)
|
|
594
646
|
return self.embeds(texts, cache)
|
|
595
647
|
|
|
596
648
|
def _bulk_insert(self, data: list[tuple[str, str, bytes]]) -> None:
|
|
597
649
|
"""
|
|
598
650
|
Perform bulk insert of embedding data with retry logic for multi-process safety.
|
|
599
|
-
|
|
651
|
+
|
|
600
652
|
Uses INSERT OR IGNORE to prevent race conditions where multiple processes
|
|
601
653
|
might try to insert the same text hash. The first process to successfully
|
|
602
654
|
insert wins, subsequent attempts are ignored. This ensures deterministic
|
|
@@ -607,7 +659,7 @@ class VectorCache:
|
|
|
607
659
|
|
|
608
660
|
max_retries = 3
|
|
609
661
|
base_delay = 0.1 # 100ms base delay
|
|
610
|
-
|
|
662
|
+
|
|
611
663
|
for attempt in range(max_retries + 1):
|
|
612
664
|
try:
|
|
613
665
|
cursor = self.conn.executemany(
|
|
@@ -615,82 +667,34 @@ class VectorCache:
|
|
|
615
667
|
data,
|
|
616
668
|
)
|
|
617
669
|
self.conn.commit()
|
|
618
|
-
|
|
670
|
+
|
|
619
671
|
# Check if some insertions were ignored due to existing entries
|
|
620
|
-
if self.verbose and cursor.rowcount < len(data):
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
672
|
+
# if self.verbose and cursor.rowcount < len(data):
|
|
673
|
+
# ignored_count = len(data) - cursor.rowcount
|
|
674
|
+
# if ignored_count > 0:
|
|
675
|
+
# print(f"ℹ️ {ignored_count}/{len(data)} embeddings already existed in cache (computed by another process)")
|
|
676
|
+
|
|
625
677
|
return # Success, exit the retry loop
|
|
626
|
-
|
|
678
|
+
|
|
627
679
|
except sqlite3.OperationalError as e:
|
|
628
680
|
if "database is locked" in str(e).lower() and attempt < max_retries:
|
|
629
681
|
# Exponential backoff: 0.1s, 0.2s, 0.4s
|
|
630
|
-
delay = base_delay * (2
|
|
682
|
+
delay = base_delay * (2**attempt)
|
|
631
683
|
if self.verbose:
|
|
632
|
-
print(
|
|
684
|
+
print(
|
|
685
|
+
f"⚠️ Database locked, retrying in {delay:.1f}s (attempt {attempt + 1}/{max_retries + 1})"
|
|
686
|
+
)
|
|
633
687
|
import time
|
|
688
|
+
|
|
634
689
|
time.sleep(delay)
|
|
635
690
|
continue
|
|
636
691
|
else:
|
|
637
692
|
# Re-raise if not a lock error or max retries exceeded
|
|
638
693
|
raise
|
|
639
|
-
except Exception
|
|
694
|
+
except Exception:
|
|
640
695
|
# Re-raise any other exceptions
|
|
641
696
|
raise
|
|
642
697
|
|
|
643
|
-
# def precompute_embeddings(self, texts: list[str]) -> None:
|
|
644
|
-
# """
|
|
645
|
-
# Precompute embeddings for a large list of texts efficiently.
|
|
646
|
-
# This is optimized for bulk operations when you know all texts upfront.
|
|
647
|
-
# """
|
|
648
|
-
# assert isinstance(texts, list), "texts must be a list"
|
|
649
|
-
# assert all(isinstance(t, str) for t in texts), "all elements in texts must be strings"
|
|
650
|
-
# if not texts:
|
|
651
|
-
# return
|
|
652
|
-
|
|
653
|
-
# # Remove duplicates while preserving order
|
|
654
|
-
# unique_texts = list(dict.fromkeys(texts))
|
|
655
|
-
# if self.verbose:
|
|
656
|
-
# print(f"Precomputing embeddings for {len(unique_texts)} unique texts...")
|
|
657
|
-
|
|
658
|
-
# # Check which ones are already cached
|
|
659
|
-
# hashes = [self._hash_text(t) for t in unique_texts]
|
|
660
|
-
# existing_hashes = set()
|
|
661
|
-
|
|
662
|
-
# # Bulk check for existing embeddings
|
|
663
|
-
# chunk_size = self.config["sqlite_chunk_size"]
|
|
664
|
-
# for i in range(0, len(hashes), chunk_size):
|
|
665
|
-
# chunk = hashes[i : i + chunk_size]
|
|
666
|
-
# placeholders = ",".join("?" * len(chunk))
|
|
667
|
-
# rows = self._execute_with_retry(
|
|
668
|
-
# f"SELECT hash FROM cache WHERE hash IN ({placeholders})",
|
|
669
|
-
# chunk,
|
|
670
|
-
# ).fetchall()
|
|
671
|
-
# existing_hashes.update(h[0] for h in rows)
|
|
672
|
-
|
|
673
|
-
# # Find missing texts
|
|
674
|
-
# missing_items = [
|
|
675
|
-
# (t, h) for t, h in zip(unique_texts, hashes) if h not in existing_hashes
|
|
676
|
-
# ]
|
|
677
|
-
|
|
678
|
-
# if not missing_items:
|
|
679
|
-
# if self.verbose:
|
|
680
|
-
# print("All texts already cached!")
|
|
681
|
-
# return
|
|
682
|
-
|
|
683
|
-
# if self.verbose:
|
|
684
|
-
# print(f"Computing {len(missing_items)} missing embeddings...")
|
|
685
|
-
|
|
686
|
-
# # Process missing items with batches
|
|
687
|
-
# missing_texts = [t for t, _ in missing_items]
|
|
688
|
-
# missing_items_tupled = [(t, h) for t, h in zip(missing_texts, [self._hash_text(t) for t in missing_texts])]
|
|
689
|
-
# hit_map_temp: dict[str, np.ndarray] = {}
|
|
690
|
-
# self._process_missing_items_with_batches(missing_items_tupled, hit_map_temp)
|
|
691
|
-
# if self.verbose:
|
|
692
|
-
# print(f"Successfully cached {len(missing_items)} new embeddings!")
|
|
693
|
-
|
|
694
698
|
def get_cache_stats(self) -> dict[str, int]:
|
|
695
699
|
"""Get statistics about the cache."""
|
|
696
700
|
cursor = self._execute_with_retry("SELECT COUNT(*) FROM cache")
|
|
@@ -701,24 +705,27 @@ class VectorCache:
|
|
|
701
705
|
"""Clear all cached embeddings."""
|
|
702
706
|
max_retries = 3
|
|
703
707
|
base_delay = 0.1 # 100ms base delay
|
|
704
|
-
|
|
708
|
+
|
|
705
709
|
for attempt in range(max_retries + 1):
|
|
706
710
|
try:
|
|
707
711
|
self.conn.execute("DELETE FROM cache")
|
|
708
712
|
self.conn.commit()
|
|
709
713
|
return # Success
|
|
710
|
-
|
|
714
|
+
|
|
711
715
|
except sqlite3.OperationalError as e:
|
|
712
716
|
if "database is locked" in str(e).lower() and attempt < max_retries:
|
|
713
|
-
delay = base_delay * (2
|
|
717
|
+
delay = base_delay * (2**attempt)
|
|
714
718
|
if self.verbose:
|
|
715
|
-
print(
|
|
719
|
+
print(
|
|
720
|
+
f"⚠️ Database locked during clear, retrying in {delay:.1f}s (attempt {attempt + 1}/{max_retries + 1})"
|
|
721
|
+
)
|
|
716
722
|
import time
|
|
723
|
+
|
|
717
724
|
time.sleep(delay)
|
|
718
725
|
continue
|
|
719
726
|
else:
|
|
720
727
|
raise
|
|
721
|
-
except Exception
|
|
728
|
+
except Exception:
|
|
722
729
|
raise
|
|
723
730
|
|
|
724
731
|
def get_config(self) -> Dict[str, Any]:
|
|
@@ -730,7 +737,7 @@ class VectorCache:
|
|
|
730
737
|
"db_path": str(self.db_path),
|
|
731
738
|
"verbose": self.verbose,
|
|
732
739
|
"lazy": self.lazy,
|
|
733
|
-
**self.config
|
|
740
|
+
**self.config,
|
|
734
741
|
}
|
|
735
742
|
|
|
736
743
|
def update_config(self, **kwargs) -> None:
|
|
@@ -744,17 +751,26 @@ class VectorCache:
|
|
|
744
751
|
self.lazy = value
|
|
745
752
|
else:
|
|
746
753
|
raise ValueError(f"Unknown configuration parameter: {key}")
|
|
747
|
-
|
|
754
|
+
|
|
748
755
|
# Reset model if backend-specific parameters changed
|
|
749
756
|
backend_params = {
|
|
750
|
-
"vllm": [
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
757
|
+
"vllm": [
|
|
758
|
+
"vllm_gpu_memory_utilization",
|
|
759
|
+
"vllm_tensor_parallel_size",
|
|
760
|
+
"vllm_dtype",
|
|
761
|
+
"vllm_trust_remote_code",
|
|
762
|
+
"vllm_max_model_len",
|
|
763
|
+
],
|
|
764
|
+
"transformers": [
|
|
765
|
+
"transformers_device",
|
|
766
|
+
"transformers_batch_size",
|
|
767
|
+
"transformers_normalize_embeddings",
|
|
768
|
+
"transformers_trust_remote_code",
|
|
769
|
+
],
|
|
754
770
|
"openai": ["api_key", "model_name"],
|
|
755
|
-
"processing": ["embedding_batch_size"]
|
|
771
|
+
"processing": ["embedding_batch_size"],
|
|
756
772
|
}
|
|
757
|
-
|
|
773
|
+
|
|
758
774
|
if any(param in kwargs for param in backend_params.get(self.backend, [])):
|
|
759
775
|
self._model = None # Force reload on next use
|
|
760
776
|
if self.backend == "openai":
|
|
@@ -258,13 +258,13 @@ def _memory_memoize(
|
|
|
258
258
|
|
|
259
259
|
with mem_lock:
|
|
260
260
|
if name in mem_cache:
|
|
261
|
-
return mem_cache[name]
|
|
261
|
+
return mem_cache[name]
|
|
262
262
|
|
|
263
263
|
result = func(*args, **kwargs)
|
|
264
264
|
|
|
265
265
|
with mem_lock:
|
|
266
266
|
if name not in mem_cache:
|
|
267
|
-
mem_cache[name] = result
|
|
267
|
+
mem_cache[name] = result
|
|
268
268
|
return result
|
|
269
269
|
|
|
270
270
|
return wrapper
|
|
@@ -292,7 +292,7 @@ def _async_memory_memoize(
|
|
|
292
292
|
|
|
293
293
|
async with alock:
|
|
294
294
|
if name in mem_cache:
|
|
295
|
-
return mem_cache[name]
|
|
295
|
+
return mem_cache[name]
|
|
296
296
|
task = inflight.get(name)
|
|
297
297
|
if task is None:
|
|
298
298
|
task = asyncio.create_task(func(*args, **kwargs)) # type: ignore[arg-type]
|
|
@@ -305,7 +305,7 @@ def _async_memory_memoize(
|
|
|
305
305
|
inflight.pop(name, None)
|
|
306
306
|
|
|
307
307
|
with mem_lock:
|
|
308
|
-
mem_cache[name] = result
|
|
308
|
+
mem_cache[name] = result
|
|
309
309
|
return result
|
|
310
310
|
|
|
311
311
|
return wrapper
|
|
@@ -447,7 +447,7 @@ def both_memoize(
|
|
|
447
447
|
# Memory first
|
|
448
448
|
with mem_lock:
|
|
449
449
|
if mem_key in mem_cache:
|
|
450
|
-
return mem_cache[mem_key]
|
|
450
|
+
return mem_cache[mem_key]
|
|
451
451
|
|
|
452
452
|
# Disk next
|
|
453
453
|
if sub_dir == "funcs":
|
|
@@ -468,7 +468,7 @@ def both_memoize(
|
|
|
468
468
|
|
|
469
469
|
if disk_result is not None:
|
|
470
470
|
with mem_lock:
|
|
471
|
-
mem_cache[mem_key] = disk_result
|
|
471
|
+
mem_cache[mem_key] = disk_result
|
|
472
472
|
return disk_result
|
|
473
473
|
|
|
474
474
|
# Miss: compute, then write both
|
|
@@ -477,7 +477,7 @@ def both_memoize(
|
|
|
477
477
|
if not osp.exists(cache_path):
|
|
478
478
|
dump_json_or_pickle(result, cache_path)
|
|
479
479
|
with mem_lock:
|
|
480
|
-
mem_cache[mem_key] = result
|
|
480
|
+
mem_cache[mem_key] = result
|
|
481
481
|
return result
|
|
482
482
|
|
|
483
483
|
return wrapper
|
|
@@ -506,7 +506,7 @@ def _async_both_memoize(
|
|
|
506
506
|
# Memory
|
|
507
507
|
async with alock:
|
|
508
508
|
if mem_key in mem_cache:
|
|
509
|
-
return mem_cache[mem_key]
|
|
509
|
+
return mem_cache[mem_key]
|
|
510
510
|
|
|
511
511
|
# Disk
|
|
512
512
|
if sub_dir == "funcs":
|
|
@@ -526,7 +526,7 @@ def _async_both_memoize(
|
|
|
526
526
|
|
|
527
527
|
if disk_result is not None:
|
|
528
528
|
with mem_lock:
|
|
529
|
-
mem_cache[mem_key] = disk_result
|
|
529
|
+
mem_cache[mem_key] = disk_result
|
|
530
530
|
return disk_result
|
|
531
531
|
|
|
532
532
|
# Avoid duplicate async work for same key
|
|
@@ -550,7 +550,7 @@ def _async_both_memoize(
|
|
|
550
550
|
await loop.run_in_executor(None, write_disk_cache)
|
|
551
551
|
|
|
552
552
|
with mem_lock:
|
|
553
|
-
mem_cache[mem_key] = result
|
|
553
|
+
mem_cache[mem_key] = result
|
|
554
554
|
return result
|
|
555
555
|
|
|
556
556
|
return wrapper
|
|
@@ -561,9 +561,10 @@ def _async_both_memoize(
|
|
|
561
561
|
# --------------------------------------------------------------------------------------
|
|
562
562
|
|
|
563
563
|
|
|
564
|
+
# Define overloads to preserve exact type information
|
|
564
565
|
@overload
|
|
565
566
|
def memoize(
|
|
566
|
-
_func: Callable[P, R
|
|
567
|
+
_func: Callable[P, R],
|
|
567
568
|
*,
|
|
568
569
|
keys: Optional[list[str]] = ...,
|
|
569
570
|
key: Optional[Callable[..., Any]] = ...,
|
|
@@ -572,7 +573,23 @@ def memoize(
|
|
|
572
573
|
size: int = ...,
|
|
573
574
|
ignore_self: bool = ...,
|
|
574
575
|
verbose: bool = ...,
|
|
575
|
-
) -> Callable[P, R
|
|
576
|
+
) -> Callable[P, R]: ...
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
@overload
|
|
580
|
+
def memoize(
|
|
581
|
+
_func: Callable[P, Awaitable[R]],
|
|
582
|
+
*,
|
|
583
|
+
keys: Optional[list[str]] = ...,
|
|
584
|
+
key: Optional[Callable[..., Any]] = ...,
|
|
585
|
+
cache_dir: str = ...,
|
|
586
|
+
cache_type: Literal["memory", "disk", "both"] = ...,
|
|
587
|
+
size: int = ...,
|
|
588
|
+
ignore_self: bool = ...,
|
|
589
|
+
verbose: bool = ...,
|
|
590
|
+
) -> Callable[P, Awaitable[R]]: ...
|
|
591
|
+
|
|
592
|
+
|
|
576
593
|
@overload
|
|
577
594
|
def memoize(
|
|
578
595
|
_func: None = ...,
|
|
@@ -585,6 +602,8 @@ def memoize(
|
|
|
585
602
|
ignore_self: bool = ...,
|
|
586
603
|
verbose: bool = ...,
|
|
587
604
|
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
|
605
|
+
|
|
606
|
+
|
|
588
607
|
@overload
|
|
589
608
|
def memoize( # type: ignore
|
|
590
609
|
_func: None = ...,
|
|
@@ -635,24 +654,24 @@ def memoize(
|
|
|
635
654
|
|
|
636
655
|
if cache_type == "memory":
|
|
637
656
|
if is_async:
|
|
638
|
-
return _async_memory_memoize(target_func, size, keys, ignore_self, key)
|
|
639
|
-
return _memory_memoize(target_func, size, keys, ignore_self, key)
|
|
657
|
+
return _async_memory_memoize(target_func, size, keys, ignore_self, key)
|
|
658
|
+
return _memory_memoize(target_func, size, keys, ignore_self, key)
|
|
640
659
|
|
|
641
660
|
if cache_type == "disk":
|
|
642
661
|
if is_async:
|
|
643
662
|
return _async_disk_memoize(
|
|
644
663
|
target_func, keys, cache_dir, ignore_self, verbose, key
|
|
645
|
-
)
|
|
664
|
+
)
|
|
646
665
|
return _disk_memoize(
|
|
647
666
|
target_func, keys, cache_dir, ignore_self, verbose, key
|
|
648
|
-
)
|
|
667
|
+
)
|
|
649
668
|
|
|
650
669
|
# cache_type == "both"
|
|
651
670
|
if is_async:
|
|
652
671
|
return _async_both_memoize(
|
|
653
672
|
target_func, keys, cache_dir, ignore_self, size, key
|
|
654
|
-
)
|
|
655
|
-
return both_memoize(target_func, keys, cache_dir, ignore_self, size, key)
|
|
673
|
+
)
|
|
674
|
+
return both_memoize(target_func, keys, cache_dir, ignore_self, size, key)
|
|
656
675
|
|
|
657
676
|
# Support both @memoize and @memoize(...)
|
|
658
677
|
if _func is None:
|
|
@@ -1,5 +1,10 @@
|
|
|
1
1
|
# ray_multi_process.py
|
|
2
2
|
import time, os, pickle, uuid, datetime, multiprocessing
|
|
3
|
+
import datetime
|
|
4
|
+
import os
|
|
5
|
+
import pickle
|
|
6
|
+
import time
|
|
7
|
+
import uuid
|
|
3
8
|
from pathlib import Path
|
|
4
9
|
from typing import Any, Callable
|
|
5
10
|
from tqdm import tqdm
|
|
@@ -12,11 +17,16 @@ try:
|
|
|
12
17
|
except Exception: # pragma: no cover
|
|
13
18
|
ray = None # type: ignore
|
|
14
19
|
_HAS_RAY = False
|
|
20
|
+
from typing import Any, Callable, Iterable
|
|
21
|
+
|
|
22
|
+
import ray
|
|
15
23
|
from fastcore.parallel import parallel
|
|
24
|
+
from tqdm import tqdm
|
|
16
25
|
|
|
17
26
|
|
|
18
27
|
# ─── cache helpers ──────────────────────────────────────────
|
|
19
28
|
|
|
29
|
+
|
|
20
30
|
def _build_cache_dir(func: Callable, items: list[Any]) -> Path:
|
|
21
31
|
"""Build cache dir with function name + timestamp."""
|
|
22
32
|
func_name = getattr(func, "__name__", "func")
|
|
@@ -27,6 +37,7 @@ def _build_cache_dir(func: Callable, items: list[Any]) -> Path:
|
|
|
27
37
|
path.mkdir(parents=True, exist_ok=True)
|
|
28
38
|
return path
|
|
29
39
|
|
|
40
|
+
|
|
30
41
|
def wrap_dump(func: Callable, cache_dir: Path | None):
|
|
31
42
|
"""Wrap a function so results are dumped to .pkl when cache_dir is set."""
|
|
32
43
|
if cache_dir is None:
|
|
@@ -38,12 +49,15 @@ def wrap_dump(func: Callable, cache_dir: Path | None):
|
|
|
38
49
|
with open(p, "wb") as fh:
|
|
39
50
|
pickle.dump(res, fh)
|
|
40
51
|
return str(p)
|
|
52
|
+
|
|
41
53
|
return wrapped
|
|
42
54
|
|
|
55
|
+
|
|
43
56
|
# ─── ray management ─────────────────────────────────────────
|
|
44
57
|
|
|
45
58
|
RAY_WORKER = None
|
|
46
59
|
|
|
60
|
+
|
|
47
61
|
def ensure_ray(workers: int, pbar: tqdm | None = None):
|
|
48
62
|
"""Initialize or reinitialize Ray with a given worker count, log to bar postfix."""
|
|
49
63
|
global RAY_WORKER
|
|
@@ -58,19 +72,21 @@ def ensure_ray(workers: int, pbar: tqdm | None = None):
|
|
|
58
72
|
pbar.set_postfix_str(f"ray.init {workers} took {took:.2f}s")
|
|
59
73
|
RAY_WORKER = workers
|
|
60
74
|
|
|
75
|
+
|
|
61
76
|
# ─── main API ───────────────────────────────────────────────
|
|
62
77
|
from typing import Literal
|
|
63
78
|
|
|
79
|
+
|
|
64
80
|
def multi_process(
|
|
65
81
|
func: Callable[[Any], Any],
|
|
66
|
-
items:
|
|
82
|
+
items: Iterable[Any] | None = None,
|
|
67
83
|
*,
|
|
68
|
-
inputs:
|
|
84
|
+
inputs: Iterable[Any] | None = None,
|
|
69
85
|
workers: int | None = None,
|
|
70
86
|
lazy_output: bool = False,
|
|
71
87
|
progress: bool = True,
|
|
72
88
|
# backend: str = "ray", # "seq", "ray", or "fastcore"
|
|
73
|
-
backend: Literal["seq", "ray", "mp", "threadpool", "safe"]
|
|
89
|
+
backend: Literal["seq", "ray", "mp", "threadpool", "safe"] = "mp",
|
|
74
90
|
# Additional optional knobs (accepted for compatibility)
|
|
75
91
|
batch: int | None = None,
|
|
76
92
|
ordered: bool | None = None,
|
|
@@ -97,8 +113,12 @@ def multi_process(
|
|
|
97
113
|
backend = "ray" if _HAS_RAY else "mp"
|
|
98
114
|
|
|
99
115
|
# unify items
|
|
116
|
+
# unify items and coerce to concrete list so we can use len() and
|
|
117
|
+
# iterate multiple times. This accepts ranges and other iterables.
|
|
100
118
|
if items is None and inputs is not None:
|
|
101
119
|
items = list(inputs)
|
|
120
|
+
if items is not None and not isinstance(items, list):
|
|
121
|
+
items = list(items)
|
|
102
122
|
if items is None:
|
|
103
123
|
raise ValueError("'items' or 'inputs' must be provided")
|
|
104
124
|
|
|
@@ -110,8 +130,9 @@ def multi_process(
|
|
|
110
130
|
f_wrapped = wrap_dump(func, cache_dir)
|
|
111
131
|
|
|
112
132
|
total = len(items)
|
|
113
|
-
with tqdm(
|
|
114
|
-
|
|
133
|
+
with tqdm(
|
|
134
|
+
total=total, desc=f"multi_process [{backend}]", disable=not progress
|
|
135
|
+
) as pbar:
|
|
115
136
|
# ---- sequential backend ----
|
|
116
137
|
if backend == "seq":
|
|
117
138
|
pbar.set_postfix_str("backend=seq")
|
|
@@ -147,18 +168,22 @@ def multi_process(
|
|
|
147
168
|
|
|
148
169
|
# ---- fastcore backend ----
|
|
149
170
|
if backend == "mp":
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
171
|
+
results = parallel(
|
|
172
|
+
f_wrapped, items, n_workers=workers, progress=progress, threadpool=False
|
|
173
|
+
)
|
|
153
174
|
return list(results)
|
|
154
175
|
if backend == "threadpool":
|
|
155
|
-
results = parallel(
|
|
176
|
+
results = parallel(
|
|
177
|
+
f_wrapped, items, n_workers=workers, progress=progress, threadpool=True
|
|
178
|
+
)
|
|
156
179
|
return list(results)
|
|
157
180
|
if backend == "safe":
|
|
158
181
|
# Completely safe backend for tests - no multiprocessing, no external progress bars
|
|
159
182
|
import concurrent.futures
|
|
160
183
|
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
|
161
184
|
results = list(executor.map(f_wrapped, items))
|
|
185
|
+
return results
|
|
186
|
+
|
|
162
187
|
raise ValueError(f"Unsupported backend: {backend!r}")
|
|
163
188
|
|
|
164
189
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: speedy-utils
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.21
|
|
4
4
|
Summary: Fast and easy-to-use package for data science
|
|
5
5
|
Project-URL: Homepage, https://github.com/anhvth/speedy
|
|
6
6
|
Project-URL: Repository, https://github.com/anhvth/speedy
|
|
@@ -18,26 +18,26 @@ Classifier: Programming Language :: Python :: 3.12
|
|
|
18
18
|
Classifier: Programming Language :: Python :: 3.13
|
|
19
19
|
Classifier: Programming Language :: Python :: 3.14
|
|
20
20
|
Requires-Python: >=3.8
|
|
21
|
-
Requires-Dist: aiohttp
|
|
21
|
+
Requires-Dist: aiohttp
|
|
22
22
|
Requires-Dist: bump2version
|
|
23
23
|
Requires-Dist: cachetools
|
|
24
24
|
Requires-Dist: debugpy
|
|
25
25
|
Requires-Dist: fastcore
|
|
26
26
|
Requires-Dist: fastprogress
|
|
27
|
-
Requires-Dist: freezegun
|
|
27
|
+
Requires-Dist: freezegun
|
|
28
28
|
Requires-Dist: ipdb
|
|
29
29
|
Requires-Dist: ipywidgets
|
|
30
|
-
Requires-Dist: json-repair
|
|
30
|
+
Requires-Dist: json-repair
|
|
31
31
|
Requires-Dist: jupyterlab
|
|
32
32
|
Requires-Dist: loguru
|
|
33
33
|
Requires-Dist: matplotlib
|
|
34
34
|
Requires-Dist: numpy
|
|
35
|
-
Requires-Dist: openai
|
|
36
|
-
Requires-Dist: packaging
|
|
35
|
+
Requires-Dist: openai
|
|
36
|
+
Requires-Dist: packaging
|
|
37
37
|
Requires-Dist: pandas
|
|
38
38
|
Requires-Dist: pydantic
|
|
39
|
-
Requires-Dist: pytest
|
|
40
|
-
Requires-Dist: ray
|
|
39
|
+
Requires-Dist: pytest
|
|
40
|
+
Requires-Dist: ray
|
|
41
41
|
Requires-Dist: requests
|
|
42
42
|
Requires-Dist: scikit-learn
|
|
43
43
|
Requires-Dist: tabulate
|
|
@@ -9,7 +9,7 @@ llm_utils/lm/base_prompt_builder.py,sha256=OLqyxbA8QeYIVFzB9EqxUiE_P2p4_MD_Lq4WS
|
|
|
9
9
|
llm_utils/lm/llm_task.py,sha256=kyBeMDJwW9ZWq5A_OMgE-ou9GQ0bk5c9lxXOvfo31R4,27915
|
|
10
10
|
llm_utils/lm/lm.py,sha256=8TaLuU7naPQbOFmiS2NQyWVLG0jUUzRRBQsR0In7GVo,7249
|
|
11
11
|
llm_utils/lm/lm_base.py,sha256=pqbHZOdR7yUMpvwt8uBG1dZnt76SY_Wk8BkXQQ-mpWs,9557
|
|
12
|
-
llm_utils/lm/openai_memoize.py,sha256=
|
|
12
|
+
llm_utils/lm/openai_memoize.py,sha256=KToCcB_rhyrULxolnwMfQgl5GNrAeykePxuLS4hBjtc,3442
|
|
13
13
|
llm_utils/lm/utils.py,sha256=a0KJj8vjT2fHKb7GKGNJjJHhKLThwpxIL7vnV9Fr3ZY,4584
|
|
14
14
|
llm_utils/lm/async_lm/__init__.py,sha256=PUBbCuf5u6-0GBUu-2PI6YAguzsyXj-LPkU6vccqT6E,121
|
|
15
15
|
llm_utils/lm/async_lm/_utils.py,sha256=P1-pUDf_0pDmo8WTIi43t5ARlyGA1RIJfpAhz-gfA5g,6105
|
|
@@ -22,7 +22,7 @@ llm_utils/scripts/vllm_load_balancer.py,sha256=TT5Ypq7gUcl52gRFp--ORFFjzhfGlcaX2
|
|
|
22
22
|
llm_utils/scripts/vllm_serve.py,sha256=gJ0-y4kybMfSt8qzye1pJqGMY3x9JLRi6Tu7RjJMnss,14771
|
|
23
23
|
llm_utils/vector_cache/__init__.py,sha256=i1KQuC4OhPewYpFl9X6HlWFBuASCTx2qgGizhpZhmn0,862
|
|
24
24
|
llm_utils/vector_cache/cli.py,sha256=DMXTj8nZ2_LRjprbYPb4uzq04qZtOfBbmblmaqDcCuM,6251
|
|
25
|
-
llm_utils/vector_cache/core.py,sha256=
|
|
25
|
+
llm_utils/vector_cache/core.py,sha256=J8ocRX9sBfzboQkf5vFF2cx0SK-nftmKWJUa91WUBy8,31134
|
|
26
26
|
llm_utils/vector_cache/types.py,sha256=ru8qmUZ8_lNd3_oYpjCMtpXTsqmwsSBe56Z4hTWm3xI,435
|
|
27
27
|
llm_utils/vector_cache/utils.py,sha256=dwbbXlRrARrpmS4YqSlYQqrTURg0UWe8XvaAWcX05MM,1458
|
|
28
28
|
speedy_utils/__init__.py,sha256=QBvGIbrC5yczQwh4T8iu9KQx6w9u-v_JdoQfA67hLUg,5780
|
|
@@ -34,17 +34,17 @@ speedy_utils/common/logger.py,sha256=a2iZx0eWyfi2-2X_H2QmfuA3tfR7_XSM7Nd0GdUnUOs
|
|
|
34
34
|
speedy_utils/common/notebook_utils.py,sha256=-97kehJ_Gg3TzDLubsLIYJcykqX1NXhbvBO6nniZSYM,2063
|
|
35
35
|
speedy_utils/common/patcher.py,sha256=VCmdxyTF87qroggQkQklRPhAOPJbeBqhcJoTsLcDxNw,2303
|
|
36
36
|
speedy_utils/common/report_manager.py,sha256=eBiw5KY6bWUhwki3B4lK5o8bFsp7L5x28X9GCI-Sd1w,3899
|
|
37
|
-
speedy_utils/common/utils_cache.py,sha256=
|
|
37
|
+
speedy_utils/common/utils_cache.py,sha256=NCwILnhsK86sDPkkriDTCyuM-qUKFxYOo1Piww1ED0g,22381
|
|
38
38
|
speedy_utils/common/utils_io.py,sha256=-RkQjYGa3zVqpgVInsdp8dbS5oLwdJdUsRz1XIUSJzg,14257
|
|
39
39
|
speedy_utils/common/utils_misc.py,sha256=cdEuBBpiB1xpuzj0UBDHDuTIerqsMIw37ENq6EXliOw,1795
|
|
40
40
|
speedy_utils/common/utils_print.py,sha256=syRrnSFtguxrV-elx6DDVcSGu4Qy7D_xVNZhPwbUY4A,4864
|
|
41
41
|
speedy_utils/multi_worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
42
|
-
speedy_utils/multi_worker/process.py,sha256=
|
|
42
|
+
speedy_utils/multi_worker/process.py,sha256=0Rhr2xJWtX0PeZXPFU3zAAqbybh83DdF1C2gwHJLXls,7231
|
|
43
43
|
speedy_utils/multi_worker/thread.py,sha256=xhCPgJokCDjjPrWh6vUtCBlZgs3E6mM81WCAEKvZea0,19522
|
|
44
44
|
speedy_utils/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
45
45
|
speedy_utils/scripts/mpython.py,sha256=IvywP7Y0_V6tWfMP-4MjPvN5_KfxWF21xaLJsCIayCk,3821
|
|
46
46
|
speedy_utils/scripts/openapi_client_codegen.py,sha256=f2125S_q0PILgH5dyzoKRz7pIvNEjCkzpi4Q4pPFRZE,9683
|
|
47
|
-
speedy_utils-1.1.
|
|
48
|
-
speedy_utils-1.1.
|
|
49
|
-
speedy_utils-1.1.
|
|
50
|
-
speedy_utils-1.1.
|
|
47
|
+
speedy_utils-1.1.21.dist-info/METADATA,sha256=b_vxwYzT_2oorlbKL5NVZZ1ZVbInUw45deXmLOG0tys,8028
|
|
48
|
+
speedy_utils-1.1.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
49
|
+
speedy_utils-1.1.21.dist-info/entry_points.txt,sha256=1rrFMfqvaMUE9hvwGiD6vnVh98kmgy0TARBj-v0Lfhs,244
|
|
50
|
+
speedy_utils-1.1.21.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|