openaivec 0.14.10__py3-none-any.whl → 0.14.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
openaivec/_di.py CHANGED
@@ -303,3 +303,24 @@ class Container:
303
303
  self._providers.clear()
304
304
  self._instances.clear()
305
305
  self._resolving.clear()
306
+
307
+ def clear_singletons(self) -> None:
308
+ """Clear all cached singleton instances from the container.
309
+
310
+ Removes all cached singleton instances while keeping the registered
311
+ providers intact. After calling this method, the next resolve call
312
+ for any service will create a new instance using the provider function.
313
+
314
+ Example:
315
+ ```python
316
+ container = Container()
317
+ container.register(str, lambda: "Hello")
318
+ instance1 = container.resolve(str)
319
+ container.clear_singletons()
320
+ instance2 = container.resolve(str)
321
+ print(instance1 is instance2)
322
+ # False - different instances after clearing singletons
323
+ ```
324
+ """
325
+ with self._lock:
326
+ self._instances.clear()
openaivec/_provider.py CHANGED
@@ -130,35 +130,9 @@ def provide_async_openai_client() -> AsyncOpenAI:
130
130
  )
131
131
 
132
132
 
133
- CONTAINER.register(ResponsesModelName, lambda: ResponsesModelName("gpt-4.1-mini"))
134
- CONTAINER.register(EmbeddingsModelName, lambda: EmbeddingsModelName("text-embedding-3-small"))
135
- CONTAINER.register(OpenAIAPIKey, lambda: OpenAIAPIKey(os.getenv("OPENAI_API_KEY")))
136
- CONTAINER.register(AzureOpenAIAPIKey, lambda: AzureOpenAIAPIKey(os.getenv("AZURE_OPENAI_API_KEY")))
137
- CONTAINER.register(AzureOpenAIBaseURL, lambda: AzureOpenAIBaseURL(os.getenv("AZURE_OPENAI_BASE_URL")))
138
- CONTAINER.register(
139
- cls=AzureOpenAIAPIVersion,
140
- provider=lambda: AzureOpenAIAPIVersion(os.getenv("AZURE_OPENAI_API_VERSION", "preview")),
141
- )
142
- CONTAINER.register(OpenAI, provide_openai_client)
143
- CONTAINER.register(AsyncOpenAI, provide_async_openai_client)
144
- CONTAINER.register(tiktoken.Encoding, lambda: tiktoken.get_encoding("o200k_base"))
145
- CONTAINER.register(TextChunker, lambda: TextChunker(CONTAINER.resolve(tiktoken.Encoding)))
146
- CONTAINER.register(
147
- SchemaInferer,
148
- lambda: SchemaInferer(
149
- client=CONTAINER.resolve(OpenAI),
150
- model_name=CONTAINER.resolve(ResponsesModelName).value,
151
- ),
152
- )
153
-
154
-
155
- def reset_environment_registrations():
156
- """Reset environment variable related registrations in the container.
157
-
158
- This function re-registers environment variable dependent services to pick up
159
- current environment variable values. Useful for testing when environment
160
- variables are changed after initial container setup.
161
- """
133
+ def set_default_registrations():
134
+ CONTAINER.register(ResponsesModelName, lambda: ResponsesModelName("gpt-4.1-mini"))
135
+ CONTAINER.register(EmbeddingsModelName, lambda: EmbeddingsModelName("text-embedding-3-small"))
162
136
  CONTAINER.register(OpenAIAPIKey, lambda: OpenAIAPIKey(os.getenv("OPENAI_API_KEY")))
163
137
  CONTAINER.register(AzureOpenAIAPIKey, lambda: AzureOpenAIAPIKey(os.getenv("AZURE_OPENAI_API_KEY")))
164
138
  CONTAINER.register(AzureOpenAIBaseURL, lambda: AzureOpenAIBaseURL(os.getenv("AZURE_OPENAI_BASE_URL")))
@@ -168,6 +142,8 @@ def reset_environment_registrations():
168
142
  )
169
143
  CONTAINER.register(OpenAI, provide_openai_client)
170
144
  CONTAINER.register(AsyncOpenAI, provide_async_openai_client)
145
+ CONTAINER.register(tiktoken.Encoding, lambda: tiktoken.get_encoding("o200k_base"))
146
+ CONTAINER.register(TextChunker, lambda: TextChunker(CONTAINER.resolve(tiktoken.Encoding)))
171
147
  CONTAINER.register(
172
148
  SchemaInferer,
173
149
  lambda: SchemaInferer(
@@ -175,3 +151,6 @@ def reset_environment_registrations():
175
151
  model_name=CONTAINER.resolve(ResponsesModelName).value,
176
152
  ),
177
153
  )
