openaivec 0.13.4__py3-none-any.whl → 0.13.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
openaivec/proxy.py CHANGED
@@ -2,7 +2,7 @@ import asyncio
2
2
  import threading
3
3
  from collections.abc import Hashable
4
4
  from dataclasses import dataclass, field
5
- from typing import Awaitable, Callable, Dict, Generic, List, Optional, TypeVar
5
+ from typing import Any, Awaitable, Callable, Dict, Generic, List, TypeVar
6
6
 
7
7
  from openaivec.optimize import BatchSizeSuggester
8
8
 
@@ -22,9 +22,9 @@ class ProxyBase(Generic[S, T]):
22
22
  should process the entire input in a single call.
23
23
  """
24
24
 
25
- batch_size: Optional[int] = None # subclasses may override via dataclass
26
- show_progress: bool = False # Enable progress bar display
27
- suggester: BatchSizeSuggester = None # Batch size optimization, initialized by subclasses
25
+ batch_size: int | None # subclasses may override via dataclass
26
+ show_progress: bool # Enable progress bar display
27
+ suggester: BatchSizeSuggester # Batch size optimization, initialized by subclasses
28
28
 
29
29
  def _is_notebook_environment(self) -> bool:
30
30
  """Check if running in a Jupyter notebook environment.
@@ -33,7 +33,7 @@ class ProxyBase(Generic[S, T]):
33
33
  bool: True if running in a notebook, False otherwise.
34
34
  """
35
35
  try:
36
- from IPython import get_ipython
36
+ from IPython.core.getipython import get_ipython
37
37
 
38
38
  ipython = get_ipython()
39
39
  if ipython is not None:
@@ -89,7 +89,7 @@ class ProxyBase(Generic[S, T]):
89
89
 
90
90
  return False
91
91
 
92
- def _create_progress_bar(self, total: int, desc: str = "Processing batches") -> Optional[object]:
92
+ def _create_progress_bar(self, total: int, desc: str = "Processing batches") -> Any:
93
93
  """Create a progress bar if conditions are met.
94
94
 
95
95
  Args:
@@ -97,7 +97,7 @@ class ProxyBase(Generic[S, T]):
97
97
  desc (str): Description for the progress bar.
98
98
 
99
99
  Returns:
100
- Optional[object]: Progress bar instance or None if not available.
100
+ Any: Progress bar instance or None if not available.
101
101
  """
102
102
  try:
103
103
  from tqdm.auto import tqdm as tqdm_progress
@@ -108,21 +108,21 @@ class ProxyBase(Generic[S, T]):
108
108
  pass
109
109
  return None
110
110
 
111
- def _update_progress_bar(self, progress_bar: Optional[object], increment: int) -> None:
111
+ def _update_progress_bar(self, progress_bar: Any, increment: int) -> None:
112
112
  """Update progress bar with the given increment.
113
113
 
114
114
  Args:
115
- progress_bar (Optional[object]): Progress bar instance.
115
+ progress_bar (Any): Progress bar instance.
116
116
  increment (int): Number of items to increment.
117
117
  """
118
118
  if progress_bar:
119
119
  progress_bar.update(increment)
120
120
 
121
- def _close_progress_bar(self, progress_bar: Optional[object]) -> None:
121
+ def _close_progress_bar(self, progress_bar: Any) -> None:
122
122
  """Close the progress bar.
123
123
 
124
124
  Args:
125
- progress_bar (Optional[object]): Progress bar instance.
125
+ progress_bar (Any): Progress bar instance.
126
126
  """
127
127
  if progress_bar:
128
128
  progress_bar.close()
@@ -179,6 +179,10 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
179
179
  not duplicate work via an in-flight registry. All public behavior is preserved
180
180
  while minimizing redundant requests and maintaining input order in the output.
181
181
 
182
+ When ``batch_size=None``, automatic batch size optimization is enabled,
183
+ dynamically adjusting batch sizes based on execution time to maintain optimal
184
+ performance (targeting 30-60 seconds per batch).
185
+
182
186
  Example:
183
187
  >>> from typing import List
184
188
  >>> p = BatchingMapProxy[int, str](batch_size=3)
@@ -188,15 +192,19 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
188
192
  ['v:1', 'v:2', 'v:2', 'v:3', 'v:4']
189
193
  """
