openaivec 0.14.10__py3-none-any.whl → 0.14.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
openaivec/spark.py CHANGED
@@ -1,45 +1,45 @@
1
1
  """Asynchronous Spark UDFs for the OpenAI and Azure OpenAI APIs.
2
2
 
3
3
  This module provides functions (`responses_udf`, `task_udf`, `embeddings_udf`,
4
- `count_tokens_udf`, `split_to_chunks_udf`)
4
+ `count_tokens_udf`, `split_to_chunks_udf`, `similarity_udf`, `parse_udf`)
5
5
  for creating asynchronous Spark UDFs that communicate with either the public
6
6
  OpenAI API or Azure OpenAI using the `openaivec.spark` subpackage.
7
- It supports UDFs for generating responses and creating embeddings asynchronously.
8
- The UDFs operate on Spark DataFrames and leverage asyncio for potentially
9
- improved performance in I/O-bound operations.
7
+ It supports UDFs for generating responses, creating embeddings, parsing text,
8
+ and computing similarities asynchronously. The UDFs operate on Spark DataFrames
9
+ and leverage asyncio for improved performance in I/O-bound operations.
10
10
 
11
- **Performance Optimization**: All AI-powered UDFs (`responses_udf`, `task_udf`, `embeddings_udf`)
11
+ **Performance Optimization**: All AI-powered UDFs (`responses_udf`, `task_udf`, `embeddings_udf`, `parse_udf`)
12
12
  automatically cache duplicate inputs within each partition, significantly reducing
13
13
  API calls and costs when processing datasets with overlapping content.
14
14
 
15
- __all__ = [
16
- "count_tokens_udf",
17
- "embeddings_udf",
18
- "responses_udf",
19
- "similarity_udf",
20
- "split_to_chunks_udf",
21
- "task_udf",
22
- ]
23
15
 
24
16
  ## Setup
25
17
 
26
18
  First, obtain a Spark session and configure authentication:
27
19
 
28
20
  ```python
29
- import os
30
21
  from pyspark.sql import SparkSession
22
+ from openaivec.spark import setup, setup_azure
31
23
 
32
24
  spark = SparkSession.builder.getOrCreate()
33
- sc = spark.sparkContext
34
25
 
35
- # Configure authentication via SparkContext environment variables
36
26
  # Option 1: Using OpenAI
37
- sc.environment["OPENAI_API_KEY"] = "your-openai-api-key"
27
+ setup(
28
+ spark,
29
+ api_key="your-openai-api-key",
30
+ responses_model_name="gpt-4.1-mini", # Optional: set default model
31
+ embeddings_model_name="text-embedding-3-small" # Optional: set default model
32
+ )
38
33
 
39
34
  # Option 2: Using Azure OpenAI
40
- # sc.environment["AZURE_OPENAI_API_KEY"] = "your-azure-openai-api-key"
41
- # sc.environment["AZURE_OPENAI_BASE_URL"] = "https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/"
42
- # sc.environment["AZURE_OPENAI_API_VERSION"] = "preview"
35
+ # setup_azure(
36
+ # spark,
37
+ # api_key="your-azure-openai-api-key",
38
+ # base_url="https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/",
39
+ # api_version="preview",
40
+ # responses_model_name="my-gpt4-deployment", # Optional: set default deployment
41
+ # embeddings_model_name="my-embedding-deployment" # Optional: set default deployment
42
+ # )
43
43
  ```
44
44
 
45
45
  Next, create UDFs and register them:
@@ -83,9 +83,10 @@ spark.udf.register(
83
83
  ),
84
84
  )
85
85
 
86
- # Register token counting and text chunking UDFs
86
+ # Register token counting, text chunking, and similarity UDFs
87
87
  spark.udf.register("count_tokens", count_tokens_udf())
88
88
  spark.udf.register("split_chunks", split_to_chunks_udf(max_tokens=512, sep=[".", "!", "?"]))
89
+ spark.udf.register("compute_similarity", similarity_udf())
89
90
  ```
90
91
 
91
92
  You can now invoke the UDFs from Spark SQL:
@@ -97,7 +98,8 @@ SELECT
97
98
  sentiment_async(text) AS sentiment,
98
99
  embed_async(text) AS embedding,
99
100
  count_tokens(text) AS token_count,
100
- split_chunks(text) AS chunks
101
+ split_chunks(text) AS chunks,
102
+ compute_similarity(embed_async(text1), embed_async(text2)) AS similarity
101
103
  FROM your_table;
102
104
  ```
103
105
 
@@ -123,6 +125,7 @@ Note: This module provides asynchronous support through the pandas extensions.
123
125
 
124
126
  import asyncio
125
127
  import logging
128
+ import os
126
129
  from collections.abc import Iterator
127
130
  from enum import Enum
128
131
  from typing import Union, get_args, get_origin
@@ -131,14 +134,18 @@ import numpy as np
131
134
  import pandas as pd
132
135
  import tiktoken
133
136
  from pydantic import BaseModel
137
+ from pyspark import SparkContext
138
+ from pyspark.sql import SparkSession
134
139
  from pyspark.sql.pandas.functions import pandas_udf
135
140
  from pyspark.sql.types import ArrayType, BooleanType, FloatType, IntegerType, StringType, StructField, StructType
136
141
  from pyspark.sql.udf import UserDefinedFunction
137
142
  from typing_extensions import Literal
138
143
 
139
144
  from openaivec import pandas_ext
140
- from openaivec._model import PreparedTask, ResponseFormat
145
+ from openaivec._model import EmbeddingsModelName, PreparedTask, ResponseFormat, ResponsesModelName
146
+ from openaivec._provider import CONTAINER
141
147
  from openaivec._proxy import AsyncBatchingMapProxy
148
+ from openaivec._schema import InferredSchema, SchemaInferenceInput, SchemaInferer
142
149
  from openaivec._serialize import deserialize_base_model, serialize_base_model
143
150
  from openaivec._util import TextChunker
144
151
 
@@ -146,6 +153,8 @@ __all__ = [
146
153
  "responses_udf",
147
154
  "task_udf",
148
155
  "embeddings_udf",
156
+ "infer_schema",
157
+ "parse_udf",
149
158
  "split_to_chunks_udf",
150
159
  "count_tokens_udf",
151
160
  "similarity_udf",
@@ -155,6 +164,86 @@ __all__ = [
155
164
  _LOGGER: logging.Logger = logging.getLogger(__name__)
156
165
 
157
166
 
167
+ def setup(
168
+ spark: SparkSession, api_key: str, responses_model_name: str | None = None, embeddings_model_name: str | None = None
169
+ ):
170
+ """Setup OpenAI authentication and default model names in Spark environment.
171
+ 1. Configures OpenAI API key in SparkContext environment.
172
+ 2. Configures OpenAI API key in local process environment.
173
+ 3. Optionally registers default model names for responses and embeddings in the DI container.
174
+
175
+ Args:
176
+ spark (SparkSession): The Spark session to configure.
177
+ api_key (str): OpenAI API key for authentication.
178
+ responses_model_name (str | None): Default model name for response generation.
179
+ If provided, registers `ResponsesModelName` in the DI container.
180
+ embeddings_model_name (str | None): Default model name for embeddings.
181
+ If provided, registers `EmbeddingsModelName` in the DI container.
182
+ """
183
+
184
+ CONTAINER.register(SparkSession, lambda: spark)
185
+ CONTAINER.register(SparkContext, lambda: CONTAINER.resolve(SparkSession).sparkContext)
186
+
187
+ sc = CONTAINER.resolve(SparkContext)
188
+ sc.environment["OPENAI_API_KEY"] = api_key
189
+
190
+ os.environ["OPENAI_API_KEY"] = api_key
191
+
192
+ if responses_model_name:
193
+ CONTAINER.register(ResponsesModelName, lambda: ResponsesModelName(responses_model_name))
194
+
195
+ if embeddings_model_name:
196
+ from openaivec._model import EmbeddingsModelName
197
+
198
+ CONTAINER.register(EmbeddingsModelName, lambda: EmbeddingsModelName(embeddings_model_name))
199
+
200
+ CONTAINER.clear_singletons()
201
+
202
+
203
+ def setup_azure(
204
+ spark: SparkSession,
205
+ api_key: str,
206
+ base_url: str,
207
+ api_version: str = "preview",
208
+ responses_model_name: str | None = None,
209
+ embeddings_model_name: str | None = None,
210
+ ):
211
+ """Setup Azure OpenAI authentication and default model names in Spark environment.
212
+ 1. Configures Azure OpenAI API key, base URL, and API version in SparkContext environment.
213
+ 2. Configures Azure OpenAI API key, base URL, and API version in local process environment.
214
+ 3. Optionally registers default model names for responses and embeddings in the DI container.
215
+ Args:
216
+ spark (SparkSession): The Spark session to configure.
217
+ api_key (str): Azure OpenAI API key for authentication.
218
+ base_url (str): Base URL for the Azure OpenAI resource.
219
+ api_version (str): API version to use. Defaults to "preview".
220
+ responses_model_name (str | None): Default model name for response generation.
221
+ If provided, registers `ResponsesModelName` in the DI container.
222
+ embeddings_model_name (str | None): Default model name for embeddings.
223
+ If provided, registers `EmbeddingsModelName` in the DI container.
224
+ """
225
+
226
+ CONTAINER.register(SparkSession, lambda: spark)
227
+ CONTAINER.register(SparkContext, lambda: CONTAINER.resolve(SparkSession).sparkContext)
228
+
229
+ sc = CONTAINER.resolve(SparkContext)
230
+ sc.environment["AZURE_OPENAI_API_KEY"] = api_key
231
+ sc.environment["AZURE_OPENAI_BASE_URL"] = base_url
232
+ sc.environment["AZURE_OPENAI_API_VERSION"] = api_version
233
+
234
+ os.environ["AZURE_OPENAI_API_KEY"] = api_key
235
+ os.environ["AZURE_OPENAI_BASE_URL"] = base_url
236
+ os.environ["AZURE_OPENAI_API_VERSION"] = api_version
237
+
238
+ if responses_model_name:
239
+ CONTAINER.register(ResponsesModelName, lambda: ResponsesModelName(responses_model_name))
240
+
241
+ if embeddings_model_name:
242
+ CONTAINER.register(EmbeddingsModelName, lambda: EmbeddingsModelName(embeddings_model_name))
243
+
244
+ CONTAINER.clear_singletons()
245
+
246
+
158
247
  def _python_type_to_spark(python_type):
159
248
  origin = get_origin(python_type)
160
249
 
@@ -233,10 +322,8 @@ def _safe_dump(x: BaseModel | None) -> dict:
233
322
  def responses_udf(
234
323
  instructions: str,
235
324
  response_format: type[ResponseFormat] = str,
236
- model_name: str = "gpt-4.1-mini",
325
+ model_name: str = CONTAINER.resolve(ResponsesModelName).value,
237
326
  batch_size: int | None = None,
238
- temperature: float | None = 0.0,
239
- top_p: float = 1.0,
240
327
  max_concurrency: int = 8,
241
328
  **api_kwargs,
242
329
  ) -> UserDefinedFunction:
@@ -265,23 +352,20 @@ def responses_udf(
265
352
  response_format (type[ResponseFormat]): The desired output format. Either `str` for plain text
266
353
  or a Pydantic `BaseModel` for structured JSON output. Defaults to `str`.
267
354
  model_name (str): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
268
- For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to "gpt-4.1-mini".
355
+ For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container.
269
356
  batch_size (int | None): Number of rows per async batch request within each partition.
270
357
  Larger values reduce API call overhead but increase memory usage.
271
358
  Defaults to None (automatic batch size optimization that dynamically
272
359
  adjusts based on execution time, targeting 30-60 seconds per batch).
273
360
  Set to a positive integer (e.g., 32-128) for fixed batch size.
274
- temperature (float): Sampling temperature (0.0 to 2.0). Defaults to 0.0.
275
- top_p (float): Nucleus sampling parameter. Defaults to 1.0.
276
361
  max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
277
362
  Total cluster concurrency = max_concurrency × number_of_executors.
278
363
  Higher values increase throughput but may hit OpenAI rate limits.
279
364
  Recommended: 4-12 per executor. Defaults to 8.
280
-
281
- Additional Keyword Args:
282
- Arbitrary OpenAI Responses API parameters (e.g. ``frequency_penalty``, ``presence_penalty``,
283
- ``seed``, ``max_output_tokens``, etc.) are forwarded verbatim to the underlying API calls.
284
- These parameters are applied to all API requests made by the UDF.
365
+ **api_kwargs: Additional OpenAI API parameters (e.g. ``temperature``, ``top_p``,
366
+ ``frequency_penalty``, ``presence_penalty``, ``seed``, ``max_output_tokens``, etc.)
367
+ forwarded verbatim to the underlying API calls. These parameters are applied to
368
+ all API requests made by the UDF.
285
369
 
286
370
  Returns:
287
371
  UserDefinedFunction: A Spark pandas UDF configured to generate responses asynchronously.
@@ -317,8 +401,6 @@ def responses_udf(
317
401
  part.aio.responses_with_cache(
318
402
  instructions=instructions,
319
403
  response_format=response_format,
320
- temperature=temperature,
321
- top_p=top_p,
322
404
  cache=cache,
323
405
  **api_kwargs,
324
406
  )
@@ -345,8 +427,6 @@ def responses_udf(
345
427
  part.aio.responses_with_cache(
346
428
  instructions=instructions,
347
429
  response_format=str,
348
- temperature=temperature,
349
- top_p=top_p,
350
430
  cache=cache,
351
431
  **api_kwargs,
352
432
  )
@@ -363,7 +443,7 @@ def responses_udf(
363
443
 
364
444
  def task_udf(
365
445
  task: PreparedTask[ResponseFormat],
366
- model_name: str = "gpt-4.1-mini",
446
+ model_name: str = CONTAINER.resolve(ResponsesModelName).value,
367
447
  batch_size: int | None = None,
368
448
  max_concurrency: int = 8,
369
449
  **api_kwargs,
@@ -378,9 +458,9 @@ def task_udf(
378
458
 
379
459
  Args:
380
460
  task (PreparedTask): A predefined task configuration containing instructions,
381
- response format, temperature, and top_p settings.
461
+ response format, and API parameters.
382
462
  model_name (str): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
383
- For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to "gpt-4.1-mini".
463
+ For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container.
384
464
  batch_size (int | None): Number of rows per async batch request within each partition.
385
465
  Larger values reduce API call overhead but increase memory usage.
386
466
  Defaults to None (automatic batch size optimization that dynamically
@@ -392,10 +472,10 @@ def task_udf(
392
472
  Recommended: 4-12 per executor. Defaults to 8.
393
473
 
394
474
  Additional Keyword Args:
395
- Arbitrary OpenAI Responses API parameters (e.g. ``frequency_penalty``, ``presence_penalty``,
396
- ``seed``, ``max_output_tokens``, etc.) are forwarded verbatim to the underlying API calls.
397
- These parameters are applied to all API requests made by the UDF and override any
398
- parameters set in the task configuration.
475
+ Arbitrary OpenAI Responses API parameters (e.g. ``temperature``, ``top_p``,
476
+ ``frequency_penalty``, ``presence_penalty``, ``seed``, ``max_output_tokens``, etc.)
477
+ are forwarded verbatim to the underlying API calls. These parameters are applied to
478
+ all API requests made by the UDF and override any parameters set in the task configuration.
399
479
 
400
480
  Returns:
401
481
  UserDefinedFunction: A Spark pandas UDF configured to execute the specified task
@@ -416,78 +496,136 @@ def task_udf(
416
496
  **Automatic Caching**: Duplicate inputs within each partition are cached,
417
497
  reducing API calls and costs significantly on datasets with repeated content.
418
498
  """
