openaivec 0.10.0__py3-none-any.whl → 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. openaivec/__init__.py +13 -4
  2. openaivec/_cache/__init__.py +12 -0
  3. openaivec/_cache/optimize.py +109 -0
  4. openaivec/_cache/proxy.py +806 -0
  5. openaivec/_di.py +326 -0
  6. openaivec/_embeddings.py +203 -0
  7. openaivec/{log.py → _log.py} +2 -2
  8. openaivec/_model.py +113 -0
  9. openaivec/{prompt.py → _prompt.py} +95 -28
  10. openaivec/_provider.py +207 -0
  11. openaivec/_responses.py +511 -0
  12. openaivec/_schema/__init__.py +9 -0
  13. openaivec/_schema/infer.py +340 -0
  14. openaivec/_schema/spec.py +350 -0
  15. openaivec/_serialize.py +234 -0
  16. openaivec/{util.py → _util.py} +25 -85
  17. openaivec/pandas_ext.py +1635 -425
  18. openaivec/spark.py +604 -335
  19. openaivec/task/__init__.py +27 -29
  20. openaivec/task/customer_support/__init__.py +9 -15
  21. openaivec/task/customer_support/customer_sentiment.py +51 -41
  22. openaivec/task/customer_support/inquiry_classification.py +86 -61
  23. openaivec/task/customer_support/inquiry_summary.py +44 -45
  24. openaivec/task/customer_support/intent_analysis.py +56 -41
  25. openaivec/task/customer_support/response_suggestion.py +49 -43
  26. openaivec/task/customer_support/urgency_analysis.py +76 -71
  27. openaivec/task/nlp/__init__.py +4 -4
  28. openaivec/task/nlp/dependency_parsing.py +19 -20
  29. openaivec/task/nlp/keyword_extraction.py +22 -24
  30. openaivec/task/nlp/morphological_analysis.py +25 -25
  31. openaivec/task/nlp/named_entity_recognition.py +26 -28
  32. openaivec/task/nlp/sentiment_analysis.py +29 -21
  33. openaivec/task/nlp/translation.py +24 -30
  34. openaivec/task/table/__init__.py +3 -0
  35. openaivec/task/table/fillna.py +183 -0
  36. openaivec-1.0.10.dist-info/METADATA +399 -0
  37. openaivec-1.0.10.dist-info/RECORD +39 -0
  38. {openaivec-0.10.0.dist-info → openaivec-1.0.10.dist-info}/WHEEL +1 -1
  39. openaivec/embeddings.py +0 -172
  40. openaivec/responses.py +0 -392
  41. openaivec/serialize.py +0 -225
  42. openaivec/task/model.py +0 -84
  43. openaivec-0.10.0.dist-info/METADATA +0 -546
  44. openaivec-0.10.0.dist-info/RECORD +0 -29
  45. {openaivec-0.10.0.dist-info → openaivec-1.0.10.dist-info}/licenses/LICENSE +0 -0
openaivec/spark.py CHANGED
@@ -1,53 +1,52 @@
1
1
  """Asynchronous Spark UDFs for the OpenAI and Azure OpenAI APIs.
2
2
 
3
- This module provides builder classes (`ResponsesUDFBuilder`, `EmbeddingsUDFBuilder`)
3
+ This module provides functions (`responses_udf`, `task_udf`, `embeddings_udf`,
4
+ `count_tokens_udf`, `split_to_chunks_udf`, `similarity_udf`, `parse_udf`)
4
5
  for creating asynchronous Spark UDFs that communicate with either the public
5
6
  OpenAI API or Azure OpenAI using the `openaivec.spark` subpackage.
6
- It supports UDFs for generating responses and creating embeddings asynchronously.
7
- The UDFs operate on Spark DataFrames and leverage asyncio for potentially
8
- improved performance in I/O-bound operations.
7
+ It supports UDFs for generating responses, creating embeddings, parsing text,
8
+ and computing similarities asynchronously. The UDFs operate on Spark DataFrames
9
+ and leverage asyncio for improved performance in I/O-bound operations.
10
+
11
+ **Performance Optimization**: All AI-powered UDFs (`responses_udf`, `task_udf`, `embeddings_udf`, `parse_udf`)
12
+ automatically cache duplicate inputs within each partition, significantly reducing
13
+ API calls and costs when processing datasets with overlapping content.
14
+
9
15
 
10
16
  ## Setup
11
17
 
12
- First, obtain a Spark session:
18
+ First, obtain a Spark session and configure authentication:
13
19
 
14
20
  ```python
15
21
  from pyspark.sql import SparkSession
22
+ from openaivec.spark import setup, setup_azure
16
23
 
17
24
  spark = SparkSession.builder.getOrCreate()
18
- ```
19
-
20
- Next, instantiate UDF builders with your OpenAI API key (or Azure credentials)
21
- and model/deployment names, then register the desired UDFs:
22
-
23
- ```python
24
- import os
25
- from openaivec.spark import ResponsesUDFBuilder, EmbeddingsUDFBuilder
26
- from pydantic import BaseModel
27
25
 
28
26
  # Option 1: Using OpenAI
29
- resp_builder = ResponsesUDFBuilder.of_openai(
30
- api_key=os.getenv("OPENAI_API_KEY"),
31
- model_name="gpt-4o-mini", # Model for responses
32
- )
33
- emb_builder = EmbeddingsUDFBuilder.of_openai(
34
- api_key=os.getenv("OPENAI_API_KEY"),
35
- model_name="text-embedding-3-small", # Model for embeddings
27
+ setup(
28
+ spark,
29
+ api_key="your-openai-api-key",
30
+ responses_model_name="gpt-4.1-mini", # Optional: set default model
31
+ embeddings_model_name="text-embedding-3-small" # Optional: set default model
36
32
  )
37
33
 
38
34
  # Option 2: Using Azure OpenAI
39
- # resp_builder = ResponsesUDFBuilder.of_azure_openai(
40
- # api_key=os.getenv("AZURE_OPENAI_KEY"),
41
- # endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
42
- # api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
43
- # model_name="your-resp-deployment-name", # Deployment for responses
44
- # )
45
- # emb_builder = EmbeddingsUDFBuilder.of_azure_openai(
46
- # api_key=os.getenv("AZURE_OPENAI_KEY"),
47
- # endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
48
- # api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
49
- # model_name="your-emb-deployment-name", # Deployment for embeddings
35
+ # setup_azure(
36
+ # spark,
37
+ # api_key="your-azure-openai-api-key",
38
+ # base_url="https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/",
39
+ # api_version="preview",
40
+ # responses_model_name="my-gpt4-deployment", # Optional: set default deployment
41
+ # embeddings_model_name="my-embedding-deployment" # Optional: set default deployment
50
42
  # )
43
+ ```
44
+
45
+ Next, create UDFs and register them:
46
+
47
+ ```python
48
+ from openaivec.spark import responses_udf, task_udf, embeddings_udf, count_tokens_udf, split_to_chunks_udf
49
+ from pydantic import BaseModel
51
50
 
52
51
  # Define a Pydantic model for structured responses (optional)
53
52
  class Translation(BaseModel):
@@ -55,27 +54,39 @@ class Translation(BaseModel):
55
54
  fr: str
56
55
  # ... other languages
57
56
 
58
- # Register the asynchronous responses UDF
57
+ # Register the asynchronous responses UDF with performance tuning
59
58
  spark.udf.register(
60
59
  "translate_async",
61
- resp_builder.build(
60
+ responses_udf(
62
61
  instructions="Translate the text to multiple languages.",
63
62
  response_format=Translation,
63
+ model_name="gpt-4.1-mini", # For Azure: deployment name, for OpenAI: model name
64
+ batch_size=64, # Rows per API request within partition
65
+ max_concurrency=8 # Concurrent requests PER EXECUTOR
64
66
  ),
65
67
  )
66
68
 
67
- # Or use a predefined task with build_from_task method
69
+ # Or use a predefined task with task_udf
68
70
  from openaivec.task import nlp
69
71
  spark.udf.register(
70
72
  "sentiment_async",
71
- resp_builder.build_from_task(nlp.SENTIMENT_ANALYSIS),
73
+ task_udf(nlp.SENTIMENT_ANALYSIS),
72
74
  )
73
75
 
74
- # Register the asynchronous embeddings UDF
76
+ # Register the asynchronous embeddings UDF with performance tuning
75
77
  spark.udf.register(
76
78
  "embed_async",
77
- emb_builder.build(),
79
+ embeddings_udf(
80
+ model_name="text-embedding-3-small", # For Azure: deployment name, for OpenAI: model name
81
+ batch_size=128, # Larger batches for embeddings
82
+ max_concurrency=8 # Concurrent requests PER EXECUTOR
83
+ ),
78
84
  )
85
+
86
+ # Register token counting, text chunking, and similarity UDFs
87
+ spark.udf.register("count_tokens", count_tokens_udf())
88
+ spark.udf.register("split_chunks", split_to_chunks_udf(max_tokens=512, sep=[".", "!", "?"]))
89
+ spark.udf.register("compute_similarity", similarity_udf())
79
90
  ```
80
91
 
81
92
  You can now invoke the UDFs from Spark SQL:
@@ -85,79 +96,194 @@ SELECT
85
96
  text,
86
97
  translate_async(text) AS translation,
87
98
  sentiment_async(text) AS sentiment,
88
- embed_async(text) AS embedding
99
+ embed_async(text) AS embedding,
100
+ count_tokens(text) AS token_count,
101
+ split_chunks(text) AS chunks,
102
+ compute_similarity(embed_async(text1), embed_async(text2)) AS similarity
89
103
  FROM your_table;
90
104
  ```
91
105
 
106
+ ## Performance Considerations
107
+
108
+ When using these UDFs in distributed Spark environments:
109
+
110
+ - **`batch_size`**: Controls rows processed per API request within each partition.
111
+ Recommended: 32-128 for responses, 64-256 for embeddings.
112
+
113
+ - **`max_concurrency`**: Sets concurrent API requests **PER EXECUTOR**, not per cluster.
114
+ Total cluster concurrency = max_concurrency × number_of_executors.
115
+ Recommended: 4-12 per executor to avoid overwhelming OpenAI rate limits.
116
+
117
+ - **Rate Limit Management**: Monitor OpenAI API usage when scaling executors.
118
+ Consider your OpenAI tier limits and adjust max_concurrency accordingly.
119
+
120
+ Example for a 5-executor cluster with max_concurrency=8:
121
+ Total concurrent requests = 8 × 5 = 40 simultaneous API calls.
122
+
92
123
  Note: This module provides asynchronous support through the pandas extensions.
93
124
  """