190
194
 
191
- # Number of items to process per call to map_func. If None or <= 0, process all at once.
192
- batch_size: Optional[int] = None
195
+ # Number of items to process per call to map_func.
196
+ # - If None (default): Enables automatic batch size optimization, dynamically adjusting
197
+ # based on execution time (targeting 30-60 seconds per batch)
198
+ # - If positive integer: Fixed batch size
199
+ # - If <= 0: Process all items at once
200
+ batch_size: int | None = None
193
201
  show_progress: bool = False
194
202
  suggester: BatchSizeSuggester = field(default_factory=BatchSizeSuggester, repr=False)
195
203
 
196
204
  # internals
197
- __cache: Dict[S, T] = field(default_factory=dict)
198
- __lock: threading.RLock = field(default_factory=threading.RLock, repr=False)
199
- __inflight: Dict[S, threading.Event] = field(default_factory=dict, repr=False)
205
+ _cache: Dict[S, T] = field(default_factory=dict)
206
+ _lock: threading.RLock = field(default_factory=threading.RLock, repr=False)
207
+ _inflight: Dict[S, threading.Event] = field(default_factory=dict, repr=False)
200
208
 
201
209
  def __all_cached(self, items: List[S]) -> bool:
202
210
  """Check whether all items are present in the cache.
@@ -209,8 +217,8 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
209
217
  Returns:
210
218
  bool: True if every item is already cached, False otherwise.
211
219
  """
212
- with self.__lock:
213
- return all(x in self.__cache for x in items)
220
+ with self._lock:
221
+ return all(x in self._cache for x in items)
214
222
 
215
223
  def __values(self, items: List[S]) -> List[T]:
216
224
  """Fetch cached values for ``items`` preserving the given order.
@@ -224,8 +232,8 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
224
232
  list[T]: The cached values corresponding to ``items`` in the same
225
233
  order.
226
234
  """
227
- with self.__lock:
228
- return [self.__cache[x] for x in items]
235
+ with self._lock:
236
+ return [self._cache[x] for x in items]
229
237
 
230
238
  def __acquire_ownership(self, items: List[S]) -> tuple[List[S], List[S]]:
231
239
  """Acquire ownership for missing items and identify keys to wait for.
@@ -245,14 +253,14 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
245
253
  """
246
254
  owned: List[S] = []
247
255
  wait_for: List[S] = []
248
- with self.__lock:
256
+ with self._lock:
249
257
  for x in items:
250
- if x in self.__cache:
258
+ if x in self._cache:
251
259
  continue
252
- if x in self.__inflight:
260
+ if x in self._inflight:
253
261
  wait_for.append(x)
254
262
  else:
255
- self.__inflight[x] = threading.Event()
263
+ self._inflight[x] = threading.Event()
256
264
  owned.append(x)
257
265
  return owned, wait_for
258
266
 
@@ -268,10 +276,10 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
268
276
  # Release waiters and surface a clear error.
269
277
  self.__finalize_failure(to_call)
270
278
  raise ValueError("map_func must return a list of results with the same length and order as inputs")
271
- with self.__lock:
279
+ with self._lock:
272
280
  for x, y in zip(to_call, results):
273
- self.__cache[x] = y
274
- ev = self.__inflight.pop(x, None)
281
+ self._cache[x] = y
282
+ ev = self._inflight.pop(x, None)
275
283
  if ev:
276
284
  ev.set()
277
285
 
@@ -282,9 +290,9 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
282
290
  to_call (list[S]): Items that were intended to be computed when an
283
291
  error occurred.
284
292
  """
285
- with self.__lock:
293
+ with self._lock:
286
294
  for x in to_call:
287
- ev = self.__inflight.pop(x, None)
295
+ ev = self._inflight.pop(x, None)
288
296
  if ev:
289
297
  ev.set()
290
298
 
@@ -296,11 +304,11 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
296
304
  - Do not call concurrently with active map() calls to avoid
297
305
  unnecessary recomputation or racy wake-ups.
298
306
  """
299
- with self.__lock:
300
- for ev in self.__inflight.values():
307
+ with self._lock:
308
+ for ev in self._inflight.values():
301
309
  ev.set()
