sdg-hub 0.1.0a4__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/_version.py +2 -2
- sdg_hub/blocks/__init__.py +35 -5
- sdg_hub/blocks/block.py +58 -16
- sdg_hub/blocks/llmblock.py +121 -193
- sdg_hub/blocks/utilblocks.py +500 -43
- sdg_hub/checkpointer.py +139 -0
- sdg_hub/configs/annotations/detailed_annotations.yaml +28 -0
- sdg_hub/configs/annotations/simple_annotations.yaml +9 -0
- sdg_hub/configs/knowledge/atomic_facts.yaml +1 -0
- sdg_hub/configs/knowledge/detailed_summary.yaml +1 -0
- sdg_hub/configs/knowledge/extractive_summary.yaml +1 -0
- sdg_hub/configs/knowledge/generate_questions.yaml +82 -0
- sdg_hub/configs/knowledge/generate_responses.yaml +86 -0
- sdg_hub/configs/skills/contexts.yaml +18 -11
- sdg_hub/configs/skills/evaluate_freeform_pair.yaml +79 -12
- sdg_hub/configs/skills/evaluate_freeform_questions.yaml +60 -28
- sdg_hub/configs/skills/evaluate_grounded_pair.yaml +95 -30
- sdg_hub/configs/skills/freeform_questions.yaml +21 -16
- sdg_hub/configs/skills/freeform_responses.yaml +19 -25
- sdg_hub/configs/skills/router.yaml +53 -6
- sdg_hub/flow.py +351 -21
- sdg_hub/flow_runner.py +216 -0
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +26 -9
- sdg_hub/flows/generation/skills/{agentic_improve_skill.yaml → improve_responses.yaml} +26 -31
- sdg_hub/flows/generation/skills/synth_skills.yaml +4 -4
- sdg_hub/pipeline.py +67 -12
- sdg_hub/prompts.py +21 -0
- sdg_hub/sdg.py +128 -86
- sdg_hub/utils/config_validation.py +91 -0
- sdg_hub/utils/validation_result.py +10 -0
- sdg_hub-0.1.1.dist-info/METADATA +190 -0
- sdg_hub-0.1.1.dist-info/RECORD +86 -0
- {sdg_hub-0.1.0a4.dist-info → sdg_hub-0.1.1.dist-info}/WHEEL +1 -1
- sdg_hub/blocks/filterblock.py +0 -76
- sdg_hub/blocks/iterblock.py +0 -31
- sdg_hub/blocks/rmblocks.py +0 -194
- sdg_hub/configs/annotations/simple.yaml +0 -10
- sdg_hub/configs/knowledge/data_recipe/default_recipe.yaml +0 -3
- sdg_hub/configs/skills/data_recipe/default_recipe.yaml +0 -6
- sdg_hub/flows/annotation/emotion/detailed_description.yaml +0 -19
- sdg_hub/flows/annotation/emotion/detailed_description_icl.yaml +0 -19
- sdg_hub/flows/annotation/emotion/simple.yaml +0 -19
- sdg_hub/utils/chunking.py +0 -73
- sdg_hub/utils/docprocessor.py +0 -357
- sdg_hub/utils/parse_and_convert.py +0 -392
- sdg_hub-0.1.0a4.dist-info/METADATA +0 -309
- sdg_hub-0.1.0a4.dist-info/RECORD +0 -90
- /sdg_hub/configs/{knowledge/data_recipe → reasoning}/__init__.py +0 -0
- /sdg_hub/configs/skills/{_G_.yaml → icl_examples/STEM.yaml} +0 -0
- /sdg_hub/configs/skills/{data_recipe → icl_examples}/__init__.py +0 -0
- /sdg_hub/configs/skills/{_A_.yaml → icl_examples/coding.yaml} +0 -0
- /sdg_hub/configs/skills/{_B_.yaml → icl_examples/extraction.yaml} +0 -0
- /sdg_hub/configs/skills/{_C_.yaml → icl_examples/humanities.yaml} +0 -0
- /sdg_hub/configs/skills/{_D_.yaml → icl_examples/math.yaml} +0 -0
- /sdg_hub/configs/skills/{_E_.yaml → icl_examples/reasoning.yaml} +0 -0
- /sdg_hub/configs/skills/{_F_.yaml → icl_examples/roleplay.yaml} +0 -0
- /sdg_hub/configs/skills/{_H_.yaml → icl_examples/writing.yaml} +0 -0
- {sdg_hub-0.1.0a4.dist-info → sdg_hub-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.1.0a4.dist-info → sdg_hub-0.1.1.dist-info}/top_level.txt +0 -0
sdg_hub/_version.py
CHANGED
sdg_hub/blocks/__init__.py
CHANGED
@@ -1,6 +1,36 @@
|
|
1
|
+
"""Block implementations for SDG Hub.
|
2
|
+
|
3
|
+
This package provides various block implementations for data generation, processing, and transformation.
|
4
|
+
"""
|
5
|
+
|
1
6
|
# Local
|
2
|
-
from .block import
|
3
|
-
from .
|
4
|
-
from .
|
5
|
-
|
6
|
-
|
7
|
+
from .block import Block
|
8
|
+
from .llmblock import LLMBlock, ConditionalLLMBlock
|
9
|
+
from .utilblocks import (
|
10
|
+
SamplePopulatorBlock,
|
11
|
+
SelectorBlock,
|
12
|
+
CombineColumnsBlock,
|
13
|
+
FlattenColumnsBlock,
|
14
|
+
DuplicateColumns,
|
15
|
+
RenameColumns,
|
16
|
+
SetToMajorityValue,
|
17
|
+
FilterByValueBlock,
|
18
|
+
IterBlock,
|
19
|
+
)
|
20
|
+
from ..registry import BlockRegistry
|
21
|
+
|
22
|
+
__all__ = [
|
23
|
+
"Block",
|
24
|
+
"FilterByValueBlock",
|
25
|
+
"IterBlock",
|
26
|
+
"LLMBlock",
|
27
|
+
"ConditionalLLMBlock",
|
28
|
+
"SamplePopulatorBlock",
|
29
|
+
"SelectorBlock",
|
30
|
+
"CombineColumnsBlock",
|
31
|
+
"FlattenColumnsBlock",
|
32
|
+
"DuplicateColumns",
|
33
|
+
"RenameColumns",
|
34
|
+
"SetToMajorityValue",
|
35
|
+
"BlockRegistry",
|
36
|
+
]
|
sdg_hub/blocks/block.py
CHANGED
@@ -1,8 +1,14 @@
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Base block implementation for the SDG Hub system.
|
3
|
+
|
4
|
+
This module provides the abstract base class for all blocks in the system,
|
5
|
+
including functionality for template validation and configuration management.
|
6
|
+
"""
|
7
|
+
|
2
8
|
# Standard
|
3
9
|
from abc import ABC
|
4
10
|
from collections import ChainMap
|
5
|
-
from typing import Any, Dict,
|
11
|
+
from typing import Any, Dict, Optional
|
6
12
|
|
7
13
|
# Third Party
|
8
14
|
from jinja2 import Template, UndefinedError
|
@@ -17,24 +23,38 @@ logger = setup_logger(__name__)
|
|
17
23
|
|
18
24
|
@BlockRegistry.register("Block")
|
19
25
|
class Block(ABC):
|
26
|
+
"""Base abstract class for all blocks in the system.
|
27
|
+
|
28
|
+
This class provides common functionality for block validation and configuration loading.
|
29
|
+
All specific block implementations should inherit from this class.
|
30
|
+
"""
|
31
|
+
|
20
32
|
def __init__(self, block_name: str) -> None:
|
21
33
|
self.block_name = block_name
|
22
34
|
|
23
35
|
@staticmethod
|
24
36
|
def _validate(prompt_template: Template, input_dict: Dict[str, Any]) -> bool:
|
25
|
-
"""
|
26
|
-
|
27
|
-
variables in the Jinja template are provided in the input_dict.
|
37
|
+
"""Validate the input data for this block.
|
38
|
+
|
39
|
+
This method validates whether all required variables in the Jinja template are provided in the input_dict.
|
40
|
+
|
41
|
+
Parameters
|
42
|
+
----------
|
43
|
+
prompt_template : Template
|
44
|
+
The Jinja2 template object.
|
45
|
+
input_dict : Dict[str, Any]
|
46
|
+
A dictionary of input values to check against the template.
|
28
47
|
|
29
|
-
|
30
|
-
|
31
|
-
|
48
|
+
Returns
|
49
|
+
-------
|
50
|
+
bool
|
51
|
+
True if the input data is valid (i.e., no missing variables), False otherwise.
|
32
52
|
"""
|
33
|
-
|
53
|
+
|
34
54
|
class Default(dict):
|
35
55
|
def __missing__(self, key: str) -> None:
|
36
56
|
raise KeyError(key)
|
37
|
-
|
57
|
+
|
38
58
|
try:
|
39
59
|
# Try rendering the template with the input_dict
|
40
60
|
prompt_template.render(ChainMap(input_dict, Default()))
|
@@ -43,12 +63,34 @@ class Block(ABC):
|
|
43
63
|
logger.error(f"Missing key: {e}")
|
44
64
|
return False
|
45
65
|
|
46
|
-
def _load_config(self, config_path: str) ->
|
47
|
-
"""
|
48
|
-
Load the configuration file for this block.
|
66
|
+
def _load_config(self, config_path: str) -> Optional[Dict[str, Any]]:
|
67
|
+
"""Load the configuration file for this block.
|
49
68
|
|
50
|
-
|
51
|
-
|
69
|
+
Parameters
|
70
|
+
----------
|
71
|
+
config_path : str
|
72
|
+
The path to the configuration file.
|
73
|
+
|
74
|
+
Returns
|
75
|
+
-------
|
76
|
+
Optional[Dict[str, Any]]
|
77
|
+
The loaded configuration. Returns None if file cannot be read or parsed.
|
78
|
+
|
79
|
+
Raises
|
80
|
+
------
|
81
|
+
FileNotFoundError
|
82
|
+
If the configuration file does not exist.
|
52
83
|
"""
|
53
|
-
|
54
|
-
|
84
|
+
try:
|
85
|
+
with open(config_path, "r", encoding="utf-8") as config_file:
|
86
|
+
try:
|
87
|
+
return yaml.safe_load(config_file)
|
88
|
+
except yaml.YAMLError as e:
|
89
|
+
logger.error(f"Error parsing YAML from {config_path}: {e}")
|
90
|
+
return None
|
91
|
+
except FileNotFoundError:
|
92
|
+
logger.error(f"Configuration file not found: {config_path}")
|
93
|
+
raise
|
94
|
+
except Exception as e:
|
95
|
+
logger.error(f"Unexpected error reading config file {config_path}: {e}")
|
96
|
+
return None
|
sdg_hub/blocks/llmblock.py
CHANGED
@@ -1,7 +1,11 @@
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""LLM-based blocks for text generation and processing.
|
3
|
+
|
4
|
+
This module provides blocks for interacting with language models.
|
5
|
+
"""
|
6
|
+
|
2
7
|
# Standard
|
3
|
-
from typing import Any, Dict, List
|
4
|
-
from typing import Optional
|
8
|
+
from typing import Any, Dict, List, Optional, Union
|
5
9
|
import json
|
6
10
|
import re
|
7
11
|
|
@@ -18,7 +22,18 @@ from ..registry import BlockRegistry, PromptRegistry
|
|
18
22
|
logger = setup_logger(__name__)
|
19
23
|
|
20
24
|
|
21
|
-
def server_supports_batched(client, model_id: str) -> bool:
|
25
|
+
def server_supports_batched(client: openai.OpenAI, model_id: str) -> bool:
|
26
|
+
"""Check if the server supports batched inputs.
|
27
|
+
|
28
|
+
This function checks if the server supports batched inputs by making a test call to the server.
|
29
|
+
|
30
|
+
Parameters
|
31
|
+
----------
|
32
|
+
client : openai.OpenAI
|
33
|
+
The client to use to make the test call.
|
34
|
+
model_id : str
|
35
|
+
The model ID to use for the test call.
|
36
|
+
"""
|
22
37
|
supported = getattr(client, "server_supports_batched", None)
|
23
38
|
if supported is not None:
|
24
39
|
return supported
|
@@ -38,19 +53,43 @@ def server_supports_batched(client, model_id: str) -> bool:
|
|
38
53
|
|
39
54
|
|
40
55
|
@BlockRegistry.register("LLMBlock")
|
41
|
-
# pylint: disable=dangerous-default-value
|
42
56
|
class LLMBlock(Block):
|
57
|
+
"""Block for generating text using language models.
|
58
|
+
|
59
|
+
This block handles text generation, prompt formatting, and output parsing
|
60
|
+
for language model interactions.
|
61
|
+
|
62
|
+
Parameters
|
63
|
+
----------
|
64
|
+
block_name : str
|
65
|
+
Name of the block.
|
66
|
+
config_path : str
|
67
|
+
Path to the configuration file.
|
68
|
+
client : openai.OpenAI
|
69
|
+
OpenAI client instance.
|
70
|
+
output_cols : List[str]
|
71
|
+
List of output column names.
|
72
|
+
parser_kwargs : Dict[str, Any], optional
|
73
|
+
Keyword arguments for the parser, by default {}.
|
74
|
+
model_prompt : str, optional
|
75
|
+
Template string for model prompt, by default "{prompt}".
|
76
|
+
model_id : Optional[str], optional
|
77
|
+
Model ID to use, by default None.
|
78
|
+
**batch_kwargs : Dict[str, Any]
|
79
|
+
Additional keyword arguments for batch processing.
|
80
|
+
"""
|
81
|
+
|
43
82
|
# pylint: disable=too-many-instance-attributes
|
44
83
|
def __init__(
|
45
84
|
self,
|
46
|
-
block_name,
|
47
|
-
config_path,
|
48
|
-
client,
|
49
|
-
output_cols,
|
50
|
-
parser_kwargs={},
|
51
|
-
model_prompt="{prompt}",
|
52
|
-
model_id=None,
|
53
|
-
**batch_kwargs,
|
85
|
+
block_name: str,
|
86
|
+
config_path: str,
|
87
|
+
client: openai.OpenAI,
|
88
|
+
output_cols: List[str],
|
89
|
+
parser_kwargs: Dict[str, Any] = {},
|
90
|
+
model_prompt: str = "{prompt}",
|
91
|
+
model_id: Optional[str] = None,
|
92
|
+
**batch_kwargs: Dict[str, Any],
|
54
93
|
) -> None:
|
55
94
|
super().__init__(block_name)
|
56
95
|
self.block_config = self._load_config(config_path)
|
@@ -84,7 +123,6 @@ class LLMBlock(Block):
|
|
84
123
|
# and supports the n parameter to generate n outputs per input
|
85
124
|
self.server_supports_batched = server_supports_batched(client, self.model)
|
86
125
|
|
87
|
-
|
88
126
|
def _extract_matches(
|
89
127
|
self, text: str, start_tag: Optional[str], end_tag: Optional[str]
|
90
128
|
) -> List[str]:
|
@@ -105,7 +143,7 @@ class LLMBlock(Block):
|
|
105
143
|
|
106
144
|
return [match.strip() for match in re.findall(pattern, text, re.DOTALL)]
|
107
145
|
|
108
|
-
def _parse(self, generated_string) -> dict:
|
146
|
+
def _parse(self, generated_string: str) -> dict:
|
109
147
|
matches = {}
|
110
148
|
|
111
149
|
if self.parser_name is not None and self.parser_name == "custom":
|
@@ -141,7 +179,7 @@ class LLMBlock(Block):
|
|
141
179
|
self.model_prompt, prompt_templated_str, add_generation_prompt=True
|
142
180
|
).strip()
|
143
181
|
|
144
|
-
def _generate(self, samples, **gen_kwargs) -> list:
|
182
|
+
def _generate(self, samples: Dataset, **gen_kwargs: Dict[str, Any]) -> list:
|
145
183
|
prompts = [self._format_prompt(sample) for sample in samples]
|
146
184
|
logger.debug("Prompt: %s", prompts[0])
|
147
185
|
generate_args = {**self.defaults, **gen_kwargs}
|
@@ -173,12 +211,16 @@ class LLMBlock(Block):
|
|
173
211
|
results.append(response.choices[0].text.strip())
|
174
212
|
return results
|
175
213
|
|
176
|
-
def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
|
177
|
-
"""
|
178
|
-
|
214
|
+
def generate(self, samples: Dataset, **gen_kwargs: Dict[str, Any]) -> Dataset:
|
215
|
+
"""Generate the output from the block.
|
216
|
+
|
217
|
+
This method should first validate the input data,
|
179
218
|
then generate the output, and finally parse the generated output before returning it.
|
180
219
|
|
181
|
-
|
220
|
+
Returns
|
221
|
+
-------
|
222
|
+
Dataset
|
223
|
+
The parsed output after generation.
|
182
224
|
"""
|
183
225
|
num_samples = self.block_config.get("num_samples", None)
|
184
226
|
logger.debug("Generating outputs for {} samples".format(len(samples)))
|
@@ -233,16 +275,40 @@ class LLMBlock(Block):
|
|
233
275
|
|
234
276
|
@BlockRegistry.register("ConditionalLLMBlock")
|
235
277
|
class ConditionalLLMBlock(LLMBlock):
|
278
|
+
"""Block for conditional text generation using language models.
|
279
|
+
|
280
|
+
This block selects different prompt templates based on a selector column value.
|
281
|
+
|
282
|
+
Parameters
|
283
|
+
----------
|
284
|
+
block_name : str
|
285
|
+
Name of the block.
|
286
|
+
config_paths : Dict[str, str]
|
287
|
+
Dictionary mapping selector values to their config file paths.
|
288
|
+
client : openai.OpenAI
|
289
|
+
OpenAI client instance.
|
290
|
+
model_id : str
|
291
|
+
Model ID to use.
|
292
|
+
output_cols : List[str]
|
293
|
+
List of output column names.
|
294
|
+
selector_column_name : str
|
295
|
+
Name of the column used to select the prompt template.
|
296
|
+
model_prompt : str, optional
|
297
|
+
Template string for model prompt, by default "{prompt}".
|
298
|
+
**batch_kwargs : Dict[str, Any]
|
299
|
+
Additional keyword arguments for batch processing.
|
300
|
+
"""
|
301
|
+
|
236
302
|
def __init__(
|
237
303
|
self,
|
238
|
-
block_name,
|
239
|
-
config_paths,
|
240
|
-
client,
|
241
|
-
model_id,
|
242
|
-
output_cols,
|
243
|
-
selector_column_name,
|
244
|
-
model_prompt="{prompt}",
|
245
|
-
**batch_kwargs,
|
304
|
+
block_name: str,
|
305
|
+
config_paths: Dict[str, str],
|
306
|
+
client: openai.OpenAI,
|
307
|
+
model_id: str,
|
308
|
+
output_cols: List[str],
|
309
|
+
selector_column_name: str,
|
310
|
+
model_prompt: str = "{prompt}",
|
311
|
+
**batch_kwargs: Dict[str, Any],
|
246
312
|
) -> None:
|
247
313
|
super().__init__(
|
248
314
|
block_name=block_name,
|
@@ -259,7 +325,6 @@ class ConditionalLLMBlock(LLMBlock):
|
|
259
325
|
self.prompt_template = self.prompt_struct.format(**self.block_config)
|
260
326
|
else:
|
261
327
|
for config_key, config in config_paths.items():
|
262
|
-
# Template(self.prompt_struct.format(**filtered_config))
|
263
328
|
filtered_config = {
|
264
329
|
k: (v if v is not None else "")
|
265
330
|
for k, v in self.block_config.items()
|
@@ -268,7 +333,19 @@ class ConditionalLLMBlock(LLMBlock):
|
|
268
333
|
self.prompt_struct.format(**self._load_config(config))
|
269
334
|
)
|
270
335
|
|
271
|
-
def _format_prompt(self, sample: Dict) -> str:
|
336
|
+
def _format_prompt(self, sample: Dict[str, Any]) -> str:
|
337
|
+
"""Format the prompt based on the selector column value.
|
338
|
+
|
339
|
+
Parameters
|
340
|
+
----------
|
341
|
+
sample : Dict[str, Any]
|
342
|
+
Input sample containing the selector column.
|
343
|
+
|
344
|
+
Returns
|
345
|
+
-------
|
346
|
+
str
|
347
|
+
Formatted prompt string.
|
348
|
+
"""
|
272
349
|
if isinstance(self.prompt_template, dict):
|
273
350
|
return (
|
274
351
|
self.prompt_template[sample[self.selector_column_name]]
|
@@ -278,170 +355,21 @@ class ConditionalLLMBlock(LLMBlock):
|
|
278
355
|
|
279
356
|
return self.prompt_template.render(**sample).strip()
|
280
357
|
|
281
|
-
def _validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool:
|
282
|
-
|
283
|
-
prompt_template = prompt_template[input_dict[self.selector_column_name]]
|
284
|
-
return super()._validate(prompt_template, input_dict)
|
285
|
-
|
286
|
-
|
287
|
-
@BlockRegistry.register("LLMLogProbBlock")
|
288
|
-
class LLMLogProbBlock(LLMBlock):
|
289
|
-
# init with init of the parent class
|
290
|
-
def __init__(
|
291
|
-
self,
|
292
|
-
block_name,
|
293
|
-
config_path,
|
294
|
-
client,
|
295
|
-
output_cols,
|
296
|
-
parser_kwargs={},
|
297
|
-
model_prompt="{prompt}",
|
298
|
-
model_id=None,
|
299
|
-
**batch_kwargs,
|
300
|
-
) -> None:
|
301
|
-
super().__init__(
|
302
|
-
block_name=block_name,
|
303
|
-
config_path=config_path,
|
304
|
-
client=client,
|
305
|
-
output_cols=output_cols,
|
306
|
-
parser_kwargs=parser_kwargs,
|
307
|
-
model_prompt=model_prompt,
|
308
|
-
model_id=model_id,
|
309
|
-
**batch_kwargs,
|
310
|
-
)
|
311
|
-
|
312
|
-
def _generate_logprobs(self, samples, **gen_kwargs):
|
313
|
-
prompts = [
|
314
|
-
self.model_prompt.format(prompt=self._format_prompt(sample))
|
315
|
-
for sample in samples
|
316
|
-
]
|
317
|
-
generate_args = {**self.defaults, **gen_kwargs}
|
318
|
-
|
319
|
-
# verify if logprobs is mentioned in the generate_args, if not add it and return top10 logprobs
|
320
|
-
if "logprobs" not in generate_args:
|
321
|
-
generate_args["logprobs"] = 10
|
322
|
-
|
323
|
-
if self.server_supports_batched:
|
324
|
-
response = self.client.completions.create(prompt=prompts, **generate_args)
|
325
|
-
return [choice.logprobs.top_logprobs for choice in response.choices]
|
326
|
-
|
327
|
-
n = gen_kwargs.get("n", 1)
|
328
|
-
results = []
|
329
|
-
for prompt in prompts:
|
330
|
-
for _ in range(n):
|
331
|
-
response = self.client.completions.create(
|
332
|
-
prompt=prompt, **generate_args
|
333
|
-
)
|
334
|
-
results.append(response.choices[0].logprobs.top_logprobs)
|
335
|
-
return results
|
358
|
+
def _validate(self, prompt_template: Union[str, Template], input_dict: Dict[str, Any]) -> bool:
|
359
|
+
"""Validate the input data for this block.
|
336
360
|
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
361
|
+
Parameters
|
362
|
+
----------
|
363
|
+
prompt_template : Union[str, Template]
|
364
|
+
The template to validate against.
|
365
|
+
input_dict : Dict[str, Any]
|
366
|
+
Input data to validate.
|
342
367
|
|
343
|
-
|
368
|
+
Returns
|
369
|
+
-------
|
370
|
+
bool
|
371
|
+
True if the input data is valid, False otherwise.
|
344
372
|
"""
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
:return: The parsed output after generation.
|
349
|
-
"""
|
350
|
-
num_samples = self.block_config.get("num_samples", None)
|
351
|
-
logger.debug("Generating outputs for {} samples".format(len(samples)))
|
352
|
-
|
353
|
-
if (num_samples is not None) and ("num_samples" not in samples.column_names):
|
354
|
-
samples = samples.add_column("num_samples", [num_samples] * len(samples))
|
355
|
-
|
356
|
-
# validate each sample
|
357
|
-
# Log errors and remove invalid samples
|
358
|
-
valid_samples = []
|
359
|
-
|
360
|
-
for sample in samples:
|
361
|
-
if self._validate(self.prompt_template, sample):
|
362
|
-
valid_samples.append(sample)
|
363
|
-
else:
|
364
|
-
logger.warning(
|
365
|
-
f"Sample failed validation: {sample}"
|
366
|
-
) # Log details of the failed sample
|
367
|
-
|
368
|
-
samples = valid_samples
|
369
|
-
|
370
|
-
if len(samples) == 0:
|
371
|
-
logger.warning(
|
372
|
-
"No valid samples to generate outputs for, returning empty dataset"
|
373
|
-
)
|
374
|
-
return Dataset.from_list([])
|
375
|
-
|
376
|
-
# generate the output
|
377
|
-
|
378
|
-
outputs = self._generate_logprobs(samples, **gen_kwargs)
|
379
|
-
logger.debug("Generated outputs: %s", outputs)
|
380
|
-
|
381
|
-
output_dataset = Dataset.from_list(samples)
|
382
|
-
output_dataset = output_dataset.add_column(
|
383
|
-
self.output_cols[0],
|
384
|
-
self._parse(outputs), # pylint: disable=no-value-for-parameter
|
385
|
-
)
|
386
|
-
|
387
|
-
return output_dataset
|
388
|
-
|
389
|
-
|
390
|
-
@BlockRegistry.register("LLMMessagesBlock")
|
391
|
-
class LLMMessagesBlock(Block):
|
392
|
-
def __init__(
|
393
|
-
self,
|
394
|
-
block_name,
|
395
|
-
client,
|
396
|
-
input_col,
|
397
|
-
output_col,
|
398
|
-
model_prompt=None,
|
399
|
-
model_id=None,
|
400
|
-
**batch_kwargs,
|
401
|
-
) -> None:
|
402
|
-
self.block_name = block_name
|
403
|
-
self.model_prompt = model_prompt
|
404
|
-
self.batch_params = batch_kwargs.get("batch_kwargs", {})
|
405
|
-
self.input_col = input_col
|
406
|
-
self.output_col = output_col
|
407
|
-
self.client = client
|
408
|
-
|
409
|
-
if model_id:
|
410
|
-
self.model = model_id
|
411
|
-
else:
|
412
|
-
self.model = self.client.models.list().data[0].id
|
413
|
-
|
414
|
-
self.defaults = {
|
415
|
-
"model": self.model,
|
416
|
-
"temperature": 0,
|
417
|
-
"max_tokens": 4096,
|
418
|
-
}
|
419
|
-
self.server_supports_batched = server_supports_batched(client, self.model)
|
420
|
-
|
421
|
-
def _generate(self, samples, **gen_kwargs) -> list:
|
422
|
-
generate_args = {**self.defaults, **gen_kwargs}
|
423
|
-
|
424
|
-
if "n" in generate_args and generate_args.get("temperature", 0) <= 0:
|
425
|
-
generate_args["temperature"] = 0.7
|
426
|
-
logger.warning(
|
427
|
-
"Temperature should be greater than 0 for n > 1, setting temperature to 0.7"
|
428
|
-
)
|
429
|
-
|
430
|
-
messages = samples[self.input_col]
|
431
|
-
|
432
|
-
results = []
|
433
|
-
n = gen_kwargs.get("n", 1)
|
434
|
-
for message in messages:
|
435
|
-
responses = self.client.chat.completions.create(
|
436
|
-
messages=message, **generate_args
|
437
|
-
)
|
438
|
-
if n > 1:
|
439
|
-
results.append([choice.message.content for choice in responses.choices])
|
440
|
-
else:
|
441
|
-
results.append(responses.choices[0].message.content)
|
442
|
-
return results
|
443
|
-
|
444
|
-
def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
|
445
|
-
outputs = self._generate(samples, **gen_kwargs)
|
446
|
-
samples = samples.add_column(self.output_col, outputs)
|
447
|
-
return samples
|
373
|
+
if isinstance(prompt_template, dict):
|
374
|
+
prompt_template = prompt_template[input_dict[self.selector_column_name]]
|
375
|
+
return super()._validate(prompt_template, input_dict)
|