sdg-hub 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/_version.py +2 -2
- sdg_hub/blocks/__init__.py +6 -0
- sdg_hub/blocks/openaichatblock.py +556 -0
- sdg_hub/flow.py +21 -18
- sdg_hub/flow_runner.py +273 -52
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +0 -5
- sdg_hub/prompts.py +31 -0
- sdg_hub/utils/__init__.py +5 -0
- sdg_hub/utils/error_handling.py +94 -0
- sdg_hub/utils/path_resolution.py +62 -0
- {sdg_hub-0.1.1.dist-info → sdg_hub-0.1.2.dist-info}/METADATA +1 -1
- {sdg_hub-0.1.1.dist-info → sdg_hub-0.1.2.dist-info}/RECORD +15 -12
- {sdg_hub-0.1.1.dist-info → sdg_hub-0.1.2.dist-info}/WHEEL +0 -0
- {sdg_hub-0.1.1.dist-info → sdg_hub-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.1.1.dist-info → sdg_hub-0.1.2.dist-info}/top_level.txt +0 -0
sdg_hub/_version.py
CHANGED
sdg_hub/blocks/__init__.py
CHANGED
@@ -6,6 +6,10 @@ This package provides various block implementations for data generation, process
|
|
6
6
|
# Local
|
7
7
|
from .block import Block
|
8
8
|
from .llmblock import LLMBlock, ConditionalLLMBlock
|
9
|
+
from .openaichatblock import (
|
10
|
+
OpenAIChatBlock,
|
11
|
+
OpenAIAsyncChatBlock
|
12
|
+
)
|
9
13
|
from .utilblocks import (
|
10
14
|
SamplePopulatorBlock,
|
11
15
|
SelectorBlock,
|
@@ -33,4 +37,6 @@ __all__ = [
|
|
33
37
|
"RenameColumns",
|
34
38
|
"SetToMajorityValue",
|
35
39
|
"BlockRegistry",
|
40
|
+
"OpenAIChatBlock",
|
41
|
+
"OpenAIAsyncChatBlock"
|
36
42
|
]
|
@@ -0,0 +1,556 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""OpenAI-specific blocks for text generation.
|
3
|
+
|
4
|
+
This module provides blocks for interacting with OpenAI's Chat Completions API.
|
5
|
+
"""
|
6
|
+
|
7
|
+
# Standard
|
8
|
+
from typing import Any, Dict, List, Optional, Union
|
9
|
+
import asyncio
|
10
|
+
|
11
|
+
# Third Party
|
12
|
+
from datasets import Dataset
|
13
|
+
from tenacity import (
|
14
|
+
retry,
|
15
|
+
retry_if_exception_type,
|
16
|
+
stop_after_attempt,
|
17
|
+
wait_random_exponential,
|
18
|
+
)
|
19
|
+
import openai
|
20
|
+
|
21
|
+
# Local
|
22
|
+
from ..logger_config import setup_logger
|
23
|
+
from ..registry import BlockRegistry
|
24
|
+
from .block import Block
|
25
|
+
|
26
|
+
logger = setup_logger(__name__)
|
27
|
+
|
28
|
+
|
29
|
+
@BlockRegistry.register("OpenAIChatBlock")
|
30
|
+
class OpenAIChatBlock(Block):
|
31
|
+
"""Block for generating text using OpenAI Chat Completions API.
|
32
|
+
|
33
|
+
This block takes a column containing OpenAI message format and makes
|
34
|
+
direct calls to the chat completions endpoint.
|
35
|
+
|
36
|
+
Parameters
|
37
|
+
----------
|
38
|
+
block_name : str
|
39
|
+
Name of the block.
|
40
|
+
input_cols : Union[str, List[str]]
|
41
|
+
Input column name(s). Should contain the messages list.
|
42
|
+
output_cols : Union[str, List[str]]
|
43
|
+
Output column name(s) for the response.
|
44
|
+
client : openai.OpenAI
|
45
|
+
OpenAI client instance.
|
46
|
+
model_id : str
|
47
|
+
Model ID to use.
|
48
|
+
|
49
|
+
### Text-relevant OpenAI Chat Completions API parameters ###
|
50
|
+
|
51
|
+
frequency_penalty : Optional[float], optional
|
52
|
+
Penalize frequent tokens (-2.0 to 2.0).
|
53
|
+
logit_bias : Optional[Dict[str, int]], optional
|
54
|
+
Modify likelihood of specified tokens.
|
55
|
+
logprobs : Optional[bool], optional
|
56
|
+
Whether to return log probabilities.
|
57
|
+
max_completion_tokens : Optional[int], optional
|
58
|
+
Maximum tokens in completion.
|
59
|
+
max_tokens : Optional[int], optional
|
60
|
+
Maximum tokens in completion (legacy).
|
61
|
+
n : Optional[int], optional
|
62
|
+
Number of completions to generate.
|
63
|
+
presence_penalty : Optional[float], optional
|
64
|
+
Penalize repeated tokens (-2.0 to 2.0).
|
65
|
+
response_format : Optional[Dict[str, Any]], optional
|
66
|
+
Response format specification (e.g., JSON mode).
|
67
|
+
seed : Optional[int], optional
|
68
|
+
Seed for deterministic outputs.
|
69
|
+
stop : Optional[Union[str, List[str]]], optional
|
70
|
+
Stop sequences.
|
71
|
+
stream : Optional[bool], optional
|
72
|
+
Whether to stream responses.
|
73
|
+
temperature : Optional[float], optional
|
74
|
+
Sampling temperature (0.0 to 2.0).
|
75
|
+
tool_choice : Optional[Union[str, Dict[str, Any]]], optional
|
76
|
+
Tool selection strategy.
|
77
|
+
tools : Optional[List[Dict[str, Any]]], optional
|
78
|
+
Available tools for function calling.
|
79
|
+
top_logprobs : Optional[int], optional
|
80
|
+
Number of top log probabilities to return.
|
81
|
+
top_p : Optional[float], optional
|
82
|
+
Nucleus sampling parameter (0.0 to 1.0).
|
83
|
+
user : Optional[str], optional
|
84
|
+
End-user identifier.
|
85
|
+
extra_body : Optional[dict], optional
|
86
|
+
Dictionary of additional parameters if supported by inference backend
|
87
|
+
"""
|
88
|
+
|
89
|
+
def __init__(
|
90
|
+
self,
|
91
|
+
block_name: str,
|
92
|
+
input_cols: Union[str, List[str]],
|
93
|
+
output_cols: Union[str, List[str]],
|
94
|
+
client: openai.OpenAI,
|
95
|
+
model_id: str,
|
96
|
+
# Text-relevant OpenAI Chat Completions API parameters
|
97
|
+
frequency_penalty: Optional[float] = None,
|
98
|
+
logit_bias: Optional[Dict[str, int]] = None,
|
99
|
+
logprobs: Optional[bool] = None,
|
100
|
+
max_completion_tokens: Optional[int] = None,
|
101
|
+
max_tokens: Optional[int] = None,
|
102
|
+
n: Optional[int] = None,
|
103
|
+
presence_penalty: Optional[float] = None,
|
104
|
+
response_format: Optional[Dict[str, Any]] = None,
|
105
|
+
seed: Optional[int] = None,
|
106
|
+
stop: Optional[Union[str, List[str]]] = None,
|
107
|
+
stream: Optional[bool] = None,
|
108
|
+
temperature: Optional[float] = None,
|
109
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
110
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
111
|
+
top_logprobs: Optional[int] = None,
|
112
|
+
top_p: Optional[float] = None,
|
113
|
+
user: Optional[str] = None,
|
114
|
+
extra_body: Optional[dict] = None,
|
115
|
+
) -> None:
|
116
|
+
super().__init__(block_name)
|
117
|
+
self.input_cols = [input_cols] if isinstance(input_cols, str) else input_cols
|
118
|
+
self.output_cols = (
|
119
|
+
[output_cols] if isinstance(output_cols, str) else output_cols
|
120
|
+
)
|
121
|
+
self.client = client
|
122
|
+
self.model_id = model_id
|
123
|
+
|
124
|
+
# For this block, we expect exactly one input column (messages) and one output column
|
125
|
+
if len(self.input_cols) != 1:
|
126
|
+
raise ValueError("OpenAIChatBlock expects exactly one input column")
|
127
|
+
if len(self.output_cols) != 1:
|
128
|
+
raise ValueError("OpenAIChatBlock expects exactly one output column")
|
129
|
+
|
130
|
+
self.messages_column = self.input_cols[0]
|
131
|
+
self.output_column = self.output_cols[0]
|
132
|
+
|
133
|
+
# Store all generation parameters (only non-None values)
|
134
|
+
self.gen_kwargs = {}
|
135
|
+
params = {
|
136
|
+
"frequency_penalty": frequency_penalty,
|
137
|
+
"logit_bias": logit_bias,
|
138
|
+
"logprobs": logprobs,
|
139
|
+
"max_completion_tokens": max_completion_tokens,
|
140
|
+
"max_tokens": max_tokens,
|
141
|
+
"n": n,
|
142
|
+
"presence_penalty": presence_penalty,
|
143
|
+
"response_format": response_format,
|
144
|
+
"seed": seed,
|
145
|
+
"stop": stop,
|
146
|
+
"stream": stream,
|
147
|
+
"temperature": temperature,
|
148
|
+
"tool_choice": tool_choice,
|
149
|
+
"tools": tools,
|
150
|
+
"top_logprobs": top_logprobs,
|
151
|
+
"top_p": top_p,
|
152
|
+
"user": user,
|
153
|
+
"extra_body": extra_body,
|
154
|
+
}
|
155
|
+
|
156
|
+
# Only include non-None parameters
|
157
|
+
for key, value in params.items():
|
158
|
+
if value is not None:
|
159
|
+
self.gen_kwargs[key] = value
|
160
|
+
|
161
|
+
# Log initialization with model and parameters
|
162
|
+
logger.info(
|
163
|
+
f"Initialized OpenAIChatBlock '{block_name}' with model '{model_id}'",
|
164
|
+
extra={
|
165
|
+
"block_name": block_name,
|
166
|
+
"model_id": model_id,
|
167
|
+
"generation_params": self.gen_kwargs,
|
168
|
+
},
|
169
|
+
)
|
170
|
+
|
171
|
+
@retry(
|
172
|
+
wait=wait_random_exponential(min=1, max=60),
|
173
|
+
stop=stop_after_attempt(6),
|
174
|
+
retry=retry_if_exception_type(
|
175
|
+
(openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError)
|
176
|
+
),
|
177
|
+
)
|
178
|
+
def _create_completion_with_retry(self, **kwargs):
|
179
|
+
"""Create completion with retry logic."""
|
180
|
+
return self.client.chat.completions.create(**kwargs)
|
181
|
+
|
182
|
+
def generate(self, samples: Dataset, **override_kwargs: Dict[str, Any]) -> Dataset:
|
183
|
+
"""Generate the output from the block.
|
184
|
+
|
185
|
+
Parameters set at runtime override those set during initialization.
|
186
|
+
Supports all text-relevant OpenAI Chat Completions API parameters.
|
187
|
+
|
188
|
+
Parameters
|
189
|
+
----------
|
190
|
+
samples : Dataset
|
191
|
+
Input dataset containing the messages column.
|
192
|
+
**override_kwargs : Dict[str, Any]
|
193
|
+
Runtime parameters that override initialization defaults.
|
194
|
+
Valid parameters: frequency_penalty, logit_bias, logprobs,
|
195
|
+
max_completion_tokens, max_tokens, n, presence_penalty,
|
196
|
+
response_format, seed, stop, stream, temperature, tool_choice,
|
197
|
+
tools, top_logprobs, top_p, user.
|
198
|
+
|
199
|
+
Returns
|
200
|
+
-------
|
201
|
+
Dataset
|
202
|
+
Dataset with the response added to the output column.
|
203
|
+
"""
|
204
|
+
# Define valid parameters for validation
|
205
|
+
valid_params = {
|
206
|
+
"frequency_penalty",
|
207
|
+
"logit_bias",
|
208
|
+
"logprobs",
|
209
|
+
"max_completion_tokens",
|
210
|
+
"max_tokens",
|
211
|
+
"n",
|
212
|
+
"presence_penalty",
|
213
|
+
"response_format",
|
214
|
+
"seed",
|
215
|
+
"stop",
|
216
|
+
"stream",
|
217
|
+
"temperature",
|
218
|
+
"tool_choice",
|
219
|
+
"tools",
|
220
|
+
"top_logprobs",
|
221
|
+
"top_p",
|
222
|
+
"user",
|
223
|
+
"extra_body",
|
224
|
+
}
|
225
|
+
|
226
|
+
# Filter and validate override parameters
|
227
|
+
filtered_kwargs = {
|
228
|
+
k: v for k, v in override_kwargs.items() if k in valid_params
|
229
|
+
}
|
230
|
+
|
231
|
+
# Warn about invalid parameters
|
232
|
+
invalid_params = set(override_kwargs.keys()) - valid_params
|
233
|
+
if invalid_params:
|
234
|
+
logger.warning(f"Ignoring invalid parameters: {invalid_params}")
|
235
|
+
|
236
|
+
# Merge kwargs with priority: runtime > init > defaults
|
237
|
+
final_kwargs = {**self.gen_kwargs, **filtered_kwargs}
|
238
|
+
final_kwargs["model"] = self.model_id
|
239
|
+
|
240
|
+
# Extract all messages
|
241
|
+
messages_list = samples[self.messages_column]
|
242
|
+
|
243
|
+
# Log generation start with model and effective parameters
|
244
|
+
logger.info(
|
245
|
+
f"Starting generation for {len(messages_list)} samples",
|
246
|
+
extra={
|
247
|
+
"block_name": self.block_name,
|
248
|
+
"model_id": self.model_id,
|
249
|
+
"batch_size": len(messages_list),
|
250
|
+
"effective_params": {
|
251
|
+
k: v
|
252
|
+
for k, v in final_kwargs.items()
|
253
|
+
if k
|
254
|
+
in [
|
255
|
+
"temperature",
|
256
|
+
"max_tokens",
|
257
|
+
"max_completion_tokens",
|
258
|
+
"top_p",
|
259
|
+
"n",
|
260
|
+
"seed",
|
261
|
+
]
|
262
|
+
},
|
263
|
+
},
|
264
|
+
)
|
265
|
+
|
266
|
+
# Get all responses
|
267
|
+
responses = []
|
268
|
+
for messages in messages_list:
|
269
|
+
response = self._create_completion_with_retry(
|
270
|
+
messages=messages, **final_kwargs
|
271
|
+
)
|
272
|
+
responses.append(response.choices[0].message.content)
|
273
|
+
|
274
|
+
# Log completion
|
275
|
+
logger.info(
|
276
|
+
f"Generation completed successfully for {len(responses)} samples",
|
277
|
+
extra={
|
278
|
+
"block_name": self.block_name,
|
279
|
+
"model_id": self.model_id,
|
280
|
+
"batch_size": len(responses),
|
281
|
+
},
|
282
|
+
)
|
283
|
+
|
284
|
+
# Add responses as new column
|
285
|
+
return samples.add_column(self.output_column, responses)
|
286
|
+
|
287
|
+
|
288
|
+
@BlockRegistry.register("OpenAIAsyncChatBlock")
|
289
|
+
class OpenAIAsyncChatBlock(Block):
|
290
|
+
"""Async block for generating text using OpenAI Chat Completions API.
|
291
|
+
|
292
|
+
This block takes a column containing OpenAI message format and makes
|
293
|
+
asynchronous calls to the chat completions endpoint for better performance.
|
294
|
+
|
295
|
+
Parameters
|
296
|
+
----------
|
297
|
+
block_name : str
|
298
|
+
Name of the block.
|
299
|
+
input_cols : Union[str, List[str]]
|
300
|
+
Input column name(s). Should contain the messages list.
|
301
|
+
output_cols : Union[str, List[str]]
|
302
|
+
Output column name(s) for the response.
|
303
|
+
async_client : openai.AsyncOpenAI
|
304
|
+
Async OpenAI client instance.
|
305
|
+
model_id : str
|
306
|
+
Model ID to use.
|
307
|
+
|
308
|
+
### Text-relevant OpenAI Chat Completions API parameters ###
|
309
|
+
|
310
|
+
frequency_penalty : Optional[float], optional
|
311
|
+
Penalize frequent tokens (-2.0 to 2.0).
|
312
|
+
logit_bias : Optional[Dict[str, int]], optional
|
313
|
+
Modify likelihood of specified tokens.
|
314
|
+
logprobs : Optional[bool], optional
|
315
|
+
Whether to return log probabilities.
|
316
|
+
max_completion_tokens : Optional[int], optional
|
317
|
+
Maximum tokens in completion.
|
318
|
+
max_tokens : Optional[int], optional
|
319
|
+
Maximum tokens in completion (legacy).
|
320
|
+
n : Optional[int], optional
|
321
|
+
Number of completions to generate.
|
322
|
+
presence_penalty : Optional[float], optional
|
323
|
+
Penalize repeated tokens (-2.0 to 2.0).
|
324
|
+
response_format : Optional[Dict[str, Any]], optional
|
325
|
+
Response format specification (e.g., JSON mode).
|
326
|
+
seed : Optional[int], optional
|
327
|
+
Seed for deterministic outputs.
|
328
|
+
stop : Optional[Union[str, List[str]]], optional
|
329
|
+
Stop sequences.
|
330
|
+
stream : Optional[bool], optional
|
331
|
+
Whether to stream responses.
|
332
|
+
temperature : Optional[float], optional
|
333
|
+
Sampling temperature (0.0 to 2.0).
|
334
|
+
tool_choice : Optional[Union[str, Dict[str, Any]]], optional
|
335
|
+
Tool selection strategy.
|
336
|
+
tools : Optional[List[Dict[str, Any]]], optional
|
337
|
+
Available tools for function calling.
|
338
|
+
top_logprobs : Optional[int], optional
|
339
|
+
Number of top log probabilities to return.
|
340
|
+
top_p : Optional[float], optional
|
341
|
+
Nucleus sampling parameter (0.0 to 1.0).
|
342
|
+
user : Optional[str], optional
|
343
|
+
End-user identifier.
|
344
|
+
extra_body : Optional[dict], optional
|
345
|
+
Dictionary of additional parameters if supported by inference backend
|
346
|
+
"""
|
347
|
+
|
348
|
+
def __init__(
|
349
|
+
self,
|
350
|
+
block_name: str,
|
351
|
+
input_cols: Union[str, List[str]],
|
352
|
+
output_cols: Union[str, List[str]],
|
353
|
+
async_client: openai.AsyncOpenAI,
|
354
|
+
model_id: str,
|
355
|
+
# Text-relevant OpenAI Chat Completions API parameters
|
356
|
+
frequency_penalty: Optional[float] = None,
|
357
|
+
logit_bias: Optional[Dict[str, int]] = None,
|
358
|
+
logprobs: Optional[bool] = None,
|
359
|
+
max_completion_tokens: Optional[int] = None,
|
360
|
+
max_tokens: Optional[int] = None,
|
361
|
+
n: Optional[int] = None,
|
362
|
+
presence_penalty: Optional[float] = None,
|
363
|
+
response_format: Optional[Dict[str, Any]] = None,
|
364
|
+
seed: Optional[int] = None,
|
365
|
+
stop: Optional[Union[str, List[str]]] = None,
|
366
|
+
stream: Optional[bool] = None,
|
367
|
+
temperature: Optional[float] = None,
|
368
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
369
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
370
|
+
top_logprobs: Optional[int] = None,
|
371
|
+
top_p: Optional[float] = None,
|
372
|
+
user: Optional[str] = None,
|
373
|
+
extra_body: Optional[dict] = None,
|
374
|
+
) -> None:
|
375
|
+
super().__init__(block_name)
|
376
|
+
self.input_cols = [input_cols] if isinstance(input_cols, str) else input_cols
|
377
|
+
self.output_cols = (
|
378
|
+
[output_cols] if isinstance(output_cols, str) else output_cols
|
379
|
+
)
|
380
|
+
self.async_client = async_client
|
381
|
+
self.model_id = model_id
|
382
|
+
|
383
|
+
# For this block, we expect exactly one input column (messages) and one output column
|
384
|
+
if len(self.input_cols) != 1:
|
385
|
+
raise ValueError("OpenAIAsyncChatBlock expects exactly one input column")
|
386
|
+
if len(self.output_cols) != 1:
|
387
|
+
raise ValueError("OpenAIAsyncChatBlock expects exactly one output column")
|
388
|
+
|
389
|
+
self.messages_column = self.input_cols[0]
|
390
|
+
self.output_column = self.output_cols[0]
|
391
|
+
|
392
|
+
# Store all generation parameters (only non-None values)
|
393
|
+
self.gen_kwargs = {}
|
394
|
+
params = {
|
395
|
+
"frequency_penalty": frequency_penalty,
|
396
|
+
"logit_bias": logit_bias,
|
397
|
+
"logprobs": logprobs,
|
398
|
+
"max_completion_tokens": max_completion_tokens,
|
399
|
+
"max_tokens": max_tokens,
|
400
|
+
"n": n,
|
401
|
+
"presence_penalty": presence_penalty,
|
402
|
+
"response_format": response_format,
|
403
|
+
"seed": seed,
|
404
|
+
"stop": stop,
|
405
|
+
"stream": stream,
|
406
|
+
"temperature": temperature,
|
407
|
+
"tool_choice": tool_choice,
|
408
|
+
"tools": tools,
|
409
|
+
"top_logprobs": top_logprobs,
|
410
|
+
"top_p": top_p,
|
411
|
+
"user": user,
|
412
|
+
"extra_body": extra_body,
|
413
|
+
}
|
414
|
+
|
415
|
+
# Only include non-None parameters
|
416
|
+
for key, value in params.items():
|
417
|
+
if value is not None:
|
418
|
+
self.gen_kwargs[key] = value
|
419
|
+
|
420
|
+
# Log initialization with model and parameters
|
421
|
+
logger.info(
|
422
|
+
f"Initialized OpenAIAsyncChatBlock '{block_name}' with model '{model_id}'",
|
423
|
+
extra={
|
424
|
+
"block_name": block_name,
|
425
|
+
"model_id": model_id,
|
426
|
+
"generation_params": self.gen_kwargs,
|
427
|
+
},
|
428
|
+
)
|
429
|
+
|
430
|
+
@retry(
|
431
|
+
wait=wait_random_exponential(min=1, max=60),
|
432
|
+
stop=stop_after_attempt(6),
|
433
|
+
retry=retry_if_exception_type(
|
434
|
+
(openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError)
|
435
|
+
),
|
436
|
+
)
|
437
|
+
async def _generate_single(
|
438
|
+
self, messages: List[Dict[str, Any]], **final_kwargs: Dict[str, Any]
|
439
|
+
) -> str:
|
440
|
+
"""Generate a single response asynchronously."""
|
441
|
+
response = await self.async_client.chat.completions.create(
|
442
|
+
messages=messages, **final_kwargs
|
443
|
+
)
|
444
|
+
return response.choices[0].message.content
|
445
|
+
|
446
|
+
def generate(self, samples: Dataset, **override_kwargs: Dict[str, Any]) -> Dataset:
|
447
|
+
"""Generate the output from the block using async calls.
|
448
|
+
|
449
|
+
Parameters set at runtime override those set during initialization.
|
450
|
+
Supports all text-relevant OpenAI Chat Completions API parameters.
|
451
|
+
|
452
|
+
Parameters
|
453
|
+
----------
|
454
|
+
samples : Dataset
|
455
|
+
Input dataset containing the messages column.
|
456
|
+
**override_kwargs : Dict[str, Any]
|
457
|
+
Runtime parameters that override initialization defaults.
|
458
|
+
Valid parameters: frequency_penalty, logit_bias, logprobs,
|
459
|
+
max_completion_tokens, max_tokens, n, presence_penalty,
|
460
|
+
response_format, seed, stop, stream, temperature, tool_choice,
|
461
|
+
tools, top_logprobs, top_p, user.
|
462
|
+
|
463
|
+
Returns
|
464
|
+
-------
|
465
|
+
Dataset
|
466
|
+
Dataset with the response added to the output column.
|
467
|
+
"""
|
468
|
+
# Define valid parameters for validation
|
469
|
+
valid_params = {
|
470
|
+
"frequency_penalty",
|
471
|
+
"logit_bias",
|
472
|
+
"logprobs",
|
473
|
+
"max_completion_tokens",
|
474
|
+
"max_tokens",
|
475
|
+
"n",
|
476
|
+
"presence_penalty",
|
477
|
+
"response_format",
|
478
|
+
"seed",
|
479
|
+
"stop",
|
480
|
+
"stream",
|
481
|
+
"temperature",
|
482
|
+
"tool_choice",
|
483
|
+
"tools",
|
484
|
+
"top_logprobs",
|
485
|
+
"top_p",
|
486
|
+
"user",
|
487
|
+
}
|
488
|
+
|
489
|
+
# Filter and validate override parameters
|
490
|
+
filtered_kwargs = {
|
491
|
+
k: v for k, v in override_kwargs.items() if k in valid_params
|
492
|
+
}
|
493
|
+
|
494
|
+
# Warn about invalid parameters
|
495
|
+
invalid_params = set(override_kwargs.keys()) - valid_params
|
496
|
+
if invalid_params:
|
497
|
+
logger.warning(f"Ignoring invalid parameters: {invalid_params}")
|
498
|
+
|
499
|
+
# Merge kwargs with priority: runtime > init > defaults
|
500
|
+
final_kwargs = {**self.gen_kwargs, **filtered_kwargs}
|
501
|
+
final_kwargs["model"] = self.model_id
|
502
|
+
|
503
|
+
# Log generation start with model and effective parameters
|
504
|
+
logger.info(
|
505
|
+
f"Starting async generation for {len(samples)} samples",
|
506
|
+
extra={
|
507
|
+
"block_name": self.block_name,
|
508
|
+
"model_id": self.model_id,
|
509
|
+
"batch_size": len(samples),
|
510
|
+
"effective_params": {
|
511
|
+
k: v
|
512
|
+
for k, v in final_kwargs.items()
|
513
|
+
if k
|
514
|
+
in [
|
515
|
+
"temperature",
|
516
|
+
"max_tokens",
|
517
|
+
"max_completion_tokens",
|
518
|
+
"top_p",
|
519
|
+
"n",
|
520
|
+
"seed",
|
521
|
+
]
|
522
|
+
},
|
523
|
+
},
|
524
|
+
)
|
525
|
+
|
526
|
+
# Run async generation
|
527
|
+
return asyncio.run(self._generate_async(samples, final_kwargs))
|
528
|
+
|
529
|
+
async def _generate_async(
|
530
|
+
self, samples: Dataset, final_kwargs: Dict[str, Any]
|
531
|
+
) -> Dataset:
|
532
|
+
"""Internal async method to generate all responses concurrently."""
|
533
|
+
# Extract all messages
|
534
|
+
messages_list = samples[self.messages_column]
|
535
|
+
|
536
|
+
# Create all tasks
|
537
|
+
tasks = [
|
538
|
+
self._generate_single(messages, **final_kwargs)
|
539
|
+
for messages in messages_list
|
540
|
+
]
|
541
|
+
|
542
|
+
# Execute all tasks concurrently and collect responses
|
543
|
+
responses = await asyncio.gather(*tasks)
|
544
|
+
|
545
|
+
# Log completion
|
546
|
+
logger.info(
|
547
|
+
f"Async generation completed successfully for {len(responses)} samples",
|
548
|
+
extra={
|
549
|
+
"block_name": self.block_name,
|
550
|
+
"model_id": final_kwargs["model"],
|
551
|
+
"batch_size": len(responses),
|
552
|
+
},
|
553
|
+
)
|
554
|
+
|
555
|
+
# Add responses as new column
|
556
|
+
return samples.add_column(self.output_column, responses)
|
sdg_hub/flow.py
CHANGED
@@ -38,6 +38,7 @@ from .logger_config import setup_logger
|
|
38
38
|
from .prompts import * # needed to register prompts
|
39
39
|
from .registry import BlockRegistry, PromptRegistry
|
40
40
|
from .utils.config_validation import validate_prompt_config_schema
|
41
|
+
from .utils.path_resolution import resolve_path
|
41
42
|
from .utils.validation_result import ValidationResult
|
42
43
|
|
43
44
|
logger = setup_logger(__name__)
|
@@ -141,15 +142,7 @@ class Flow(ABC):
|
|
141
142
|
str
|
142
143
|
Selected file path.
|
143
144
|
"""
|
144
|
-
|
145
|
-
return filename
|
146
|
-
for d in dirs:
|
147
|
-
full_file_path = os.path.join(d, filename)
|
148
|
-
if os.path.isfile(full_file_path):
|
149
|
-
return full_file_path
|
150
|
-
# If not found above then return the path unchanged i.e.
|
151
|
-
# assume the path is relative to the current directory
|
152
|
-
return filename
|
145
|
+
return resolve_path(filename, dirs)
|
153
146
|
|
154
147
|
def _drop_duplicates(self, dataset: Dataset, cols: List[str]) -> Dataset:
|
155
148
|
"""Drop duplicates from the dataset based on the columns provided.
|
@@ -273,7 +266,9 @@ class Flow(ABC):
|
|
273
266
|
try:
|
274
267
|
with open(path, "r", encoding="utf-8") as f:
|
275
268
|
config_data = yaml.safe_load(f)
|
276
|
-
_, validation_errors = validate_prompt_config_schema(
|
269
|
+
_, validation_errors = validate_prompt_config_schema(
|
270
|
+
config_data, path
|
271
|
+
)
|
277
272
|
|
278
273
|
if validation_errors:
|
279
274
|
errors.extend(validation_errors)
|
@@ -320,9 +315,7 @@ class Flow(ABC):
|
|
320
315
|
KeyError
|
321
316
|
If a required block or prompt is not found in the registry.
|
322
317
|
"""
|
323
|
-
|
324
|
-
if os.path.isfile(yaml_path_relative_to_base):
|
325
|
-
yaml_path = yaml_path_relative_to_base
|
318
|
+
yaml_path = resolve_path(yaml_path, self.base_path)
|
326
319
|
yaml_dir = os.path.dirname(yaml_path)
|
327
320
|
|
328
321
|
try:
|
@@ -433,7 +426,11 @@ class Flow(ABC):
|
|
433
426
|
config = block["block_config"]
|
434
427
|
|
435
428
|
# LLM Block: parse Jinja vars
|
436
|
-
cls_name =
|
429
|
+
cls_name = (
|
430
|
+
block_type.__name__
|
431
|
+
if isinstance(block_type, type)
|
432
|
+
else block_type.__class__.__name__
|
433
|
+
)
|
437
434
|
logger.info(f"Validating block: {name} ({cls_name})")
|
438
435
|
if "LLM" in cls_name:
|
439
436
|
config_path = config.get("config_path")
|
@@ -445,7 +442,9 @@ class Flow(ABC):
|
|
445
442
|
vars_found = meta.find_undeclared_variables(ast)
|
446
443
|
for var in vars_found:
|
447
444
|
if var not in all_columns:
|
448
|
-
errors.append(
|
445
|
+
errors.append(
|
446
|
+
f"[{name}] Missing column for prompt var: '{var}'"
|
447
|
+
)
|
449
448
|
|
450
449
|
# FilterByValueBlock
|
451
450
|
if "FilterByValueBlock" in str(block_type):
|
@@ -462,13 +461,17 @@ class Flow(ABC):
|
|
462
461
|
choice_map = config.get("choice_map", {})
|
463
462
|
for col in choice_map.values():
|
464
463
|
if col not in all_columns:
|
465
|
-
errors.append(
|
464
|
+
errors.append(
|
465
|
+
f"[{name}] choice_map references missing column: '{col}'"
|
466
|
+
)
|
466
467
|
|
467
468
|
# CombineColumnsBlock
|
468
469
|
if "CombineColumnsBlock" in str(block_type):
|
469
470
|
cols = config.get("columns", [])
|
470
471
|
for col in cols:
|
471
472
|
if col not in all_columns:
|
472
|
-
errors.append(
|
473
|
+
errors.append(
|
474
|
+
f"[{name}] CombineColumnsBlock requires column: '{col}'"
|
475
|
+
)
|
473
476
|
|
474
|
-
return ValidationResult(valid=(len(errors) == 0), errors=errors)
|
477
|
+
return ValidationResult(valid=(len(errors) == 0), errors=errors)
|