302
- self.__inflight.clear()
303
- self.__cache.clear()
310
+ self._inflight.clear()
311
+ self._cache.clear()
304
312
 
305
313
  def close(self) -> None:
306
314
  """Alias for clear()."""
@@ -337,8 +345,8 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
337
345
  current_batch_size = self._normalized_batch_size(len(owned))
338
346
  batch = owned[i : i + current_batch_size]
339
347
  # Double-check cache right before processing
340
- with self.__lock:
341
- uncached_in_batch = [x for x in batch if x not in self.__cache]
348
+ with self._lock:
349
+ uncached_in_batch = [x for x in batch if x not in self._cache]
342
350
 
343
351
  pending_to_call.extend(uncached_in_batch)
344
352
 
@@ -400,13 +408,13 @@ class BatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
400
408
  rescued: List[S] = [] # keys we claim to batch-process
401
409
  for x in keys:
402
410
  while True:
403
- with self.__lock:
404
- if x in self.__cache:
411
+ with self._lock:
412
+ if x in self._cache:
405
413
  break
406
- ev = self.__inflight.get(x)
414
+ ev = self._inflight.get(x)
407
415
  if ev is None:
408
416
  # Not cached and no one computing; claim ownership to batch later.
409
- self.__inflight[x] = threading.Event()
417
+ self._inflight[x] = threading.Event()
410
418
  rescued.append(x)
411
419
  break
412
420
  # Someone else is computing; wait for completion.
@@ -463,6 +471,10 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
463
471
  coordinates concurrent coroutines to avoid duplicate work via an in-flight
464
472
  registry of asyncio events.
465
473
 
474
+ When ``batch_size=None``, automatic batch size optimization is enabled,
475
+ dynamically adjusting batch sizes based on execution time to maintain optimal
476
+ performance (targeting 30-60 seconds per batch).
477
+
466
478
  Example:
467
479
  >>> import asyncio
468
480
  >>> from typing import List
@@ -476,16 +488,21 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
476
488
  ['v:1', 'v:2', 'v:3']
477
489
  """
478
490
 
479
- batch_size: Optional[int] = None
491
+ # Number of items to process per call to map_func.
492
+ # - If None (default): Enables automatic batch size optimization, dynamically adjusting
493
+ # based on execution time (targeting 30-60 seconds per batch)
494
+ # - If positive integer: Fixed batch size
495
+ # - If <= 0: Process all items at once
496
+ batch_size: int | None = None
480
497
  max_concurrency: int = 8
481
498
  show_progress: bool = False
482
499
  suggester: BatchSizeSuggester = field(default_factory=BatchSizeSuggester, repr=False)
483
500
 
484
501
  # internals
485
- __cache: Dict[S, T] = field(default_factory=dict, repr=False)
486
- __lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False)
487
- __inflight: Dict[S, asyncio.Event] = field(default_factory=dict, repr=False)
488
- __sema: Optional[asyncio.Semaphore] = field(default=None, init=False, repr=False)
502
+ _cache: Dict[S, T] = field(default_factory=dict, repr=False)
503
+ _lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False)
504
+ _inflight: Dict[S, asyncio.Event] = field(default_factory=dict, repr=False)
505
+ __sema: asyncio.Semaphore | None = field(default=None, init=False, repr=False)
489
506
 
490
507
  def __post_init__(self) -> None:
491
508
  """Initialize internal semaphore based on ``max_concurrency``.
@@ -517,8 +534,8 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
517
534
  Returns:
518
535
  bool: True if every item in ``items`` is already cached, False otherwise.
519
536
  """
520
- async with self.__lock:
521
- return all(x in self.__cache for x in items)
537
+ async with self._lock:
538
+ return all(x in self._cache for x in items)
522
539
 
523
540
  async def __values(self, items: List[S]) -> List[T]:
524
541
  """Get cached values for ``items`` preserving their given order.
@@ -532,8 +549,8 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
532
549
  Returns:
533
550
  list[T]: Cached values corresponding to ``items`` in the same order.