154
+
155
+
156
+ set_default_registrations()
openaivec/pandas_ext.py CHANGED
@@ -454,6 +454,7 @@ class OpenAIVecSeriesAccessor:
454
454
  """Parse Series values using an LLM with a provided cache.
455
455
  This method allows you to parse the Series content into structured data
456
456
  using an LLM, optionally inferring a schema based on the provided purpose.
457
+
457
458
  Args:
458
459
  instructions (str): System prompt for the LLM.
459
460
  cache (BatchingMapProxy[str, BaseModel]): Explicit cache instance for
openaivec/spark.py CHANGED
@@ -1,45 +1,45 @@
1
1
  """Asynchronous Spark UDFs for the OpenAI and Azure OpenAI APIs.
2
2
 
3
3
  This module provides functions (`responses_udf`, `task_udf`, `embeddings_udf`,
4
- `count_tokens_udf`, `split_to_chunks_udf`)
4
+ `count_tokens_udf`, `split_to_chunks_udf`, `similarity_udf`, `parse_udf`)
5
5
  for creating asynchronous Spark UDFs that communicate with either the public
6
6
  OpenAI API or Azure OpenAI using the `openaivec.spark` subpackage.
7
- It supports UDFs for generating responses and creating embeddings asynchronously.
8
- The UDFs operate on Spark DataFrames and leverage asyncio for potentially
9
- improved performance in I/O-bound operations.
7
+ It supports UDFs for generating responses, creating embeddings, parsing text,
8
+ and computing similarities asynchronously. The UDFs operate on Spark DataFrames
9
+ and leverage asyncio for improved performance in I/O-bound operations.
10
10
 
11
- **Performance Optimization**: All AI-powered UDFs (`responses_udf`, `task_udf`, `embeddings_udf`)
11
+ **Performance Optimization**: All AI-powered UDFs (`responses_udf`, `task_udf`, `embeddings_udf`, `parse_udf`)
12
12
  automatically cache duplicate inputs within each partition, significantly reducing
13
13
  API calls and costs when processing datasets with overlapping content.
14
14
 
15
- __all__ = [
16
- "count_tokens_udf",
17
- "embeddings_udf",
18
- "responses_udf",
19
- "similarity_udf",
20
- "split_to_chunks_udf",
21
- "task_udf",
22
- ]
23
15
 
24
16
  ## Setup
25
17
 
26
18
  First, obtain a Spark session and configure authentication:
27
19
 
28
20
  ```python
29
- import os
30
21
  from pyspark.sql import SparkSession
22
+ from openaivec.spark import setup, setup_azure
31
23
 
32
24
  spark = SparkSession.builder.getOrCreate()
33
- sc = spark.sparkContext
34
25
 
35
- # Configure authentication via SparkContext environment variables
36
26
  # Option 1: Using OpenAI
37
- sc.environment["OPENAI_API_KEY"] = "your-openai-api-key"
27
+ setup(
28
+ spark,
29
+ api_key="your-openai-api-key",
30
+ responses_model_name="gpt-4.1-mini", # Optional: set default model
31
+ embeddings_model_name="text-embedding-3-small" # Optional: set default model
32
+ )
38
33
 
39
34
  # Option 2: Using Azure OpenAI
40
- # sc.environment["AZURE_OPENAI_API_KEY"] = "your-azure-openai-api-key"
41
- # sc.environment["AZURE_OPENAI_BASE_URL"] = "https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/"
42
- # sc.environment["AZURE_OPENAI_API_VERSION"] = "preview"
35
+ # setup_azure(
36
+ # spark,
37
+ # api_key="your-azure-openai-api-key",
38
+ # base_url="https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/",
39
+ # api_version="preview",
40
+ # responses_model_name="my-gpt4-deployment", # Optional: set default deployment
41
+ # embeddings_model_name="my-embedding-deployment" # Optional: set default deployment
42
+ # )
43
43
  ```
44
44
 
45
45
  Next, create UDFs and register them:
@@ -83,9 +83,10 @@ spark.udf.register(
83
83
  ),
84
84
  )
85
85
 
86
- # Register token counting and text chunking UDFs
86
+ # Register token counting, text chunking, and similarity UDFs
87
87
  spark.udf.register("count_tokens", count_tokens_udf())
88
88
  spark.udf.register("split_chunks", split_to_chunks_udf(max_tokens=512, sep=[".", "!", "?"]))
89
+ spark.udf.register("compute_similarity", similarity_udf())
89
90
  ```
90
91
 
91
92
  You can now invoke the UDFs from Spark SQL:
@@ -97,7 +98,8 @@ SELECT
97
98
  sentiment_async(text) AS sentiment,
98
99
  embed_async(text) AS embedding,
99
100
  count_tokens(text) AS token_count,
100
- split_chunks(text) AS chunks
101
+ split_chunks(text) AS chunks,
102
+ compute_similarity(embed_async(text1), embed_async(text2)) AS similarity
101
103
  FROM your_table;
102
104
  ```
