openaivec 0.14.7__py3-none-any.whl → 0.14.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openaivec/_di.py +10 -9
- openaivec/_embeddings.py +12 -13
- openaivec/_log.py +1 -1
- openaivec/_model.py +3 -3
- openaivec/_optimize.py +3 -4
- openaivec/_prompt.py +4 -5
- openaivec/_proxy.py +34 -35
- openaivec/_responses.py +29 -29
- openaivec/_schema.py +56 -18
- openaivec/_serialize.py +19 -15
- openaivec/_util.py +9 -8
- openaivec/pandas_ext.py +20 -19
- openaivec/spark.py +11 -10
- openaivec/task/customer_support/customer_sentiment.py +2 -2
- openaivec/task/customer_support/inquiry_classification.py +8 -8
- openaivec/task/customer_support/inquiry_summary.py +4 -4
- openaivec/task/customer_support/intent_analysis.py +5 -5
- openaivec/task/customer_support/response_suggestion.py +4 -4
- openaivec/task/customer_support/urgency_analysis.py +9 -9
- openaivec/task/nlp/dependency_parsing.py +2 -4
- openaivec/task/nlp/keyword_extraction.py +3 -5
- openaivec/task/nlp/morphological_analysis.py +4 -6
- openaivec/task/nlp/named_entity_recognition.py +7 -9
- openaivec/task/nlp/sentiment_analysis.py +3 -3
- openaivec/task/nlp/translation.py +1 -2
- openaivec/task/table/fillna.py +2 -3
- {openaivec-0.14.7.dist-info → openaivec-0.14.8.dist-info}/METADATA +1 -1
- openaivec-0.14.8.dist-info/RECORD +36 -0
- openaivec-0.14.7.dist-info/RECORD +0 -36
- {openaivec-0.14.7.dist-info → openaivec-0.14.8.dist-info}/WHEEL +0 -0
- {openaivec-0.14.7.dist-info → openaivec-0.14.8.dist-info}/licenses/LICENSE +0 -0
openaivec/_di.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
1
2
|
from dataclasses import dataclass, field
|
|
2
3
|
from threading import RLock
|
|
3
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, TypeVar
|
|
4
5
|
|
|
5
6
|
__all__ = []
|
|
6
7
|
|
|
@@ -119,12 +120,12 @@ class Container:
|
|
|
119
120
|
```
|
|
120
121
|
"""
|
|
121
122
|
|
|
122
|
-
_instances:
|
|
123
|
-
_providers:
|
|
123
|
+
_instances: dict[type[Any], Any] = field(default_factory=dict)
|
|
124
|
+
_providers: dict[type[Any], Provider[Any]] = field(default_factory=dict)
|
|
124
125
|
_lock: RLock = field(default_factory=RLock)
|
|
125
|
-
_resolving:
|
|
126
|
+
_resolving: set[type[Any]] = field(default_factory=set)
|
|
126
127
|
|
|
127
|
-
def register(self, cls:
|
|
128
|
+
def register(self, cls: type[T], provider: Provider[T]) -> None:
|
|
128
129
|
"""Register a provider function for a service type.
|
|
129
130
|
|
|
130
131
|
The provider function will be called once to create the singleton instance
|
|
@@ -150,7 +151,7 @@ class Container:
|
|
|
150
151
|
|
|
151
152
|
self._providers[cls] = provider
|
|
152
153
|
|
|
153
|
-
def register_instance(self, cls:
|
|
154
|
+
def register_instance(self, cls: type[T], instance: T) -> None:
|
|
154
155
|
"""Register a pre-created instance for a service type.
|
|
155
156
|
|
|
156
157
|
The provided instance will be stored directly in the container and returned
|
|
@@ -178,7 +179,7 @@ class Container:
|
|
|
178
179
|
self._instances[cls] = instance
|
|
179
180
|
self._providers[cls] = lambda: instance
|
|
180
181
|
|
|
181
|
-
def resolve(self, cls:
|
|
182
|
+
def resolve(self, cls: type[T]) -> T:
|
|
182
183
|
"""Resolve a service instance, creating it if necessary.
|
|
183
184
|
|
|
184
185
|
Returns the singleton instance for the requested service type. If this is
|
|
@@ -232,7 +233,7 @@ class Container:
|
|
|
232
233
|
finally:
|
|
233
234
|
self._resolving.discard(cls)
|
|
234
235
|
|
|
235
|
-
def is_registered(self, cls:
|
|
236
|
+
def is_registered(self, cls: type[Any]) -> bool:
|
|
236
237
|
"""Check if a service type is registered in the container.
|
|
237
238
|
|
|
238
239
|
Args:
|
|
@@ -252,7 +253,7 @@ class Container:
|
|
|
252
253
|
with self._lock:
|
|
253
254
|
return cls in self._providers
|
|
254
255
|
|
|
255
|
-
def unregister(self, cls:
|
|
256
|
+
def unregister(self, cls: type[Any]) -> None:
|
|
256
257
|
"""Unregister a service type from the container.
|
|
257
258
|
|
|
258
259
|
Removes the provider function and any cached singleton instance for
|
openaivec/_embeddings.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
2
|
from logging import Logger, getLogger
|
|
3
|
-
from typing import List
|
|
4
3
|
|
|
5
4
|
import numpy as np
|
|
6
5
|
from numpy.typing import NDArray
|
|
@@ -50,7 +49,7 @@ class BatchEmbeddings:
|
|
|
50
49
|
|
|
51
50
|
@observe(_LOGGER)
|
|
52
51
|
@backoff(exceptions=[RateLimitError, InternalServerError], scale=1, max_retries=12)
|
|
53
|
-
def _embed_chunk(self, inputs:
|
|
52
|
+
def _embed_chunk(self, inputs: list[str]) -> list[NDArray[np.float32]]:
|
|
54
53
|
"""Embed one minibatch of strings.
|
|
55
54
|
|
|
56
55
|
This private helper is the unit of work used by the map/parallel
|
|
@@ -58,23 +57,23 @@ class BatchEmbeddings:
|
|
|
58
57
|
``openai.RateLimitError`` is raised.
|
|
59
58
|
|
|
60
59
|
Args:
|
|
61
|
-
inputs (
|
|
60
|
+
inputs (list[str]): Input strings to be embedded. Duplicates allowed.
|
|
62
61
|
|
|
63
62
|
Returns:
|
|
64
|
-
|
|
63
|
+
list[NDArray[np.float32]]: Embedding vectors aligned to ``inputs``.
|
|
65
64
|
"""
|
|
66
65
|
responses = self.client.embeddings.create(input=inputs, model=self.model_name)
|
|
67
66
|
return [np.array(d.embedding, dtype=np.float32) for d in responses.data]
|
|
68
67
|
|
|
69
68
|
@observe(_LOGGER)
|
|
70
|
-
def create(self, inputs:
|
|
69
|
+
def create(self, inputs: list[str]) -> list[NDArray[np.float32]]:
|
|
71
70
|
"""Generate embeddings for inputs using cached, ordered batching.
|
|
72
71
|
|
|
73
72
|
Args:
|
|
74
|
-
inputs (
|
|
73
|
+
inputs (list[str]): Input strings. Duplicates allowed.
|
|
75
74
|
|
|
76
75
|
Returns:
|
|
77
|
-
|
|
76
|
+
list[NDArray[np.float32]]: Embedding vectors aligned to ``inputs``.
|
|
78
77
|
"""
|
|
79
78
|
return self.cache.map(inputs, self._embed_chunk)
|
|
80
79
|
|
|
@@ -159,7 +158,7 @@ class AsyncBatchEmbeddings:
|
|
|
159
158
|
|
|
160
159
|
@backoff_async(exceptions=[RateLimitError, InternalServerError], scale=1, max_retries=12)
|
|
161
160
|
@observe(_LOGGER)
|
|
162
|
-
async def _embed_chunk(self, inputs:
|
|
161
|
+
async def _embed_chunk(self, inputs: list[str]) -> list[NDArray[np.float32]]:
|
|
163
162
|
"""Embed one minibatch of strings asynchronously.
|
|
164
163
|
|
|
165
164
|
This private helper handles the actual API call for a batch of inputs.
|
|
@@ -167,10 +166,10 @@ class AsyncBatchEmbeddings:
|
|
|
167
166
|
is raised.
|
|
168
167
|
|
|
169
168
|
Args:
|
|
170
|
-
inputs (
|
|
169
|
+
inputs (list[str]): Input strings to be embedded. Duplicates allowed.
|
|
171
170
|
|
|
172
171
|
Returns:
|
|
173
|
-
|
|
172
|
+
list[NDArray[np.float32]]: Embedding vectors aligned to ``inputs``.
|
|
174
173
|
|
|
175
174
|
Raises:
|
|
176
175
|
RateLimitError: Propagated if retries are exhausted.
|
|
@@ -179,13 +178,13 @@ class AsyncBatchEmbeddings:
|
|
|
179
178
|
return [np.array(d.embedding, dtype=np.float32) for d in responses.data]
|
|
180
179
|
|
|
181
180
|
@observe(_LOGGER)
|
|
182
|
-
async def create(self, inputs:
|
|
181
|
+
async def create(self, inputs: list[str]) -> list[NDArray[np.float32]]:
|
|
183
182
|
"""Generate embeddings for inputs using proxy batching (async).
|
|
184
183
|
|
|
185
184
|
Args:
|
|
186
|
-
inputs (
|
|
185
|
+
inputs (list[str]): Input strings. Duplicates allowed.
|
|
187
186
|
|
|
188
187
|
Returns:
|
|
189
|
-
|
|
188
|
+
list[NDArray[np.float32]]: Embedding vectors aligned to ``inputs``.
|
|
190
189
|
"""
|
|
191
190
|
return await self.cache.map(inputs, self._embed_chunk) # type: ignore[arg-type]
|
openaivec/_log.py
CHANGED
openaivec/_model.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Generic,
|
|
2
|
+
from typing import Generic, TypeVar
|
|
3
3
|
|
|
4
4
|
__all__ = [
|
|
5
5
|
"PreparedTask",
|
|
@@ -20,7 +20,7 @@ class PreparedTask(Generic[ResponseFormat]):
|
|
|
20
20
|
Attributes:
|
|
21
21
|
instructions (str): The prompt or instructions to send to the OpenAI model.
|
|
22
22
|
This should contain clear, specific directions for the task.
|
|
23
|
-
response_format (
|
|
23
|
+
response_format (type[ResponseFormat]): A Pydantic model class or str type that defines the expected
|
|
24
24
|
structure of the response. Can be either a BaseModel subclass or str.
|
|
25
25
|
temperature (float): Controls randomness in the model's output.
|
|
26
26
|
Range: 0.0 to 1.0. Lower values make output more deterministic.
|
|
@@ -54,7 +54,7 @@ class PreparedTask(Generic[ResponseFormat]):
|
|
|
54
54
|
"""
|
|
55
55
|
|
|
56
56
|
instructions: str
|
|
57
|
-
response_format:
|
|
57
|
+
response_format: type[ResponseFormat]
|
|
58
58
|
temperature: float = 0.0
|
|
59
59
|
top_p: float = 1.0
|
|
60
60
|
|
openaivec/_optimize.py
CHANGED
|
@@ -3,7 +3,6 @@ import time
|
|
|
3
3
|
from contextlib import contextmanager
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
5
|
from datetime import datetime, timezone
|
|
6
|
-
from typing import List
|
|
7
6
|
|
|
8
7
|
__all__ = []
|
|
9
8
|
|
|
@@ -24,7 +23,7 @@ class BatchSizeSuggester:
|
|
|
24
23
|
max_duration: float = 60.0
|
|
25
24
|
step_ratio: float = 0.2
|
|
26
25
|
sample_size: int = 4
|
|
27
|
-
_history:
|
|
26
|
+
_history: list[PerformanceMetric] = field(default_factory=list)
|
|
28
27
|
_lock: threading.RLock = field(default_factory=threading.RLock, repr=False)
|
|
29
28
|
_batch_size_changed_at: datetime | None = field(default=None, init=False)
|
|
30
29
|
|
|
@@ -65,9 +64,9 @@ class BatchSizeSuggester:
|
|
|
65
64
|
)
|
|
66
65
|
|
|
67
66
|
@property
|
|
68
|
-
def samples(self) ->
|
|
67
|
+
def samples(self) -> list[PerformanceMetric]:
|
|
69
68
|
with self._lock:
|
|
70
|
-
selected:
|
|
69
|
+
selected: list[PerformanceMetric] = []
|
|
71
70
|
for metric in reversed(self._history):
|
|
72
71
|
if metric.exception is not None:
|
|
73
72
|
continue
|
openaivec/_prompt.py
CHANGED
|
@@ -44,7 +44,6 @@ this will produce an XML string that looks like this:
|
|
|
44
44
|
|
|
45
45
|
import difflib
|
|
46
46
|
import logging
|
|
47
|
-
from typing import List
|
|
48
47
|
from xml.etree import ElementTree
|
|
49
48
|
|
|
50
49
|
from openai import OpenAI
|
|
@@ -90,8 +89,8 @@ class FewShotPrompt(BaseModel):
|
|
|
90
89
|
"""
|
|
91
90
|
|
|
92
91
|
purpose: str
|
|
93
|
-
cautions:
|
|
94
|
-
examples:
|
|
92
|
+
cautions: list[str]
|
|
93
|
+
examples: list[Example]
|
|
95
94
|
|
|
96
95
|
|
|
97
96
|
class Step(BaseModel):
|
|
@@ -116,7 +115,7 @@ class Request(BaseModel):
|
|
|
116
115
|
|
|
117
116
|
|
|
118
117
|
class Response(BaseModel):
|
|
119
|
-
iterations:
|
|
118
|
+
iterations: list[Step]
|
|
120
119
|
|
|
121
120
|
|
|
122
121
|
_PROMPT: str = """
|
|
@@ -358,7 +357,7 @@ class FewShotPromptBuilder:
|
|
|
358
357
|
"""
|
|
359
358
|
|
|
360
359
|
_prompt: FewShotPrompt
|
|
361
|
-
_steps:
|
|
360
|
+
_steps: list[Step]
|
|
362
361
|
|
|
363
362
|
def __init__(self):
|
|
364
363
|
"""Initialize an empty FewShotPromptBuilder.
|
openaivec/_proxy.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import threading
|
|
3
|
-
from collections.abc import Hashable
|
|
3
|
+
from collections.abc import Awaitable, Callable, Hashable
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Generic, TypeVar
|
|
6
6
|
|
|
7
7
|
from openaivec._optimize import BatchSizeSuggester
|
|
8
8
|
|
|
@@ -130,7 +130,7 @@ class ProxyBase(Generic[S, T]):
|
|
|
130
130
|
progress_bar.close()
|
|
131
131
|
|
|
132
132
|
@staticmethod
|
|
133
|
-
def _unique_in_order(seq:
|
|
133
|
+
def _unique_in_order(seq: list[S]) -> list[S]:
|
|
134
134
|
"""Return unique items preserving their first-occurrence order.
|
|
135
135
|
|
|
136
136
|
Args:
|
|
@@ -141,7 +141,7 @@ class ProxyBase(Generic[S, T]):
|
|
|
141
141
|
once, in the order of their first occurrence.
|
|
142
142
|
"""
|
|
143
143
|
seen: set[S] = set()
|
|
144
|
-
out:
|
|
144
|
+
out: list[S] = []
|
|
145
145
|
for x in seq:
|
|
146
146
|
if x not in seen:
|
|
147
147
|
seen.add(x)
|
|
@@ -186,9 +186,8 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
186
186
|
performance (targeting 30-60 seconds per batch).
|
|
187
187
|
|
|
188
188
|
Example:
|
|
189
|
-
>>> from typing import List
|
|
190
189
|
>>> p = BatchingMapProxy[int, str](batch_size=3)
|
|
191
|
-
>>> def f(xs:
|
|
190
|
+
>>> def f(xs: list[int]) -> list[str]:
|
|
192
191
|
... return [f"v:{x}" for x in xs]
|
|
193
192
|
>>> p.map([1, 2, 2, 3, 4], f)
|
|
194
193
|
['v:1', 'v:2', 'v:2', 'v:3', 'v:4']
|
|
@@ -204,11 +203,11 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
204
203
|
suggester: BatchSizeSuggester = field(default_factory=BatchSizeSuggester, repr=False)
|
|
205
204
|
|
|
206
205
|
# internals
|
|
207
|
-
_cache:
|
|
206
|
+
_cache: dict[S, T] = field(default_factory=dict)
|
|
208
207
|
_lock: threading.RLock = field(default_factory=threading.RLock, repr=False)
|
|
209
|
-
_inflight:
|
|
208
|
+
_inflight: dict[S, threading.Event] = field(default_factory=dict, repr=False)
|
|
210
209
|
|
|
211
|
-
def __all_cached(self, items:
|
|
210
|
+
def __all_cached(self, items: list[S]) -> bool:
|
|
212
211
|
"""Check whether all items are present in the cache.
|
|
213
212
|
|
|
214
213
|
This method acquires the internal lock to perform a consistent check.
|
|
@@ -222,7 +221,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
222
221
|
with self._lock:
|
|
223
222
|
return all(x in self._cache for x in items)
|
|
224
223
|
|
|
225
|
-
def __values(self, items:
|
|
224
|
+
def __values(self, items: list[S]) -> list[T]:
|
|
226
225
|
"""Fetch cached values for ``items`` preserving the given order.
|
|
227
226
|
|
|
228
227
|
This method acquires the internal lock while reading the cache.
|
|
@@ -237,7 +236,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
237
236
|
with self._lock:
|
|
238
237
|
return [self._cache[x] for x in items]
|
|
239
238
|
|
|
240
|
-
def __acquire_ownership(self, items:
|
|
239
|
+
def __acquire_ownership(self, items: list[S]) -> tuple[list[S], list[S]]:
|
|
241
240
|
"""Acquire ownership for missing items and identify keys to wait for.
|
|
242
241
|
|
|
243
242
|
For each unique item, if it's already cached, it is ignored. If it's
|
|
@@ -253,8 +252,8 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
253
252
|
- ``owned`` are items this thread is responsible for computing.
|
|
254
253
|
- ``wait_for`` are items that another thread is already computing.
|
|
255
254
|
"""
|
|
256
|
-
owned:
|
|
257
|
-
wait_for:
|
|
255
|
+
owned: list[S] = []
|
|
256
|
+
wait_for: list[S] = []
|
|
258
257
|
with self._lock:
|
|
259
258
|
for x in items:
|
|
260
259
|
if x in self._cache:
|
|
@@ -266,7 +265,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
266
265
|
owned.append(x)
|
|
267
266
|
return owned, wait_for
|
|
268
267
|
|
|
269
|
-
def __finalize_success(self, to_call:
|
|
268
|
+
def __finalize_success(self, to_call: list[S], results: list[T]) -> None:
|
|
270
269
|
"""Populate cache with results and signal completion events.
|
|
271
270
|
|
|
272
271
|
Args:
|
|
@@ -285,7 +284,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
285
284
|
if ev:
|
|
286
285
|
ev.set()
|
|
287
286
|
|
|
288
|
-
def __finalize_failure(self, to_call:
|
|
287
|
+
def __finalize_failure(self, to_call: list[S]) -> None:
|
|
289
288
|
"""Release in-flight events on failure to avoid deadlocks.
|
|
290
289
|
|
|
291
290
|
Args:
|
|
@@ -316,7 +315,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
316
315
|
"""Alias for clear()."""
|
|
317
316
|
self.clear()
|
|
318
317
|
|
|
319
|
-
def __process_owned(self, owned:
|
|
318
|
+
def __process_owned(self, owned: list[S], map_func: Callable[[list[S]], list[T]]) -> None:
|
|
320
319
|
"""Process owned items in mini-batches and fill the cache.
|
|
321
320
|
|
|
322
321
|
Before calling ``map_func`` for each batch, the cache is re-checked
|
|
@@ -339,7 +338,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
339
338
|
progress_bar = self._create_progress_bar(len(owned))
|
|
340
339
|
|
|
341
340
|
# Accumulate uncached items to maximize batch size utilization
|
|
342
|
-
pending_to_call:
|
|
341
|
+
pending_to_call: list[S] = []
|
|
343
342
|
|
|
344
343
|
i = 0
|
|
345
344
|
while i < len(owned):
|
|
@@ -395,7 +394,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
395
394
|
# Close progress bar
|
|
396
395
|
self._close_progress_bar(progress_bar)
|
|
397
396
|
|
|
398
|
-
def __wait_for(self, keys:
|
|
397
|
+
def __wait_for(self, keys: list[S], map_func: Callable[[list[S]], list[T]]) -> None:
|
|
399
398
|
"""Wait for other threads to complete computations for the given keys.
|
|
400
399
|
|
|
401
400
|
If a key is neither cached nor in-flight, this method now claims ownership
|
|
@@ -407,7 +406,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
407
406
|
Args:
|
|
408
407
|
keys (list[S]): Items whose computations are owned by other threads.
|
|
409
408
|
"""
|
|
410
|
-
rescued:
|
|
409
|
+
rescued: list[S] = [] # keys we claim to batch-process
|
|
411
410
|
for x in keys:
|
|
412
411
|
while True:
|
|
413
412
|
with self._lock:
|
|
@@ -431,7 +430,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
431
430
|
raise
|
|
432
431
|
|
|
433
432
|
# ---- public API ------------------------------------------------------
|
|
434
|
-
def map(self, items:
|
|
433
|
+
def map(self, items: list[S], map_func: Callable[[list[S]], list[T]]) -> list[T]:
|
|
435
434
|
"""Map ``items`` to values using caching and optional mini-batching.
|
|
436
435
|
|
|
437
436
|
This method is thread-safe. It deduplicates inputs while preserving order,
|
|
@@ -494,7 +493,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
494
493
|
>>> import asyncio
|
|
495
494
|
>>> from typing import List
|
|
496
495
|
>>> p = AsyncBatchingMapProxy[int, str](batch_size=2)
|
|
497
|
-
>>> async def af(xs:
|
|
496
|
+
>>> async def af(xs: list[int]) -> list[str]:
|
|
498
497
|
... await asyncio.sleep(0)
|
|
499
498
|
... return [f"v:{x}" for x in xs]
|
|
500
499
|
>>> async def run():
|
|
@@ -514,9 +513,9 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
514
513
|
suggester: BatchSizeSuggester = field(default_factory=BatchSizeSuggester, repr=False)
|
|
515
514
|
|
|
516
515
|
# internals
|
|
517
|
-
_cache:
|
|
516
|
+
_cache: dict[S, T] = field(default_factory=dict, repr=False)
|
|
518
517
|
_lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False)
|
|
519
|
-
_inflight:
|
|
518
|
+
_inflight: dict[S, asyncio.Event] = field(default_factory=dict, repr=False)
|
|
520
519
|
__sema: asyncio.Semaphore | None = field(default=None, init=False, repr=False)
|
|
521
520
|
|
|
522
521
|
def __post_init__(self) -> None:
|
|
@@ -537,7 +536,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
537
536
|
else:
|
|
538
537
|
self.__sema = None
|
|
539
538
|
|
|
540
|
-
async def __all_cached(self, items:
|
|
539
|
+
async def __all_cached(self, items: list[S]) -> bool:
|
|
541
540
|
"""Check whether all items are present in the cache.
|
|
542
541
|
|
|
543
542
|
This method acquires the internal asyncio lock for a consistent view
|
|
@@ -552,7 +551,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
552
551
|
async with self._lock:
|
|
553
552
|
return all(x in self._cache for x in items)
|
|
554
553
|
|
|
555
|
-
async def __values(self, items:
|
|
554
|
+
async def __values(self, items: list[S]) -> list[T]:
|
|
556
555
|
"""Get cached values for ``items`` preserving their given order.
|
|
557
556
|
|
|
558
557
|
The internal asyncio lock is held while reading the cache to preserve
|
|
@@ -567,7 +566,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
567
566
|
async with self._lock:
|
|
568
567
|
return [self._cache[x] for x in items]
|
|
569
568
|
|
|
570
|
-
async def __acquire_ownership(self, items:
|
|
569
|
+
async def __acquire_ownership(self, items: list[S]) -> tuple[list[S], list[S]]:
|
|
571
570
|
"""Acquire ownership for missing keys and identify keys to wait for.
|
|
572
571
|
|
|
573
572
|
Args:
|
|
@@ -578,8 +577,8 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
578
577
|
keys this coroutine should compute, and wait_for are keys currently
|
|
579
578
|
being computed elsewhere.
|
|
580
579
|
"""
|
|
581
|
-
owned:
|
|
582
|
-
wait_for:
|
|
580
|
+
owned: list[S] = []
|
|
581
|
+
wait_for: list[S] = []
|
|
583
582
|
async with self._lock:
|
|
584
583
|
for x in items:
|
|
585
584
|
if x in self._cache:
|
|
@@ -591,7 +590,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
591
590
|
owned.append(x)
|
|
592
591
|
return owned, wait_for
|
|
593
592
|
|
|
594
|
-
async def __finalize_success(self, to_call:
|
|
593
|
+
async def __finalize_success(self, to_call: list[S], results: list[T]) -> None:
|
|
595
594
|
"""Populate cache and signal completion for successfully computed keys.
|
|
596
595
|
|
|
597
596
|
Args:
|
|
@@ -609,7 +608,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
609
608
|
if ev:
|
|
610
609
|
ev.set()
|
|
611
610
|
|
|
612
|
-
async def __finalize_failure(self, to_call:
|
|
611
|
+
async def __finalize_failure(self, to_call: list[S]) -> None:
|
|
613
612
|
"""Release in-flight events on failure to avoid deadlocks.
|
|
614
613
|
|
|
615
614
|
Args:
|
|
@@ -640,7 +639,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
640
639
|
"""Alias for clear()."""
|
|
641
640
|
await self.clear()
|
|
642
641
|
|
|
643
|
-
async def __process_owned(self, owned:
|
|
642
|
+
async def __process_owned(self, owned: list[S], map_func: Callable[[list[S]], Awaitable[list[T]]]) -> None:
|
|
644
643
|
"""Process owned keys using Producer-Consumer pattern with dynamic batch sizing.
|
|
645
644
|
|
|
646
645
|
Args:
|
|
@@ -681,7 +680,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
681
680
|
self._close_progress_bar(progress_bar)
|
|
682
681
|
|
|
683
682
|
async def __process_single_batch(
|
|
684
|
-
self, to_call:
|
|
683
|
+
self, to_call: list[S], map_func: Callable[[list[S]], Awaitable[list[T]]], progress_bar
|
|
685
684
|
) -> None:
|
|
686
685
|
"""Process a single batch with semaphore control."""
|
|
687
686
|
acquired = False
|
|
@@ -703,7 +702,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
703
702
|
# Update progress bar
|
|
704
703
|
self._update_progress_bar(progress_bar, len(to_call))
|
|
705
704
|
|
|
706
|
-
async def __wait_for(self, keys:
|
|
705
|
+
async def __wait_for(self, keys: list[S], map_func: Callable[[list[S]], Awaitable[list[T]]]) -> None:
|
|
707
706
|
"""Wait for computations owned by other coroutines to complete.
|
|
708
707
|
|
|
709
708
|
If a key is neither cached nor in-flight, this method now claims ownership
|
|
@@ -715,7 +714,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
715
714
|
Args:
|
|
716
715
|
keys (list[S]): Items whose computations are owned by other coroutines.
|
|
717
716
|
"""
|
|
718
|
-
rescued:
|
|
717
|
+
rescued: list[S] = [] # keys we claim to batch-process
|
|
719
718
|
for x in keys:
|
|
720
719
|
while True:
|
|
721
720
|
async with self._lock:
|
|
@@ -738,7 +737,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
738
737
|
raise
|
|
739
738
|
|
|
740
739
|
# ---- public API ------------------------------------------------------
|
|
741
|
-
async def map(self, items:
|
|
740
|
+
async def map(self, items: list[S], map_func: Callable[[list[S]], Awaitable[list[T]]]) -> list[T]:
|
|
742
741
|
"""Async map with caching, de-duplication, and optional mini-batching.
|
|
743
742
|
|
|
744
743
|
Args:
|