openaivec 0.12.5__py3-none-any.whl → 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. openaivec/__init__.py +13 -4
  2. openaivec/_cache/__init__.py +12 -0
  3. openaivec/_cache/optimize.py +109 -0
  4. openaivec/_cache/proxy.py +806 -0
  5. openaivec/{di.py → _di.py} +36 -12
  6. openaivec/_embeddings.py +203 -0
  7. openaivec/{log.py → _log.py} +2 -2
  8. openaivec/_model.py +113 -0
  9. openaivec/{prompt.py → _prompt.py} +95 -28
  10. openaivec/_provider.py +207 -0
  11. openaivec/_responses.py +511 -0
  12. openaivec/_schema/__init__.py +9 -0
  13. openaivec/_schema/infer.py +340 -0
  14. openaivec/_schema/spec.py +350 -0
  15. openaivec/_serialize.py +234 -0
  16. openaivec/{util.py → _util.py} +25 -85
  17. openaivec/pandas_ext.py +1496 -318
  18. openaivec/spark.py +485 -183
  19. openaivec/task/__init__.py +9 -7
  20. openaivec/task/customer_support/__init__.py +9 -15
  21. openaivec/task/customer_support/customer_sentiment.py +17 -15
  22. openaivec/task/customer_support/inquiry_classification.py +23 -22
  23. openaivec/task/customer_support/inquiry_summary.py +14 -13
  24. openaivec/task/customer_support/intent_analysis.py +21 -19
  25. openaivec/task/customer_support/response_suggestion.py +16 -16
  26. openaivec/task/customer_support/urgency_analysis.py +24 -25
  27. openaivec/task/nlp/__init__.py +4 -4
  28. openaivec/task/nlp/dependency_parsing.py +10 -12
  29. openaivec/task/nlp/keyword_extraction.py +11 -14
  30. openaivec/task/nlp/morphological_analysis.py +12 -14
  31. openaivec/task/nlp/named_entity_recognition.py +16 -18
  32. openaivec/task/nlp/sentiment_analysis.py +14 -11
  33. openaivec/task/nlp/translation.py +6 -9
  34. openaivec/task/table/__init__.py +2 -2
  35. openaivec/task/table/fillna.py +11 -11
  36. openaivec-1.0.10.dist-info/METADATA +399 -0
  37. openaivec-1.0.10.dist-info/RECORD +39 -0
  38. {openaivec-0.12.5.dist-info → openaivec-1.0.10.dist-info}/WHEEL +1 -1
  39. openaivec/embeddings.py +0 -172
  40. openaivec/model.py +0 -67
  41. openaivec/provider.py +0 -45
  42. openaivec/responses.py +0 -393
  43. openaivec/serialize.py +0 -225
  44. openaivec-0.12.5.dist-info/METADATA +0 -696
  45. openaivec-0.12.5.dist-info/RECORD +0 -33
  46. {openaivec-0.12.5.dist-info → openaivec-1.0.10.dist-info}/licenses/LICENSE +0 -0