419
- # Serialize task parameters for Spark serialization compatibility
420
- task_instructions = task.instructions
421
- task_temperature = task.temperature
422
- task_top_p = task.top_p
499
+ # Merge task's api_kwargs with caller's api_kwargs (caller takes precedence)
500
+ merged_kwargs = {**task.api_kwargs, **api_kwargs}
423
501
 
424
- if issubclass(task.response_format, BaseModel):
425
- task_response_format_json = serialize_base_model(task.response_format)
502
+ return responses_udf(
503
+ instructions=task.instructions,
504
+ response_format=task.response_format,
505
+ model_name=model_name,
506
+ batch_size=batch_size,
507
+ max_concurrency=max_concurrency,
508
+ **merged_kwargs,
509
+ )
426
510
 
427
- # Deserialize the response format from JSON
428
- response_format = deserialize_base_model(task_response_format_json)
429
- spark_schema = _pydantic_to_spark_schema(response_format)
430
511
 
431
- @pandas_udf(returnType=spark_schema) # type: ignore[call-overload]
432
- def task_udf(col: Iterator[pd.Series]) -> Iterator[pd.DataFrame]:
433
- pandas_ext.responses_model(model_name)
434
- cache = AsyncBatchingMapProxy[str, response_format](
435
- batch_size=batch_size,
436
- max_concurrency=max_concurrency,
437
- )
512
+ def infer_schema(
513
+ instructions: str,
514
+ example_table_name: str,
515
+ example_field_name: str,
516
+ max_examples: int = 100,
517
+ ) -> InferredSchema:
518
+ """Infer the schema for a response format based on example data.
438
519
 
439
- try:
440
- for part in col:
441
- predictions: pd.Series = asyncio.run(
442
- part.aio.responses_with_cache(
443
- instructions=task_instructions,
444
- response_format=response_format,
445
- temperature=task_temperature,
446
- top_p=task_top_p,
447
- cache=cache,
448
- **api_kwargs,
449
- )
450
- )
451
- yield pd.DataFrame(predictions.map(_safe_dump).tolist())
452
- finally:
453
- asyncio.run(cache.clear())
520
+ This function retrieves examples from a Spark table and infers the schema
521
+ for the response format using the provided instructions. It is useful when
522
+ you want to dynamically generate a schema based on existing data.
454
523
 
455
- return task_udf # type: ignore[return-value]
524
+ Args:
525
+ instructions (str): Instructions for the model to infer the schema.
526
+ example_table_name (str | None): Name of the Spark table containing example data.
527
+ example_field_name (str | None): Name of the field in the table to use as examples.
528
+ max_examples (int): Maximum number of examples to retrieve for schema inference.
456
529
 
457
- elif issubclass(task.response_format, str):
530
+ Returns:
531
+ InferredSchema: An object containing the inferred schema and response format.
532
+ """
458
533
 