94
125
 
95
126
  import asyncio
96
- from dataclasses import dataclass
97
- from typing import Dict, Iterator, List, Type, TypeVar, Union, get_args, get_origin, Optional
98
- from typing_extensions import Literal
99
- from enum import Enum
100
127
  import logging
101
- from pyspark.sql.pandas.functions import pandas_udf
102
- from pyspark.sql.udf import UserDefinedFunction
103
- from pyspark.sql.types import BooleanType, IntegerType, StringType, ArrayType, FloatType, StructField, StructType
104
- from openai import AsyncOpenAI, AsyncAzureOpenAI
105
- import tiktoken
106
- from . import pandas_ext
128
+ import os
129
+ from collections.abc import Iterator
130
+ from enum import Enum
131
+ from typing import Union, get_args, get_origin
132
+
133
+ import numpy as np
107
134
  import pandas as pd
135
+ import tiktoken
108
136
  from pydantic import BaseModel
137
+ from pyspark import SparkContext
138
+ from pyspark.sql import SparkSession
139
+ from pyspark.sql.pandas.functions import pandas_udf
140
+ from pyspark.sql.types import ArrayType, BooleanType, FloatType, IntegerType, StringType, StructField, StructType
141
+ from pyspark.sql.udf import UserDefinedFunction
142
+ from typing_extensions import Literal
109
143
 
110
- from .serialize import deserialize_base_model, serialize_base_model
111
- from .util import TextChunker
112
- from .task.model import PreparedTask
144
+ from openaivec import pandas_ext
145
+ from openaivec._cache import AsyncBatchingMapProxy
146
+ from openaivec._model import EmbeddingsModelName, PreparedTask, ResponseFormat, ResponsesModelName
147
+ from openaivec._provider import CONTAINER
148
+ from openaivec._schema import SchemaInferenceInput, SchemaInferenceOutput, SchemaInferer
149
+ from openaivec._serialize import deserialize_base_model, serialize_base_model
150
+ from openaivec._util import TextChunker
113
151
 
114
152
  __all__ = [
115
- "ResponsesUDFBuilder",
116
- "EmbeddingsUDFBuilder",
153
+ "setup",
154
+ "setup_azure",
155
+ "responses_udf",
156
+ "task_udf",
157
+ "embeddings_udf",
158
+ "infer_schema",
159
+ "parse_udf",
117
160
  "split_to_chunks_udf",
118
161
  "count_tokens_udf",
119
162
  "similarity_udf",
120
163
  ]
121
164
 
122
- ResponseFormat = BaseModel | Type[str]
123
- T = TypeVar("T", bound=BaseModel)
124
165
 
125
- _INITIALIZED: bool = False
126
166
  _LOGGER: logging.Logger = logging.getLogger(__name__)
127
- _TIKTOKEN_ENC: tiktoken.Encoding | None = None
128
167
 
129
168
 
