openaivec 0.14.7__py3-none-any.whl → 0.14.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openaivec/_di.py +10 -9
- openaivec/_dynamic.py +350 -0
- openaivec/_embeddings.py +12 -13
- openaivec/_log.py +1 -1
- openaivec/_model.py +3 -3
- openaivec/_optimize.py +3 -4
- openaivec/_prompt.py +4 -5
- openaivec/_proxy.py +34 -35
- openaivec/_responses.py +29 -29
- openaivec/_schema.py +105 -244
- openaivec/_serialize.py +19 -15
- openaivec/_util.py +9 -8
- openaivec/pandas_ext.py +20 -19
- openaivec/spark.py +11 -10
- openaivec/task/customer_support/customer_sentiment.py +2 -2
- openaivec/task/customer_support/inquiry_classification.py +8 -8
- openaivec/task/customer_support/inquiry_summary.py +4 -4
- openaivec/task/customer_support/intent_analysis.py +5 -5
- openaivec/task/customer_support/response_suggestion.py +4 -4
- openaivec/task/customer_support/urgency_analysis.py +9 -9
- openaivec/task/nlp/dependency_parsing.py +2 -4
- openaivec/task/nlp/keyword_extraction.py +3 -5
- openaivec/task/nlp/morphological_analysis.py +4 -6
- openaivec/task/nlp/named_entity_recognition.py +7 -9
- openaivec/task/nlp/sentiment_analysis.py +3 -3
- openaivec/task/nlp/translation.py +1 -2
- openaivec/task/table/fillna.py +2 -3
- {openaivec-0.14.7.dist-info → openaivec-0.14.9.dist-info}/METADATA +1 -1
- openaivec-0.14.9.dist-info/RECORD +37 -0
- openaivec-0.14.7.dist-info/RECORD +0 -36
- {openaivec-0.14.7.dist-info → openaivec-0.14.9.dist-info}/WHEEL +0 -0
- {openaivec-0.14.7.dist-info → openaivec-0.14.9.dist-info}/licenses/LICENSE +0 -0
openaivec/_proxy.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import threading
|
|
3
|
-
from collections.abc import Hashable
|
|
3
|
+
from collections.abc import Awaitable, Callable, Hashable
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Generic, TypeVar
|
|
6
6
|
|
|
7
7
|
from openaivec._optimize import BatchSizeSuggester
|
|
8
8
|
|
|
@@ -130,7 +130,7 @@ class ProxyBase(Generic[S, T]):
|
|
|
130
130
|
progress_bar.close()
|
|
131
131
|
|
|
132
132
|
@staticmethod
|
|
133
|
-
def _unique_in_order(seq:
|
|
133
|
+
def _unique_in_order(seq: list[S]) -> list[S]:
|
|
134
134
|
"""Return unique items preserving their first-occurrence order.
|
|
135
135
|
|
|
136
136
|
Args:
|
|
@@ -141,7 +141,7 @@ class ProxyBase(Generic[S, T]):
|
|
|
141
141
|
once, in the order of their first occurrence.
|
|
142
142
|
"""
|
|
143
143
|
seen: set[S] = set()
|
|
144
|
-
out:
|
|
144
|
+
out: list[S] = []
|
|
145
145
|
for x in seq:
|
|
146
146
|
if x not in seen:
|
|
147
147
|
seen.add(x)
|
|
@@ -186,9 +186,8 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
186
186
|
performance (targeting 30-60 seconds per batch).
|
|
187
187
|
|
|
188
188
|
Example:
|
|
189
|
-
>>> from typing import List
|
|
190
189
|
>>> p = BatchingMapProxy[int, str](batch_size=3)
|
|
191
|
-
>>> def f(xs:
|
|
190
|
+
>>> def f(xs: list[int]) -> list[str]:
|
|
192
191
|
... return [f"v:{x}" for x in xs]
|
|
193
192
|
>>> p.map([1, 2, 2, 3, 4], f)
|
|
194
193
|
['v:1', 'v:2', 'v:2', 'v:3', 'v:4']
|
|
@@ -204,11 +203,11 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
204
203
|
suggester: BatchSizeSuggester = field(default_factory=BatchSizeSuggester, repr=False)
|
|
205
204
|
|
|
206
205
|
# internals
|
|
207
|
-
_cache:
|
|
206
|
+
_cache: dict[S, T] = field(default_factory=dict)
|
|
208
207
|
_lock: threading.RLock = field(default_factory=threading.RLock, repr=False)
|
|
209
|
-
_inflight:
|
|
208
|
+
_inflight: dict[S, threading.Event] = field(default_factory=dict, repr=False)
|
|
210
209
|
|
|
211
|
-
def __all_cached(self, items:
|
|
210
|
+
def __all_cached(self, items: list[S]) -> bool:
|
|
212
211
|
"""Check whether all items are present in the cache.
|
|
213
212
|
|
|
214
213
|
This method acquires the internal lock to perform a consistent check.
|
|
@@ -222,7 +221,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
222
221
|
with self._lock:
|
|
223
222
|
return all(x in self._cache for x in items)
|
|
224
223
|
|
|
225
|
-
def __values(self, items:
|
|
224
|
+
def __values(self, items: list[S]) -> list[T]:
|
|
226
225
|
"""Fetch cached values for ``items`` preserving the given order.
|
|
227
226
|
|
|
228
227
|
This method acquires the internal lock while reading the cache.
|
|
@@ -237,7 +236,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
237
236
|
with self._lock:
|
|
238
237
|
return [self._cache[x] for x in items]
|
|
239
238
|
|
|
240
|
-
def __acquire_ownership(self, items:
|
|
239
|
+
def __acquire_ownership(self, items: list[S]) -> tuple[list[S], list[S]]:
|
|
241
240
|
"""Acquire ownership for missing items and identify keys to wait for.
|
|
242
241
|
|
|
243
242
|
For each unique item, if it's already cached, it is ignored. If it's
|
|
@@ -253,8 +252,8 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
253
252
|
- ``owned`` are items this thread is responsible for computing.
|
|
254
253
|
- ``wait_for`` are items that another thread is already computing.
|
|
255
254
|
"""
|
|
256
|
-
owned:
|
|
257
|
-
wait_for:
|
|
255
|
+
owned: list[S] = []
|
|
256
|
+
wait_for: list[S] = []
|
|
258
257
|
with self._lock:
|
|
259
258
|
for x in items:
|
|
260
259
|
if x in self._cache:
|
|
@@ -266,7 +265,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
266
265
|
owned.append(x)
|
|
267
266
|
return owned, wait_for
|
|
268
267
|
|
|
269
|
-
def __finalize_success(self, to_call:
|
|
268
|
+
def __finalize_success(self, to_call: list[S], results: list[T]) -> None:
|
|
270
269
|
"""Populate cache with results and signal completion events.
|
|
271
270
|
|
|
272
271
|
Args:
|
|
@@ -285,7 +284,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
285
284
|
if ev:
|
|
286
285
|
ev.set()
|
|
287
286
|
|
|
288
|
-
def __finalize_failure(self, to_call:
|
|
287
|
+
def __finalize_failure(self, to_call: list[S]) -> None:
|
|
289
288
|
"""Release in-flight events on failure to avoid deadlocks.
|
|
290
289
|
|
|
291
290
|
Args:
|
|
@@ -316,7 +315,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
316
315
|
"""Alias for clear()."""
|
|
317
316
|
self.clear()
|
|
318
317
|
|
|
319
|
-
def __process_owned(self, owned:
|
|
318
|
+
def __process_owned(self, owned: list[S], map_func: Callable[[list[S]], list[T]]) -> None:
|
|
320
319
|
"""Process owned items in mini-batches and fill the cache.
|
|
321
320
|
|
|
322
321
|
Before calling ``map_func`` for each batch, the cache is re-checked
|
|
@@ -339,7 +338,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
339
338
|
progress_bar = self._create_progress_bar(len(owned))
|
|
340
339
|
|
|
341
340
|
# Accumulate uncached items to maximize batch size utilization
|
|
342
|
-
pending_to_call:
|
|
341
|
+
pending_to_call: list[S] = []
|
|
343
342
|
|
|
344
343
|
i = 0
|
|
345
344
|
while i < len(owned):
|
|
@@ -395,7 +394,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
395
394
|
# Close progress bar
|
|
396
395
|
self._close_progress_bar(progress_bar)
|
|
397
396
|
|
|
398
|
-
def __wait_for(self, keys:
|
|
397
|
+
def __wait_for(self, keys: list[S], map_func: Callable[[list[S]], list[T]]) -> None:
|
|
399
398
|
"""Wait for other threads to complete computations for the given keys.
|
|
400
399
|
|
|
401
400
|
If a key is neither cached nor in-flight, this method now claims ownership
|
|
@@ -407,7 +406,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
407
406
|
Args:
|
|
408
407
|
keys (list[S]): Items whose computations are owned by other threads.
|
|
409
408
|
"""
|
|
410
|
-
rescued:
|
|
409
|
+
rescued: list[S] = [] # keys we claim to batch-process
|
|
411
410
|
for x in keys:
|
|
412
411
|
while True:
|
|
413
412
|
with self._lock:
|
|
@@ -431,7 +430,7 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
431
430
|
raise
|
|
432
431
|
|
|
433
432
|
# ---- public API ------------------------------------------------------
|
|
434
|
-
def map(self, items:
|
|
433
|
+
def map(self, items: list[S], map_func: Callable[[list[S]], list[T]]) -> list[T]:
|
|
435
434
|
"""Map ``items`` to values using caching and optional mini-batching.
|
|
436
435
|
|
|
437
436
|
This method is thread-safe. It deduplicates inputs while preserving order,
|
|
@@ -494,7 +493,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
494
493
|
>>> import asyncio
|
|
495
494
|
>>> from typing import List
|
|
496
495
|
>>> p = AsyncBatchingMapProxy[int, str](batch_size=2)
|
|
497
|
-
>>> async def af(xs:
|
|
496
|
+
>>> async def af(xs: list[int]) -> list[str]:
|
|
498
497
|
... await asyncio.sleep(0)
|
|
499
498
|
... return [f"v:{x}" for x in xs]
|
|
500
499
|
>>> async def run():
|
|
@@ -514,9 +513,9 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
514
513
|
suggester: BatchSizeSuggester = field(default_factory=BatchSizeSuggester, repr=False)
|
|
515
514
|
|
|
516
515
|
# internals
|
|
517
|
-
_cache:
|
|
516
|
+
_cache: dict[S, T] = field(default_factory=dict, repr=False)
|
|
518
517
|
_lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False)
|
|
519
|
-
_inflight:
|
|
518
|
+
_inflight: dict[S, asyncio.Event] = field(default_factory=dict, repr=False)
|
|
520
519
|
__sema: asyncio.Semaphore | None = field(default=None, init=False, repr=False)
|
|
521
520
|
|
|
522
521
|
def __post_init__(self) -> None:
|
|
@@ -537,7 +536,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
537
536
|
else:
|
|
538
537
|
self.__sema = None
|
|
539
538
|
|
|
540
|
-
async def __all_cached(self, items:
|
|
539
|
+
async def __all_cached(self, items: list[S]) -> bool:
|
|
541
540
|
"""Check whether all items are present in the cache.
|
|
542
541
|
|
|
543
542
|
This method acquires the internal asyncio lock for a consistent view
|
|
@@ -552,7 +551,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
552
551
|
async with self._lock:
|
|
553
552
|
return all(x in self._cache for x in items)
|
|
554
553
|
|
|
555
|
-
async def __values(self, items:
|
|
554
|
+
async def __values(self, items: list[S]) -> list[T]:
|
|
556
555
|
"""Get cached values for ``items`` preserving their given order.
|
|
557
556
|
|
|
558
557
|
The internal asyncio lock is held while reading the cache to preserve
|
|
@@ -567,7 +566,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
567
566
|
async with self._lock:
|
|
568
567
|
return [self._cache[x] for x in items]
|
|
569
568
|
|
|
570
|
-
async def __acquire_ownership(self, items:
|
|
569
|
+
async def __acquire_ownership(self, items: list[S]) -> tuple[list[S], list[S]]:
|
|
571
570
|
"""Acquire ownership for missing keys and identify keys to wait for.
|
|
572
571
|
|
|
573
572
|
Args:
|
|
@@ -578,8 +577,8 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
578
577
|
keys this coroutine should compute, and wait_for are keys currently
|
|
579
578
|
being computed elsewhere.
|
|
580
579
|
"""
|
|
581
|
-
owned:
|
|
582
|
-
wait_for:
|
|
580
|
+
owned: list[S] = []
|
|
581
|
+
wait_for: list[S] = []
|
|
583
582
|
async with self._lock:
|
|
584
583
|
for x in items:
|
|
585
584
|
if x in self._cache:
|
|
@@ -591,7 +590,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
591
590
|
owned.append(x)
|
|
592
591
|
return owned, wait_for
|
|
593
592
|
|
|
594
|
-
async def __finalize_success(self, to_call:
|
|
593
|
+
async def __finalize_success(self, to_call: list[S], results: list[T]) -> None:
|
|
595
594
|
"""Populate cache and signal completion for successfully computed keys.
|
|
596
595
|
|
|
597
596
|
Args:
|
|
@@ -609,7 +608,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
609
608
|
if ev:
|
|
610
609
|
ev.set()
|
|
611
610
|
|
|
612
|
-
async def __finalize_failure(self, to_call:
|
|
611
|
+
async def __finalize_failure(self, to_call: list[S]) -> None:
|
|
613
612
|
"""Release in-flight events on failure to avoid deadlocks.
|
|
614
613
|
|
|
615
614
|
Args:
|
|
@@ -640,7 +639,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
640
639
|
"""Alias for clear()."""
|
|
641
640
|
await self.clear()
|
|
642
641
|
|
|
643
|
-
async def __process_owned(self, owned:
|
|
642
|
+
async def __process_owned(self, owned: list[S], map_func: Callable[[list[S]], Awaitable[list[T]]]) -> None:
|
|
644
643
|
"""Process owned keys using Producer-Consumer pattern with dynamic batch sizing.
|
|
645
644
|
|
|
646
645
|
Args:
|
|
@@ -681,7 +680,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
681
680
|
self._close_progress_bar(progress_bar)
|
|
682
681
|
|
|
683
682
|
async def __process_single_batch(
|
|
684
|
-
self, to_call:
|
|
683
|
+
self, to_call: list[S], map_func: Callable[[list[S]], Awaitable[list[T]]], progress_bar
|
|
685
684
|
) -> None:
|
|
686
685
|
"""Process a single batch with semaphore control."""
|
|
687
686
|
acquired = False
|
|
@@ -703,7 +702,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
703
702
|
# Update progress bar
|
|
704
703
|
self._update_progress_bar(progress_bar, len(to_call))
|
|
705
704
|
|
|
706
|
-
async def __wait_for(self, keys:
|
|
705
|
+
async def __wait_for(self, keys: list[S], map_func: Callable[[list[S]], Awaitable[list[T]]]) -> None:
|
|
707
706
|
"""Wait for computations owned by other coroutines to complete.
|
|
708
707
|
|
|
709
708
|
If a key is neither cached nor in-flight, this method now claims ownership
|
|
@@ -715,7 +714,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
715
714
|
Args:
|
|
716
715
|
keys (list[S]): Items whose computations are owned by other coroutines.
|
|
717
716
|
"""
|
|
718
|
-
rescued:
|
|
717
|
+
rescued: list[S] = [] # keys we claim to batch-process
|
|
719
718
|
for x in keys:
|
|
720
719
|
while True:
|
|
721
720
|
async with self._lock:
|
|
@@ -738,7 +737,7 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
|
|
|
738
737
|
raise
|
|
739
738
|
|
|
740
739
|
# ---- public API ------------------------------------------------------
|
|
741
|
-
async def map(self, items:
|
|
740
|
+
async def map(self, items: list[S], map_func: Callable[[list[S]], Awaitable[list[T]]]) -> list[T]:
|
|
742
741
|
"""Async map with caching, de-duplication, and optional mini-batching.
|
|
743
742
|
|
|
744
743
|
Args:
|
openaivec/_responses.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
3
|
from logging import Logger, getLogger
|
|
4
|
-
from typing import Any, Generic,
|
|
4
|
+
from typing import Any, Generic, cast
|
|
5
5
|
|
|
6
6
|
from openai import AsyncOpenAI, BadRequestError, InternalServerError, OpenAI, RateLimitError
|
|
7
7
|
from openai.types.responses import ParsedResponse
|
|
@@ -120,11 +120,11 @@ class Message(BaseModel, Generic[ResponseFormat]):
|
|
|
120
120
|
|
|
121
121
|
|
|
122
122
|
class Request(BaseModel):
|
|
123
|
-
user_messages:
|
|
123
|
+
user_messages: list[Message[str]]
|
|
124
124
|
|
|
125
125
|
|
|
126
126
|
class Response(BaseModel, Generic[ResponseFormat]):
|
|
127
|
-
assistant_messages:
|
|
127
|
+
assistant_messages: list[Message[ResponseFormat]]
|
|
128
128
|
|
|
129
129
|
|
|
130
130
|
@dataclass(frozen=True)
|
|
@@ -150,7 +150,7 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
150
150
|
system_message (str): System prompt prepended to every request.
|
|
151
151
|
temperature (float): Sampling temperature.
|
|
152
152
|
top_p (float): Nucleus‑sampling parameter.
|
|
153
|
-
response_format (
|
|
153
|
+
response_format (type[ResponseFormat]): Expected Pydantic model class or ``str`` for each assistant message.
|
|
154
154
|
cache (BatchingMapProxy[str, ResponseFormat]): Order‑preserving batching proxy with de‑duplication and caching.
|
|
155
155
|
|
|
156
156
|
Notes:
|
|
@@ -165,7 +165,7 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
165
165
|
system_message: str
|
|
166
166
|
temperature: float | None = None
|
|
167
167
|
top_p: float = 1.0
|
|
168
|
-
response_format:
|
|
168
|
+
response_format: type[ResponseFormat] = str # type: ignore[assignment]
|
|
169
169
|
cache: BatchingMapProxy[str, ResponseFormat] = field(default_factory=lambda: BatchingMapProxy(batch_size=None))
|
|
170
170
|
_vectorized_system_message: str = field(init=False)
|
|
171
171
|
_model_json_schema: dict = field(init=False)
|
|
@@ -178,7 +178,7 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
178
178
|
system_message: str,
|
|
179
179
|
temperature: float | None = 0.0,
|
|
180
180
|
top_p: float = 1.0,
|
|
181
|
-
response_format:
|
|
181
|
+
response_format: type[ResponseFormat] = str,
|
|
182
182
|
batch_size: int | None = None,
|
|
183
183
|
) -> "BatchResponses":
|
|
184
184
|
"""Factory constructor.
|
|
@@ -189,7 +189,7 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
189
189
|
system_message (str): System prompt for the model.
|
|
190
190
|
temperature (float, optional): Sampling temperature. Defaults to 0.0.
|
|
191
191
|
top_p (float, optional): Nucleus sampling parameter. Defaults to 1.0.
|
|
192
|
-
response_format (
|
|
192
|
+
response_format (type[ResponseFormat], optional): Expected output type. Defaults to ``str``.
|
|
193
193
|
batch_size (int | None, optional): Max unique prompts per API call. Defaults to None
|
|
194
194
|
(automatic batch size optimization). Set to a positive integer for fixed batch size.
|
|
195
195
|
|
|
@@ -242,12 +242,12 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
242
242
|
@observe(_LOGGER)
|
|
243
243
|
@backoff(exceptions=[RateLimitError, InternalServerError], scale=1, max_retries=12)
|
|
244
244
|
def _request_llm(
|
|
245
|
-
self, user_messages:
|
|
245
|
+
self, user_messages: list[Message[str]], **extra_api_params: Any
|
|
246
246
|
) -> ParsedResponse[Response[ResponseFormat]]:
|
|
247
247
|
"""Make a single call to the OpenAI JSON‑mode endpoint.
|
|
248
248
|
|
|
249
249
|
Args:
|
|
250
|
-
user_messages (
|
|
250
|
+
user_messages (list[Message[str]]): Sequence of ``Message[str]`` representing the
|
|
251
251
|
prompts for this minibatch. Each message carries a unique `id`
|
|
252
252
|
so we can restore ordering later.
|
|
253
253
|
|
|
@@ -265,7 +265,7 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
265
265
|
body: response_format # type: ignore
|
|
266
266
|
|
|
267
267
|
class ResponseT(BaseModel):
|
|
268
|
-
assistant_messages:
|
|
268
|
+
assistant_messages: list[MessageT]
|
|
269
269
|
|
|
270
270
|
# Build base API parameters (cannot be overridden by caller)
|
|
271
271
|
api_params: dict[str, Any] = {
|
|
@@ -300,7 +300,7 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
300
300
|
return cast(ParsedResponse[Response[ResponseFormat]], completion)
|
|
301
301
|
|
|
302
302
|
@observe(_LOGGER)
|
|
303
|
-
def _predict_chunk(self, user_messages:
|
|
303
|
+
def _predict_chunk(self, user_messages: list[str], **api_kwargs: Any) -> list[ResponseFormat | None]:
|
|
304
304
|
"""Helper executed for every unique minibatch.
|
|
305
305
|
|
|
306
306
|
This method:
|
|
@@ -316,11 +316,11 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
316
316
|
if not responses.output_parsed:
|
|
317
317
|
return [None] * len(messages)
|
|
318
318
|
response_dict = {message.id: message.body for message in responses.output_parsed.assistant_messages}
|
|
319
|
-
sorted_responses:
|
|
319
|
+
sorted_responses: list[ResponseFormat | None] = [response_dict.get(m.id, None) for m in messages]
|
|
320
320
|
return sorted_responses
|
|
321
321
|
|
|
322
322
|
@observe(_LOGGER)
|
|
323
|
-
def parse(self, inputs:
|
|
323
|
+
def parse(self, inputs: list[str], **api_kwargs: Any) -> list[ResponseFormat | None]:
|
|
324
324
|
"""Batched predict.
|
|
325
325
|
|
|
326
326
|
Accepts arbitrary keyword arguments that are forwarded to the underlying
|
|
@@ -329,16 +329,16 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
329
329
|
configured values but can be overridden explicitly.
|
|
330
330
|
|
|
331
331
|
Args:
|
|
332
|
-
inputs (
|
|
332
|
+
inputs (list[str]): Prompts that require responses. Duplicates are de‑duplicated.
|
|
333
333
|
**api_kwargs: Extra keyword args forwarded to the OpenAI Responses API.
|
|
334
334
|
|
|
335
335
|
Returns:
|
|
336
|
-
|
|
336
|
+
list[ResponseFormat | None]: Assistant responses aligned to ``inputs``.
|
|
337
337
|
"""
|
|
338
338
|
if not api_kwargs:
|
|
339
339
|
return self.cache.map(inputs, self._predict_chunk) # type: ignore[return-value]
|
|
340
340
|
|
|
341
|
-
def _predict_with(xs:
|
|
341
|
+
def _predict_with(xs: list[str]) -> list[ResponseFormat | None]:
|
|
342
342
|
return self._predict_chunk(xs, **api_kwargs)
|
|
343
343
|
|
|
344
344
|
return self.cache.map(inputs, _predict_with) # type: ignore[return-value]
|
|
@@ -385,7 +385,7 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
385
385
|
system_message (str): System prompt prepended to every request.
|
|
386
386
|
temperature (float): Sampling temperature.
|
|
387
387
|
top_p (float): Nucleus‑sampling parameter.
|
|
388
|
-
response_format (
|
|
388
|
+
response_format (type[ResponseFormat]): Expected Pydantic model class or ``str`` for each assistant message.
|
|
389
389
|
cache (AsyncBatchingMapProxy[str, ResponseFormat]): Async batching proxy with de‑duplication
|
|
390
390
|
and concurrency control.
|
|
391
391
|
"""
|
|
@@ -395,7 +395,7 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
395
395
|
system_message: str
|
|
396
396
|
temperature: float | None = 0.0
|
|
397
397
|
top_p: float = 1.0
|
|
398
|
-
response_format:
|
|
398
|
+
response_format: type[ResponseFormat] = str # type: ignore[assignment]
|
|
399
399
|
cache: AsyncBatchingMapProxy[str, ResponseFormat] = field(
|
|
400
400
|
default_factory=lambda: AsyncBatchingMapProxy(batch_size=None, max_concurrency=8)
|
|
401
401
|
)
|
|
@@ -410,7 +410,7 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
410
410
|
system_message: str,
|
|
411
411
|
temperature: float | None = None,
|
|
412
412
|
top_p: float = 1.0,
|
|
413
|
-
response_format:
|
|
413
|
+
response_format: type[ResponseFormat] = str,
|
|
414
414
|
batch_size: int | None = None,
|
|
415
415
|
max_concurrency: int = 8,
|
|
416
416
|
) -> "AsyncBatchResponses":
|
|
@@ -422,7 +422,7 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
422
422
|
system_message (str): System prompt.
|
|
423
423
|
temperature (float, optional): Sampling temperature. Defaults to 0.0.
|
|
424
424
|
top_p (float, optional): Nucleus sampling parameter. Defaults to 1.0.
|
|
425
|
-
response_format (
|
|
425
|
+
response_format (type[ResponseFormat], optional): Expected output type. Defaults to ``str``.
|
|
426
426
|
batch_size (int | None, optional): Max unique prompts per API call. Defaults to None
|
|
427
427
|
(automatic batch size optimization). Set to a positive integer for fixed batch size.
|
|
428
428
|
max_concurrency (int, optional): Max concurrent API calls. Defaults to 8.
|
|
@@ -482,12 +482,12 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
482
482
|
@backoff_async(exceptions=[RateLimitError, InternalServerError], scale=1, max_retries=12)
|
|
483
483
|
@observe(_LOGGER)
|
|
484
484
|
async def _request_llm(
|
|
485
|
-
self, user_messages:
|
|
485
|
+
self, user_messages: list[Message[str]], **extra_api_params: Any
|
|
486
486
|
) -> ParsedResponse[Response[ResponseFormat]]:
|
|
487
487
|
"""Make a single async call to the OpenAI JSON‑mode endpoint.
|
|
488
488
|
|
|
489
489
|
Args:
|
|
490
|
-
user_messages (
|
|
490
|
+
user_messages (list[Message[str]]): Sequence of ``Message[str]`` representing the minibatch prompts.
|
|
491
491
|
|
|
492
492
|
Returns:
|
|
493
493
|
ParsedResponse[Response[ResponseFormat]]: Parsed response with assistant messages (arbitrary order).
|
|
@@ -502,7 +502,7 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
502
502
|
body: response_format # type: ignore
|
|
503
503
|
|
|
504
504
|
class ResponseT(BaseModel):
|
|
505
|
-
assistant_messages:
|
|
505
|
+
assistant_messages: list[MessageT]
|
|
506
506
|
|
|
507
507
|
# Build base API parameters (cannot be overridden by caller)
|
|
508
508
|
api_params: dict[str, Any] = {
|
|
@@ -537,7 +537,7 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
537
537
|
return cast(ParsedResponse[Response[ResponseFormat]], completion)
|
|
538
538
|
|
|
539
539
|
@observe(_LOGGER)
|
|
540
|
-
async def _predict_chunk(self, user_messages:
|
|
540
|
+
async def _predict_chunk(self, user_messages: list[str], **api_kwargs: Any) -> list[ResponseFormat | None]:
|
|
541
541
|
"""Async helper executed for every unique minibatch.
|
|
542
542
|
|
|
543
543
|
This method:
|
|
@@ -553,11 +553,11 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
553
553
|
return [None] * len(messages)
|
|
554
554
|
response_dict = {message.id: message.body for message in responses.output_parsed.assistant_messages}
|
|
555
555
|
# Ensure proper handling for missing IDs - this shouldn't happen in normal operation
|
|
556
|
-
sorted_responses:
|
|
556
|
+
sorted_responses: list[ResponseFormat | None] = [response_dict.get(m.id, None) for m in messages]
|
|
557
557
|
return sorted_responses
|
|
558
558
|
|
|
559
559
|
@observe(_LOGGER)
|
|
560
|
-
async def parse(self, inputs:
|
|
560
|
+
async def parse(self, inputs: list[str], **api_kwargs: Any) -> list[ResponseFormat | None]:
|
|
561
561
|
"""Batched predict (async).
|
|
562
562
|
|
|
563
563
|
Accepts arbitrary keyword arguments forwarded to ``AsyncOpenAI.responses.parse``.
|
|
@@ -566,16 +566,16 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
566
566
|
changing the public surface again.
|
|
567
567
|
|
|
568
568
|
Args:
|
|
569
|
-
inputs (
|
|
569
|
+
inputs (list[str]): Prompts that require responses. Duplicates are de‑duplicated.
|
|
570
570
|
**api_kwargs: Extra keyword args for the OpenAI Responses API.
|
|
571
571
|
|
|
572
572
|
Returns:
|
|
573
|
-
|
|
573
|
+
list[ResponseFormat | None]: Assistant responses aligned to ``inputs``.
|
|
574
574
|
"""
|
|
575
575
|
if not api_kwargs:
|
|
576
576
|
return await self.cache.map(inputs, self._predict_chunk) # type: ignore[return-value]
|
|
577
577
|
|
|
578
|
-
async def _predict_with(xs:
|
|
578
|
+
async def _predict_with(xs: list[str]) -> list[ResponseFormat | None]:
|
|
579
579
|
return await self._predict_chunk(xs, **api_kwargs)
|
|
580
580
|
|
|
581
581
|
return await self.cache.map(inputs, _predict_with) # type: ignore[return-value]
|