459
- @pandas_udf(returnType=StringType()) # type: ignore[call-overload]
460
- def task_string_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
461
- pandas_ext.responses_model(model_name)
462
- cache = AsyncBatchingMapProxy[str, str](
463
- batch_size=batch_size,
464
- max_concurrency=max_concurrency,
465
- )
534
+ spark = CONTAINER.resolve(SparkSession)
535
+ examples: list[str] = (
536
+ spark.table(example_table_name).rdd.map(lambda row: row[example_field_name]).takeSample(False, max_examples)
537
+ )
466
538
 
467
- try:
468
- for part in col:
469
- predictions: pd.Series = asyncio.run(
470
- part.aio.responses_with_cache(
471
- instructions=task_instructions,
472
- response_format=str,
473
- temperature=task_temperature,
474
- top_p=task_top_p,
475
- cache=cache,
476
- **api_kwargs,
477
- )
478
- )
479
- yield predictions.map(_safe_cast_str)
480
- finally:
481
- asyncio.run(cache.clear())
539
+ input = SchemaInferenceInput(
540
+ instructions=instructions,
541
+ examples=examples,
542
+ )
543
+ inferer = CONTAINER.resolve(SchemaInferer)
544
+ return inferer.infer_schema(input)
482
545
 
483
- return task_string_udf # type: ignore[return-value]
484
546
 