130
- def _initialize(api_key: str, endpoint: str | None, api_version: str | None) -> None:
131
- """Initializes the OpenAI client for asynchronous operations.
169
+ def setup(
170
+ spark: SparkSession, api_key: str, responses_model_name: str | None = None, embeddings_model_name: str | None = None
171
+ ):
172
+ """Setup OpenAI authentication and default model names in Spark environment.
173
+ 1. Configures OpenAI API key in SparkContext environment.
174
+ 2. Configures OpenAI API key in local process environment.
175
+ 3. Optionally registers default model names for responses and embeddings in the DI container.
176
+
177
+ Args:
178
+ spark (SparkSession): The Spark session to configure.
179
+ api_key (str): OpenAI API key for authentication.
180
+ responses_model_name (str | None): Default model name for response generation.
181
+ If provided, registers `ResponsesModelName` in the DI container.
182
+ embeddings_model_name (str | None): Default model name for embeddings.
183
+ If provided, registers `EmbeddingsModelName` in the DI container.
184
+
185
+ Example:
186
+ ```python
187
+ from pyspark.sql import SparkSession
188
+ from openaivec.spark import setup
189
+
190
+ spark = SparkSession.builder.getOrCreate()
191
+ setup(
192
+ spark,
193
+ api_key="sk-***",
194
+ responses_model_name="gpt-4.1-mini",
195
+ embeddings_model_name="text-embedding-3-small",
196
+ )
197
+ ```
198
+ """
199
+
200
+ CONTAINER.register(SparkSession, lambda: spark)
201
+ CONTAINER.register(SparkContext, lambda: CONTAINER.resolve(SparkSession).sparkContext)
202
+
203
+ sc = CONTAINER.resolve(SparkContext)
204
+ sc.environment["OPENAI_API_KEY"] = api_key
205
+
206
+ os.environ["OPENAI_API_KEY"] = api_key
207
+
208
+ if responses_model_name:
209
+ CONTAINER.register(ResponsesModelName, lambda: ResponsesModelName(responses_model_name))
210
+
211
+ if embeddings_model_name:
212
+ CONTAINER.register(EmbeddingsModelName, lambda: EmbeddingsModelName(embeddings_model_name))
213
+
214
+ CONTAINER.clear_singletons()
132
215
 
133
- This function sets up the global asynchronous OpenAI client instance
134
- (either `AsyncOpenAI` or `AsyncAzureOpenAI`) used by the UDFs in this
135
- module. It ensures the client is initialized only once.
136
216
 
217
+ def setup_azure(
218
+ spark: SparkSession,
219
+ api_key: str,
220
+ base_url: str,
221
+ api_version: str = "preview",
222
+ responses_model_name: str | None = None,
223
+ embeddings_model_name: str | None = None,
224
+ ):
225
+ """Setup Azure OpenAI authentication and default model names in Spark environment.
226
+ 1. Configures Azure OpenAI API key, base URL, and API version in SparkContext environment.
227
+ 2. Configures Azure OpenAI API key, base URL, and API version in local process environment.
228
+ 3. Optionally registers default model names for responses and embeddings in the DI container.
137
229
  Args:
138
- api_key: The OpenAI or Azure OpenAI API key.
139
- endpoint: The Azure OpenAI endpoint URL. Required for Azure.
140
- api_version: The Azure OpenAI API version. Required for Azure.
230
+ spark (SparkSession): The Spark session to configure.
231
+ api_key (str): Azure OpenAI API key for authentication.
232
+ base_url (str): Base URL for the Azure OpenAI resource.
233
+ api_version (str): API version to use. Defaults to "preview".
234
+ responses_model_name (str | None): Default model name for response generation.
235
+ If provided, registers `ResponsesModelName` in the DI container.
236
+ embeddings_model_name (str | None): Default model name for embeddings.
237
+ If provided, registers `EmbeddingsModelName` in the DI container.
238
+
239
+ Example:
240
+ ```python
241
+ from pyspark.sql import SparkSession
242
+ from openaivec.spark import setup_azure
243
+
244
+ spark = SparkSession.builder.getOrCreate()
245
+ setup_azure(
246
+ spark,
247
+ api_key="azure-key",
248
+ base_url="https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/",
249
+ api_version="preview",
250
+ responses_model_name="gpt4-deployment",
251
+ embeddings_model_name="embedding-deployment",
252
+ )
253
+ ```
141
254
  """
142
- global _INITIALIZED
143
- if not _INITIALIZED:
144
- if endpoint and api_version:
145
- pandas_ext.use_async(AsyncAzureOpenAI(api_key=api_key, azure_endpoint=endpoint, api_version=api_version))
146
- else:
147
- pandas_ext.use_async(AsyncOpenAI(api_key=api_key))
148
- _INITIALIZED = True
255
+
256
+ CONTAINER.register(SparkSession, lambda: spark)
257
+ CONTAINER.register(SparkContext, lambda: CONTAINER.resolve(SparkSession).sparkContext)
258
+
259
+ sc = CONTAINER.resolve(SparkContext)
260
+ sc.environment["AZURE_OPENAI_API_KEY"] = api_key
261
+ sc.environment["AZURE_OPENAI_BASE_URL"] = base_url
262
+ sc.environment["AZURE_OPENAI_API_VERSION"] = api_version
263
+
264
+ os.environ["AZURE_OPENAI_API_KEY"] = api_key
265
+ os.environ["AZURE_OPENAI_BASE_URL"] = base_url
266
+ os.environ["AZURE_OPENAI_API_VERSION"] = api_version
267
+
268
+ if responses_model_name:
269
+ CONTAINER.register(ResponsesModelName, lambda: ResponsesModelName(responses_model_name))
270
+
271
+ if embeddings_model_name:
272
+ CONTAINER.register(EmbeddingsModelName, lambda: EmbeddingsModelName(embeddings_model_name))
273
+
274
+ CONTAINER.clear_singletons()
149
275
 
150
276
 
151
277
  def _python_type_to_spark(python_type):
152
278
  origin = get_origin(python_type)
153
279
 
154
- # For list types (e.g., List[int])
155
- if origin is list or origin is List:
280
+ # For list types (e.g., list[int])
281
+ if origin is list:
156
282
  # Retrieve the inner type and recursively convert it
157
283
  inner_type = get_args(python_type)[0]
158
284
  return ArrayType(_python_type_to_spark(inner_type))
159
285
 
160
- # For Optional types (Union[..., None])
286
+ # For Optional types (T | None via Union internally)
161
287
  elif origin is Union:
162
288
  non_none_args = [arg for arg in get_args(python_type) if arg is not type(None)]
163
289
  if len(non_none_args) == 1:
@@ -170,7 +296,7 @@ def _python_type_to_spark(python_type):
170
296
  return StringType()
171
297
 
172
298
  # For Enum types - also treat as StringType since Spark doesn't have enum types
173
- elif hasattr(python_type, '__bases__') and Enum in python_type.__bases__:
299
+ elif hasattr(python_type, "__bases__") and Enum in python_type.__bases__:
174
300
  return StringType()
175
301
 
176
302
  # For nested Pydantic models (to be treated as Structs)
@@ -190,7 +316,7 @@ def _python_type_to_spark(python_type):
190
316
  raise ValueError(f"Unsupported type: {python_type}")
191
317
 
192
318
 
193
- def _pydantic_to_spark_schema(model: Type[BaseModel]) -> StructType:
319
+ def _pydantic_to_spark_schema(model: type[BaseModel]) -> StructType:
194
320
  fields = []
195
321
  for field_name, field in model.model_fields.items():
196
322
  field_type = field.annotation
@@ -201,7 +327,7 @@ def _pydantic_to_spark_schema(model: Type[BaseModel]) -> StructType:
201
327
  return StructType(fields)
202
328
 
203
329
 
204
- def _safe_cast_str(x: Optional[str]) -> Optional[str]:
330
+ def _safe_cast_str(x: str | None) -> str | None:
205
331
  try:
206
332
  if x is None:
207
333
  return None
@@ -212,7 +338,7 @@ def _safe_cast_str(x: Optional[str]) -> Optional[str]:
212
338
  return None
213
339
 
214
340
 
215
- def _safe_dump(x: Optional[BaseModel]) -> Dict:
341
+ def _safe_dump(x: BaseModel | None) -> dict:
216
342
  try:
217
343
  if x is None:
218
344
  return {}
@@ -223,347 +349,490 @@ def _safe_dump(x: Optional[BaseModel]) -> Dict:
223
349
  return {}
224
350
 
225
351
 
226
- @dataclass(frozen=True)
227
- class ResponsesUDFBuilder:
228
- """Builder for asynchronous Spark pandas UDFs for generating responses.
352
+ def responses_udf(
353
+ instructions: str,
354
+ response_format: type[ResponseFormat] = str,
355
+ model_name: str | None = None,
356
+ batch_size: int | None = None,
357
+ max_concurrency: int = 8,
358
+ **api_kwargs,
359
+ ) -> UserDefinedFunction:
360
+ """Create an asynchronous Spark pandas UDF for generating responses.
229
361
 
230
- Configures and builds UDFs that leverage `pandas_ext.aio.responses`
362
+ Configures and builds UDFs that leverage `pandas_ext.aio.responses_with_cache`
231
363
  to generate text or structured responses from OpenAI models asynchronously.
232
- An instance stores authentication parameters and the model name.
364
+ Each partition maintains its own cache to eliminate duplicate API calls within
365
+ the partition, significantly reducing API usage and costs when processing
366
+ datasets with overlapping content.
367
+
368
+ Note:
369
+ Authentication must be configured via SparkContext environment variables.
370
+ Set the appropriate environment variables on the SparkContext:
371
+
372
+ For OpenAI:
373
+ sc.environment["OPENAI_API_KEY"] = "your-openai-api-key"
233
374
 
234
- This builder supports two main methods:
235
- - `build()`: Creates UDFs with custom instructions and response formats
236
- - `build_from_task()`: Creates UDFs from predefined tasks (e.g., sentiment analysis)
375
+ For Azure OpenAI:
376
+ sc.environment["AZURE_OPENAI_API_KEY"] = "your-azure-openai-api-key"
377
+ sc.environment["AZURE_OPENAI_BASE_URL"] = "https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/"
378
+ sc.environment["AZURE_OPENAI_API_VERSION"] = "preview"
379
+
380
+ Args:
381
+ instructions (str): The system prompt or instructions for the model.
382
+ response_format (type[ResponseFormat]): The desired output format. Either `str` for plain text
383
+ or a Pydantic `BaseModel` for structured JSON output. Defaults to `str`.
384
+ model_name (str | None): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
385
+ For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container
386
+ via ResponsesModelName if not provided.
387
+ batch_size (int | None): Number of rows per async batch request within each partition.
388
+ Larger values reduce API call overhead but increase memory usage.
389
+ Defaults to None (automatic batch size optimization that dynamically
390
+ adjusts based on execution time, targeting 30-60 seconds per batch).
391
+ Set to a positive integer (e.g., 32-128) for fixed batch size.
392
+ max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
393
+ Total cluster concurrency = max_concurrency × number_of_executors.
394
+ Higher values increase throughput but may hit OpenAI rate limits.
395
+ Recommended: 4-12 per executor. Defaults to 8.
396
+ **api_kwargs: Additional OpenAI API parameters (e.g. ``temperature``, ``top_p``,
397
+ ``frequency_penalty``, ``presence_penalty``, ``seed``, ``max_output_tokens``, etc.)
398
+ forwarded verbatim to the underlying API calls. These parameters are applied to
399
+ all API requests made by the UDF.
237
400
 
238
- Attributes:
239
- api_key (str): OpenAI or Azure API key.
240
- endpoint (Optional[str]): Azure endpoint base URL. None for public OpenAI.
241
- api_version (Optional[str]): Azure API version. Ignored for public OpenAI.
242
- model_name (str): Deployment name (Azure) or model name (OpenAI) for responses.
401
+ Returns:
402
+ UserDefinedFunction: A Spark pandas UDF configured to generate responses asynchronously.
403
+ Output schema is `StringType` or a struct derived from `response_format`.
404
+
405
+ Raises:
406
+ ValueError: If `response_format` is not `str` or a Pydantic `BaseModel`.
407
+
408
+ Example:
409
+ ```python
410
+ from pyspark.sql import SparkSession
411
+ from openaivec.spark import responses_udf, setup
412
+
413
+ spark = SparkSession.builder.getOrCreate()
414
+ setup(spark, api_key="sk-***", responses_model_name="gpt-4.1-mini")
415
+ udf = responses_udf("Reply with one word.")
416
+ spark.udf.register("short_answer", udf)
417
+ df = spark.createDataFrame([("hello",), ("bye",)], ["text"])
418
+ df.selectExpr("short_answer(text) as reply").show()
419
+ ```
420
+
421
+ Note:
422
+ For optimal performance in distributed environments:
423
+ - **Automatic Caching**: Duplicate inputs within each partition are cached,
424
+ reducing API calls and costs significantly on datasets with repeated content
425
+ - Monitor OpenAI API rate limits when scaling executor count
426
+ - Consider your OpenAI tier limits: total_requests = max_concurrency × executors
427
+ - Use Spark UI to optimize partition sizes relative to batch_size
243
428
  """
429
+ _model_name = model_name or CONTAINER.resolve(ResponsesModelName).value
244
430
 
