sdg-hub 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/__init__.py +28 -1
- sdg_hub/_version.py +2 -2
- sdg_hub/core/__init__.py +22 -0
- sdg_hub/core/blocks/__init__.py +58 -0
- sdg_hub/core/blocks/base.py +313 -0
- sdg_hub/core/blocks/deprecated_blocks/__init__.py +29 -0
- sdg_hub/core/blocks/deprecated_blocks/combine_columns.py +93 -0
- sdg_hub/core/blocks/deprecated_blocks/duplicate_columns.py +88 -0
- sdg_hub/core/blocks/deprecated_blocks/filter_by_value.py +103 -0
- sdg_hub/core/blocks/deprecated_blocks/flatten_columns.py +94 -0
- sdg_hub/core/blocks/deprecated_blocks/llmblock.py +479 -0
- sdg_hub/core/blocks/deprecated_blocks/rename_columns.py +88 -0
- sdg_hub/core/blocks/deprecated_blocks/sample_populator.py +58 -0
- sdg_hub/core/blocks/deprecated_blocks/selector.py +97 -0
- sdg_hub/core/blocks/deprecated_blocks/set_to_majority_value.py +88 -0
- sdg_hub/core/blocks/evaluation/__init__.py +9 -0
- sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +564 -0
- sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +564 -0
- sdg_hub/core/blocks/evaluation/verify_question_block.py +564 -0
- sdg_hub/core/blocks/filtering/__init__.py +12 -0
- sdg_hub/core/blocks/filtering/column_value_filter.py +188 -0
- sdg_hub/core/blocks/llm/__init__.py +25 -0
- sdg_hub/core/blocks/llm/client_manager.py +398 -0
- sdg_hub/core/blocks/llm/config.py +336 -0
- sdg_hub/core/blocks/llm/error_handler.py +368 -0
- sdg_hub/core/blocks/llm/llm_chat_block.py +542 -0
- sdg_hub/core/blocks/llm/prompt_builder_block.py +368 -0
- sdg_hub/core/blocks/llm/text_parser_block.py +310 -0
- sdg_hub/core/blocks/registry.py +331 -0
- sdg_hub/core/blocks/transform/__init__.py +23 -0
- sdg_hub/core/blocks/transform/duplicate_columns.py +88 -0
- sdg_hub/core/blocks/transform/index_based_mapper.py +225 -0
- sdg_hub/core/blocks/transform/melt_columns.py +126 -0
- sdg_hub/core/blocks/transform/rename_columns.py +69 -0
- sdg_hub/core/blocks/transform/text_concat.py +102 -0
- sdg_hub/core/blocks/transform/uniform_col_val_setter.py +101 -0
- sdg_hub/core/flow/__init__.py +20 -0
- sdg_hub/core/flow/base.py +980 -0
- sdg_hub/core/flow/metadata.py +344 -0
- sdg_hub/core/flow/migration.py +187 -0
- sdg_hub/core/flow/registry.py +330 -0
- sdg_hub/core/flow/validation.py +265 -0
- sdg_hub/{utils → core/utils}/__init__.py +6 -4
- sdg_hub/{utils → core/utils}/datautils.py +1 -3
- sdg_hub/core/utils/error_handling.py +208 -0
- sdg_hub/{utils → core/utils}/path_resolution.py +2 -2
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/atomic_facts.yaml +40 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/detailed_summary.yaml +13 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_faithfulness.yaml +64 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_question.yaml +29 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_relevancy.yaml +81 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/extractive_summary.yaml +13 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +191 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/generate_questions_responses.yaml +54 -0
- sdg_hub-0.2.0.dist-info/METADATA +218 -0
- sdg_hub-0.2.0.dist-info/RECORD +63 -0
- sdg_hub/blocks/__init__.py +0 -42
- sdg_hub/blocks/block.py +0 -96
- sdg_hub/blocks/llmblock.py +0 -375
- sdg_hub/blocks/openaichatblock.py +0 -556
- sdg_hub/blocks/utilblocks.py +0 -597
- sdg_hub/checkpointer.py +0 -139
- sdg_hub/configs/annotations/cot_reflection.yaml +0 -34
- sdg_hub/configs/annotations/detailed_annotations.yaml +0 -28
- sdg_hub/configs/annotations/detailed_description.yaml +0 -10
- sdg_hub/configs/annotations/detailed_description_icl.yaml +0 -32
- sdg_hub/configs/annotations/simple_annotations.yaml +0 -9
- sdg_hub/configs/knowledge/__init__.py +0 -0
- sdg_hub/configs/knowledge/atomic_facts.yaml +0 -46
- sdg_hub/configs/knowledge/auxilary_instructions.yaml +0 -35
- sdg_hub/configs/knowledge/detailed_summary.yaml +0 -18
- sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +0 -68
- sdg_hub/configs/knowledge/evaluate_question.yaml +0 -38
- sdg_hub/configs/knowledge/evaluate_relevancy.yaml +0 -84
- sdg_hub/configs/knowledge/extractive_summary.yaml +0 -18
- sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +0 -39
- sdg_hub/configs/knowledge/generate_questions.yaml +0 -82
- sdg_hub/configs/knowledge/generate_questions_responses.yaml +0 -56
- sdg_hub/configs/knowledge/generate_responses.yaml +0 -86
- sdg_hub/configs/knowledge/mcq_generation.yaml +0 -83
- sdg_hub/configs/knowledge/router.yaml +0 -12
- sdg_hub/configs/knowledge/simple_generate_qa.yaml +0 -34
- sdg_hub/configs/reasoning/__init__.py +0 -0
- sdg_hub/configs/reasoning/dynamic_cot.yaml +0 -40
- sdg_hub/configs/skills/__init__.py +0 -0
- sdg_hub/configs/skills/analyzer.yaml +0 -48
- sdg_hub/configs/skills/annotation.yaml +0 -36
- sdg_hub/configs/skills/contexts.yaml +0 -28
- sdg_hub/configs/skills/critic.yaml +0 -60
- sdg_hub/configs/skills/evaluate_freeform_pair.yaml +0 -111
- sdg_hub/configs/skills/evaluate_freeform_questions.yaml +0 -78
- sdg_hub/configs/skills/evaluate_grounded_pair.yaml +0 -119
- sdg_hub/configs/skills/evaluate_grounded_questions.yaml +0 -51
- sdg_hub/configs/skills/freeform_questions.yaml +0 -34
- sdg_hub/configs/skills/freeform_responses.yaml +0 -39
- sdg_hub/configs/skills/grounded_questions.yaml +0 -38
- sdg_hub/configs/skills/grounded_responses.yaml +0 -59
- sdg_hub/configs/skills/icl_examples/STEM.yaml +0 -56
- sdg_hub/configs/skills/icl_examples/__init__.py +0 -0
- sdg_hub/configs/skills/icl_examples/coding.yaml +0 -97
- sdg_hub/configs/skills/icl_examples/extraction.yaml +0 -36
- sdg_hub/configs/skills/icl_examples/humanities.yaml +0 -71
- sdg_hub/configs/skills/icl_examples/math.yaml +0 -85
- sdg_hub/configs/skills/icl_examples/reasoning.yaml +0 -30
- sdg_hub/configs/skills/icl_examples/roleplay.yaml +0 -45
- sdg_hub/configs/skills/icl_examples/writing.yaml +0 -80
- sdg_hub/configs/skills/judge.yaml +0 -53
- sdg_hub/configs/skills/planner.yaml +0 -67
- sdg_hub/configs/skills/respond.yaml +0 -8
- sdg_hub/configs/skills/revised_responder.yaml +0 -78
- sdg_hub/configs/skills/router.yaml +0 -59
- sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +0 -27
- sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +0 -31
- sdg_hub/flow.py +0 -477
- sdg_hub/flow_runner.py +0 -450
- sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +0 -13
- sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +0 -12
- sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +0 -89
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +0 -136
- sdg_hub/flows/generation/skills/improve_responses.yaml +0 -103
- sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +0 -12
- sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +0 -12
- sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +0 -80
- sdg_hub/flows/generation/skills/synth_skills.yaml +0 -59
- sdg_hub/pipeline.py +0 -121
- sdg_hub/prompts.py +0 -80
- sdg_hub/registry.py +0 -122
- sdg_hub/sdg.py +0 -206
- sdg_hub/utils/config_validation.py +0 -91
- sdg_hub/utils/error_handling.py +0 -94
- sdg_hub/utils/validation_result.py +0 -10
- sdg_hub-0.1.4.dist-info/METADATA +0 -190
- sdg_hub-0.1.4.dist-info/RECORD +0 -89
- sdg_hub/{logger_config.py → core/utils/logger_config.py} +1 -1
- /sdg_hub/{configs/__init__.py → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/README.md} +0 -0
- /sdg_hub/{configs/annotations → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab}/__init__.py +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/WHEEL +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,556 +0,0 @@
|
|
1
|
-
# SPDX-License-Identifier: Apache-2.0
|
2
|
-
"""OpenAI-specific blocks for text generation.
|
3
|
-
|
4
|
-
This module provides blocks for interacting with OpenAI's Chat Completions API.
|
5
|
-
"""
|
6
|
-
|
7
|
-
# Standard
|
8
|
-
from typing import Any, Dict, List, Optional, Union
|
9
|
-
import asyncio
|
10
|
-
|
11
|
-
# Third Party
|
12
|
-
from datasets import Dataset
|
13
|
-
from tenacity import (
|
14
|
-
retry,
|
15
|
-
retry_if_exception_type,
|
16
|
-
stop_after_attempt,
|
17
|
-
wait_random_exponential,
|
18
|
-
)
|
19
|
-
import openai
|
20
|
-
|
21
|
-
# Local
|
22
|
-
from ..logger_config import setup_logger
|
23
|
-
from ..registry import BlockRegistry
|
24
|
-
from .block import Block
|
25
|
-
|
26
|
-
logger = setup_logger(__name__)
|
27
|
-
|
28
|
-
|
29
|
-
@BlockRegistry.register("OpenAIChatBlock")
|
30
|
-
class OpenAIChatBlock(Block):
|
31
|
-
"""Block for generating text using OpenAI Chat Completions API.
|
32
|
-
|
33
|
-
This block takes a column containing OpenAI message format and makes
|
34
|
-
direct calls to the chat completions endpoint.
|
35
|
-
|
36
|
-
Parameters
|
37
|
-
----------
|
38
|
-
block_name : str
|
39
|
-
Name of the block.
|
40
|
-
input_cols : Union[str, List[str]]
|
41
|
-
Input column name(s). Should contain the messages list.
|
42
|
-
output_cols : Union[str, List[str]]
|
43
|
-
Output column name(s) for the response.
|
44
|
-
client : openai.OpenAI
|
45
|
-
OpenAI client instance.
|
46
|
-
model_id : str
|
47
|
-
Model ID to use.
|
48
|
-
|
49
|
-
### Text-relevant OpenAI Chat Completions API parameters ###
|
50
|
-
|
51
|
-
frequency_penalty : Optional[float], optional
|
52
|
-
Penalize frequent tokens (-2.0 to 2.0).
|
53
|
-
logit_bias : Optional[Dict[str, int]], optional
|
54
|
-
Modify likelihood of specified tokens.
|
55
|
-
logprobs : Optional[bool], optional
|
56
|
-
Whether to return log probabilities.
|
57
|
-
max_completion_tokens : Optional[int], optional
|
58
|
-
Maximum tokens in completion.
|
59
|
-
max_tokens : Optional[int], optional
|
60
|
-
Maximum tokens in completion (legacy).
|
61
|
-
n : Optional[int], optional
|
62
|
-
Number of completions to generate.
|
63
|
-
presence_penalty : Optional[float], optional
|
64
|
-
Penalize repeated tokens (-2.0 to 2.0).
|
65
|
-
response_format : Optional[Dict[str, Any]], optional
|
66
|
-
Response format specification (e.g., JSON mode).
|
67
|
-
seed : Optional[int], optional
|
68
|
-
Seed for deterministic outputs.
|
69
|
-
stop : Optional[Union[str, List[str]]], optional
|
70
|
-
Stop sequences.
|
71
|
-
stream : Optional[bool], optional
|
72
|
-
Whether to stream responses.
|
73
|
-
temperature : Optional[float], optional
|
74
|
-
Sampling temperature (0.0 to 2.0).
|
75
|
-
tool_choice : Optional[Union[str, Dict[str, Any]]], optional
|
76
|
-
Tool selection strategy.
|
77
|
-
tools : Optional[List[Dict[str, Any]]], optional
|
78
|
-
Available tools for function calling.
|
79
|
-
top_logprobs : Optional[int], optional
|
80
|
-
Number of top log probabilities to return.
|
81
|
-
top_p : Optional[float], optional
|
82
|
-
Nucleus sampling parameter (0.0 to 1.0).
|
83
|
-
user : Optional[str], optional
|
84
|
-
End-user identifier.
|
85
|
-
extra_body : Optional[dict], optional
|
86
|
-
Dictionary of additional parameters if supported by inference backend
|
87
|
-
"""
|
88
|
-
|
89
|
-
def __init__(
|
90
|
-
self,
|
91
|
-
block_name: str,
|
92
|
-
input_cols: Union[str, List[str]],
|
93
|
-
output_cols: Union[str, List[str]],
|
94
|
-
client: openai.OpenAI,
|
95
|
-
model_id: str,
|
96
|
-
# Text-relevant OpenAI Chat Completions API parameters
|
97
|
-
frequency_penalty: Optional[float] = None,
|
98
|
-
logit_bias: Optional[Dict[str, int]] = None,
|
99
|
-
logprobs: Optional[bool] = None,
|
100
|
-
max_completion_tokens: Optional[int] = None,
|
101
|
-
max_tokens: Optional[int] = None,
|
102
|
-
n: Optional[int] = None,
|
103
|
-
presence_penalty: Optional[float] = None,
|
104
|
-
response_format: Optional[Dict[str, Any]] = None,
|
105
|
-
seed: Optional[int] = None,
|
106
|
-
stop: Optional[Union[str, List[str]]] = None,
|
107
|
-
stream: Optional[bool] = None,
|
108
|
-
temperature: Optional[float] = None,
|
109
|
-
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
110
|
-
tools: Optional[List[Dict[str, Any]]] = None,
|
111
|
-
top_logprobs: Optional[int] = None,
|
112
|
-
top_p: Optional[float] = None,
|
113
|
-
user: Optional[str] = None,
|
114
|
-
extra_body: Optional[dict] = None,
|
115
|
-
) -> None:
|
116
|
-
super().__init__(block_name)
|
117
|
-
self.input_cols = [input_cols] if isinstance(input_cols, str) else input_cols
|
118
|
-
self.output_cols = (
|
119
|
-
[output_cols] if isinstance(output_cols, str) else output_cols
|
120
|
-
)
|
121
|
-
self.client = client
|
122
|
-
self.model_id = model_id
|
123
|
-
|
124
|
-
# For this block, we expect exactly one input column (messages) and one output column
|
125
|
-
if len(self.input_cols) != 1:
|
126
|
-
raise ValueError("OpenAIChatBlock expects exactly one input column")
|
127
|
-
if len(self.output_cols) != 1:
|
128
|
-
raise ValueError("OpenAIChatBlock expects exactly one output column")
|
129
|
-
|
130
|
-
self.messages_column = self.input_cols[0]
|
131
|
-
self.output_column = self.output_cols[0]
|
132
|
-
|
133
|
-
# Store all generation parameters (only non-None values)
|
134
|
-
self.gen_kwargs = {}
|
135
|
-
params = {
|
136
|
-
"frequency_penalty": frequency_penalty,
|
137
|
-
"logit_bias": logit_bias,
|
138
|
-
"logprobs": logprobs,
|
139
|
-
"max_completion_tokens": max_completion_tokens,
|
140
|
-
"max_tokens": max_tokens,
|
141
|
-
"n": n,
|
142
|
-
"presence_penalty": presence_penalty,
|
143
|
-
"response_format": response_format,
|
144
|
-
"seed": seed,
|
145
|
-
"stop": stop,
|
146
|
-
"stream": stream,
|
147
|
-
"temperature": temperature,
|
148
|
-
"tool_choice": tool_choice,
|
149
|
-
"tools": tools,
|
150
|
-
"top_logprobs": top_logprobs,
|
151
|
-
"top_p": top_p,
|
152
|
-
"user": user,
|
153
|
-
"extra_body": extra_body,
|
154
|
-
}
|
155
|
-
|
156
|
-
# Only include non-None parameters
|
157
|
-
for key, value in params.items():
|
158
|
-
if value is not None:
|
159
|
-
self.gen_kwargs[key] = value
|
160
|
-
|
161
|
-
# Log initialization with model and parameters
|
162
|
-
logger.info(
|
163
|
-
f"Initialized OpenAIChatBlock '{block_name}' with model '{model_id}'",
|
164
|
-
extra={
|
165
|
-
"block_name": block_name,
|
166
|
-
"model_id": model_id,
|
167
|
-
"generation_params": self.gen_kwargs,
|
168
|
-
},
|
169
|
-
)
|
170
|
-
|
171
|
-
@retry(
|
172
|
-
wait=wait_random_exponential(min=1, max=60),
|
173
|
-
stop=stop_after_attempt(6),
|
174
|
-
retry=retry_if_exception_type(
|
175
|
-
(openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError)
|
176
|
-
),
|
177
|
-
)
|
178
|
-
def _create_completion_with_retry(self, **kwargs):
|
179
|
-
"""Create completion with retry logic."""
|
180
|
-
return self.client.chat.completions.create(**kwargs)
|
181
|
-
|
182
|
-
def generate(self, samples: Dataset, **override_kwargs: Dict[str, Any]) -> Dataset:
|
183
|
-
"""Generate the output from the block.
|
184
|
-
|
185
|
-
Parameters set at runtime override those set during initialization.
|
186
|
-
Supports all text-relevant OpenAI Chat Completions API parameters.
|
187
|
-
|
188
|
-
Parameters
|
189
|
-
----------
|
190
|
-
samples : Dataset
|
191
|
-
Input dataset containing the messages column.
|
192
|
-
**override_kwargs : Dict[str, Any]
|
193
|
-
Runtime parameters that override initialization defaults.
|
194
|
-
Valid parameters: frequency_penalty, logit_bias, logprobs,
|
195
|
-
max_completion_tokens, max_tokens, n, presence_penalty,
|
196
|
-
response_format, seed, stop, stream, temperature, tool_choice,
|
197
|
-
tools, top_logprobs, top_p, user.
|
198
|
-
|
199
|
-
Returns
|
200
|
-
-------
|
201
|
-
Dataset
|
202
|
-
Dataset with the response added to the output column.
|
203
|
-
"""
|
204
|
-
# Define valid parameters for validation
|
205
|
-
valid_params = {
|
206
|
-
"frequency_penalty",
|
207
|
-
"logit_bias",
|
208
|
-
"logprobs",
|
209
|
-
"max_completion_tokens",
|
210
|
-
"max_tokens",
|
211
|
-
"n",
|
212
|
-
"presence_penalty",
|
213
|
-
"response_format",
|
214
|
-
"seed",
|
215
|
-
"stop",
|
216
|
-
"stream",
|
217
|
-
"temperature",
|
218
|
-
"tool_choice",
|
219
|
-
"tools",
|
220
|
-
"top_logprobs",
|
221
|
-
"top_p",
|
222
|
-
"user",
|
223
|
-
"extra_body",
|
224
|
-
}
|
225
|
-
|
226
|
-
# Filter and validate override parameters
|
227
|
-
filtered_kwargs = {
|
228
|
-
k: v for k, v in override_kwargs.items() if k in valid_params
|
229
|
-
}
|
230
|
-
|
231
|
-
# Warn about invalid parameters
|
232
|
-
invalid_params = set(override_kwargs.keys()) - valid_params
|
233
|
-
if invalid_params:
|
234
|
-
logger.warning(f"Ignoring invalid parameters: {invalid_params}")
|
235
|
-
|
236
|
-
# Merge kwargs with priority: runtime > init > defaults
|
237
|
-
final_kwargs = {**self.gen_kwargs, **filtered_kwargs}
|
238
|
-
final_kwargs["model"] = self.model_id
|
239
|
-
|
240
|
-
# Extract all messages
|
241
|
-
messages_list = samples[self.messages_column]
|
242
|
-
|
243
|
-
# Log generation start with model and effective parameters
|
244
|
-
logger.info(
|
245
|
-
f"Starting generation for {len(messages_list)} samples",
|
246
|
-
extra={
|
247
|
-
"block_name": self.block_name,
|
248
|
-
"model_id": self.model_id,
|
249
|
-
"batch_size": len(messages_list),
|
250
|
-
"effective_params": {
|
251
|
-
k: v
|
252
|
-
for k, v in final_kwargs.items()
|
253
|
-
if k
|
254
|
-
in [
|
255
|
-
"temperature",
|
256
|
-
"max_tokens",
|
257
|
-
"max_completion_tokens",
|
258
|
-
"top_p",
|
259
|
-
"n",
|
260
|
-
"seed",
|
261
|
-
]
|
262
|
-
},
|
263
|
-
},
|
264
|
-
)
|
265
|
-
|
266
|
-
# Get all responses
|
267
|
-
responses = []
|
268
|
-
for messages in messages_list:
|
269
|
-
response = self._create_completion_with_retry(
|
270
|
-
messages=messages, **final_kwargs
|
271
|
-
)
|
272
|
-
responses.append(response.choices[0].message.content)
|
273
|
-
|
274
|
-
# Log completion
|
275
|
-
logger.info(
|
276
|
-
f"Generation completed successfully for {len(responses)} samples",
|
277
|
-
extra={
|
278
|
-
"block_name": self.block_name,
|
279
|
-
"model_id": self.model_id,
|
280
|
-
"batch_size": len(responses),
|
281
|
-
},
|
282
|
-
)
|
283
|
-
|
284
|
-
# Add responses as new column
|
285
|
-
return samples.add_column(self.output_column, responses)
|
286
|
-
|
287
|
-
|
288
|
-
@BlockRegistry.register("OpenAIAsyncChatBlock")
|
289
|
-
class OpenAIAsyncChatBlock(Block):
|
290
|
-
"""Async block for generating text using OpenAI Chat Completions API.
|
291
|
-
|
292
|
-
This block takes a column containing OpenAI message format and makes
|
293
|
-
asynchronous calls to the chat completions endpoint for better performance.
|
294
|
-
|
295
|
-
Parameters
|
296
|
-
----------
|
297
|
-
block_name : str
|
298
|
-
Name of the block.
|
299
|
-
input_cols : Union[str, List[str]]
|
300
|
-
Input column name(s). Should contain the messages list.
|
301
|
-
output_cols : Union[str, List[str]]
|
302
|
-
Output column name(s) for the response.
|
303
|
-
async_client : openai.AsyncOpenAI
|
304
|
-
Async OpenAI client instance.
|
305
|
-
model_id : str
|
306
|
-
Model ID to use.
|
307
|
-
|
308
|
-
### Text-relevant OpenAI Chat Completions API parameters ###
|
309
|
-
|
310
|
-
frequency_penalty : Optional[float], optional
|
311
|
-
Penalize frequent tokens (-2.0 to 2.0).
|
312
|
-
logit_bias : Optional[Dict[str, int]], optional
|
313
|
-
Modify likelihood of specified tokens.
|
314
|
-
logprobs : Optional[bool], optional
|
315
|
-
Whether to return log probabilities.
|
316
|
-
max_completion_tokens : Optional[int], optional
|
317
|
-
Maximum tokens in completion.
|
318
|
-
max_tokens : Optional[int], optional
|
319
|
-
Maximum tokens in completion (legacy).
|
320
|
-
n : Optional[int], optional
|
321
|
-
Number of completions to generate.
|
322
|
-
presence_penalty : Optional[float], optional
|
323
|
-
Penalize repeated tokens (-2.0 to 2.0).
|
324
|
-
response_format : Optional[Dict[str, Any]], optional
|
325
|
-
Response format specification (e.g., JSON mode).
|
326
|
-
seed : Optional[int], optional
|
327
|
-
Seed for deterministic outputs.
|
328
|
-
stop : Optional[Union[str, List[str]]], optional
|
329
|
-
Stop sequences.
|
330
|
-
stream : Optional[bool], optional
|
331
|
-
Whether to stream responses.
|
332
|
-
temperature : Optional[float], optional
|
333
|
-
Sampling temperature (0.0 to 2.0).
|
334
|
-
tool_choice : Optional[Union[str, Dict[str, Any]]], optional
|
335
|
-
Tool selection strategy.
|
336
|
-
tools : Optional[List[Dict[str, Any]]], optional
|
337
|
-
Available tools for function calling.
|
338
|
-
top_logprobs : Optional[int], optional
|
339
|
-
Number of top log probabilities to return.
|
340
|
-
top_p : Optional[float], optional
|
341
|
-
Nucleus sampling parameter (0.0 to 1.0).
|
342
|
-
user : Optional[str], optional
|
343
|
-
End-user identifier.
|
344
|
-
extra_body : Optional[dict], optional
|
345
|
-
Dictionary of additional parameters if supported by inference backend
|
346
|
-
"""
|
347
|
-
|
348
|
-
def __init__(
|
349
|
-
self,
|
350
|
-
block_name: str,
|
351
|
-
input_cols: Union[str, List[str]],
|
352
|
-
output_cols: Union[str, List[str]],
|
353
|
-
async_client: openai.AsyncOpenAI,
|
354
|
-
model_id: str,
|
355
|
-
# Text-relevant OpenAI Chat Completions API parameters
|
356
|
-
frequency_penalty: Optional[float] = None,
|
357
|
-
logit_bias: Optional[Dict[str, int]] = None,
|
358
|
-
logprobs: Optional[bool] = None,
|
359
|
-
max_completion_tokens: Optional[int] = None,
|
360
|
-
max_tokens: Optional[int] = None,
|
361
|
-
n: Optional[int] = None,
|
362
|
-
presence_penalty: Optional[float] = None,
|
363
|
-
response_format: Optional[Dict[str, Any]] = None,
|
364
|
-
seed: Optional[int] = None,
|
365
|
-
stop: Optional[Union[str, List[str]]] = None,
|
366
|
-
stream: Optional[bool] = None,
|
367
|
-
temperature: Optional[float] = None,
|
368
|
-
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
369
|
-
tools: Optional[List[Dict[str, Any]]] = None,
|
370
|
-
top_logprobs: Optional[int] = None,
|
371
|
-
top_p: Optional[float] = None,
|
372
|
-
user: Optional[str] = None,
|
373
|
-
extra_body: Optional[dict] = None,
|
374
|
-
) -> None:
|
375
|
-
super().__init__(block_name)
|
376
|
-
self.input_cols = [input_cols] if isinstance(input_cols, str) else input_cols
|
377
|
-
self.output_cols = (
|
378
|
-
[output_cols] if isinstance(output_cols, str) else output_cols
|
379
|
-
)
|
380
|
-
self.async_client = async_client
|
381
|
-
self.model_id = model_id
|
382
|
-
|
383
|
-
# For this block, we expect exactly one input column (messages) and one output column
|
384
|
-
if len(self.input_cols) != 1:
|
385
|
-
raise ValueError("OpenAIAsyncChatBlock expects exactly one input column")
|
386
|
-
if len(self.output_cols) != 1:
|
387
|
-
raise ValueError("OpenAIAsyncChatBlock expects exactly one output column")
|
388
|
-
|
389
|
-
self.messages_column = self.input_cols[0]
|
390
|
-
self.output_column = self.output_cols[0]
|
391
|
-
|
392
|
-
# Store all generation parameters (only non-None values)
|
393
|
-
self.gen_kwargs = {}
|
394
|
-
params = {
|
395
|
-
"frequency_penalty": frequency_penalty,
|
396
|
-
"logit_bias": logit_bias,
|
397
|
-
"logprobs": logprobs,
|
398
|
-
"max_completion_tokens": max_completion_tokens,
|
399
|
-
"max_tokens": max_tokens,
|
400
|
-
"n": n,
|
401
|
-
"presence_penalty": presence_penalty,
|
402
|
-
"response_format": response_format,
|
403
|
-
"seed": seed,
|
404
|
-
"stop": stop,
|
405
|
-
"stream": stream,
|
406
|
-
"temperature": temperature,
|
407
|
-
"tool_choice": tool_choice,
|
408
|
-
"tools": tools,
|
409
|
-
"top_logprobs": top_logprobs,
|
410
|
-
"top_p": top_p,
|
411
|
-
"user": user,
|
412
|
-
"extra_body": extra_body,
|
413
|
-
}
|
414
|
-
|
415
|
-
# Only include non-None parameters
|
416
|
-
for key, value in params.items():
|
417
|
-
if value is not None:
|
418
|
-
self.gen_kwargs[key] = value
|
419
|
-
|
420
|
-
# Log initialization with model and parameters
|
421
|
-
logger.info(
|
422
|
-
f"Initialized OpenAIAsyncChatBlock '{block_name}' with model '{model_id}'",
|
423
|
-
extra={
|
424
|
-
"block_name": block_name,
|
425
|
-
"model_id": model_id,
|
426
|
-
"generation_params": self.gen_kwargs,
|
427
|
-
},
|
428
|
-
)
|
429
|
-
|
430
|
-
@retry(
|
431
|
-
wait=wait_random_exponential(min=1, max=60),
|
432
|
-
stop=stop_after_attempt(6),
|
433
|
-
retry=retry_if_exception_type(
|
434
|
-
(openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError)
|
435
|
-
),
|
436
|
-
)
|
437
|
-
async def _generate_single(
|
438
|
-
self, messages: List[Dict[str, Any]], **final_kwargs: Dict[str, Any]
|
439
|
-
) -> str:
|
440
|
-
"""Generate a single response asynchronously."""
|
441
|
-
response = await self.async_client.chat.completions.create(
|
442
|
-
messages=messages, **final_kwargs
|
443
|
-
)
|
444
|
-
return response.choices[0].message.content
|
445
|
-
|
446
|
-
def generate(self, samples: Dataset, **override_kwargs: Dict[str, Any]) -> Dataset:
|
447
|
-
"""Generate the output from the block using async calls.
|
448
|
-
|
449
|
-
Parameters set at runtime override those set during initialization.
|
450
|
-
Supports all text-relevant OpenAI Chat Completions API parameters.
|
451
|
-
|
452
|
-
Parameters
|
453
|
-
----------
|
454
|
-
samples : Dataset
|
455
|
-
Input dataset containing the messages column.
|
456
|
-
**override_kwargs : Dict[str, Any]
|
457
|
-
Runtime parameters that override initialization defaults.
|
458
|
-
Valid parameters: frequency_penalty, logit_bias, logprobs,
|
459
|
-
max_completion_tokens, max_tokens, n, presence_penalty,
|
460
|
-
response_format, seed, stop, stream, temperature, tool_choice,
|
461
|
-
tools, top_logprobs, top_p, user.
|
462
|
-
|
463
|
-
Returns
|
464
|
-
-------
|
465
|
-
Dataset
|
466
|
-
Dataset with the response added to the output column.
|
467
|
-
"""
|
468
|
-
# Define valid parameters for validation
|
469
|
-
valid_params = {
|
470
|
-
"frequency_penalty",
|
471
|
-
"logit_bias",
|
472
|
-
"logprobs",
|
473
|
-
"max_completion_tokens",
|
474
|
-
"max_tokens",
|
475
|
-
"n",
|
476
|
-
"presence_penalty",
|
477
|
-
"response_format",
|
478
|
-
"seed",
|
479
|
-
"stop",
|
480
|
-
"stream",
|
481
|
-
"temperature",
|
482
|
-
"tool_choice",
|
483
|
-
"tools",
|
484
|
-
"top_logprobs",
|
485
|
-
"top_p",
|
486
|
-
"user",
|
487
|
-
}
|
488
|
-
|
489
|
-
# Filter and validate override parameters
|
490
|
-
filtered_kwargs = {
|
491
|
-
k: v for k, v in override_kwargs.items() if k in valid_params
|
492
|
-
}
|
493
|
-
|
494
|
-
# Warn about invalid parameters
|
495
|
-
invalid_params = set(override_kwargs.keys()) - valid_params
|
496
|
-
if invalid_params:
|
497
|
-
logger.warning(f"Ignoring invalid parameters: {invalid_params}")
|
498
|
-
|
499
|
-
# Merge kwargs with priority: runtime > init > defaults
|
500
|
-
final_kwargs = {**self.gen_kwargs, **filtered_kwargs}
|
501
|
-
final_kwargs["model"] = self.model_id
|
502
|
-
|
503
|
-
# Log generation start with model and effective parameters
|
504
|
-
logger.info(
|
505
|
-
f"Starting async generation for {len(samples)} samples",
|
506
|
-
extra={
|
507
|
-
"block_name": self.block_name,
|
508
|
-
"model_id": self.model_id,
|
509
|
-
"batch_size": len(samples),
|
510
|
-
"effective_params": {
|
511
|
-
k: v
|
512
|
-
for k, v in final_kwargs.items()
|
513
|
-
if k
|
514
|
-
in [
|
515
|
-
"temperature",
|
516
|
-
"max_tokens",
|
517
|
-
"max_completion_tokens",
|
518
|
-
"top_p",
|
519
|
-
"n",
|
520
|
-
"seed",
|
521
|
-
]
|
522
|
-
},
|
523
|
-
},
|
524
|
-
)
|
525
|
-
|
526
|
-
# Run async generation
|
527
|
-
return asyncio.run(self._generate_async(samples, final_kwargs))
|
528
|
-
|
529
|
-
async def _generate_async(
|
530
|
-
self, samples: Dataset, final_kwargs: Dict[str, Any]
|
531
|
-
) -> Dataset:
|
532
|
-
"""Internal async method to generate all responses concurrently."""
|
533
|
-
# Extract all messages
|
534
|
-
messages_list = samples[self.messages_column]
|
535
|
-
|
536
|
-
# Create all tasks
|
537
|
-
tasks = [
|
538
|
-
self._generate_single(messages, **final_kwargs)
|
539
|
-
for messages in messages_list
|
540
|
-
]
|
541
|
-
|
542
|
-
# Execute all tasks concurrently and collect responses
|
543
|
-
responses = await asyncio.gather(*tasks)
|
544
|
-
|
545
|
-
# Log completion
|
546
|
-
logger.info(
|
547
|
-
f"Async generation completed successfully for {len(responses)} samples",
|
548
|
-
extra={
|
549
|
-
"block_name": self.block_name,
|
550
|
-
"model_id": final_kwargs["model"],
|
551
|
-
"batch_size": len(responses),
|
552
|
-
},
|
553
|
-
)
|
554
|
-
|
555
|
-
# Add responses as new column
|
556
|
-
return samples.add_column(self.output_column, responses)
|