485
- else:
486
- raise ValueError(f"Unsupported response_format in task: {task.response_format}")
547
+ def parse_udf(
548
+ instructions: str,
549
+ response_format: type[ResponseFormat] | None = None,
550
+ example_table_name: str | None = None,
551
+ example_field_name: str | None = None,
552
+ max_examples: int = 100,
553
+ model_name: str = CONTAINER.resolve(ResponsesModelName).value,
554
+ batch_size: int | None = None,
555
+ max_concurrency: int = 8,
556
+ **api_kwargs,
557
+ ) -> UserDefinedFunction:
558
+ """Create an asynchronous Spark pandas UDF for parsing responses.
559
+ This function allows users to create UDFs that parse responses based on
560
+ provided instructions and either a predefined response format or example data.
561
+ It supports both structured responses using Pydantic models and plain text responses.
562
+ Each partition maintains its own cache to eliminate duplicate API calls within
563
+ the partition, significantly reducing API usage and costs when processing
564
+ datasets with overlapping content.
565
+
566
+ Args:
567
+ instructions (str): The system prompt or instructions for the model.
568
+ response_format (type[ResponseFormat] | None): The desired output format.
569
+ Either `str` for plain text or a Pydantic `BaseModel` for structured JSON output.
570
+ If not provided, the schema will be inferred from example data.
571
+ example_table_name (str | None): Name of the Spark table containing example data.
572
+ If provided, `example_field_name` must also be specified.
573
+ example_field_name (str | None): Name of the field in the table to use as examples.
574
+ If provided, `example_table_name` must also be specified.
575
+ max_examples (int): Maximum number of examples to retrieve for schema inference.
576
+ Defaults to 100.
577
+ model_name (str): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
578
+ For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container.
579
+ batch_size (int | None): Number of rows per async batch request within each partition.
580
+ Larger values reduce API call overhead but increase memory usage.
581
+ Defaults to None (automatic batch size optimization that dynamically
582
+ adjusts based on execution time, targeting 30-60 seconds per batch).
583
+ Set to a positive integer (e.g., 32-128) for fixed batch size
584
+ max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
585
+ Total cluster concurrency = max_concurrency × number_of_executors.
586
+ Higher values increase throughput but may hit OpenAI rate limits.
587
+ Recommended: 4-12 per executor. Defaults to 8.
588
+ **api_kwargs: Additional OpenAI API parameters (e.g. ``temperature``, ``top_p``,
589
+ ``frequency_penalty``, ``presence_penalty``, ``seed``, ``max_output_tokens``, etc.)
590
+ forwarded verbatim to the underlying API calls. These parameters are applied to
591
+ all API requests made by the UDF and override any parameters set in the
592
+ response_format or example data.
593
+ Returns:
594
+ UserDefinedFunction: A Spark pandas UDF configured to parse responses asynchronously.
595
+ Output schema is `StringType` for str response format or a struct derived from
596
+ the response_format for BaseModel.
597
+ Raises:
598
+ ValueError: If neither `response_format` nor `example_table_name` and `example_field_name` are provided.
599
+ """
600
+
601
+ if not response_format and not (example_field_name and example_table_name):
602
+ raise ValueError("Either response_format or example_table_name and example_field_name must be provided.")
603
+
604
+ schema: InferredSchema | None = None
605
+
606
+ if not response_format:
607
+ schema = infer_schema(
608
+ instructions=instructions,
609
+ example_table_name=example_table_name,
610
+ example_field_name=example_field_name,
611
+ max_examples=max_examples,
612
+ )
613
+
614
+ return responses_udf(
615
+ instructions=schema.inference_prompt if schema else instructions,
616
+ response_format=schema.model if schema else response_format,
617
+ model_name=model_name,
618
+ batch_size=batch_size,
619
+ max_concurrency=max_concurrency,
620
+ **api_kwargs,
621
+ )
487
622
 