openaivec/spark.py CHANGED
@@ -1,37 +1,51 @@
1
1
  """Asynchronous Spark UDFs for the OpenAI and Azure OpenAI APIs.
2
2
 
3
- This module provides functions (`responses_udf`, `task_udf`, `embeddings_udf`)
3
+ This module provides functions (`responses_udf`, `task_udf`, `embeddings_udf`,
4
+ `count_tokens_udf`, `split_to_chunks_udf`, `similarity_udf`, `parse_udf`)
4
5
  for creating asynchronous Spark UDFs that communicate with either the public
5
6
  OpenAI API or Azure OpenAI using the `openaivec.spark` subpackage.
6
- It supports UDFs for generating responses and creating embeddings asynchronously.
7
- The UDFs operate on Spark DataFrames and leverage asyncio for potentially
8
- improved performance in I/O-bound operations.
7
+ It supports UDFs for generating responses, creating embeddings, parsing text,
8
+ and computing similarities asynchronously. The UDFs operate on Spark DataFrames
9
+ and leverage asyncio for improved performance in I/O-bound operations.
10
+
11
+ **Performance Optimization**: All AI-powered UDFs (`responses_udf`, `task_udf`, `embeddings_udf`, `parse_udf`)
12
+ automatically cache duplicate inputs within each partition, significantly reducing
13
+ API calls and costs when processing datasets with overlapping content.
14
+
9
15
 
10
16
  ## Setup
11
17
 
12
18
  First, obtain a Spark session and configure authentication:
13
19
 
14
20
  ```python
15
- import os
16
21
  from pyspark.sql import SparkSession
22
+ from openaivec.spark import setup, setup_azure
17
23
 
18
24
  spark = SparkSession.builder.getOrCreate()
19
- sc = spark.sparkContext
20
25
 
21
- # Configure authentication via SparkContext environment variables
22
26
  # Option 1: Using OpenAI
23
- sc.environment["OPENAI_API_KEY"] = "your-openai-api-key"
27
+ setup(
28
+ spark,
29
+ api_key="your-openai-api-key",
30
+ responses_model_name="gpt-4.1-mini", # Optional: set default model
31
+ embeddings_model_name="text-embedding-3-small" # Optional: set default model
32
+ )
24
33
 
25
34
  # Option 2: Using Azure OpenAI
26
- # sc.environment["AZURE_OPENAI_API_KEY"] = "your-azure-openai-api-key"
27
- # sc.environment["AZURE_OPENAI_API_ENDPOINT"] = "your-azure-openai-endpoint"
28
- # sc.environment["AZURE_OPENAI_API_VERSION"] = "your-azure-openai-api-version"
35
+ # setup_azure(
36
+ # spark,
37
+ # api_key="your-azure-openai-api-key",
38
+ # base_url="https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/",
39
+ # api_version="preview",
40
+ # responses_model_name="my-gpt4-deployment", # Optional: set default deployment
41
+ # embeddings_model_name="my-embedding-deployment" # Optional: set default deployment
42
+ # )
29
43
  ```
30
44
 
31
45
  Next, create UDFs and register them:
32
46
 
33
47
  ```python
34
- from openaivec.spark import responses_udf, task_udf, embeddings_udf
48
+ from openaivec.spark import responses_udf, task_udf, embeddings_udf, count_tokens_udf, split_to_chunks_udf
35
49
  from pydantic import BaseModel
36
50
 
37
51
  # Define a Pydantic model for structured responses (optional)
@@ -46,7 +60,7 @@ spark.udf.register(
46
60
  responses_udf(
47
61
  instructions="Translate the text to multiple languages.",
48
62
  response_format=Translation,
49
- model_name="gpt-4.1-mini", # Optional, defaults to gpt-4.1-mini
63
+ model_name="gpt-4.1-mini", # For Azure: deployment name, for OpenAI: model name
50
64
  batch_size=64, # Rows per API request within partition
51
65
  max_concurrency=8 # Concurrent requests PER EXECUTOR
52
66
  ),
@@ -63,11 +77,16 @@ spark.udf.register(
63
77
  spark.udf.register(
64
78
  "embed_async",
65
79
  embeddings_udf(
66
- model_name="text-embedding-3-small", # Optional, defaults to text-embedding-3-small
80
+ model_name="text-embedding-3-small", # For Azure: deployment name, for OpenAI: model name
67
81
  batch_size=128, # Larger batches for embeddings
68
82
  max_concurrency=8 # Concurrent requests PER EXECUTOR
69
83
  ),
70
84
  )
85
+
86
+ # Register token counting, text chunking, and similarity UDFs
87
+ spark.udf.register("count_tokens", count_tokens_udf())
88
+ spark.udf.register("split_chunks", split_to_chunks_udf(max_tokens=512, sep=[".", "!", "?"]))
89
+ spark.udf.register("compute_similarity", similarity_udf())
71
90
  ```
72
91
 
73
92
  You can now invoke the UDFs from Spark SQL:
@@ -77,7 +96,10 @@ SELECT
77
96
  text,
78
97
  translate_async(text) AS translation,
79
98
  sentiment_async(text) AS sentiment,
80
- embed_async(text) AS embedding
99
+ embed_async(text) AS embedding,
100
+ count_tokens(text) AS token_count,
101
+ split_chunks(text) AS chunks,
102
+ compute_similarity(embed_async(text1), embed_async(text2)) AS similarity
81
103
  FROM your_table;
82
104
  ```
83
105
 
@@ -103,26 +125,38 @@ Note: This module provides asynchronous support through the pandas extensions.
103
125
 
104
126
  import asyncio
105
127
  import logging
128
+ import os
129
+ from collections.abc import Iterator
106
130
  from enum import Enum
107
- from typing import Dict, Iterator, List, Optional, Type, Union, get_args, get_origin
131
+ from typing import Union, get_args, get_origin
108
132
 
133
+ import numpy as np
109
134
  import pandas as pd
110
135
  import tiktoken
111
136
  from pydantic import BaseModel
137
+ from pyspark import SparkContext
138
+ from pyspark.sql import SparkSession
112
139
  from pyspark.sql.pandas.functions import pandas_udf
113
140
  from pyspark.sql.types import ArrayType, BooleanType, FloatType, IntegerType, StringType, StructField, StructType
114
141
  from pyspark.sql.udf import UserDefinedFunction
115
142
  from typing_extensions import Literal
116
143
 
117
- from . import pandas_ext
118
- from .model import PreparedTask, ResponseFormat
119
- from .serialize import deserialize_base_model, serialize_base_model
120
- from .util import TextChunker
144
+ from openaivec import pandas_ext
145
+ from openaivec._cache import AsyncBatchingMapProxy
146
+ from openaivec._model import EmbeddingsModelName, PreparedTask, ResponseFormat, ResponsesModelName
147
+ from openaivec._provider import CONTAINER
148
+ from openaivec._schema import SchemaInferenceInput, SchemaInferenceOutput, SchemaInferer
149
+ from openaivec._serialize import deserialize_base_model, serialize_base_model
150
+ from openaivec._util import TextChunker
121
151
 
122
152
  __all__ = [
153
+ "setup",
154
+ "setup_azure",
123
155
  "responses_udf",
124
156
  "task_udf",
125
157
  "embeddings_udf",
158
+ "infer_schema",
159
+ "parse_udf",
126
160
  "split_to_chunks_udf",
127
161
  "count_tokens_udf",
128
162
  "similarity_udf",
@@ -130,21 +164,126 @@ __all__ = [
130
164
 
131
165
 
132
166
  _LOGGER: logging.Logger = logging.getLogger(__name__)
133
- _TIKTOKEN_ENC: tiktoken.Encoding | None = None
134
167
 
135
168
 
169
+ def setup(
170
+ spark: SparkSession, api_key: str, responses_model_name: str | None = None, embeddings_model_name: str | None = None
171
+ ):
172
+ """Setup OpenAI authentication and default model names in Spark environment.
173
+ 1. Configures OpenAI API key in SparkContext environment.
174
+ 2. Configures OpenAI API key in local process environment.
175
+ 3. Optionally registers default model names for responses and embeddings in the DI container.
176
+
177
+ Args:
178
+ spark (SparkSession): The Spark session to configure.
179
+ api_key (str): OpenAI API key for authentication.
180
+ responses_model_name (str | None): Default model name for response generation.
181
+ If provided, registers `ResponsesModelName` in the DI container.
182
+ embeddings_model_name (str | None): Default model name for embeddings.
183
+ If provided, registers `EmbeddingsModelName` in the DI container.
184
+
185
+ Example:
186
+ ```python
187
+ from pyspark.sql import SparkSession
188
+ from openaivec.spark import setup
189
+
190
+ spark = SparkSession.builder.getOrCreate()
191
+ setup(
192
+ spark,
193
+ api_key="sk-***",
194
+ responses_model_name="gpt-4.1-mini",
195
+ embeddings_model_name="text-embedding-3-small",
196
+ )
197
+ ```
198
+ """
199
+
200
+ CONTAINER.register(SparkSession, lambda: spark)
201
+ CONTAINER.register(SparkContext, lambda: CONTAINER.resolve(SparkSession).sparkContext)
202
+
203
+ sc = CONTAINER.resolve(SparkContext)
204
+ sc.environment["OPENAI_API_KEY"] = api_key
205
+
206
+ os.environ["OPENAI_API_KEY"] = api_key
207
+
208
+ if responses_model_name:
209
+ CONTAINER.register(ResponsesModelName, lambda: ResponsesModelName(responses_model_name))
210
+
211
+ if embeddings_model_name:
212
+ CONTAINER.register(EmbeddingsModelName, lambda: EmbeddingsModelName(embeddings_model_name))
213
+
214
+ CONTAINER.clear_singletons()
215
+
216
+
217
+ def setup_azure(
218
+ spark: SparkSession,
219
+ api_key: str,
220
+ base_url: str,
221
+ api_version: str = "preview",
222
+ responses_model_name: str | None = None,
223
+ embeddings_model_name: str | None = None,
224
+ ):
225
+ """Setup Azure OpenAI authentication and default model names in Spark environment.
226
+ 1. Configures Azure OpenAI API key, base URL, and API version in SparkContext environment.
227
+ 2. Configures Azure OpenAI API key, base URL, and API version in local process environment.
228
+ 3. Optionally registers default model names for responses and embeddings in the DI container.
229
+ Args:
230
+ spark (SparkSession): The Spark session to configure.
231
+ api_key (str): Azure OpenAI API key for authentication.
232
+ base_url (str): Base URL for the Azure OpenAI resource.
233
+ api_version (str): API version to use. Defaults to "preview".
234
+ responses_model_name (str | None): Default model name for response generation.
235
+ If provided, registers `ResponsesModelName` in the DI container.
236
+ embeddings_model_name (str | None): Default model name for embeddings.
237
+ If provided, registers `EmbeddingsModelName` in the DI container.
238
+
239
+ Example:
240
+ ```python
241
+ from pyspark.sql import SparkSession
242
+ from openaivec.spark import setup_azure
243
+
244
+ spark = SparkSession.builder.getOrCreate()
245
+ setup_azure(
246
+ spark,
247
+ api_key="azure-key",
248
+ base_url="https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/",
249
+ api_version="preview",
250
+ responses_model_name="gpt4-deployment",
251
+ embeddings_model_name="embedding-deployment",
252
+ )
253
+ ```
254
+ """
255
+
256
+ CONTAINER.register(SparkSession, lambda: spark)
257
+ CONTAINER.register(SparkContext, lambda: CONTAINER.resolve(SparkSession).sparkContext)
258
+
259
+ sc = CONTAINER.resolve(SparkContext)
260
+ sc.environment["AZURE_OPENAI_API_KEY"] = api_key
261
+ sc.environment["AZURE_OPENAI_BASE_URL"] = base_url
262
+ sc.environment["AZURE_OPENAI_API_VERSION"] = api_version
263
+
264
+ os.environ["AZURE_OPENAI_API_KEY"] = api_key
265
+ os.environ["AZURE_OPENAI_BASE_URL"] = base_url
266
+ os.environ["AZURE_OPENAI_API_VERSION"] = api_version
267
+
268
+ if responses_model_name:
269
+ CONTAINER.register(ResponsesModelName, lambda: ResponsesModelName(responses_model_name))
270
+
271
+ if embeddings_model_name:
272
+ CONTAINER.register(EmbeddingsModelName, lambda: EmbeddingsModelName(embeddings_model_name))
273
+
274
+ CONTAINER.clear_singletons()
136
275
 
137
276
 
138
277
  def _python_type_to_spark(python_type):
139
278
  origin = get_origin(python_type)
140
279
 
141
- # For list types (e.g., List[int])
142
- if origin is list or origin is List:
280
+ # For list types (e.g., list[int])
281
+ if origin is list:
143
282
  # Retrieve the inner type and recursively convert it
144
283
  inner_type = get_args(python_type)[0]
145
284
  return ArrayType(_python_type_to_spark(inner_type))
146
285
 
147
- # For Optional types (Union[..., None])
286
+ # For Optional types (T | None via Union internally)
148
287
  elif origin is Union:
149
288
  non_none_args = [arg for arg in get_args(python_type) if arg is not type(None)]
150
289
  if len(non_none_args) == 1:
@@ -177,7 +316,7 @@ def _python_type_to_spark(python_type):
177
316
  raise ValueError(f"Unsupported type: {python_type}")
178
317
 
179
318
 
180
- def _pydantic_to_spark_schema(model: Type[BaseModel]) -> StructType:
319
+ def _pydantic_to_spark_schema(model: type[BaseModel]) -> StructType:
181
320
  fields = []
182
321
  for field_name, field in model.model_fields.items():
183
322
  field_type = field.annotation
@@ -188,7 +327,7 @@ def _pydantic_to_spark_schema(model: Type[BaseModel]) -> StructType:
188
327
  return StructType(fields)
189
328
 
190
329
 
191
- def _safe_cast_str(x: Optional[str]) -> Optional[str]:
330
+ def _safe_cast_str(x: str | None) -> str | None:
192
331
  try:
193
332
  if x is None:
194
333
  return None
@@ -199,7 +338,7 @@ def _safe_cast_str(x: Optional[str]) -> Optional[str]:
199
338
  return None
200
339
 
201
340
 
202
- def _safe_dump(x: Optional[BaseModel]) -> Dict:
341
+ def _safe_dump(x: BaseModel | None) -> dict:
203
342
  try:
204
343
  if x is None:
205
344
  return {}
@@ -212,45 +351,52 @@ def _safe_dump(x: Optional[BaseModel]) -> Dict:
212
351
 
213
352
  def responses_udf(
214
353
  instructions: str,
215
- response_format: Type[ResponseFormat] = str,
216
- model_name: str = "gpt-4.1-mini",
217
- batch_size: int = 128,
218
- temperature: float = 0.0,
219
- top_p: float = 1.0,
354
+ response_format: type[ResponseFormat] = str,
355
+ model_name: str | None = None,
356
+ batch_size: int | None = None,
220
357
  max_concurrency: int = 8,
358
+ **api_kwargs,
221
359
  ) -> UserDefinedFunction:
222
360
  """Create an asynchronous Spark pandas UDF for generating responses.
