openaivec 0.14.12__py3-none-any.whl → 0.14.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openaivec/_embeddings.py +17 -4
- openaivec/_model.py +7 -12
- openaivec/_prompt.py +3 -6
- openaivec/_responses.py +39 -117
- openaivec/_schema.py +27 -23
- openaivec/pandas_ext.py +355 -343
- openaivec/spark.py +98 -56
- openaivec/task/__init__.py +1 -1
- openaivec/task/customer_support/customer_sentiment.py +4 -9
- openaivec/task/customer_support/inquiry_classification.py +5 -8
- openaivec/task/customer_support/inquiry_summary.py +5 -6
- openaivec/task/customer_support/intent_analysis.py +5 -7
- openaivec/task/customer_support/response_suggestion.py +5 -8
- openaivec/task/customer_support/urgency_analysis.py +5 -8
- openaivec/task/nlp/dependency_parsing.py +1 -2
- openaivec/task/nlp/keyword_extraction.py +1 -2
- openaivec/task/nlp/morphological_analysis.py +1 -2
- openaivec/task/nlp/named_entity_recognition.py +1 -2
- openaivec/task/nlp/sentiment_analysis.py +1 -2
- openaivec/task/nlp/translation.py +1 -1
- openaivec/task/table/fillna.py +8 -3
- {openaivec-0.14.12.dist-info → openaivec-0.14.14.dist-info}/METADATA +1 -1
- openaivec-0.14.14.dist-info/RECORD +37 -0
- openaivec-0.14.12.dist-info/RECORD +0 -37
- {openaivec-0.14.12.dist-info → openaivec-0.14.14.dist-info}/WHEEL +0 -0
- {openaivec-0.14.12.dist-info → openaivec-0.14.14.dist-info}/licenses/LICENSE +0 -0
openaivec/_embeddings.py
CHANGED
|
@@ -26,14 +26,16 @@ class BatchEmbeddings:
|
|
|
26
26
|
model_name (str): For Azure OpenAI, use your deployment name. For OpenAI, use the model name
|
|
27
27
|
(e.g., ``"text-embedding-3-small"``).
|
|
28
28
|
cache (BatchingMapProxy[str, NDArray[np.float32]]): Batching proxy for ordered, cached mapping.
|
|
29
|
+
api_kwargs (dict[str, Any]): Additional OpenAI API parameters stored at initialization.
|
|
29
30
|
"""
|
|
30
31
|
|
|
31
32
|
client: OpenAI
|
|
32
33
|
model_name: str
|
|
33
34
|
cache: BatchingMapProxy[str, NDArray[np.float32]] = field(default_factory=lambda: BatchingMapProxy(batch_size=None))
|
|
35
|
+
api_kwargs: dict[str, int | float | str | bool] = field(default_factory=dict)
|
|
34
36
|
|
|
35
37
|
@classmethod
|
|
36
|
-
def of(cls, client: OpenAI, model_name: str, batch_size: int | None = None) -> "BatchEmbeddings":
|
|
38
|
+
def of(cls, client: OpenAI, model_name: str, batch_size: int | None = None, **api_kwargs) -> "BatchEmbeddings":
|
|
37
39
|
"""Factory constructor.
|
|
38
40
|
|
|
39
41
|
Args:
|
|
@@ -41,11 +43,17 @@ class BatchEmbeddings:
|
|
|
41
43
|
model_name (str): For Azure OpenAI, use your deployment name. For OpenAI, use the model name.
|
|
42
44
|
batch_size (int | None, optional): Max unique inputs per API call. Defaults to None
|
|
43
45
|
(automatic batch size optimization). Set to a positive integer for fixed batch size.
|
|
46
|
+
**api_kwargs: Additional OpenAI API parameters (e.g., dimensions for text-embedding-3 models).
|
|
44
47
|
|
|
45
48
|
Returns:
|
|
46
49
|
BatchEmbeddings: Configured instance backed by a batching proxy.
|
|
47
50
|
"""
|
|
48
|
-
return cls(
|
|
51
|
+
return cls(
|
|
52
|
+
client=client,
|
|
53
|
+
model_name=model_name,
|
|
54
|
+
cache=BatchingMapProxy(batch_size=batch_size),
|
|
55
|
+
api_kwargs=api_kwargs,
|
|
56
|
+
)
|
|
49
57
|
|
|
50
58
|
@observe(_LOGGER)
|
|
51
59
|
@backoff(exceptions=[RateLimitError, InternalServerError], scale=1, max_retries=12)
|
|
@@ -62,7 +70,7 @@ class BatchEmbeddings:
|
|
|
62
70
|
Returns:
|
|
63
71
|
list[NDArray[np.float32]]: Embedding vectors aligned to ``inputs``.
|
|
64
72
|
"""
|
|
65
|
-
responses = self.client.embeddings.create(input=inputs, model=self.model_name)
|
|
73
|
+
responses = self.client.embeddings.create(input=inputs, model=self.model_name, **self.api_kwargs)
|
|
66
74
|
return [np.array(d.embedding, dtype=np.float32) for d in responses.data]
|
|
67
75
|
|
|
68
76
|
@observe(_LOGGER)
|
|
@@ -122,6 +130,7 @@ class AsyncBatchEmbeddings:
|
|
|
122
130
|
client (AsyncOpenAI): Configured OpenAI async client.
|
|
123
131
|
model_name (str): For Azure OpenAI, use your deployment name. For OpenAI, use the model name.
|
|
124
132
|
cache (AsyncBatchingMapProxy[str, NDArray[np.float32]]): Async batching proxy.
|
|
133
|
+
api_kwargs (dict): Additional OpenAI API parameters stored at initialization.
|
|
125
134
|
"""
|
|
126
135
|
|
|
127
136
|
client: AsyncOpenAI
|
|
@@ -129,6 +138,7 @@ class AsyncBatchEmbeddings:
|
|
|
129
138
|
cache: AsyncBatchingMapProxy[str, NDArray[np.float32]] = field(
|
|
130
139
|
default_factory=lambda: AsyncBatchingMapProxy(batch_size=None, max_concurrency=8)
|
|
131
140
|
)
|
|
141
|
+
api_kwargs: dict[str, int | float | str | bool] = field(default_factory=dict)
|
|
132
142
|
|
|
133
143
|
@classmethod
|
|
134
144
|
def of(
|
|
@@ -137,6 +147,7 @@ class AsyncBatchEmbeddings:
|
|
|
137
147
|
model_name: str,
|
|
138
148
|
batch_size: int | None = None,
|
|
139
149
|
max_concurrency: int = 8,
|
|
150
|
+
**api_kwargs,
|
|
140
151
|
) -> "AsyncBatchEmbeddings":
|
|
141
152
|
"""Factory constructor.
|
|
142
153
|
|
|
@@ -146,6 +157,7 @@ class AsyncBatchEmbeddings:
|
|
|
146
157
|
batch_size (int | None, optional): Max unique inputs per API call. Defaults to None
|
|
147
158
|
(automatic batch size optimization). Set to a positive integer for fixed batch size.
|
|
148
159
|
max_concurrency (int, optional): Max concurrent API calls. Defaults to 8.
|
|
160
|
+
**api_kwargs: Additional OpenAI API parameters (e.g., dimensions for text-embedding-3 models).
|
|
149
161
|
|
|
150
162
|
Returns:
|
|
151
163
|
AsyncBatchEmbeddings: Configured instance with an async batching proxy.
|
|
@@ -154,6 +166,7 @@ class AsyncBatchEmbeddings:
|
|
|
154
166
|
client=client,
|
|
155
167
|
model_name=model_name,
|
|
156
168
|
cache=AsyncBatchingMapProxy(batch_size=batch_size, max_concurrency=max_concurrency),
|
|
169
|
+
api_kwargs=api_kwargs,
|
|
157
170
|
)
|
|
158
171
|
|
|
159
172
|
@backoff_async(exceptions=[RateLimitError, InternalServerError], scale=1, max_retries=12)
|
|
@@ -174,7 +187,7 @@ class AsyncBatchEmbeddings:
|
|
|
174
187
|
Raises:
|
|
175
188
|
RateLimitError: Propagated if retries are exhausted.
|
|
176
189
|
"""
|
|
177
|
-
responses = await self.client.embeddings.create(input=inputs, model=self.model_name)
|
|
190
|
+
responses = await self.client.embeddings.create(input=inputs, model=self.model_name, **self.api_kwargs)
|
|
178
191
|
return [np.array(d.embedding, dtype=np.float32) for d in responses.data]
|
|
179
192
|
|
|
180
193
|
@observe(_LOGGER)
|
openaivec/_model.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from dataclasses import dataclass
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
2
|
from typing import Generic, TypeVar
|
|
3
3
|
|
|
4
4
|
__all__ = [
|
|
@@ -14,7 +14,7 @@ class PreparedTask(Generic[ResponseFormat]):
|
|
|
14
14
|
|
|
15
15
|
This class encapsulates all the necessary parameters for executing a task,
|
|
16
16
|
including the instructions to be sent to the model, the expected response
|
|
17
|
-
format using Pydantic models, and
|
|
17
|
+
format using Pydantic models, and API parameters for controlling
|
|
18
18
|
the model's output behavior.
|
|
19
19
|
|
|
20
20
|
Attributes:
|
|
@@ -22,12 +22,9 @@ class PreparedTask(Generic[ResponseFormat]):
|
|
|
22
22
|
This should contain clear, specific directions for the task.
|
|
23
23
|
response_format (type[ResponseFormat]): A Pydantic model class or str type that defines the expected
|
|
24
24
|
structure of the response. Can be either a BaseModel subclass or str.
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
Defaults to
|
|
28
|
-
top_p (float): Controls diversity via nucleus sampling. Only tokens
|
|
29
|
-
comprising the top_p probability mass are considered.
|
|
30
|
-
Range: 0.0 to 1.0. Defaults to 1.0.
|
|
25
|
+
api_kwargs (dict[str, int | float | str | bool]): Additional OpenAI API parameters
|
|
26
|
+
such as temperature, top_p, frequency_penalty, presence_penalty, seed, etc.
|
|
27
|
+
Defaults to an empty dict.
|
|
31
28
|
|
|
32
29
|
Example:
|
|
33
30
|
Creating a custom task:
|
|
@@ -43,8 +40,7 @@ class PreparedTask(Generic[ResponseFormat]):
|
|
|
43
40
|
custom_task = PreparedTask(
|
|
44
41
|
instructions="Translate the following text to French:",
|
|
45
42
|
response_format=TranslationResponse,
|
|
46
|
-
temperature
|
|
47
|
-
top_p=0.9
|
|
43
|
+
api_kwargs={"temperature": 0.1, "top_p": 0.9}
|
|
48
44
|
)
|
|
49
45
|
```
|
|
50
46
|
|
|
@@ -55,8 +51,7 @@ class PreparedTask(Generic[ResponseFormat]):
|
|
|
55
51
|
|
|
56
52
|
instructions: str
|
|
57
53
|
response_format: type[ResponseFormat]
|
|
58
|
-
|
|
59
|
-
top_p: float = 1.0
|
|
54
|
+
api_kwargs: dict[str, int | float | str | bool] = field(default_factory=dict)
|
|
60
55
|
|
|
61
56
|
|
|
62
57
|
@dataclass(frozen=True)
|
openaivec/_prompt.py
CHANGED
|
@@ -445,8 +445,7 @@ class FewShotPromptBuilder:
|
|
|
445
445
|
self,
|
|
446
446
|
client: OpenAI | None = None,
|
|
447
447
|
model_name: str | None = None,
|
|
448
|
-
|
|
449
|
-
top_p: float | None = None,
|
|
448
|
+
**api_kwargs,
|
|
450
449
|
) -> "FewShotPromptBuilder":
|
|
451
450
|
"""Iteratively refine the prompt using an LLM.
|
|
452
451
|
|
|
@@ -460,8 +459,7 @@ class FewShotPromptBuilder:
|
|
|
460
459
|
Args:
|
|
461
460
|
client (OpenAI | None): Configured OpenAI client. If None, uses DI container with environment variables.
|
|
462
461
|
model_name (str | None): Model identifier. If None, uses default ``gpt-4.1-mini``.
|
|
463
|
-
|
|
464
|
-
top_p (float | None): Nucleus sampling parameter. If None, uses model default.
|
|
462
|
+
**api_kwargs: Additional OpenAI API parameters (temperature, top_p, etc.).
|
|
465
463
|
|
|
466
464
|
Returns:
|
|
467
465
|
FewShotPromptBuilder: The current builder instance containing the refined prompt and iteration history.
|
|
@@ -479,9 +477,8 @@ class FewShotPromptBuilder:
|
|
|
479
477
|
model=_model_name,
|
|
480
478
|
instructions=_PROMPT,
|
|
481
479
|
input=Request(prompt=self._prompt).model_dump_json(),
|
|
482
|
-
temperature=temperature,
|
|
483
|
-
top_p=top_p,
|
|
484
480
|
text_format=Response,
|
|
481
|
+
**api_kwargs,
|
|
485
482
|
)
|
|
486
483
|
|
|
487
484
|
# keep the original prompt
|
openaivec/_responses.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
3
|
from logging import Logger, getLogger
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Generic, cast
|
|
5
5
|
|
|
6
6
|
from openai import AsyncOpenAI, BadRequestError, InternalServerError, OpenAI, RateLimitError
|
|
7
7
|
from openai.types.responses import ParsedResponse
|
|
@@ -148,8 +148,6 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
148
148
|
client (OpenAI): Initialised OpenAI client.
|
|
149
149
|
model_name (str): For Azure OpenAI, use your deployment name. For OpenAI, use the model name.
|
|
150
150
|
system_message (str): System prompt prepended to every request.
|
|
151
|
-
temperature (float): Sampling temperature.
|
|
152
|
-
top_p (float): Nucleus‑sampling parameter.
|
|
153
151
|
response_format (type[ResponseFormat]): Expected Pydantic model class or ``str`` for each assistant message.
|
|
154
152
|
cache (BatchingMapProxy[str, ResponseFormat]): Order‑preserving batching proxy with de‑duplication and caching.
|
|
155
153
|
|
|
@@ -163,10 +161,9 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
163
161
|
client: OpenAI
|
|
164
162
|
model_name: str # For Azure: deployment name, for OpenAI: model name
|
|
165
163
|
system_message: str
|
|
166
|
-
temperature: float | None = None
|
|
167
|
-
top_p: float = 1.0
|
|
168
164
|
response_format: type[ResponseFormat] = str # type: ignore[assignment]
|
|
169
165
|
cache: BatchingMapProxy[str, ResponseFormat] = field(default_factory=lambda: BatchingMapProxy(batch_size=None))
|
|
166
|
+
api_kwargs: dict[str, int | float | str | bool] = field(default_factory=dict)
|
|
170
167
|
_vectorized_system_message: str = field(init=False)
|
|
171
168
|
_model_json_schema: dict = field(init=False)
|
|
172
169
|
|
|
@@ -176,10 +173,9 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
176
173
|
client: OpenAI,
|
|
177
174
|
model_name: str,
|
|
178
175
|
system_message: str,
|
|
179
|
-
temperature: float | None = 0.0,
|
|
180
|
-
top_p: float = 1.0,
|
|
181
176
|
response_format: type[ResponseFormat] = str,
|
|
182
177
|
batch_size: int | None = None,
|
|
178
|
+
**api_kwargs,
|
|
183
179
|
) -> "BatchResponses":
|
|
184
180
|
"""Factory constructor.
|
|
185
181
|
|
|
@@ -187,11 +183,10 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
187
183
|
client (OpenAI): OpenAI client.
|
|
188
184
|
model_name (str): For Azure OpenAI, use your deployment name. For OpenAI, use the model name.
|
|
189
185
|
system_message (str): System prompt for the model.
|
|
190
|
-
temperature (float, optional): Sampling temperature. Defaults to 0.0.
|
|
191
|
-
top_p (float, optional): Nucleus sampling parameter. Defaults to 1.0.
|
|
192
186
|
response_format (type[ResponseFormat], optional): Expected output type. Defaults to ``str``.
|
|
193
187
|
batch_size (int | None, optional): Max unique prompts per API call. Defaults to None
|
|
194
188
|
(automatic batch size optimization). Set to a positive integer for fixed batch size.
|
|
189
|
+
**api_kwargs: Additional OpenAI API parameters (temperature, top_p, etc.).
|
|
195
190
|
|
|
196
191
|
Returns:
|
|
197
192
|
BatchResponses: Configured instance backed by a batching proxy.
|
|
@@ -200,10 +195,9 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
200
195
|
client=client,
|
|
201
196
|
model_name=model_name,
|
|
202
197
|
system_message=system_message,
|
|
203
|
-
temperature=temperature,
|
|
204
|
-
top_p=top_p,
|
|
205
198
|
response_format=response_format,
|
|
206
199
|
cache=BatchingMapProxy(batch_size=batch_size),
|
|
200
|
+
api_kwargs=api_kwargs,
|
|
207
201
|
)
|
|
208
202
|
|
|
209
203
|
@classmethod
|
|
@@ -226,10 +220,9 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
226
220
|
client=client,
|
|
227
221
|
model_name=model_name,
|
|
228
222
|
system_message=task.instructions,
|
|
229
|
-
temperature=task.temperature,
|
|
230
|
-
top_p=task.top_p,
|
|
231
223
|
response_format=task.response_format,
|
|
232
224
|
cache=BatchingMapProxy(batch_size=batch_size),
|
|
225
|
+
api_kwargs=task.api_kwargs,
|
|
233
226
|
)
|
|
234
227
|
|
|
235
228
|
def __post_init__(self):
|
|
@@ -241,9 +234,7 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
241
234
|
|
|
242
235
|
@observe(_LOGGER)
|
|
243
236
|
@backoff(exceptions=[RateLimitError, InternalServerError], scale=1, max_retries=12)
|
|
244
|
-
def _request_llm(
|
|
245
|
-
self, user_messages: list[Message[str]], **extra_api_params: Any
|
|
246
|
-
) -> ParsedResponse[Response[ResponseFormat]]:
|
|
237
|
+
def _request_llm(self, user_messages: list[Message[str]]) -> ParsedResponse[Response[ResponseFormat]]:
|
|
247
238
|
"""Make a single call to the OpenAI JSON‑mode endpoint.
|
|
248
239
|
|
|
249
240
|
Args:
|
|
@@ -267,40 +258,22 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
267
258
|
class ResponseT(BaseModel):
|
|
268
259
|
assistant_messages: list[MessageT]
|
|
269
260
|
|
|
270
|
-
# Build base API parameters (cannot be overridden by caller)
|
|
271
|
-
api_params: dict[str, Any] = {
|
|
272
|
-
"model": self.model_name,
|
|
273
|
-
"instructions": self._vectorized_system_message,
|
|
274
|
-
"input": Request(user_messages=user_messages).model_dump_json(),
|
|
275
|
-
"text_format": ResponseT,
|
|
276
|
-
}
|
|
277
|
-
|
|
278
|
-
# Resolve nucleus sampling (caller can override)
|
|
279
|
-
top_p = extra_api_params.pop("top_p", self.top_p)
|
|
280
|
-
if top_p is not None:
|
|
281
|
-
api_params["top_p"] = top_p
|
|
282
|
-
|
|
283
|
-
# Resolve temperature (caller can override). If None, omit entirely for reasoning models.
|
|
284
|
-
temperature = extra_api_params.pop("temperature", self.temperature)
|
|
285
|
-
if temperature is not None:
|
|
286
|
-
api_params["temperature"] = temperature
|
|
287
|
-
|
|
288
|
-
# Merge remaining user supplied params, excluding protected keys
|
|
289
|
-
for k, v in extra_api_params.items():
|
|
290
|
-
if k in {"model", "instructions", "input", "text_format"}:
|
|
291
|
-
continue # ignore attempts to override core batching contract
|
|
292
|
-
api_params[k] = v
|
|
293
|
-
|
|
294
261
|
try:
|
|
295
|
-
|
|
262
|
+
response: ParsedResponse[ResponseT] = self.client.responses.parse(
|
|
263
|
+
instructions=self._vectorized_system_message,
|
|
264
|
+
model=self.model_name,
|
|
265
|
+
input=Request(user_messages=user_messages).model_dump_json(),
|
|
266
|
+
text_format=ResponseT,
|
|
267
|
+
**self.api_kwargs,
|
|
268
|
+
)
|
|
296
269
|
except BadRequestError as e:
|
|
297
|
-
_handle_temperature_error(e, self.model_name, self.temperature
|
|
270
|
+
_handle_temperature_error(e, self.model_name, self.api_kwargs.get("temperature", 0.0))
|
|
298
271
|
raise # Re-raise if it wasn't a temperature error
|
|
299
272
|
|
|
300
|
-
return cast(ParsedResponse[Response[ResponseFormat]],
|
|
273
|
+
return cast(ParsedResponse[Response[ResponseFormat]], response)
|
|
301
274
|
|
|
302
275
|
@observe(_LOGGER)
|
|
303
|
-
def _predict_chunk(self, user_messages: list[str]
|
|
276
|
+
def _predict_chunk(self, user_messages: list[str]) -> list[ResponseFormat | None]:
|
|
304
277
|
"""Helper executed for every unique minibatch.
|
|
305
278
|
|
|
306
279
|
This method:
|
|
@@ -312,7 +285,7 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
312
285
|
only on its arguments – which allows safe reuse.
|
|
313
286
|
"""
|
|
314
287
|
messages = [Message(id=i, body=message) for i, message in enumerate(user_messages)]
|
|
315
|
-
responses: ParsedResponse[Response[ResponseFormat]] = self._request_llm(messages
|
|
288
|
+
responses: ParsedResponse[Response[ResponseFormat]] = self._request_llm(messages)
|
|
316
289
|
if not responses.output_parsed:
|
|
317
290
|
return [None] * len(messages)
|
|
318
291
|
response_dict = {message.id: message.body for message in responses.output_parsed.assistant_messages}
|
|
@@ -320,28 +293,16 @@ class BatchResponses(Generic[ResponseFormat]):
|
|
|
320
293
|
return sorted_responses
|
|
321
294
|
|
|
322
295
|
@observe(_LOGGER)
|
|
323
|
-
def parse(self, inputs: list[str]
|
|
296
|
+
def parse(self, inputs: list[str]) -> list[ResponseFormat | None]:
|
|
324
297
|
"""Batched predict.
|
|
325
298
|
|
|
326
|
-
Accepts arbitrary keyword arguments that are forwarded to the underlying
|
|
327
|
-
``OpenAI.responses.parse`` call for future‑proofing (e.g., ``max_output_tokens``,
|
|
328
|
-
penalties, etc.). ``top_p`` and ``temperature`` default to the instance's
|
|
329
|
-
configured values but can be overridden explicitly.
|
|
330
|
-
|
|
331
299
|
Args:
|
|
332
300
|
inputs (list[str]): Prompts that require responses. Duplicates are de‑duplicated.
|
|
333
|
-
**api_kwargs: Extra keyword args forwarded to the OpenAI Responses API.
|
|
334
301
|
|
|
335
302
|
Returns:
|
|
336
303
|
list[ResponseFormat | None]: Assistant responses aligned to ``inputs``.
|
|
337
304
|
"""
|
|
338
|
-
|
|
339
|
-
return self.cache.map(inputs, self._predict_chunk) # type: ignore[return-value]
|
|
340
|
-
|
|
341
|
-
def _predict_with(xs: list[str]) -> list[ResponseFormat | None]:
|
|
342
|
-
return self._predict_chunk(xs, **api_kwargs)
|
|
343
|
-
|
|
344
|
-
return self.cache.map(inputs, _predict_with) # type: ignore[return-value]
|
|
305
|
+
return self.cache.map(inputs, self._predict_chunk) # type: ignore[return-value]
|
|
345
306
|
|
|
346
307
|
|
|
347
308
|
@dataclass(frozen=True)
|
|
@@ -383,8 +344,6 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
383
344
|
client (AsyncOpenAI): Initialised OpenAI async client.
|
|
384
345
|
model_name (str): For Azure OpenAI, use your deployment name. For OpenAI, use the model name.
|
|
385
346
|
system_message (str): System prompt prepended to every request.
|
|
386
|
-
temperature (float): Sampling temperature.
|
|
387
|
-
top_p (float): Nucleus‑sampling parameter.
|
|
388
347
|
response_format (type[ResponseFormat]): Expected Pydantic model class or ``str`` for each assistant message.
|
|
389
348
|
cache (AsyncBatchingMapProxy[str, ResponseFormat]): Async batching proxy with de‑duplication
|
|
390
349
|
and concurrency control.
|
|
@@ -393,12 +352,11 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
393
352
|
client: AsyncOpenAI
|
|
394
353
|
model_name: str # For Azure: deployment name, for OpenAI: model name
|
|
395
354
|
system_message: str
|
|
396
|
-
temperature: float | None = 0.0
|
|
397
|
-
top_p: float = 1.0
|
|
398
355
|
response_format: type[ResponseFormat] = str # type: ignore[assignment]
|
|
399
356
|
cache: AsyncBatchingMapProxy[str, ResponseFormat] = field(
|
|
400
357
|
default_factory=lambda: AsyncBatchingMapProxy(batch_size=None, max_concurrency=8)
|
|
401
358
|
)
|
|
359
|
+
api_kwargs: dict[str, int | float | str | bool] = field(default_factory=dict)
|
|
402
360
|
_vectorized_system_message: str = field(init=False)
|
|
403
361
|
_model_json_schema: dict = field(init=False)
|
|
404
362
|
|
|
@@ -408,11 +366,10 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
408
366
|
client: AsyncOpenAI,
|
|
409
367
|
model_name: str,
|
|
410
368
|
system_message: str,
|
|
411
|
-
temperature: float | None = None,
|
|
412
|
-
top_p: float = 1.0,
|
|
413
369
|
response_format: type[ResponseFormat] = str,
|
|
414
370
|
batch_size: int | None = None,
|
|
415
371
|
max_concurrency: int = 8,
|
|
372
|
+
**api_kwargs,
|
|
416
373
|
) -> "AsyncBatchResponses":
|
|
417
374
|
"""Factory constructor.
|
|
418
375
|
|
|
@@ -420,12 +377,11 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
420
377
|
client (AsyncOpenAI): OpenAI async client.
|
|
421
378
|
model_name (str): For Azure OpenAI, use your deployment name. For OpenAI, use the model name.
|
|
422
379
|
system_message (str): System prompt.
|
|
423
|
-
temperature (float, optional): Sampling temperature. Defaults to 0.0.
|
|
424
|
-
top_p (float, optional): Nucleus sampling parameter. Defaults to 1.0.
|
|
425
380
|
response_format (type[ResponseFormat], optional): Expected output type. Defaults to ``str``.
|
|
426
381
|
batch_size (int | None, optional): Max unique prompts per API call. Defaults to None
|
|
427
382
|
(automatic batch size optimization). Set to a positive integer for fixed batch size.
|
|
428
383
|
max_concurrency (int, optional): Max concurrent API calls. Defaults to 8.
|
|
384
|
+
**api_kwargs: Additional OpenAI API parameters (temperature, top_p, etc.).
|
|
429
385
|
|
|
430
386
|
Returns:
|
|
431
387
|
AsyncBatchResponses: Configured instance backed by an async batching proxy.
|
|
@@ -434,10 +390,9 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
434
390
|
client=client,
|
|
435
391
|
model_name=model_name,
|
|
436
392
|
system_message=system_message,
|
|
437
|
-
temperature=temperature,
|
|
438
|
-
top_p=top_p,
|
|
439
393
|
response_format=response_format,
|
|
440
394
|
cache=AsyncBatchingMapProxy(batch_size=batch_size, max_concurrency=max_concurrency),
|
|
395
|
+
api_kwargs=api_kwargs,
|
|
441
396
|
)
|
|
442
397
|
|
|
443
398
|
@classmethod
|
|
@@ -466,10 +421,9 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
466
421
|
client=client,
|
|
467
422
|
model_name=model_name,
|
|
468
423
|
system_message=task.instructions,
|
|
469
|
-
temperature=task.temperature,
|
|
470
|
-
top_p=task.top_p,
|
|
471
424
|
response_format=task.response_format,
|
|
472
425
|
cache=AsyncBatchingMapProxy(batch_size=batch_size, max_concurrency=max_concurrency),
|
|
426
|
+
api_kwargs=task.api_kwargs,
|
|
473
427
|
)
|
|
474
428
|
|
|
475
429
|
def __post_init__(self):
|
|
@@ -481,9 +435,7 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
481
435
|
|
|
482
436
|
@backoff_async(exceptions=[RateLimitError, InternalServerError], scale=1, max_retries=12)
|
|
483
437
|
@observe(_LOGGER)
|
|
484
|
-
async def _request_llm(
|
|
485
|
-
self, user_messages: list[Message[str]], **extra_api_params: Any
|
|
486
|
-
) -> ParsedResponse[Response[ResponseFormat]]:
|
|
438
|
+
async def _request_llm(self, user_messages: list[Message[str]]) -> ParsedResponse[Response[ResponseFormat]]:
|
|
487
439
|
"""Make a single async call to the OpenAI JSON‑mode endpoint.
|
|
488
440
|
|
|
489
441
|
Args:
|
|
@@ -504,40 +456,22 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
504
456
|
class ResponseT(BaseModel):
|
|
505
457
|
assistant_messages: list[MessageT]
|
|
506
458
|
|
|
507
|
-
# Build base API parameters (cannot be overridden by caller)
|
|
508
|
-
api_params: dict[str, Any] = {
|
|
509
|
-
"model": self.model_name,
|
|
510
|
-
"instructions": self._vectorized_system_message,
|
|
511
|
-
"input": Request(user_messages=user_messages).model_dump_json(),
|
|
512
|
-
"text_format": ResponseT,
|
|
513
|
-
}
|
|
514
|
-
|
|
515
|
-
# Resolve nucleus sampling (caller can override)
|
|
516
|
-
top_p = extra_api_params.pop("top_p", self.top_p)
|
|
517
|
-
if top_p is not None:
|
|
518
|
-
api_params["top_p"] = top_p
|
|
519
|
-
|
|
520
|
-
# Resolve temperature (caller can override). If None, omit entirely for reasoning models.
|
|
521
|
-
temperature = extra_api_params.pop("temperature", self.temperature)
|
|
522
|
-
if temperature is not None:
|
|
523
|
-
api_params["temperature"] = temperature
|
|
524
|
-
|
|
525
|
-
# Merge remaining user supplied params, excluding protected keys
|
|
526
|
-
for k, v in extra_api_params.items():
|
|
527
|
-
if k in {"model", "instructions", "input", "text_format"}:
|
|
528
|
-
continue
|
|
529
|
-
api_params[k] = v
|
|
530
|
-
|
|
531
459
|
try:
|
|
532
|
-
|
|
460
|
+
response: ParsedResponse[ResponseT] = await self.client.responses.parse(
|
|
461
|
+
instructions=self._vectorized_system_message,
|
|
462
|
+
model=self.model_name,
|
|
463
|
+
input=Request(user_messages=user_messages).model_dump_json(),
|
|
464
|
+
text_format=ResponseT,
|
|
465
|
+
**self.api_kwargs,
|
|
466
|
+
)
|
|
533
467
|
except BadRequestError as e:
|
|
534
|
-
_handle_temperature_error(e, self.model_name, self.temperature
|
|
468
|
+
_handle_temperature_error(e, self.model_name, self.api_kwargs.get("temperature", 0.0))
|
|
535
469
|
raise # Re-raise if it wasn't a temperature error
|
|
536
470
|
|
|
537
|
-
return cast(ParsedResponse[Response[ResponseFormat]],
|
|
471
|
+
return cast(ParsedResponse[Response[ResponseFormat]], response)
|
|
538
472
|
|
|
539
473
|
@observe(_LOGGER)
|
|
540
|
-
async def _predict_chunk(self, user_messages: list[str]
|
|
474
|
+
async def _predict_chunk(self, user_messages: list[str]) -> list[ResponseFormat | None]:
|
|
541
475
|
"""Async helper executed for every unique minibatch.
|
|
542
476
|
|
|
543
477
|
This method:
|
|
@@ -548,7 +482,7 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
548
482
|
The function is pure – it has no side‑effects and the result depends only on its arguments.
|
|
549
483
|
"""
|
|
550
484
|
messages = [Message(id=i, body=message) for i, message in enumerate(user_messages)]
|
|
551
|
-
responses: ParsedResponse[Response[ResponseFormat]] = await self._request_llm(messages
|
|
485
|
+
responses: ParsedResponse[Response[ResponseFormat]] = await self._request_llm(messages)
|
|
552
486
|
if not responses.output_parsed:
|
|
553
487
|
return [None] * len(messages)
|
|
554
488
|
response_dict = {message.id: message.body for message in responses.output_parsed.assistant_messages}
|
|
@@ -557,25 +491,13 @@ class AsyncBatchResponses(Generic[ResponseFormat]):
|
|
|
557
491
|
return sorted_responses
|
|
558
492
|
|
|
559
493
|
@observe(_LOGGER)
|
|
560
|
-
async def parse(self, inputs: list[str]
|
|
494
|
+
async def parse(self, inputs: list[str]) -> list[ResponseFormat | None]:
|
|
561
495
|
"""Batched predict (async).
|
|
562
496
|
|
|
563
|
-
Accepts arbitrary keyword arguments forwarded to ``AsyncOpenAI.responses.parse``.
|
|
564
|
-
``top_p`` and ``temperature`` default to instance configuration but can be
|
|
565
|
-
overridden per call. This prepares for future API parameters without
|
|
566
|
-
changing the public surface again.
|
|
567
|
-
|
|
568
497
|
Args:
|
|
569
498
|
inputs (list[str]): Prompts that require responses. Duplicates are de‑duplicated.
|
|
570
|
-
**api_kwargs: Extra keyword args for the OpenAI Responses API.
|
|
571
499
|
|
|
572
500
|
Returns:
|
|
573
501
|
list[ResponseFormat | None]: Assistant responses aligned to ``inputs``.
|
|
574
502
|
"""
|
|
575
|
-
|
|
576
|
-
return await self.cache.map(inputs, self._predict_chunk) # type: ignore[return-value]
|
|
577
|
-
|
|
578
|
-
async def _predict_with(xs: list[str]) -> list[ResponseFormat | None]:
|
|
579
|
-
return await self._predict_chunk(xs, **api_kwargs)
|
|
580
|
-
|
|
581
|
-
return await self.cache.map(inputs, _predict_with) # type: ignore[return-value]
|
|
503
|
+
return await self.cache.map(inputs, self._predict_chunk) # type: ignore[return-value]
|