488
623
 
489
624
  def embeddings_udf(
490
- model_name: str = "text-embedding-3-small", batch_size: int | None = None, max_concurrency: int = 8
625
+ model_name: str = CONTAINER.resolve(EmbeddingsModelName).value,
626
+ batch_size: int | None = None,
627
+ max_concurrency: int = 8,
628
+ **api_kwargs,
491
629
  ) -> UserDefinedFunction:
492
630
  """Create an asynchronous Spark pandas UDF for generating embeddings.
493
631
 
@@ -511,7 +649,8 @@ def embeddings_udf(
511
649
 
512
650
  Args:
513
651
  model_name (str): For Azure OpenAI, use your deployment name (e.g., "my-embedding-deployment").
514
- For OpenAI, use the model name (e.g., "text-embedding-3-small"). Defaults to "text-embedding-3-small".
652
+ For OpenAI, use the model name (e.g., "text-embedding-3-small").
653
+ Defaults to configured model in DI container.
515
654
  batch_size (int | None): Number of rows per async batch request within each partition.
516
655
  Larger values reduce API call overhead but increase memory usage.
517
656
  Defaults to None (automatic batch size optimization that dynamically
@@ -522,6 +661,7 @@ def embeddings_udf(
522
661
  Total cluster concurrency = max_concurrency × number_of_executors.
523
662
  Higher values increase throughput but may hit OpenAI rate limits.
524
663
  Recommended: 4-12 per executor. Defaults to 8.
664
+ **api_kwargs: Additional OpenAI API parameters (e.g., dimensions for text-embedding-3 models).
525
665
 
526
666
  Returns:
527
667
  UserDefinedFunction: A Spark pandas UDF configured to generate embeddings asynchronously
@@ -548,7 +688,7 @@ def embeddings_udf(
548
688
 
549
689
  try:
550
690
  for part in col:
551
- embeddings: pd.Series = asyncio.run(part.aio.embeddings_with_cache(cache=cache))
691
+ embeddings: pd.Series = asyncio.run(part.aio.embeddings_with_cache(cache=cache, **api_kwargs))
552
692
  yield embeddings.map(lambda x: x.tolist())
553
693
  finally:
554
694
  asyncio.run(cache.clear())
@@ -600,17 +740,15 @@ def count_tokens_udf() -> UserDefinedFunction:
600
740
 
601
741
 
602
742
  def similarity_udf() -> UserDefinedFunction:
603
- @pandas_udf(FloatType()) # type: ignore[call-overload]
604
- def fn(a: pd.Series, b: pd.Series) -> pd.Series:
605
- """Compute cosine similarity between two vectors.
743
+ """Create a pandas-UDF that computes cosine similarity between embedding vectors.
606
744
 
607
- Args:
608
- a: First vector.
609
- b: Second vector.
745
+ Returns:
746
+ UserDefinedFunction: A Spark pandas UDF that takes two embedding vector columns
747
+ and returns their cosine similarity as a FloatType column.
748
+ """
610
749
 