534
551
  """
535
- async with self.__lock:
536
- return [self.__cache[x] for x in items]
552
+ async with self._lock:
553
+ return [self._cache[x] for x in items]
537
554
 
538
555
  async def __acquire_ownership(self, items: List[S]) -> tuple[List[S], List[S]]:
539
556
  """Acquire ownership for missing keys and identify keys to wait for.
@@ -548,14 +565,14 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
548
565
  """
549
566
  owned: List[S] = []
550
567
  wait_for: List[S] = []
551
- async with self.__lock:
568
+ async with self._lock:
552
569
  for x in items:
553
- if x in self.__cache:
570
+ if x in self._cache:
554
571
  continue
555
- if x in self.__inflight:
572
+ if x in self._inflight:
556
573
  wait_for.append(x)
557
574
  else:
558
- self.__inflight[x] = asyncio.Event()
575
+ self._inflight[x] = asyncio.Event()
559
576
  owned.append(x)
560
577
  return owned, wait_for
561
578
 
@@ -570,10 +587,10 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
570
587
  # Prevent deadlocks if map_func violates the contract.
571
588
  await self.__finalize_failure(to_call)
572
589
  raise ValueError("map_func must return a list of results with the same length and order as inputs")
573
- async with self.__lock:
590
+ async with self._lock:
574
591
  for x, y in zip(to_call, results):
575
- self.__cache[x] = y
576
- ev = self.__inflight.pop(x, None)
592
+ self._cache[x] = y
593
+ ev = self._inflight.pop(x, None)
577
594
  if ev:
578
595
  ev.set()
579
596
 
@@ -584,9 +601,9 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
584
601
  to_call (list[S]): Items whose computation failed; their waiters will
585
602
  be released.
586
603
  """
587
- async with self.__lock:
604
+ async with self._lock:
588
605
  for x in to_call:
589
- ev = self.__inflight.pop(x, None)
606
+ ev = self._inflight.pop(x, None)
590
607
  if ev:
591
608
  ev.set()
592
609
 
@@ -598,11 +615,11 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
598
615
  - Do not call concurrently with active map() calls to avoid
599
616
  unnecessary recomputation or racy wake-ups.
600
617
  """
601
- async with self.__lock:
602
- for ev in self.__inflight.values():
618
+ async with self._lock:
619
+ for ev in self._inflight.values():
603
620
  ev.set()
604
- self.__inflight.clear()
605
- self.__cache.clear()
621
+ self._inflight.clear()
622
+ self._cache.clear()
606
623
 
607
624
  async def aclose(self) -> None:
608
625
  """Alias for clear()."""
@@ -686,13 +703,13 @@ class AsyncBatchingMapProxy(ProxyBase[S, T], Generic[S, T]):
686
703
  rescued: List[S] = [] # keys we claim to batch-process
687
704
  for x in keys:
688
705
  while True:
689
- async with self.__lock:
690
- if x in self.__cache:
706
+ async with self._lock:
707
+ if x in self._cache:
691
708
  break
692
- ev = self.__inflight.get(x)
709
+ ev = self._inflight.get(x)
693
710
  if ev is None:
694
711
  # Not cached and no one computing; claim ownership to batch later.
695
- self.__inflight[x] = asyncio.Event()
712
+ self._inflight[x] = asyncio.Event()
696
713
  rescued.append(x)
697
714
  break
698
715
  # Someone else is computing; wait for completion.
openaivec/responses.py CHANGED
@@ -165,8 +165,8 @@ class BatchResponses(Generic[ResponseFormat]):
165
165
  system_message: str
166
166
  temperature: float | None = 0.0
167
167
  top_p: float = 1.0
168
- response_format: Type[ResponseFormat] = str
169
- cache: BatchingMapProxy[str, ResponseFormat] = field(default_factory=lambda: BatchingMapProxy(batch_size=128))
168
+ response_format: Type[ResponseFormat] = str # type: ignore[assignment]
169
+ cache: BatchingMapProxy[str, ResponseFormat] = field(default_factory=lambda: BatchingMapProxy(batch_size=None))
170
170
  _vectorized_system_message: str = field(init=False)
171
171
  _model_json_schema: dict = field(init=False)
172
172
 
@@ -179,7 +179,7 @@ class BatchResponses(Generic[ResponseFormat]):
179
179
  temperature: float | None = 0.0,
180
180
  top_p: float = 1.0,
181
181
  response_format: Type[ResponseFormat] = str,
182
- batch_size: int = 128,
182
+ batch_size: int | None = None,
183
183
  ) -> "BatchResponses":
184
184
  """Factory constructor.
