openaivec 0.13.4__py3-none-any.whl → 0.13.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openaivec/embeddings.py +10 -8
- openaivec/model.py +9 -11
- openaivec/optimize.py +1 -1
- openaivec/pandas_ext.py +61 -42
- openaivec/prompt.py +58 -8
- openaivec/provider.py +10 -0
- openaivec/proxy.py +82 -65
- openaivec/responses.py +35 -18
- openaivec/spark.py +40 -34
- openaivec/task/customer_support/inquiry_classification.py +9 -9
- openaivec/task/customer_support/urgency_analysis.py +13 -13
- openaivec/task/nlp/keyword_extraction.py +2 -2
- openaivec/task/nlp/named_entity_recognition.py +2 -2
- openaivec/util.py +2 -2
- {openaivec-0.13.4.dist-info → openaivec-0.13.6.dist-info}/METADATA +9 -9
- {openaivec-0.13.4.dist-info → openaivec-0.13.6.dist-info}/RECORD +18 -18
- {openaivec-0.13.4.dist-info → openaivec-0.13.6.dist-info}/WHEEL +0 -0
- {openaivec-0.13.4.dist-info → openaivec-0.13.6.dist-info}/licenses/LICENSE +0 -0
openaivec/proxy.py
CHANGED
|
@@ -2,7 +2,7 @@ import asyncio
|
|
|
2
2
|
import threading
|
|
3
3
|
from collections.abc import Hashable
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
|
-
from typing import Awaitable, Callable, Dict, Generic, List,
|
|
5
|
+
from typing import Any, Awaitable, Callable, Dict, Generic, List, TypeVar
|
|
6
6
|
|
|
7
7
|
from openaivec.optimize import BatchSizeSuggester
|
|
8
8
|
|
|
@@ -22,9 +22,9 @@ class ProxyBase(Generic[S, T]):
|
|
|
22
22
|
should process the entire input in a single call.
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
|
-
batch_size:
|
|
26
|
-
show_progress: bool
|
|
27
|
-
suggester: BatchSizeSuggester
|
|
25
|
+
batch_size: int | None # subclasses may override via dataclass
|
|
26
|
+
show_progress: bool # Enable progress bar display
|
|
27
|
+
suggester: BatchSizeSuggester # Batch size optimization, initialized by subclasses
|
|
28
28
|
|
|
29
29
|
def _is_notebook_environment(self) -> bool:
|
|
30
30
|
"""Check if running in a Jupyter notebook environment.
|
|
@@ -33,7 +33,7 @@ class ProxyBase(Generic[S, T]):
|
|
|
33
33
|
bool: True if running in a notebook, False otherwise.
|
|
34
34
|
"""
|
|
35
35
|
try:
|
|
36
|
-
from IPython import get_ipython
|
|
36
|
+
from IPython.core.getipython import get_ipython
|
|
37
37
|
|
|
38
38
|
ipython = get_ipython()
|
|
39
39
|
if ipython is not None:
|
|
@@ -89,7 +89,7 @@ class ProxyBase(Generic[S, T]):
|
|
|
89
89
|
|
|
90
90
|
return False
|
|
91
91
|
|
|
92
|
-
def _create_progress_bar(self, total: int, desc: str = "Processing batches") ->
|
|
92
|
+
def _create_progress_bar(self, total: int, desc: str = "Processing batches") -> Any:
|
|
93
93
|
"""Create a progress bar if conditions are met.
|
|
94
94
|
|
|
95
95
|
Args:
|
|
@@ -97,7 +97,7 @@ class ProxyBase(Generic[S, T]):
|
|
|
97
97
|
desc (str): Description for the progress bar.
|
|
98
98
|
|
|
99
99
|
Returns:
|
|
100
|
-
|
|
100
|
+
Any: Progress bar instance or None if not available.
|
|
101
101
|
"""
|
|
102
102
|
try:
|
|
103
103
|
from tqdm.auto import tqdm as tqdm_progress
|
|
@@ -108,21 +108,21 @@ class ProxyBase(Generic[S, T]):
|
|
|
108
108
|
pass
|
|
109
109
|
return None
|
|
110
110
|
|
|
111
|
-
def _update_progress_bar(self, progress_bar:
|
|
111
|
+
def _update_progress_bar(self, progress_bar: Any, increment: int) -> None:
|
|
112
112
|
"""Update progress bar with the given increment.
|
|
113
113
|
|
|
114
114
|
Args:
|
|
115
|
-
progress_bar (
|
|
115
|
+
progress_bar (Any): Progress bar instance.
|
|
116
116
|
increment (int): Number of items to increment.
|
|
117
117
|
"""
|
|
118
118
|
if progress_bar:
|
|
119
119
|
progress_bar.update(increment)
|
|
120
120
|
|
|
121
|
-
def _close_progress_bar(self, progress_bar:
|
|
121
|
+
def _close_progress_bar(self, progress_bar: Any) -> None:
|
|
122
122
|
"""Close the progress bar.
|
|
123
123
|
|
|
124
124
|
Args:
|
|
125
|
-
progress_bar (
|
|
125
|
+
progress_bar (Any): Progress bar instance.
|
|
126
126
|
"""
|
|
127
127
|
if progress_bar:
|
|
128
128
|
progress_bar.close()
|
|
@@ -179,6 +179,10 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
179
179
|
not duplicate work via an in-flight registry. All public behavior is preserved
|
|
180
180
|
while minimizing redundant requests and maintaining input order in the output.
|
|
181
181
|
|
|
182
|
+
When ``batch_size=None``, automatic batch size optimization is enabled,
|
|
183
|
+
dynamically adjusting batch sizes based on execution time to maintain optimal
|
|
184
|
+
performance (targeting 30-60 seconds per batch).
|
|
185
|
+
|
|
182
186
|
Example:
|
|
183
187
|
>>> from typing import List
|
|
184
188
|
>>> p = BatchingMapProxy[int, str](batch_size=3)
|
|
@@ -188,15 +192,19 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
188
192
|
['v:1', 'v:2', 'v:2', 'v:3', 'v:4']
|
|
189
193
|
"""
|
|
190
194
|
|
|
191
|
-
# Number of items to process per call to map_func.
|
|
192
|
-
|
|
195
|
+
# Number of items to process per call to map_func.
|
|
196
|
+
# - If None (default): Enables automatic batch size optimization, dynamically adjusting
|
|
197
|
+
# based on execution time (targeting 30-60 seconds per batch)
|
|
198
|
+
# - If positive integer: Fixed batch size
|
|
199
|
+
# - If <= 0: Process all items at once
|
|
200
|
+
batch_size: int | None = None
|
|
193
201
|
show_progress: bool = False
|
|
194
202
|
suggester: BatchSizeSuggester = field(default_factory=BatchSizeSuggester, repr=False)
|
|
195
203
|
|
|
196
204
|
# internals
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
205
|
+
_cache: Dict[S, T] = field(default_factory=dict)
|
|
206
|
+
_lock: threading.RLock = field(default_factory=threading.RLock, repr=False)
|
|
207
|
+
_inflight: Dict[S, threading.Event] = field(default_factory=dict, repr=False)
|
|
200
208
|
|
|
201
209
|
def __all_cached(self, items: List[S]) -> bool:
|
|
202
210
|
"""Check whether all items are present in the cache.
|
|
@@ -209,8 +217,8 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
209
217
|
Returns:
|
|
210
218
|
bool: True if every item is already cached, False otherwise.
|
|
211
219
|
"""
|
|
212
|
-
with self.
|
|
213
|
-
return all(x in self.
|
|
220
|
+
with self._lock:
|
|
221
|
+
return all(x in self._cache for x in items)
|
|
214
222
|
|
|
215
223
|
def __values(self, items: List[S]) -> List[T]:
|
|
216
224
|
"""Fetch cached values for ``items`` preserving the given order.
|
|
@@ -224,8 +232,8 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
224
232
|
list[T]: The cached values corresponding to ``items`` in the same
|
|
225
233
|
order.
|
|
226
234
|
"""
|
|
227
|
-
with self.
|
|
228
|
-
return [self.
|
|
235
|
+
with self._lock:
|
|
236
|
+
return [self._cache[x] for x in items]
|
|
229
237
|
|
|
230
238
|
def __acquire_ownership(self, items: List[S]) -> tuple[List[S], List[S]]:
|
|
231
239
|
"""Acquire ownership for missing items and identify keys to wait for.
|
|
@@ -245,14 +253,14 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
245
253
|
"""
|
|
246
254
|
owned: List[S] = []
|
|
247
255
|
wait_for: List[S] = []
|
|
248
|
-
with self.
|
|
256
|
+
with self._lock:
|
|
249
257
|
for x in items:
|
|
250
|
-
if x in self.
|
|
258
|
+
if x in self._cache:
|
|
251
259
|
continue
|
|
252
|
-
if x in self.
|
|
260
|
+
if x in self._inflight:
|
|
253
261
|
wait_for.append(x)
|
|
254
262
|
else:
|
|
255
|
-
self.
|
|
263
|
+
self._inflight[x] = threading.Event()
|
|
256
264
|
owned.append(x)
|
|
257
265
|
return owned, wait_for
|
|
258
266
|
|
|
@@ -268,10 +276,10 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
268
276
|
# Release waiters and surface a clear error.
|
|
269
277
|
self.__finalize_failure(to_call)
|
|
270
278
|
raise ValueError("map_func must return a list of results with the same length and order as inputs")
|
|
271
|
-
with self.
|
|
279
|
+
with self._lock:
|
|
272
280
|
for x, y in zip(to_call, results):
|
|
273
|
-
self.
|
|
274
|
-
ev = self.
|
|
281
|
+
self._cache[x] = y
|
|
282
|
+
ev = self._inflight.pop(x, None)
|
|
275
283
|
if ev:
|
|
276
284
|
ev.set()
|
|
277
285
|
|
|
@@ -282,9 +290,9 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
282
290
|
to_call (list[S]): Items that were intended to be computed when an
|
|
283
291
|
error occurred.
|
|
284
292
|
"""
|
|
285
|
-
with self.
|
|
293
|
+
with self._lock:
|
|
286
294
|
for x in to_call:
|
|
287
|
-
ev = self.
|
|
295
|
+
ev = self._inflight.pop(x, None)
|
|
288
296
|
if ev:
|
|
289
297
|
ev.set()
|
|
290
298
|
|
|
@@ -296,11 +304,11 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
296
304
|
- Do not call concurrently with active map() calls to avoid
|
|
297
305
|
unnecessary recomputation or racy wake-ups.
|
|
298
306
|
"""
|
|
299
|
-
with self.
|
|
300
|
-
for ev in self.
|
|
307
|
+
with self._lock:
|
|
308
|
+
for ev in self._inflight.values():
|
|
301
309
|
ev.set()
|
|
302
|
-
self.
|
|
303
|
-
self.
|
|
310
|
+
self._inflight.clear()
|
|
311
|
+
self._cache.clear()
|
|
304
312
|
|
|
305
313
|
def close(self) -> None:
|
|
306
314
|
"""Alias for clear()."""
|
|
@@ -337,8 +345,8 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
337
345
|
current_batch_size = self._normalized_batch_size(len(owned))
|
|
338
346
|
batch = owned[i : i + current_batch_size]
|
|
339
347
|
# Double-check cache right before processing
|
|
340
|
-
with self.
|
|
341
|
-
uncached_in_batch = [x for x in batch if x not in self.
|
|
348
|
+
with self._lock:
|
|
349
|
+
uncached_in_batch = [x for x in batch if x not in self._cache]
|
|
342
350
|
|
|
343
351
|
pending_to_call.extend(uncached_in_batch)
|
|
344
352
|
|
|
@@ -400,13 +408,13 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
400
408
|
rescued: List[S] = [] # keys we claim to batch-process
|
|
401
409
|
for x in keys:
|
|
402
410
|
while True:
|
|
403
|
-
with self.
|
|
404
|
-
if x in self.
|
|
411
|
+
with self._lock:
|
|
412
|
+
if x in self._cache:
|
|
405
413
|
break
|
|
406
|
-
ev = self.
|
|
414
|
+
ev = self._inflight.get(x)
|
|
407
415
|
if ev is None:
|
|
408
416
|
# Not cached and no one computing; claim ownership to batch later.
|
|
409
|
-
self.
|
|
417
|
+
self._inflight[x] = threading.Event()
|
|
410
418
|
rescued.append(x)
|
|
411
419
|
break
|
|
412
420
|
# Someone else is computing; wait for completion.
|
|
@@ -463,6 +471,10 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
463
471
|
coordinates concurrent coroutines to avoid duplicate work via an in-flight
|
|
464
472
|
registry of asyncio events.
|
|
465
473
|
|
|
474
|
+
When ``batch_size=None``, automatic batch size optimization is enabled,
|
|
475
|
+
dynamically adjusting batch sizes based on execution time to maintain optimal
|
|
476
|
+
performance (targeting 30-60 seconds per batch).
|
|
477
|
+
|
|
466
478
|
Example:
|
|
467
479
|
>>> import asyncio
|
|
468
480
|
>>> from typing import List
|
|
@@ -476,16 +488,21 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
476
488
|
['v:1', 'v:2', 'v:3']
|
|
477
489
|
"""
|
|
478
490
|
|
|
479
|
-
|
|
491
|
+
# Number of items to process per call to map_func.
|
|
492
|
+
# - If None (default): Enables automatic batch size optimization, dynamically adjusting
|
|
493
|
+
# based on execution time (targeting 30-60 seconds per batch)
|
|
494
|
+
# - If positive integer: Fixed batch size
|
|
495
|
+
# - If <= 0: Process all items at once
|
|
496
|
+
batch_size: int | None = None
|
|
480
497
|
max_concurrency: int = 8
|
|
481
498
|
show_progress: bool = False
|
|
482
499
|
suggester: BatchSizeSuggester = field(default_factory=BatchSizeSuggester, repr=False)
|
|
483
500
|
|
|
484
501
|
# internals
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
__sema:
|
|
502
|
+
_cache: Dict[S, T] = field(default_factory=dict, repr=False)
|
|
503
|
+
_lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False)
|
|
504
|
+
_inflight: Dict[S, asyncio.Event] = field(default_factory=dict, repr=False)
|
|
505
|
+
__sema: asyncio.Semaphore | None = field(default=None, init=False, repr=False)
|
|
489
506
|
|
|
490
507
|
def __post_init__(self) -> None:
|
|
491
508
|
"""Initialize internal semaphore based on ``max_concurrency``.
|
|
@@ -517,8 +534,8 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
517
534
|
Returns:
|
|
518
535
|
bool: True if every item in ``items`` is already cached, False otherwise.
|
|
519
536
|
"""
|
|
520
|
-
async with self.
|
|
521
|
-
return all(x in self.
|
|
537
|
+
async with self._lock:
|
|
538
|
+
return all(x in self._cache for x in items)
|
|
522
539
|
|
|
523
540
|
async def __values(self, items: List[S]) -> List[T]:
|
|
524
541
|
"""Get cached values for ``items`` preserving their given order.
|
|
@@ -532,8 +549,8 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
532
549
|
Returns:
|
|
533
550
|
list[T]: Cached values corresponding to ``items`` in the same order.
|
|
534
551
|
"""
|
|
535
|
-
async with self.
|
|
536
|
-
return [self.
|
|
552
|
+
async with self._lock:
|
|
553
|
+
return [self._cache[x] for x in items]
|
|
537
554
|
|
|
538
555
|
async def __acquire_ownership(self, items: List[S]) -> tuple[List[S], List[S]]:
|
|
539
556
|
"""Acquire ownership for missing keys and identify keys to wait for.
|
|
@@ -548,14 +565,14 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
548
565
|
"""
|
|
549
566
|
owned: List[S] = []
|
|
550
567
|
wait_for: List[S] = []
|
|
551
|
-
async with self.
|
|
568
|
+
async with self._lock:
|
|
552
569
|
for x in items:
|
|
553
|
-
if x in self.
|
|
570
|
+
if x in self._cache:
|
|
554
571
|
continue
|
|
555
|
-
if x in self.
|
|
572
|
+
if x in self._inflight:
|
|
556
573
|
wait_for.append(x)
|
|
557
574
|
else:
|
|
558
|
-
self.
|
|
575
|
+
self._inflight[x] = asyncio.Event()
|
|
559
576
|
owned.append(x)
|
|
560
577
|
return owned, wait_for
|
|
561
578
|
|
|
@@ -570,10 +587,10 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
570
587
|
# Prevent deadlocks if map_func violates the contract.
|
|
571
588
|
await self.__finalize_failure(to_call)
|
|
572
589
|
raise ValueError("map_func must return a list of results with the same length and order as inputs")
|
|
573
|
-
async with self.
|
|
590
|
+
async with self._lock:
|
|
574
591
|
for x, y in zip(to_call, results):
|
|
575
|
-
self.
|
|
576
|
-
ev = self.
|
|
592
|
+
self._cache[x] = y
|
|
593
|
+
ev = self._inflight.pop(x, None)
|
|
577
594
|
if ev:
|
|
578
595
|
ev.set()
|
|
579
596
|
|
|
@@ -584,9 +601,9 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
584
601
|
to_call (list[S]): Items whose computation failed; their waiters will
|
|
585
602
|
be released.
|
|
586
603
|
"""
|
|
587
|
-
async with self.
|
|
604
|
+
async with self._lock:
|
|
588
605
|
for x in to_call:
|
|
589
|
-
ev = self.
|
|
606
|
+
ev = self._inflight.pop(x, None)
|
|
590
607
|
if ev:
|
|
591
608
|
ev.set()
|
|
592
609
|
|
|
@@ -598,11 +615,11 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
598
615
|
- Do not call concurrently with active map() calls to avoid
|
|
599
616
|
unnecessary recomputation or racy wake-ups.
|
|
600
617
|
"""
|
|
601
|
-
async with self.
|
|
602
|
-
for ev in self.
|
|
618
|
+
async with self._lock:
|
|
619
|
+
for ev in self._inflight.values():
|
|
603
620
|
ev.set()
|
|
604
|
-
self.
|
|
605
|
-
self.
|
|
621
|
+
self._inflight.clear()
|
|
622
|
+
self._cache.clear()
|
|
606
623
|
|
|
607
624
|
async def aclose(self) -> None:
|
|
608
625
|
"""Alias for clear()."""
|
|
@@ -686,13 +703,13 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
686
703
|
rescued: List[S] = [] # keys we claim to batch-process
|
|
687
704
|
for x in keys:
|
|
688
705
|
while True:
|
|
689
|
-
async with self.
|
|
690
|
-
if x in self.
|
|
706
|
+
async with self._lock:
|
|
707
|
+
if x in self._cache:
|
|
691
708
|
break
|
|
692
|
-
ev = self.
|
|
709
|
+
ev = self._inflight.get(x)
|
|
693
710
|
if ev is None:
|
|
694
711
|
# Not cached and no one computing; claim ownership to batch later.
|
|
695
|
-
self.
|
|
712
|
+
self._inflight[x] = asyncio.Event()
|
|
696
713
|
rescued.append(x)
|
|
697
714
|
break
|
|
698
715
|
# Someone else is computing; wait for completion.
|
openaivec/responses.py
CHANGED
|
@@ -165,8 +165,8 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
165
165
|
system_message: str
|
|
166
166
|
temperature: float | None = 0.0
|
|
167
167
|
top_p: float = 1.0
|
|
168
|
-
response_format: Type[ResponseFormat] = str
|
|
169
|
-
cache: BatchingMapProxy[str, ResponseFormat] = field(default_factory=lambda: BatchingMapProxy(batch_size=
|
|
168
|
+
response_format: Type[ResponseFormat] = str # type: ignore[assignment]
|
|
169
|
+
cache: BatchingMapProxy[str, ResponseFormat] = field(default_factory=lambda: BatchingMapProxy(batch_size=None))
|
|
170
170
|
_vectorized_system_message: str = field(init=False)
|
|
171
171
|
_model_json_schema: dict = field(init=False)
|
|
172
172
|
|
|
@@ -179,7 +179,7 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
179
179
|
temperature: float | None = 0.0,
|
|
180
180
|
top_p: float = 1.0,
|
|
181
181
|
response_format: Type[ResponseFormat] = str,
|
|
182
|
-
batch_size: int =
|
|
182
|
+
batch_size: int | None = None,
|
|
183
183
|
) -> "BatchResponses":
|
|
184
184
|
"""Factory constructor.
|
|
185
185
|
|
|
@@ -190,7 +190,8 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
190
190
|
temperature (float, optional): Sampling temperature. Defaults to 0.0.
|
|
191
191
|
top_p (float, optional): Nucleus sampling parameter. Defaults to 1.0.
|
|
192
192
|
response_format (Type[ResponseFormat], optional): Expected output type. Defaults to ``str``.
|
|
193
|
-
batch_size (int, optional): Max unique prompts per API call. Defaults to
|
|
193
|
+
batch_size (int | None, optional): Max unique prompts per API call. Defaults to None
|
|
194
|
+
(automatic batch size optimization). Set to a positive integer for fixed batch size.
|
|
194
195
|
|
|
195
196
|
Returns:
|
|
196
197
|
BatchResponses: Configured instance backed by a batching proxy.
|
|
@@ -206,14 +207,17 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
206
207
|
)
|
|
207
208
|
|
|
208
209
|
@classmethod
|
|
209
|
-
def of_task(
|
|
210
|
+
def of_task(
|
|
211
|
+
cls, client: OpenAI, model_name: str, task: PreparedTask[ResponseFormat], batch_size: int | None = None
|
|
212
|
+
) -> "BatchResponses":
|
|
210
213
|
"""Factory from a PreparedTask.
|
|
211
214
|
|
|
212
215
|
Args:
|
|
213
216
|
client (OpenAI): OpenAI client.
|
|
214
217
|
model_name (str): For Azure OpenAI, use your deployment name. For OpenAI, use the model name.
|
|
215
218
|
task (PreparedTask): Prepared task with instructions and response format.
|
|
216
|
-
batch_size (int, optional): Max unique prompts per API call. Defaults to
|
|
219
|
+
batch_size (int | None, optional): Max unique prompts per API call. Defaults to None
|
|
220
|
+
(automatic batch size optimization). Set to a positive integer for fixed batch size.
|
|
217
221
|
|
|
218
222
|
Returns:
|
|
219
223
|
BatchResponses: Configured instance backed by a batching proxy.
|
|
@@ -294,8 +298,10 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
294
298
|
"""
|
|
295
299
|
messages = [Message(id=i, body=message) for i, message in enumerate(user_messages)]
|
|
296
300
|
responses: ParsedResponse[Response[ResponseFormat]] = self._request_llm(messages)
|
|
301
|
+
if not responses.output_parsed:
|
|
302
|
+
return [None] * len(messages)
|
|
297
303
|
response_dict = {message.id: message.body for message in responses.output_parsed.assistant_messages}
|
|
298
|
-
sorted_responses = [response_dict.get(m.id, None) for m in messages]
|
|
304
|
+
sorted_responses: List[ResponseFormat | None] = [response_dict.get(m.id, None) for m in messages]
|
|
299
305
|
return sorted_responses
|
|
300
306
|
|
|
301
307
|
@observe(_LOGGER)
|
|
@@ -308,7 +314,8 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
308
314
|
Returns:
|
|
309
315
|
List[ResponseFormat | None]: Assistant responses aligned to ``inputs``.
|
|
310
316
|
"""
|
|
311
|
-
|
|
317
|
+
result = self.cache.map(inputs, self._predict_chunk)
|
|
318
|
+
return result # type: ignore[return-value]
|
|
312
319
|
|
|
313
320
|
|
|
314
321
|
@dataclass(frozen=True)
|
|
@@ -362,9 +369,9 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
362
369
|
system_message: str
|
|
363
370
|
temperature: float | None = 0.0
|
|
364
371
|
top_p: float = 1.0
|
|
365
|
-
response_format: Type[ResponseFormat] = str
|
|
372
|
+
response_format: Type[ResponseFormat] = str # type: ignore[assignment]
|
|
366
373
|
cache: AsyncBatchingMapProxy[str, ResponseFormat] = field(
|
|
367
|
-
default_factory=lambda: AsyncBatchingMapProxy(batch_size=
|
|
374
|
+
default_factory=lambda: AsyncBatchingMapProxy(batch_size=None, max_concurrency=8)
|
|
368
375
|
)
|
|
369
376
|
_vectorized_system_message: str = field(init=False)
|
|
370
377
|
_model_json_schema: dict = field(init=False)
|
|
@@ -378,7 +385,7 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
378
385
|
temperature: float | None = 0.0,
|
|
379
386
|
top_p: float = 1.0,
|
|
380
387
|
response_format: Type[ResponseFormat] = str,
|
|
381
|
-
batch_size: int =
|
|
388
|
+
batch_size: int | None = None,
|
|
382
389
|
max_concurrency: int = 8,
|
|
383
390
|
) -> "AsyncBatchResponses":
|
|
384
391
|
"""Factory constructor.
|
|
@@ -390,7 +397,8 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
390
397
|
temperature (float, optional): Sampling temperature. Defaults to 0.0.
|
|
391
398
|
top_p (float, optional): Nucleus sampling parameter. Defaults to 1.0.
|
|
392
399
|
response_format (Type[ResponseFormat], optional): Expected output type. Defaults to ``str``.
|
|
393
|
-
batch_size (int, optional): Max unique prompts per API call. Defaults to
|
|
400
|
+
batch_size (int | None, optional): Max unique prompts per API call. Defaults to None
|
|
401
|
+
(automatic batch size optimization). Set to a positive integer for fixed batch size.
|
|
394
402
|
max_concurrency (int, optional): Max concurrent API calls. Defaults to 8.
|
|
395
403
|
|
|
396
404
|
Returns:
|
|
@@ -408,7 +416,12 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
408
416
|
|
|
409
417
|
@classmethod
|
|
410
418
|
def of_task(
|
|
411
|
-
cls,
|
|
419
|
+
cls,
|
|
420
|
+
client: AsyncOpenAI,
|
|
421
|
+
model_name: str,
|
|
422
|
+
task: PreparedTask[ResponseFormat],
|
|
423
|
+
batch_size: int | None = None,
|
|
424
|
+
max_concurrency: int = 8,
|
|
412
425
|
) -> "AsyncBatchResponses":
|
|
413
426
|
"""Factory from a PreparedTask.
|
|
414
427
|
|
|
@@ -416,7 +429,8 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
416
429
|
client (AsyncOpenAI): OpenAI async client.
|
|
417
430
|
model_name (str): For Azure OpenAI, use your deployment name. For OpenAI, use the model name.
|
|
418
431
|
task (PreparedTask): Prepared task with instructions and response format.
|
|
419
|
-
batch_size (int, optional): Max unique prompts per API call. Defaults to
|
|
432
|
+
batch_size (int | None, optional): Max unique prompts per API call. Defaults to None
|
|
433
|
+
(automatic batch size optimization). Set to a positive integer for fixed batch size.
|
|
420
434
|
max_concurrency (int, optional): Max concurrent API calls. Defaults to 8.
|
|
421
435
|
|
|
422
436
|
Returns:
|
|
@@ -439,8 +453,8 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
439
453
|
_vectorize_system_message(self.system_message),
|
|
440
454
|
)
|
|
441
455
|
|
|
442
|
-
@observe(_LOGGER)
|
|
443
456
|
@backoff_async(exceptions=[RateLimitError, InternalServerError], scale=1, max_retries=12)
|
|
457
|
+
@observe(_LOGGER)
|
|
444
458
|
async def _request_llm(self, user_messages: List[Message[str]]) -> ParsedResponse[Response[ResponseFormat]]:
|
|
445
459
|
"""Make a single async call to the OpenAI JSON‑mode endpoint.
|
|
446
460
|
|
|
@@ -493,10 +507,12 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
493
507
|
The function is pure – it has no side‑effects and the result depends only on its arguments.
|
|
494
508
|
"""
|
|
495
509
|
messages = [Message(id=i, body=message) for i, message in enumerate(user_messages)]
|
|
496
|
-
responses: ParsedResponse[Response[ResponseFormat]] = await self._request_llm(messages)
|
|
510
|
+
responses: ParsedResponse[Response[ResponseFormat]] = await self._request_llm(messages) # type: ignore[call-issue]
|
|
511
|
+
if not responses.output_parsed:
|
|
512
|
+
return [None] * len(messages)
|
|
497
513
|
response_dict = {message.id: message.body for message in responses.output_parsed.assistant_messages}
|
|
498
514
|
# Ensure proper handling for missing IDs - this shouldn't happen in normal operation
|
|
499
|
-
sorted_responses = [response_dict.get(m.id, None) for m in messages]
|
|
515
|
+
sorted_responses: List[ResponseFormat | None] = [response_dict.get(m.id, None) for m in messages]
|
|
500
516
|
return sorted_responses
|
|
501
517
|
|
|
502
518
|
@observe(_LOGGER)
|
|
@@ -509,4 +525,5 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
509
525
|
Returns:
|
|
510
526
|
List[ResponseFormat | None]: Assistant responses aligned to ``inputs``.
|
|
511
527
|
"""
|
|
512
|
-
|
|
528
|
+
result = await self.cache.map(inputs, self._predict_chunk)
|
|
529
|
+
return result # type: ignore[return-value]
|