sdg-hub 0.1.0a3__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/_version.py +2 -2
- sdg_hub/blocks/__init__.py +35 -5
- sdg_hub/blocks/block.py +58 -16
- sdg_hub/blocks/llmblock.py +149 -204
- sdg_hub/blocks/utilblocks.py +500 -43
- sdg_hub/checkpointer.py +139 -0
- sdg_hub/configs/annotations/detailed_annotations.yaml +28 -0
- sdg_hub/configs/annotations/simple_annotations.yaml +9 -0
- sdg_hub/configs/knowledge/atomic_facts.yaml +1 -0
- sdg_hub/configs/knowledge/detailed_summary.yaml +1 -0
- sdg_hub/configs/knowledge/extractive_summary.yaml +1 -0
- sdg_hub/configs/knowledge/generate_questions.yaml +82 -0
- sdg_hub/configs/knowledge/generate_responses.yaml +86 -0
- sdg_hub/configs/skills/contexts.yaml +18 -11
- sdg_hub/configs/skills/evaluate_freeform_pair.yaml +79 -12
- sdg_hub/configs/skills/evaluate_freeform_questions.yaml +60 -28
- sdg_hub/configs/skills/evaluate_grounded_pair.yaml +95 -30
- sdg_hub/configs/skills/freeform_questions.yaml +21 -16
- sdg_hub/configs/skills/freeform_responses.yaml +19 -25
- sdg_hub/configs/skills/router.yaml +53 -6
- sdg_hub/flow.py +351 -21
- sdg_hub/flow_runner.py +216 -0
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +26 -9
- sdg_hub/flows/generation/skills/{agentic_improve_skill.yaml → improve_responses.yaml} +26 -31
- sdg_hub/flows/generation/skills/synth_skills.yaml +4 -4
- sdg_hub/pipeline.py +67 -12
- sdg_hub/prompts.py +26 -0
- sdg_hub/sdg.py +128 -86
- sdg_hub/utils/config_validation.py +91 -0
- sdg_hub/utils/validation_result.py +10 -0
- sdg_hub-0.1.1.dist-info/METADATA +190 -0
- sdg_hub-0.1.1.dist-info/RECORD +86 -0
- {sdg_hub-0.1.0a3.dist-info → sdg_hub-0.1.1.dist-info}/WHEEL +1 -1
- sdg_hub/blocks/filterblock.py +0 -76
- sdg_hub/blocks/iterblock.py +0 -31
- sdg_hub/blocks/rmblocks.py +0 -194
- sdg_hub/configs/annotations/simple.yaml +0 -10
- sdg_hub/configs/knowledge/data_recipe/default_recipe.yaml +0 -3
- sdg_hub/configs/skills/data_recipe/default_recipe.yaml +0 -6
- sdg_hub/flows/annotation/emotion/detailed_description.yaml +0 -19
- sdg_hub/flows/annotation/emotion/detailed_description_icl.yaml +0 -19
- sdg_hub/flows/annotation/emotion/simple.yaml +0 -19
- sdg_hub/utils/chunking.py +0 -73
- sdg_hub/utils/docprocessor.py +0 -357
- sdg_hub/utils/parse_and_convert.py +0 -392
- sdg_hub-0.1.0a3.dist-info/METADATA +0 -154
- sdg_hub-0.1.0a3.dist-info/RECORD +0 -90
- /sdg_hub/configs/{knowledge/data_recipe → reasoning}/__init__.py +0 -0
- /sdg_hub/configs/skills/{_G_.yaml → icl_examples/STEM.yaml} +0 -0
- /sdg_hub/configs/skills/{data_recipe → icl_examples}/__init__.py +0 -0
- /sdg_hub/configs/skills/{_A_.yaml → icl_examples/coding.yaml} +0 -0
- /sdg_hub/configs/skills/{_B_.yaml → icl_examples/extraction.yaml} +0 -0
- /sdg_hub/configs/skills/{_C_.yaml → icl_examples/humanities.yaml} +0 -0
- /sdg_hub/configs/skills/{_D_.yaml → icl_examples/math.yaml} +0 -0
- /sdg_hub/configs/skills/{_E_.yaml → icl_examples/reasoning.yaml} +0 -0
- /sdg_hub/configs/skills/{_F_.yaml → icl_examples/roleplay.yaml} +0 -0
- /sdg_hub/configs/skills/{_H_.yaml → icl_examples/writing.yaml} +0 -0
- {sdg_hub-0.1.0a3.dist-info → sdg_hub-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.1.0a3.dist-info → sdg_hub-0.1.1.dist-info}/top_level.txt +0 -0
sdg_hub/_version.py
CHANGED
sdg_hub/blocks/__init__.py
CHANGED
@@ -1,6 +1,36 @@
|
|
1
|
+
"""Block implementations for SDG Hub.
|
2
|
+
|
3
|
+
This package provides various block implementations for data generation, processing, and transformation.
|
4
|
+
"""
|
5
|
+
|
1
6
|
# Local
|
2
|
-
from .block import
|
3
|
-
from .
|
4
|
-
from .
|
5
|
-
|
6
|
-
|
7
|
+
from .block import Block
|
8
|
+
from .llmblock import LLMBlock, ConditionalLLMBlock
|
9
|
+
from .utilblocks import (
|
10
|
+
SamplePopulatorBlock,
|
11
|
+
SelectorBlock,
|
12
|
+
CombineColumnsBlock,
|
13
|
+
FlattenColumnsBlock,
|
14
|
+
DuplicateColumns,
|
15
|
+
RenameColumns,
|
16
|
+
SetToMajorityValue,
|
17
|
+
FilterByValueBlock,
|
18
|
+
IterBlock,
|
19
|
+
)
|
20
|
+
from ..registry import BlockRegistry
|
21
|
+
|
22
|
+
__all__ = [
|
23
|
+
"Block",
|
24
|
+
"FilterByValueBlock",
|
25
|
+
"IterBlock",
|
26
|
+
"LLMBlock",
|
27
|
+
"ConditionalLLMBlock",
|
28
|
+
"SamplePopulatorBlock",
|
29
|
+
"SelectorBlock",
|
30
|
+
"CombineColumnsBlock",
|
31
|
+
"FlattenColumnsBlock",
|
32
|
+
"DuplicateColumns",
|
33
|
+
"RenameColumns",
|
34
|
+
"SetToMajorityValue",
|
35
|
+
"BlockRegistry",
|
36
|
+
]
|
sdg_hub/blocks/block.py
CHANGED
@@ -1,8 +1,14 @@
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Base block implementation for the SDG Hub system.
|
3
|
+
|
4
|
+
This module provides the abstract base class for all blocks in the system,
|
5
|
+
including functionality for template validation and configuration management.
|
6
|
+
"""
|
7
|
+
|
2
8
|
# Standard
|
3
9
|
from abc import ABC
|
4
10
|
from collections import ChainMap
|
5
|
-
from typing import Any, Dict,
|
11
|
+
from typing import Any, Dict, Optional
|
6
12
|
|
7
13
|
# Third Party
|
8
14
|
from jinja2 import Template, UndefinedError
|
@@ -17,24 +23,38 @@ logger = setup_logger(__name__)
|
|
17
23
|
|
18
24
|
@BlockRegistry.register("Block")
|
19
25
|
class Block(ABC):
|
26
|
+
"""Base abstract class for all blocks in the system.
|
27
|
+
|
28
|
+
This class provides common functionality for block validation and configuration loading.
|
29
|
+
All specific block implementations should inherit from this class.
|
30
|
+
"""
|
31
|
+
|
20
32
|
def __init__(self, block_name: str) -> None:
|
21
33
|
self.block_name = block_name
|
22
34
|
|
23
35
|
@staticmethod
|
24
36
|
def _validate(prompt_template: Template, input_dict: Dict[str, Any]) -> bool:
|
25
|
-
"""
|
26
|
-
|
27
|
-
variables in the Jinja template are provided in the input_dict.
|
37
|
+
"""Validate the input data for this block.
|
38
|
+
|
39
|
+
This method validates whether all required variables in the Jinja template are provided in the input_dict.
|
40
|
+
|
41
|
+
Parameters
|
42
|
+
----------
|
43
|
+
prompt_template : Template
|
44
|
+
The Jinja2 template object.
|
45
|
+
input_dict : Dict[str, Any]
|
46
|
+
A dictionary of input values to check against the template.
|
28
47
|
|
29
|
-
|
30
|
-
|
31
|
-
|
48
|
+
Returns
|
49
|
+
-------
|
50
|
+
bool
|
51
|
+
True if the input data is valid (i.e., no missing variables), False otherwise.
|
32
52
|
"""
|
33
|
-
|
53
|
+
|
34
54
|
class Default(dict):
|
35
55
|
def __missing__(self, key: str) -> None:
|
36
56
|
raise KeyError(key)
|
37
|
-
|
57
|
+
|
38
58
|
try:
|
39
59
|
# Try rendering the template with the input_dict
|
40
60
|
prompt_template.render(ChainMap(input_dict, Default()))
|
@@ -43,12 +63,34 @@ class Block(ABC):
|
|
43
63
|
logger.error(f"Missing key: {e}")
|
44
64
|
return False
|
45
65
|
|
46
|
-
def _load_config(self, config_path: str) ->
|
47
|
-
"""
|
48
|
-
Load the configuration file for this block.
|
66
|
+
def _load_config(self, config_path: str) -> Optional[Dict[str, Any]]:
|
67
|
+
"""Load the configuration file for this block.
|
49
68
|
|
50
|
-
|
51
|
-
|
69
|
+
Parameters
|
70
|
+
----------
|
71
|
+
config_path : str
|
72
|
+
The path to the configuration file.
|
73
|
+
|
74
|
+
Returns
|
75
|
+
-------
|
76
|
+
Optional[Dict[str, Any]]
|
77
|
+
The loaded configuration. Returns None if file cannot be read or parsed.
|
78
|
+
|
79
|
+
Raises
|
80
|
+
------
|
81
|
+
FileNotFoundError
|
82
|
+
If the configuration file does not exist.
|
52
83
|
"""
|
53
|
-
|
54
|
-
|
84
|
+
try:
|
85
|
+
with open(config_path, "r", encoding="utf-8") as config_file:
|
86
|
+
try:
|
87
|
+
return yaml.safe_load(config_file)
|
88
|
+
except yaml.YAMLError as e:
|
89
|
+
logger.error(f"Error parsing YAML from {config_path}: {e}")
|
90
|
+
return None
|
91
|
+
except FileNotFoundError:
|
92
|
+
logger.error(f"Configuration file not found: {config_path}")
|
93
|
+
raise
|
94
|
+
except Exception as e:
|
95
|
+
logger.error(f"Unexpected error reading config file {config_path}: {e}")
|
96
|
+
return None
|
sdg_hub/blocks/llmblock.py
CHANGED
@@ -1,7 +1,11 @@
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""LLM-based blocks for text generation and processing.
|
3
|
+
|
4
|
+
This module provides blocks for interacting with language models.
|
5
|
+
"""
|
6
|
+
|
2
7
|
# Standard
|
3
|
-
from
|
4
|
-
from typing import Any, Dict, List
|
8
|
+
from typing import Any, Dict, List, Optional, Union
|
5
9
|
import json
|
6
10
|
import re
|
7
11
|
|
@@ -18,7 +22,18 @@ from ..registry import BlockRegistry, PromptRegistry
|
|
18
22
|
logger = setup_logger(__name__)
|
19
23
|
|
20
24
|
|
21
|
-
def server_supports_batched(client, model_id: str) -> bool:
|
25
|
+
def server_supports_batched(client: openai.OpenAI, model_id: str) -> bool:
|
26
|
+
"""Check if the server supports batched inputs.
|
27
|
+
|
28
|
+
This function checks if the server supports batched inputs by making a test call to the server.
|
29
|
+
|
30
|
+
Parameters
|
31
|
+
----------
|
32
|
+
client : openai.OpenAI
|
33
|
+
The client to use to make the test call.
|
34
|
+
model_id : str
|
35
|
+
The model ID to use for the test call.
|
36
|
+
"""
|
22
37
|
supported = getattr(client, "server_supports_batched", None)
|
23
38
|
if supported is not None:
|
24
39
|
return supported
|
@@ -38,19 +53,43 @@ def server_supports_batched(client, model_id: str) -> bool:
|
|
38
53
|
|
39
54
|
|
40
55
|
@BlockRegistry.register("LLMBlock")
|
41
|
-
# pylint: disable=dangerous-default-value
|
42
56
|
class LLMBlock(Block):
|
57
|
+
"""Block for generating text using language models.
|
58
|
+
|
59
|
+
This block handles text generation, prompt formatting, and output parsing
|
60
|
+
for language model interactions.
|
61
|
+
|
62
|
+
Parameters
|
63
|
+
----------
|
64
|
+
block_name : str
|
65
|
+
Name of the block.
|
66
|
+
config_path : str
|
67
|
+
Path to the configuration file.
|
68
|
+
client : openai.OpenAI
|
69
|
+
OpenAI client instance.
|
70
|
+
output_cols : List[str]
|
71
|
+
List of output column names.
|
72
|
+
parser_kwargs : Dict[str, Any], optional
|
73
|
+
Keyword arguments for the parser, by default {}.
|
74
|
+
model_prompt : str, optional
|
75
|
+
Template string for model prompt, by default "{prompt}".
|
76
|
+
model_id : Optional[str], optional
|
77
|
+
Model ID to use, by default None.
|
78
|
+
**batch_kwargs : Dict[str, Any]
|
79
|
+
Additional keyword arguments for batch processing.
|
80
|
+
"""
|
81
|
+
|
43
82
|
# pylint: disable=too-many-instance-attributes
|
44
83
|
def __init__(
|
45
84
|
self,
|
46
|
-
block_name,
|
47
|
-
config_path,
|
48
|
-
client,
|
49
|
-
output_cols,
|
50
|
-
parser_kwargs={},
|
51
|
-
model_prompt="{prompt}",
|
52
|
-
model_id=None,
|
53
|
-
**batch_kwargs,
|
85
|
+
block_name: str,
|
86
|
+
config_path: str,
|
87
|
+
client: openai.OpenAI,
|
88
|
+
output_cols: List[str],
|
89
|
+
parser_kwargs: Dict[str, Any] = {},
|
90
|
+
model_prompt: str = "{prompt}",
|
91
|
+
model_id: Optional[str] = None,
|
92
|
+
**batch_kwargs: Dict[str, Any],
|
54
93
|
) -> None:
|
55
94
|
super().__init__(block_name)
|
56
95
|
self.block_config = self._load_config(config_path)
|
@@ -84,7 +123,27 @@ class LLMBlock(Block):
|
|
84
123
|
# and supports the n parameter to generate n outputs per input
|
85
124
|
self.server_supports_batched = server_supports_batched(client, self.model)
|
86
125
|
|
87
|
-
def
|
126
|
+
def _extract_matches(
|
127
|
+
self, text: str, start_tag: Optional[str], end_tag: Optional[str]
|
128
|
+
) -> List[str]:
|
129
|
+
if not text:
|
130
|
+
return []
|
131
|
+
if not start_tag and not end_tag:
|
132
|
+
return [text.strip()]
|
133
|
+
|
134
|
+
pattern = ""
|
135
|
+
if start_tag:
|
136
|
+
pattern += re.escape(start_tag)
|
137
|
+
pattern += r"(.*?)"
|
138
|
+
if end_tag:
|
139
|
+
pattern += re.escape(end_tag)
|
140
|
+
elif start_tag:
|
141
|
+
# Enforce matching till end of string when only start_tag is provided.
|
142
|
+
pattern += "$"
|
143
|
+
|
144
|
+
return [match.strip() for match in re.findall(pattern, text, re.DOTALL)]
|
145
|
+
|
146
|
+
def _parse(self, generated_string: str) -> dict:
|
88
147
|
matches = {}
|
89
148
|
|
90
149
|
if self.parser_name is not None and self.parser_name == "custom":
|
@@ -108,16 +167,9 @@ class LLMBlock(Block):
|
|
108
167
|
self.block_config.get("end_tags", []),
|
109
168
|
self.output_cols,
|
110
169
|
):
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
]
|
115
|
-
else:
|
116
|
-
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
|
117
|
-
all_matches = re.findall(pattern, generated_string, re.DOTALL)
|
118
|
-
matches[output_col] = (
|
119
|
-
[match.strip() for match in all_matches] if all_matches else []
|
120
|
-
)
|
170
|
+
matches[output_col] = self._extract_matches(
|
171
|
+
generated_string, start_tag, end_tag
|
172
|
+
)
|
121
173
|
|
122
174
|
return matches
|
123
175
|
|
@@ -127,7 +179,7 @@ class LLMBlock(Block):
|
|
127
179
|
self.model_prompt, prompt_templated_str, add_generation_prompt=True
|
128
180
|
).strip()
|
129
181
|
|
130
|
-
def _generate(self, samples, **gen_kwargs) -> list:
|
182
|
+
def _generate(self, samples: Dataset, **gen_kwargs: Dict[str, Any]) -> list:
|
131
183
|
prompts = [self._format_prompt(sample) for sample in samples]
|
132
184
|
logger.debug("Prompt: %s", prompts[0])
|
133
185
|
generate_args = {**self.defaults, **gen_kwargs}
|
@@ -159,12 +211,16 @@ class LLMBlock(Block):
|
|
159
211
|
results.append(response.choices[0].text.strip())
|
160
212
|
return results
|
161
213
|
|
162
|
-
def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
|
163
|
-
"""
|
164
|
-
|
214
|
+
def generate(self, samples: Dataset, **gen_kwargs: Dict[str, Any]) -> Dataset:
|
215
|
+
"""Generate the output from the block.
|
216
|
+
|
217
|
+
This method should first validate the input data,
|
165
218
|
then generate the output, and finally parse the generated output before returning it.
|
166
219
|
|
167
|
-
|
220
|
+
Returns
|
221
|
+
-------
|
222
|
+
Dataset
|
223
|
+
The parsed output after generation.
|
168
224
|
"""
|
169
225
|
num_samples = self.block_config.get("num_samples", None)
|
170
226
|
logger.debug("Generating outputs for {} samples".format(len(samples)))
|
@@ -219,16 +275,40 @@ class LLMBlock(Block):
|
|
219
275
|
|
220
276
|
@BlockRegistry.register("ConditionalLLMBlock")
|
221
277
|
class ConditionalLLMBlock(LLMBlock):
|
278
|
+
"""Block for conditional text generation using language models.
|
279
|
+
|
280
|
+
This block selects different prompt templates based on a selector column value.
|
281
|
+
|
282
|
+
Parameters
|
283
|
+
----------
|
284
|
+
block_name : str
|
285
|
+
Name of the block.
|
286
|
+
config_paths : Dict[str, str]
|
287
|
+
Dictionary mapping selector values to their config file paths.
|
288
|
+
client : openai.OpenAI
|
289
|
+
OpenAI client instance.
|
290
|
+
model_id : str
|
291
|
+
Model ID to use.
|
292
|
+
output_cols : List[str]
|
293
|
+
List of output column names.
|
294
|
+
selector_column_name : str
|
295
|
+
Name of the column used to select the prompt template.
|
296
|
+
model_prompt : str, optional
|
297
|
+
Template string for model prompt, by default "{prompt}".
|
298
|
+
**batch_kwargs : Dict[str, Any]
|
299
|
+
Additional keyword arguments for batch processing.
|
300
|
+
"""
|
301
|
+
|
222
302
|
def __init__(
|
223
303
|
self,
|
224
|
-
block_name,
|
225
|
-
config_paths,
|
226
|
-
client,
|
227
|
-
model_id,
|
228
|
-
output_cols,
|
229
|
-
selector_column_name,
|
230
|
-
model_prompt="{prompt}",
|
231
|
-
**batch_kwargs,
|
304
|
+
block_name: str,
|
305
|
+
config_paths: Dict[str, str],
|
306
|
+
client: openai.OpenAI,
|
307
|
+
model_id: str,
|
308
|
+
output_cols: List[str],
|
309
|
+
selector_column_name: str,
|
310
|
+
model_prompt: str = "{prompt}",
|
311
|
+
**batch_kwargs: Dict[str, Any],
|
232
312
|
) -> None:
|
233
313
|
super().__init__(
|
234
314
|
block_name=block_name,
|
@@ -245,15 +325,27 @@ class ConditionalLLMBlock(LLMBlock):
|
|
245
325
|
self.prompt_template = self.prompt_struct.format(**self.block_config)
|
246
326
|
else:
|
247
327
|
for config_key, config in config_paths.items():
|
248
|
-
# Template(self.prompt_struct.format(**filtered_config))
|
249
328
|
filtered_config = {
|
250
|
-
k: (v if v is not None else "")
|
329
|
+
k: (v if v is not None else "")
|
330
|
+
for k, v in self.block_config.items()
|
251
331
|
}
|
252
|
-
self.prompt_template[config_key] = Template(
|
253
|
-
**self._load_config(config)
|
254
|
-
)
|
332
|
+
self.prompt_template[config_key] = Template(
|
333
|
+
self.prompt_struct.format(**self._load_config(config))
|
334
|
+
)
|
255
335
|
|
256
|
-
def _format_prompt(self, sample: Dict) -> str:
|
336
|
+
def _format_prompt(self, sample: Dict[str, Any]) -> str:
|
337
|
+
"""Format the prompt based on the selector column value.
|
338
|
+
|
339
|
+
Parameters
|
340
|
+
----------
|
341
|
+
sample : Dict[str, Any]
|
342
|
+
Input sample containing the selector column.
|
343
|
+
|
344
|
+
Returns
|
345
|
+
-------
|
346
|
+
str
|
347
|
+
Formatted prompt string.
|
348
|
+
"""
|
257
349
|
if isinstance(self.prompt_template, dict):
|
258
350
|
return (
|
259
351
|
self.prompt_template[sample[self.selector_column_name]]
|
@@ -263,168 +355,21 @@ class ConditionalLLMBlock(LLMBlock):
|
|
263
355
|
|
264
356
|
return self.prompt_template.render(**sample).strip()
|
265
357
|
|
266
|
-
def _validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool:
|
267
|
-
|
268
|
-
prompt_template = prompt_template[input_dict[self.selector_column_name]]
|
269
|
-
return super()._validate(prompt_template, input_dict)
|
270
|
-
|
271
|
-
|
272
|
-
@BlockRegistry.register("LLMLogProbBlock")
|
273
|
-
class LLMLogProbBlock(LLMBlock):
|
274
|
-
# init with init of the parent class
|
275
|
-
def __init__(
|
276
|
-
self,
|
277
|
-
block_name,
|
278
|
-
config_path,
|
279
|
-
client,
|
280
|
-
output_cols,
|
281
|
-
parser_kwargs={},
|
282
|
-
model_prompt="{prompt}",
|
283
|
-
model_id=None,
|
284
|
-
**batch_kwargs,
|
285
|
-
) -> None:
|
286
|
-
super().__init__(
|
287
|
-
block_name=block_name,
|
288
|
-
config_path=config_path,
|
289
|
-
client=client,
|
290
|
-
output_cols=output_cols,
|
291
|
-
parser_kwargs=parser_kwargs,
|
292
|
-
model_prompt=model_prompt,
|
293
|
-
model_id=model_id,
|
294
|
-
**batch_kwargs,
|
295
|
-
)
|
296
|
-
|
297
|
-
def _generate_logprobs(self, samples, **gen_kwargs):
|
298
|
-
prompts = [
|
299
|
-
self.model_prompt.format(prompt=self._format_prompt(sample))
|
300
|
-
for sample in samples
|
301
|
-
]
|
302
|
-
generate_args = {**self.defaults, **gen_kwargs}
|
303
|
-
|
304
|
-
# verify if logprobs is mentioned in the generate_args, if not add it and return top10 logprobs
|
305
|
-
if "logprobs" not in generate_args:
|
306
|
-
generate_args["logprobs"] = 10
|
358
|
+
def _validate(self, prompt_template: Union[str, Template], input_dict: Dict[str, Any]) -> bool:
|
359
|
+
"""Validate the input data for this block.
|
307
360
|
|
308
|
-
|
309
|
-
|
310
|
-
|
361
|
+
Parameters
|
362
|
+
----------
|
363
|
+
prompt_template : Union[str, Template]
|
364
|
+
The template to validate against.
|
365
|
+
input_dict : Dict[str, Any]
|
366
|
+
Input data to validate.
|
311
367
|
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
response = self.client.completions.create(
|
317
|
-
prompt=prompt, **generate_args
|
318
|
-
)
|
319
|
-
results.append(response.choices[0].logprobs.top_logprobs)
|
320
|
-
return results
|
321
|
-
|
322
|
-
def _parse(self, generations: List[List[Dict]]) -> List[List[str]]:
|
323
|
-
# override the parse method to convert the generations to json string
|
324
|
-
# convert the generations to json string to save as dataset
|
325
|
-
# this is because the dataset can only store key value pairs which are consistent
|
326
|
-
return [[json.dumps(item) for item in sublist] for sublist in generations]
|
327
|
-
|
328
|
-
def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
|
368
|
+
Returns
|
369
|
+
-------
|
370
|
+
bool
|
371
|
+
True if the input data is valid, False otherwise.
|
329
372
|
"""
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
:return: The parsed output after generation.
|
334
|
-
"""
|
335
|
-
num_samples = self.block_config.get("num_samples", None)
|
336
|
-
logger.debug("Generating outputs for {} samples".format(len(samples)))
|
337
|
-
|
338
|
-
if (num_samples is not None) and ("num_samples" not in samples.column_names):
|
339
|
-
samples = samples.add_column("num_samples", [num_samples] * len(samples))
|
340
|
-
|
341
|
-
# validate each sample
|
342
|
-
# Log errors and remove invalid samples
|
343
|
-
valid_samples = []
|
344
|
-
|
345
|
-
for sample in samples:
|
346
|
-
if self._validate(self.prompt_template, sample):
|
347
|
-
valid_samples.append(sample)
|
348
|
-
else:
|
349
|
-
logger.warning(
|
350
|
-
f"Sample failed validation: {sample}"
|
351
|
-
) # Log details of the failed sample
|
352
|
-
|
353
|
-
samples = valid_samples
|
354
|
-
|
355
|
-
if len(samples) == 0:
|
356
|
-
logger.warning(
|
357
|
-
"No valid samples to generate outputs for, returning empty dataset"
|
358
|
-
)
|
359
|
-
return Dataset.from_list([])
|
360
|
-
|
361
|
-
# generate the output
|
362
|
-
|
363
|
-
outputs = self._generate_logprobs(samples, **gen_kwargs)
|
364
|
-
logger.debug("Generated outputs: %s", outputs)
|
365
|
-
|
366
|
-
output_dataset = Dataset.from_list(samples)
|
367
|
-
output_dataset = output_dataset.add_column(
|
368
|
-
self.output_cols[0],
|
369
|
-
self._parse(outputs), # pylint: disable=no-value-for-parameter
|
370
|
-
)
|
371
|
-
|
372
|
-
return output_dataset
|
373
|
-
|
374
|
-
|
375
|
-
@BlockRegistry.register("LLMMessagesBlock")
|
376
|
-
class LLMMessagesBlock(Block):
|
377
|
-
def __init__(
|
378
|
-
self,
|
379
|
-
block_name,
|
380
|
-
client,
|
381
|
-
input_col,
|
382
|
-
output_col,
|
383
|
-
model_prompt=None,
|
384
|
-
model_id=None,
|
385
|
-
**batch_kwargs,
|
386
|
-
) -> None:
|
387
|
-
self.block_name = block_name
|
388
|
-
self.model_prompt = model_prompt
|
389
|
-
self.batch_params = batch_kwargs.get("batch_kwargs", {})
|
390
|
-
self.input_col = input_col
|
391
|
-
self.output_col = output_col
|
392
|
-
self.client = client
|
393
|
-
|
394
|
-
if model_id:
|
395
|
-
self.model = model_id
|
396
|
-
else:
|
397
|
-
self.model = self.client.models.list().data[0].id
|
398
|
-
|
399
|
-
self.defaults = {
|
400
|
-
"model": self.model,
|
401
|
-
"temperature": 0,
|
402
|
-
"max_tokens": 4096,
|
403
|
-
}
|
404
|
-
self.server_supports_batched = server_supports_batched(client, self.model)
|
405
|
-
|
406
|
-
def _generate(self, samples, **gen_kwargs) -> list:
|
407
|
-
generate_args = {**self.defaults, **gen_kwargs}
|
408
|
-
|
409
|
-
if "n" in generate_args and generate_args.get("temperature", 0) <= 0:
|
410
|
-
generate_args["temperature"] = 0.7
|
411
|
-
logger.warning(
|
412
|
-
"Temperature should be greater than 0 for n > 1, setting temperature to 0.7"
|
413
|
-
)
|
414
|
-
|
415
|
-
messages = samples[self.input_col]
|
416
|
-
|
417
|
-
results = []
|
418
|
-
n = gen_kwargs.get("n", 1)
|
419
|
-
for message in messages:
|
420
|
-
responses = self.client.chat.completions.create(messages=message, **generate_args)
|
421
|
-
if n > 1:
|
422
|
-
results.append([choice.message.content for choice in responses.choices])
|
423
|
-
else:
|
424
|
-
results.append(responses.choices[0].message.content)
|
425
|
-
return results
|
426
|
-
|
427
|
-
def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
|
428
|
-
outputs = self._generate(samples, **gen_kwargs)
|
429
|
-
samples = samples.add_column(self.output_col, outputs)
|
430
|
-
return samples
|
373
|
+
if isinstance(prompt_template, dict):
|
374
|
+
prompt_template = prompt_template[input_dict[self.selector_column_name]]
|
375
|
+
return super()._validate(prompt_template, input_dict)
|