223
361
 
224
- Configures and builds UDFs that leverage `pandas_ext.aio.responses`
362
+ Configures and builds UDFs that leverage `pandas_ext.aio.responses_with_cache`
225
363
  to generate text or structured responses from OpenAI models asynchronously.
364
+ Each partition maintains its own cache to eliminate duplicate API calls within
365
+ the partition, significantly reducing API usage and costs when processing
366
+ datasets with overlapping content.
226
367
 
227
368
  Note:
228
369
  Authentication must be configured via SparkContext environment variables.
229
370
  Set the appropriate environment variables on the SparkContext:
230
-
371
+
231
372
  For OpenAI:
232
373
  sc.environment["OPENAI_API_KEY"] = "your-openai-api-key"
233
-
374
+
234
375
  For Azure OpenAI:
235
376
  sc.environment["AZURE_OPENAI_API_KEY"] = "your-azure-openai-api-key"
236
- sc.environment["AZURE_OPENAI_API_ENDPOINT"] = "your-azure-openai-endpoint"
237
- sc.environment["AZURE_OPENAI_API_VERSION"] = "your-azure-openai-api-version"
377
+ sc.environment["AZURE_OPENAI_BASE_URL"] = "https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/"
378
+ sc.environment["AZURE_OPENAI_API_VERSION"] = "preview"
238
379
 
