openaivec 0.12.5__py3-none-any.whl → 1.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openaivec/__init__.py +13 -4
- openaivec/_cache/__init__.py +12 -0
- openaivec/_cache/optimize.py +109 -0
- openaivec/_cache/proxy.py +806 -0
- openaivec/{di.py → _di.py} +36 -12
- openaivec/_embeddings.py +203 -0
- openaivec/{log.py → _log.py} +2 -2
- openaivec/_model.py +113 -0
- openaivec/{prompt.py → _prompt.py} +95 -28
- openaivec/_provider.py +207 -0
- openaivec/_responses.py +511 -0
- openaivec/_schema/__init__.py +9 -0
- openaivec/_schema/infer.py +340 -0
- openaivec/_schema/spec.py +350 -0
- openaivec/_serialize.py +234 -0
- openaivec/{util.py → _util.py} +25 -85
- openaivec/pandas_ext.py +1496 -318
- openaivec/spark.py +485 -183
- openaivec/task/__init__.py +9 -7
- openaivec/task/customer_support/__init__.py +9 -15
- openaivec/task/customer_support/customer_sentiment.py +17 -15
- openaivec/task/customer_support/inquiry_classification.py +23 -22
- openaivec/task/customer_support/inquiry_summary.py +14 -13
- openaivec/task/customer_support/intent_analysis.py +21 -19
- openaivec/task/customer_support/response_suggestion.py +16 -16
- openaivec/task/customer_support/urgency_analysis.py +24 -25
- openaivec/task/nlp/__init__.py +4 -4
- openaivec/task/nlp/dependency_parsing.py +10 -12
- openaivec/task/nlp/keyword_extraction.py +11 -14
- openaivec/task/nlp/morphological_analysis.py +12 -14
- openaivec/task/nlp/named_entity_recognition.py +16 -18
- openaivec/task/nlp/sentiment_analysis.py +14 -11
- openaivec/task/nlp/translation.py +6 -9
- openaivec/task/table/__init__.py +2 -2
- openaivec/task/table/fillna.py +11 -11
- openaivec-1.0.10.dist-info/METADATA +399 -0
- openaivec-1.0.10.dist-info/RECORD +39 -0
- {openaivec-0.12.5.dist-info → openaivec-1.0.10.dist-info}/WHEEL +1 -1
- openaivec/embeddings.py +0 -172
- openaivec/model.py +0 -67
- openaivec/provider.py +0 -45
- openaivec/responses.py +0 -393
- openaivec/serialize.py +0 -225
- openaivec-0.12.5.dist-info/METADATA +0 -696
- openaivec-0.12.5.dist-info/RECORD +0 -33
- {openaivec-0.12.5.dist-info → openaivec-1.0.10.dist-info}/licenses/LICENSE +0 -0
openaivec/spark.py
CHANGED
|
@@ -1,37 +1,51 @@
|
|
|
1
1
|
"""Asynchronous Spark UDFs for the OpenAI and Azure OpenAI APIs.
|
|
2
2
|
|
|
3
|
-
This module provides functions (`responses_udf`, `task_udf`, `embeddings_udf
|
|
3
|
+
This module provides functions (`responses_udf`, `task_udf`, `embeddings_udf`,
|
|
4
|
+
`count_tokens_udf`, `split_to_chunks_udf`, `similarity_udf`, `parse_udf`)
|
|
4
5
|
for creating asynchronous Spark UDFs that communicate with either the public
|
|
5
6
|
OpenAI API or Azure OpenAI using the `openaivec.spark` subpackage.
|
|
6
|
-
It supports UDFs for generating responses
|
|
7
|
-
The UDFs operate on Spark DataFrames
|
|
8
|
-
improved performance in I/O-bound operations.
|
|
7
|
+
It supports UDFs for generating responses, creating embeddings, parsing text,
|
|
8
|
+
and computing similarities asynchronously. The UDFs operate on Spark DataFrames
|
|
9
|
+
and leverage asyncio for improved performance in I/O-bound operations.
|
|
10
|
+
|
|
11
|
+
**Performance Optimization**: All AI-powered UDFs (`responses_udf`, `task_udf`, `embeddings_udf`, `parse_udf`)
|
|
12
|
+
automatically cache duplicate inputs within each partition, significantly reducing
|
|
13
|
+
API calls and costs when processing datasets with overlapping content.
|
|
14
|
+
|
|
9
15
|
|
|
10
16
|
## Setup
|
|
11
17
|
|
|
12
18
|
First, obtain a Spark session and configure authentication:
|
|
13
19
|
|
|
14
20
|
```python
|
|
15
|
-
import os
|
|
16
21
|
from pyspark.sql import SparkSession
|
|
22
|
+
from openaivec.spark import setup, setup_azure
|
|
17
23
|
|
|
18
24
|
spark = SparkSession.builder.getOrCreate()
|
|
19
|
-
sc = spark.sparkContext
|
|
20
25
|
|
|
21
|
-
# Configure authentication via SparkContext environment variables
|
|
22
26
|
# Option 1: Using OpenAI
|
|
23
|
-
|
|
27
|
+
setup(
|
|
28
|
+
spark,
|
|
29
|
+
api_key="your-openai-api-key",
|
|
30
|
+
responses_model_name="gpt-4.1-mini", # Optional: set default model
|
|
31
|
+
embeddings_model_name="text-embedding-3-small" # Optional: set default model
|
|
32
|
+
)
|
|
24
33
|
|
|
25
34
|
# Option 2: Using Azure OpenAI
|
|
26
|
-
#
|
|
27
|
-
#
|
|
28
|
-
#
|
|
35
|
+
# setup_azure(
|
|
36
|
+
# spark,
|
|
37
|
+
# api_key="your-azure-openai-api-key",
|
|
38
|
+
# base_url="https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/",
|
|
39
|
+
# api_version="preview",
|
|
40
|
+
# responses_model_name="my-gpt4-deployment", # Optional: set default deployment
|
|
41
|
+
# embeddings_model_name="my-embedding-deployment" # Optional: set default deployment
|
|
42
|
+
# )
|
|
29
43
|
```
|
|
30
44
|
|
|
31
45
|
Next, create UDFs and register them:
|
|
32
46
|
|
|
33
47
|
```python
|
|
34
|
-
from openaivec.spark import responses_udf, task_udf, embeddings_udf
|
|
48
|
+
from openaivec.spark import responses_udf, task_udf, embeddings_udf, count_tokens_udf, split_to_chunks_udf
|
|
35
49
|
from pydantic import BaseModel
|
|
36
50
|
|
|
37
51
|
# Define a Pydantic model for structured responses (optional)
|
|
@@ -46,7 +60,7 @@ spark.udf.register(
|
|
|
46
60
|
responses_udf(
|
|
47
61
|
instructions="Translate the text to multiple languages.",
|
|
48
62
|
response_format=Translation,
|
|
49
|
-
model_name="gpt-4.1-mini", #
|
|
63
|
+
model_name="gpt-4.1-mini", # For Azure: deployment name, for OpenAI: model name
|
|
50
64
|
batch_size=64, # Rows per API request within partition
|
|
51
65
|
max_concurrency=8 # Concurrent requests PER EXECUTOR
|
|
52
66
|
),
|
|
@@ -63,11 +77,16 @@ spark.udf.register(
|
|
|
63
77
|
spark.udf.register(
|
|
64
78
|
"embed_async",
|
|
65
79
|
embeddings_udf(
|
|
66
|
-
model_name="text-embedding-3-small", #
|
|
80
|
+
model_name="text-embedding-3-small", # For Azure: deployment name, for OpenAI: model name
|
|
67
81
|
batch_size=128, # Larger batches for embeddings
|
|
68
82
|
max_concurrency=8 # Concurrent requests PER EXECUTOR
|
|
69
83
|
),
|
|
70
84
|
)
|
|
85
|
+
|
|
86
|
+
# Register token counting, text chunking, and similarity UDFs
|
|
87
|
+
spark.udf.register("count_tokens", count_tokens_udf())
|
|
88
|
+
spark.udf.register("split_chunks", split_to_chunks_udf(max_tokens=512, sep=[".", "!", "?"]))
|
|
89
|
+
spark.udf.register("compute_similarity", similarity_udf())
|
|
71
90
|
```
|
|
72
91
|
|
|
73
92
|
You can now invoke the UDFs from Spark SQL:
|
|
@@ -77,7 +96,10 @@ SELECT
|
|
|
77
96
|
text,
|
|
78
97
|
translate_async(text) AS translation,
|
|
79
98
|
sentiment_async(text) AS sentiment,
|
|
80
|
-
embed_async(text) AS embedding
|
|
99
|
+
embed_async(text) AS embedding,
|
|
100
|
+
count_tokens(text) AS token_count,
|
|
101
|
+
split_chunks(text) AS chunks,
|
|
102
|
+
compute_similarity(embed_async(text1), embed_async(text2)) AS similarity
|
|
81
103
|
FROM your_table;
|
|
82
104
|
```
|
|
83
105
|
|
|
@@ -103,26 +125,38 @@ Note: This module provides asynchronous support through the pandas extensions.
|
|
|
103
125
|
|
|
104
126
|
import asyncio
|
|
105
127
|
import logging
|
|
128
|
+
import os
|
|
129
|
+
from collections.abc import Iterator
|
|
106
130
|
from enum import Enum
|
|
107
|
-
from typing import
|
|
131
|
+
from typing import Union, get_args, get_origin
|
|
108
132
|
|
|
133
|
+
import numpy as np
|
|
109
134
|
import pandas as pd
|
|
110
135
|
import tiktoken
|
|
111
136
|
from pydantic import BaseModel
|
|
137
|
+
from pyspark import SparkContext
|
|
138
|
+
from pyspark.sql import SparkSession
|
|
112
139
|
from pyspark.sql.pandas.functions import pandas_udf
|
|
113
140
|
from pyspark.sql.types import ArrayType, BooleanType, FloatType, IntegerType, StringType, StructField, StructType
|
|
114
141
|
from pyspark.sql.udf import UserDefinedFunction
|
|
115
142
|
from typing_extensions import Literal
|
|
116
143
|
|
|
117
|
-
from
|
|
118
|
-
from .
|
|
119
|
-
from .
|
|
120
|
-
from .
|
|
144
|
+
from openaivec import pandas_ext
|
|
145
|
+
from openaivec._cache import AsyncBatchingMapProxy
|
|
146
|
+
from openaivec._model import EmbeddingsModelName, PreparedTask, ResponseFormat, ResponsesModelName
|
|
147
|
+
from openaivec._provider import CONTAINER
|
|
148
|
+
from openaivec._schema import SchemaInferenceInput, SchemaInferenceOutput, SchemaInferer
|
|
149
|
+
from openaivec._serialize import deserialize_base_model, serialize_base_model
|
|
150
|
+
from openaivec._util import TextChunker
|
|
121
151
|
|
|
122
152
|
__all__ = [
|
|
153
|
+
"setup",
|
|
154
|
+
"setup_azure",
|
|
123
155
|
"responses_udf",
|
|
124
156
|
"task_udf",
|
|
125
157
|
"embeddings_udf",
|
|
158
|
+
"infer_schema",
|
|
159
|
+
"parse_udf",
|
|
126
160
|
"split_to_chunks_udf",
|
|
127
161
|
"count_tokens_udf",
|
|
128
162
|
"similarity_udf",
|
|
@@ -130,21 +164,126 @@ __all__ = [
|
|
|
130
164
|
|
|
131
165
|
|
|
132
166
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
|
133
|
-
_TIKTOKEN_ENC: tiktoken.Encoding | None = None
|
|
134
167
|
|
|
135
168
|
|
|
169
|
+
def setup(
|
|
170
|
+
spark: SparkSession, api_key: str, responses_model_name: str | None = None, embeddings_model_name: str | None = None
|
|
171
|
+
):
|
|
172
|
+
"""Setup OpenAI authentication and default model names in Spark environment.
|
|
173
|
+
1. Configures OpenAI API key in SparkContext environment.
|
|
174
|
+
2. Configures OpenAI API key in local process environment.
|
|
175
|
+
3. Optionally registers default model names for responses and embeddings in the DI container.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
spark (SparkSession): The Spark session to configure.
|
|
179
|
+
api_key (str): OpenAI API key for authentication.
|
|
180
|
+
responses_model_name (str | None): Default model name for response generation.
|
|
181
|
+
If provided, registers `ResponsesModelName` in the DI container.
|
|
182
|
+
embeddings_model_name (str | None): Default model name for embeddings.
|
|
183
|
+
If provided, registers `EmbeddingsModelName` in the DI container.
|
|
184
|
+
|
|
185
|
+
Example:
|
|
186
|
+
```python
|
|
187
|
+
from pyspark.sql import SparkSession
|
|
188
|
+
from openaivec.spark import setup
|
|
189
|
+
|
|
190
|
+
spark = SparkSession.builder.getOrCreate()
|
|
191
|
+
setup(
|
|
192
|
+
spark,
|
|
193
|
+
api_key="sk-***",
|
|
194
|
+
responses_model_name="gpt-4.1-mini",
|
|
195
|
+
embeddings_model_name="text-embedding-3-small",
|
|
196
|
+
)
|
|
197
|
+
```
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
CONTAINER.register(SparkSession, lambda: spark)
|
|
201
|
+
CONTAINER.register(SparkContext, lambda: CONTAINER.resolve(SparkSession).sparkContext)
|
|
202
|
+
|
|
203
|
+
sc = CONTAINER.resolve(SparkContext)
|
|
204
|
+
sc.environment["OPENAI_API_KEY"] = api_key
|
|
205
|
+
|
|
206
|
+
os.environ["OPENAI_API_KEY"] = api_key
|
|
207
|
+
|
|
208
|
+
if responses_model_name:
|
|
209
|
+
CONTAINER.register(ResponsesModelName, lambda: ResponsesModelName(responses_model_name))
|
|
210
|
+
|
|
211
|
+
if embeddings_model_name:
|
|
212
|
+
CONTAINER.register(EmbeddingsModelName, lambda: EmbeddingsModelName(embeddings_model_name))
|
|
213
|
+
|
|
214
|
+
CONTAINER.clear_singletons()
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def setup_azure(
|
|
218
|
+
spark: SparkSession,
|
|
219
|
+
api_key: str,
|
|
220
|
+
base_url: str,
|
|
221
|
+
api_version: str = "preview",
|
|
222
|
+
responses_model_name: str | None = None,
|
|
223
|
+
embeddings_model_name: str | None = None,
|
|
224
|
+
):
|
|
225
|
+
"""Setup Azure OpenAI authentication and default model names in Spark environment.
|
|
226
|
+
1. Configures Azure OpenAI API key, base URL, and API version in SparkContext environment.
|
|
227
|
+
2. Configures Azure OpenAI API key, base URL, and API version in local process environment.
|
|
228
|
+
3. Optionally registers default model names for responses and embeddings in the DI container.
|
|
229
|
+
Args:
|
|
230
|
+
spark (SparkSession): The Spark session to configure.
|
|
231
|
+
api_key (str): Azure OpenAI API key for authentication.
|
|
232
|
+
base_url (str): Base URL for the Azure OpenAI resource.
|
|
233
|
+
api_version (str): API version to use. Defaults to "preview".
|
|
234
|
+
responses_model_name (str | None): Default model name for response generation.
|
|
235
|
+
If provided, registers `ResponsesModelName` in the DI container.
|
|
236
|
+
embeddings_model_name (str | None): Default model name for embeddings.
|
|
237
|
+
If provided, registers `EmbeddingsModelName` in the DI container.
|
|
238
|
+
|
|
239
|
+
Example:
|
|
240
|
+
```python
|
|
241
|
+
from pyspark.sql import SparkSession
|
|
242
|
+
from openaivec.spark import setup_azure
|
|
243
|
+
|
|
244
|
+
spark = SparkSession.builder.getOrCreate()
|
|
245
|
+
setup_azure(
|
|
246
|
+
spark,
|
|
247
|
+
api_key="azure-key",
|
|
248
|
+
base_url="https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/",
|
|
249
|
+
api_version="preview",
|
|
250
|
+
responses_model_name="gpt4-deployment",
|
|
251
|
+
embeddings_model_name="embedding-deployment",
|
|
252
|
+
)
|
|
253
|
+
```
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
CONTAINER.register(SparkSession, lambda: spark)
|
|
257
|
+
CONTAINER.register(SparkContext, lambda: CONTAINER.resolve(SparkSession).sparkContext)
|
|
258
|
+
|
|
259
|
+
sc = CONTAINER.resolve(SparkContext)
|
|
260
|
+
sc.environment["AZURE_OPENAI_API_KEY"] = api_key
|
|
261
|
+
sc.environment["AZURE_OPENAI_BASE_URL"] = base_url
|
|
262
|
+
sc.environment["AZURE_OPENAI_API_VERSION"] = api_version
|
|
263
|
+
|
|
264
|
+
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
|
265
|
+
os.environ["AZURE_OPENAI_BASE_URL"] = base_url
|
|
266
|
+
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
|
|
267
|
+
|
|
268
|
+
if responses_model_name:
|
|
269
|
+
CONTAINER.register(ResponsesModelName, lambda: ResponsesModelName(responses_model_name))
|
|
270
|
+
|
|
271
|
+
if embeddings_model_name:
|
|
272
|
+
CONTAINER.register(EmbeddingsModelName, lambda: EmbeddingsModelName(embeddings_model_name))
|
|
273
|
+
|
|
274
|
+
CONTAINER.clear_singletons()
|
|
136
275
|
|
|
137
276
|
|
|
138
277
|
def _python_type_to_spark(python_type):
|
|
139
278
|
origin = get_origin(python_type)
|
|
140
279
|
|
|
141
|
-
# For list types (e.g.,
|
|
142
|
-
if origin is list
|
|
280
|
+
# For list types (e.g., list[int])
|
|
281
|
+
if origin is list:
|
|
143
282
|
# Retrieve the inner type and recursively convert it
|
|
144
283
|
inner_type = get_args(python_type)[0]
|
|
145
284
|
return ArrayType(_python_type_to_spark(inner_type))
|
|
146
285
|
|
|
147
|
-
# For Optional types (Union
|
|
286
|
+
# For Optional types (T | None via Union internally)
|
|
148
287
|
elif origin is Union:
|
|
149
288
|
non_none_args = [arg for arg in get_args(python_type) if arg is not type(None)]
|
|
150
289
|
if len(non_none_args) == 1:
|
|
@@ -177,7 +316,7 @@ def _python_type_to_spark(python_type):
|
|
|
177
316
|
raise ValueError(f"Unsupported type: {python_type}")
|
|
178
317
|
|
|
179
318
|
|
|
180
|
-
def _pydantic_to_spark_schema(model:
|
|
319
|
+
def _pydantic_to_spark_schema(model: type[BaseModel]) -> StructType:
|
|
181
320
|
fields = []
|
|
182
321
|
for field_name, field in model.model_fields.items():
|
|
183
322
|
field_type = field.annotation
|
|
@@ -188,7 +327,7 @@ def _pydantic_to_spark_schema(model: Type[BaseModel]) -> StructType:
|
|
|
188
327
|
return StructType(fields)
|
|
189
328
|
|
|
190
329
|
|
|
191
|
-
def _safe_cast_str(x:
|
|
330
|
+
def _safe_cast_str(x: str | None) -> str | None:
|
|
192
331
|
try:
|
|
193
332
|
if x is None:
|
|
194
333
|
return None
|
|
@@ -199,7 +338,7 @@ def _safe_cast_str(x: Optional[str]) -> Optional[str]:
|
|
|
199
338
|
return None
|
|
200
339
|
|
|
201
340
|
|
|
202
|
-
def _safe_dump(x:
|
|
341
|
+
def _safe_dump(x: BaseModel | None) -> dict:
|
|
203
342
|
try:
|
|
204
343
|
if x is None:
|
|
205
344
|
return {}
|
|
@@ -212,45 +351,52 @@ def _safe_dump(x: Optional[BaseModel]) -> Dict:
|
|
|
212
351
|
|
|
213
352
|
def responses_udf(
|
|
214
353
|
instructions: str,
|
|
215
|
-
response_format:
|
|
216
|
-
model_name: str =
|
|
217
|
-
batch_size: int =
|
|
218
|
-
temperature: float = 0.0,
|
|
219
|
-
top_p: float = 1.0,
|
|
354
|
+
response_format: type[ResponseFormat] = str,
|
|
355
|
+
model_name: str | None = None,
|
|
356
|
+
batch_size: int | None = None,
|
|
220
357
|
max_concurrency: int = 8,
|
|
358
|
+
**api_kwargs,
|
|
221
359
|
) -> UserDefinedFunction:
|
|
222
360
|
"""Create an asynchronous Spark pandas UDF for generating responses.
|
|
223
361
|
|
|
224
|
-
Configures and builds UDFs that leverage `pandas_ext.aio.
|
|
362
|
+
Configures and builds UDFs that leverage `pandas_ext.aio.responses_with_cache`
|
|
225
363
|
to generate text or structured responses from OpenAI models asynchronously.
|
|
364
|
+
Each partition maintains its own cache to eliminate duplicate API calls within
|
|
365
|
+
the partition, significantly reducing API usage and costs when processing
|
|
366
|
+
datasets with overlapping content.
|
|
226
367
|
|
|
227
368
|
Note:
|
|
228
369
|
Authentication must be configured via SparkContext environment variables.
|
|
229
370
|
Set the appropriate environment variables on the SparkContext:
|
|
230
|
-
|
|
371
|
+
|
|
231
372
|
For OpenAI:
|
|
232
373
|
sc.environment["OPENAI_API_KEY"] = "your-openai-api-key"
|
|
233
|
-
|
|
374
|
+
|
|
234
375
|
For Azure OpenAI:
|
|
235
376
|
sc.environment["AZURE_OPENAI_API_KEY"] = "your-azure-openai-api-key"
|
|
236
|
-
sc.environment["
|
|
237
|
-
sc.environment["AZURE_OPENAI_API_VERSION"] = "
|
|
377
|
+
sc.environment["AZURE_OPENAI_BASE_URL"] = "https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/"
|
|
378
|
+
sc.environment["AZURE_OPENAI_API_VERSION"] = "preview"
|
|
238
379
|
|
|
239
380
|
Args:
|
|
240
381
|
instructions (str): The system prompt or instructions for the model.
|
|
241
|
-
response_format (
|
|
382
|
+
response_format (type[ResponseFormat]): The desired output format. Either `str` for plain text
|
|
242
383
|
or a Pydantic `BaseModel` for structured JSON output. Defaults to `str`.
|
|
243
|
-
model_name (str):
|
|
244
|
-
|
|
245
|
-
|
|
384
|
+
model_name (str | None): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
|
|
385
|
+
For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container
|
|
386
|
+
via ResponsesModelName if not provided.
|
|
387
|
+
batch_size (int | None): Number of rows per async batch request within each partition.
|
|
246
388
|
Larger values reduce API call overhead but increase memory usage.
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
389
|
+
Defaults to None (automatic batch size optimization that dynamically
|
|
390
|
+
adjusts based on execution time, targeting 30-60 seconds per batch).
|
|
391
|
+
Set to a positive integer (e.g., 32-128) for fixed batch size.
|
|
250
392
|
max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
|
|
251
393
|
Total cluster concurrency = max_concurrency × number_of_executors.
|
|
252
394
|
Higher values increase throughput but may hit OpenAI rate limits.
|
|
253
395
|
Recommended: 4-12 per executor. Defaults to 8.
|
|
396
|
+
**api_kwargs: Additional OpenAI API parameters (e.g. ``temperature``, ``top_p``,
|
|
397
|
+
``frequency_penalty``, ``presence_penalty``, ``seed``, ``max_output_tokens``, etc.)
|
|
398
|
+
forwarded verbatim to the underlying API calls. These parameters are applied to
|
|
399
|
+
all API requests made by the UDF.
|
|
254
400
|
|
|
255
401
|
Returns:
|
|
256
402
|
UserDefinedFunction: A Spark pandas UDF configured to generate responses asynchronously.
|
|
@@ -259,89 +405,130 @@ def responses_udf(
|
|
|
259
405
|
Raises:
|
|
260
406
|
ValueError: If `response_format` is not `str` or a Pydantic `BaseModel`.
|
|
261
407
|
|
|
408
|
+
Example:
|
|
409
|
+
```python
|
|
410
|
+
from pyspark.sql import SparkSession
|
|
411
|
+
from openaivec.spark import responses_udf, setup
|
|
412
|
+
|
|
413
|
+
spark = SparkSession.builder.getOrCreate()
|
|
414
|
+
setup(spark, api_key="sk-***", responses_model_name="gpt-4.1-mini")
|
|
415
|
+
udf = responses_udf("Reply with one word.")
|
|
416
|
+
spark.udf.register("short_answer", udf)
|
|
417
|
+
df = spark.createDataFrame([("hello",), ("bye",)], ["text"])
|
|
418
|
+
df.selectExpr("short_answer(text) as reply").show()
|
|
419
|
+
```
|
|
420
|
+
|
|
262
421
|
Note:
|
|
263
422
|
For optimal performance in distributed environments:
|
|
423
|
+
- **Automatic Caching**: Duplicate inputs within each partition are cached,
|
|
424
|
+
reducing API calls and costs significantly on datasets with repeated content
|
|
264
425
|
- Monitor OpenAI API rate limits when scaling executor count
|
|
265
426
|
- Consider your OpenAI tier limits: total_requests = max_concurrency × executors
|
|
266
427
|
- Use Spark UI to optimize partition sizes relative to batch_size
|
|
267
428
|
"""
|
|
429
|
+
_model_name = model_name or CONTAINER.resolve(ResponsesModelName).value
|
|
430
|
+
|
|
268
431
|
if issubclass(response_format, BaseModel):
|
|
269
432
|
spark_schema = _pydantic_to_spark_schema(response_format)
|
|
270
433
|
json_schema_string = serialize_base_model(response_format)
|
|
271
434
|
|
|
272
|
-
@pandas_udf(returnType=spark_schema)
|
|
435
|
+
@pandas_udf(returnType=spark_schema) # type: ignore[call-overload]
|
|
273
436
|
def structure_udf(col: Iterator[pd.Series]) -> Iterator[pd.DataFrame]:
|
|
274
|
-
pandas_ext.
|
|
437
|
+
pandas_ext.set_responses_model(_model_name)
|
|
438
|
+
response_format = deserialize_base_model(json_schema_string)
|
|
439
|
+
cache = AsyncBatchingMapProxy[str, response_format](
|
|
440
|
+
batch_size=batch_size,
|
|
441
|
+
max_concurrency=max_concurrency,
|
|
442
|
+
)
|
|
275
443
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
444
|
+
try:
|
|
445
|
+
for part in col:
|
|
446
|
+
predictions: pd.Series = asyncio.run(
|
|
447
|
+
part.aio.responses_with_cache(
|
|
448
|
+
instructions=instructions,
|
|
449
|
+
response_format=response_format,
|
|
450
|
+
cache=cache,
|
|
451
|
+
**api_kwargs,
|
|
452
|
+
)
|
|
285
453
|
)
|
|
286
|
-
|
|
287
|
-
|
|
454
|
+
yield pd.DataFrame(predictions.map(_safe_dump).tolist())
|
|
455
|
+
finally:
|
|
456
|
+
asyncio.run(cache.clear())
|
|
288
457
|
|
|
289
|
-
return structure_udf
|
|
458
|
+
return structure_udf # type: ignore[return-value]
|
|
290
459
|
|
|
291
460
|
elif issubclass(response_format, str):
|
|
292
461
|
|
|
293
|
-
@pandas_udf(returnType=StringType())
|
|
462
|
+
@pandas_udf(returnType=StringType()) # type: ignore[call-overload]
|
|
294
463
|
def string_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
|
|
295
|
-
pandas_ext.
|
|
464
|
+
pandas_ext.set_responses_model(_model_name)
|
|
465
|
+
cache = AsyncBatchingMapProxy[str, str](
|
|
466
|
+
batch_size=batch_size,
|
|
467
|
+
max_concurrency=max_concurrency,
|
|
468
|
+
)
|
|
296
469
|
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
470
|
+
try:
|
|
471
|
+
for part in col:
|
|
472
|
+
predictions: pd.Series = asyncio.run(
|
|
473
|
+
part.aio.responses_with_cache(
|
|
474
|
+
instructions=instructions,
|
|
475
|
+
response_format=str,
|
|
476
|
+
cache=cache,
|
|
477
|
+
**api_kwargs,
|
|
478
|
+
)
|
|
306
479
|
)
|
|
307
|
-
|
|
308
|
-
|
|
480
|
+
yield predictions.map(_safe_cast_str)
|
|
481
|
+
finally:
|
|
482
|
+
asyncio.run(cache.clear())
|
|
309
483
|
|
|
310
|
-
return string_udf
|
|
484
|
+
return string_udf # type: ignore[return-value]
|
|
311
485
|
|
|
312
486
|
else:
|
|
313
487
|
raise ValueError(f"Unsupported response_format: {response_format}")
|
|
314
488
|
|
|
315
489
|
|
|
316
|
-
|
|
317
490
|
def task_udf(
|
|
318
|
-
task: PreparedTask,
|
|
319
|
-
model_name: str =
|
|
320
|
-
batch_size: int =
|
|
491
|
+
task: PreparedTask[ResponseFormat],
|
|
492
|
+
model_name: str | None = None,
|
|
493
|
+
batch_size: int | None = None,
|
|
321
494
|
max_concurrency: int = 8,
|
|
495
|
+
**api_kwargs,
|
|
322
496
|
) -> UserDefinedFunction:
|
|
323
497
|
"""Create an asynchronous Spark pandas UDF from a predefined task.
|
|
324
498
|
|
|
325
499
|
This function allows users to create UDFs from predefined tasks such as sentiment analysis,
|
|
326
500
|
translation, or other common NLP operations defined in the openaivec.task module.
|
|
501
|
+
Each partition maintains its own cache to eliminate duplicate API calls within
|
|
502
|
+
the partition, significantly reducing API usage and costs when processing
|
|
503
|
+
datasets with overlapping content.
|
|
327
504
|
|
|
328
505
|
Args:
|
|
329
|
-
task (PreparedTask): A predefined task configuration containing instructions
|
|
330
|
-
response format
|
|
331
|
-
model_name (str):
|
|
332
|
-
|
|
333
|
-
|
|
506
|
+
task (PreparedTask): A predefined task configuration containing instructions
|
|
507
|
+
and response format.
|
|
508
|
+
model_name (str | None): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
|
|
509
|
+
For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container
|
|
510
|
+
via ResponsesModelName if not provided.
|
|
511
|
+
batch_size (int | None): Number of rows per async batch request within each partition.
|
|
334
512
|
Larger values reduce API call overhead but increase memory usage.
|
|
335
|
-
|
|
513
|
+
Defaults to None (automatic batch size optimization that dynamically
|
|
514
|
+
adjusts based on execution time, targeting 30-60 seconds per batch).
|
|
515
|
+
Set to a positive integer (e.g., 32-128) for fixed batch size.
|
|
336
516
|
max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
|
|
337
517
|
Total cluster concurrency = max_concurrency × number_of_executors.
|
|
338
518
|
Higher values increase throughput but may hit OpenAI rate limits.
|
|
339
519
|
Recommended: 4-12 per executor. Defaults to 8.
|
|
340
520
|
|
|
521
|
+
Additional Keyword Args:
|
|
522
|
+
Arbitrary OpenAI Responses API parameters (e.g. ``temperature``, ``top_p``,
|
|
523
|
+
``frequency_penalty``, ``presence_penalty``, ``seed``, ``max_output_tokens``, etc.)
|
|
524
|
+
are forwarded verbatim to the underlying API calls. These parameters are applied to
|
|
525
|
+
all API requests made by the UDF.
|
|
526
|
+
|
|
341
527
|
Returns:
|
|
342
528
|
UserDefinedFunction: A Spark pandas UDF configured to execute the specified task
|
|
343
|
-
asynchronously
|
|
344
|
-
|
|
529
|
+
asynchronously with automatic caching for duplicate inputs within each partition.
|
|
530
|
+
Output schema is StringType for str response format or a struct derived from
|
|
531
|
+
the task's response format for BaseModel.
|
|
345
532
|
|
|
346
533
|
Example:
|
|
347
534
|
```python
|
|
@@ -351,186 +538,301 @@ def task_udf(
|
|
|
351
538
|
|
|
352
539
|
spark.udf.register("analyze_sentiment", sentiment_udf)
|
|
353
540
|
```
|
|
541
|
+
|
|
542
|
+
Note:
|
|
543
|
+
**Automatic Caching**: Duplicate inputs within each partition are cached,
|
|
544
|
+
reducing API calls and costs significantly on datasets with repeated content.
|
|
354
545
|
"""
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
# Deserialize the response format from JSON
|
|
364
|
-
response_format = deserialize_base_model(task_response_format_json)
|
|
365
|
-
spark_schema = _pydantic_to_spark_schema(response_format)
|
|
546
|
+
return responses_udf(
|
|
547
|
+
instructions=task.instructions,
|
|
548
|
+
response_format=task.response_format,
|
|
549
|
+
model_name=model_name,
|
|
550
|
+
batch_size=batch_size,
|
|
551
|
+
max_concurrency=max_concurrency,
|
|
552
|
+
**api_kwargs,
|
|
553
|
+
)
|
|
366
554
|
|
|
367
|
-
@pandas_udf(returnType=spark_schema)
|
|
368
|
-
def task_udf(col: Iterator[pd.Series]) -> Iterator[pd.DataFrame]:
|
|
369
|
-
pandas_ext.responses_model(model_name)
|
|
370
555
|
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
top_p=task_top_p,
|
|
379
|
-
max_concurrency=max_concurrency,
|
|
380
|
-
)
|
|
381
|
-
)
|
|
382
|
-
yield pd.DataFrame(predictions.map(_safe_dump).tolist())
|
|
556
|
+
def infer_schema(
|
|
557
|
+
instructions: str,
|
|
558
|
+
example_table_name: str,
|
|
559
|
+
example_field_name: str,
|
|
560
|
+
max_examples: int = 100,
|
|
561
|
+
) -> SchemaInferenceOutput:
|
|
562
|
+
"""Infer the schema for a response format based on example data.
|
|
383
563
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
@pandas_udf(returnType=StringType())
|
|
389
|
-
def task_string_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
|
|
390
|
-
pandas_ext.responses_model(model_name)
|
|
564
|
+
This function retrieves examples from a Spark table and infers the schema
|
|
565
|
+
for the response format using the provided instructions. It is useful when
|
|
566
|
+
you want to dynamically generate a schema based on existing data.
|
|
391
567
|
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
batch_size=batch_size,
|
|
398
|
-
temperature=task_temperature,
|
|
399
|
-
top_p=task_top_p,
|
|
400
|
-
max_concurrency=max_concurrency,
|
|
401
|
-
)
|
|
402
|
-
)
|
|
403
|
-
yield predictions.map(_safe_cast_str)
|
|
568
|
+
Args:
|
|
569
|
+
instructions (str): Instructions for the model to infer the schema.
|
|
570
|
+
example_table_name (str | None): Name of the Spark table containing example data.
|
|
571
|
+
example_field_name (str | None): Name of the field in the table to use as examples.
|
|
572
|
+
max_examples (int): Maximum number of examples to retrieve for schema inference.
|
|
404
573
|
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
else:
|
|
408
|
-
raise ValueError(f"Unsupported response_format in task: {task.response_format}")
|
|
574
|
+
Returns:
|
|
575
|
+
InferredSchema: An object containing the inferred schema and response format.
|
|
409
576
|
|
|
577
|
+
Example:
|
|
578
|
+
```python
|
|
579
|
+
from pyspark.sql import SparkSession
|
|
580
|
+
|
|
581
|
+
spark = SparkSession.builder.getOrCreate()
|
|
582
|
+
spark.createDataFrame([("great product",), ("bad service",)], ["text"]).createOrReplaceTempView("examples")
|
|
583
|
+
infer_schema(
|
|
584
|
+
instructions="Classify sentiment as positive or negative.",
|
|
585
|
+
example_table_name="examples",
|
|
586
|
+
example_field_name="text",
|
|
587
|
+
max_examples=2,
|
|
588
|
+
)
|
|
589
|
+
```
|
|
590
|
+
"""
|
|
410
591
|
|
|
592
|
+
spark = CONTAINER.resolve(SparkSession)
|
|
593
|
+
examples: list[str] = (
|
|
594
|
+
spark.table(example_table_name).rdd.map(lambda row: row[example_field_name]).takeSample(False, max_examples)
|
|
595
|
+
)
|
|
411
596
|
|
|
597
|
+
input = SchemaInferenceInput(
|
|
598
|
+
instructions=instructions,
|
|
599
|
+
examples=examples,
|
|
600
|
+
)
|
|
601
|
+
inferer = CONTAINER.resolve(SchemaInferer)
|
|
602
|
+
return inferer.infer_schema(input)
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
def parse_udf(
|
|
606
|
+
instructions: str,
|
|
607
|
+
response_format: type[ResponseFormat] | None = None,
|
|
608
|
+
example_table_name: str | None = None,
|
|
609
|
+
example_field_name: str | None = None,
|
|
610
|
+
max_examples: int = 100,
|
|
611
|
+
model_name: str | None = None,
|
|
612
|
+
batch_size: int | None = None,
|
|
613
|
+
max_concurrency: int = 8,
|
|
614
|
+
**api_kwargs,
|
|
615
|
+
) -> UserDefinedFunction:
|
|
616
|
+
"""Create an asynchronous Spark pandas UDF for parsing responses.
|
|
617
|
+
This function allows users to create UDFs that parse responses based on
|
|
618
|
+
provided instructions and either a predefined response format or example data.
|
|
619
|
+
It supports both structured responses using Pydantic models and plain text responses.
|
|
620
|
+
Each partition maintains its own cache to eliminate duplicate API calls within
|
|
621
|
+
the partition, significantly reducing API usage and costs when processing
|
|
622
|
+
datasets with overlapping content.
|
|
412
623
|
|
|
413
|
-
|
|
624
|
+
Args:
|
|
625
|
+
instructions (str): The system prompt or instructions for the model.
|
|
626
|
+
response_format (type[ResponseFormat] | None): The desired output format.
|
|
627
|
+
Either `str` for plain text or a Pydantic `BaseModel` for structured JSON output.
|
|
628
|
+
If not provided, the schema will be inferred from example data.
|
|
629
|
+
example_table_name (str | None): Name of the Spark table containing example data.
|
|
630
|
+
If provided, `example_field_name` must also be specified.
|
|
631
|
+
example_field_name (str | None): Name of the field in the table to use as examples.
|
|
632
|
+
If provided, `example_table_name` must also be specified.
|
|
633
|
+
max_examples (int): Maximum number of examples to retrieve for schema inference.
|
|
634
|
+
Defaults to 100.
|
|
635
|
+
model_name (str | None): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
|
|
636
|
+
For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container
|
|
637
|
+
via ResponsesModelName if not provided.
|
|
638
|
+
batch_size (int | None): Number of rows per async batch request within each partition.
|
|
639
|
+
Larger values reduce API call overhead but increase memory usage.
|
|
640
|
+
Defaults to None (automatic batch size optimization that dynamically
|
|
641
|
+
adjusts based on execution time, targeting 30-60 seconds per batch).
|
|
642
|
+
Set to a positive integer (e.g., 32-128) for fixed batch size
|
|
643
|
+
max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
|
|
644
|
+
Total cluster concurrency = max_concurrency × number_of_executors.
|
|
645
|
+
Higher values increase throughput but may hit OpenAI rate limits.
|
|
646
|
+
Recommended: 4-12 per executor. Defaults to 8.
|
|
647
|
+
**api_kwargs: Additional OpenAI API parameters (e.g. ``temperature``, ``top_p``,
|
|
648
|
+
``frequency_penalty``, ``presence_penalty``, ``seed``, ``max_output_tokens``, etc.)
|
|
649
|
+
forwarded verbatim to the underlying API calls. These parameters are applied to
|
|
650
|
+
all API requests made by the UDF and override any parameters set in the
|
|
651
|
+
response_format or example data.
|
|
652
|
+
Example:
|
|
653
|
+
```python
|
|
654
|
+
from pyspark.sql import SparkSession
|
|
655
|
+
|
|
656
|
+
spark = SparkSession.builder.getOrCreate()
|
|
657
|
+
spark.createDataFrame(
|
|
658
|
+
[("Order #123 delivered",), ("Order #456 delayed",)],
|
|
659
|
+
["body"],
|
|
660
|
+
).createOrReplaceTempView("messages")
|
|
661
|
+
udf = parse_udf(
|
|
662
|
+
instructions="Extract order id as `order_id` and status as `status`.",
|
|
663
|
+
example_table_name="messages",
|
|
664
|
+
example_field_name="body",
|
|
665
|
+
)
|
|
666
|
+
spark.udf.register("parse_ticket", udf)
|
|
667
|
+
spark.sql("SELECT parse_ticket(body) AS parsed FROM messages").show()
|
|
668
|
+
```
|
|
669
|
+
Returns:
|
|
670
|
+
UserDefinedFunction: A Spark pandas UDF configured to parse responses asynchronously.
|
|
671
|
+
Output schema is `StringType` for str response format or a struct derived from
|
|
672
|
+
the response_format for BaseModel.
|
|
673
|
+
Raises:
|
|
674
|
+
ValueError: If neither `response_format` nor `example_table_name` and `example_field_name` are provided.
|
|
675
|
+
"""
|
|
676
|
+
|
|
677
|
+
if not response_format and not (example_field_name and example_table_name):
|
|
678
|
+
raise ValueError("Either response_format or example_table_name and example_field_name must be provided.")
|
|
679
|
+
|
|
680
|
+
schema: SchemaInferenceOutput | None = None
|
|
681
|
+
|
|
682
|
+
if not response_format:
|
|
683
|
+
schema = infer_schema(
|
|
684
|
+
instructions=instructions,
|
|
685
|
+
example_table_name=example_table_name,
|
|
686
|
+
example_field_name=example_field_name,
|
|
687
|
+
max_examples=max_examples,
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
return responses_udf(
|
|
691
|
+
instructions=schema.inference_prompt if schema else instructions,
|
|
692
|
+
response_format=schema.model if schema else response_format,
|
|
693
|
+
model_name=model_name,
|
|
694
|
+
batch_size=batch_size,
|
|
695
|
+
max_concurrency=max_concurrency,
|
|
696
|
+
**api_kwargs,
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
|
|
700
|
+
def embeddings_udf(
|
|
701
|
+
model_name: str | None = None,
|
|
702
|
+
batch_size: int | None = None,
|
|
703
|
+
max_concurrency: int = 8,
|
|
704
|
+
**api_kwargs,
|
|
705
|
+
) -> UserDefinedFunction:
|
|
414
706
|
"""Create an asynchronous Spark pandas UDF for generating embeddings.
|
|
415
707
|
|
|
416
|
-
Configures and builds UDFs that leverage `pandas_ext.aio.
|
|
708
|
+
Configures and builds UDFs that leverage `pandas_ext.aio.embeddings_with_cache`
|
|
417
709
|
to generate vector embeddings from OpenAI models asynchronously.
|
|
710
|
+
Each partition maintains its own cache to eliminate duplicate API calls within
|
|
711
|
+
the partition, significantly reducing API usage and costs when processing
|
|
712
|
+
datasets with overlapping content.
|
|
418
713
|
|
|
419
714
|
Note:
|
|
420
715
|
Authentication must be configured via SparkContext environment variables.
|
|
421
716
|
Set the appropriate environment variables on the SparkContext:
|
|
422
|
-
|
|
717
|
+
|
|
423
718
|
For OpenAI:
|
|
424
719
|
sc.environment["OPENAI_API_KEY"] = "your-openai-api-key"
|
|
425
|
-
|
|
720
|
+
|
|
426
721
|
For Azure OpenAI:
|
|
427
722
|
sc.environment["AZURE_OPENAI_API_KEY"] = "your-azure-openai-api-key"
|
|
428
|
-
sc.environment["
|
|
429
|
-
sc.environment["AZURE_OPENAI_API_VERSION"] = "
|
|
723
|
+
sc.environment["AZURE_OPENAI_BASE_URL"] = "https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/"
|
|
724
|
+
sc.environment["AZURE_OPENAI_API_VERSION"] = "preview"
|
|
430
725
|
|
|
431
726
|
Args:
|
|
432
|
-
model_name (str):
|
|
433
|
-
|
|
434
|
-
|
|
727
|
+
model_name (str | None): For Azure OpenAI, use your deployment name (e.g., "my-embedding-deployment").
|
|
728
|
+
For OpenAI, use the model name (e.g., "text-embedding-3-small").
|
|
729
|
+
Defaults to configured model in DI container via EmbeddingsModelName if not provided.
|
|
730
|
+
batch_size (int | None): Number of rows per async batch request within each partition.
|
|
435
731
|
Larger values reduce API call overhead but increase memory usage.
|
|
732
|
+
Defaults to None (automatic batch size optimization that dynamically
|
|
733
|
+
adjusts based on execution time, targeting 30-60 seconds per batch).
|
|
734
|
+
Set to a positive integer (e.g., 64-256) for fixed batch size.
|
|
436
735
|
Embeddings typically handle larger batches efficiently.
|
|
437
|
-
Recommended: 64-256 depending on text length. Defaults to 128.
|
|
438
736
|
max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
|
|
439
737
|
Total cluster concurrency = max_concurrency × number_of_executors.
|
|
440
738
|
Higher values increase throughput but may hit OpenAI rate limits.
|
|
441
739
|
Recommended: 4-12 per executor. Defaults to 8.
|
|
740
|
+
**api_kwargs: Additional OpenAI API parameters (e.g., dimensions for text-embedding-3 models).
|
|
442
741
|
|
|
443
742
|
Returns:
|
|
444
|
-
UserDefinedFunction: A Spark pandas UDF configured to generate embeddings asynchronously
|
|
743
|
+
UserDefinedFunction: A Spark pandas UDF configured to generate embeddings asynchronously
|
|
744
|
+
with automatic caching for duplicate inputs within each partition,
|
|
445
745
|
returning an `ArrayType(FloatType())` column.
|
|
446
746
|
|
|
447
747
|
Note:
|
|
448
748
|
For optimal performance in distributed environments:
|
|
749
|
+
- **Automatic Caching**: Duplicate inputs within each partition are cached,
|
|
750
|
+
reducing API calls and costs significantly on datasets with repeated content
|
|
449
751
|
- Monitor OpenAI API rate limits when scaling executor count
|
|
450
752
|
- Consider your OpenAI tier limits: total_requests = max_concurrency × executors
|
|
451
753
|
- Embeddings API typically has higher throughput than chat completions
|
|
452
754
|
- Use larger batch_size for embeddings compared to response generation
|
|
453
755
|
"""
|
|
454
|
-
|
|
756
|
+
|
|
757
|
+
_model_name = model_name or CONTAINER.resolve(EmbeddingsModelName).value
|
|
758
|
+
|
|
759
|
+
@pandas_udf(returnType=ArrayType(FloatType())) # type: ignore[call-overload,misc]
|
|
455
760
|
def _embeddings_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
|
|
456
|
-
pandas_ext.
|
|
761
|
+
pandas_ext.set_embeddings_model(_model_name)
|
|
762
|
+
cache = AsyncBatchingMapProxy[str, np.ndarray](
|
|
763
|
+
batch_size=batch_size,
|
|
764
|
+
max_concurrency=max_concurrency,
|
|
765
|
+
)
|
|
457
766
|
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
part.aio.
|
|
461
|
-
|
|
462
|
-
|
|
767
|
+
try:
|
|
768
|
+
for part in col:
|
|
769
|
+
embeddings: pd.Series = asyncio.run(part.aio.embeddings_with_cache(cache=cache, **api_kwargs))
|
|
770
|
+
yield embeddings.map(lambda x: x.tolist())
|
|
771
|
+
finally:
|
|
772
|
+
asyncio.run(cache.clear())
|
|
463
773
|
|
|
464
|
-
return _embeddings_udf
|
|
774
|
+
return _embeddings_udf # type: ignore[return-value]
|
|
465
775
|
|
|
466
776
|
|
|
467
|
-
def split_to_chunks_udf(
|
|
777
|
+
def split_to_chunks_udf(max_tokens: int, sep: list[str]) -> UserDefinedFunction:
|
|
468
778
|
"""Create a pandas‑UDF that splits text into token‑bounded chunks.
|
|
469
779
|
|
|
470
780
|
Args:
|
|
471
|
-
model_name (str): Model identifier passed to *tiktoken*.
|
|
472
781
|
max_tokens (int): Maximum tokens allowed per chunk.
|
|
473
|
-
sep (
|
|
782
|
+
sep (list[str]): Ordered list of separator strings used by ``TextChunker``.
|
|
474
783
|
|
|
475
784
|
Returns:
|
|
476
785
|
A pandas UDF producing an ``ArrayType(StringType())`` column whose
|
|
477
786
|
values are lists of chunks respecting the ``max_tokens`` limit.
|
|
478
787
|
"""
|
|
479
788
|
|
|
480
|
-
@pandas_udf(ArrayType(StringType()))
|
|
789
|
+
@pandas_udf(ArrayType(StringType())) # type: ignore[call-overload,misc]
|
|
481
790
|
def fn(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
_TIKTOKEN_ENC = tiktoken.encoding_for_model(model_name)
|
|
485
|
-
|
|
486
|
-
chunker = TextChunker(_TIKTOKEN_ENC)
|
|
791
|
+
encoding = tiktoken.get_encoding("o200k_base")
|
|
792
|
+
chunker = TextChunker(encoding)
|
|
487
793
|
|
|
488
794
|
for part in col:
|
|
489
795
|
yield part.map(lambda x: chunker.split(x, max_tokens=max_tokens, sep=sep) if isinstance(x, str) else [])
|
|
490
796
|
|
|
491
|
-
return fn
|
|
797
|
+
return fn # type: ignore[return-value]
|
|
492
798
|
|
|
493
799
|
|
|
494
|
-
def count_tokens_udf(
|
|
800
|
+
def count_tokens_udf() -> UserDefinedFunction:
|
|
495
801
|
"""Create a pandas‑UDF that counts tokens for every string cell.
|
|
496
802
|
|
|
497
803
|
The UDF uses *tiktoken* to approximate tokenisation and caches the
|
|
498
804
|
resulting ``Encoding`` object per executor.
|
|
499
805
|
|
|
500
|
-
Args:
|
|
501
|
-
model_name (str): Model identifier understood by ``tiktoken``.
|
|
502
|
-
|
|
503
806
|
Returns:
|
|
504
807
|
A pandas UDF producing an ``IntegerType`` column with token counts.
|
|
505
808
|
"""
|
|
506
809
|
|
|
507
|
-
@pandas_udf(IntegerType())
|
|
810
|
+
@pandas_udf(IntegerType()) # type: ignore[call-overload]
|
|
508
811
|
def fn(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
|
|
509
|
-
|
|
510
|
-
if _TIKTOKEN_ENC is None:
|
|
511
|
-
_TIKTOKEN_ENC = tiktoken.encoding_for_model(model_name)
|
|
812
|
+
encoding = tiktoken.get_encoding("o200k_base")
|
|
512
813
|
|
|
513
814
|
for part in col:
|
|
514
|
-
yield part.map(lambda x: len(
|
|
815
|
+
yield part.map(lambda x: len(encoding.encode(x)) if isinstance(x, str) else 0)
|
|
515
816
|
|
|
516
|
-
return fn
|
|
817
|
+
return fn # type: ignore[return-value]
|
|
517
818
|
|
|
518
819
|
|
|
519
820
|
def similarity_udf() -> UserDefinedFunction:
|
|
520
|
-
|
|
521
|
-
def fn(a: pd.Series, b: pd.Series) -> pd.Series:
|
|
522
|
-
"""Compute cosine similarity between two vectors.
|
|
821
|
+
"""Create a pandas-UDF that computes cosine similarity between embedding vectors.
|
|
523
822
|
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
823
|
+
Returns:
|
|
824
|
+
UserDefinedFunction: A Spark pandas UDF that takes two embedding vector columns
|
|
825
|
+
and returns their cosine similarity as a FloatType column.
|
|
826
|
+
"""
|
|
527
827
|
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
"""
|
|
828
|
+
@pandas_udf(FloatType()) # type: ignore[call-overload]
|
|
829
|
+
def fn(a: pd.Series, b: pd.Series) -> pd.Series:
|
|
531
830
|
# Import pandas_ext to ensure .ai accessor is available in Spark workers
|
|
532
|
-
from
|
|
831
|
+
from openaivec import pandas_ext
|
|
832
|
+
|
|
833
|
+
# Explicitly reference pandas_ext to satisfy linters
|
|
834
|
+
assert pandas_ext is not None
|
|
533
835
|
|
|
534
836
|
return pd.DataFrame({"a": a, "b": b}).ai.similarity("a", "b")
|
|
535
837
|
|
|
536
|
-
return fn
|
|
838
|
+
return fn # type: ignore[return-value]
|