sdg-hub 0.1.0a2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/__init__.py +4 -0
- sdg_hub/_version.py +21 -0
- sdg_hub/blocks/__init__.py +6 -0
- sdg_hub/blocks/block.py +54 -0
- sdg_hub/blocks/filterblock.py +76 -0
- sdg_hub/blocks/iterblock.py +31 -0
- sdg_hub/blocks/llmblock.py +430 -0
- sdg_hub/blocks/rmblocks.py +194 -0
- sdg_hub/blocks/utilblocks.py +140 -0
- sdg_hub/configs/__init__.py +0 -0
- sdg_hub/configs/annotations/__init__.py +0 -0
- sdg_hub/configs/annotations/cot_reflection.yaml +34 -0
- sdg_hub/configs/annotations/detailed_description.yaml +10 -0
- sdg_hub/configs/annotations/detailed_description_icl.yaml +32 -0
- sdg_hub/configs/annotations/simple.yaml +10 -0
- sdg_hub/configs/knowledge/__init__.py +0 -0
- sdg_hub/configs/knowledge/atomic_facts.yaml +45 -0
- sdg_hub/configs/knowledge/auxilary_instructions.yaml +35 -0
- sdg_hub/configs/knowledge/data_recipe/__init__.py +0 -0
- sdg_hub/configs/knowledge/data_recipe/default_recipe.yaml +3 -0
- sdg_hub/configs/knowledge/detailed_summary.yaml +17 -0
- sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +68 -0
- sdg_hub/configs/knowledge/evaluate_question.yaml +38 -0
- sdg_hub/configs/knowledge/evaluate_relevancy.yaml +85 -0
- sdg_hub/configs/knowledge/extractive_summary.yaml +17 -0
- sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +39 -0
- sdg_hub/configs/knowledge/generate_questions_responses.yaml +56 -0
- sdg_hub/configs/knowledge/mcq_generation.yaml +83 -0
- sdg_hub/configs/knowledge/router.yaml +12 -0
- sdg_hub/configs/knowledge/simple_generate_qa.yaml +34 -0
- sdg_hub/configs/reasoning/dynamic_cot.yaml +40 -0
- sdg_hub/configs/skills/_A_.yaml +97 -0
- sdg_hub/configs/skills/_B_.yaml +36 -0
- sdg_hub/configs/skills/_C_.yaml +71 -0
- sdg_hub/configs/skills/_D_.yaml +85 -0
- sdg_hub/configs/skills/_E_.yaml +30 -0
- sdg_hub/configs/skills/_F_.yaml +45 -0
- sdg_hub/configs/skills/_G_.yaml +56 -0
- sdg_hub/configs/skills/_H_.yaml +80 -0
- sdg_hub/configs/skills/__init__.py +0 -0
- sdg_hub/configs/skills/analyzer.yaml +48 -0
- sdg_hub/configs/skills/annotation.yaml +36 -0
- sdg_hub/configs/skills/contexts.yaml +21 -0
- sdg_hub/configs/skills/critic.yaml +60 -0
- sdg_hub/configs/skills/data_recipe/__init__.py +0 -0
- sdg_hub/configs/skills/data_recipe/default_recipe.yaml +6 -0
- sdg_hub/configs/skills/evaluate_freeform_pair.yaml +44 -0
- sdg_hub/configs/skills/evaluate_freeform_questions.yaml +46 -0
- sdg_hub/configs/skills/evaluate_grounded_pair.yaml +54 -0
- sdg_hub/configs/skills/evaluate_grounded_questions.yaml +51 -0
- sdg_hub/configs/skills/freeform_questions.yaml +29 -0
- sdg_hub/configs/skills/freeform_responses.yaml +45 -0
- sdg_hub/configs/skills/grounded_questions.yaml +38 -0
- sdg_hub/configs/skills/grounded_responses.yaml +59 -0
- sdg_hub/configs/skills/judge.yaml +53 -0
- sdg_hub/configs/skills/planner.yaml +67 -0
- sdg_hub/configs/skills/respond.yaml +8 -0
- sdg_hub/configs/skills/revised_responder.yaml +78 -0
- sdg_hub/configs/skills/router.yaml +12 -0
- sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +27 -0
- sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +31 -0
- sdg_hub/flow.py +127 -0
- sdg_hub/flows/annotation/emotion/detailed_description.yaml +19 -0
- sdg_hub/flows/annotation/emotion/detailed_description_icl.yaml +19 -0
- sdg_hub/flows/annotation/emotion/simple.yaml +19 -0
- sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +13 -0
- sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +12 -0
- sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +89 -0
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +136 -0
- sdg_hub/flows/generation/skills/agentic_improve_skill.yaml +108 -0
- sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +12 -0
- sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +12 -0
- sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +80 -0
- sdg_hub/flows/generation/skills/synth_skills.yaml +59 -0
- sdg_hub/logger_config.py +20 -0
- sdg_hub/pipeline.py +66 -0
- sdg_hub/prompts.py +17 -0
- sdg_hub/py.typed +0 -0
- sdg_hub/registry.py +122 -0
- sdg_hub/sdg.py +164 -0
- sdg_hub/utils/__init__.py +5 -0
- sdg_hub/utils/chunking.py +73 -0
- sdg_hub/utils/datamixing.py +123 -0
- sdg_hub/utils/datautils.py +14 -0
- sdg_hub/utils/docprocessor.py +357 -0
- sdg_hub/utils/json.py +48 -0
- sdg_hub/utils/models.py +31 -0
- sdg_hub/utils/parse_and_convert.py +392 -0
- sdg_hub/utils/taxonomy.py +489 -0
- sdg_hub-0.1.0a2.dev0.dist-info/METADATA +154 -0
- sdg_hub-0.1.0a2.dev0.dist-info/RECORD +94 -0
- sdg_hub-0.1.0a2.dev0.dist-info/WHEEL +5 -0
- sdg_hub-0.1.0a2.dev0.dist-info/licenses/LICENSE +201 -0
- sdg_hub-0.1.0a2.dev0.dist-info/top_level.txt +1 -0
sdg_hub/__init__.py
ADDED
sdg_hub/_version.py
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
# file generated by setuptools-scm
|
2
|
+
# don't change, don't track in version control
|
3
|
+
|
4
|
+
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
5
|
+
|
6
|
+
TYPE_CHECKING = False
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from typing import Tuple
|
9
|
+
from typing import Union
|
10
|
+
|
11
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
12
|
+
else:
|
13
|
+
VERSION_TUPLE = object
|
14
|
+
|
15
|
+
version: str
|
16
|
+
__version__: str
|
17
|
+
__version_tuple__: VERSION_TUPLE
|
18
|
+
version_tuple: VERSION_TUPLE
|
19
|
+
|
20
|
+
__version__ = version = '0.1.0a2.dev0'
|
21
|
+
__version_tuple__ = version_tuple = (0, 1, 0, 'dev0')
|
sdg_hub/blocks/block.py
ADDED
@@ -0,0 +1,54 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# Standard
|
3
|
+
from abc import ABC
|
4
|
+
from collections import ChainMap
|
5
|
+
from typing import Any, Dict, Union
|
6
|
+
|
7
|
+
# Third Party
|
8
|
+
from jinja2 import Template, UndefinedError
|
9
|
+
import yaml
|
10
|
+
|
11
|
+
# Local
|
12
|
+
from ..registry import BlockRegistry
|
13
|
+
from ..logger_config import setup_logger
|
14
|
+
|
15
|
+
logger = setup_logger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
@BlockRegistry.register("Block")
|
19
|
+
class Block(ABC):
|
20
|
+
def __init__(self, block_name: str) -> None:
|
21
|
+
self.block_name = block_name
|
22
|
+
|
23
|
+
@staticmethod
|
24
|
+
def _validate(prompt_template: Template, input_dict: Dict[str, Any]) -> bool:
|
25
|
+
"""
|
26
|
+
Validate the input data for this block. This method validates whether all required
|
27
|
+
variables in the Jinja template are provided in the input_dict.
|
28
|
+
|
29
|
+
:param prompt_template: The Jinja2 template object.
|
30
|
+
:param input_dict: A dictionary of input values to check against the template.
|
31
|
+
:return: True if the input data is valid (i.e., no missing variables), False otherwise.
|
32
|
+
"""
|
33
|
+
|
34
|
+
class Default(dict):
|
35
|
+
def __missing__(self, key: str) -> None:
|
36
|
+
raise KeyError(key)
|
37
|
+
|
38
|
+
try:
|
39
|
+
# Try rendering the template with the input_dict
|
40
|
+
prompt_template.render(ChainMap(input_dict, Default()))
|
41
|
+
return True
|
42
|
+
except UndefinedError as e:
|
43
|
+
logger.error(f"Missing key: {e}")
|
44
|
+
return False
|
45
|
+
|
46
|
+
def _load_config(self, config_path: str) -> Union[Dict[str, Any], None]:
|
47
|
+
"""
|
48
|
+
Load the configuration file for this block.
|
49
|
+
|
50
|
+
:param config_path: The path to the configuration file.
|
51
|
+
:return: The loaded configuration.
|
52
|
+
"""
|
53
|
+
with open(config_path, "r", encoding="utf-8") as config_file:
|
54
|
+
return yaml.safe_load(config_file)
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# Standard
|
3
|
+
import operator
|
4
|
+
|
5
|
+
# Third Party
|
6
|
+
from datasets import Dataset
|
7
|
+
|
8
|
+
# Local
|
9
|
+
from .block import Block
|
10
|
+
from ..registry import BlockRegistry
|
11
|
+
from ..logger_config import setup_logger
|
12
|
+
|
13
|
+
logger = setup_logger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
@BlockRegistry.register("FilterByValueBlock")
|
17
|
+
class FilterByValueBlock(Block):
|
18
|
+
def __init__(
|
19
|
+
self, filter_column, filter_value, operation, convert_dtype=None, **batch_kwargs
|
20
|
+
) -> None:
|
21
|
+
"""
|
22
|
+
Initializes a new instance of the FilterByValueBlock class.
|
23
|
+
|
24
|
+
Parameters:
|
25
|
+
- filter_column (str): The name of the column in the dataset to apply the filter on.
|
26
|
+
- filter_value (any or list of any): The value(s) to filter by.
|
27
|
+
- operation (callable): A function that takes two arguments (column value and filter value) and returns a boolean indicating whether the row should be included in the filtered dataset.
|
28
|
+
- convert_dtype (callable, optional): A function to convert the data type of the filter column before applying the filter. Defaults to None.
|
29
|
+
- **batch_kwargs: Additional kwargs for batch processing.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
None
|
33
|
+
"""
|
34
|
+
super().__init__(block_name=self.__class__.__name__)
|
35
|
+
self.value = filter_value if isinstance(filter_value, list) else [filter_value]
|
36
|
+
self.column_name = filter_column
|
37
|
+
self.operation = operation
|
38
|
+
self.convert_dtype = convert_dtype
|
39
|
+
self.num_procs = batch_kwargs.get("num_procs", 1)
|
40
|
+
|
41
|
+
def _convert_dtype(self, sample):
|
42
|
+
try:
|
43
|
+
sample[self.column_name] = self.convert_dtype(sample[self.column_name])
|
44
|
+
except ValueError as e:
|
45
|
+
logger.error(
|
46
|
+
"Error converting dtype: %s, filling with None to be filtered later", e
|
47
|
+
)
|
48
|
+
sample[self.column_name] = None
|
49
|
+
return sample
|
50
|
+
|
51
|
+
def generate(self, samples) -> Dataset:
|
52
|
+
if self.convert_dtype:
|
53
|
+
samples = samples.map(
|
54
|
+
self._convert_dtype,
|
55
|
+
num_proc=self.num_procs,
|
56
|
+
)
|
57
|
+
|
58
|
+
if self.operation == operator.contains:
|
59
|
+
samples = samples.filter(
|
60
|
+
lambda x: self.operation(self.value, x[self.column_name]),
|
61
|
+
num_proc=self.num_procs,
|
62
|
+
)
|
63
|
+
|
64
|
+
samples = samples.filter(
|
65
|
+
lambda x: x[self.column_name] is not None,
|
66
|
+
num_proc=self.num_procs,
|
67
|
+
)
|
68
|
+
|
69
|
+
samples = samples.filter(
|
70
|
+
lambda x: any(
|
71
|
+
self.operation(x[self.column_name], value) for value in self.value
|
72
|
+
),
|
73
|
+
num_proc=self.num_procs,
|
74
|
+
)
|
75
|
+
|
76
|
+
return samples
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# Third Party
|
2
|
+
from datasets import Dataset
|
3
|
+
|
4
|
+
# Local
|
5
|
+
from .block import Block
|
6
|
+
from ..registry import BlockRegistry
|
7
|
+
from ..logger_config import setup_logger
|
8
|
+
|
9
|
+
logger = setup_logger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
@BlockRegistry.register("IterBlock")
|
13
|
+
class IterBlock(Block):
|
14
|
+
def __init__(self, block_name, num_iters, block_type, block_kwargs, **kwargs):
|
15
|
+
super().__init__(block_name)
|
16
|
+
self.num_iters = num_iters
|
17
|
+
self.block = block_type(**block_kwargs)
|
18
|
+
self.gen_kwargs = kwargs.get("gen_kwargs", {})
|
19
|
+
self.gen_kwargs = kwargs.get("gen_kwargs", {})
|
20
|
+
|
21
|
+
def generate(self, samples, **gen_kwargs) -> Dataset:
|
22
|
+
generated_samples = []
|
23
|
+
num_iters = self.num_iters
|
24
|
+
|
25
|
+
for _ in range(num_iters):
|
26
|
+
batch_generated = self.block.generate(
|
27
|
+
samples, **{**self.gen_kwargs, **gen_kwargs}
|
28
|
+
)
|
29
|
+
generated_samples.extend(batch_generated)
|
30
|
+
|
31
|
+
return Dataset.from_list(generated_samples)
|
@@ -0,0 +1,430 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# Standard
|
3
|
+
from collections import Counter
|
4
|
+
from typing import Any, Dict, List
|
5
|
+
import json
|
6
|
+
import re
|
7
|
+
|
8
|
+
# Third Party
|
9
|
+
from datasets import Dataset
|
10
|
+
from jinja2 import Template
|
11
|
+
import openai
|
12
|
+
|
13
|
+
# Local
|
14
|
+
from .block import Block
|
15
|
+
from ..logger_config import setup_logger
|
16
|
+
from ..registry import BlockRegistry, PromptRegistry
|
17
|
+
|
18
|
+
logger = setup_logger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
def server_supports_batched(client, model_id: str) -> bool:
|
22
|
+
supported = getattr(client, "server_supports_batched", None)
|
23
|
+
if supported is not None:
|
24
|
+
return supported
|
25
|
+
try:
|
26
|
+
# Make a test call to the server to determine whether it supports
|
27
|
+
# multiple input prompts per request and also the n parameter
|
28
|
+
response = client.completions.create(
|
29
|
+
model=model_id, prompt=["test1", "test2"], max_tokens=1, n=3
|
30
|
+
)
|
31
|
+
# Number outputs should be 2 * 3 = 6
|
32
|
+
supported = len(response.choices) == 6
|
33
|
+
except openai.InternalServerError:
|
34
|
+
supported = False
|
35
|
+
client.server_supports_batched = supported
|
36
|
+
logger.info(f"LLM server supports batched inputs: {client.server_supports_batched}")
|
37
|
+
return supported
|
38
|
+
|
39
|
+
|
40
|
+
@BlockRegistry.register("LLMBlock")
|
41
|
+
# pylint: disable=dangerous-default-value
|
42
|
+
class LLMBlock(Block):
|
43
|
+
# pylint: disable=too-many-instance-attributes
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
block_name,
|
47
|
+
config_path,
|
48
|
+
client,
|
49
|
+
output_cols,
|
50
|
+
parser_kwargs={},
|
51
|
+
model_prompt="{prompt}",
|
52
|
+
model_id=None,
|
53
|
+
**batch_kwargs,
|
54
|
+
) -> None:
|
55
|
+
super().__init__(block_name)
|
56
|
+
self.block_config = self._load_config(config_path)
|
57
|
+
self.prompt_struct = (
|
58
|
+
"""{system}\n{introduction}\n{principles}\n{examples}\n{generation}"""
|
59
|
+
)
|
60
|
+
filtered_config = {
|
61
|
+
k: (v if v is not None else "") for k, v in self.block_config.items()
|
62
|
+
}
|
63
|
+
self.prompt_template = Template(self.prompt_struct.format(**filtered_config))
|
64
|
+
self.client = client
|
65
|
+
if model_id:
|
66
|
+
self.model = model_id
|
67
|
+
else:
|
68
|
+
# get the default model id from client
|
69
|
+
self.model = self.client.models.list().data[0].id
|
70
|
+
|
71
|
+
self.model_prompt = model_prompt
|
72
|
+
self.output_cols = output_cols
|
73
|
+
self.batch_params = batch_kwargs.get("batch_kwargs", {})
|
74
|
+
self.parser_name = parser_kwargs.get("parser_name", None)
|
75
|
+
self.parsing_pattern = parser_kwargs.get("parsing_pattern", None)
|
76
|
+
self.parser_cleanup_tags = parser_kwargs.get("parser_cleanup_tags", None)
|
77
|
+
self.defaults = {
|
78
|
+
"model": self.model,
|
79
|
+
"temperature": 0,
|
80
|
+
"max_tokens": 4096,
|
81
|
+
}
|
82
|
+
|
83
|
+
# Whether the LLM server supports a list of input prompts
|
84
|
+
# and supports the n parameter to generate n outputs per input
|
85
|
+
self.server_supports_batched = server_supports_batched(client, self.model)
|
86
|
+
|
87
|
+
def _parse(self, generated_string) -> dict:
|
88
|
+
matches = {}
|
89
|
+
|
90
|
+
if self.parser_name is not None and self.parser_name == "custom":
|
91
|
+
pattern = re.compile(self.parsing_pattern, re.DOTALL)
|
92
|
+
all_matches = pattern.findall(generated_string)
|
93
|
+
matches = {column_name: [] for column_name in self.output_cols}
|
94
|
+
if all_matches and isinstance(all_matches[0], tuple):
|
95
|
+
for match in all_matches:
|
96
|
+
for column_name, value in zip(self.output_cols, match):
|
97
|
+
value = value.strip()
|
98
|
+
for clean_tag in self.parser_cleanup_tags:
|
99
|
+
value = value.replace(clean_tag, "")
|
100
|
+
matches[column_name].append(value)
|
101
|
+
else:
|
102
|
+
matches[self.output_cols[0]] = (
|
103
|
+
[match.strip() for match in all_matches] if all_matches else []
|
104
|
+
)
|
105
|
+
else:
|
106
|
+
for start_tag, end_tag, output_col in zip(
|
107
|
+
self.block_config.get("start_tags", []),
|
108
|
+
self.block_config.get("end_tags", []),
|
109
|
+
self.output_cols,
|
110
|
+
):
|
111
|
+
if not start_tag and not end_tag:
|
112
|
+
matches[output_col] = [
|
113
|
+
generated_string.strip() if generated_string else None
|
114
|
+
]
|
115
|
+
else:
|
116
|
+
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
|
117
|
+
all_matches = re.findall(pattern, generated_string, re.DOTALL)
|
118
|
+
matches[output_col] = (
|
119
|
+
[match.strip() for match in all_matches] if all_matches else []
|
120
|
+
)
|
121
|
+
|
122
|
+
return matches
|
123
|
+
|
124
|
+
def _format_prompt(self, sample: Dict) -> str:
|
125
|
+
prompt_templated_str = self.prompt_template.render(sample).strip()
|
126
|
+
return PromptRegistry.render_template(
|
127
|
+
self.model_prompt, prompt_templated_str, add_generation_prompt=True
|
128
|
+
).strip()
|
129
|
+
|
130
|
+
def _generate(self, samples, **gen_kwargs) -> list:
|
131
|
+
prompts = [self._format_prompt(sample) for sample in samples]
|
132
|
+
logger.debug("Prompt: %s", prompts[0])
|
133
|
+
generate_args = {**self.defaults, **gen_kwargs}
|
134
|
+
|
135
|
+
if self.server_supports_batched:
|
136
|
+
response = self.client.completions.create(prompt=prompts, **generate_args)
|
137
|
+
# if stop is provided, then we need to add the stop token to the generated text,
|
138
|
+
# this is because the stop token is not included in the generated text - this is a limitation of the openai api
|
139
|
+
# we need to add the stop token to the generated text to make it consistent for the parser
|
140
|
+
if "stop" in generate_args:
|
141
|
+
return [
|
142
|
+
choice.text.strip() + "".join(generate_args["stop"])
|
143
|
+
for choice in response.choices
|
144
|
+
]
|
145
|
+
return [choice.text.strip() for choice in response.choices]
|
146
|
+
|
147
|
+
n = gen_kwargs.get("n", 1)
|
148
|
+
results = []
|
149
|
+
for prompt in prompts:
|
150
|
+
for _ in range(n):
|
151
|
+
response = self.client.completions.create(
|
152
|
+
prompt=prompt, **generate_args
|
153
|
+
)
|
154
|
+
if "stop" in generate_args:
|
155
|
+
results.append(
|
156
|
+
response.choices[0].text.strip()
|
157
|
+
+ "".join(generate_args["stop"])
|
158
|
+
)
|
159
|
+
results.append(response.choices[0].text.strip())
|
160
|
+
return results
|
161
|
+
|
162
|
+
def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
|
163
|
+
"""
|
164
|
+
Generate the output from the block. This method should first validate the input data,
|
165
|
+
then generate the output, and finally parse the generated output before returning it.
|
166
|
+
|
167
|
+
:return: The parsed output after generation.
|
168
|
+
"""
|
169
|
+
num_samples = self.block_config.get("num_samples", None)
|
170
|
+
logger.debug("Generating outputs for {} samples".format(len(samples)))
|
171
|
+
|
172
|
+
if (num_samples is not None) and ("num_samples" not in samples.column_names):
|
173
|
+
samples = samples.add_column("num_samples", [num_samples] * len(samples))
|
174
|
+
|
175
|
+
# validate each sample
|
176
|
+
# Log errors and remove invalid samples
|
177
|
+
valid_samples = []
|
178
|
+
|
179
|
+
for sample in samples:
|
180
|
+
if self._validate(self.prompt_template, sample):
|
181
|
+
valid_samples.append(sample)
|
182
|
+
else:
|
183
|
+
logger.warning(
|
184
|
+
f"Sample failed validation: {sample}"
|
185
|
+
) # Log details of the failed sample
|
186
|
+
|
187
|
+
samples = valid_samples
|
188
|
+
|
189
|
+
if len(samples) == 0:
|
190
|
+
logger.warning(
|
191
|
+
"No valid samples to generate outputs for, returning empty dataset"
|
192
|
+
)
|
193
|
+
return Dataset.from_list([])
|
194
|
+
|
195
|
+
# generate the output
|
196
|
+
|
197
|
+
outputs = self._generate(samples, **gen_kwargs)
|
198
|
+
|
199
|
+
logger.debug("Generated outputs: %s", outputs)
|
200
|
+
|
201
|
+
num_parallel_samples = gen_kwargs.get("n", 1)
|
202
|
+
extended_samples = []
|
203
|
+
|
204
|
+
# Duplicate each input sample n times, where n is the number
|
205
|
+
# of output sequences generated per input, so that we can
|
206
|
+
# pair up the inputs and outputs.
|
207
|
+
for item in samples:
|
208
|
+
extended_samples.extend([item] * num_parallel_samples)
|
209
|
+
|
210
|
+
new_data = []
|
211
|
+
for sample, output in zip(extended_samples, outputs):
|
212
|
+
parsed_outputs = self._parse(output)
|
213
|
+
max_length = max(len(value) for value in parsed_outputs.values())
|
214
|
+
for values in zip(*(lst[:max_length] for lst in parsed_outputs.values())):
|
215
|
+
new_data.append({**sample, **dict(zip(parsed_outputs.keys(), values))})
|
216
|
+
|
217
|
+
return Dataset.from_list(new_data)
|
218
|
+
|
219
|
+
|
220
|
+
@BlockRegistry.register("ConditionalLLMBlock")
|
221
|
+
class ConditionalLLMBlock(LLMBlock):
|
222
|
+
def __init__(
|
223
|
+
self,
|
224
|
+
block_name,
|
225
|
+
config_paths,
|
226
|
+
client,
|
227
|
+
model_id,
|
228
|
+
output_cols,
|
229
|
+
selector_column_name,
|
230
|
+
model_prompt="{prompt}",
|
231
|
+
**batch_kwargs,
|
232
|
+
) -> None:
|
233
|
+
super().__init__(
|
234
|
+
block_name=block_name,
|
235
|
+
config_path=list(config_paths.values())[0],
|
236
|
+
client=client,
|
237
|
+
model_id=model_id,
|
238
|
+
output_cols=output_cols,
|
239
|
+
model_prompt=model_prompt,
|
240
|
+
**batch_kwargs,
|
241
|
+
)
|
242
|
+
self.selector_column_name = selector_column_name
|
243
|
+
self.prompt_template = {}
|
244
|
+
if "All" in config_paths:
|
245
|
+
self.prompt_template = self.prompt_struct.format(**self.block_config)
|
246
|
+
else:
|
247
|
+
for config_key, config in config_paths.items():
|
248
|
+
# Template(self.prompt_struct.format(**filtered_config))
|
249
|
+
filtered_config = {
|
250
|
+
k: (v if v is not None else "") for k, v in self.block_config.items()
|
251
|
+
}
|
252
|
+
self.prompt_template[config_key] = Template(self.prompt_struct.format(
|
253
|
+
**self._load_config(config)
|
254
|
+
))
|
255
|
+
|
256
|
+
def _format_prompt(self, sample: Dict) -> str:
|
257
|
+
if isinstance(self.prompt_template, dict):
|
258
|
+
return (
|
259
|
+
self.prompt_template[sample[self.selector_column_name]]
|
260
|
+
.render(**sample)
|
261
|
+
.strip()
|
262
|
+
)
|
263
|
+
|
264
|
+
return self.prompt_template.render(**sample).strip()
|
265
|
+
|
266
|
+
def _validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool:
|
267
|
+
if isinstance(prompt_template, dict):
|
268
|
+
prompt_template = prompt_template[input_dict[self.selector_column_name]]
|
269
|
+
return super()._validate(prompt_template, input_dict)
|
270
|
+
|
271
|
+
|
272
|
+
@BlockRegistry.register("LLMLogProbBlock")
|
273
|
+
class LLMLogProbBlock(LLMBlock):
|
274
|
+
# init with init of the parent class
|
275
|
+
def __init__(
|
276
|
+
self,
|
277
|
+
block_name,
|
278
|
+
config_path,
|
279
|
+
client,
|
280
|
+
output_cols,
|
281
|
+
parser_kwargs={},
|
282
|
+
model_prompt="{prompt}",
|
283
|
+
model_id=None,
|
284
|
+
**batch_kwargs,
|
285
|
+
) -> None:
|
286
|
+
super().__init__(
|
287
|
+
block_name=block_name,
|
288
|
+
config_path=config_path,
|
289
|
+
client=client,
|
290
|
+
output_cols=output_cols,
|
291
|
+
parser_kwargs=parser_kwargs,
|
292
|
+
model_prompt=model_prompt,
|
293
|
+
model_id=model_id,
|
294
|
+
**batch_kwargs,
|
295
|
+
)
|
296
|
+
|
297
|
+
def _generate_logprobs(self, samples, **gen_kwargs):
|
298
|
+
prompts = [
|
299
|
+
self.model_prompt.format(prompt=self._format_prompt(sample))
|
300
|
+
for sample in samples
|
301
|
+
]
|
302
|
+
generate_args = {**self.defaults, **gen_kwargs}
|
303
|
+
|
304
|
+
# verify if logprobs is mentioned in the generate_args, if not add it and return top10 logprobs
|
305
|
+
if "logprobs" not in generate_args:
|
306
|
+
generate_args["logprobs"] = 10
|
307
|
+
|
308
|
+
if self.server_supports_batched:
|
309
|
+
response = self.client.completions.create(prompt=prompts, **generate_args)
|
310
|
+
return [choice.logprobs.top_logprobs for choice in response.choices]
|
311
|
+
|
312
|
+
n = gen_kwargs.get("n", 1)
|
313
|
+
results = []
|
314
|
+
for prompt in prompts:
|
315
|
+
for _ in range(n):
|
316
|
+
response = self.client.completions.create(
|
317
|
+
prompt=prompt, **generate_args
|
318
|
+
)
|
319
|
+
results.append(response.choices[0].logprobs.top_logprobs)
|
320
|
+
return results
|
321
|
+
|
322
|
+
def _parse(self, generations: List[List[Dict]]) -> List[List[str]]:
|
323
|
+
# override the parse method to convert the generations to json string
|
324
|
+
# convert the generations to json string to save as dataset
|
325
|
+
# this is because the dataset can only store key value pairs which are consistent
|
326
|
+
return [[json.dumps(item) for item in sublist] for sublist in generations]
|
327
|
+
|
328
|
+
def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
|
329
|
+
"""
|
330
|
+
Generate the output from the block. This method should first validate the input data,
|
331
|
+
then generate the output, and finally parse the generated output before returning it.
|
332
|
+
|
333
|
+
:return: The parsed output after generation.
|
334
|
+
"""
|
335
|
+
num_samples = self.block_config.get("num_samples", None)
|
336
|
+
logger.debug("Generating outputs for {} samples".format(len(samples)))
|
337
|
+
|
338
|
+
if (num_samples is not None) and ("num_samples" not in samples.column_names):
|
339
|
+
samples = samples.add_column("num_samples", [num_samples] * len(samples))
|
340
|
+
|
341
|
+
# validate each sample
|
342
|
+
# Log errors and remove invalid samples
|
343
|
+
valid_samples = []
|
344
|
+
|
345
|
+
for sample in samples:
|
346
|
+
if self._validate(self.prompt_template, sample):
|
347
|
+
valid_samples.append(sample)
|
348
|
+
else:
|
349
|
+
logger.warning(
|
350
|
+
f"Sample failed validation: {sample}"
|
351
|
+
) # Log details of the failed sample
|
352
|
+
|
353
|
+
samples = valid_samples
|
354
|
+
|
355
|
+
if len(samples) == 0:
|
356
|
+
logger.warning(
|
357
|
+
"No valid samples to generate outputs for, returning empty dataset"
|
358
|
+
)
|
359
|
+
return Dataset.from_list([])
|
360
|
+
|
361
|
+
# generate the output
|
362
|
+
|
363
|
+
outputs = self._generate_logprobs(samples, **gen_kwargs)
|
364
|
+
logger.debug("Generated outputs: %s", outputs)
|
365
|
+
|
366
|
+
output_dataset = Dataset.from_list(samples)
|
367
|
+
output_dataset = output_dataset.add_column(
|
368
|
+
self.output_cols[0],
|
369
|
+
self._parse(outputs), # pylint: disable=no-value-for-parameter
|
370
|
+
)
|
371
|
+
|
372
|
+
return output_dataset
|
373
|
+
|
374
|
+
|
375
|
+
@BlockRegistry.register("LLMMessagesBlock")
|
376
|
+
class LLMMessagesBlock(Block):
|
377
|
+
def __init__(
|
378
|
+
self,
|
379
|
+
block_name,
|
380
|
+
client,
|
381
|
+
input_col,
|
382
|
+
output_col,
|
383
|
+
model_prompt=None,
|
384
|
+
model_id=None,
|
385
|
+
**batch_kwargs,
|
386
|
+
) -> None:
|
387
|
+
self.block_name = block_name
|
388
|
+
self.model_prompt = model_prompt
|
389
|
+
self.batch_params = batch_kwargs.get("batch_kwargs", {})
|
390
|
+
self.input_col = input_col
|
391
|
+
self.output_col = output_col
|
392
|
+
self.client = client
|
393
|
+
|
394
|
+
if model_id:
|
395
|
+
self.model = model_id
|
396
|
+
else:
|
397
|
+
self.model = self.client.models.list().data[0].id
|
398
|
+
|
399
|
+
self.defaults = {
|
400
|
+
"model": self.model,
|
401
|
+
"temperature": 0,
|
402
|
+
"max_tokens": 4096,
|
403
|
+
}
|
404
|
+
self.server_supports_batched = server_supports_batched(client, self.model)
|
405
|
+
|
406
|
+
def _generate(self, samples, **gen_kwargs) -> list:
|
407
|
+
generate_args = {**self.defaults, **gen_kwargs}
|
408
|
+
|
409
|
+
if "n" in generate_args and generate_args.get("temperature", 0) <= 0:
|
410
|
+
generate_args["temperature"] = 0.7
|
411
|
+
logger.warning(
|
412
|
+
"Temperature should be greater than 0 for n > 1, setting temperature to 0.7"
|
413
|
+
)
|
414
|
+
|
415
|
+
messages = samples[self.input_col]
|
416
|
+
|
417
|
+
results = []
|
418
|
+
n = gen_kwargs.get("n", 1)
|
419
|
+
for message in messages:
|
420
|
+
responses = self.client.chat.completions.create(messages=message, **generate_args)
|
421
|
+
if n > 1:
|
422
|
+
results.append([choice.message.content for choice in responses.choices])
|
423
|
+
else:
|
424
|
+
results.append(responses.choices[0].message.content)
|
425
|
+
return results
|
426
|
+
|
427
|
+
def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
|
428
|
+
outputs = self._generate(samples, **gen_kwargs)
|
429
|
+
samples = samples.add_column(self.output_col, outputs)
|
430
|
+
return samples
|