239
380
  Args:
240
381
  instructions (str): The system prompt or instructions for the model.
241
- response_format (Type[ResponseFormat]): The desired output format. Either `str` for plain text
382
+ response_format (type[ResponseFormat]): The desired output format. Either `str` for plain text
242
383
  or a Pydantic `BaseModel` for structured JSON output. Defaults to `str`.
243
- model_name (str): Deployment name (Azure) or model name (OpenAI) for responses.
244
- Defaults to "gpt-4.1-mini".
245
- batch_size (int): Number of rows per async batch request within each partition.
384
+ model_name (str | None): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
385
+ For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container
386
+ via ResponsesModelName if not provided.
387
+ batch_size (int | None): Number of rows per async batch request within each partition.
246
388
  Larger values reduce API call overhead but increase memory usage.
247
- Recommended: 32-128 depending on data complexity. Defaults to 128.
248
- temperature (float): Sampling temperature (0.0 to 2.0). Defaults to 0.0.
249
- top_p (float): Nucleus sampling parameter. Defaults to 1.0.
389
+ Defaults to None (automatic batch size optimization that dynamically
390
+ adjusts based on execution time, targeting 30-60 seconds per batch).
391
+ Set to a positive integer (e.g., 32-128) for fixed batch size.
250
392
  max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
251
393
  Total cluster concurrency = max_concurrency × number_of_executors.
252
394
  Higher values increase throughput but may hit OpenAI rate limits.
253
395
  Recommended: 4-12 per executor. Defaults to 8.
396
+ **api_kwargs: Additional OpenAI API parameters (e.g. ``temperature``, ``top_p``,
397
+ ``frequency_penalty``, ``presence_penalty``, ``seed``, ``max_output_tokens``, etc.)
398
+ forwarded verbatim to the underlying API calls. These parameters are applied to
399
+ all API requests made by the UDF.
254
400
 
255
401
  Returns:
256
402
  UserDefinedFunction: A Spark pandas UDF configured to generate responses asynchronously.
