openaivec 0.14.12__py3-none-any.whl → 0.14.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
openaivec/spark.py CHANGED
@@ -134,6 +134,7 @@ import numpy as np
134
134
  import pandas as pd
135
135
  import tiktoken
136
136
  from pydantic import BaseModel
137
+ from pyspark import SparkContext
137
138
  from pyspark.sql import SparkSession
138
139
  from pyspark.sql.pandas.functions import pandas_udf
139
140
  from pyspark.sql.types import ArrayType, BooleanType, FloatType, IntegerType, StringType, StructField, StructType
@@ -180,7 +181,10 @@ def setup(
180
181
  If provided, registers `EmbeddingsModelName` in the DI container.
181
182
  """
182
183
 
183
- sc = spark.sparkContext
184
+ CONTAINER.register(SparkSession, lambda: spark)
185
+ CONTAINER.register(SparkContext, lambda: CONTAINER.resolve(SparkSession).sparkContext)
186
+
187
+ sc = CONTAINER.resolve(SparkContext)
184
188
  sc.environment["OPENAI_API_KEY"] = api_key
185
189
 
186
190
  os.environ["OPENAI_API_KEY"] = api_key
@@ -189,8 +193,6 @@ def setup(
189
193
  CONTAINER.register(ResponsesModelName, lambda: ResponsesModelName(responses_model_name))
190
194
 
191
195
  if embeddings_model_name:
192
- from openaivec._model import EmbeddingsModelName
193
-
194
196
  CONTAINER.register(EmbeddingsModelName, lambda: EmbeddingsModelName(embeddings_model_name))
195
197
 
196
198
  CONTAINER.clear_singletons()
@@ -219,7 +221,10 @@ def setup_azure(
219
221
  If provided, registers `EmbeddingsModelName` in the DI container.
220
222
  """
221
223
 
222
- sc = spark.sparkContext
224
+ CONTAINER.register(SparkSession, lambda: spark)
225
+ CONTAINER.register(SparkContext, lambda: CONTAINER.resolve(SparkSession).sparkContext)
226
+
227
+ sc = CONTAINER.resolve(SparkContext)
223
228
  sc.environment["AZURE_OPENAI_API_KEY"] = api_key
224
229
  sc.environment["AZURE_OPENAI_BASE_URL"] = base_url
225
230
  sc.environment["AZURE_OPENAI_API_VERSION"] = api_version
@@ -237,6 +242,50 @@ def setup_azure(
237
242
  CONTAINER.clear_singletons()
238
243
 
239
244
 
245
+ def set_responses_model(model_name: str):
246
+ """Set the default model name for response generation in the DI container.
247
+
248
+ Args:
249
+ model_name (str): The model name to set as default for responses.
250
+ """
251
+ CONTAINER.register(ResponsesModelName, lambda: ResponsesModelName(model_name))
252
+ CONTAINER.clear_singletons()
253
+
254
+
255
+ def get_responses_model() -> str | None:
256
+ """Get the default model name for response generation from the DI container.
257
+
258
+ Returns:
259
+ str | None: The default model name for responses, or None if not set.
260
+ """
261
+ try:
262
+ return CONTAINER.resolve(ResponsesModelName).value
263
+ except Exception:
264
+ return None
265
+
266
+
267
+ def set_embeddings_model(model_name: str):
268
+ """Set the default model name for embeddings in the DI container.
269
+
270
+ Args:
271
+ model_name (str): The model name to set as default for embeddings.
272
+ """
273
+ CONTAINER.register(EmbeddingsModelName, lambda: EmbeddingsModelName(model_name))
274
+ CONTAINER.clear_singletons()
275
+
276
+
277
+ def get_embeddings_model() -> str | None:
278
+ """Get the default model name for embeddings from the DI container.
279
+
280
+ Returns:
281
+ str | None: The default model name for embeddings, or None if not set.
282
+ """
283
+ try:
284
+ return CONTAINER.resolve(EmbeddingsModelName).value
285
+ except Exception:
286
+ return None
287
+
288
+
240
289
  def _python_type_to_spark(python_type):
241
290
  origin = get_origin(python_type)
242
291
 
@@ -315,10 +364,8 @@ def _safe_dump(x: BaseModel | None) -> dict:
315
364
  def responses_udf(
316
365
  instructions: str,
317
366
  response_format: type[ResponseFormat] = str,
318
- model_name: str = CONTAINER.resolve(ResponsesModelName).value,
367
+ model_name: str | None = None,
319
368
  batch_size: int | None = None,
320
- temperature: float | None = 0.0,
321
- top_p: float = 1.0,
322
369
  max_concurrency: int = 8,
323
370
  **api_kwargs,
324
371
  ) -> UserDefinedFunction:
@@ -346,24 +393,22 @@ def responses_udf(
346
393
  instructions (str): The system prompt or instructions for the model.
347
394
  response_format (type[ResponseFormat]): The desired output format. Either `str` for plain text
348
395
  or a Pydantic `BaseModel` for structured JSON output. Defaults to `str`.
349
- model_name (str): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
350
- For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container.
396
+ model_name (str | None): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
397
+ For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container
398
+ via ResponsesModelName if not provided.
351
399
  batch_size (int | None): Number of rows per async batch request within each partition.
352
400
  Larger values reduce API call overhead but increase memory usage.
353
401
  Defaults to None (automatic batch size optimization that dynamically
354
402
  adjusts based on execution time, targeting 30-60 seconds per batch).
355
403
  Set to a positive integer (e.g., 32-128) for fixed batch size.
356
- temperature (float): Sampling temperature (0.0 to 2.0). Defaults to 0.0.
357
- top_p (float): Nucleus sampling parameter. Defaults to 1.0.
358
404
  max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
359
405
  Total cluster concurrency = max_concurrency × number_of_executors.
360
406
  Higher values increase throughput but may hit OpenAI rate limits.
361
407
  Recommended: 4-12 per executor. Defaults to 8.
362
-
363
- Additional Keyword Args:
364
- Arbitrary OpenAI Responses API parameters (e.g. ``frequency_penalty``, ``presence_penalty``,
365
- ``seed``, ``max_output_tokens``, etc.) are forwarded verbatim to the underlying API calls.
366
- These parameters are applied to all API requests made by the UDF.
408
+ **api_kwargs: Additional OpenAI API parameters (e.g. ``temperature``, ``top_p``,
409
+ ``frequency_penalty``, ``presence_penalty``, ``seed``, ``max_output_tokens``, etc.)
410
+ forwarded verbatim to the underlying API calls. These parameters are applied to
411
+ all API requests made by the UDF.
367
412
 
368
413
  Returns:
369
414
  UserDefinedFunction: A Spark pandas UDF configured to generate responses asynchronously.
@@ -380,13 +425,15 @@ def responses_udf(
380
425
  - Consider your OpenAI tier limits: total_requests = max_concurrency × executors
381
426
  - Use Spark UI to optimize partition sizes relative to batch_size
382
427
  """
428
+ _model_name = model_name or CONTAINER.resolve(ResponsesModelName).value
429
+
383
430
  if issubclass(response_format, BaseModel):
384
431
  spark_schema = _pydantic_to_spark_schema(response_format)
385
432
  json_schema_string = serialize_base_model(response_format)
386
433
 
387
434
  @pandas_udf(returnType=spark_schema) # type: ignore[call-overload]
388
435
  def structure_udf(col: Iterator[pd.Series]) -> Iterator[pd.DataFrame]:
389
- pandas_ext.responses_model(model_name)
436
+ pandas_ext.responses_model(_model_name)
390
437
  response_format = deserialize_base_model(json_schema_string)
391
438
  cache = AsyncBatchingMapProxy[str, response_format](
392
439
  batch_size=batch_size,
@@ -399,8 +446,6 @@ def responses_udf(
399
446
  part.aio.responses_with_cache(
400
447
  instructions=instructions,
401
448
  response_format=response_format,
402
- temperature=temperature,
403
- top_p=top_p,
404
449
  cache=cache,
405
450
  **api_kwargs,
406
451
  )
@@ -415,7 +460,7 @@ def responses_udf(
415
460
 
416
461
  @pandas_udf(returnType=StringType()) # type: ignore[call-overload]
417
462
  def string_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
418
- pandas_ext.responses_model(model_name)
463
+ pandas_ext.responses_model(_model_name)
419
464
  cache = AsyncBatchingMapProxy[str, str](
420
465
  batch_size=batch_size,
421
466
  max_concurrency=max_concurrency,
@@ -427,8 +472,6 @@ def responses_udf(
427
472
  part.aio.responses_with_cache(
428
473
  instructions=instructions,
429
474
  response_format=str,
430
- temperature=temperature,
431
- top_p=top_p,
432
475
  cache=cache,
433
476
  **api_kwargs,
434
477
  )
@@ -445,7 +488,7 @@ def responses_udf(
445
488
 
446
489
  def task_udf(
447
490
  task: PreparedTask[ResponseFormat],
448
- model_name: str = CONTAINER.resolve(ResponsesModelName).value,
491
+ model_name: str | None = None,
449
492
  batch_size: int | None = None,
450
493
  max_concurrency: int = 8,
451
494
  **api_kwargs,
@@ -460,9 +503,10 @@ def task_udf(
460
503
 
461
504
  Args:
462
505
  task (PreparedTask): A predefined task configuration containing instructions,
463
- response format, temperature, and top_p settings.
464
- model_name (str): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
465
- For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container.
506
+ response format, and API parameters.
507
+ model_name (str | None): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
508
+ For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container
509
+ via ResponsesModelName if not provided.
466
510
  batch_size (int | None): Number of rows per async batch request within each partition.
467
511
  Larger values reduce API call overhead but increase memory usage.
468
512
  Defaults to None (automatic batch size optimization that dynamically
@@ -474,10 +518,10 @@ def task_udf(
474
518
  Recommended: 4-12 per executor. Defaults to 8.
475
519
 
476
520
  Additional Keyword Args:
477
- Arbitrary OpenAI Responses API parameters (e.g. ``frequency_penalty``, ``presence_penalty``,
478
- ``seed``, ``max_output_tokens``, etc.) are forwarded verbatim to the underlying API calls.
479
- These parameters are applied to all API requests made by the UDF and override any
480
- parameters set in the task configuration.
521
+ Arbitrary OpenAI Responses API parameters (e.g. ``temperature``, ``top_p``,
522
+ ``frequency_penalty``, ``presence_penalty``, ``seed``, ``max_output_tokens``, etc.)
523
+ are forwarded verbatim to the underlying API calls. These parameters are applied to
524
+ all API requests made by the UDF and override any parameters set in the task configuration.
481
525
 
482
526
  Returns:
483
527
  UserDefinedFunction: A Spark pandas UDF configured to execute the specified task
@@ -498,15 +542,16 @@ def task_udf(
498
542
  **Automatic Caching**: Duplicate inputs within each partition are cached,
499
543
  reducing API calls and costs significantly on datasets with repeated content.
500
544
  """
545
+ # Merge task's api_kwargs with caller's api_kwargs (caller takes precedence)
546
+ merged_kwargs = {**task.api_kwargs, **api_kwargs}
547
+
501
548
  return responses_udf(
502
549
  instructions=task.instructions,
503
550
  response_format=task.response_format,
504
551
  model_name=model_name,
505
552
  batch_size=batch_size,
506
- temperature=task.temperature,
507
- top_p=task.top_p,
508
553
  max_concurrency=max_concurrency,
509
- **api_kwargs,
554
+ **merged_kwargs,
510
555
  )
511
556
 
512
557
 
@@ -532,15 +577,13 @@ def infer_schema(
532
577
  InferredSchema: An object containing the inferred schema and response format.
533
578
  """
534
579
 
535
- from pyspark.sql import SparkSession
536
-
537
- spark = SparkSession.builder.getOrCreate()
580
+ spark = CONTAINER.resolve(SparkSession)
538
581
  examples: list[str] = (
539
582
  spark.table(example_table_name).rdd.map(lambda row: row[example_field_name]).takeSample(False, max_examples)
540
583
  )
541
584
 
542
585
  input = SchemaInferenceInput(
543
- purpose=instructions,
586
+ instructions=instructions,
544
587
  examples=examples,
545
588
  )
546
589
  inferer = CONTAINER.resolve(SchemaInferer)
@@ -553,10 +596,8 @@ def parse_udf(
553
596
  example_table_name: str | None = None,
554
597
  example_field_name: str | None = None,
555
598
  max_examples: int = 100,
556
- model_name: str = CONTAINER.resolve(ResponsesModelName).value,
599
+ model_name: str | None = None,
557
600
  batch_size: int | None = None,
558
- temperature: float | None = 0.0,
559
- top_p: float = 1.0,
560
601
  max_concurrency: int = 8,
561
602
  **api_kwargs,
562
603
  ) -> UserDefinedFunction:
@@ -579,24 +620,23 @@ def parse_udf(
579
620
  If provided, `example_table_name` must also be specified.
580
621
  max_examples (int): Maximum number of examples to retrieve for schema inference.
581
622
  Defaults to 100.
582
- model_name (str): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
583
- For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container.
623
+ model_name (str | None): For Azure OpenAI, use your deployment name (e.g., "my-gpt4-deployment").
624
+ For OpenAI, use the model name (e.g., "gpt-4.1-mini"). Defaults to configured model in DI container
625
+ via ResponsesModelName if not provided.
584
626
  batch_size (int | None): Number of rows per async batch request within each partition.
585
627
  Larger values reduce API call overhead but increase memory usage.
586
628
  Defaults to None (automatic batch size optimization that dynamically
587
629
  adjusts based on execution time, targeting 30-60 seconds per batch).
588
630
  Set to a positive integer (e.g., 32-128) for fixed batch size
589
- temperature (float | None): Sampling temperature (0.0 to 2.0). Defaults to 0.0.
590
- top_p (float): Nucleus sampling parameter. Defaults to 1.0.
591
631
  max_concurrency (int): Maximum number of concurrent API requests **PER EXECUTOR**.
592
632
  Total cluster concurrency = max_concurrency × number_of_executors.
593
633
  Higher values increase throughput but may hit OpenAI rate limits.
594
634
  Recommended: 4-12 per executor. Defaults to 8.
595
- Additional Keyword Args:
596
- Arbitrary OpenAI Responses API parameters (e.g. ``frequency_penalty``, ``presence_penalty``,
597
- ``seed``, ``max_output_tokens``, etc.) are forwarded verbatim to the underlying API calls.
598
- These parameters are applied to all API requests made by the UDF and override any
599
- parameters set in the response_format or example data.
635
+ **api_kwargs: Additional OpenAI API parameters (e.g. ``temperature``, ``top_p``,
636
+ ``frequency_penalty``, ``presence_penalty``, ``seed``, ``max_output_tokens``, etc.)
637
+ forwarded verbatim to the underlying API calls. These parameters are applied to
638
+ all API requests made by the UDF and override any parameters set in the
639
+ response_format or example data.
600
640
  Returns:
601
641
  UserDefinedFunction: A Spark pandas UDF configured to parse responses asynchronously.
602
642
  Output schema is `StringType` for str response format or a struct derived from
@@ -623,17 +663,16 @@ def parse_udf(
623
663
  response_format=schema.model if schema else response_format,
624
664
  model_name=model_name,
625
665
  batch_size=batch_size,
626
- temperature=temperature,
627
- top_p=top_p,
628
666
  max_concurrency=max_concurrency,
629
667
  **api_kwargs,
630
668
  )
631
669
 
632
670
 
633
671
  def embeddings_udf(
634
- model_name: str = CONTAINER.resolve(EmbeddingsModelName).value,
672
+ model_name: str | None = None,
635
673
  batch_size: int | None = None,
636
674
  max_concurrency: int = 8,
675
+ **api_kwargs,
637
676
  ) -> UserDefinedFunction:
638
677
  """Create an asynchronous Spark pandas UDF for generating embeddings.
639
678
 
@@ -656,9 +695,9 @@ def embeddings_udf(
656
695
  sc.environment["AZURE_OPENAI_API_VERSION"] = "preview"
657
696
 
658
697
  Args:
659
- model_name (str): For Azure OpenAI, use your deployment name (e.g., "my-embedding-deployment").
698
+ model_name (str | None): For Azure OpenAI, use your deployment name (e.g., "my-embedding-deployment").
660
699
  For OpenAI, use the model name (e.g., "text-embedding-3-small").
661
- Defaults to configured model in DI container.
700
+ Defaults to configured model in DI container via EmbeddingsModelName if not provided.
662
701
  batch_size (int | None): Number of rows per async batch request within each partition.
663
702
  Larger values reduce API call overhead but increase memory usage.
664
703
  Defaults to None (automatic batch size optimization that dynamically
@@ -669,6 +708,7 @@ def embeddings_udf(
669
708
  Total cluster concurrency = max_concurrency × number_of_executors.
670
709
  Higher values increase throughput but may hit OpenAI rate limits.
671
710
  Recommended: 4-12 per executor. Defaults to 8.
711
+ **api_kwargs: Additional OpenAI API parameters (e.g., dimensions for text-embedding-3 models).
672
712
 
673
713
  Returns:
674
714
  UserDefinedFunction: A Spark pandas UDF configured to generate embeddings asynchronously
@@ -685,9 +725,11 @@ def embeddings_udf(
685
725
  - Use larger batch_size for embeddings compared to response generation
686
726
  """
687
727
 
728
+ _model_name = model_name or CONTAINER.resolve(EmbeddingsModelName).value
729
+
688
730
  @pandas_udf(returnType=ArrayType(FloatType())) # type: ignore[call-overload,misc]
689
731
  def _embeddings_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
690
- pandas_ext.embeddings_model(model_name)
732
+ pandas_ext.embeddings_model(_model_name)
691
733
  cache = AsyncBatchingMapProxy[str, np.ndarray](
692
734
  batch_size=batch_size,
693
735
  max_concurrency=max_concurrency,
@@ -695,7 +737,7 @@ def embeddings_udf(
695
737
 
696
738
  try:
697
739
  for part in col:
698
- embeddings: pd.Series = asyncio.run(part.aio.embeddings_with_cache(cache=cache))
740
+ embeddings: pd.Series = asyncio.run(part.aio.embeddings_with_cache(cache=cache, **api_kwargs))
699
741
  yield embeddings.map(lambda x: x.tolist())
700
742
  finally:
701
743
  asyncio.run(cache.clear())
@@ -117,7 +117,7 @@ All tasks are built using the `PreparedTask` dataclass:
117
117
  @dataclass(frozen=True)
118
118
  class PreparedTask:
119
119
  instructions: str # Detailed prompt for the LLM
120
- response_format: Type[ResponseFormat] # Pydantic model or str for structured/plain output
120
+ response_format: type[ResponseFormat] # Pydantic model or str for structured/plain output
121
121
  temperature: float = 0.0 # Sampling temperature
122
122
  top_p: float = 1.0 # Nucleus sampling parameter
123
123
  ```
@@ -95,15 +95,12 @@ class CustomerSentiment(BaseModel):
95
95
  )
96
96
 
97
97
 
98
- def customer_sentiment(
99
- business_context: str = "general customer support", temperature: float = 0.0, top_p: float = 1.0
100
- ) -> PreparedTask:
98
+ def customer_sentiment(business_context: str = "general customer support", **api_kwargs) -> PreparedTask:
101
99
  """Create a configurable customer sentiment analysis task.
102
100
 
103
101
  Args:
104
102
  business_context (str): Business context for sentiment analysis.
105
- temperature (float): Sampling temperature (0.0-1.0).
106
- top_p (float): Nucleus sampling parameter (0.0-1.0).
103
+ **api_kwargs: Additional OpenAI API parameters (temperature, top_p, etc.).
107
104
 
108
105
  Returns:
109
106
  PreparedTask configured for customer sentiment analysis.
@@ -169,10 +166,8 @@ values like "positive" for sentiment.
169
166
 
170
167
  Provide comprehensive sentiment analysis with business context and recommended response strategy."""
171
168
 
172
- return PreparedTask(
173
- instructions=instructions, response_format=CustomerSentiment, temperature=temperature, top_p=top_p
174
- )
169
+ return PreparedTask(instructions=instructions, response_format=CustomerSentiment, api_kwargs=api_kwargs)
175
170
 
176
171
 
177
172
  # Backward compatibility - default configuration
178
- CUSTOMER_SENTIMENT = customer_sentiment()
173
+ CUSTOMER_SENTIMENT = customer_sentiment(temperature=0.0, top_p=1.0)
@@ -119,8 +119,7 @@ def inquiry_classification(
119
119
  priority_rules: Dict[str, str] | None = None,
120
120
  business_context: str = "general customer support",
121
121
  custom_keywords: Dict[str, list[str]] | None = None,
122
- temperature: float = 0.0,
123
- top_p: float = 1.0,
122
+ **api_kwargs,
124
123
  ) -> PreparedTask:
125
124
  """Create a configurable inquiry classification task.
126
125
 
@@ -133,8 +132,8 @@ def inquiry_classification(
133
132
  Default uses standard priority indicators.
134
133
  business_context (str): Description of the business context to help with classification.
135
134
  custom_keywords (dict[str, list[str]] | None): Dictionary mapping categories to relevant keywords.
136
- temperature (float): Sampling temperature (0.0-1.0).
137
- top_p (float): Nucleus sampling parameter (0.0-1.0).
135
+ **api_kwargs: Additional keyword arguments to pass to the OpenAI API,
136
+ such as temperature, top_p, etc.
138
137
 
139
138
  Returns:
140
139
  PreparedTask configured for inquiry classification.
@@ -254,10 +253,8 @@ language where appropriate, but priority must use English values like "high".
254
253
 
255
254
  Provide accurate classification with detailed reasoning."""
256
255
 
257
- return PreparedTask(
258
- instructions=instructions, response_format=InquiryClassification, temperature=temperature, top_p=top_p
259
- )
256
+ return PreparedTask(instructions=instructions, response_format=InquiryClassification, api_kwargs=api_kwargs)
260
257
 
261
258
 
262
259
  # Backward compatibility - default configuration
263
- INQUIRY_CLASSIFICATION = inquiry_classification()
260
+ INQUIRY_CLASSIFICATION = inquiry_classification(temperature=0.0, top_p=1.0)
@@ -87,16 +87,15 @@ class InquirySummary(BaseModel):
87
87
  def inquiry_summary(
88
88
  summary_length: str = "concise",
89
89
  business_context: str = "general customer support",
90
- temperature: float = 0.0,
91
- top_p: float = 1.0,
90
+ **api_kwargs,
92
91
  ) -> PreparedTask:
93
92
  """Create a configurable inquiry summary task.
94
93
 
95
94
  Args:
96
95
  summary_length (str): Length of summary (concise, detailed, bullet_points).
97
96
  business_context (str): Business context for summary.
98
- temperature (float): Sampling temperature (0.0-1.0).
99
- top_p (float): Nucleus sampling parameter (0.0-1.0).
97
+ **api_kwargs: Additional keyword arguments to pass to the OpenAI API,
98
+ such as temperature, top_p, etc.
100
99
 
101
100
  Returns:
102
101
  PreparedTask configured for inquiry summarization.
@@ -163,8 +162,8 @@ input is in German, provide all summary content in German, but use English value
163
162
 
164
163
  Provide accurate, actionable summary that enables efficient support resolution."""
165
164
 
166
- return PreparedTask(instructions=instructions, response_format=InquirySummary, temperature=temperature, top_p=top_p)
165
+ return PreparedTask(instructions=instructions, response_format=InquirySummary, api_kwargs=api_kwargs)
167
166
 
168
167
 
169
168
  # Backward compatibility - default configuration
170
- INQUIRY_SUMMARY = inquiry_summary()
169
+ INQUIRY_SUMMARY = inquiry_summary(temperature=0.0, top_p=1.0)
@@ -100,15 +100,13 @@ class IntentAnalysis(BaseModel):
100
100
  )
101
101
 
102
102
 
103
- def intent_analysis(
104
- business_context: str = "general customer support", temperature: float = 0.0, top_p: float = 1.0
105
- ) -> PreparedTask:
103
+ def intent_analysis(business_context: str = "general customer support", **api_kwargs) -> PreparedTask:
106
104
  """Create a configurable intent analysis task.
107
105
 
108
106
  Args:
109
107
  business_context (str): Business context for intent analysis.
110
- temperature (float): Sampling temperature (0.0-1.0).
111
- top_p (float): Nucleus sampling parameter (0.0-1.0).
108
+ **api_kwargs: Additional keyword arguments to pass to the OpenAI API,
109
+ such as temperature, top_p, etc.
112
110
 
113
111
  Returns:
114
112
  PreparedTask configured for intent analysis.
@@ -171,8 +169,8 @@ next_steps, and reasoning in Japanese, but use English values like "get_help" fo
171
169
 
172
170
  Provide comprehensive intent analysis with actionable recommendations."""
173
171
 
174
- return PreparedTask(instructions=instructions, response_format=IntentAnalysis, temperature=temperature, top_p=top_p)
172
+ return PreparedTask(instructions=instructions, response_format=IntentAnalysis, api_kwargs=api_kwargs)
175
173
 
176
174
 
177
175
  # Backward compatibility - default configuration
178
- INTENT_ANALYSIS = intent_analysis()
176
+ INTENT_ANALYSIS = intent_analysis(temperature=0.0, top_p=1.0)
@@ -92,8 +92,7 @@ def response_suggestion(
92
92
  response_style: str = "professional",
93
93
  company_name: str = "our company",
94
94
  business_context: str = "general customer support",
95
- temperature: float = 0.0,
96
- top_p: float = 1.0,
95
+ **api_kwargs,
97
96
  ) -> PreparedTask:
98
97
  """Create a configurable response suggestion task.
99
98
 
@@ -101,8 +100,8 @@ def response_suggestion(
101
100
  response_style (str): Style of response (professional, friendly, empathetic, formal).
102
101
  company_name (str): Name of the company for personalization.
103
102
  business_context (str): Business context for responses.
104
- temperature (float): Sampling temperature (0.0-1.0).
105
- top_p (float): Nucleus sampling parameter (0.0-1.0).
103
+ **api_kwargs: Additional keyword arguments to pass to the OpenAI API,
104
+ such as temperature, top_p, etc.
106
105
 
107
106
  Returns:
108
107
  PreparedTask configured for response suggestions.
@@ -190,10 +189,8 @@ but use English values like "empathetic" for tone.
190
189
  Generate helpful, professional response that moves toward resolution while maintaining
191
190
  positive customer relationship."""
192
191
 
193
- return PreparedTask(
194
- instructions=instructions, response_format=ResponseSuggestion, temperature=temperature, top_p=top_p
195
- )
192
+ return PreparedTask(instructions=instructions, response_format=ResponseSuggestion, api_kwargs=api_kwargs)
196
193
 
197
194
 
198
195
  # Backward compatibility - default configuration
199
- RESPONSE_SUGGESTION = response_suggestion()
196
+ RESPONSE_SUGGESTION = response_suggestion(temperature=0.0, top_p=1.0)
@@ -135,8 +135,7 @@ def urgency_analysis(
135
135
  business_context: str = "general customer support",
136
136
  business_hours: str = "24/7 support",
137
137
  sla_rules: Dict[str, str] | None = None,
138
- temperature: float = 0.0,
139
- top_p: float = 1.0,
138
+ **api_kwargs,
140
139
  ) -> PreparedTask:
141
140
  """Create a configurable urgency analysis task.
142
141
 
@@ -149,8 +148,8 @@ def urgency_analysis(
149
148
  business_context (str): Description of the business context.
150
149
  business_hours (str): Description of business hours for response time calculation.
151
150
  sla_rules (dict[str, str] | None): Dictionary mapping customer tiers to SLA requirements.
152
- temperature (float): Sampling temperature (0.0-1.0).
153
- top_p (float): Nucleus sampling parameter (0.0-1.0).
151
+ **api_kwargs: Additional keyword arguments to pass to the OpenAI API,
152
+ such as temperature, top_p, etc.
154
153
 
155
154
  Returns:
156
155
  PreparedTask configured for urgency analysis.
@@ -287,10 +286,8 @@ urgency_level.
287
286
 
288
287
  Provide detailed analysis with clear reasoning for urgency level and response time recommendations."""
289
288
 
290
- return PreparedTask(
291
- instructions=instructions, response_format=UrgencyAnalysis, temperature=temperature, top_p=top_p
292
- )
289
+ return PreparedTask(instructions=instructions, response_format=UrgencyAnalysis, api_kwargs=api_kwargs)
293
290
 
294
291
 
295
292
  # Backward compatibility - default configuration
296
- URGENCY_ANALYSIS = urgency_analysis()
293
+ URGENCY_ANALYSIS = urgency_analysis(temperature=0.0, top_p=1.0)
@@ -75,6 +75,5 @@ DEPENDENCY_PARSING = PreparedTask(
75
75
  "relations between words, determine the root word, and provide a tree representation of the "
76
76
  "syntactic structure.",
77
77
  response_format=DependencyParsing,
78
- temperature=0.0,
79
- top_p=1.0,
78
+ api_kwargs={"temperature": 0.0, "top_p": 1.0},
80
79
  )
@@ -75,6 +75,5 @@ KEYWORD_EXTRACTION = PreparedTask(
75
75
  instructions="Extract important keywords and phrases from the following text. Rank them "
76
76
  "by importance, provide frequency counts, identify main topics, and generate a brief summary.",
77
77
  response_format=KeywordExtraction,
78
- temperature=0.0,
79
- top_p=1.0,
78
+ api_kwargs={"temperature": 0.0, "top_p": 1.0},
80
79
  )
@@ -70,6 +70,5 @@ MORPHOLOGICAL_ANALYSIS = PreparedTask(
70
70
  "identify part-of-speech tags, provide lemmatized forms, and extract morphological features "
71
71
  "for each token.",
72
72
  response_format=MorphologicalAnalysis,
73
- temperature=0.0,
74
- top_p=1.0,
73
+ api_kwargs={"temperature": 0.0, "top_p": 1.0},
75
74
  )
@@ -78,6 +78,5 @@ NAMED_ENTITY_RECOGNITION = PreparedTask(
78
78
  "organizations, locations, dates, money, percentages, and other miscellaneous entities "
79
79
  "with their positions and confidence scores.",
80
80
  response_format=NamedEntityRecognition,
81
- temperature=0.0,
82
- top_p=1.0,
81
+ api_kwargs={"temperature": 0.0, "top_p": 1.0},
83
82
  )
@@ -78,6 +78,5 @@ SENTIMENT_ANALYSIS = PreparedTask(
78
78
  "English values specified (positive/negative/neutral for sentiment, and "
79
79
  "joy/sadness/anger/fear/surprise/disgust for emotions).",
80
80
  response_format=SentimentAnalysis,
81
- temperature=0.0,
82
- top_p=1.0,
81
+ api_kwargs={"temperature": 0.0, "top_p": 1.0},
83
82
  )
@@ -157,5 +157,5 @@ class TranslatedString(BaseModel):
157
157
  instructions = "Translate the following text into multiple languages. "
158
158
 
159
159
  MULTILINGUAL_TRANSLATION = PreparedTask(
160
- instructions=instructions, response_format=TranslatedString, temperature=0.0, top_p=1.0
160
+ instructions=instructions, response_format=TranslatedString, api_kwargs={"temperature": 0.0, "top_p": 1.0}
161
161
  )
@@ -125,7 +125,7 @@ class FillNaResponse(BaseModel):
125
125
  )
126
126
 
127
127
 
128
- def fillna(df: pd.DataFrame, target_column_name: str, max_examples: int = 500) -> PreparedTask:
128
+ def fillna(df: pd.DataFrame, target_column_name: str, max_examples: int = 500, **api_kwargs) -> PreparedTask:
129
129
  """Create a prepared task for filling missing values in a DataFrame column.
130
130
 
131
131
  Analyzes the provided DataFrame to understand data patterns and creates
@@ -141,12 +141,14 @@ def fillna(df: pd.DataFrame, target_column_name: str, max_examples: int = 500) -
141
141
  max_examples (int): Maximum number of example rows to use for few-shot
142
142
  learning. Defaults to 500. Higher values provide more context
143
143
  but increase token usage and processing time.
144
+ **api_kwargs: Additional keyword arguments to pass to the OpenAI API,
145
+ such as temperature, top_p, etc.
144
146
 
145
147
  Returns:
146
148
  PreparedTask configured for missing value imputation with:
147
149
  - Instructions based on DataFrame patterns
148
150
  - FillNaResponse format for structured output
149
- - Temperature=0.0 and top_p=1.0 for deterministic results
151
+ - Default deterministic settings (temperature=0.0, top_p=1.0)
150
152
 
151
153
  Raises:
152
154
  ValueError: If target_column_name doesn't exist in DataFrame,
@@ -180,4 +182,7 @@ def fillna(df: pd.DataFrame, target_column_name: str, max_examples: int = 500) -
180
182
  if df[target_column_name].notna().sum() == 0:
181
183
  raise ValueError(f"Column '{target_column_name}' contains no non-null values for training examples.")
182
184
  instructions = get_instructions(df, target_column_name, max_examples)
183
- return PreparedTask(instructions=instructions, response_format=FillNaResponse, temperature=0.0, top_p=1.0)
185
+ # Set default values for deterministic results if not provided
186
+ if not api_kwargs:
187
+ api_kwargs = {"temperature": 0.0, "top_p": 1.0}
188
+ return PreparedTask(instructions=instructions, response_format=FillNaResponse, api_kwargs=api_kwargs)