185
185
 
@@ -190,7 +190,8 @@ class BatchResponses(Generic[ResponseFormat]):
190
190
  temperature (float, optional): Sampling temperature. Defaults to 0.0.
191
191
  top_p (float, optional): Nucleus sampling parameter. Defaults to 1.0.
192
192
  response_format (Type[ResponseFormat], optional): Expected output type. Defaults to ``str``.
193
- batch_size (int, optional): Max unique prompts per API call. Defaults to 128.
193
+ batch_size (int | None, optional): Max unique prompts per API call. Defaults to None
194
+ (automatic batch size optimization). Set to a positive integer for fixed batch size.
194
195
 
195
196
  Returns:
196
197
  BatchResponses: Configured instance backed by a batching proxy.
@@ -206,14 +207,17 @@ class BatchResponses(Generic[ResponseFormat]):
206
207
  )
207
208
 
208
209
  @classmethod
209
- def of_task(cls, client: OpenAI, model_name: str, task: PreparedTask, batch_size: int = 128) -> "BatchResponses":
210
+ def of_task(
211
+ cls, client: OpenAI, model_name: str, task: PreparedTask[ResponseFormat], batch_size: int | None = None
212
+ ) -> "BatchResponses":
210
213
  """Factory from a PreparedTask.
211
214
 
212
215
  Args:
213
216
  client (OpenAI): OpenAI client.
214
217
  model_name (str): For Azure OpenAI, use your deployment name. For OpenAI, use the model name.
215
218
  task (PreparedTask): Prepared task with instructions and response format.
216
- batch_size (int, optional): Max unique prompts per API call. Defaults to 128.
219
+ batch_size (int | None, optional): Max unique prompts per API call. Defaults to None
220
+ (automatic batch size optimization). Set to a positive integer for fixed batch size.
217
221
 
218
222
  Returns:
219
223
  BatchResponses: Configured instance backed by a batching proxy.
@@ -294,8 +298,10 @@ class BatchResponses(Generic[ResponseFormat]):
294
298
  """
295
299
  messages = [Message(id=i, body=message) for i, message in enumerate(user_messages)]
296
300
  responses: ParsedResponse[Response[ResponseFormat]] = self._request_llm(messages)
301
+ if not responses.output_parsed:
302
+ return [None] * len(messages)
297
303
  response_dict = {message.id: message.body for message in responses.output_parsed.assistant_messages}
298
- sorted_responses = [response_dict.get(m.id, None) for m in messages]
304
+ sorted_responses: List[ResponseFormat | None] = [response_dict.get(m.id, None) for m in messages]
299
305
  return sorted_responses
300
306
 
301
307
  @observe(_LOGGER)
@@ -308,7 +314,8 @@ class BatchResponses(Generic[ResponseFormat]):
308
314
  Returns:
309
315
  List[ResponseFormat | None]: Assistant responses aligned to ``inputs``.
310
316
  """
311
- return self.cache.map(inputs, self._predict_chunk)
317
+ result = self.cache.map(inputs, self._predict_chunk)
318
+ return result # type: ignore[return-value]
312
319
 
313
320
 
314
321
  @dataclass(frozen=True)
@@ -362,9 +369,9 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
362
369
  system_message: str
363
370
  temperature: float | None = 0.0
364
371
  top_p: float = 1.0
365
- response_format: Type[ResponseFormat] = str
372
+ response_format: Type[ResponseFormat] = str # type: ignore[assignment]
366
373
  cache: AsyncBatchingMapProxy[str, ResponseFormat] = field(
367
- default_factory=lambda: AsyncBatchingMapProxy(batch_size=128, max_concurrency=8)
374
+ default_factory=lambda: AsyncBatchingMapProxy(batch_size=None, max_concurrency=8)
368
375
  )
369
376
  _vectorized_system_message: str = field(init=False)
370
377
  _model_json_schema: dict = field(init=False)
@@ -378,7 +385,7 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
378
385
  temperature: float | None = 0.0,
379
386
  top_p: float = 1.0,
380
387
  response_format: Type[ResponseFormat] = str,
381
- batch_size: int = 128,
388
+ batch_size: int | None = None,
382
389
  max_concurrency: int = 8,
383
390
  ) -> "AsyncBatchResponses":
384
391
  """Factory constructor.