@@ -259,89 +405,130 @@ def responses_udf(
259
405
  Raises:
260
406
  ValueError: If `response_format` is not `str` or a Pydantic `BaseModel`.
261
407
 
408
+ Example:
409
+ ```python
410
+ from pyspark.sql import SparkSession
411
+ from openaivec.spark import responses_udf, setup
412
+
413
+ spark = SparkSession.builder.getOrCreate()
414
+ setup(spark, api_key="sk-***", responses_model_name="gpt-4.1-mini")
415
+ udf = responses_udf("Reply with one word.")
416
+ spark.udf.register("short_answer", udf)
417
+ df = spark.createDataFrame([("hello",), ("bye",)], ["text"])
418
+ df.selectExpr("short_answer(text) as reply").show()
419
+ ```
420
+
262
421
  Note:
263
422
  For optimal performance in distributed environments:
423
+ - **Automatic Caching**: Duplicate inputs within each partition are cached,
424
+ reducing API calls and costs significantly on datasets with repeated content
264
425
  - Monitor OpenAI API rate limits when scaling executor count
265
426
  - Consider your OpenAI tier limits: total_requests = max_concurrency × executors
266
427
  - Use Spark UI to optimize partition sizes relative to batch_size
267
428
  """
429
+ _model_name = model_name or CONTAINER.resolve(ResponsesModelName).value
430
+
268
431
  if issubclass(response_format, BaseModel):
269
432
  spark_schema = _pydantic_to_spark_schema(response_format)
270
433
  json_schema_string = serialize_base_model(response_format)
271
434
 
272
- @pandas_udf(returnType=spark_schema)
435
+ @pandas_udf(returnType=spark_schema) # type: ignore[call-overload]
273
436
  def structure_udf(col: Iterator[pd.Series]) -> Iterator[pd.DataFrame]:
274
- pandas_ext.responses_model(model_name)
437
+ pandas_ext.set_responses_model(_model_name)
438
+ response_format = deserialize_base_model(json_schema_string)
439
+ cache = AsyncBatchingMapProxy[str, response_format](
440
+ batch_size=batch_size,
441
+ max_concurrency=max_concurrency,
442
+ )
275
443
 
276
- for part in col:
277
- predictions: pd.Series = asyncio.run(
278
- part.aio.responses(
279
- instructions=instructions,
280
- response_format=deserialize_base_model(json_schema_string),
281
- batch_size=batch_size,
282
- temperature=temperature,
283
- top_p=top_p,
284
- max_concurrency=max_concurrency,
444
+ try:
445
+ for part in col:
446
+ predictions: pd.Series = asyncio.run(
447
+ part.aio.responses_with_cache(
448
+ instructions=instructions,
449
+ response_format=response_format,
450
+ cache=cache,
451
+ **api_kwargs,
452
+ )
285
453
  )
286
- )
287
- yield pd.DataFrame(predictions.map(_safe_dump).tolist())
454
+ yield pd.DataFrame(predictions.map(_safe_dump).tolist())
455
+ finally:
456
+ asyncio.run(cache.clear())
288
457
 
289
- return structure_udf
458
+ return structure_udf # type: ignore[return-value]
290
459
 
291
460
  elif issubclass(response_format, str):
292
461
 
293
- @pandas_udf(returnType=StringType())
462
+ @pandas_udf(returnType=StringType()) # type: ignore[call-overload]
294
463
  def string_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
295
- pandas_ext.responses_model(model_name)
464
+ pandas_ext.set_responses_model(_model_name)
465
+ cache = AsyncBatchingMapProxy[str, str](
466
+ batch_size=batch_size,
467
+ max_concurrency=max_concurrency,
468
+ )
296
469
 
297
- for part in col:
298
- predictions: pd.Series = asyncio.run(
299
- part.aio.responses(
300
- instructions=instructions,
301
- response_format=str,
302
- batch_size=batch_size,
303
- temperature=temperature,
304
- top_p=top_p,
305
- max_concurrency=max_concurrency,
470
+ try:
471
+ for part in col:
472
+ predictions: pd.Series = asyncio.run(
473
+ part.aio.responses_with_cache(
474
+ instructions=instructions,
475
+ response_format=str,
476
+ cache=cache,
477
+ **api_kwargs,
478
+ )
306
479
  )
307
- )
308
- yield predictions.map(_safe_cast_str)
480
+ yield predictions.map(_safe_cast_str)
481
+ finally:
482
+ asyncio.run(cache.clear())
309
483
 
310
- return string_udf
484
+ return string_udf # type: ignore[return-value]
311
485
 
312
486
  else:
313
487
  raise ValueError(f"Unsupported response_format: {response_format}")
314
488
 
315
489
 
316
-
317
490
  def task_udf(
318
- task: PreparedTask,
319
- model_name: str = "gpt-4.1-mini",
320
- batch_size: int = 128,
491
+ task: PreparedTask[ResponseFormat],
492
+ model_name: str | None = None,
493
+ batch_size: int | None = None,
321
494
  max_concurrency: int = 8,
495
+ **api_kwargs,
322
496
  ) -> UserDefinedFunction:
323
497
  """Create an asynchronous Spark pandas UDF from a predefined task.
324
498
 
325
499
  This function allows users to create UDFs from predefined tasks such as sentiment analysis,
326
500
  translation, or other common NLP operations defined in the openaivec.task module.
501
+ Each partition maintains its own cache to eliminate duplicate API calls within
502
+ the partition, significantly reducing API usage and costs when processing
503
+ datasets with overlapping content.
327
504
 
328
505
  Args:
329
- task (PreparedTask): A predefined task configuration containing instructions,
330
- response format, temperature, and top_p settings.
331
- model_name (str): Deployment name (Azure) or model name (OpenAI) for responses.
332
- Defaults to "gpt-4.1-mini".
333
- batch_size (int): Number of rows per async batch request within each partition.
506
+ task (PreparedTask): A predefined task configuration containing instructions
507
+ and response format.
508
+ model_name (str | None): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
509
+ For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container
510
+ via ResponsesModelName if not provided.
511
+ batch_size (int | None): Number of rows per async batch request within each partition.
334
512
  Larger values reduce API call overhead but increase memory usage.
335
- Recommended: 32-128 depending on task complexity. Defaults to 128.
513
+ Defaults to None (automatic batch size optimization that dynamically
514
+ adjusts based on execution time, targeting 30-60 seconds per batch).
515
+ Set to a positive integer (e.g., 32-128) for fixed batch size.
336
516
  max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
337
517
  Total cluster concurrency = max_concurrency × number_of_executors.
338
518
  Higher values increase throughput but may hit OpenAI rate limits.
339
519
  Recommended: 4-12 per executor. Defaults to 8.
340
520
 
521
+ Additional Keyword Args:
522
+ Arbitrary OpenAI Responses API parameters (e.g. ``temperature``, ``top_p``,
523
+ ``frequency_penalty``, ``presence_penalty``, ``seed``, ``max_output_tokens``, etc.)
524
+ are forwarded verbatim to the underlying API calls. These parameters are applied to
525
+ all API requests made by the UDF.
526
+
341
527
  Returns:
342
528
  UserDefinedFunction: A Spark pandas UDF configured to execute the specified task
343
- asynchronously. Output schema is StringType for str response format or
344
- a struct derived from the task's response format for BaseModel.
529
+ asynchronously with automatic caching for duplicate inputs within each partition.
530
+ Output schema is StringType for str response format or a struct derived from
531
+ the task's response format for BaseModel.
345
532
 
346
533
  Example:
347
534
  ```python
@@ -351,186 +538,301 @@ def task_udf(
351
538
 
352
539
  spark.udf.register("analyze_sentiment", sentiment_udf)
353
540
  ```
541
+
542
+ Note:
543
+ **Automatic Caching**: Duplicate inputs within each partition are cached,
544
+ reducing API calls and costs significantly on datasets with repeated content.
354
545
  """
355
- # Serialize task parameters for Spark serialization compatibility
356
- task_instructions = task.instructions
357
- task_temperature = task.temperature
358
- task_top_p = task.top_p
359
-
360
- if issubclass(task.response_format, BaseModel):
361
- task_response_format_json = serialize_base_model(task.response_format)
362
-
363
- # Deserialize the response format from JSON
364
- response_format = deserialize_base_model(task_response_format_json)
365
- spark_schema = _pydantic_to_spark_schema(response_format)
546
+ return responses_udf(
547
+ instructions=task.instructions,
548
+ response_format=task.response_format,
549
+ model_name=model_name,
550
+ batch_size=batch_size,
551
+ max_concurrency=max_concurrency,
552
+ **api_kwargs,
553
+ )
366
554
 
367
- @pandas_udf(returnType=spark_schema)
368
- def task_udf(col: Iterator[pd.Series]) -> Iterator[pd.DataFrame]:
369
- pandas_ext.responses_model(model_name)
370
555
 
371
- for part in col:
372
- predictions: pd.Series = asyncio.run(
373
- part.aio.responses(
374
- instructions=task_instructions,
375
- response_format=response_format,
376
- batch_size=batch_size,
377
- temperature=task_temperature,
378
- top_p=task_top_p,
379
- max_concurrency=max_concurrency,
380
- )
381
- )
382
- yield pd.DataFrame(predictions.map(_safe_dump).tolist())
556
+ def infer_schema(
557
+ instructions: str,
558
+ example_table_name: str,
559
+ example_field_name: str,
560
+ max_examples: int = 100,
561
+ ) -> SchemaInferenceOutput:
562
+ """Infer the schema for a response format based on example data.
383
563
 
384
- return task_udf
385
-
386
- elif issubclass(task.response_format, str):
387
-
388
- @pandas_udf(returnType=StringType())
389
- def task_string_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
390
- pandas_ext.responses_model(model_name)
564
+ This function retrieves examples from a Spark table and infers the schema
565
+ for the response format using the provided instructions. It is useful when
566
+ you want to dynamically generate a schema based on existing data.
391
567
 
392
- for part in col:
393
- predictions: pd.Series = asyncio.run(
394
- part.aio.responses(
395
- instructions=task_instructions,
396
- response_format=str,
397
- batch_size=batch_size,
398
- temperature=task_temperature,
399
- top_p=task_top_p,
400
- max_concurrency=max_concurrency,
401
- )
402
- )
403
- yield predictions.map(_safe_cast_str)
568
+ Args:
569
+ instructions (str): Instructions for the model to infer the schema.
570
+ example_table_name (str | None): Name of the Spark table containing example data.
571
+ example_field_name (str | None): Name of the field in the table to use as examples.
572
+ max_examples (int): Maximum number of examples to retrieve for schema inference.
404
573
 
405
- return task_string_udf
406
-
407
- else:
408
- raise ValueError(f"Unsupported response_format in task: {task.response_format}")
574
+ Returns:
575
+ InferredSchema: An object containing the inferred schema and response format.
409
576
 
577
+ Example:
578
+ ```python
579
+ from pyspark.sql import SparkSession
580
+
581
+ spark = SparkSession.builder.getOrCreate()
582
+ spark.createDataFrame([("great product",), ("bad service",)], ["text"]).createOrReplaceTempView("examples")
583
+ infer_schema(
584
+ instructions="Classify sentiment as positive or negative.",
585
+ example_table_name="examples",
586
+ example_field_name="text",
587
+ max_examples=2,
588
+ )
589
+ ```
590
+ """
410
591
 
592
+ spark = CONTAINER.resolve(SparkSession)
593
+ examples: list[str] = (
594
+ spark.table(example_table_name).rdd.map(lambda row: row[example_field_name]).takeSample(False, max_examples)
595
+ )
411
596
 
597
+ input = SchemaInferenceInput(
598
+ instructions=instructions,
599
+ examples=examples,
600
+ )
601
+ inferer = CONTAINER.resolve(SchemaInferer)
602
+ return inferer.infer_schema(input)
603
+
604
+
605
+ def parse_udf(
606
+ instructions: str,
607
+ response_format: type[ResponseFormat] | None = None,
608
+ example_table_name: str | None = None,
609
+ example_field_name: str | None = None,
610
+ max_examples: int = 100,
611
+ model_name: str | None = None,
612
+ batch_size: int | None = None,
613
+ max_concurrency: int = 8,
614
+ **api_kwargs,
615
+ ) -> UserDefinedFunction:
616
+ """Create an asynchronous Spark pandas UDF for parsing responses.
617
+ This function allows users to create UDFs that parse responses based on
618
+ provided instructions and either a predefined response format or example data.
619
+ It supports both structured responses using Pydantic models and plain text responses.
620
+ Each partition maintains its own cache to eliminate duplicate API calls within
621
+ the partition, significantly reducing API usage and costs when processing
622
+ datasets with overlapping content.
412
623
 
413
- def embeddings_udf(model_name: str = "text-embedding-3-small", batch_size: int = 128, max_concurrency: int = 8) -> UserDefinedFunction:
624
+ Args:
625
+ instructions (str): The system prompt or instructions for the model.
626
+ response_format (type[ResponseFormat] | None): The desired output format.
627
+ Either `str` for plain text or a Pydantic `BaseModel` for structured JSON output.
628
+ If not provided, the schema will be inferred from example data.
629
+ example_table_name (str | None): Name of the Spark table containing example data.
630
+ If provided, `example_field_name` must also be specified.
631
+ example_field_name (str | None): Name of the field in the table to use as examples.
632
+ If provided, `example_table_name` must also be specified.
633
+ max_examples (int): Maximum number of examples to retrieve for schema inference.
634
+ Defaults to 100.
635
+ model_name (str | None): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
636
+ For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container
637
+ via ResponsesModelName if not provided.
638
+ batch_size (int | None): Number of rows per async batch request within each partition.
639
+ Larger values reduce API call overhead but increase memory usage.
640
+ Defaults to None (automatic batch size optimization that dynamically
641
+ adjusts based on execution time, targeting 30-60 seconds per batch).
642
+ Set to a positive integer (e.g., 32-128) for fixed batch size
643
+ max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
644
+ Total cluster concurrency = max_concurrency × number_of_executors.
645
+ Higher values increase throughput but may hit OpenAI rate limits.
646
+ Recommended: 4-12 per executor. Defaults to 8.
647
+ **api_kwargs: Additional OpenAI API parameters (e.g. ``temperature``, ``top_p``,
648
+ ``frequency_penalty``, ``presence_penalty``, ``seed``, ``max_output_tokens``, etc.)
649
+ forwarded verbatim to the underlying API calls. These parameters are applied to
650
+ all API requests made by the UDF and override any parameters set in the
651
+ response_format or example data.
652
+ Example:
653
+ ```python
654
+ from pyspark.sql import SparkSession
655
+
656
+ spark = SparkSession.builder.getOrCreate()
657
+ spark.createDataFrame(
658
+ [("Order #123 delivered",), ("Order #456 delayed",)],
659
+ ["body"],
660
+ ).createOrReplaceTempView("messages")
661
+ udf = parse_udf(
662
+ instructions="Extract order id as `order_id` and status as `status`.",
663
+ example_table_name="messages",
664
+ example_field_name="body",
665
+ )
666
+ spark.udf.register("parse_ticket", udf)
667
+ spark.sql("SELECT parse_ticket(body) AS parsed FROM messages").show()
668
+ ```
669
+ Returns:
670
+ UserDefinedFunction: A Spark pandas UDF configured to parse responses asynchronously.
671
+ Output schema is `StringType` for str response format or a struct derived from
672
+ the response_format for BaseModel.
673
+ Raises:
674
+ ValueError: If neither `response_format` nor `example_table_name` and `example_field_name` are provided.
675
+ """
676
+
677
+ if not response_format and not (example_field_name and example_table_name):
678
+ raise ValueError("Either response_format or example_table_name and example_field_name must be provided.")
679
+
680
+ schema: SchemaInferenceOutput | None = None
681
+
682
+ if not response_format:
683
+ schema = infer_schema(
684
+ instructions=instructions,
685
+ example_table_name=example_table_name,
686
+ example_field_name=example_field_name,
687
+ max_examples=max_examples,
688
+ )
689
+
690
+ return responses_udf(
691
+ instructions=schema.inference_prompt if schema else instructions,
692
+ response_format=schema.model if schema else response_format,
693
+ model_name=model_name,
694
+ batch_size=batch_size,
695
+ max_concurrency=max_concurrency,
696
+ **api_kwargs,
697
+ )
698
+
699
+
700
+ def embeddings_udf(
701
+ model_name: str | None = None,
702
+ batch_size: int | None = None,
703
+ max_concurrency: int = 8,
704
+ **api_kwargs,
705
+ ) -> UserDefinedFunction:
414
706
  """Create an asynchronous Spark pandas UDF for generating embeddings.
415
707
 
416
- Configures and builds UDFs that leverage `pandas_ext.aio.embeddings`
708
+ Configures and builds UDFs that leverage `pandas_ext.aio.embeddings_with_cache`
417
709
  to generate vector embeddings from OpenAI models asynchronously.
710
+ Each partition maintains its own cache to eliminate duplicate API calls within
711
+ the partition, significantly reducing API usage and costs when processing
712
+ datasets with overlapping content.
418
713
 
419
714
  Note:
420
715
  Authentication must be configured via SparkContext environment variables.
421
716
  Set the appropriate environment variables on the SparkContext:
422
-
717
+
423
718
  For OpenAI:
424
719
  sc.environment["OPENAI_API_KEY"] = "your-openai-api-key"
425
-
720
+
426
721
  For Azure OpenAI:
427
722
  sc.environment["AZURE_OPENAI_API_KEY"] = "your-azure-openai-api-key"
428
- sc.environment["AZURE_OPENAI_API_ENDPOINT"] = "your-azure-openai-endpoint"
429
- sc.environment["AZURE_OPENAI_API_VERSION"] = "your-azure-openai-api-version"
723
+ sc.environment["AZURE_OPENAI_BASE_URL"] = "https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/"
724
+ sc.environment["AZURE_OPENAI_API_VERSION"] = "preview"
430
725
 
431
726
  Args:
432
- model_name (str): Deployment name (Azure) or model name (OpenAI) for embeddings.
433
- Defaults to "text-embedding-3-small".
434
- batch_size (int): Number of rows per async batch request within each partition.
727
+ model_name (str | None): For Azure OpenAI, use your deployment name (e.g., "my-embedding-deployment").
728
+ For OpenAI, use the model name (e.g., "text-embedding-3-small").
729
+ Defaults to configured model in DI container via EmbeddingsModelName if not provided.
730
+ batch_size (int | None): Number of rows per async batch request within each partition.
435
731
  Larger values reduce API call overhead but increase memory usage.
732
+ Defaults to None (automatic batch size optimization that dynamically
733
+ adjusts based on execution time, targeting 30-60 seconds per batch).
734
+ Set to a positive integer (e.g., 64-256) for fixed batch size.
436
735
  Embeddings typically handle larger batches efficiently.
437
- Recommended: 64-256 depending on text length. Defaults to 128.
438
736
  max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
439
737
  Total cluster concurrency = max_concurrency × number_of_executors.
440
738
  Higher values increase throughput but may hit OpenAI rate limits.
441
739
  Recommended: 4-12 per executor. Defaults to 8.
740
+ **api_kwargs: Additional OpenAI API parameters (e.g., dimensions for text-embedding-3 models).
442
741
 
443
742
  Returns:
444
- UserDefinedFunction: A Spark pandas UDF configured to generate embeddings asynchronously,
743
+ UserDefinedFunction: A Spark pandas UDF configured to generate embeddings asynchronously
744
+ with automatic caching for duplicate inputs within each partition,
445
745
  returning an `ArrayType(FloatType())` column.
446
746
 
447
747
  Note:
448
748
  For optimal performance in distributed environments:
749
+ - **Automatic Caching**: Duplicate inputs within each partition are cached,
750
+ reducing API calls and costs significantly on datasets with repeated content
449
751
  - Monitor OpenAI API rate limits when scaling executor count
450
752
  - Consider your OpenAI tier limits: total_requests = max_concurrency × executors
451
753
  - Embeddings API typically has higher throughput than chat completions
452
754
  - Use larger batch_size for embeddings compared to response generation
453
755
  """
454
- @pandas_udf(returnType=ArrayType(FloatType()))
756
+
757
+ _model_name = model_name or CONTAINER.resolve(EmbeddingsModelName).value
758
+
759
+ @pandas_udf(returnType=ArrayType(FloatType())) # type: ignore[call-overload,misc]
455
760
  def _embeddings_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
456
- pandas_ext.embeddings_model(model_name)
761
+ pandas_ext.set_embeddings_model(_model_name)
762
+ cache = AsyncBatchingMapProxy[str, np.ndarray](
763
+ batch_size=batch_size,
764
+ max_concurrency=max_concurrency,
765
+ )
457
766
 
458
- for part in col:
459
- embeddings: pd.Series = asyncio.run(
460
- part.aio.embeddings(batch_size=batch_size, max_concurrency=max_concurrency)
461
- )
462
- yield embeddings.map(lambda x: x.tolist())
767
+ try:
768
+ for part in col:
769
+ embeddings: pd.Series = asyncio.run(part.aio.embeddings_with_cache(cache=cache, **api_kwargs))
770
+ yield embeddings.map(lambda x: x.tolist())
771
+ finally:
772
+ asyncio.run(cache.clear())
463
773
 
464
- return _embeddings_udf
774
+ return _embeddings_udf # type: ignore[return-value]
465
775
 
466
776
 
467
- def split_to_chunks_udf(model_name: str, max_tokens: int, sep: List[str]) -> UserDefinedFunction:
777
+ def split_to_chunks_udf(max_tokens: int, sep: list[str]) -> UserDefinedFunction:
468
778
  """Create a pandas‑UDF that splits text into token‑bounded chunks.
469
779
 
470
780
  Args:
471
- model_name (str): Model identifier passed to *tiktoken*.
472
781
  max_tokens (int): Maximum tokens allowed per chunk.
473
- sep (List[str]): Ordered list of separator strings used by ``TextChunker``.
782
+ sep (list[str]): Ordered list of separator strings used by ``TextChunker``.
474
783
 
475
784
  Returns:
476
785
  A pandas UDF producing an ``ArrayType(StringType())`` column whose
477
786
  values are lists of chunks respecting the ``max_tokens`` limit.
478
787
  """
479
788
 
480
- @pandas_udf(ArrayType(StringType()))
789
+ @pandas_udf(ArrayType(StringType())) # type: ignore[call-overload,misc]
481
790
  def fn(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
482
- global _TIKTOKEN_ENC
483
- if _TIKTOKEN_ENC is None:
484
- _TIKTOKEN_ENC = tiktoken.encoding_for_model(model_name)
485
-
486
- chunker = TextChunker(_TIKTOKEN_ENC)
791
+ encoding = tiktoken.get_encoding("o200k_base")
792
+ chunker = TextChunker(encoding)
487
793
 
488
794
  for part in col:
489
795
  yield part.map(lambda x: chunker.split(x, max_tokens=max_tokens, sep=sep) if isinstance(x, str) else [])
490
796
 
491
- return fn
797
+ return fn # type: ignore[return-value]
492
798
 
493
799
 
494
- def count_tokens_udf(model_name: str = "gpt-4o") -> UserDefinedFunction:
800
+ def count_tokens_udf() -> UserDefinedFunction:
495
801
  """Create a pandas‑UDF that counts tokens for every string cell.
496
802
 
497
803
  The UDF uses *tiktoken* to approximate tokenisation and caches the
498
804
  resulting ``Encoding`` object per executor.
499
805
 
500
- Args:
501
- model_name (str): Model identifier understood by ``tiktoken``.
502
-
503
806
  Returns:
504
807
  A pandas UDF producing an ``IntegerType`` column with token counts.
505
808
  """
506
809
 
507
- @pandas_udf(IntegerType())
810
+ @pandas_udf(IntegerType()) # type: ignore[call-overload]
508
811
  def fn(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
509
- global _TIKTOKEN_ENC
510
- if _TIKTOKEN_ENC is None:
511
- _TIKTOKEN_ENC = tiktoken.encoding_for_model(model_name)
812
+ encoding = tiktoken.get_encoding("o200k_base")
512
813
 
513
814
  for part in col:
514
- yield part.map(lambda x: len(_TIKTOKEN_ENC.encode(x)) if isinstance(x, str) else 0)
815
+ yield part.map(lambda x: len(encoding.encode(x)) if isinstance(x, str) else 0)
515
816
 
516
- return fn
817
+ return fn # type: ignore[return-value]
517
818
 
518
819
 
519
820
  def similarity_udf() -> UserDefinedFunction:
520
- @pandas_udf(FloatType())
521
- def fn(a: pd.Series, b: pd.Series) -> pd.Series:
522
- """Compute cosine similarity between two vectors.
821
+ """Create a pandas-UDF that computes cosine similarity between embedding vectors.
523
822
 
524
- Args:
525
- a: First vector.
526
- b: Second vector.
823
+ Returns:
824
+ UserDefinedFunction: A Spark pandas UDF that takes two embedding vector columns
825
+ and returns their cosine similarity as a FloatType column.
826
+ """
527
827
 
528
- Returns:
529
- Cosine similarity between the two vectors.
530
- """
828
+ @pandas_udf(FloatType()) # type: ignore[call-overload]
829
+ def fn(a: pd.Series, b: pd.Series) -> pd.Series:
531
830
  # Import pandas_ext to ensure .ai accessor is available in Spark workers
532
- from . import pandas_ext # noqa: F401
831
+ from openaivec import pandas_ext
832
+
833
+ # Explicitly reference pandas_ext to satisfy linters
834
+ assert pandas_ext is not None
533
835
 
534
836
  return pd.DataFrame({"a": a, "b": b}).ai.similarity("a", "b")
535
837
 
536
- return fn
838
+ return fn # type: ignore[return-value]