openaivec 0.10.0__py3-none-any.whl → 1.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openaivec/__init__.py +13 -4
- openaivec/_cache/__init__.py +12 -0
- openaivec/_cache/optimize.py +109 -0
- openaivec/_cache/proxy.py +806 -0
- openaivec/_di.py +326 -0
- openaivec/_embeddings.py +203 -0
- openaivec/{log.py → _log.py} +2 -2
- openaivec/_model.py +113 -0
- openaivec/{prompt.py → _prompt.py} +95 -28
- openaivec/_provider.py +207 -0
- openaivec/_responses.py +511 -0
- openaivec/_schema/__init__.py +9 -0
- openaivec/_schema/infer.py +340 -0
- openaivec/_schema/spec.py +350 -0
- openaivec/_serialize.py +234 -0
- openaivec/{util.py → _util.py} +25 -85
- openaivec/pandas_ext.py +1635 -425
- openaivec/spark.py +604 -335
- openaivec/task/__init__.py +27 -29
- openaivec/task/customer_support/__init__.py +9 -15
- openaivec/task/customer_support/customer_sentiment.py +51 -41
- openaivec/task/customer_support/inquiry_classification.py +86 -61
- openaivec/task/customer_support/inquiry_summary.py +44 -45
- openaivec/task/customer_support/intent_analysis.py +56 -41
- openaivec/task/customer_support/response_suggestion.py +49 -43
- openaivec/task/customer_support/urgency_analysis.py +76 -71
- openaivec/task/nlp/__init__.py +4 -4
- openaivec/task/nlp/dependency_parsing.py +19 -20
- openaivec/task/nlp/keyword_extraction.py +22 -24
- openaivec/task/nlp/morphological_analysis.py +25 -25
- openaivec/task/nlp/named_entity_recognition.py +26 -28
- openaivec/task/nlp/sentiment_analysis.py +29 -21
- openaivec/task/nlp/translation.py +24 -30
- openaivec/task/table/__init__.py +3 -0
- openaivec/task/table/fillna.py +183 -0
- openaivec-1.0.10.dist-info/METADATA +399 -0
- openaivec-1.0.10.dist-info/RECORD +39 -0
- {openaivec-0.10.0.dist-info → openaivec-1.0.10.dist-info}/WHEEL +1 -1
- openaivec/embeddings.py +0 -172
- openaivec/responses.py +0 -392
- openaivec/serialize.py +0 -225
- openaivec/task/model.py +0 -84
- openaivec-0.10.0.dist-info/METADATA +0 -546
- openaivec-0.10.0.dist-info/RECORD +0 -29
- {openaivec-0.10.0.dist-info → openaivec-1.0.10.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,806 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import threading
|
|
3
|
+
from collections.abc import Awaitable, Callable, Hashable
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any, Generic, TypeVar
|
|
6
|
+
|
|
7
|
+
from openaivec._cache import BatchSizeSuggester
|
|
8
|
+
|
|
9
|
+
__all__ = []
|
|
10
|
+
|
|
11
|
+
S = TypeVar("S", bound=Hashable)
|
|
12
|
+
T = TypeVar("T")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ProxyBase(Generic[S, T]):
|
|
16
|
+
"""Common utilities shared by BatchingMapProxy and AsyncBatchingMapProxy.
|
|
17
|
+
|
|
18
|
+
Provides order-preserving deduplication and batch size normalization that
|
|
19
|
+
depend only on ``batch_size`` and do not touch concurrency primitives.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
batch_size: Optional mini-batch size hint used by implementations to
|
|
23
|
+
split work into chunks. When unset or non-positive, implementations
|
|
24
|
+
should process the entire input in a single call.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
batch_size: int | None # subclasses may override via dataclass
|
|
28
|
+
show_progress: bool # Enable progress bar display
|
|
29
|
+
suggester: BatchSizeSuggester # Batch size optimization, initialized by subclasses
|
|
30
|
+
|
|
31
|
+
def _is_notebook_environment(self) -> bool:
|
|
32
|
+
"""Check if running in a Jupyter notebook environment.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
bool: True if running in a notebook, False otherwise.
|
|
36
|
+
"""
|
|
37
|
+
try:
|
|
38
|
+
from IPython.core.getipython import get_ipython
|
|
39
|
+
|
|
40
|
+
ipython = get_ipython()
|
|
41
|
+
if ipython is not None:
|
|
42
|
+
# Check for different notebook environments
|
|
43
|
+
class_name = ipython.__class__.__name__
|
|
44
|
+
module_name = ipython.__class__.__module__
|
|
45
|
+
|
|
46
|
+
# Standard Jupyter notebook/lab
|
|
47
|
+
if class_name == "ZMQInteractiveShell":
|
|
48
|
+
return True
|
|
49
|
+
|
|
50
|
+
# JupyterLab and newer environments
|
|
51
|
+
if "zmq" in module_name.lower() or "jupyter" in module_name.lower():
|
|
52
|
+
return True
|
|
53
|
+
|
|
54
|
+
# Google Colab
|
|
55
|
+
if "google.colab" in module_name:
|
|
56
|
+
return True
|
|
57
|
+
|
|
58
|
+
# VS Code notebooks and others
|
|
59
|
+
if hasattr(ipython, "kernel"):
|
|
60
|
+
return True
|
|
61
|
+
|
|
62
|
+
except ImportError:
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
# Check for other notebook indicators
|
|
66
|
+
# Check for common notebook environment variables
|
|
67
|
+
import os
|
|
68
|
+
import sys
|
|
69
|
+
|
|
70
|
+
notebook_vars = [
|
|
71
|
+
"JUPYTER_CONFIG_DIR",
|
|
72
|
+
"JUPYTERLAB_DIR",
|
|
73
|
+
"COLAB_GPU",
|
|
74
|
+
"VSCODE_PID", # VS Code
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
for var in notebook_vars:
|
|
78
|
+
if var in os.environ:
|
|
79
|
+
return True
|
|
80
|
+
|
|
81
|
+
# Check if running in IPython without terminal
|
|
82
|
+
if "IPython" in sys.modules:
|
|
83
|
+
try:
|
|
84
|
+
# If we can import display from IPython, likely in notebook
|
|
85
|
+
import importlib.util
|
|
86
|
+
|
|
87
|
+
if importlib.util.find_spec("IPython.display") is not None:
|
|
88
|
+
return True
|
|
89
|
+
except ImportError:
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
return False
|
|
93
|
+
|
|
94
|
+
def _create_progress_bar(self, total: int, desc: str = "Processing batches") -> Any:
|
|
95
|
+
"""Create a progress bar if conditions are met.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
total (int): Total number of items to process.
|
|
99
|
+
desc (str): Description for the progress bar.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Any: Progress bar instance or None if not available.
|
|
103
|
+
"""
|
|
104
|
+
try:
|
|
105
|
+
from tqdm.auto import tqdm as tqdm_progress
|
|
106
|
+
|
|
107
|
+
if self.show_progress and self._is_notebook_environment():
|
|
108
|
+
return tqdm_progress(total=total, desc=desc, unit="item")
|
|
109
|
+
except ImportError:
|
|
110
|
+
pass
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
def _update_progress_bar(self, progress_bar: Any, increment: int) -> None:
|
|
114
|
+
"""Update progress bar with the given increment.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
progress_bar (Any): Progress bar instance.
|
|
118
|
+
increment (int): Number of items to increment.
|
|
119
|
+
"""
|
|
120
|
+
if progress_bar:
|
|
121
|
+
progress_bar.update(increment)
|
|
122
|
+
|
|
123
|
+
def _close_progress_bar(self, progress_bar: Any) -> None:
|
|
124
|
+
"""Close the progress bar.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
progress_bar (Any): Progress bar instance.
|
|
128
|
+
"""
|
|
129
|
+
if progress_bar:
|
|
130
|
+
progress_bar.close()
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def _unique_in_order(seq: list[S]) -> list[S]:
|
|
134
|
+
"""Return unique items preserving their first-occurrence order.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
seq (list[S]): Sequence of items which may contain duplicates.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
list[S]: A new list containing each distinct item from ``seq`` exactly
|
|
141
|
+
once, in the order of their first occurrence.
|
|
142
|
+
"""
|
|
143
|
+
seen: set[S] = set()
|
|
144
|
+
out: list[S] = []
|
|
145
|
+
for x in seq:
|
|
146
|
+
if x not in seen:
|
|
147
|
+
seen.add(x)
|
|
148
|
+
out.append(x)
|
|
149
|
+
return out
|
|
150
|
+
|
|
151
|
+
def _normalized_batch_size(self, total: int) -> int:
|
|
152
|
+
"""Compute the effective batch size used for processing.
|
|
153
|
+
|
|
154
|
+
If ``batch_size`` is None, use the suggester to determine optimal batch size.
|
|
155
|
+
If ``batch_size`` is non-positive, process the entire ``total`` in a single call.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
total (int): Number of items intended to be processed.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
int: The positive batch size to use.
|
|
162
|
+
"""
|
|
163
|
+
if self.batch_size and self.batch_size > 0:
|
|
164
|
+
return self.batch_size
|
|
165
|
+
elif self.batch_size is None:
|
|
166
|
+
# Use suggester to determine optimal batch size
|
|
167
|
+
suggested = self.suggester.suggest_batch_size()
|
|
168
|
+
return min(suggested, total) # Don't exceed total items
|
|
169
|
+
else:
|
|
170
|
+
# batch_size is 0 or negative, process all at once
|
|
171
|
+
return total
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
@dataclass
|
|
175
|
+
class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
176
|
+
"""Thread-safe local proxy that caches results of a mapping function.
|
|
177
|
+
|
|
178
|
+
This proxy batches calls to the ``map_func`` you pass to ``map()`` (if
|
|
179
|
+
``batch_size`` is set),
|
|
180
|
+
deduplicates inputs while preserving order, and ensures that concurrent calls do
|
|
181
|
+
not duplicate work via an in-flight registry. All public behavior is preserved
|
|
182
|
+
while minimizing redundant requests and maintaining input order in the output.
|
|
183
|
+
|
|
184
|
+
When ``batch_size=None``, automatic batch size optimization is enabled,
|
|
185
|
+
dynamically adjusting batch sizes based on execution time to maintain optimal
|
|
186
|
+
performance (targeting 30-60 seconds per batch).
|
|
187
|
+
|
|
188
|
+
Example:
|
|
189
|
+
```python
|
|
190
|
+
p = BatchingMapProxy[int, str](batch_size=3)
|
|
191
|
+
|
|
192
|
+
def f(xs: list[int]) -> list[str]:
|
|
193
|
+
return [f"v:{x}" for x in xs]
|
|
194
|
+
|
|
195
|
+
p.map([1, 2, 2, 3, 4], f)
|
|
196
|
+
# ['v:1', 'v:2', 'v:2', 'v:3', 'v:4']
|
|
197
|
+
```
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
# Number of items to process per call to map_func.
|
|
201
|
+
# - If None (default): Enables automatic batch size optimization, dynamically adjusting
|
|
202
|
+
# based on execution time (targeting 30-60 seconds per batch)
|
|
203
|
+
# - If positive integer: Fixed batch size
|
|
204
|
+
# - If <= 0: Process all items at once
|
|
205
|
+
batch_size: int | None = None
|
|
206
|
+
show_progress: bool = True
|
|
207
|
+
suggester: BatchSizeSuggester = field(default_factory=BatchSizeSuggester, repr=False)
|
|
208
|
+
|
|
209
|
+
# internals
|
|
210
|
+
_cache: dict[S, T] = field(default_factory=dict)
|
|
211
|
+
_lock: threading.RLock = field(default_factory=threading.RLock, repr=False)
|
|
212
|
+
_inflight: dict[S, threading.Event] = field(default_factory=dict, repr=False)
|
|
213
|
+
|
|
214
|
+
def __all_cached(self, items: list[S]) -> bool:
|
|
215
|
+
"""Check whether all items are present in the cache.
|
|
216
|
+
|
|
217
|
+
This method acquires the internal lock to perform a consistent check.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
items (list[S]): Items to verify against the cache.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
bool: True if every item is already cached, False otherwise.
|
|
224
|
+
"""
|
|
225
|
+
with self._lock:
|
|
226
|
+
return all(x in self._cache for x in items)
|
|
227
|
+
|
|
228
|
+
def __values(self, items: list[S]) -> list[T]:
|
|
229
|
+
"""Fetch cached values for ``items`` preserving the given order.
|
|
230
|
+
|
|
231
|
+
This method acquires the internal lock while reading the cache.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
items (list[S]): Items to retrieve from the cache.
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
list[T]: The cached values corresponding to ``items`` in the same
|
|
238
|
+
order.
|
|
239
|
+
"""
|
|
240
|
+
with self._lock:
|
|
241
|
+
return [self._cache[x] for x in items]
|
|
242
|
+
|
|
243
|
+
def __acquire_ownership(self, items: list[S]) -> tuple[list[S], list[S]]:
|
|
244
|
+
"""Acquire ownership for missing items and identify keys to wait for.
|
|
245
|
+
|
|
246
|
+
For each unique item, if it's already cached, it is ignored. If it's
|
|
247
|
+
currently being computed by another thread (in-flight), it is added to
|
|
248
|
+
the wait list. Otherwise, this method marks the key as in-flight and
|
|
249
|
+
considers it "owned" by the current thread.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
items (list[S]): Unique items (order-preserving) to be processed.
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
tuple[list[S], list[S]]: A tuple ``(owned, wait_for)`` where
|
|
256
|
+
- ``owned`` are items this thread is responsible for computing.
|
|
257
|
+
- ``wait_for`` are items that another thread is already computing.
|
|
258
|
+
"""
|
|
259
|
+
owned: list[S] = []
|
|
260
|
+
wait_for: list[S] = []
|
|
261
|
+
with self._lock:
|
|
262
|
+
for x in items:
|
|
263
|
+
if x in self._cache:
|
|
264
|
+
continue
|
|
265
|
+
if x in self._inflight:
|
|
266
|
+
wait_for.append(x)
|
|
267
|
+
else:
|
|
268
|
+
self._inflight[x] = threading.Event()
|
|
269
|
+
owned.append(x)
|
|
270
|
+
return owned, wait_for
|
|
271
|
+
|
|
272
|
+
def __finalize_success(self, to_call: list[S], results: list[T]) -> None:
|
|
273
|
+
"""Populate cache with results and signal completion events.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
to_call (list[S]): Items that were computed.
|
|
277
|
+
results (list[T]): Results corresponding to ``to_call`` in order.
|
|
278
|
+
"""
|
|
279
|
+
if len(results) != len(to_call):
|
|
280
|
+
# Prevent deadlocks if map_func violates the contract.
|
|
281
|
+
# Release waiters and surface a clear error.
|
|
282
|
+
self.__finalize_failure(to_call)
|
|
283
|
+
raise ValueError("map_func must return a list of results with the same length and order as inputs")
|
|
284
|
+
with self._lock:
|
|
285
|
+
for x, y in zip(to_call, results):
|
|
286
|
+
self._cache[x] = y
|
|
287
|
+
ev = self._inflight.pop(x, None)
|
|
288
|
+
if ev:
|
|
289
|
+
ev.set()
|
|
290
|
+
|
|
291
|
+
def __finalize_failure(self, to_call: list[S]) -> None:
|
|
292
|
+
"""Release in-flight events on failure to avoid deadlocks.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
to_call (list[S]): Items that were intended to be computed when an
|
|
296
|
+
error occurred.
|
|
297
|
+
"""
|
|
298
|
+
with self._lock:
|
|
299
|
+
for x in to_call:
|
|
300
|
+
ev = self._inflight.pop(x, None)
|
|
301
|
+
if ev:
|
|
302
|
+
ev.set()
|
|
303
|
+
|
|
304
|
+
def clear(self) -> None:
|
|
305
|
+
"""Clear all cached results and release any in-flight waiters.
|
|
306
|
+
|
|
307
|
+
Notes:
|
|
308
|
+
- Intended to be called after all processing is finished.
|
|
309
|
+
- Do not call concurrently with active map() calls to avoid
|
|
310
|
+
unnecessary recomputation or racy wake-ups.
|
|
311
|
+
"""
|
|
312
|
+
with self._lock:
|
|
313
|
+
for ev in self._inflight.values():
|
|
314
|
+
ev.set()
|
|
315
|
+
self._inflight.clear()
|
|
316
|
+
self._cache.clear()
|
|
317
|
+
|
|
318
|
+
def close(self) -> None:
|
|
319
|
+
"""Alias for clear()."""
|
|
320
|
+
self.clear()
|
|
321
|
+
|
|
322
|
+
def __process_owned(self, owned: list[S], map_func: Callable[[list[S]], list[T]]) -> None:
|
|
323
|
+
"""Process owned items in mini-batches and fill the cache.
|
|
324
|
+
|
|
325
|
+
Before calling ``map_func`` for each batch, the cache is re-checked
|
|
326
|
+
to skip any items that may have been filled in the meantime. Items
|
|
327
|
+
are accumulated across multiple original batches to maximize batch
|
|
328
|
+
size utilization when some items are cached. On exceptions raised
|
|
329
|
+
by ``map_func``, all corresponding in-flight events are released
|
|
330
|
+
to prevent deadlocks, and the exception is propagated.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
owned (list[S]): Items for which the current thread has computation
|
|
334
|
+
ownership.
|
|
335
|
+
|
|
336
|
+
Raises:
|
|
337
|
+
Exception: Propagates any exception raised by ``map_func``.
|
|
338
|
+
"""
|
|
339
|
+
if not owned:
|
|
340
|
+
return
|
|
341
|
+
# Setup progress bar
|
|
342
|
+
progress_bar = self._create_progress_bar(len(owned))
|
|
343
|
+
|
|
344
|
+
# Accumulate uncached items to maximize batch size utilization
|
|
345
|
+
pending_to_call: list[S] = []
|
|
346
|
+
|
|
347
|
+
i = 0
|
|
348
|
+
while i < len(owned):
|
|
349
|
+
# Get dynamic batch size for each iteration
|
|
350
|
+
current_batch_size = self._normalized_batch_size(len(owned))
|
|
351
|
+
batch = owned[i : i + current_batch_size]
|
|
352
|
+
# Double-check cache right before processing
|
|
353
|
+
with self._lock:
|
|
354
|
+
uncached_in_batch = [x for x in batch if x not in self._cache]
|
|
355
|
+
|
|
356
|
+
pending_to_call.extend(uncached_in_batch)
|
|
357
|
+
|
|
358
|
+
# Process accumulated items when we reach batch_size or at the end
|
|
359
|
+
is_last_batch = i + current_batch_size >= len(owned)
|
|
360
|
+
if len(pending_to_call) >= current_batch_size or (is_last_batch and pending_to_call):
|
|
361
|
+
# Take up to batch_size items to process
|
|
362
|
+
to_call = pending_to_call[:current_batch_size]
|
|
363
|
+
pending_to_call = pending_to_call[current_batch_size:]
|
|
364
|
+
|
|
365
|
+
try:
|
|
366
|
+
# Always measure execution time using suggester
|
|
367
|
+
with self.suggester.record(len(to_call)):
|
|
368
|
+
results = map_func(to_call)
|
|
369
|
+
except Exception:
|
|
370
|
+
self.__finalize_failure(to_call)
|
|
371
|
+
raise
|
|
372
|
+
self.__finalize_success(to_call, results)
|
|
373
|
+
|
|
374
|
+
# Update progress bar
|
|
375
|
+
self._update_progress_bar(progress_bar, len(to_call))
|
|
376
|
+
|
|
377
|
+
# Move to next batch
|
|
378
|
+
i += current_batch_size
|
|
379
|
+
|
|
380
|
+
# Process any remaining items
|
|
381
|
+
while pending_to_call:
|
|
382
|
+
# Get dynamic batch size for remaining items
|
|
383
|
+
remaining_batch_size = self._normalized_batch_size(len(pending_to_call))
|
|
384
|
+
to_call = pending_to_call[:remaining_batch_size]
|
|
385
|
+
pending_to_call = pending_to_call[remaining_batch_size:]
|
|
386
|
+
|
|
387
|
+
try:
|
|
388
|
+
with self.suggester.record(len(to_call)):
|
|
389
|
+
results = map_func(to_call)
|
|
390
|
+
except Exception:
|
|
391
|
+
self.__finalize_failure(to_call)
|
|
392
|
+
raise
|
|
393
|
+
self.__finalize_success(to_call, results)
|
|
394
|
+
|
|
395
|
+
# Update progress bar
|
|
396
|
+
self._update_progress_bar(progress_bar, len(to_call))
|
|
397
|
+
|
|
398
|
+
# Close progress bar
|
|
399
|
+
self._close_progress_bar(progress_bar)
|
|
400
|
+
|
|
401
|
+
def __wait_for(self, keys: list[S], map_func: Callable[[list[S]], list[T]]) -> None:
|
|
402
|
+
"""Wait for other threads to complete computations for the given keys.
|
|
403
|
+
|
|
404
|
+
If a key is neither cached nor in-flight, this method now claims ownership
|
|
405
|
+
for that key immediately (registers an in-flight Event) and defers the
|
|
406
|
+
computation so that all such rescued keys can be processed together in a
|
|
407
|
+
single batched call to ``map_func`` after the scan completes. This avoids
|
|
408
|
+
high-cost single-item calls.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
keys (list[S]): Items whose computations are owned by other threads.
|
|
412
|
+
"""
|
|
413
|
+
rescued: list[S] = [] # keys we claim to batch-process
|
|
414
|
+
for x in keys:
|
|
415
|
+
while True:
|
|
416
|
+
with self._lock:
|
|
417
|
+
if x in self._cache:
|
|
418
|
+
break
|
|
419
|
+
ev = self._inflight.get(x)
|
|
420
|
+
if ev is None:
|
|
421
|
+
# Not cached and no one computing; claim ownership to batch later.
|
|
422
|
+
self._inflight[x] = threading.Event()
|
|
423
|
+
rescued.append(x)
|
|
424
|
+
break
|
|
425
|
+
# Someone else is computing; wait for completion.
|
|
426
|
+
ev.wait()
|
|
427
|
+
# Batch-process rescued keys, if any
|
|
428
|
+
if rescued:
|
|
429
|
+
try:
|
|
430
|
+
self.__process_owned(rescued, map_func)
|
|
431
|
+
except Exception:
|
|
432
|
+
# Ensure events are released on failure to avoid deadlock
|
|
433
|
+
self.__finalize_failure(rescued)
|
|
434
|
+
raise
|
|
435
|
+
|
|
436
|
+
# ---- public API ------------------------------------------------------
|
|
437
|
+
def map(self, items: list[S], map_func: Callable[[list[S]], list[T]]) -> list[T]:
|
|
438
|
+
"""Map ``items`` to values using caching and optional mini-batching.
|
|
439
|
+
|
|
440
|
+
This method is thread-safe. It deduplicates inputs while preserving order,
|
|
441
|
+
coordinates concurrent work to prevent duplicate computation, and processes
|
|
442
|
+
owned items in mini-batches determined by ``batch_size``. Before each batch
|
|
443
|
+
call to ``map_func``, the cache is re-checked to avoid redundant requests.
|
|
444
|
+
|
|
445
|
+
Args:
|
|
446
|
+
items (list[S]): Input items to map.
|
|
447
|
+
map_func (Callable[[list[S]], list[T]]): Function that maps a batch of
|
|
448
|
+
items to their corresponding results. Must return results in the
|
|
449
|
+
same order as inputs.
|
|
450
|
+
|
|
451
|
+
Returns:
|
|
452
|
+
list[T]: Mapped values corresponding to ``items`` in the same order.
|
|
453
|
+
|
|
454
|
+
Raises:
|
|
455
|
+
Exception: Propagates any exception raised by ``map_func``.
|
|
456
|
+
|
|
457
|
+
Example:
|
|
458
|
+
```python
|
|
459
|
+
proxy: BatchingMapProxy[int, str] = BatchingMapProxy(batch_size=2)
|
|
460
|
+
calls: list[list[int]] = []
|
|
461
|
+
|
|
462
|
+
def mapper(chunk: list[int]) -> list[str]:
|
|
463
|
+
calls.append(chunk)
|
|
464
|
+
return [f"v:{x}" for x in chunk]
|
|
465
|
+
|
|
466
|
+
proxy.map([1, 2, 2, 3], mapper)
|
|
467
|
+
# ['v:1', 'v:2', 'v:2', 'v:3']
|
|
468
|
+
calls # duplicate ``2`` is only computed once
|
|
469
|
+
# [[1, 2], [3]]
|
|
470
|
+
```
|
|
471
|
+
"""
|
|
472
|
+
if self.__all_cached(items):
|
|
473
|
+
return self.__values(items)
|
|
474
|
+
|
|
475
|
+
unique_items = self._unique_in_order(items)
|
|
476
|
+
owned, wait_for = self.__acquire_ownership(unique_items)
|
|
477
|
+
|
|
478
|
+
self.__process_owned(owned, map_func)
|
|
479
|
+
self.__wait_for(wait_for, map_func)
|
|
480
|
+
|
|
481
|
+
# Fetch results before purging None entries
|
|
482
|
+
results = self.__values(items)
|
|
483
|
+
|
|
484
|
+
# Remove None values from cache so they are recomputed on future calls
|
|
485
|
+
with self._lock:
|
|
486
|
+
if self._cache: # micro-optimization
|
|
487
|
+
for k in set(items):
|
|
488
|
+
try:
|
|
489
|
+
if self._cache.get(k, object()) is None:
|
|
490
|
+
del self._cache[k]
|
|
491
|
+
except KeyError:
|
|
492
|
+
pass
|
|
493
|
+
|
|
494
|
+
return results
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
@dataclass
|
|
498
|
+
class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
499
|
+
"""Asynchronous version of BatchingMapProxy for use with async functions.
|
|
500
|
+
|
|
501
|
+
The ``map()`` method accepts an async ``map_func`` that may perform I/O and
|
|
502
|
+
awaits it
|
|
503
|
+
in mini-batches. It deduplicates inputs, maintains cache consistency, and
|
|
504
|
+
coordinates concurrent coroutines to avoid duplicate work via an in-flight
|
|
505
|
+
registry of asyncio events.
|
|
506
|
+
|
|
507
|
+
When ``batch_size=None``, automatic batch size optimization is enabled,
|
|
508
|
+
dynamically adjusting batch sizes based on execution time to maintain optimal
|
|
509
|
+
performance (targeting 30-60 seconds per batch).
|
|
510
|
+
|
|
511
|
+
Example:
|
|
512
|
+
```python
|
|
513
|
+
import asyncio
|
|
514
|
+
|
|
515
|
+
p = AsyncBatchingMapProxy[int, str](batch_size=2)
|
|
516
|
+
|
|
517
|
+
async def af(xs: list[int]) -> list[str]:
|
|
518
|
+
await asyncio.sleep(0)
|
|
519
|
+
return [f"v:{x}" for x in xs]
|
|
520
|
+
|
|
521
|
+
async def run():
|
|
522
|
+
return await p.map([1, 2, 3], af)
|
|
523
|
+
|
|
524
|
+
asyncio.run(run())
|
|
525
|
+
# ['v:1', 'v:2', 'v:3']
|
|
526
|
+
```
|
|
527
|
+
"""
|
|
528
|
+
|
|
529
|
+
# Number of items to process per call to map_func.
|
|
530
|
+
# - If None (default): Enables automatic batch size optimization, dynamically adjusting
|
|
531
|
+
# based on execution time (targeting 30-60 seconds per batch)
|
|
532
|
+
# - If positive integer: Fixed batch size
|
|
533
|
+
# - If <= 0: Process all items at once
|
|
534
|
+
batch_size: int | None = None
|
|
535
|
+
max_concurrency: int = 8
|
|
536
|
+
show_progress: bool = True
|
|
537
|
+
suggester: BatchSizeSuggester = field(default_factory=BatchSizeSuggester, repr=False)
|
|
538
|
+
|
|
539
|
+
# internals
|
|
540
|
+
_cache: dict[S, T] = field(default_factory=dict, repr=False)
|
|
541
|
+
_lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False)
|
|
542
|
+
_inflight: dict[S, asyncio.Event] = field(default_factory=dict, repr=False)
|
|
543
|
+
__sema: asyncio.Semaphore | None = field(default=None, init=False, repr=False)
|
|
544
|
+
|
|
545
|
+
def __post_init__(self) -> None:
|
|
546
|
+
"""Initialize internal semaphore based on ``max_concurrency``.
|
|
547
|
+
|
|
548
|
+
If ``max_concurrency`` is a positive integer, an ``asyncio.Semaphore``
|
|
549
|
+
is created to limit the number of concurrent ``map_func`` calls across
|
|
550
|
+
overlapping ``map`` invocations. When non-positive or ``None``, no
|
|
551
|
+
semaphore is used and concurrency is unrestricted by this proxy.
|
|
552
|
+
|
|
553
|
+
Notes:
|
|
554
|
+
This method is invoked automatically by ``dataclasses`` after
|
|
555
|
+
initialization and does not need to be called directly.
|
|
556
|
+
"""
|
|
557
|
+
# Initialize semaphore if limiting is requested; non-positive disables limiting
|
|
558
|
+
if self.max_concurrency and self.max_concurrency > 0:
|
|
559
|
+
self.__sema = asyncio.Semaphore(self.max_concurrency)
|
|
560
|
+
else:
|
|
561
|
+
self.__sema = None
|
|
562
|
+
|
|
563
|
+
async def __all_cached(self, items: list[S]) -> bool:
|
|
564
|
+
"""Check whether all items are present in the cache.
|
|
565
|
+
|
|
566
|
+
This method acquires the internal asyncio lock for a consistent view
|
|
567
|
+
of the cache.
|
|
568
|
+
|
|
569
|
+
Args:
|
|
570
|
+
items (list[S]): Items to verify against the cache.
|
|
571
|
+
|
|
572
|
+
Returns:
|
|
573
|
+
bool: True if every item in ``items`` is already cached, False otherwise.
|
|
574
|
+
"""
|
|
575
|
+
async with self._lock:
|
|
576
|
+
return all(x in self._cache for x in items)
|
|
577
|
+
|
|
578
|
+
async def __values(self, items: list[S]) -> list[T]:
|
|
579
|
+
"""Get cached values for ``items`` preserving their given order.
|
|
580
|
+
|
|
581
|
+
The internal asyncio lock is held while reading the cache to preserve
|
|
582
|
+
consistency under concurrency.
|
|
583
|
+
|
|
584
|
+
Args:
|
|
585
|
+
items (list[S]): Items to read from the cache.
|
|
586
|
+
|
|
587
|
+
Returns:
|
|
588
|
+
list[T]: Cached values corresponding to ``items`` in the same order.
|
|
589
|
+
"""
|
|
590
|
+
async with self._lock:
|
|
591
|
+
return [self._cache[x] for x in items]
|
|
592
|
+
|
|
593
|
+
async def __acquire_ownership(self, items: list[S]) -> tuple[list[S], list[S]]:
|
|
594
|
+
"""Acquire ownership for missing keys and identify keys to wait for.
|
|
595
|
+
|
|
596
|
+
Args:
|
|
597
|
+
items (list[S]): Unique items (order-preserving) to be processed.
|
|
598
|
+
|
|
599
|
+
Returns:
|
|
600
|
+
tuple[list[S], list[S]]: A tuple ``(owned, wait_for)`` where owned are
|
|
601
|
+
keys this coroutine should compute, and wait_for are keys currently
|
|
602
|
+
being computed elsewhere.
|
|
603
|
+
"""
|
|
604
|
+
owned: list[S] = []
|
|
605
|
+
wait_for: list[S] = []
|
|
606
|
+
async with self._lock:
|
|
607
|
+
for x in items:
|
|
608
|
+
if x in self._cache:
|
|
609
|
+
continue
|
|
610
|
+
if x in self._inflight:
|
|
611
|
+
wait_for.append(x)
|
|
612
|
+
else:
|
|
613
|
+
self._inflight[x] = asyncio.Event()
|
|
614
|
+
owned.append(x)
|
|
615
|
+
return owned, wait_for
|
|
616
|
+
|
|
617
|
+
async def __finalize_success(self, to_call: list[S], results: list[T]) -> None:
|
|
618
|
+
"""Populate cache and signal completion for successfully computed keys.
|
|
619
|
+
|
|
620
|
+
Args:
|
|
621
|
+
to_call (list[S]): Items that were computed in the recent batch.
|
|
622
|
+
results (list[T]): Results corresponding to ``to_call`` in order.
|
|
623
|
+
"""
|
|
624
|
+
if len(results) != len(to_call):
|
|
625
|
+
# Prevent deadlocks if map_func violates the contract.
|
|
626
|
+
await self.__finalize_failure(to_call)
|
|
627
|
+
raise ValueError("map_func must return a list of results with the same length and order as inputs")
|
|
628
|
+
async with self._lock:
|
|
629
|
+
for x, y in zip(to_call, results):
|
|
630
|
+
self._cache[x] = y
|
|
631
|
+
ev = self._inflight.pop(x, None)
|
|
632
|
+
if ev:
|
|
633
|
+
ev.set()
|
|
634
|
+
|
|
635
|
+
async def __finalize_failure(self, to_call: list[S]) -> None:
|
|
636
|
+
"""Release in-flight events on failure to avoid deadlocks.
|
|
637
|
+
|
|
638
|
+
Args:
|
|
639
|
+
to_call (list[S]): Items whose computation failed; their waiters will
|
|
640
|
+
be released.
|
|
641
|
+
"""
|
|
642
|
+
async with self._lock:
|
|
643
|
+
for x in to_call:
|
|
644
|
+
ev = self._inflight.pop(x, None)
|
|
645
|
+
if ev:
|
|
646
|
+
ev.set()
|
|
647
|
+
|
|
648
|
+
async def clear(self) -> None:
|
|
649
|
+
"""Clear all cached results and release any in-flight waiters.
|
|
650
|
+
|
|
651
|
+
Notes:
|
|
652
|
+
- Intended to be awaited after all processing is finished.
|
|
653
|
+
- Do not call concurrently with active map() calls to avoid
|
|
654
|
+
unnecessary recomputation or racy wake-ups.
|
|
655
|
+
"""
|
|
656
|
+
async with self._lock:
|
|
657
|
+
for ev in self._inflight.values():
|
|
658
|
+
ev.set()
|
|
659
|
+
self._inflight.clear()
|
|
660
|
+
self._cache.clear()
|
|
661
|
+
|
|
662
|
+
async def aclose(self) -> None:
|
|
663
|
+
"""Alias for clear()."""
|
|
664
|
+
await self.clear()
|
|
665
|
+
|
|
666
|
+
async def __process_owned(self, owned: list[S], map_func: Callable[[list[S]], Awaitable[list[T]]]) -> None:
|
|
667
|
+
"""Process owned keys using Producer-Consumer pattern with dynamic batch sizing.
|
|
668
|
+
|
|
669
|
+
Args:
|
|
670
|
+
owned (list[S]): Items for which this coroutine holds computation ownership.
|
|
671
|
+
|
|
672
|
+
Raises:
|
|
673
|
+
Exception: Propagates any exception raised by ``map_func``.
|
|
674
|
+
"""
|
|
675
|
+
if not owned:
|
|
676
|
+
return
|
|
677
|
+
|
|
678
|
+
progress_bar = self._create_progress_bar(len(owned))
|
|
679
|
+
batch_queue: asyncio.Queue = asyncio.Queue(maxsize=self.max_concurrency)
|
|
680
|
+
|
|
681
|
+
async def producer():
|
|
682
|
+
index = 0
|
|
683
|
+
while index < len(owned):
|
|
684
|
+
batch_size = self._normalized_batch_size(len(owned) - index)
|
|
685
|
+
batch = owned[index : index + batch_size]
|
|
686
|
+
await batch_queue.put(batch)
|
|
687
|
+
index += batch_size
|
|
688
|
+
# Send completion signals
|
|
689
|
+
for _ in range(self.max_concurrency):
|
|
690
|
+
await batch_queue.put(None)
|
|
691
|
+
|
|
692
|
+
async def consumer():
|
|
693
|
+
while True:
|
|
694
|
+
batch = await batch_queue.get()
|
|
695
|
+
try:
|
|
696
|
+
if batch is None:
|
|
697
|
+
break
|
|
698
|
+
await self.__process_single_batch(batch, map_func, progress_bar)
|
|
699
|
+
finally:
|
|
700
|
+
batch_queue.task_done()
|
|
701
|
+
|
|
702
|
+
await asyncio.gather(producer(), *[consumer() for _ in range(self.max_concurrency)])
|
|
703
|
+
|
|
704
|
+
self._close_progress_bar(progress_bar)
|
|
705
|
+
|
|
706
|
+
async def __process_single_batch(
|
|
707
|
+
self, to_call: list[S], map_func: Callable[[list[S]], Awaitable[list[T]]], progress_bar
|
|
708
|
+
) -> None:
|
|
709
|
+
"""Process a single batch with semaphore control."""
|
|
710
|
+
acquired = False
|
|
711
|
+
try:
|
|
712
|
+
if self.__sema:
|
|
713
|
+
await self.__sema.acquire()
|
|
714
|
+
acquired = True
|
|
715
|
+
# Measure async map_func execution using suggester
|
|
716
|
+
with self.suggester.record(len(to_call)):
|
|
717
|
+
results = await map_func(to_call)
|
|
718
|
+
except Exception:
|
|
719
|
+
await self.__finalize_failure(to_call)
|
|
720
|
+
raise
|
|
721
|
+
finally:
|
|
722
|
+
if self.__sema and acquired:
|
|
723
|
+
self.__sema.release()
|
|
724
|
+
await self.__finalize_success(to_call, results)
|
|
725
|
+
|
|
726
|
+
# Update progress bar
|
|
727
|
+
self._update_progress_bar(progress_bar, len(to_call))
|
|
728
|
+
|
|
729
|
+
async def __wait_for(self, keys: list[S], map_func: Callable[[list[S]], Awaitable[list[T]]]) -> None:
|
|
730
|
+
"""Wait for computations owned by other coroutines to complete.
|
|
731
|
+
|
|
732
|
+
If a key is neither cached nor in-flight, this method now claims ownership
|
|
733
|
+
for that key immediately (registers an in-flight Event) and defers the
|
|
734
|
+
computation so that all such rescued keys can be processed together in a
|
|
735
|
+
single batched call to ``map_func`` after the scan completes. This avoids
|
|
736
|
+
high-cost single-item calls.
|
|
737
|
+
|
|
738
|
+
Args:
|
|
739
|
+
keys (list[S]): Items whose computations are owned by other coroutines.
|
|
740
|
+
"""
|
|
741
|
+
rescued: list[S] = [] # keys we claim to batch-process
|
|
742
|
+
for x in keys:
|
|
743
|
+
while True:
|
|
744
|
+
async with self._lock:
|
|
745
|
+
if x in self._cache:
|
|
746
|
+
break
|
|
747
|
+
ev = self._inflight.get(x)
|
|
748
|
+
if ev is None:
|
|
749
|
+
# Not cached and no one computing; claim ownership to batch later.
|
|
750
|
+
self._inflight[x] = asyncio.Event()
|
|
751
|
+
rescued.append(x)
|
|
752
|
+
break
|
|
753
|
+
# Someone else is computing; wait for completion.
|
|
754
|
+
await ev.wait()
|
|
755
|
+
# Batch-process rescued keys, if any
|
|
756
|
+
if rescued:
|
|
757
|
+
try:
|
|
758
|
+
await self.__process_owned(rescued, map_func)
|
|
759
|
+
except Exception:
|
|
760
|
+
await self.__finalize_failure(rescued)
|
|
761
|
+
raise
|
|
762
|
+
|
|
763
|
+
# ---- public API ------------------------------------------------------
|
|
764
|
+
async def map(self, items: list[S], map_func: Callable[[list[S]], Awaitable[list[T]]]) -> list[T]:
|
|
765
|
+
"""Async map with caching, de-duplication, and optional mini-batching.
|
|
766
|
+
|
|
767
|
+
Args:
|
|
768
|
+
items (list[S]): Input items to map.
|
|
769
|
+
map_func (Callable[[list[S]], Awaitable[list[T]]]): Async function that
|
|
770
|
+
maps a batch of items to their results, preserving input order.
|
|
771
|
+
|
|
772
|
+
Returns:
|
|
773
|
+
list[T]: Mapped values corresponding to ``items`` in the same order.
|
|
774
|
+
|
|
775
|
+
Example:
|
|
776
|
+
```python
|
|
777
|
+
import asyncio
|
|
778
|
+
|
|
779
|
+
async def mapper(chunk: list[int]) -> list[str]:
|
|
780
|
+
await asyncio.sleep(0)
|
|
781
|
+
return [f"v:{x}" for x in chunk]
|
|
782
|
+
|
|
783
|
+
proxy: AsyncBatchingMapProxy[int, str] = AsyncBatchingMapProxy(batch_size=2)
|
|
784
|
+
asyncio.run(proxy.map([1, 1, 2], mapper))
|
|
785
|
+
# ['v:1', 'v:1', 'v:2']
|
|
786
|
+
```
|
|
787
|
+
"""
|
|
788
|
+
if await self.__all_cached(items):
|
|
789
|
+
return await self.__values(items)
|
|
790
|
+
|
|
791
|
+
unique_items = self._unique_in_order(items)
|
|
792
|
+
owned, wait_for = await self.__acquire_ownership(unique_items)
|
|
793
|
+
|
|
794
|
+
await self.__process_owned(owned, map_func)
|
|
795
|
+
await self.__wait_for(wait_for, map_func)
|
|
796
|
+
|
|
797
|
+
results = await self.__values(items)
|
|
798
|
+
|
|
799
|
+
# Remove None values from cache after retrieval to avoid persisting incomplete results
|
|
800
|
+
async with self._lock:
|
|
801
|
+
if self._cache:
|
|
802
|
+
for k in set(items):
|
|
803
|
+
if self._cache.get(k, object()) is None:
|
|
804
|
+
self._cache.pop(k, None)
|
|
805
|
+
|
|
806
|
+
return results
|