611
- Returns:
612
- Cosine similarity between the two vectors.
613
- """
750
+ @pandas_udf(FloatType()) # type: ignore[call-overload]
751
+ def fn(a: pd.Series, b: pd.Series) -> pd.Series:
614
752
  # Import pandas_ext to ensure .ai accessor is available in Spark workers
615
753
  from openaivec import pandas_ext
616
754
 
@@ -117,7 +117,7 @@ All tasks are built using the `PreparedTask` dataclass:
117
117
  @dataclass(frozen=True)
118
118
  class PreparedTask:
119
119
  instructions: str # Detailed prompt for the LLM
120
- response_format: Type[ResponseFormat] # Pydantic model or str for structured/plain output
120
+ response_format: type[ResponseFormat] # Pydantic model or str for structured/plain output
121
121
  temperature: float = 0.0 # Sampling temperature
122
122
  top_p: float = 1.0 # Nucleus sampling parameter
123
123
  ```
@@ -95,15 +95,12 @@ class CustomerSentiment(BaseModel):
95
95
  )
96
96
 
97
97
 
98
- def customer_sentiment(
99
- business_context: str = "general customer support", temperature: float = 0.0, top_p: float = 1.0
100
- ) -> PreparedTask:
98
+ def customer_sentiment(business_context: str = "general customer support", **api_kwargs) -> PreparedTask:
101
99
  """Create a configurable customer sentiment analysis task.