245
- # Params for OpenAI SDK
246
- api_key: str
247
- endpoint: str | None
248
- api_version: str | None
249
-
250
- # Params for Responses API
251
- model_name: str
252
-
253
- @classmethod
254
- def of_openai(cls, api_key: str, model_name: str) -> "ResponsesUDFBuilder":
255
- """Creates a builder configured for the public OpenAI API.
256
-
257
- Args:
258
- api_key (str): The OpenAI API key.
259
- model_name (str): The OpenAI model name for responses (e.g., "gpt-4o-mini").
260
-
261
- Returns:
262
- ResponsesUDFBuilder: A builder instance configured for OpenAI responses.
263
- """
264
- return cls(api_key=api_key, endpoint=None, api_version=None, model_name=model_name)
265
-
266
- @classmethod
267
- def of_azure_openai(cls, api_key: str, endpoint: str, api_version: str, model_name: str) -> "ResponsesUDFBuilder":
268
- """Creates a builder configured for Azure OpenAI.
269
-
270
- Args:
271
- api_key (str): The Azure OpenAI API key.
272
- endpoint (str): The Azure OpenAI endpoint URL.
273
- api_version (str): The Azure OpenAI API version (e.g., "2024-02-01").
274
- model_name (str): The Azure OpenAI deployment name for responses.
275
-
276
- Returns:
277
- ResponsesUDFBuilder: A builder instance configured for Azure OpenAI responses.
278
- """
279
- return cls(api_key=api_key, endpoint=endpoint, api_version=api_version, model_name=model_name)
280
-
281
- def build(
282
- self,
283
- instructions: str,
284
- response_format: Type[T] = str,
285
- batch_size: int = 128, # Default batch size for async might differ
286
- temperature: float = 0.0,
287
- top_p: float = 1.0,
288
- max_concurrency: int = 8,
289
- ) -> UserDefinedFunction:
290
- """Builds the asynchronous pandas UDF for generating responses.
291
-
292
- Args:
293
- instructions (str): The system prompt or instructions for the model.
294
- response_format (Type[T]): The desired output format. Either `str` for plain text
295
- or a Pydantic `BaseModel` for structured JSON output. Defaults to `str`.
296
- batch_size (int): Number of rows per async batch request passed to the underlying
297
- `pandas_ext` function. Defaults to 128.
298
- temperature (float): Sampling temperature (0.0 to 2.0). Defaults to 0.0.
299
- top_p (float): Nucleus sampling parameter. Defaults to 1.0.
300
-
301
- Returns:
302
- UserDefinedFunction: A Spark pandas UDF configured to generate responses asynchronously.
303
- Output schema is `StringType` or a struct derived from `response_format`.
304
-
305
- Raises:
306
- ValueError: If `response_format` is not `str` or a Pydantic `BaseModel`.
307
- """
308
- if issubclass(response_format, BaseModel):
309
- spark_schema = _pydantic_to_spark_schema(response_format)
310
- json_schema_string = serialize_base_model(response_format)
311
-
312
- @pandas_udf(returnType=spark_schema)
313
- def structure_udf(col: Iterator[pd.Series]) -> Iterator[pd.DataFrame]:
314
- _initialize(self.api_key, self.endpoint, self.api_version)
315
- pandas_ext.responses_model(self.model_name)
431
+ if issubclass(response_format, BaseModel):
432
+ spark_schema = _pydantic_to_spark_schema(response_format)
433
+ json_schema_string = serialize_base_model(response_format)
434
+
435
+ @pandas_udf(returnType=spark_schema) # type: ignore[call-overload]
436
+ def structure_udf(col: Iterator[pd.Series]) -> Iterator[pd.DataFrame]:
437
+ pandas_ext.set_responses_model(_model_name)
438
+ response_format = deserialize_base_model(json_schema_string)
439
+ cache = AsyncBatchingMapProxy[str, response_format](
440
+ batch_size=batch_size,
441
+ max_concurrency=max_concurrency,
442
+ )
316
443
 
444
+ try:
317
445
  for part in col:
318
446
  predictions: pd.Series = asyncio.run(
319
- part.aio.responses(
447
+ part.aio.responses_with_cache(
320
448
  instructions=instructions,
321
- response_format=deserialize_base_model(json_schema_string),
322
- batch_size=batch_size,
323
- temperature=temperature,
324
- top_p=top_p,
325
- max_concurrency=max_concurrency,
449
+ response_format=response_format,
450
+ cache=cache,
451
+ **api_kwargs,
326
452
  )
327
453
  )
328
454
  yield pd.DataFrame(predictions.map(_safe_dump).tolist())
455
+ finally:
456
+ asyncio.run(cache.clear())
329
457
 
330
- return structure_udf
458
+ return structure_udf # type: ignore[return-value]
331
459
 
332
- elif issubclass(response_format, str):
460
+ elif issubclass(response_format, str):
333
461
 
334
- @pandas_udf(returnType=StringType())
335
- def string_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
336
- _initialize(self.api_key, self.endpoint, self.api_version)
337
- pandas_ext.responses_model(self.model_name)
462
+ @pandas_udf(returnType=StringType()) # type: ignore[call-overload]
463
+ def string_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
464
+ pandas_ext.set_responses_model(_model_name)
465
+ cache = AsyncBatchingMapProxy[str, str](
466
+ batch_size=batch_size,
467
+ max_concurrency=max_concurrency,
468
+ )
338
469
 
470
+ try:
339
471
  for part in col:
340
472
  predictions: pd.Series = asyncio.run(
341
- part.aio.responses(
473
+ part.aio.responses_with_cache(
342
474
  instructions=instructions,
343
475
  response_format=str,
344
- batch_size=batch_size,
345
- temperature=temperature,
346
- top_p=top_p,
347
- max_concurrency=max_concurrency,
476
+ cache=cache,
477
+ **api_kwargs,
348
478
  )
349
479
  )
350
480
  yield predictions.map(_safe_cast_str)
481
+ finally:
482
+ asyncio.run(cache.clear())
351
483
 
352
- return string_udf
484
+ return string_udf # type: ignore[return-value]
353
485
 
354
- else:
355
- raise ValueError(f"Unsupported response_format: {response_format}")
356
-
357
- def build_from_task(
358
- self,
359
- task: PreparedTask,
360
- batch_size: int = 128,
361
- max_concurrency: int = 8,
362
- ) -> UserDefinedFunction:
363
- """Builds the asynchronous pandas UDF from a predefined task.
364
-
365
- This method allows users to create UDFs from predefined tasks such as sentiment analysis,
366
- translation, or other common NLP operations defined in the openaivec.task module.
367
-
368
- Args:
369
- task (PreparedTask): A predefined task configuration containing instructions,
370
- response format, temperature, and top_p settings.
371
- batch_size (int): Number of rows per async batch request passed to the underlying
372
- `pandas_ext` function. Defaults to 128.
373
- max_concurrency (int): Maximum number of concurrent requests. Defaults to 8.
374
-
375
- Returns:
376
- UserDefinedFunction: A Spark pandas UDF configured to execute the specified task
377
- asynchronously, returning a struct derived from the task's response format.
378
-
379
- Example:
380
- ```python
381
- from openaivec.task import nlp
382
-
383
- builder = ResponsesUDFBuilder.of_openai(
384
- api_key="your-api-key",
385
- model_name="gpt-4o-mini"
386
- )
387
-
388
- sentiment_udf = builder.build_from_task(nlp.SENTIMENT_ANALYSIS)
389
-
390
- spark.udf.register("analyze_sentiment", sentiment_udf)
391
- ```
392
- """
393
- # Serialize task parameters for Spark serialization compatibility
394
- task_instructions = task.instructions
395
- task_response_format_json = serialize_base_model(task.response_format)
396
- task_temperature = task.temperature
397
- task_top_p = task.top_p
398
-
399
- # Deserialize the response format from JSON
400
- response_format = deserialize_base_model(task_response_format_json)
401
- spark_schema = _pydantic_to_spark_schema(response_format)
486
+ else:
487
+ raise ValueError(f"Unsupported response_format: {response_format}")
402
488
 
403
- @pandas_udf(returnType=spark_schema)
404
- def task_udf(col: Iterator[pd.Series]) -> Iterator[pd.DataFrame]:
405
- _initialize(self.api_key, self.endpoint, self.api_version)
406
- pandas_ext.responses_model(self.model_name)
407
489
 
408
- for part in col:
409
- predictions: pd.Series = asyncio.run(
410
- part.aio.responses(
411
- instructions=task_instructions,
412
- response_format=response_format,
413
- batch_size=batch_size,
414
- temperature=task_temperature,
415
- top_p=task_top_p,
416
- max_concurrency=max_concurrency,
417
- )
418
- )
419
- yield pd.DataFrame(predictions.map(_safe_dump).tolist())
490
+ def task_udf(
491
+ task: PreparedTask[ResponseFormat],
492
+ model_name: str | None = None,
493
+ batch_size: int | None = None,
494
+ max_concurrency: int = 8,
495
+ **api_kwargs,
496
+ ) -> UserDefinedFunction:
497
+ """Create an asynchronous Spark pandas UDF from a predefined task.
498
+
499
+ This function allows users to create UDFs from predefined tasks such as sentiment analysis,
500
+ translation, or other common NLP operations defined in the openaivec.task module.
501
+ Each partition maintains its own cache to eliminate duplicate API calls within
502
+ the partition, significantly reducing API usage and costs when processing
503
+ datasets with overlapping content.
420
504
 
421
- return task_udf
505
+ Args:
506
+ task (PreparedTask): A predefined task configuration containing instructions
507
+ and response format.
508
+ model_name (str | None): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
509
+ For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container
510
+ via ResponsesModelName if not provided.
511
+ batch_size (int | None): Number of rows per async batch request within each partition.
512
+ Larger values reduce API call overhead but increase memory usage.
513
+ Defaults to None (automatic batch size optimization that dynamically
514
+ adjusts based on execution time, targeting 30-60 seconds per batch).
515
+ Set to a positive integer (e.g., 32-128) for fixed batch size.
516
+ max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
517
+ Total cluster concurrency = max_concurrency × number_of_executors.
518
+ Higher values increase throughput but may hit OpenAI rate limits.
519
+ Recommended: 4-12 per executor. Defaults to 8.
520
+
521
+ Additional Keyword Args:
522
+ Arbitrary OpenAI Responses API parameters (e.g. ``temperature``, ``top_p``,
523
+ ``frequency_penalty``, ``presence_penalty``, ``seed``, ``max_output_tokens``, etc.)
524
+ are forwarded verbatim to the underlying API calls. These parameters are applied to
525
+ all API requests made by the UDF.
422
526
 
527
+ Returns:
528
+ UserDefinedFunction: A Spark pandas UDF configured to execute the specified task
529
+ asynchronously with automatic caching for duplicate inputs within each partition.
530
+ Output schema is StringType for str response format or a struct derived from
531
+ the task's response format for BaseModel.
423
532
 
424
- @dataclass(frozen=True)
425
- class EmbeddingsUDFBuilder:
426
- """Builder for asynchronous Spark pandas UDFs for creating embeddings.
533
+ Example:
534
+ ```python
535
+ from openaivec.task import nlp
427
536
 
428
- Configures and builds UDFs that leverage `pandas_ext.aio.embeddings`
429
- to generate vector embeddings from OpenAI models asynchronously.
430
- An instance stores authentication parameters and the model name.
537
+ sentiment_udf = task_udf(nlp.SENTIMENT_ANALYSIS)
431
538
 
432
- Attributes:
433
- api_key (str): OpenAI or Azure API key.
434
- endpoint (Optional[str]): Azure endpoint base URL. None for public OpenAI.
435
- api_version (Optional[str]): Azure API version. Ignored for public OpenAI.
436
- model_name (str): Deployment name (Azure) or model name (OpenAI) for embeddings.
539
+ spark.udf.register("analyze_sentiment", sentiment_udf)
540
+ ```
541
+
542
+ Note:
543
+ **Automatic Caching**: Duplicate inputs within each partition are cached,
544
+ reducing API calls and costs significantly on datasets with repeated content.
437
545
  """
546
+ return responses_udf(
547
+ instructions=task.instructions,
548
+ response_format=task.response_format,
549
+ model_name=model_name,
550
+ batch_size=batch_size,
551
+ max_concurrency=max_concurrency,
552
+ **api_kwargs,
553
+ )
554
+
555
+
556
+ def infer_schema(
557
+ instructions: str,
558
+ example_table_name: str,
559
+ example_field_name: str,
560
+ max_examples: int = 100,
561
+ ) -> SchemaInferenceOutput:
562
+ """Infer the schema for a response format based on example data.
563
+
564
+ This function retrieves examples from a Spark table and infers the schema
565
+ for the response format using the provided instructions. It is useful when
566
+ you want to dynamically generate a schema based on existing data.
438
567
 
439
- # Params for OpenAI SDK
440
- api_key: str
441
- endpoint: str | None
442
- api_version: str | None
568
+ Args:
569
+ instructions (str): Instructions for the model to infer the schema.
570
+ example_table_name (str | None): Name of the Spark table containing example data.
571
+ example_field_name (str | None): Name of the field in the table to use as examples.
572
+ max_examples (int): Maximum number of examples to retrieve for schema inference.
443
573
 
444
- # Params for Embeddings API
445
- model_name: str
574
+ Returns:
575
+ InferredSchema: An object containing the inferred schema and response format.
576
+
577
+ Example:
578
+ ```python
579
+ from pyspark.sql import SparkSession
580
+
581
+ spark = SparkSession.builder.getOrCreate()
582
+ spark.createDataFrame([("great product",), ("bad service",)], ["text"]).createOrReplaceTempView("examples")
583
+ infer_schema(
584
+ instructions="Classify sentiment as positive or negative.",
585
+ example_table_name="examples",
586
+ example_field_name="text",
587
+ max_examples=2,
588
+ )
589
+ ```
590
+ """
446
591
 
447
- @classmethod
448
- def of_openai(cls, api_key: str, model_name: str) -> "EmbeddingsUDFBuilder":
449
- """Creates a builder configured for the public OpenAI API.
592
+ spark = CONTAINER.resolve(SparkSession)
593
+ examples: list[str] = (
594
+ spark.table(example_table_name).rdd.map(lambda row: row[example_field_name]).takeSample(False, max_examples)
595
+ )
596
+
597
+ input = SchemaInferenceInput(
598
+ instructions=instructions,
599
+ examples=examples,
600
+ )
601
+ inferer = CONTAINER.resolve(SchemaInferer)
602
+ return inferer.infer_schema(input)
603
+
604
+
605
+ def parse_udf(
606
+ instructions: str,
607
+ response_format: type[ResponseFormat] | None = None,
608
+ example_table_name: str | None = None,
609
+ example_field_name: str | None = None,
610
+ max_examples: int = 100,
611
+ model_name: str | None = None,
612
+ batch_size: int | None = None,
613
+ max_concurrency: int = 8,
614
+ **api_kwargs,
615
+ ) -> UserDefinedFunction:
616
+ """Create an asynchronous Spark pandas UDF for parsing responses.
617
+ This function allows users to create UDFs that parse responses based on
618
+ provided instructions and either a predefined response format or example data.
619
+ It supports both structured responses using Pydantic models and plain text responses.
620
+ Each partition maintains its own cache to eliminate duplicate API calls within
621
+ the partition, significantly reducing API usage and costs when processing
622
+ datasets with overlapping content.
450
623
 
451
- Args:
452
- api_key (str): The OpenAI API key.
453
- model_name (str): The OpenAI model name for embeddings (e.g., "text-embedding-3-small").
624
+ Args:
625
+ instructions (str): The system prompt or instructions for the model.
626
+ response_format (type[ResponseFormat] | None): The desired output format.
627
+ Either `str` for plain text or a Pydantic `BaseModel` for structured JSON output.
628
+ If not provided, the schema will be inferred from example data.
629
+ example_table_name (str | None): Name of the Spark table containing example data.
630
+ If provided, `example_field_name` must also be specified.
631
+ example_field_name (str | None): Name of the field in the table to use as examples.
632
+ If provided, `example_table_name` must also be specified.
633
+ max_examples (int): Maximum number of examples to retrieve for schema inference.
634
+ Defaults to 100.
635
+ model_name (str | None): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
636
+ For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container
637
+ via ResponsesModelName if not provided.
638
+ batch_size (int | None): Number of rows per async batch request within each partition.
639
+ Larger values reduce API call overhead but increase memory usage.
640
+ Defaults to None (automatic batch size optimization that dynamically
641
+ adjusts based on execution time, targeting 30-60 seconds per batch).
642
+ Set to a positive integer (e.g., 32-128) for fixed batch size
643
+ max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
644
+ Total cluster concurrency = max_concurrency × number_of_executors.
645
+ Higher values increase throughput but may hit OpenAI rate limits.
646
+ Recommended: 4-12 per executor. Defaults to 8.
647
+ **api_kwargs: Additional OpenAI API parameters (e.g. ``temperature``, ``top_p``,
648
+ ``frequency_penalty``, ``presence_penalty``, ``seed``, ``max_output_tokens``, etc.)
649
+ forwarded verbatim to the underlying API calls. These parameters are applied to
650
+ all API requests made by the UDF and override any parameters set in the
651
+ response_format or example data.
652
+ Example:
653
+ ```python
654
+ from pyspark.sql import SparkSession
655
+
656
+ spark = SparkSession.builder.getOrCreate()
657
+ spark.createDataFrame(
658
+ [("Order #123 delivered",), ("Order #456 delayed",)],
659
+ ["body"],
660
+ ).createOrReplaceTempView("messages")
661
+ udf = parse_udf(
662
+ instructions="Extract order id as `order_id` and status as `status`.",
663
+ example_table_name="messages",
664
+ example_field_name="body",
665
+ )
666
+ spark.udf.register("parse_ticket", udf)
667
+ spark.sql("SELECT parse_ticket(body) AS parsed FROM messages").show()
668
+ ```
669
+ Returns:
670
+ UserDefinedFunction: A Spark pandas UDF configured to parse responses asynchronously.
671
+ Output schema is `StringType` for str response format or a struct derived from
672
+ the response_format for BaseModel.
673
+ Raises:
674
+ ValueError: If neither `response_format` nor `example_table_name` and `example_field_name` are provided.
675
+ """
454
676
 
455
- Returns:
456
- EmbeddingsUDFBuilder: A builder instance configured for OpenAI embeddings.
457
- """
458
- return cls(api_key=api_key, endpoint=None, api_version=None, model_name=model_name)
677
+ if not response_format and not (example_field_name and example_table_name):
678
+ raise ValueError("Either response_format or example_table_name and example_field_name must be provided.")
679
+
680
+ schema: SchemaInferenceOutput | None = None
681
+
682
+ if not response_format:
683
+ schema = infer_schema(
684
+ instructions=instructions,
685
+ example_table_name=example_table_name,
686
+ example_field_name=example_field_name,
687
+ max_examples=max_examples,
688
+ )
689
+
690
+ return responses_udf(
691
+ instructions=schema.inference_prompt if schema else instructions,
692
+ response_format=schema.model if schema else response_format,
693
+ model_name=model_name,
694
+ batch_size=batch_size,
695
+ max_concurrency=max_concurrency,
696
+ **api_kwargs,
697
+ )
698
+
699
+
700
+ def embeddings_udf(
701
+ model_name: str | None = None,
702
+ batch_size: int | None = None,
703
+ max_concurrency: int = 8,
704
+ **api_kwargs,
705
+ ) -> UserDefinedFunction:
706
+ """Create an asynchronous Spark pandas UDF for generating embeddings.
707
+
708
+ Configures and builds UDFs that leverage `pandas_ext.aio.embeddings_with_cache`
709
+ to generate vector embeddings from OpenAI models asynchronously.
710
+ Each partition maintains its own cache to eliminate duplicate API calls within
711
+ the partition, significantly reducing API usage and costs when processing
712
+ datasets with overlapping content.
459
713
 
460
- @classmethod
461
- def of_azure_openai(cls, api_key: str, endpoint: str, api_version: str, model_name: str) -> "EmbeddingsUDFBuilder":
462
- """Creates a builder configured for Azure OpenAI.
714
+ Note:
715
+ Authentication must be configured via SparkContext environment variables.
716
+ Set the appropriate environment variables on the SparkContext:
463
717
 
464
- Args:
465
- api_key (str): The Azure OpenAI API key.
466
- endpoint (str): The Azure OpenAI endpoint URL.
467
- api_version (str): The Azure OpenAI API version (e.g., "2024-02-01").
468
- model_name (str): The Azure OpenAI deployment name for embeddings.
718
+ For OpenAI:
719
+ sc.environment["OPENAI_API_KEY"] = "your-openai-api-key"
469
720
 
470
- Returns:
471
- EmbeddingsUDFBuilder: A builder instance configured for Azure OpenAI embeddings.
472
- """
473
- return cls(api_key=api_key, endpoint=endpoint, api_version=api_version, model_name=model_name)
721
+ For Azure OpenAI:
722
+ sc.environment["AZURE_OPENAI_API_KEY"] = "your-azure-openai-api-key"
723
+ sc.environment["AZURE_OPENAI_BASE_URL"] = "https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/"
724
+ sc.environment["AZURE_OPENAI_API_VERSION"] = "preview"
474
725
 
475
- def build(self, batch_size: int = 128, max_concurrency: int = 8) -> UserDefinedFunction:
476
- """Builds the asynchronous pandas UDF for generating embeddings.
726
+ Args:
727
+ model_name (str | None): For Azure OpenAI, use your deployment name (e.g., "my-embedding-deployment").
728
+ For OpenAI, use the model name (e.g., "text-embedding-3-small").
729
+ Defaults to configured model in DI container via EmbeddingsModelName if not provided.
730
+ batch_size (int | None): Number of rows per async batch request within each partition.
731
+ Larger values reduce API call overhead but increase memory usage.
732
+ Defaults to None (automatic batch size optimization that dynamically
733
+ adjusts based on execution time, targeting 30-60 seconds per batch).
734
+ Set to a positive integer (e.g., 64-256) for fixed batch size.
735
+ Embeddings typically handle larger batches efficiently.
736
+ max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
737
+ Total cluster concurrency = max_concurrency × number_of_executors.
738
+ Higher values increase throughput but may hit OpenAI rate limits.
739
+ Recommended: 4-12 per executor. Defaults to 8.
740
+ **api_kwargs: Additional OpenAI API parameters (e.g., dimensions for text-embedding-3 models).
477
741
 
478
- Args:
479
- batch_size (int): Number of rows per async batch request passed to the underlying
480
- `pandas_ext` function. Defaults to 128.
742
+ Returns:
743
+ UserDefinedFunction: A Spark pandas UDF configured to generate embeddings asynchronously
744
+ with automatic caching for duplicate inputs within each partition,
745
+ returning an `ArrayType(FloatType())` column.
746
+
747
+ Note:
748
+ For optimal performance in distributed environments:
749
+ - **Automatic Caching**: Duplicate inputs within each partition are cached,
750
+ reducing API calls and costs significantly on datasets with repeated content
751
+ - Monitor OpenAI API rate limits when scaling executor count
752
+ - Consider your OpenAI tier limits: total_requests = max_concurrency × executors
753
+ - Embeddings API typically has higher throughput than chat completions
754
+ - Use larger batch_size for embeddings compared to response generation
755
+ """
481
756
 
482
- Returns:
483
- UserDefinedFunction: A Spark pandas UDF configured to generate embeddings asynchronously,
484
- returning an `ArrayType(FloatType())` column.
485
- """
757
+ _model_name = model_name or CONTAINER.resolve(EmbeddingsModelName).value
486
758
 
487
- @pandas_udf(returnType=ArrayType(FloatType()))
488
- def embeddings_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
489
- _initialize(self.api_key, self.endpoint, self.api_version)
490
- pandas_ext.embeddings_model(self.model_name)
759
+ @pandas_udf(returnType=ArrayType(FloatType())) # type: ignore[call-overload,misc]
760
+ def _embeddings_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
761
+ pandas_ext.set_embeddings_model(_model_name)
762
+ cache = AsyncBatchingMapProxy[str, np.ndarray](
763
+ batch_size=batch_size,
764
+ max_concurrency=max_concurrency,
765
+ )
491
766
 
767
+ try:
492
768
  for part in col:
493
- embeddings: pd.Series = asyncio.run(
494
- part.aio.embeddings(batch_size=batch_size, max_concurrency=max_concurrency)
495
- )
769
+ embeddings: pd.Series = asyncio.run(part.aio.embeddings_with_cache(cache=cache, **api_kwargs))
496
770
  yield embeddings.map(lambda x: x.tolist())
771
+ finally:
772
+ asyncio.run(cache.clear())
497
773
 
498
- return embeddings_udf
774
+ return _embeddings_udf # type: ignore[return-value]
499
775
 
500
776
 
501
-
502
- def split_to_chunks_udf(model_name: str, max_tokens: int, sep: List[str]) -> UserDefinedFunction:
777
+ def split_to_chunks_udf(max_tokens: int, sep: list[str]) -> UserDefinedFunction:
503
778
  """Create a pandas‑UDF that splits text into token‑bounded chunks.
504
779
 
505
780
  Args:
506
- model_name: Model identifier passed to *tiktoken*.
507
- max_tokens: Maximum tokens allowed per chunk.
508
- sep: Ordered list of separator strings used by ``TextChunker``.
781
+ max_tokens (int): Maximum tokens allowed per chunk.
782
+ sep (list[str]): Ordered list of separator strings used by ``TextChunker``.
509
783
 
510
784
  Returns:
511
785
  A pandas UDF producing an ``ArrayType(StringType())`` column whose
512
786
  values are lists of chunks respecting the ``max_tokens`` limit.
513
787
  """
514
788
 
515
- @pandas_udf(ArrayType(StringType()))
789
+ @pandas_udf(ArrayType(StringType())) # type: ignore[call-overload,misc]
516
790
  def fn(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
517
- global _TIKTOKEN_ENC
518
- if _TIKTOKEN_ENC is None:
519
- _TIKTOKEN_ENC = tiktoken.encoding_for_model(model_name)
520
-
521
- chunker = TextChunker(_TIKTOKEN_ENC)
791
+ encoding = tiktoken.get_encoding("o200k_base")
792
+ chunker = TextChunker(encoding)
522
793
 
523
794
  for part in col:
524
795
  yield part.map(lambda x: chunker.split(x, max_tokens=max_tokens, sep=sep) if isinstance(x, str) else [])
525
796
 
526
- return fn
797
+ return fn # type: ignore[return-value]
527
798
 
528
799
 
529
- def count_tokens_udf(model_name: str = "gpt-4o") -> UserDefinedFunction:
800
+ def count_tokens_udf() -> UserDefinedFunction:
530
801
  """Create a pandas‑UDF that counts tokens for every string cell.
531
802
 
532
803
  The UDF uses *tiktoken* to approximate tokenisation and caches the
533
804
  resulting ``Encoding`` object per executor.
534
805
 
535
- Args:
536
- model_name: Model identifier understood by ``tiktoken``.
537
-
538
806
  Returns:
539
807
  A pandas UDF producing an ``IntegerType`` column with token counts.
540
808
  """
541
809
 
542
- @pandas_udf(IntegerType())
810
+ @pandas_udf(IntegerType()) # type: ignore[call-overload]
543
811
  def fn(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
544
- global _TIKTOKEN_ENC
545
- if _TIKTOKEN_ENC is None:
546
- _TIKTOKEN_ENC = tiktoken.encoding_for_model(model_name)
812
+ encoding = tiktoken.get_encoding("o200k_base")
547
813
 
548
814
  for part in col:
549
- yield part.map(lambda x: len(_TIKTOKEN_ENC.encode(x)) if isinstance(x, str) else 0)
815
+ yield part.map(lambda x: len(encoding.encode(x)) if isinstance(x, str) else 0)
550
816
 
551
- return fn
817
+ return fn # type: ignore[return-value]
552
818
 
553
819
 
554
820
  def similarity_udf() -> UserDefinedFunction:
555
- @pandas_udf(FloatType())
821
+ """Create a pandas-UDF that computes cosine similarity between embedding vectors.
822
+
823
+ Returns:
824
+ UserDefinedFunction: A Spark pandas UDF that takes two embedding vector columns
825
+ and returns their cosine similarity as a FloatType column.
826
+ """
827
+
828
+ @pandas_udf(FloatType()) # type: ignore[call-overload]
556
829
  def fn(a: pd.Series, b: pd.Series) -> pd.Series:
557
- """Compute cosine similarity between two vectors.
830
+ # Import pandas_ext to ensure .ai accessor is available in Spark workers
831
+ from openaivec import pandas_ext
558
832
 
559
- Args:
560
- a: First vector.
561
- b: Second vector.
833
+ # Explicitly reference pandas_ext to satisfy linters
834
+ assert pandas_ext is not None
562
835
 
563
- Returns:
564
- Cosine similarity between the two vectors.
565
- """
566
- pandas_ext._wakeup()
567
836
  return pd.DataFrame({"a": a, "b": b}).ai.similarity("a", "b")
568
837
 
569
- return fn
838
+ return fn # type: ignore[return-value]