103
105
 
@@ -123,6 +125,7 @@ Note: This module provides asynchronous support through the pandas extensions.
123
125
 
124
126
  import asyncio
125
127
  import logging
128
+ import os
126
129
  from collections.abc import Iterator
127
130
  from enum import Enum
128
131
  from typing import Union, get_args, get_origin
@@ -131,14 +134,17 @@ import numpy as np
131
134
  import pandas as pd
132
135
  import tiktoken
133
136
  from pydantic import BaseModel
137
+ from pyspark.sql import SparkSession
134
138
  from pyspark.sql.pandas.functions import pandas_udf
135
139
  from pyspark.sql.types import ArrayType, BooleanType, FloatType, IntegerType, StringType, StructField, StructType
136
140
  from pyspark.sql.udf import UserDefinedFunction
137
141
  from typing_extensions import Literal
138
142
 
139
143
  from openaivec import pandas_ext
140
- from openaivec._model import PreparedTask, ResponseFormat
144
+ from openaivec._model import EmbeddingsModelName, PreparedTask, ResponseFormat, ResponsesModelName
145
+ from openaivec._provider import CONTAINER
141
146
  from openaivec._proxy import AsyncBatchingMapProxy
147
+ from openaivec._schema import InferredSchema, SchemaInferenceInput, SchemaInferer
142
148
  from openaivec._serialize import deserialize_base_model, serialize_base_model
143
149
  from openaivec._util import TextChunker
144
150
 