102
100
 
103
101
  Args:
104
102
  business_context (str): Business context for sentiment analysis.
105
- temperature (float): Sampling temperature (0.0-1.0).
106
- top_p (float): Nucleus sampling parameter (0.0-1.0).
103
+ **api_kwargs: Additional OpenAI API parameters (temperature, top_p, etc.).
107
104
 
108
105
  Returns:
109
106
  PreparedTask configured for customer sentiment analysis.
@@ -169,10 +166,8 @@ values like "positive" for sentiment.
169
166
 
170
167
  Provide comprehensive sentiment analysis with business context and recommended response strategy."""
171
168
 
172
- return PreparedTask(
173
- instructions=instructions, response_format=CustomerSentiment, temperature=temperature, top_p=top_p
174
- )
169
+ return PreparedTask(instructions=instructions, response_format=CustomerSentiment, api_kwargs=api_kwargs)
175
170
 
176
171
 
177
172
  # Backward compatibility - default configuration
178
- CUSTOMER_SENTIMENT = customer_sentiment()
173
+ CUSTOMER_SENTIMENT = customer_sentiment(temperature=0.0, top_p=1.0)
@@ -119,8 +119,7 @@ def inquiry_classification(
119
119
  priority_rules: Dict[str, str] | None = None,
120
120
  business_context: str = "general customer support",
121
121
  custom_keywords: Dict[str, list[str]] | None = None,
122
- temperature: float = 0.0,
123
- top_p: float = 1.0,
122
+ **api_kwargs,
124
123
  ) -> PreparedTask:
125
124
  """Create a configurable inquiry classification task.
