openaivec 0.10.0__py3-none-any.whl → 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. openaivec/__init__.py +13 -4
  2. openaivec/_cache/__init__.py +12 -0
  3. openaivec/_cache/optimize.py +109 -0
  4. openaivec/_cache/proxy.py +806 -0
  5. openaivec/_di.py +326 -0
  6. openaivec/_embeddings.py +203 -0
  7. openaivec/{log.py → _log.py} +2 -2
  8. openaivec/_model.py +113 -0
  9. openaivec/{prompt.py → _prompt.py} +95 -28
  10. openaivec/_provider.py +207 -0
  11. openaivec/_responses.py +511 -0
  12. openaivec/_schema/__init__.py +9 -0
  13. openaivec/_schema/infer.py +340 -0
  14. openaivec/_schema/spec.py +350 -0
  15. openaivec/_serialize.py +234 -0
  16. openaivec/{util.py → _util.py} +25 -85
  17. openaivec/pandas_ext.py +1635 -425
  18. openaivec/spark.py +604 -335
  19. openaivec/task/__init__.py +27 -29
  20. openaivec/task/customer_support/__init__.py +9 -15
  21. openaivec/task/customer_support/customer_sentiment.py +51 -41
  22. openaivec/task/customer_support/inquiry_classification.py +86 -61
  23. openaivec/task/customer_support/inquiry_summary.py +44 -45
  24. openaivec/task/customer_support/intent_analysis.py +56 -41
  25. openaivec/task/customer_support/response_suggestion.py +49 -43
  26. openaivec/task/customer_support/urgency_analysis.py +76 -71
  27. openaivec/task/nlp/__init__.py +4 -4
  28. openaivec/task/nlp/dependency_parsing.py +19 -20
  29. openaivec/task/nlp/keyword_extraction.py +22 -24
  30. openaivec/task/nlp/morphological_analysis.py +25 -25
  31. openaivec/task/nlp/named_entity_recognition.py +26 -28
  32. openaivec/task/nlp/sentiment_analysis.py +29 -21
  33. openaivec/task/nlp/translation.py +24 -30
  34. openaivec/task/table/__init__.py +3 -0
  35. openaivec/task/table/fillna.py +183 -0
  36. openaivec-1.0.10.dist-info/METADATA +399 -0
  37. openaivec-1.0.10.dist-info/RECORD +39 -0
  38. {openaivec-0.10.0.dist-info → openaivec-1.0.10.dist-info}/WHEEL +1 -1
  39. openaivec/embeddings.py +0 -172
  40. openaivec/responses.py +0 -392
  41. openaivec/serialize.py +0 -225
  42. openaivec/task/model.py +0 -84
  43. openaivec-0.10.0.dist-info/METADATA +0 -546
  44. openaivec-0.10.0.dist-info/RECORD +0 -29
  45. {openaivec-0.10.0.dist-info → openaivec-1.0.10.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,806 @@
1
+ import asyncio
2
+ import threading
3
+ from collections.abc import Awaitable, Callable, Hashable
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Generic, TypeVar
6
+
7
+ from openaivec._cache import BatchSizeSuggester
8
+
9
+ __all__ = []
10
+
11
+ S = TypeVar("S", bound=Hashable)
12
+ T = TypeVar("T")
13
+
14
+
15
+ class ProxyBase(Generic[S, T]):
16
+ """Common utilities shared by BatchingMapProxy and AsyncBatchingMapProxy.
17
+
18
+ Provides order-preserving deduplication and batch size normalization that
19
+ depend only on ``batch_size`` and do not touch concurrency primitives.
20
+
21
+ Attributes:
22
+ batch_size: Optional mini-batch size hint used by implementations to
23
+ split work into chunks. When unset or non-positive, implementations
24
+ should process the entire input in a single call.
25
+ """
26
+
27
+ batch_size: int | None # subclasses may override via dataclass
28
+ show_progress: bool # Enable progress bar display
29
+ suggester: BatchSizeSuggester # Batch size optimization, initialized by subclasses
30
+
31
+ def _is_notebook_environment(self) -> bool:
32
+ """Check if running in a Jupyter notebook environment.
33
+
34
+ Returns:
35
+ bool: True if running in a notebook, False otherwise.
36
+ """
37
+ try:
38
+ from IPython.core.getipython import get_ipython
39
+
40
+ ipython = get_ipython()
41
+ if ipython is not None:
42
+ # Check for different notebook environments
43
+ class_name = ipython.__class__.__name__
44
+ module_name = ipython.__class__.__module__
45
+
46
+ # Standard Jupyter notebook/lab
47
+ if class_name == "ZMQInteractiveShell":
48
+ return True
49
+
50
+ # JupyterLab and newer environments
51
+ if "zmq" in module_name.lower() or "jupyter" in module_name.lower():
52
+ return True
53
+
54
+ # Google Colab
55
+ if "google.colab" in module_name:
56
+ return True
57
+
58
+ # VS Code notebooks and others
59
+ if hasattr(ipython, "kernel"):
60
+ return True
61
+
62
+ except ImportError:
63
+ pass
64
+
65
+ # Check for other notebook indicators
66
+ # Check for common notebook environment variables
67
+ import os
68
+ import sys
69
+
70
+ notebook_vars = [
71
+ "JUPYTER_CONFIG_DIR",
72
+ "JUPYTERLAB_DIR",
73
+ "COLAB_GPU",
74
+ "VSCODE_PID", # VS Code
75
+ ]
76
+
77
+ for var in notebook_vars:
78
+ if var in os.environ:
79
+ return True
80
+
81
+ # Check if running in IPython without terminal
82
+ if "IPython" in sys.modules:
83
+ try:
84
+ # If we can import display from IPython, likely in notebook
85
+ import importlib.util
86
+
87
+ if importlib.util.find_spec("IPython.display") is not None:
88
+ return True
89
+ except ImportError:
90
+ pass
91
+
92
+ return False
93
+
94
+ def _create_progress_bar(self, total: int, desc: str = "Processing batches") -> Any:
95
+ """Create a progress bar if conditions are met.
96
+
97
+ Args:
98
+ total (int): Total number of items to process.
99
+ desc (str): Description for the progress bar.
100
+
101
+ Returns:
102
+ Any: Progress bar instance or None if not available.
103
+ """
104
+ try:
105
+ from tqdm.auto import tqdm as tqdm_progress
106
+
107
+ if self.show_progress and self._is_notebook_environment():
108
+ return tqdm_progress(total=total, desc=desc, unit="item")
109
+ except ImportError:
110
+ pass
111
+ return None
112
+
113
+ def _update_progress_bar(self, progress_bar: Any, increment: int) -> None:
114
+ """Update progress bar with the given increment.
115
+
116
+ Args:
117
+ progress_bar (Any): Progress bar instance.
118
+ increment (int): Number of items to increment.
119
+ """
120
+ if progress_bar:
121
+ progress_bar.update(increment)
122
+
123
+ def _close_progress_bar(self, progress_bar: Any) -> None:
124
+ """Close the progress bar.
125
+
126
+ Args:
127
+ progress_bar (Any): Progress bar instance.
128
+ """
129
+ if progress_bar:
130
+ progress_bar.close()
131
+
132
+ @staticmethod
133
+ def _unique_in_order(seq: list[S]) -> list[S]:
134
+ """Return unique items preserving their first-occurrence order.
135
+
136
+ Args:
137
+ seq (list[S]): Sequence of items which may contain duplicates.
138
+
139
+ Returns:
140
+ list[S]: A new list containing each distinct item from ``seq`` exactly
141
+ once, in the order of their first occurrence.
142
+ """
143
+ seen: set[S] = set()
144
+ out: list[S] = []
145
+ for x in seq:
146
+ if x not in seen:
147
+ seen.add(x)
148
+ out.append(x)
149
+ return out
150
+
151
+ def _normalized_batch_size(self, total: int) -> int:
152
+ """Compute the effective batch size used for processing.
153
+
154
+ If ``batch_size`` is None, use the suggester to determine optimal batch size.
155
+ If ``batch_size`` is non-positive, process the entire ``total`` in a single call.
156
+
157
+ Args:
158
+ total (int): Number of items intended to be processed.
159
+
160
+ Returns:
161
+ int: The positive batch size to use.
162
+ """
163
+ if self.batch_size and self.batch_size > 0:
164
+ return self.batch_size
165
+ elif self.batch_size is None:
166
+ # Use suggester to determine optimal batch size
167
+ suggested = self.suggester.suggest_batch_size()
168
+ return min(suggested, total) # Don't exceed total items
169
+ else:
170
+ # batch_size is 0 or negative, process all at once
171
+ return total
172
+
173
+
174
+ @dataclass
175
+ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
176
+ """Thread-safe local proxy that caches results of a mapping function.
177
+
178
+ This proxy batches calls to the ``map_func`` you pass to ``map()`` (if
179
+ ``batch_size`` is set),
180
+ deduplicates inputs while preserving order, and ensures that concurrent calls do
181
+ not duplicate work via an in-flight registry. All public behavior is preserved
182
+ while minimizing redundant requests and maintaining input order in the output.
183
+
184
+ When ``batch_size=None``, automatic batch size optimization is enabled,
185
+ dynamically adjusting batch sizes based on execution time to maintain optimal
186
+ performance (targeting 30-60 seconds per batch).
187
+
188
+ Example:
189
+ ```python
190
+ p = BatchingMapProxy[int, str](batch_size=3)
191
+
192
+ def f(xs: list[int]) -> list[str]:
193
+ return [f"v:{x}" for x in xs]
194
+
195
+ p.map([1, 2, 2, 3, 4], f)
196
+ # ['v:1', 'v:2', 'v:2', 'v:3', 'v:4']
197
+ ```
198
+ """
199
+
200
+ # Number of items to process per call to map_func.
201
+ # - If None (default): Enables automatic batch size optimization, dynamically adjusting
202
+ # based on execution time (targeting 30-60 seconds per batch)
203
+ # - If positive integer: Fixed batch size
204
+ # - If <= 0: Process all items at once
205
+ batch_size: int | None = None
206
+ show_progress: bool = True
207
+ suggester: BatchSizeSuggester = field(default_factory=BatchSizeSuggester, repr=False)
208
+
209
+ # internals
210
+ _cache: dict[S, T] = field(default_factory=dict)
211
+ _lock: threading.RLock = field(default_factory=threading.RLock, repr=False)
212
+ _inflight: dict[S, threading.Event] = field(default_factory=dict, repr=False)
213
+
214
+ def __all_cached(self, items: list[S]) -> bool:
215
+ """Check whether all items are present in the cache.
216
+
217
+ This method acquires the internal lock to perform a consistent check.
218
+
219
+ Args:
220
+ items (list[S]): Items to verify against the cache.
221
+
222
+ Returns:
223
+ bool: True if every item is already cached, False otherwise.
224
+ """
225
+ with self._lock:
226
+ return all(x in self._cache for x in items)
227
+
228
+ def __values(self, items: list[S]) -> list[T]:
229
+ """Fetch cached values for ``items`` preserving the given order.
230
+
231
+ This method acquires the internal lock while reading the cache.
232
+
233
+ Args:
234
+ items (list[S]): Items to retrieve from the cache.
235
+
236
+ Returns:
237
+ list[T]: The cached values corresponding to ``items`` in the same
238
+ order.
239
+ """
240
+ with self._lock:
241
+ return [self._cache[x] for x in items]
242
+
243
+ def __acquire_ownership(self, items: list[S]) -> tuple[list[S], list[S]]:
244
+ """Acquire ownership for missing items and identify keys to wait for.
245
+
246
+ For each unique item, if it's already cached, it is ignored. If it's
247
+ currently being computed by another thread (in-flight), it is added to
248
+ the wait list. Otherwise, this method marks the key as in-flight and
249
+ considers it "owned" by the current thread.
250
+
251
+ Args:
252
+ items (list[S]): Unique items (order-preserving) to be processed.
253
+
254
+ Returns:
255
+ tuple[list[S], list[S]]: A tuple ``(owned, wait_for)`` where
256
+ - ``owned`` are items this thread is responsible for computing.
257
+ - ``wait_for`` are items that another thread is already computing.
258
+ """
259
+ owned: list[S] = []
260
+ wait_for: list[S] = []
261
+ with self._lock:
262
+ for x in items:
263
+ if x in self._cache:
264
+ continue
265
+ if x in self._inflight:
266
+ wait_for.append(x)
267
+ else:
268
+ self._inflight[x] = threading.Event()
269
+ owned.append(x)
270
+ return owned, wait_for
271
+
272
+ def __finalize_success(self, to_call: list[S], results: list[T]) -> None:
273
+ """Populate cache with results and signal completion events.
274
+
275
+ Args:
276
+ to_call (list[S]): Items that were computed.
277
+ results (list[T]): Results corresponding to ``to_call`` in order.
278
+ """
279
+ if len(results) != len(to_call):
280
+ # Prevent deadlocks if map_func violates the contract.
281
+ # Release waiters and surface a clear error.
282
+ self.__finalize_failure(to_call)
283
+ raise ValueError("map_func must return a list of results with the same length and order as inputs")
284
+ with self._lock:
285
+ for x, y in zip(to_call, results):
286
+ self._cache[x] = y
287
+ ev = self._inflight.pop(x, None)
288
+ if ev:
289
+ ev.set()
290
+
291
+ def __finalize_failure(self, to_call: list[S]) -> None:
292
+ """Release in-flight events on failure to avoid deadlocks.
293
+
294
+ Args:
295
+ to_call (list[S]): Items that were intended to be computed when an
296
+ error occurred.
297
+ """
298
+ with self._lock:
299
+ for x in to_call:
300
+ ev = self._inflight.pop(x, None)
301
+ if ev:
302
+ ev.set()
303
+
304
+ def clear(self) -> None:
305
+ """Clear all cached results and release any in-flight waiters.
306
+
307
+ Notes:
308
+ - Intended to be called after all processing is finished.
309
+ - Do not call concurrently with active map() calls to avoid
310
+ unnecessary recomputation or racy wake-ups.
311
+ """
312
+ with self._lock:
313
+ for ev in self._inflight.values():
314
+ ev.set()
315
+ self._inflight.clear()
316
+ self._cache.clear()
317
+
318
+ def close(self) -> None:
319
+ """Alias for clear()."""
320
+ self.clear()
321
+
322
+ def __process_owned(self, owned: list[S], map_func: Callable[[list[S]], list[T]]) -> None:
323
+ """Process owned items in mini-batches and fill the cache.
324
+
325
+ Before calling ``map_func`` for each batch, the cache is re-checked
326
+ to skip any items that may have been filled in the meantime. Items
327
+ are accumulated across multiple original batches to maximize batch
328
+ size utilization when some items are cached. On exceptions raised
329
+ by ``map_func``, all corresponding in-flight events are released
330
+ to prevent deadlocks, and the exception is propagated.
331
+
332
+ Args:
333
+ owned (list[S]): Items for which the current thread has computation
334
+ ownership.
335
+
336
+ Raises:
337
+ Exception: Propagates any exception raised by ``map_func``.
338
+ """
339
+ if not owned:
340
+ return
341
+ # Setup progress bar
342
+ progress_bar = self._create_progress_bar(len(owned))
343
+
344
+ # Accumulate uncached items to maximize batch size utilization
345
+ pending_to_call: list[S] = []
346
+
347
+ i = 0
348
+ while i < len(owned):
349
+ # Get dynamic batch size for each iteration
350
+ current_batch_size = self._normalized_batch_size(len(owned))
351
+ batch = owned[i : i + current_batch_size]
352
+ # Double-check cache right before processing
353
+ with self._lock:
354
+ uncached_in_batch = [x for x in batch if x not in self._cache]
355
+
356
+ pending_to_call.extend(uncached_in_batch)
357
+
358
+ # Process accumulated items when we reach batch_size or at the end
359
+ is_last_batch = i + current_batch_size >= len(owned)
360
+ if len(pending_to_call) >= current_batch_size or (is_last_batch and pending_to_call):
361
+ # Take up to batch_size items to process
362
+ to_call = pending_to_call[:current_batch_size]
363
+ pending_to_call = pending_to_call[current_batch_size:]
364
+
365
+ try:
366
+ # Always measure execution time using suggester
367
+ with self.suggester.record(len(to_call)):
368
+ results = map_func(to_call)
369
+ except Exception:
370
+ self.__finalize_failure(to_call)
371
+ raise
372
+ self.__finalize_success(to_call, results)
373
+
374
+ # Update progress bar
375
+ self._update_progress_bar(progress_bar, len(to_call))
376
+
377
+ # Move to next batch
378
+ i += current_batch_size
379
+
380
+ # Process any remaining items
381
+ while pending_to_call:
382
+ # Get dynamic batch size for remaining items
383
+ remaining_batch_size = self._normalized_batch_size(len(pending_to_call))
384
+ to_call = pending_to_call[:remaining_batch_size]
385
+ pending_to_call = pending_to_call[remaining_batch_size:]
386
+
387
+ try:
388
+ with self.suggester.record(len(to_call)):
389
+ results = map_func(to_call)
390
+ except Exception:
391
+ self.__finalize_failure(to_call)
392
+ raise
393
+ self.__finalize_success(to_call, results)
394
+
395
+ # Update progress bar
396
+ self._update_progress_bar(progress_bar, len(to_call))
397
+
398
+ # Close progress bar
399
+ self._close_progress_bar(progress_bar)
400
+
401
+ def __wait_for(self, keys: list[S], map_func: Callable[[list[S]], list[T]]) -> None:
402
+ """Wait for other threads to complete computations for the given keys.
403
+
404
+ If a key is neither cached nor in-flight, this method now claims ownership
405
+ for that key immediately (registers an in-flight Event) and defers the
406
+ computation so that all such rescued keys can be processed together in a
407
+ single batched call to ``map_func`` after the scan completes. This avoids
408
+ high-cost single-item calls.
409
+
410
+ Args:
411
+ keys (list[S]): Items whose computations are owned by other threads.
412
+ """
413
+ rescued: list[S] = [] # keys we claim to batch-process
414
+ for x in keys:
415
+ while True:
416
+ with self._lock:
417
+ if x in self._cache:
418
+ break
419
+ ev = self._inflight.get(x)
420
+ if ev is None:
421
+ # Not cached and no one computing; claim ownership to batch later.
422
+ self._inflight[x] = threading.Event()
423
+ rescued.append(x)
424
+ break
425
+ # Someone else is computing; wait for completion.
426
+ ev.wait()
427
+ # Batch-process rescued keys, if any
428
+ if rescued:
429
+ try:
430
+ self.__process_owned(rescued, map_func)
431
+ except Exception:
432
+ # Ensure events are released on failure to avoid deadlock
433
+ self.__finalize_failure(rescued)
434
+ raise
435
+
436
+ # ---- public API ------------------------------------------------------
437
+ def map(self, items: list[S], map_func: Callable[[list[S]], list[T]]) -> list[T]:
438
+ """Map ``items`` to values using caching and optional mini-batching.
439
+
440
+ This method is thread-safe. It deduplicates inputs while preserving order,
441
+ coordinates concurrent work to prevent duplicate computation, and processes
442
+ owned items in mini-batches determined by ``batch_size``. Before each batch
443
+ call to ``map_func``, the cache is re-checked to avoid redundant requests.
444
+
445
+ Args:
446
+ items (list[S]): Input items to map.
447
+ map_func (Callable[[list[S]], list[T]]): Function that maps a batch of
448
+ items to their corresponding results. Must return results in the
449
+ same order as inputs.
450
+
451
+ Returns:
452
+ list[T]: Mapped values corresponding to ``items`` in the same order.
453
+
454
+ Raises:
455
+ Exception: Propagates any exception raised by ``map_func``.
456
+
457
+ Example:
458
+ ```python
459
+ proxy: BatchingMapProxy[int, str] = BatchingMapProxy(batch_size=2)
460
+ calls: list[list[int]] = []
461
+
462
+ def mapper(chunk: list[int]) -> list[str]:
463
+ calls.append(chunk)
464
+ return [f"v:{x}" for x in chunk]
465
+
466
+ proxy.map([1, 2, 2, 3], mapper)
467
+ # ['v:1', 'v:2', 'v:2', 'v:3']
468
+ calls # duplicate ``2`` is only computed once
469
+ # [[1, 2], [3]]
470
+ ```
471
+ """
472
+ if self.__all_cached(items):
473
+ return self.__values(items)
474
+
475
+ unique_items = self._unique_in_order(items)
476
+ owned, wait_for = self.__acquire_ownership(unique_items)
477
+
478
+ self.__process_owned(owned, map_func)
479
+ self.__wait_for(wait_for, map_func)
480
+
481
+ # Fetch results before purging None entries
482
+ results = self.__values(items)
483
+
484
+ # Remove None values from cache so they are recomputed on future calls
485
+ with self._lock:
486
+ if self._cache: # micro-optimization
487
+ for k in set(items):
488
+ try:
489
+ if self._cache.get(k, object()) is None:
490
+ del self._cache[k]
491
+ except KeyError:
492
+ pass
493
+
494
+ return results
495
+
496
+
497
+ @dataclass
498
+ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
499
+ """Asynchronous version of BatchingMapProxy for use with async functions.
500
+
501
+ The ``map()`` method accepts an async ``map_func`` that may perform I/O and
502
+ awaits it
503
+ in mini-batches. It deduplicates inputs, maintains cache consistency, and
504
+ coordinates concurrent coroutines to avoid duplicate work via an in-flight
505
+ registry of asyncio events.
506
+
507
+ When ``batch_size=None``, automatic batch size optimization is enabled,
508
+ dynamically adjusting batch sizes based on execution time to maintain optimal
509
+ performance (targeting 30-60 seconds per batch).
510
+
511
+ Example:
512
+ ```python
513
+ import asyncio
514
+
515
+ p = AsyncBatchingMapProxy[int, str](batch_size=2)
516
+
517
+ async def af(xs: list[int]) -> list[str]:
518
+ await asyncio.sleep(0)
519
+ return [f"v:{x}" for x in xs]
520
+
521
+ async def run():
522
+ return await p.map([1, 2, 3], af)
523
+
524
+ asyncio.run(run())
525
+ # ['v:1', 'v:2', 'v:3']
526
+ ```
527
+ """
528
+
529
+ # Number of items to process per call to map_func.
530
+ # - If None (default): Enables automatic batch size optimization, dynamically adjusting
531
+ # based on execution time (targeting 30-60 seconds per batch)
532
+ # - If positive integer: Fixed batch size
533
+ # - If <= 0: Process all items at once
534
+ batch_size: int | None = None
535
+ max_concurrency: int = 8
536
+ show_progress: bool = True
537
+ suggester: BatchSizeSuggester = field(default_factory=BatchSizeSuggester, repr=False)
538
+
539
+ # internals
540
+ _cache: dict[S, T] = field(default_factory=dict, repr=False)
541
+ _lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False)
542
+ _inflight: dict[S, asyncio.Event] = field(default_factory=dict, repr=False)
543
+ __sema: asyncio.Semaphore | None = field(default=None, init=False, repr=False)
544
+
545
+ def __post_init__(self) -> None:
546
+ """Initialize internal semaphore based on ``max_concurrency``.
547
+
548
+ If ``max_concurrency`` is a positive integer, an ``asyncio.Semaphore``
549
+ is created to limit the number of concurrent ``map_func`` calls across
550
+ overlapping ``map`` invocations. When non-positive or ``None``, no
551
+ semaphore is used and concurrency is unrestricted by this proxy.
552
+
553
+ Notes:
554
+ This method is invoked automatically by ``dataclasses`` after
555
+ initialization and does not need to be called directly.
556
+ """
557
+ # Initialize semaphore if limiting is requested; non-positive disables limiting
558
+ if self.max_concurrency and self.max_concurrency > 0:
559
+ self.__sema = asyncio.Semaphore(self.max_concurrency)
560
+ else:
561
+ self.__sema = None
562
+
563
+ async def __all_cached(self, items: list[S]) -> bool:
564
+ """Check whether all items are present in the cache.
565
+
566
+ This method acquires the internal asyncio lock for a consistent view
567
+ of the cache.
568
+
569
+ Args:
570
+ items (list[S]): Items to verify against the cache.
571
+
572
+ Returns:
573
+ bool: True if every item in ``items`` is already cached, False otherwise.
574
+ """
575
+ async with self._lock:
576
+ return all(x in self._cache for x in items)
577
+
578
+ async def __values(self, items: list[S]) -> list[T]:
579
+ """Get cached values for ``items`` preserving their given order.
580
+
581
+ The internal asyncio lock is held while reading the cache to preserve
582
+ consistency under concurrency.
583
+
584
+ Args:
585
+ items (list[S]): Items to read from the cache.
586
+
587
+ Returns:
588
+ list[T]: Cached values corresponding to ``items`` in the same order.
589
+ """
590
+ async with self._lock:
591
+ return [self._cache[x] for x in items]
592
+
593
+ async def __acquire_ownership(self, items: list[S]) -> tuple[list[S], list[S]]:
594
+ """Acquire ownership for missing keys and identify keys to wait for.
595
+
596
+ Args:
597
+ items (list[S]): Unique items (order-preserving) to be processed.
598
+
599
+ Returns:
600
+ tuple[list[S], list[S]]: A tuple ``(owned, wait_for)`` where owned are
601
+ keys this coroutine should compute, and wait_for are keys currently
602
+ being computed elsewhere.
603
+ """
604
+ owned: list[S] = []
605
+ wait_for: list[S] = []
606
+ async with self._lock:
607
+ for x in items:
608
+ if x in self._cache:
609
+ continue
610
+ if x in self._inflight:
611
+ wait_for.append(x)
612
+ else:
613
+ self._inflight[x] = asyncio.Event()
614
+ owned.append(x)
615
+ return owned, wait_for
616
+
617
+ async def __finalize_success(self, to_call: list[S], results: list[T]) -> None:
618
+ """Populate cache and signal completion for successfully computed keys.
619
+
620
+ Args:
621
+ to_call (list[S]): Items that were computed in the recent batch.
622
+ results (list[T]): Results corresponding to ``to_call`` in order.
623
+ """
624
+ if len(results) != len(to_call):
625
+ # Prevent deadlocks if map_func violates the contract.
626
+ await self.__finalize_failure(to_call)
627
+ raise ValueError("map_func must return a list of results with the same length and order as inputs")
628
+ async with self._lock:
629
+ for x, y in zip(to_call, results):
630
+ self._cache[x] = y
631
+ ev = self._inflight.pop(x, None)
632
+ if ev:
633
+ ev.set()
634
+
635
+ async def __finalize_failure(self, to_call: list[S]) -> None:
636
+ """Release in-flight events on failure to avoid deadlocks.
637
+
638
+ Args:
639
+ to_call (list[S]): Items whose computation failed; their waiters will
640
+ be released.
641
+ """
642
+ async with self._lock:
643
+ for x in to_call:
644
+ ev = self._inflight.pop(x, None)
645
+ if ev:
646
+ ev.set()
647
+
648
+ async def clear(self) -> None:
649
+ """Clear all cached results and release any in-flight waiters.
650
+
651
+ Notes:
652
+ - Intended to be awaited after all processing is finished.
653
+ - Do not call concurrently with active map() calls to avoid
654
+ unnecessary recomputation or racy wake-ups.
655
+ """
656
+ async with self._lock:
657
+ for ev in self._inflight.values():
658
+ ev.set()
659
+ self._inflight.clear()
660
+ self._cache.clear()
661
+
662
+ async def aclose(self) -> None:
663
+ """Alias for clear()."""
664
+ await self.clear()
665
+
666
+ async def __process_owned(self, owned: list[S], map_func: Callable[[list[S]], Awaitable[list[T]]]) -> None:
667
+ """Process owned keys using Producer-Consumer pattern with dynamic batch sizing.
668
+
669
+ Args:
670
+ owned (list[S]): Items for which this coroutine holds computation ownership.
671
+
672
+ Raises:
673
+ Exception: Propagates any exception raised by ``map_func``.
674
+ """
675
+ if not owned:
676
+ return
677
+
678
+ progress_bar = self._create_progress_bar(len(owned))
679
+ batch_queue: asyncio.Queue = asyncio.Queue(maxsize=self.max_concurrency)
680
+
681
+ async def producer():
682
+ index = 0
683
+ while index < len(owned):
684
+ batch_size = self._normalized_batch_size(len(owned) - index)
685
+ batch = owned[index : index + batch_size]
686
+ await batch_queue.put(batch)
687
+ index += batch_size
688
+ # Send completion signals
689
+ for _ in range(self.max_concurrency):
690
+ await batch_queue.put(None)
691
+
692
+ async def consumer():
693
+ while True:
694
+ batch = await batch_queue.get()
695
+ try:
696
+ if batch is None:
697
+ break
698
+ await self.__process_single_batch(batch, map_func, progress_bar)
699
+ finally:
700
+ batch_queue.task_done()
701
+
702
+ await asyncio.gather(producer(), *[consumer() for _ in range(self.max_concurrency)])
703
+
704
+ self._close_progress_bar(progress_bar)
705
+
706
+ async def __process_single_batch(
707
+ self, to_call: list[S], map_func: Callable[[list[S]], Awaitable[list[T]]], progress_bar
708
+ ) -> None:
709
+ """Process a single batch with semaphore control."""
710
+ acquired = False
711
+ try:
712
+ if self.__sema:
713
+ await self.__sema.acquire()
714
+ acquired = True
715
+ # Measure async map_func execution using suggester
716
+ with self.suggester.record(len(to_call)):
717
+ results = await map_func(to_call)
718
+ except Exception:
719
+ await self.__finalize_failure(to_call)
720
+ raise
721
+ finally:
722
+ if self.__sema and acquired:
723
+ self.__sema.release()
724
+ await self.__finalize_success(to_call, results)
725
+
726
+ # Update progress bar
727
+ self._update_progress_bar(progress_bar, len(to_call))
728
+
729
+ async def __wait_for(self, keys: list[S], map_func: Callable[[list[S]], Awaitable[list[T]]]) -> None:
730
+ """Wait for computations owned by other coroutines to complete.
731
+
732
+ If a key is neither cached nor in-flight, this method now claims ownership
733
+ for that key immediately (registers an in-flight Event) and defers the
734
+ computation so that all such rescued keys can be processed together in a
735
+ single batched call to ``map_func`` after the scan completes. This avoids
736
+ high-cost single-item calls.
737
+
738
+ Args:
739
+ keys (list[S]): Items whose computations are owned by other coroutines.
740
+ """
741
+ rescued: list[S] = [] # keys we claim to batch-process
742
+ for x in keys:
743
+ while True:
744
+ async with self._lock:
745
+ if x in self._cache:
746
+ break
747
+ ev = self._inflight.get(x)
748
+ if ev is None:
749
+ # Not cached and no one computing; claim ownership to batch later.
750
+ self._inflight[x] = asyncio.Event()
751
+ rescued.append(x)
752
+ break
753
+ # Someone else is computing; wait for completion.
754
+ await ev.wait()
755
+ # Batch-process rescued keys, if any
756
+ if rescued:
757
+ try:
758
+ await self.__process_owned(rescued, map_func)
759
+ except Exception:
760
+ await self.__finalize_failure(rescued)
761
+ raise
762
+
763
+ # ---- public API ------------------------------------------------------
764
+ async def map(self, items: list[S], map_func: Callable[[list[S]], Awaitable[list[T]]]) -> list[T]:
765
+ """Async map with caching, de-duplication, and optional mini-batching.
766
+
767
+ Args:
768
+ items (list[S]): Input items to map.
769
+ map_func (Callable[[list[S]], Awaitable[list[T]]]): Async function that
770
+ maps a batch of items to their results, preserving input order.
771
+
772
+ Returns:
773
+ list[T]: Mapped values corresponding to ``items`` in the same order.
774
+
775
+ Example:
776
+ ```python
777
+ import asyncio
778
+
779
+ async def mapper(chunk: list[int]) -> list[str]:
780
+ await asyncio.sleep(0)
781
+ return [f"v:{x}" for x in chunk]
782
+
783
+ proxy: AsyncBatchingMapProxy[int, str] = AsyncBatchingMapProxy(batch_size=2)
784
+ asyncio.run(proxy.map([1, 1, 2], mapper))
785
+ # ['v:1', 'v:1', 'v:2']
786
+ ```
787
+ """
788
+ if await self.__all_cached(items):
789
+ return await self.__values(items)
790
+
791
+ unique_items = self._unique_in_order(items)
792
+ owned, wait_for = await self.__acquire_ownership(unique_items)
793
+
794
+ await self.__process_owned(owned, map_func)
795
+ await self.__wait_for(wait_for, map_func)
796
+
797
+ results = await self.__values(items)
798
+
799
+ # Remove None values from cache after retrieval to avoid persisting incomplete results
800
+ async with self._lock:
801
+ if self._cache:
802
+ for k in set(items):
803
+ if self._cache.get(k, object()) is None:
804
+ self._cache.pop(k, None)
805
+
806
+ return results