@@ -146,6 +152,8 @@ __all__ = [
146
152
  "responses_udf",
147
153
  "task_udf",
148
154
  "embeddings_udf",
155
+ "infer_schema",
156
+ "parse_udf",
149
157
  "split_to_chunks_udf",
150
158
  "count_tokens_udf",
151
159
  "similarity_udf",
@@ -155,6 +163,80 @@ __all__ = [
155
163
  _LOGGER: logging.Logger = logging.getLogger(__name__)
156
164
 
157
165
 
166
+ def setup(
167
+ spark: SparkSession, api_key: str, responses_model_name: str | None = None, embeddings_model_name: str | None = None
168
+ ):
169
+ """Setup OpenAI authentication and default model names in Spark environment.
170
+ 1. Configures OpenAI API key in SparkContext environment.
171
+ 2. Configures OpenAI API key in local process environment.
172
+ 3. Optionally registers default model names for responses and embeddings in the DI container.
173
+
174
+ Args:
175
+ spark (SparkSession): The Spark session to configure.
176
+ api_key (str): OpenAI API key for authentication.
177
+ responses_model_name (str | None): Default model name for response generation.
178
+ If provided, registers `ResponsesModelName` in the DI container.
179
+ embeddings_model_name (str | None): Default model name for embeddings.
180
+ If provided, registers `EmbeddingsModelName` in the DI container.
181
+ """
182
+
183
+ sc = spark.sparkContext
184
+ sc.environment["OPENAI_API_KEY"] = api_key
185
+
186
+ os.environ["OPENAI_API_KEY"] = api_key
187
+
188
+ if responses_model_name:
189
+ CONTAINER.register(ResponsesModelName, lambda: ResponsesModelName(responses_model_name))
190
+
191
+ if embeddings_model_name:
192
+ from openaivec._model import EmbeddingsModelName
193
+
194
+ CONTAINER.register(EmbeddingsModelName, lambda: EmbeddingsModelName(embeddings_model_name))
195
+
196
+ CONTAINER.clear_singletons()
197
+
198
+
199
+ def setup_azure(
200
+ spark: SparkSession,
201
+ api_key: str,
202
+ base_url: str,
203
+ api_version: str = "preview",
204
+ responses_model_name: str | None = None,
205
+ embeddings_model_name: str | None = None,
206
+ ):
207
+ """Setup Azure OpenAI authentication and default model names in Spark environment.
208
+ 1. Configures Azure OpenAI API key, base URL, and API version in SparkContext environment.
209
+ 2. Configures Azure OpenAI API key, base URL, and API version in local process environment.
210
+ 3. Optionally registers default model names for responses and embeddings in the DI container.
211
+ Args:
212
+ spark (SparkSession): The Spark session to configure.
213
+ api_key (str): Azure OpenAI API key for authentication.
214
+ base_url (str): Base URL for the Azure OpenAI resource.
215
+ api_version (str): API version to use. Defaults to "preview".
216
+ responses_model_name (str | None): Default model name for response generation.
217
+ If provided, registers `ResponsesModelName` in the DI container.
218
+ embeddings_model_name (str | None): Default model name for embeddings.
219
+ If provided, registers `EmbeddingsModelName` in the DI container.
220
+ """
221
+
222
+ sc = spark.sparkContext
223
+ sc.environment["AZURE_OPENAI_API_KEY"] = api_key
224
+ sc.environment["AZURE_OPENAI_BASE_URL"] = base_url
225
+ sc.environment["AZURE_OPENAI_API_VERSION"] = api_version
226
+
227
+ os.environ["AZURE_OPENAI_API_KEY"] = api_key
228
+ os.environ["AZURE_OPENAI_BASE_URL"] = base_url
229
+ os.environ["AZURE_OPENAI_API_VERSION"] = api_version
230
+
231
+ if responses_model_name:
232
+ CONTAINER.register(ResponsesModelName, lambda: ResponsesModelName(responses_model_name))
233
+
234
+ if embeddings_model_name:
235
+ CONTAINER.register(EmbeddingsModelName, lambda: EmbeddingsModelName(embeddings_model_name))
236
+
237
+ CONTAINER.clear_singletons()
238
+
239
+
158
240
  def _python_type_to_spark(python_type):
159
241
  origin = get_origin(python_type)
160
242
 
@@ -233,7 +315,7 @@ def _safe_dump(x: BaseModel | None) -> dict:
233
315
  def responses_udf(
234
316
  instructions: str,
235
317
  response_format: type[ResponseFormat] = str,
236
- model_name: str = "gpt-4.1-mini",
318
+ model_name: str = CONTAINER.resolve(ResponsesModelName).value,
237
319
  batch_size: int | None = None,
238
320
  temperature: float | None = 0.0,
239
321
  top_p: float = 1.0,
@@ -265,7 +347,7 @@ def responses_udf(
265
347
  response_format (type[ResponseFormat]): The desired output format. Either `str` for plain text
266
348
  or a Pydantic `BaseModel` for structured JSON output. Defaults to `str`.
267
349
  model_name (str): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
268
- For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to "gpt-4.1-mini".
350
+ For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container.
269
351
  batch_size (int | None): Number of rows per async batch request within each partition.
270
352
  Larger values reduce API call overhead but increase memory usage.
271
353
  Defaults to None (automatic batch size optimization that dynamically
@@ -363,7 +445,7 @@ def responses_udf(
363
445
 
364
446
  def task_udf(
365
447
  task: PreparedTask[ResponseFormat],
366
- model_name: str = "gpt-4.1-mini",
448
+ model_name: str = CONTAINER.resolve(ResponsesModelName).value,
367
449
  batch_size: int | None = None,
368
450
  max_concurrency: int = 8,
369
451
  **api_kwargs,
@@ -380,7 +462,7 @@ def task_udf(
380
462
  task (PreparedTask): A predefined task configuration containing instructions,
381
463
  response format, temperature, and top_p settings.
382
464
  model_name (str): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
383
- For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to "gpt-4.1-mini".
465
+ For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container.
384
466
  batch_size (int | None): Number of rows per async batch request within each partition.
385
467
  Larger values reduce API call overhead but increase memory usage.
386
468
  Defaults to None (automatic batch size optimization that dynamically
@@ -416,78 +498,142 @@ def task_udf(
416
498
  **Automatic Caching**: Duplicate inputs within each partition are cached,
417
499
  reducing API calls and costs significantly on datasets with repeated content.
418
500
  """
419
- # Serialize task parameters for Spark serialization compatibility
420
- task_instructions = task.instructions
421
- task_temperature = task.temperature
422
- task_top_p = task.top_p
501
+ return responses_udf(
502
+ instructions=task.instructions,
503
+ response_format=task.response_format,
504
+ model_name=model_name,
505
+ batch_size=batch_size,
506
+ temperature=task.temperature,
507
+ top_p=task.top_p,
508
+ max_concurrency=max_concurrency,
509
+ **api_kwargs,
510
+ )
511
+
512
+
513
+ def infer_schema(
514
+ instructions: str,
515
+ example_table_name: str,
516
+ example_field_name: str,
517
+ max_examples: int = 100,
518
+ ) -> InferredSchema:
519
+ """Infer the schema for a response format based on example data.
423
520
 
424
- if issubclass(task.response_format, BaseModel):
425
- task_response_format_json = serialize_base_model(task.response_format)
521
+ This function retrieves examples from a Spark table and infers the schema
522
+ for the response format using the provided instructions. It is useful when
523
+ you want to dynamically generate a schema based on existing data.
426
524
 
427
- # Deserialize the response format from JSON
428
- response_format = deserialize_base_model(task_response_format_json)
429
- spark_schema = _pydantic_to_spark_schema(response_format)
525
+ Args:
526
+ instructions (str): Instructions for the model to infer the schema.
527
+ example_table_name (str | None): Name of the Spark table containing example data.
528
+ example_field_name (str | None): Name of the field in the table to use as examples.
529
+ max_examples (int): Maximum number of examples to retrieve for schema inference.
430
530
 
431
- @pandas_udf(returnType=spark_schema) # type: ignore[call-overload]
432
- def task_udf(col: Iterator[pd.Series]) -> Iterator[pd.DataFrame]:
433
- pandas_ext.responses_model(model_name)
434
- cache = AsyncBatchingMapProxy[str, response_format](
435
- batch_size=batch_size,
436
- max_concurrency=max_concurrency,
437
- )
531
+ Returns:
532
+ InferredSchema: An object containing the inferred schema and response format.
533
+ """
438
534
 
439
- try:
440
- for part in col:
441
- predictions: pd.Series = asyncio.run(
442
- part.aio.responses_with_cache(
443
- instructions=task_instructions,
444
- response_format=response_format,
445
- temperature=task_temperature,
446
- top_p=task_top_p,
447
- cache=cache,
448
- **api_kwargs,
449
- )
450
- )
451
- yield pd.DataFrame(predictions.map(_safe_dump).tolist())
452
- finally:
453
- asyncio.run(cache.clear())
535
+ from pyspark.sql import SparkSession
454
536
 
455
- return task_udf # type: ignore[return-value]
537
+ spark = SparkSession.builder.getOrCreate()
538
+ examples: list[str] = (
539
+ spark.table(example_table_name).rdd.map(lambda row: row[example_field_name]).takeSample(False, max_examples)
540
+ )
456
541
 
457
- elif issubclass(task.response_format, str):
542
+ input = SchemaInferenceInput(
543
+ purpose=instructions,
544
+ examples=examples,
545
+ )
546
+ inferer = CONTAINER.resolve(SchemaInferer)
547
+ return inferer.infer_schema(input)
458
548
 
459
- @pandas_udf(returnType=StringType()) # type: ignore[call-overload]
460
- def task_string_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
461
- pandas_ext.responses_model(model_name)
462
- cache = AsyncBatchingMapProxy[str, str](
463
- batch_size=batch_size,
464
- max_concurrency=max_concurrency,
465
- )
466
549
 
467
- try:
468
- for part in col:
469
- predictions: pd.Series = asyncio.run(
470
- part.aio.responses_with_cache(
471
- instructions=task_instructions,
472
- response_format=str,
473
- temperature=task_temperature,
474
- top_p=task_top_p,
475
- cache=cache,
476
- **api_kwargs,
477
- )
478
- )
479
- yield predictions.map(_safe_cast_str)
480
- finally:
481
- asyncio.run(cache.clear())
550
+ def parse_udf(
551
+ instructions: str,
552
+ response_format: type[ResponseFormat] | None = None,
553
+ example_table_name: str | None = None,
554
+ example_field_name: str | None = None,
555
+ max_examples: int = 100,
556
+ model_name: str = CONTAINER.resolve(ResponsesModelName).value,
557
+ batch_size: int | None = None,
558
+ temperature: float | None = 0.0,
559
+ top_p: float = 1.0,
560
+ max_concurrency: int = 8,
561
+ **api_kwargs,
562
+ ) -> UserDefinedFunction:
563
+ """Create an asynchronous Spark pandas UDF for parsing responses.
564
+ This function allows users to create UDFs that parse responses based on
565
+ provided instructions and either a predefined response format or example data.
566
+ It supports both structured responses using Pydantic models and plain text responses.
567
+ Each partition maintains its own cache to eliminate duplicate API calls within
568
+ the partition, significantly reducing API usage and costs when processing
569
+ datasets with overlapping content.
482
570
 
483
- return task_string_udf # type: ignore[return-value]
571
+ Args:
572
+ instructions (str): The system prompt or instructions for the model.
573
+ response_format (type[ResponseFormat] | None): The desired output format.
574
+ Either `str` for plain text or a Pydantic `BaseModel` for structured JSON output.
575
+ If not provided, the schema will be inferred from example data.
576
+ example_table_name (str | None): Name of the Spark table containing example data.
577
+ If provided, `example_field_name` must also be specified.
578
+ example_field_name (str | None): Name of the field in the table to use as examples.
579
+ If provided, `example_table_name` must also be specified.
580
+ max_examples (int): Maximum number of examples to retrieve for schema inference.
581
+ Defaults to 100.
582
+ model_name (str): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
583
+ For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container.
584
+ batch_size (int | None): Number of rows per async batch request within each partition.
585
+ Larger values reduce API call overhead but increase memory usage.
586
+ Defaults to None (automatic batch size optimization that dynamically
587
+ adjusts based on execution time, targeting 30-60 seconds per batch).
588
+ Set to a positive integer (e.g., 32-128) for fixed batch size
589
+ temperature (float | None): Sampling temperature (0.0 to 2.0). Defaults to 0.0.
590
+ top_p (float): Nucleus sampling parameter. Defaults to 1.0.
591
+ max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
592
+ Total cluster concurrency = max_concurrency × number_of_executors.
593
+ Higher values increase throughput but may hit OpenAI rate limits.
594
+ Recommended: 4-12 per executor. Defaults to 8.
595
+ Additional Keyword Args:
596
+ Arbitrary OpenAI Responses API parameters (e.g. ``frequency_penalty``, ``presence_penalty``,
597
+ ``seed``, ``max_output_tokens``, etc.) are forwarded verbatim to the underlying API calls.
598
+ These parameters are applied to all API requests made by the UDF and override any
599
+ parameters set in the response_format or example data.
600
+ Returns:
601
+ UserDefinedFunction: A Spark pandas UDF configured to parse responses asynchronously.
602
+ Output schema is `StringType` for str response format or a struct derived from
603
+ the response_format for BaseModel.
604
+ Raises:
605
+ ValueError: If neither `response_format` nor `example_table_name` and `example_field_name` are provided.
606
+ """
484
607
 
485
- else:
486
- raise ValueError(f"Unsupported response_format in task: {task.response_format}")
608
+ if not response_format and not (example_field_name and example_table_name):
609
+ raise ValueError("Either response_format or example_table_name and example_field_name must be provided.")
610
+
611
+ schema: InferredSchema | None = None
612
+
613
+ if not response_format:
614
+ schema = infer_schema(
615
+ instructions=instructions,
616
+ example_table_name=example_table_name,
617
+ example_field_name=example_field_name,
618
+ max_examples=max_examples,
619
+ )
620
+
621
+ return responses_udf(
622
+ instructions=schema.inference_prompt if schema else instructions,
623
+ response_format=schema.model if schema else response_format,
624
+ model_name=model_name,
625
+ batch_size=batch_size,
626
+ temperature=temperature,
627
+ top_p=top_p,
628
+ max_concurrency=max_concurrency,
629
+ **api_kwargs,
630
+ )
487
631
 
488
632
 
489
633
  def embeddings_udf(
490
- model_name: str = "text-embedding-3-small", batch_size: int | None = None, max_concurrency: int = 8
634
+ model_name: str = CONTAINER.resolve(EmbeddingsModelName).value,
635
+ batch_size: int | None = None,
636
+ max_concurrency: int = 8,
491
637
  ) -> UserDefinedFunction:
492
638
  """Create an asynchronous Spark pandas UDF for generating embeddings.
493
639
 
@@ -511,7 +657,8 @@ def embeddings_udf(
511
657
 
512
658
  Args:
513
659
  model_name (str): For Azure OpenAI, use your deployment name (e.g., "my-embedding-deployment").
514
- For OpenAI, use the model name (e.g., "text-embedding-3-small"). Defaults to "text-embedding-3-small".
660
+ For OpenAI, use the model name (e.g., "text-embedding-3-small").
661
+ Defaults to configured model in DI container.
515
662
  batch_size (int | None): Number of rows per async batch request within each partition.
516
663
  Larger values reduce API call overhead but increase memory usage.
517
664
  Defaults to None (automatic batch size optimization that dynamically
@@ -600,17 +747,15 @@ def count_tokens_udf() -> UserDefinedFunction:
600
747
 
601
748
 
602
749
  def similarity_udf() -> UserDefinedFunction:
603
- @pandas_udf(FloatType()) # type: ignore[call-overload]
604
- def fn(a: pd.Series, b: pd.Series) -> pd.Series:
605
- """Compute cosine similarity between two vectors.
750
+ """Create a pandas-UDF that computes cosine similarity between embedding vectors.
606
751
 
607
- Args:
608
- a: First vector.
609
- b: Second vector.
752
+ Returns:
753
+ UserDefinedFunction: A Spark pandas UDF that takes two embedding vector columns
754
+ and returns their cosine similarity as a FloatType column.
755
+ """
610
756
 
611
- Returns:
612
- Cosine similarity between the two vectors.
613
- """
757
+ @pandas_udf(FloatType()) # type: ignore[call-overload]
758
+ def fn(a: pd.Series, b: pd.Series) -> pd.Series:
614
759
  # Import pandas_ext to ensure .ai accessor is available in Spark workers
615
760
  from openaivec import pandas_ext
616
761
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: openaivec
3
- Version: 0.14.10
3
+ Version: 0.14.12
4
4
  Summary: Generative mutation for tabular calculation
5
5
  Project-URL: Homepage, https://microsoft.github.io/openaivec/
6
6
  Project-URL: Repository, https://github.com/microsoft/openaivec
@@ -334,26 +334,34 @@ Scale to enterprise datasets with distributed processing:
334
334
  First, obtain a Spark session and configure authentication:
335
335
 
336
336
  ```python
337
- import os
338
337
  from pyspark.sql import SparkSession
338
+ from openaivec.spark import setup, setup_azure
339
339
 
340
340
  spark = SparkSession.builder.getOrCreate()
341
- sc = spark.sparkContext
342
341
 
343
- # Configure authentication via SparkContext environment variables
344
342
  # Option 1: Using OpenAI
345
- sc.environment["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY")
343
+ setup(
344
+ spark,
345
+ api_key="your-openai-api-key",
346
+ responses_model_name="gpt-4.1-mini", # Optional: set default model
347
+ embeddings_model_name="text-embedding-3-small" # Optional: set default model
348
+ )
346
349
 
347
350
  # Option 2: Using Azure OpenAI
348
- # sc.environment["AZURE_OPENAI_API_KEY"] = os.environ.get("AZURE_OPENAI_API_KEY")
349
- # sc.environment["AZURE_OPENAI_BASE_URL"] = os.environ.get("AZURE_OPENAI_BASE_URL")
350
- # sc.environment["AZURE_OPENAI_API_VERSION"] = os.environ.get("AZURE_OPENAI_API_VERSION")
351
+ # setup_azure(
352
+ # spark,
353
+ # api_key="your-azure-openai-api-key",
354
+ # base_url="https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/",
355
+ # api_version="preview",
356
+ # responses_model_name="my-gpt4-deployment", # Optional: set default deployment
357
+ # embeddings_model_name="my-embedding-deployment" # Optional: set default deployment
358
+ # )
351
359
  ```
352
360
 
353
361
  Next, create and register UDFs using the provided functions:
354
362
 
355
363
  ```python
356
- from openaivec.spark import responses_udf, task_udf, embeddings_udf, count_tokens_udf
364
+ from openaivec.spark import responses_udf, task_udf, embeddings_udf, count_tokens_udf, similarity_udf, parse_udf
357
365
  from pydantic import BaseModel
358
366
 
359
367
  # --- Register Responses UDF (String Output) ---
@@ -387,6 +395,9 @@ spark.udf.register(
387
395
  # --- Register Token Counting UDF ---
388
396
  spark.udf.register("count_tokens", count_tokens_udf())
389
397
 
398
+ # --- Register Similarity UDF ---
399
+ spark.udf.register("compute_similarity", similarity_udf())
400
+
390
401
  # --- Register UDFs with Pre-configured Tasks ---
391
402
  from openaivec.task import nlp, customer_support
392
403
 
@@ -414,6 +425,17 @@ spark.udf.register(
414
425
  )
415
426
  )
416
427
 
428
+ # --- Register Parse UDF (Dynamic Schema Inference) ---
429
+ spark.udf.register(
430
+ "parse_dynamic",
431
+ parse_udf(
432
+ instructions="Extract key entities and attributes from the text",
433
+ example_table_name="sample_texts", # Infer schema from examples
434
+ example_field_name="text",
435
+ max_examples=50
436
+ )
437
+ )
438
+
417
439
  ```
418
440
 
419
441
  You can now use these UDFs in Spark SQL:
@@ -691,17 +713,19 @@ steps:
691
713
  - In the notebook, import and use `openaivec.spark` functions as you normally would. For example:
692
714
 
693
715
  ```python
694
- import os
695
- from openaivec.spark import responses_udf, embeddings_udf
716
+ from openaivec.spark import setup_azure, responses_udf, embeddings_udf
696
717
 
697
718
  # In Microsoft Fabric, spark session is automatically available
698
719
  # spark = SparkSession.builder.getOrCreate()
699
- sc = spark.sparkContext
700
-
720
+
701
721
  # Configure Azure OpenAI authentication
702
- sc.environment["AZURE_OPENAI_API_KEY"] = "<your-api-key>"
703
- sc.environment["AZURE_OPENAI_BASE_URL"] = "https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/"
704
- sc.environment["AZURE_OPENAI_API_VERSION"] = "preview"
722
+ setup_azure(
723
+ spark,
724
+ api_key="<your-api-key>",
725
+ base_url="https://YOUR-RESOURCE-NAME.services.ai.azure.com/openai/v1/",
726
+ api_version="preview",
727
+ responses_model_name="my-gpt4-deployment" # Your Azure deployment name
728
+ )
705
729
 
706
730
  # Register UDFs
707
731
  spark.udf.register(
@@ -1,19 +1,19 @@
1
1
  openaivec/__init__.py,sha256=mXCGNNTjYbmE4CAXGvAs78soxUsoy_mxxnvaCk_CL6Y,361
2
- openaivec/_di.py,sha256=1MXaBzaH_ZenQnWKQzBY2z-egHwiteMvg7byoUH3ZZI,10658
2
+ openaivec/_di.py,sha256=Cl1ZoNBlQsJL1bpzoMDl08uT9pZFVSlqOdLbS3_MwPE,11462
3
3
  openaivec/_dynamic.py,sha256=7ZaC59w2Edemnao57XeZVO4qmSOA-Kus6TchZC3Dd5o,14821
4
4
  openaivec/_embeddings.py,sha256=upCjl8m9h1CihP6t7wvIH_vivOAPSgmgooAxIhnUMUw,7449
5
5
  openaivec/_log.py,sha256=LHNs6AbJzM4weaRARZFroigxR6D148d7WSIMLk1IhbU,1439
6
6
  openaivec/_model.py,sha256=toS2oBubrJa9jrdYy-87Fb2XivjXUlk_8Zn5gKUAcFI,3345
7
7
  openaivec/_optimize.py,sha256=3nS8VehbS7iGC1tPDDQh-iAgyKHbVYmMbCRBWM77U_U,3827
8
8
  openaivec/_prompt.py,sha256=zLv13q47CKV3jnETUyWAIlnjXFSEMs70c8m0yN7_Hek,20820
9
- openaivec/_provider.py,sha256=YLrEcb4aWBD1fj0n6PNcJpCtEXK6jkUuRH_WxcLDCuI,7145
9
+ openaivec/_provider.py,sha256=8z8gPYY5-Z7rzDlj_NC6hR__DUqVAH7VLHJn6LalzRg,6158
10
10
  openaivec/_proxy.py,sha256=AiGuC1MCFjZCRXCac-pHUI3Np3nf1HIpWY6nC9ZVCFY,29671
11
11
  openaivec/_responses.py,sha256=lVJRa_Uc7hQJnYJRgumqwBbu6GToZqsLFS6tIAFO1Fc,24014
12
12
  openaivec/_schema.py,sha256=RKjDPqet1TlReYibah0R0NIvCV1VWN5SZxiaBeV0gCY,15492
13
13
  openaivec/_serialize.py,sha256=u2Om94Sc_QgJkTlW2BAGw8wd6gYDhc6IRqvS-qevFSs,8399
14
14
  openaivec/_util.py,sha256=XfueAycVCQvgRLS7wF7e306b53lebORvZOBzbQjy4vE,6438
15
- openaivec/pandas_ext.py,sha256=_MdiZWokius62zI_sTp_nd-33fMNlnRHbyqso0eF_Hw,85406
16
- openaivec/spark.py,sha256=Dbuhlk8Z89Fwk3fbWp1Ud9uTpfNyfjZOIx8ARJMnQf0,25371
15
+ openaivec/pandas_ext.py,sha256=fjBW_TU4zsew3j7g7x67t9ESCwZ0fIuxbh9bZdOmRA0,85407
16
+ openaivec/spark.py,sha256=V0Gg9b9Q-2ycet33ENAN21aA-GltNj57tWoE2pCZIRQ,32601
17
17
  openaivec/task/__init__.py,sha256=lrgoc9UIox7XnxZ96dQRl88a-8QfuZRFBHshxctpMB8,6178
18
18
  openaivec/task/customer_support/__init__.py,sha256=KWfGyXPdZyfGdRH17x7hPpJJ1N2EP9PPhZx0fvBAwSI,884
19
19
  openaivec/task/customer_support/customer_sentiment.py,sha256=NHIr9nm2d2Bu1MSpxFsM3_w1UuQrQEwnHrClVbhdCUw,7612
@@ -31,7 +31,7 @@ openaivec/task/nlp/sentiment_analysis.py,sha256=Np-yY0d4Kr5WEjGjq4tNFHDNarBLajJr
31
31
  openaivec/task/nlp/translation.py,sha256=VYgiXtr2TL1tbqZkBpyVAy4ahrgd8UO4ZjhIL6xMdkI,6609
32
32
  openaivec/task/table/__init__.py,sha256=kJz15WDJXjyC7UIHKBvlTRhCf347PCDMH5T5fONV2sU,83
33
33
  openaivec/task/table/fillna.py,sha256=g_CpLnLzK1C5rCiVq15L3X0kywJK6CtSrKRYxQFuhn8,6606
34
- openaivec-0.14.10.dist-info/METADATA,sha256=BXQWevriu4qabbZM1paMO1PV_i8zmFPqiodTMwzeJnQ,27567
35
- openaivec-0.14.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
- openaivec-0.14.10.dist-info/licenses/LICENSE,sha256=ws_MuBL-SCEBqPBFl9_FqZkaaydIJmxHrJG2parhU4M,1141
37
- openaivec-0.14.10.dist-info/RECORD,,
34
+ openaivec-0.14.12.dist-info/METADATA,sha256=GC5evUtog4LhK1XhJXfF-jO9DeyDq7l9Ii8KN1sVIBo,28216
35
+ openaivec-0.14.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
+ openaivec-0.14.12.dist-info/licenses/LICENSE,sha256=ws_MuBL-SCEBqPBFl9_FqZkaaydIJmxHrJG2parhU4M,1141
37
+ openaivec-0.14.12.dist-info/RECORD,,