126
125
 
@@ -133,8 +132,8 @@ def inquiry_classification(
133
132
  Default uses standard priority indicators.
134
133
  business_context (str): Description of the business context to help with classification.
135
134
  custom_keywords (dict[str, list[str]] | None): Dictionary mapping categories to relevant keywords.
136
- temperature (float): Sampling temperature (0.0-1.0).
137
- top_p (float): Nucleus sampling parameter (0.0-1.0).
135
+ **api_kwargs: Additional keyword arguments to pass to the OpenAI API,
136
+ such as temperature, top_p, etc.
138
137
 
139
138
  Returns:
140
139
  PreparedTask configured for inquiry classification.
@@ -254,10 +253,8 @@ language where appropriate, but priority must use English values like "high".
254
253
 
255
254
  Provide accurate classification with detailed reasoning."""
256
255
 
257
- return PreparedTask(
258
- instructions=instructions, response_format=InquiryClassification, temperature=temperature, top_p=top_p
259
- )
256
+ return PreparedTask(instructions=instructions, response_format=InquiryClassification, api_kwargs=api_kwargs)
260
257
 
261
258
 
262
259
  # Backward compatibility - default configuration
263
- INQUIRY_CLASSIFICATION = inquiry_classification()
260
+ INQUIRY_CLASSIFICATION = inquiry_classification(temperature=0.0, top_p=1.0)
@@ -87,16 +87,15 @@ class InquirySummary(BaseModel):
87
87
  def inquiry_summary(
88
88
  summary_length: str = "concise",
89
89
  business_context: str = "general customer support",
90
- temperature: float = 0.0,
91
- top_p: float = 1.0,
90
+ **api_kwargs,
92
91
  ) -> PreparedTask:
93
92
  """Create a configurable inquiry summary task.
94
93
 
95
94
  Args:
96
95
  summary_length (str): Length of summary (concise, detailed, bullet_points).
97
96
  business_context (str): Business context for summary.
98
- temperature (float): Sampling temperature (0.0-1.0).
99
- top_p (float): Nucleus sampling parameter (0.0-1.0).
97
+ **api_kwargs: Additional keyword arguments to pass to the OpenAI API,
98
+ such as temperature, top_p, etc.
100
99
 
101
100
  Returns:
102
101
  PreparedTask configured for inquiry summarization.
@@ -163,8 +162,8 @@ input is in German, provide all summary content in German, but use English value
163
162
 
164
163
  Provide accurate, actionable summary that enables efficient support resolution."""
165
164
 
166
- return PreparedTask(instructions=instructions, response_format=InquirySummary, temperature=temperature, top_p=top_p)
165
+ return PreparedTask(instructions=instructions, response_format=InquirySummary, api_kwargs=api_kwargs)
167
166
 
168
167
 
169
168
  # Backward compatibility - default configuration
170
- INQUIRY_SUMMARY = inquiry_summary()
169
+ INQUIRY_SUMMARY = inquiry_summary(temperature=0.0, top_p=1.0)