sdg-hub 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. sdg_hub/__init__.py +28 -1
  2. sdg_hub/_version.py +2 -2
  3. sdg_hub/core/__init__.py +22 -0
  4. sdg_hub/core/blocks/__init__.py +58 -0
  5. sdg_hub/core/blocks/base.py +313 -0
  6. sdg_hub/core/blocks/deprecated_blocks/__init__.py +29 -0
  7. sdg_hub/core/blocks/deprecated_blocks/combine_columns.py +93 -0
  8. sdg_hub/core/blocks/deprecated_blocks/duplicate_columns.py +88 -0
  9. sdg_hub/core/blocks/deprecated_blocks/filter_by_value.py +103 -0
  10. sdg_hub/core/blocks/deprecated_blocks/flatten_columns.py +94 -0
  11. sdg_hub/core/blocks/deprecated_blocks/llmblock.py +479 -0
  12. sdg_hub/core/blocks/deprecated_blocks/rename_columns.py +88 -0
  13. sdg_hub/core/blocks/deprecated_blocks/sample_populator.py +58 -0
  14. sdg_hub/core/blocks/deprecated_blocks/selector.py +97 -0
  15. sdg_hub/core/blocks/deprecated_blocks/set_to_majority_value.py +88 -0
  16. sdg_hub/core/blocks/evaluation/__init__.py +9 -0
  17. sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +564 -0
  18. sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +564 -0
  19. sdg_hub/core/blocks/evaluation/verify_question_block.py +564 -0
  20. sdg_hub/core/blocks/filtering/__init__.py +12 -0
  21. sdg_hub/core/blocks/filtering/column_value_filter.py +188 -0
  22. sdg_hub/core/blocks/llm/__init__.py +27 -0
  23. sdg_hub/core/blocks/llm/client_manager.py +398 -0
  24. sdg_hub/core/blocks/llm/config.py +336 -0
  25. sdg_hub/core/blocks/llm/error_handler.py +368 -0
  26. sdg_hub/core/blocks/llm/llm_chat_block.py +542 -0
  27. sdg_hub/core/blocks/llm/llm_chat_with_parsing_retry_block.py +491 -0
  28. sdg_hub/core/blocks/llm/prompt_builder_block.py +368 -0
  29. sdg_hub/core/blocks/llm/text_parser_block.py +357 -0
  30. sdg_hub/core/blocks/registry.py +331 -0
  31. sdg_hub/core/blocks/transform/__init__.py +23 -0
  32. sdg_hub/core/blocks/transform/duplicate_columns.py +88 -0
  33. sdg_hub/core/blocks/transform/index_based_mapper.py +225 -0
  34. sdg_hub/core/blocks/transform/melt_columns.py +126 -0
  35. sdg_hub/core/blocks/transform/rename_columns.py +69 -0
  36. sdg_hub/core/blocks/transform/text_concat.py +102 -0
  37. sdg_hub/core/blocks/transform/uniform_col_val_setter.py +101 -0
  38. sdg_hub/core/flow/__init__.py +20 -0
  39. sdg_hub/core/flow/base.py +1209 -0
  40. sdg_hub/core/flow/checkpointer.py +333 -0
  41. sdg_hub/core/flow/metadata.py +389 -0
  42. sdg_hub/core/flow/migration.py +198 -0
  43. sdg_hub/core/flow/registry.py +393 -0
  44. sdg_hub/core/flow/validation.py +277 -0
  45. sdg_hub/{utils → core/utils}/__init__.py +7 -4
  46. sdg_hub/core/utils/datautils.py +63 -0
  47. sdg_hub/core/utils/error_handling.py +208 -0
  48. sdg_hub/core/utils/flow_id_words.yaml +231 -0
  49. sdg_hub/core/utils/flow_identifier.py +94 -0
  50. sdg_hub/{utils → core/utils}/path_resolution.py +2 -2
  51. sdg_hub/core/utils/yaml_utils.py +59 -0
  52. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/atomic_facts.yaml +40 -0
  53. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/detailed_summary.yaml +13 -0
  54. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_faithfulness.yaml +64 -0
  55. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_question.yaml +29 -0
  56. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_relevancy.yaml +81 -0
  57. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/extractive_summary.yaml +13 -0
  58. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +192 -0
  59. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/generate_questions_responses.yaml +54 -0
  60. sdg_hub-0.2.1.dist-info/METADATA +221 -0
  61. sdg_hub-0.2.1.dist-info/RECORD +68 -0
  62. sdg_hub/blocks/__init__.py +0 -42
  63. sdg_hub/blocks/block.py +0 -96
  64. sdg_hub/blocks/llmblock.py +0 -375
  65. sdg_hub/blocks/openaichatblock.py +0 -556
  66. sdg_hub/blocks/utilblocks.py +0 -597
  67. sdg_hub/checkpointer.py +0 -139
  68. sdg_hub/configs/annotations/cot_reflection.yaml +0 -34
  69. sdg_hub/configs/annotations/detailed_annotations.yaml +0 -28
  70. sdg_hub/configs/annotations/detailed_description.yaml +0 -10
  71. sdg_hub/configs/annotations/detailed_description_icl.yaml +0 -32
  72. sdg_hub/configs/annotations/simple_annotations.yaml +0 -9
  73. sdg_hub/configs/knowledge/__init__.py +0 -0
  74. sdg_hub/configs/knowledge/atomic_facts.yaml +0 -46
  75. sdg_hub/configs/knowledge/auxilary_instructions.yaml +0 -35
  76. sdg_hub/configs/knowledge/detailed_summary.yaml +0 -18
  77. sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +0 -68
  78. sdg_hub/configs/knowledge/evaluate_question.yaml +0 -38
  79. sdg_hub/configs/knowledge/evaluate_relevancy.yaml +0 -84
  80. sdg_hub/configs/knowledge/extractive_summary.yaml +0 -18
  81. sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +0 -39
  82. sdg_hub/configs/knowledge/generate_questions.yaml +0 -82
  83. sdg_hub/configs/knowledge/generate_questions_responses.yaml +0 -56
  84. sdg_hub/configs/knowledge/generate_responses.yaml +0 -86
  85. sdg_hub/configs/knowledge/mcq_generation.yaml +0 -83
  86. sdg_hub/configs/knowledge/router.yaml +0 -12
  87. sdg_hub/configs/knowledge/simple_generate_qa.yaml +0 -34
  88. sdg_hub/configs/reasoning/__init__.py +0 -0
  89. sdg_hub/configs/reasoning/dynamic_cot.yaml +0 -40
  90. sdg_hub/configs/skills/__init__.py +0 -0
  91. sdg_hub/configs/skills/analyzer.yaml +0 -48
  92. sdg_hub/configs/skills/annotation.yaml +0 -36
  93. sdg_hub/configs/skills/contexts.yaml +0 -28
  94. sdg_hub/configs/skills/critic.yaml +0 -60
  95. sdg_hub/configs/skills/evaluate_freeform_pair.yaml +0 -111
  96. sdg_hub/configs/skills/evaluate_freeform_questions.yaml +0 -78
  97. sdg_hub/configs/skills/evaluate_grounded_pair.yaml +0 -119
  98. sdg_hub/configs/skills/evaluate_grounded_questions.yaml +0 -51
  99. sdg_hub/configs/skills/freeform_questions.yaml +0 -34
  100. sdg_hub/configs/skills/freeform_responses.yaml +0 -39
  101. sdg_hub/configs/skills/grounded_questions.yaml +0 -38
  102. sdg_hub/configs/skills/grounded_responses.yaml +0 -59
  103. sdg_hub/configs/skills/icl_examples/STEM.yaml +0 -56
  104. sdg_hub/configs/skills/icl_examples/__init__.py +0 -0
  105. sdg_hub/configs/skills/icl_examples/coding.yaml +0 -97
  106. sdg_hub/configs/skills/icl_examples/extraction.yaml +0 -36
  107. sdg_hub/configs/skills/icl_examples/humanities.yaml +0 -71
  108. sdg_hub/configs/skills/icl_examples/math.yaml +0 -85
  109. sdg_hub/configs/skills/icl_examples/reasoning.yaml +0 -30
  110. sdg_hub/configs/skills/icl_examples/roleplay.yaml +0 -45
  111. sdg_hub/configs/skills/icl_examples/writing.yaml +0 -80
  112. sdg_hub/configs/skills/judge.yaml +0 -53
  113. sdg_hub/configs/skills/planner.yaml +0 -67
  114. sdg_hub/configs/skills/respond.yaml +0 -8
  115. sdg_hub/configs/skills/revised_responder.yaml +0 -78
  116. sdg_hub/configs/skills/router.yaml +0 -59
  117. sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +0 -27
  118. sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +0 -31
  119. sdg_hub/flow.py +0 -477
  120. sdg_hub/flow_runner.py +0 -450
  121. sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +0 -13
  122. sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +0 -12
  123. sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +0 -89
  124. sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +0 -136
  125. sdg_hub/flows/generation/skills/improve_responses.yaml +0 -103
  126. sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +0 -12
  127. sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +0 -12
  128. sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +0 -80
  129. sdg_hub/flows/generation/skills/synth_skills.yaml +0 -59
  130. sdg_hub/pipeline.py +0 -121
  131. sdg_hub/prompts.py +0 -80
  132. sdg_hub/registry.py +0 -122
  133. sdg_hub/sdg.py +0 -206
  134. sdg_hub/utils/config_validation.py +0 -91
  135. sdg_hub/utils/datautils.py +0 -14
  136. sdg_hub/utils/error_handling.py +0 -94
  137. sdg_hub/utils/validation_result.py +0 -10
  138. sdg_hub-0.1.4.dist-info/METADATA +0 -190
  139. sdg_hub-0.1.4.dist-info/RECORD +0 -89
  140. sdg_hub/{logger_config.py → core/utils/logger_config.py} +1 -1
  141. /sdg_hub/{configs/__init__.py → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/README.md} +0 -0
  142. /sdg_hub/{configs/annotations → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab}/__init__.py +0 -0
  143. {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.1.dist-info}/WHEEL +0 -0
  144. {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.1.dist-info}/licenses/LICENSE +0 -0
  145. {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,556 +0,0 @@
1
- # SPDX-License-Identifier: Apache-2.0
2
- """OpenAI-specific blocks for text generation.
3
-
4
- This module provides blocks for interacting with OpenAI's Chat Completions API.
5
- """
6
-
7
- # Standard
8
- from typing import Any, Dict, List, Optional, Union
9
- import asyncio
10
-
11
- # Third Party
12
- from datasets import Dataset
13
- from tenacity import (
14
- retry,
15
- retry_if_exception_type,
16
- stop_after_attempt,
17
- wait_random_exponential,
18
- )
19
- import openai
20
-
21
- # Local
22
- from ..logger_config import setup_logger
23
- from ..registry import BlockRegistry
24
- from .block import Block
25
-
26
- logger = setup_logger(__name__)
27
-
28
-
29
- @BlockRegistry.register("OpenAIChatBlock")
30
- class OpenAIChatBlock(Block):
31
- """Block for generating text using OpenAI Chat Completions API.
32
-
33
- This block takes a column containing OpenAI message format and makes
34
- direct calls to the chat completions endpoint.
35
-
36
- Parameters
37
- ----------
38
- block_name : str
39
- Name of the block.
40
- input_cols : Union[str, List[str]]
41
- Input column name(s). Should contain the messages list.
42
- output_cols : Union[str, List[str]]
43
- Output column name(s) for the response.
44
- client : openai.OpenAI
45
- OpenAI client instance.
46
- model_id : str
47
- Model ID to use.
48
-
49
- ### Text-relevant OpenAI Chat Completions API parameters ###
50
-
51
- frequency_penalty : Optional[float], optional
52
- Penalize frequent tokens (-2.0 to 2.0).
53
- logit_bias : Optional[Dict[str, int]], optional
54
- Modify likelihood of specified tokens.
55
- logprobs : Optional[bool], optional
56
- Whether to return log probabilities.
57
- max_completion_tokens : Optional[int], optional
58
- Maximum tokens in completion.
59
- max_tokens : Optional[int], optional
60
- Maximum tokens in completion (legacy).
61
- n : Optional[int], optional
62
- Number of completions to generate.
63
- presence_penalty : Optional[float], optional
64
- Penalize repeated tokens (-2.0 to 2.0).
65
- response_format : Optional[Dict[str, Any]], optional
66
- Response format specification (e.g., JSON mode).
67
- seed : Optional[int], optional
68
- Seed for deterministic outputs.
69
- stop : Optional[Union[str, List[str]]], optional
70
- Stop sequences.
71
- stream : Optional[bool], optional
72
- Whether to stream responses.
73
- temperature : Optional[float], optional
74
- Sampling temperature (0.0 to 2.0).
75
- tool_choice : Optional[Union[str, Dict[str, Any]]], optional
76
- Tool selection strategy.
77
- tools : Optional[List[Dict[str, Any]]], optional
78
- Available tools for function calling.
79
- top_logprobs : Optional[int], optional
80
- Number of top log probabilities to return.
81
- top_p : Optional[float], optional
82
- Nucleus sampling parameter (0.0 to 1.0).
83
- user : Optional[str], optional
84
- End-user identifier.
85
- extra_body : Optional[dict], optional
86
- Dictionary of additional parameters if supported by inference backend
87
- """
88
-
89
- def __init__(
90
- self,
91
- block_name: str,
92
- input_cols: Union[str, List[str]],
93
- output_cols: Union[str, List[str]],
94
- client: openai.OpenAI,
95
- model_id: str,
96
- # Text-relevant OpenAI Chat Completions API parameters
97
- frequency_penalty: Optional[float] = None,
98
- logit_bias: Optional[Dict[str, int]] = None,
99
- logprobs: Optional[bool] = None,
100
- max_completion_tokens: Optional[int] = None,
101
- max_tokens: Optional[int] = None,
102
- n: Optional[int] = None,
103
- presence_penalty: Optional[float] = None,
104
- response_format: Optional[Dict[str, Any]] = None,
105
- seed: Optional[int] = None,
106
- stop: Optional[Union[str, List[str]]] = None,
107
- stream: Optional[bool] = None,
108
- temperature: Optional[float] = None,
109
- tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
110
- tools: Optional[List[Dict[str, Any]]] = None,
111
- top_logprobs: Optional[int] = None,
112
- top_p: Optional[float] = None,
113
- user: Optional[str] = None,
114
- extra_body: Optional[dict] = None,
115
- ) -> None:
116
- super().__init__(block_name)
117
- self.input_cols = [input_cols] if isinstance(input_cols, str) else input_cols
118
- self.output_cols = (
119
- [output_cols] if isinstance(output_cols, str) else output_cols
120
- )
121
- self.client = client
122
- self.model_id = model_id
123
-
124
- # For this block, we expect exactly one input column (messages) and one output column
125
- if len(self.input_cols) != 1:
126
- raise ValueError("OpenAIChatBlock expects exactly one input column")
127
- if len(self.output_cols) != 1:
128
- raise ValueError("OpenAIChatBlock expects exactly one output column")
129
-
130
- self.messages_column = self.input_cols[0]
131
- self.output_column = self.output_cols[0]
132
-
133
- # Store all generation parameters (only non-None values)
134
- self.gen_kwargs = {}
135
- params = {
136
- "frequency_penalty": frequency_penalty,
137
- "logit_bias": logit_bias,
138
- "logprobs": logprobs,
139
- "max_completion_tokens": max_completion_tokens,
140
- "max_tokens": max_tokens,
141
- "n": n,
142
- "presence_penalty": presence_penalty,
143
- "response_format": response_format,
144
- "seed": seed,
145
- "stop": stop,
146
- "stream": stream,
147
- "temperature": temperature,
148
- "tool_choice": tool_choice,
149
- "tools": tools,
150
- "top_logprobs": top_logprobs,
151
- "top_p": top_p,
152
- "user": user,
153
- "extra_body": extra_body,
154
- }
155
-
156
- # Only include non-None parameters
157
- for key, value in params.items():
158
- if value is not None:
159
- self.gen_kwargs[key] = value
160
-
161
- # Log initialization with model and parameters
162
- logger.info(
163
- f"Initialized OpenAIChatBlock '{block_name}' with model '{model_id}'",
164
- extra={
165
- "block_name": block_name,
166
- "model_id": model_id,
167
- "generation_params": self.gen_kwargs,
168
- },
169
- )
170
-
171
- @retry(
172
- wait=wait_random_exponential(min=1, max=60),
173
- stop=stop_after_attempt(6),
174
- retry=retry_if_exception_type(
175
- (openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError)
176
- ),
177
- )
178
- def _create_completion_with_retry(self, **kwargs):
179
- """Create completion with retry logic."""
180
- return self.client.chat.completions.create(**kwargs)
181
-
182
- def generate(self, samples: Dataset, **override_kwargs: Dict[str, Any]) -> Dataset:
183
- """Generate the output from the block.
184
-
185
- Parameters set at runtime override those set during initialization.
186
- Supports all text-relevant OpenAI Chat Completions API parameters.
187
-
188
- Parameters
189
- ----------
190
- samples : Dataset
191
- Input dataset containing the messages column.
192
- **override_kwargs : Dict[str, Any]
193
- Runtime parameters that override initialization defaults.
194
- Valid parameters: frequency_penalty, logit_bias, logprobs,
195
- max_completion_tokens, max_tokens, n, presence_penalty,
196
- response_format, seed, stop, stream, temperature, tool_choice,
197
- tools, top_logprobs, top_p, user.
198
-
199
- Returns
200
- -------
201
- Dataset
202
- Dataset with the response added to the output column.
203
- """
204
- # Define valid parameters for validation
205
- valid_params = {
206
- "frequency_penalty",
207
- "logit_bias",
208
- "logprobs",
209
- "max_completion_tokens",
210
- "max_tokens",
211
- "n",
212
- "presence_penalty",
213
- "response_format",
214
- "seed",
215
- "stop",
216
- "stream",
217
- "temperature",
218
- "tool_choice",
219
- "tools",
220
- "top_logprobs",
221
- "top_p",
222
- "user",
223
- "extra_body",
224
- }
225
-
226
- # Filter and validate override parameters
227
- filtered_kwargs = {
228
- k: v for k, v in override_kwargs.items() if k in valid_params
229
- }
230
-
231
- # Warn about invalid parameters
232
- invalid_params = set(override_kwargs.keys()) - valid_params
233
- if invalid_params:
234
- logger.warning(f"Ignoring invalid parameters: {invalid_params}")
235
-
236
- # Merge kwargs with priority: runtime > init > defaults
237
- final_kwargs = {**self.gen_kwargs, **filtered_kwargs}
238
- final_kwargs["model"] = self.model_id
239
-
240
- # Extract all messages
241
- messages_list = samples[self.messages_column]
242
-
243
- # Log generation start with model and effective parameters
244
- logger.info(
245
- f"Starting generation for {len(messages_list)} samples",
246
- extra={
247
- "block_name": self.block_name,
248
- "model_id": self.model_id,
249
- "batch_size": len(messages_list),
250
- "effective_params": {
251
- k: v
252
- for k, v in final_kwargs.items()
253
- if k
254
- in [
255
- "temperature",
256
- "max_tokens",
257
- "max_completion_tokens",
258
- "top_p",
259
- "n",
260
- "seed",
261
- ]
262
- },
263
- },
264
- )
265
-
266
- # Get all responses
267
- responses = []
268
- for messages in messages_list:
269
- response = self._create_completion_with_retry(
270
- messages=messages, **final_kwargs
271
- )
272
- responses.append(response.choices[0].message.content)
273
-
274
- # Log completion
275
- logger.info(
276
- f"Generation completed successfully for {len(responses)} samples",
277
- extra={
278
- "block_name": self.block_name,
279
- "model_id": self.model_id,
280
- "batch_size": len(responses),
281
- },
282
- )
283
-
284
- # Add responses as new column
285
- return samples.add_column(self.output_column, responses)
286
-
287
-
288
- @BlockRegistry.register("OpenAIAsyncChatBlock")
289
- class OpenAIAsyncChatBlock(Block):
290
- """Async block for generating text using OpenAI Chat Completions API.
291
-
292
- This block takes a column containing OpenAI message format and makes
293
- asynchronous calls to the chat completions endpoint for better performance.
294
-
295
- Parameters
296
- ----------
297
- block_name : str
298
- Name of the block.
299
- input_cols : Union[str, List[str]]
300
- Input column name(s). Should contain the messages list.
301
- output_cols : Union[str, List[str]]
302
- Output column name(s) for the response.
303
- async_client : openai.AsyncOpenAI
304
- Async OpenAI client instance.
305
- model_id : str
306
- Model ID to use.
307
-
308
- ### Text-relevant OpenAI Chat Completions API parameters ###
309
-
310
- frequency_penalty : Optional[float], optional
311
- Penalize frequent tokens (-2.0 to 2.0).
312
- logit_bias : Optional[Dict[str, int]], optional
313
- Modify likelihood of specified tokens.
314
- logprobs : Optional[bool], optional
315
- Whether to return log probabilities.
316
- max_completion_tokens : Optional[int], optional
317
- Maximum tokens in completion.
318
- max_tokens : Optional[int], optional
319
- Maximum tokens in completion (legacy).
320
- n : Optional[int], optional
321
- Number of completions to generate.
322
- presence_penalty : Optional[float], optional
323
- Penalize repeated tokens (-2.0 to 2.0).
324
- response_format : Optional[Dict[str, Any]], optional
325
- Response format specification (e.g., JSON mode).
326
- seed : Optional[int], optional
327
- Seed for deterministic outputs.
328
- stop : Optional[Union[str, List[str]]], optional
329
- Stop sequences.
330
- stream : Optional[bool], optional
331
- Whether to stream responses.
332
- temperature : Optional[float], optional
333
- Sampling temperature (0.0 to 2.0).
334
- tool_choice : Optional[Union[str, Dict[str, Any]]], optional
335
- Tool selection strategy.
336
- tools : Optional[List[Dict[str, Any]]], optional
337
- Available tools for function calling.
338
- top_logprobs : Optional[int], optional
339
- Number of top log probabilities to return.
340
- top_p : Optional[float], optional
341
- Nucleus sampling parameter (0.0 to 1.0).
342
- user : Optional[str], optional
343
- End-user identifier.
344
- extra_body : Optional[dict], optional
345
- Dictionary of additional parameters if supported by inference backend
346
- """
347
-
348
- def __init__(
349
- self,
350
- block_name: str,
351
- input_cols: Union[str, List[str]],
352
- output_cols: Union[str, List[str]],
353
- async_client: openai.AsyncOpenAI,
354
- model_id: str,
355
- # Text-relevant OpenAI Chat Completions API parameters
356
- frequency_penalty: Optional[float] = None,
357
- logit_bias: Optional[Dict[str, int]] = None,
358
- logprobs: Optional[bool] = None,
359
- max_completion_tokens: Optional[int] = None,
360
- max_tokens: Optional[int] = None,
361
- n: Optional[int] = None,
362
- presence_penalty: Optional[float] = None,
363
- response_format: Optional[Dict[str, Any]] = None,
364
- seed: Optional[int] = None,
365
- stop: Optional[Union[str, List[str]]] = None,
366
- stream: Optional[bool] = None,
367
- temperature: Optional[float] = None,
368
- tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
369
- tools: Optional[List[Dict[str, Any]]] = None,
370
- top_logprobs: Optional[int] = None,
371
- top_p: Optional[float] = None,
372
- user: Optional[str] = None,
373
- extra_body: Optional[dict] = None,
374
- ) -> None:
375
- super().__init__(block_name)
376
- self.input_cols = [input_cols] if isinstance(input_cols, str) else input_cols
377
- self.output_cols = (
378
- [output_cols] if isinstance(output_cols, str) else output_cols
379
- )
380
- self.async_client = async_client
381
- self.model_id = model_id
382
-
383
- # For this block, we expect exactly one input column (messages) and one output column
384
- if len(self.input_cols) != 1:
385
- raise ValueError("OpenAIAsyncChatBlock expects exactly one input column")
386
- if len(self.output_cols) != 1:
387
- raise ValueError("OpenAIAsyncChatBlock expects exactly one output column")
388
-
389
- self.messages_column = self.input_cols[0]
390
- self.output_column = self.output_cols[0]
391
-
392
- # Store all generation parameters (only non-None values)
393
- self.gen_kwargs = {}
394
- params = {
395
- "frequency_penalty": frequency_penalty,
396
- "logit_bias": logit_bias,
397
- "logprobs": logprobs,
398
- "max_completion_tokens": max_completion_tokens,
399
- "max_tokens": max_tokens,
400
- "n": n,
401
- "presence_penalty": presence_penalty,
402
- "response_format": response_format,
403
- "seed": seed,
404
- "stop": stop,
405
- "stream": stream,
406
- "temperature": temperature,
407
- "tool_choice": tool_choice,
408
- "tools": tools,
409
- "top_logprobs": top_logprobs,
410
- "top_p": top_p,
411
- "user": user,
412
- "extra_body": extra_body,
413
- }
414
-
415
- # Only include non-None parameters
416
- for key, value in params.items():
417
- if value is not None:
418
- self.gen_kwargs[key] = value
419
-
420
- # Log initialization with model and parameters
421
- logger.info(
422
- f"Initialized OpenAIAsyncChatBlock '{block_name}' with model '{model_id}'",
423
- extra={
424
- "block_name": block_name,
425
- "model_id": model_id,
426
- "generation_params": self.gen_kwargs,
427
- },
428
- )
429
-
430
- @retry(
431
- wait=wait_random_exponential(min=1, max=60),
432
- stop=stop_after_attempt(6),
433
- retry=retry_if_exception_type(
434
- (openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError)
435
- ),
436
- )
437
- async def _generate_single(
438
- self, messages: List[Dict[str, Any]], **final_kwargs: Dict[str, Any]
439
- ) -> str:
440
- """Generate a single response asynchronously."""
441
- response = await self.async_client.chat.completions.create(
442
- messages=messages, **final_kwargs
443
- )
444
- return response.choices[0].message.content
445
-
446
- def generate(self, samples: Dataset, **override_kwargs: Dict[str, Any]) -> Dataset:
447
- """Generate the output from the block using async calls.
448
-
449
- Parameters set at runtime override those set during initialization.
450
- Supports all text-relevant OpenAI Chat Completions API parameters.
451
-
452
- Parameters
453
- ----------
454
- samples : Dataset
455
- Input dataset containing the messages column.
456
- **override_kwargs : Dict[str, Any]
457
- Runtime parameters that override initialization defaults.
458
- Valid parameters: frequency_penalty, logit_bias, logprobs,
459
- max_completion_tokens, max_tokens, n, presence_penalty,
460
- response_format, seed, stop, stream, temperature, tool_choice,
461
- tools, top_logprobs, top_p, user.
462
-
463
- Returns
464
- -------
465
- Dataset
466
- Dataset with the response added to the output column.
467
- """
468
- # Define valid parameters for validation
469
- valid_params = {
470
- "frequency_penalty",
471
- "logit_bias",
472
- "logprobs",
473
- "max_completion_tokens",
474
- "max_tokens",
475
- "n",
476
- "presence_penalty",
477
- "response_format",
478
- "seed",
479
- "stop",
480
- "stream",
481
- "temperature",
482
- "tool_choice",
483
- "tools",
484
- "top_logprobs",
485
- "top_p",
486
- "user",
487
- }
488
-
489
- # Filter and validate override parameters
490
- filtered_kwargs = {
491
- k: v for k, v in override_kwargs.items() if k in valid_params
492
- }
493
-
494
- # Warn about invalid parameters
495
- invalid_params = set(override_kwargs.keys()) - valid_params
496
- if invalid_params:
497
- logger.warning(f"Ignoring invalid parameters: {invalid_params}")
498
-
499
- # Merge kwargs with priority: runtime > init > defaults
500
- final_kwargs = {**self.gen_kwargs, **filtered_kwargs}
501
- final_kwargs["model"] = self.model_id
502
-
503
- # Log generation start with model and effective parameters
504
- logger.info(
505
- f"Starting async generation for {len(samples)} samples",
506
- extra={
507
- "block_name": self.block_name,
508
- "model_id": self.model_id,
509
- "batch_size": len(samples),
510
- "effective_params": {
511
- k: v
512
- for k, v in final_kwargs.items()
513
- if k
514
- in [
515
- "temperature",
516
- "max_tokens",
517
- "max_completion_tokens",
518
- "top_p",
519
- "n",
520
- "seed",
521
- ]
522
- },
523
- },
524
- )
525
-
526
- # Run async generation
527
- return asyncio.run(self._generate_async(samples, final_kwargs))
528
-
529
- async def _generate_async(
530
- self, samples: Dataset, final_kwargs: Dict[str, Any]
531
- ) -> Dataset:
532
- """Internal async method to generate all responses concurrently."""
533
- # Extract all messages
534
- messages_list = samples[self.messages_column]
535
-
536
- # Create all tasks
537
- tasks = [
538
- self._generate_single(messages, **final_kwargs)
539
- for messages in messages_list
540
- ]
541
-
542
- # Execute all tasks concurrently and collect responses
543
- responses = await asyncio.gather(*tasks)
544
-
545
- # Log completion
546
- logger.info(
547
- f"Async generation completed successfully for {len(responses)} samples",
548
- extra={
549
- "block_name": self.block_name,
550
- "model_id": final_kwargs["model"],
551
- "batch_size": len(responses),
552
- },
553
- )
554
-
555
- # Add responses as new column
556
- return samples.add_column(self.output_column, responses)