@@ -390,7 +397,8 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
390
397
  temperature (float, optional): Sampling temperature. Defaults to 0.0.
391
398
  top_p (float, optional): Nucleus sampling parameter. Defaults to 1.0.
392
399
  response_format (Type[ResponseFormat], optional): Expected output type. Defaults to ``str``.
393
- batch_size (int, optional): Max unique prompts per API call. Defaults to 128.
400
+ batch_size (int | None, optional): Max unique prompts per API call. Defaults to None
401
+ (automatic batch size optimization). Set to a positive integer for fixed batch size.
394
402
  max_concurrency (int, optional): Max concurrent API calls. Defaults to 8.
395
403
 
396
404
  Returns:
@@ -408,7 +416,12 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
408
416
 
409
417
  @classmethod
410
418
  def of_task(
411
- cls, client: AsyncOpenAI, model_name: str, task: PreparedTask, batch_size: int = 128, max_concurrency: int = 8
419
+ cls,
420
+ client: AsyncOpenAI,
421
+ model_name: str,
422
+ task: PreparedTask[ResponseFormat],
423
+ batch_size: int | None = None,
424
+ max_concurrency: int = 8,
412
425
  ) -> "AsyncBatchResponses":
413
426
  """Factory from a PreparedTask.
414
427
 
@@ -416,7 +429,8 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
416
429
  client (AsyncOpenAI): OpenAI async client.
417
430
  model_name (str): For Azure OpenAI, use your deployment name. For OpenAI, use the model name.
418
431
  task (PreparedTask): Prepared task with instructions and response format.
419
- batch_size (int, optional): Max unique prompts per API call. Defaults to 128.
432
+ batch_size (int | None, optional): Max unique prompts per API call. Defaults to None
433
+ (automatic batch size optimization). Set to a positive integer for fixed batch size.
420
434
  max_concurrency (int, optional): Max concurrent API calls. Defaults to 8.
421
435
 
422
436
  Returns:
@@ -439,8 +453,8 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
439
453
  _vectorize_system_message(self.system_message),
440
454
  )
441
455
 
442
- @observe(_LOGGER)
443
456
  @backoff_async(exceptions=[RateLimitError, InternalServerError], scale=1, max_retries=12)
457
+ @observe(_LOGGER)
444
458
  async def _request_llm(self, user_messages: List[Message[str]]) -> ParsedResponse[Response[ResponseFormat]]:
445
459
  """Make a single async call to the OpenAI JSON‑mode endpoint.
446
460
 
@@ -493,10 +507,12 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
493
507
  The function is pure – it has no side‑effects and the result depends only on its arguments.
494
508
  """
495
509
  messages = [Message(id=i, body=message) for i, message in enumerate(user_messages)]
496
- responses: ParsedResponse[Response[ResponseFormat]] = await self._request_llm(messages)
510
+ responses: ParsedResponse[Response[ResponseFormat]] = await self._request_llm(messages) # type: ignore[call-issue]
511
+ if not responses.output_parsed:
512
+ return [None] * len(messages)
497
513
  response_dict = {message.id: message.body for message in responses.output_parsed.assistant_messages}
498
514
  # Ensure proper handling for missing IDs - this shouldn't happen in normal operation
499
- sorted_responses = [response_dict.get(m.id, None) for m in messages]
515
+ sorted_responses: List[ResponseFormat | None] = [response_dict.get(m.id, None) for m in messages]
500
516
  return sorted_responses
501
517
 
502
518
  @observe(_LOGGER)
@@ -509,4 +525,5 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
509
525
  Returns:
510
526
  List[ResponseFormat | None]: Assistant responses aligned to ``inputs``.
511
527
  """
512
- return await self.cache.map(inputs, self._predict_chunk)
528
+ result = await self.cache.map(inputs, self._predict_chunk)
529
+ return result # type: ignore[return-value]