sdg-hub 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/__init__.py +3 -0
- sdg_hub/_version.py +21 -0
- sdg_hub/blocks/__init__.py +36 -0
- sdg_hub/blocks/block.py +96 -0
- sdg_hub/blocks/llmblock.py +375 -0
- sdg_hub/blocks/utilblocks.py +597 -0
- sdg_hub/checkpointer.py +139 -0
- sdg_hub/configs/__init__.py +0 -0
- sdg_hub/configs/annotations/__init__.py +0 -0
- sdg_hub/configs/annotations/cot_reflection.yaml +34 -0
- sdg_hub/configs/annotations/detailed_annotations.yaml +28 -0
- sdg_hub/configs/annotations/detailed_description.yaml +10 -0
- sdg_hub/configs/annotations/detailed_description_icl.yaml +32 -0
- sdg_hub/configs/annotations/simple_annotations.yaml +9 -0
- sdg_hub/configs/knowledge/__init__.py +0 -0
- sdg_hub/configs/knowledge/atomic_facts.yaml +45 -0
- sdg_hub/configs/knowledge/auxilary_instructions.yaml +35 -0
- sdg_hub/configs/knowledge/detailed_summary.yaml +17 -0
- sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +68 -0
- sdg_hub/configs/knowledge/evaluate_question.yaml +38 -0
- sdg_hub/configs/knowledge/evaluate_relevancy.yaml +85 -0
- sdg_hub/configs/knowledge/extractive_summary.yaml +17 -0
- sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +39 -0
- sdg_hub/configs/knowledge/generate_questions_responses.yaml +56 -0
- sdg_hub/configs/knowledge/mcq_generation.yaml +83 -0
- sdg_hub/configs/knowledge/router.yaml +12 -0
- sdg_hub/configs/knowledge/simple_generate_qa.yaml +34 -0
- sdg_hub/configs/reasoning/__init__.py +0 -0
- sdg_hub/configs/reasoning/dynamic_cot.yaml +40 -0
- sdg_hub/configs/skills/__init__.py +0 -0
- sdg_hub/configs/skills/analyzer.yaml +48 -0
- sdg_hub/configs/skills/annotation.yaml +36 -0
- sdg_hub/configs/skills/contexts.yaml +28 -0
- sdg_hub/configs/skills/critic.yaml +60 -0
- sdg_hub/configs/skills/evaluate_freeform_pair.yaml +111 -0
- sdg_hub/configs/skills/evaluate_freeform_questions.yaml +78 -0
- sdg_hub/configs/skills/evaluate_grounded_pair.yaml +119 -0
- sdg_hub/configs/skills/evaluate_grounded_questions.yaml +51 -0
- sdg_hub/configs/skills/freeform_questions.yaml +34 -0
- sdg_hub/configs/skills/freeform_responses.yaml +39 -0
- sdg_hub/configs/skills/grounded_questions.yaml +38 -0
- sdg_hub/configs/skills/grounded_responses.yaml +59 -0
- sdg_hub/configs/skills/icl_examples/STEM.yaml +56 -0
- sdg_hub/configs/skills/icl_examples/__init__.py +0 -0
- sdg_hub/configs/skills/icl_examples/coding.yaml +97 -0
- sdg_hub/configs/skills/icl_examples/extraction.yaml +36 -0
- sdg_hub/configs/skills/icl_examples/humanities.yaml +71 -0
- sdg_hub/configs/skills/icl_examples/math.yaml +85 -0
- sdg_hub/configs/skills/icl_examples/reasoning.yaml +30 -0
- sdg_hub/configs/skills/icl_examples/roleplay.yaml +45 -0
- sdg_hub/configs/skills/icl_examples/writing.yaml +80 -0
- sdg_hub/configs/skills/judge.yaml +53 -0
- sdg_hub/configs/skills/planner.yaml +67 -0
- sdg_hub/configs/skills/respond.yaml +8 -0
- sdg_hub/configs/skills/revised_responder.yaml +78 -0
- sdg_hub/configs/skills/router.yaml +59 -0
- sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +27 -0
- sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +31 -0
- sdg_hub/flow.py +306 -0
- sdg_hub/flow_runner.py +204 -0
- sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +13 -0
- sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +12 -0
- sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +89 -0
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +136 -0
- sdg_hub/flows/generation/skills/improve_responses.yaml +103 -0
- sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +12 -0
- sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +12 -0
- sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +80 -0
- sdg_hub/flows/generation/skills/synth_skills.yaml +59 -0
- sdg_hub/logger_config.py +20 -0
- sdg_hub/pipeline.py +121 -0
- sdg_hub/prompts.py +43 -0
- sdg_hub/py.typed +0 -0
- sdg_hub/registry.py +122 -0
- sdg_hub/sdg.py +206 -0
- sdg_hub/utils/__init__.py +5 -0
- sdg_hub/utils/datautils.py +14 -0
- sdg_hub-0.1.0.dist-info/METADATA +190 -0
- sdg_hub-0.1.0.dist-info/RECORD +82 -0
- sdg_hub-0.1.0.dist-info/WHEEL +5 -0
- sdg_hub-0.1.0.dist-info/licenses/LICENSE +201 -0
- sdg_hub-0.1.0.dist-info/top_level.txt +1 -0
sdg_hub/__init__.py
ADDED
sdg_hub/_version.py
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
# file generated by setuptools-scm
|
2
|
+
# don't change, don't track in version control
|
3
|
+
|
4
|
+
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
5
|
+
|
6
|
+
TYPE_CHECKING = False
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from typing import Tuple
|
9
|
+
from typing import Union
|
10
|
+
|
11
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
12
|
+
else:
|
13
|
+
VERSION_TUPLE = object
|
14
|
+
|
15
|
+
version: str
|
16
|
+
__version__: str
|
17
|
+
__version_tuple__: VERSION_TUPLE
|
18
|
+
version_tuple: VERSION_TUPLE
|
19
|
+
|
20
|
+
__version__ = version = '0.1.0'
|
21
|
+
__version_tuple__ = version_tuple = (0, 1, 0)
|
@@ -0,0 +1,36 @@
|
|
1
|
+
"""Block implementations for SDG Hub.
|
2
|
+
|
3
|
+
This package provides various block implementations for data generation, processing, and transformation.
|
4
|
+
"""
|
5
|
+
|
6
|
+
# Local
|
7
|
+
from .block import Block
|
8
|
+
from .llmblock import LLMBlock, ConditionalLLMBlock
|
9
|
+
from .utilblocks import (
|
10
|
+
SamplePopulatorBlock,
|
11
|
+
SelectorBlock,
|
12
|
+
CombineColumnsBlock,
|
13
|
+
FlattenColumnsBlock,
|
14
|
+
DuplicateColumns,
|
15
|
+
RenameColumns,
|
16
|
+
SetToMajorityValue,
|
17
|
+
FilterByValueBlock,
|
18
|
+
IterBlock,
|
19
|
+
)
|
20
|
+
from ..registry import BlockRegistry
|
21
|
+
|
22
|
+
__all__ = [
|
23
|
+
"Block",
|
24
|
+
"FilterByValueBlock",
|
25
|
+
"IterBlock",
|
26
|
+
"LLMBlock",
|
27
|
+
"ConditionalLLMBlock",
|
28
|
+
"SamplePopulatorBlock",
|
29
|
+
"SelectorBlock",
|
30
|
+
"CombineColumnsBlock",
|
31
|
+
"FlattenColumnsBlock",
|
32
|
+
"DuplicateColumns",
|
33
|
+
"RenameColumns",
|
34
|
+
"SetToMajorityValue",
|
35
|
+
"BlockRegistry",
|
36
|
+
]
|
sdg_hub/blocks/block.py
ADDED
@@ -0,0 +1,96 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Base block implementation for the SDG Hub system.
|
3
|
+
|
4
|
+
This module provides the abstract base class for all blocks in the system,
|
5
|
+
including functionality for template validation and configuration management.
|
6
|
+
"""
|
7
|
+
|
8
|
+
# Standard
|
9
|
+
from abc import ABC
|
10
|
+
from collections import ChainMap
|
11
|
+
from typing import Any, Dict, Optional
|
12
|
+
|
13
|
+
# Third Party
|
14
|
+
from jinja2 import Template, UndefinedError
|
15
|
+
import yaml
|
16
|
+
|
17
|
+
# Local
|
18
|
+
from ..registry import BlockRegistry
|
19
|
+
from ..logger_config import setup_logger
|
20
|
+
|
21
|
+
logger = setup_logger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
@BlockRegistry.register("Block")
|
25
|
+
class Block(ABC):
|
26
|
+
"""Base abstract class for all blocks in the system.
|
27
|
+
|
28
|
+
This class provides common functionality for block validation and configuration loading.
|
29
|
+
All specific block implementations should inherit from this class.
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(self, block_name: str) -> None:
|
33
|
+
self.block_name = block_name
|
34
|
+
|
35
|
+
@staticmethod
|
36
|
+
def _validate(prompt_template: Template, input_dict: Dict[str, Any]) -> bool:
|
37
|
+
"""Validate the input data for this block.
|
38
|
+
|
39
|
+
This method validates whether all required variables in the Jinja template are provided in the input_dict.
|
40
|
+
|
41
|
+
Parameters
|
42
|
+
----------
|
43
|
+
prompt_template : Template
|
44
|
+
The Jinja2 template object.
|
45
|
+
input_dict : Dict[str, Any]
|
46
|
+
A dictionary of input values to check against the template.
|
47
|
+
|
48
|
+
Returns
|
49
|
+
-------
|
50
|
+
bool
|
51
|
+
True if the input data is valid (i.e., no missing variables), False otherwise.
|
52
|
+
"""
|
53
|
+
|
54
|
+
class Default(dict):
|
55
|
+
def __missing__(self, key: str) -> None:
|
56
|
+
raise KeyError(key)
|
57
|
+
|
58
|
+
try:
|
59
|
+
# Try rendering the template with the input_dict
|
60
|
+
prompt_template.render(ChainMap(input_dict, Default()))
|
61
|
+
return True
|
62
|
+
except UndefinedError as e:
|
63
|
+
logger.error(f"Missing key: {e}")
|
64
|
+
return False
|
65
|
+
|
66
|
+
def _load_config(self, config_path: str) -> Optional[Dict[str, Any]]:
|
67
|
+
"""Load the configuration file for this block.
|
68
|
+
|
69
|
+
Parameters
|
70
|
+
----------
|
71
|
+
config_path : str
|
72
|
+
The path to the configuration file.
|
73
|
+
|
74
|
+
Returns
|
75
|
+
-------
|
76
|
+
Optional[Dict[str, Any]]
|
77
|
+
The loaded configuration. Returns None if file cannot be read or parsed.
|
78
|
+
|
79
|
+
Raises
|
80
|
+
------
|
81
|
+
FileNotFoundError
|
82
|
+
If the configuration file does not exist.
|
83
|
+
"""
|
84
|
+
try:
|
85
|
+
with open(config_path, "r", encoding="utf-8") as config_file:
|
86
|
+
try:
|
87
|
+
return yaml.safe_load(config_file)
|
88
|
+
except yaml.YAMLError as e:
|
89
|
+
logger.error(f"Error parsing YAML from {config_path}: {e}")
|
90
|
+
return None
|
91
|
+
except FileNotFoundError:
|
92
|
+
logger.error(f"Configuration file not found: {config_path}")
|
93
|
+
raise
|
94
|
+
except Exception as e:
|
95
|
+
logger.error(f"Unexpected error reading config file {config_path}: {e}")
|
96
|
+
return None
|
@@ -0,0 +1,375 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""LLM-based blocks for text generation and processing.
|
3
|
+
|
4
|
+
This module provides blocks for interacting with language models.
|
5
|
+
"""
|
6
|
+
|
7
|
+
# Standard
|
8
|
+
from typing import Any, Dict, List, Optional, Union
|
9
|
+
import json
|
10
|
+
import re
|
11
|
+
|
12
|
+
# Third Party
|
13
|
+
from datasets import Dataset
|
14
|
+
from jinja2 import Template
|
15
|
+
import openai
|
16
|
+
|
17
|
+
# Local
|
18
|
+
from .block import Block
|
19
|
+
from ..logger_config import setup_logger
|
20
|
+
from ..registry import BlockRegistry, PromptRegistry
|
21
|
+
|
22
|
+
logger = setup_logger(__name__)
|
23
|
+
|
24
|
+
|
25
|
+
def server_supports_batched(client: openai.OpenAI, model_id: str) -> bool:
|
26
|
+
"""Check if the server supports batched inputs.
|
27
|
+
|
28
|
+
This function checks if the server supports batched inputs by making a test call to the server.
|
29
|
+
|
30
|
+
Parameters
|
31
|
+
----------
|
32
|
+
client : openai.OpenAI
|
33
|
+
The client to use to make the test call.
|
34
|
+
model_id : str
|
35
|
+
The model ID to use for the test call.
|
36
|
+
"""
|
37
|
+
supported = getattr(client, "server_supports_batched", None)
|
38
|
+
if supported is not None:
|
39
|
+
return supported
|
40
|
+
try:
|
41
|
+
# Make a test call to the server to determine whether it supports
|
42
|
+
# multiple input prompts per request and also the n parameter
|
43
|
+
response = client.completions.create(
|
44
|
+
model=model_id, prompt=["test1", "test2"], max_tokens=1, n=3
|
45
|
+
)
|
46
|
+
# Number outputs should be 2 * 3 = 6
|
47
|
+
supported = len(response.choices) == 6
|
48
|
+
except openai.InternalServerError:
|
49
|
+
supported = False
|
50
|
+
client.server_supports_batched = supported
|
51
|
+
logger.info(f"LLM server supports batched inputs: {client.server_supports_batched}")
|
52
|
+
return supported
|
53
|
+
|
54
|
+
|
55
|
+
@BlockRegistry.register("LLMBlock")
|
56
|
+
class LLMBlock(Block):
|
57
|
+
"""Block for generating text using language models.
|
58
|
+
|
59
|
+
This block handles text generation, prompt formatting, and output parsing
|
60
|
+
for language model interactions.
|
61
|
+
|
62
|
+
Parameters
|
63
|
+
----------
|
64
|
+
block_name : str
|
65
|
+
Name of the block.
|
66
|
+
config_path : str
|
67
|
+
Path to the configuration file.
|
68
|
+
client : openai.OpenAI
|
69
|
+
OpenAI client instance.
|
70
|
+
output_cols : List[str]
|
71
|
+
List of output column names.
|
72
|
+
parser_kwargs : Dict[str, Any], optional
|
73
|
+
Keyword arguments for the parser, by default {}.
|
74
|
+
model_prompt : str, optional
|
75
|
+
Template string for model prompt, by default "{prompt}".
|
76
|
+
model_id : Optional[str], optional
|
77
|
+
Model ID to use, by default None.
|
78
|
+
**batch_kwargs : Dict[str, Any]
|
79
|
+
Additional keyword arguments for batch processing.
|
80
|
+
"""
|
81
|
+
|
82
|
+
# pylint: disable=too-many-instance-attributes
|
83
|
+
def __init__(
|
84
|
+
self,
|
85
|
+
block_name: str,
|
86
|
+
config_path: str,
|
87
|
+
client: openai.OpenAI,
|
88
|
+
output_cols: List[str],
|
89
|
+
parser_kwargs: Dict[str, Any] = {},
|
90
|
+
model_prompt: str = "{prompt}",
|
91
|
+
model_id: Optional[str] = None,
|
92
|
+
**batch_kwargs: Dict[str, Any],
|
93
|
+
) -> None:
|
94
|
+
super().__init__(block_name)
|
95
|
+
self.block_config = self._load_config(config_path)
|
96
|
+
self.prompt_struct = (
|
97
|
+
"""{system}\n{introduction}\n{principles}\n{examples}\n{generation}"""
|
98
|
+
)
|
99
|
+
filtered_config = {
|
100
|
+
k: (v if v is not None else "") for k, v in self.block_config.items()
|
101
|
+
}
|
102
|
+
self.prompt_template = Template(self.prompt_struct.format(**filtered_config))
|
103
|
+
self.client = client
|
104
|
+
if model_id:
|
105
|
+
self.model = model_id
|
106
|
+
else:
|
107
|
+
# get the default model id from client
|
108
|
+
self.model = self.client.models.list().data[0].id
|
109
|
+
|
110
|
+
self.model_prompt = model_prompt
|
111
|
+
self.output_cols = output_cols
|
112
|
+
self.batch_params = batch_kwargs.get("batch_kwargs", {})
|
113
|
+
self.parser_name = parser_kwargs.get("parser_name", None)
|
114
|
+
self.parsing_pattern = parser_kwargs.get("parsing_pattern", None)
|
115
|
+
self.parser_cleanup_tags = parser_kwargs.get("parser_cleanup_tags", None)
|
116
|
+
self.defaults = {
|
117
|
+
"model": self.model,
|
118
|
+
"temperature": 0,
|
119
|
+
"max_tokens": 4096,
|
120
|
+
}
|
121
|
+
|
122
|
+
# Whether the LLM server supports a list of input prompts
|
123
|
+
# and supports the n parameter to generate n outputs per input
|
124
|
+
self.server_supports_batched = server_supports_batched(client, self.model)
|
125
|
+
|
126
|
+
def _extract_matches(
|
127
|
+
self, text: str, start_tag: Optional[str], end_tag: Optional[str]
|
128
|
+
) -> List[str]:
|
129
|
+
if not text:
|
130
|
+
return []
|
131
|
+
if not start_tag and not end_tag:
|
132
|
+
return [text.strip()]
|
133
|
+
|
134
|
+
pattern = ""
|
135
|
+
if start_tag:
|
136
|
+
pattern += re.escape(start_tag)
|
137
|
+
pattern += r"(.*?)"
|
138
|
+
if end_tag:
|
139
|
+
pattern += re.escape(end_tag)
|
140
|
+
elif start_tag:
|
141
|
+
# Enforce matching till end of string when only start_tag is provided.
|
142
|
+
pattern += "$"
|
143
|
+
|
144
|
+
return [match.strip() for match in re.findall(pattern, text, re.DOTALL)]
|
145
|
+
|
146
|
+
def _parse(self, generated_string: str) -> dict:
|
147
|
+
matches = {}
|
148
|
+
|
149
|
+
if self.parser_name is not None and self.parser_name == "custom":
|
150
|
+
pattern = re.compile(self.parsing_pattern, re.DOTALL)
|
151
|
+
all_matches = pattern.findall(generated_string)
|
152
|
+
matches = {column_name: [] for column_name in self.output_cols}
|
153
|
+
if all_matches and isinstance(all_matches[0], tuple):
|
154
|
+
for match in all_matches:
|
155
|
+
for column_name, value in zip(self.output_cols, match):
|
156
|
+
value = value.strip()
|
157
|
+
for clean_tag in self.parser_cleanup_tags:
|
158
|
+
value = value.replace(clean_tag, "")
|
159
|
+
matches[column_name].append(value)
|
160
|
+
else:
|
161
|
+
matches[self.output_cols[0]] = (
|
162
|
+
[match.strip() for match in all_matches] if all_matches else []
|
163
|
+
)
|
164
|
+
else:
|
165
|
+
for start_tag, end_tag, output_col in zip(
|
166
|
+
self.block_config.get("start_tags", []),
|
167
|
+
self.block_config.get("end_tags", []),
|
168
|
+
self.output_cols,
|
169
|
+
):
|
170
|
+
matches[output_col] = self._extract_matches(
|
171
|
+
generated_string, start_tag, end_tag
|
172
|
+
)
|
173
|
+
|
174
|
+
return matches
|
175
|
+
|
176
|
+
def _format_prompt(self, sample: Dict) -> str:
|
177
|
+
prompt_templated_str = self.prompt_template.render(sample).strip()
|
178
|
+
return PromptRegistry.render_template(
|
179
|
+
self.model_prompt, prompt_templated_str, add_generation_prompt=True
|
180
|
+
).strip()
|
181
|
+
|
182
|
+
def _generate(self, samples: Dataset, **gen_kwargs: Dict[str, Any]) -> list:
|
183
|
+
prompts = [self._format_prompt(sample) for sample in samples]
|
184
|
+
logger.debug("Prompt: %s", prompts[0])
|
185
|
+
generate_args = {**self.defaults, **gen_kwargs}
|
186
|
+
|
187
|
+
if self.server_supports_batched:
|
188
|
+
response = self.client.completions.create(prompt=prompts, **generate_args)
|
189
|
+
# if stop is provided, then we need to add the stop token to the generated text,
|
190
|
+
# this is because the stop token is not included in the generated text - this is a limitation of the openai api
|
191
|
+
# we need to add the stop token to the generated text to make it consistent for the parser
|
192
|
+
if "stop" in generate_args:
|
193
|
+
return [
|
194
|
+
choice.text.strip() + "".join(generate_args["stop"])
|
195
|
+
for choice in response.choices
|
196
|
+
]
|
197
|
+
return [choice.text.strip() for choice in response.choices]
|
198
|
+
|
199
|
+
n = gen_kwargs.get("n", 1)
|
200
|
+
results = []
|
201
|
+
for prompt in prompts:
|
202
|
+
for _ in range(n):
|
203
|
+
response = self.client.completions.create(
|
204
|
+
prompt=prompt, **generate_args
|
205
|
+
)
|
206
|
+
if "stop" in generate_args:
|
207
|
+
results.append(
|
208
|
+
response.choices[0].text.strip()
|
209
|
+
+ "".join(generate_args["stop"])
|
210
|
+
)
|
211
|
+
results.append(response.choices[0].text.strip())
|
212
|
+
return results
|
213
|
+
|
214
|
+
def generate(self, samples: Dataset, **gen_kwargs: Dict[str, Any]) -> Dataset:
|
215
|
+
"""Generate the output from the block.
|
216
|
+
|
217
|
+
This method should first validate the input data,
|
218
|
+
then generate the output, and finally parse the generated output before returning it.
|
219
|
+
|
220
|
+
Returns
|
221
|
+
-------
|
222
|
+
Dataset
|
223
|
+
The parsed output after generation.
|
224
|
+
"""
|
225
|
+
num_samples = self.block_config.get("num_samples", None)
|
226
|
+
logger.debug("Generating outputs for {} samples".format(len(samples)))
|
227
|
+
|
228
|
+
if (num_samples is not None) and ("num_samples" not in samples.column_names):
|
229
|
+
samples = samples.add_column("num_samples", [num_samples] * len(samples))
|
230
|
+
|
231
|
+
# validate each sample
|
232
|
+
# Log errors and remove invalid samples
|
233
|
+
valid_samples = []
|
234
|
+
|
235
|
+
for sample in samples:
|
236
|
+
if self._validate(self.prompt_template, sample):
|
237
|
+
valid_samples.append(sample)
|
238
|
+
else:
|
239
|
+
logger.warning(
|
240
|
+
f"Sample failed validation: {sample}"
|
241
|
+
) # Log details of the failed sample
|
242
|
+
|
243
|
+
samples = valid_samples
|
244
|
+
|
245
|
+
if len(samples) == 0:
|
246
|
+
logger.warning(
|
247
|
+
"No valid samples to generate outputs for, returning empty dataset"
|
248
|
+
)
|
249
|
+
return Dataset.from_list([])
|
250
|
+
|
251
|
+
# generate the output
|
252
|
+
|
253
|
+
outputs = self._generate(samples, **gen_kwargs)
|
254
|
+
|
255
|
+
logger.debug("Generated outputs: %s", outputs)
|
256
|
+
|
257
|
+
num_parallel_samples = gen_kwargs.get("n", 1)
|
258
|
+
extended_samples = []
|
259
|
+
|
260
|
+
# Duplicate each input sample n times, where n is the number
|
261
|
+
# of output sequences generated per input, so that we can
|
262
|
+
# pair up the inputs and outputs.
|
263
|
+
for item in samples:
|
264
|
+
extended_samples.extend([item] * num_parallel_samples)
|
265
|
+
|
266
|
+
new_data = []
|
267
|
+
for sample, output in zip(extended_samples, outputs):
|
268
|
+
parsed_outputs = self._parse(output)
|
269
|
+
max_length = max(len(value) for value in parsed_outputs.values())
|
270
|
+
for values in zip(*(lst[:max_length] for lst in parsed_outputs.values())):
|
271
|
+
new_data.append({**sample, **dict(zip(parsed_outputs.keys(), values))})
|
272
|
+
|
273
|
+
return Dataset.from_list(new_data)
|
274
|
+
|
275
|
+
|
276
|
+
@BlockRegistry.register("ConditionalLLMBlock")
|
277
|
+
class ConditionalLLMBlock(LLMBlock):
|
278
|
+
"""Block for conditional text generation using language models.
|
279
|
+
|
280
|
+
This block selects different prompt templates based on a selector column value.
|
281
|
+
|
282
|
+
Parameters
|
283
|
+
----------
|
284
|
+
block_name : str
|
285
|
+
Name of the block.
|
286
|
+
config_paths : Dict[str, str]
|
287
|
+
Dictionary mapping selector values to their config file paths.
|
288
|
+
client : openai.OpenAI
|
289
|
+
OpenAI client instance.
|
290
|
+
model_id : str
|
291
|
+
Model ID to use.
|
292
|
+
output_cols : List[str]
|
293
|
+
List of output column names.
|
294
|
+
selector_column_name : str
|
295
|
+
Name of the column used to select the prompt template.
|
296
|
+
model_prompt : str, optional
|
297
|
+
Template string for model prompt, by default "{prompt}".
|
298
|
+
**batch_kwargs : Dict[str, Any]
|
299
|
+
Additional keyword arguments for batch processing.
|
300
|
+
"""
|
301
|
+
|
302
|
+
def __init__(
|
303
|
+
self,
|
304
|
+
block_name: str,
|
305
|
+
config_paths: Dict[str, str],
|
306
|
+
client: openai.OpenAI,
|
307
|
+
model_id: str,
|
308
|
+
output_cols: List[str],
|
309
|
+
selector_column_name: str,
|
310
|
+
model_prompt: str = "{prompt}",
|
311
|
+
**batch_kwargs: Dict[str, Any],
|
312
|
+
) -> None:
|
313
|
+
super().__init__(
|
314
|
+
block_name=block_name,
|
315
|
+
config_path=list(config_paths.values())[0],
|
316
|
+
client=client,
|
317
|
+
model_id=model_id,
|
318
|
+
output_cols=output_cols,
|
319
|
+
model_prompt=model_prompt,
|
320
|
+
**batch_kwargs,
|
321
|
+
)
|
322
|
+
self.selector_column_name = selector_column_name
|
323
|
+
self.prompt_template = {}
|
324
|
+
if "All" in config_paths:
|
325
|
+
self.prompt_template = self.prompt_struct.format(**self.block_config)
|
326
|
+
else:
|
327
|
+
for config_key, config in config_paths.items():
|
328
|
+
filtered_config = {
|
329
|
+
k: (v if v is not None else "")
|
330
|
+
for k, v in self.block_config.items()
|
331
|
+
}
|
332
|
+
self.prompt_template[config_key] = Template(
|
333
|
+
self.prompt_struct.format(**self._load_config(config))
|
334
|
+
)
|
335
|
+
|
336
|
+
def _format_prompt(self, sample: Dict[str, Any]) -> str:
|
337
|
+
"""Format the prompt based on the selector column value.
|
338
|
+
|
339
|
+
Parameters
|
340
|
+
----------
|
341
|
+
sample : Dict[str, Any]
|
342
|
+
Input sample containing the selector column.
|
343
|
+
|
344
|
+
Returns
|
345
|
+
-------
|
346
|
+
str
|
347
|
+
Formatted prompt string.
|
348
|
+
"""
|
349
|
+
if isinstance(self.prompt_template, dict):
|
350
|
+
return (
|
351
|
+
self.prompt_template[sample[self.selector_column_name]]
|
352
|
+
.render(**sample)
|
353
|
+
.strip()
|
354
|
+
)
|
355
|
+
|
356
|
+
return self.prompt_template.render(**sample).strip()
|
357
|
+
|
358
|
+
def _validate(self, prompt_template: Union[str, Template], input_dict: Dict[str, Any]) -> bool:
|
359
|
+
"""Validate the input data for this block.
|
360
|
+
|
361
|
+
Parameters
|
362
|
+
----------
|
363
|
+
prompt_template : Union[str, Template]
|
364
|
+
The template to validate against.
|
365
|
+
input_dict : Dict[str, Any]
|
366
|
+
Input data to validate.
|
367
|
+
|
368
|
+
Returns
|
369
|
+
-------
|
370
|
+
bool
|
371
|
+
True if the input data is valid, False otherwise.
|
372
|
+
"""
|
373
|
+
if isinstance(prompt_template, dict):
|
374
|
+
prompt_template = prompt_template[input_dict[self.selector_column_name]]
|
375
|
+
return super()._validate(prompt_template, input_dict)
|