sdg-hub 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. sdg_hub/__init__.py +4 -0
  2. sdg_hub/_version.py +21 -0
  3. sdg_hub/blocks/__init__.py +6 -0
  4. sdg_hub/blocks/block.py +54 -0
  5. sdg_hub/blocks/filterblock.py +76 -0
  6. sdg_hub/blocks/iterblock.py +31 -0
  7. sdg_hub/blocks/llmblock.py +430 -0
  8. sdg_hub/blocks/rmblocks.py +194 -0
  9. sdg_hub/blocks/utilblocks.py +140 -0
  10. sdg_hub/configs/__init__.py +0 -0
  11. sdg_hub/configs/annotations/__init__.py +0 -0
  12. sdg_hub/configs/annotations/cot_reflection.yaml +34 -0
  13. sdg_hub/configs/annotations/detailed_description.yaml +10 -0
  14. sdg_hub/configs/annotations/detailed_description_icl.yaml +32 -0
  15. sdg_hub/configs/annotations/simple.yaml +10 -0
  16. sdg_hub/configs/knowledge/__init__.py +0 -0
  17. sdg_hub/configs/knowledge/atomic_facts.yaml +45 -0
  18. sdg_hub/configs/knowledge/auxilary_instructions.yaml +35 -0
  19. sdg_hub/configs/knowledge/data_recipe/__init__.py +0 -0
  20. sdg_hub/configs/knowledge/data_recipe/default_recipe.yaml +3 -0
  21. sdg_hub/configs/knowledge/detailed_summary.yaml +17 -0
  22. sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +68 -0
  23. sdg_hub/configs/knowledge/evaluate_question.yaml +38 -0
  24. sdg_hub/configs/knowledge/evaluate_relevancy.yaml +85 -0
  25. sdg_hub/configs/knowledge/extractive_summary.yaml +17 -0
  26. sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +39 -0
  27. sdg_hub/configs/knowledge/generate_questions_responses.yaml +56 -0
  28. sdg_hub/configs/knowledge/mcq_generation.yaml +83 -0
  29. sdg_hub/configs/knowledge/router.yaml +12 -0
  30. sdg_hub/configs/knowledge/simple_generate_qa.yaml +34 -0
  31. sdg_hub/configs/reasoning/dynamic_cot.yaml +40 -0
  32. sdg_hub/configs/skills/_A_.yaml +97 -0
  33. sdg_hub/configs/skills/_B_.yaml +36 -0
  34. sdg_hub/configs/skills/_C_.yaml +71 -0
  35. sdg_hub/configs/skills/_D_.yaml +85 -0
  36. sdg_hub/configs/skills/_E_.yaml +30 -0
  37. sdg_hub/configs/skills/_F_.yaml +45 -0
  38. sdg_hub/configs/skills/_G_.yaml +56 -0
  39. sdg_hub/configs/skills/_H_.yaml +80 -0
  40. sdg_hub/configs/skills/__init__.py +0 -0
  41. sdg_hub/configs/skills/analyzer.yaml +48 -0
  42. sdg_hub/configs/skills/annotation.yaml +36 -0
  43. sdg_hub/configs/skills/contexts.yaml +21 -0
  44. sdg_hub/configs/skills/critic.yaml +60 -0
  45. sdg_hub/configs/skills/data_recipe/__init__.py +0 -0
  46. sdg_hub/configs/skills/data_recipe/default_recipe.yaml +6 -0
  47. sdg_hub/configs/skills/evaluate_freeform_pair.yaml +44 -0
  48. sdg_hub/configs/skills/evaluate_freeform_questions.yaml +46 -0
  49. sdg_hub/configs/skills/evaluate_grounded_pair.yaml +54 -0
  50. sdg_hub/configs/skills/evaluate_grounded_questions.yaml +51 -0
  51. sdg_hub/configs/skills/freeform_questions.yaml +29 -0
  52. sdg_hub/configs/skills/freeform_responses.yaml +45 -0
  53. sdg_hub/configs/skills/grounded_questions.yaml +38 -0
  54. sdg_hub/configs/skills/grounded_responses.yaml +59 -0
  55. sdg_hub/configs/skills/judge.yaml +53 -0
  56. sdg_hub/configs/skills/planner.yaml +67 -0
  57. sdg_hub/configs/skills/respond.yaml +8 -0
  58. sdg_hub/configs/skills/revised_responder.yaml +78 -0
  59. sdg_hub/configs/skills/router.yaml +12 -0
  60. sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +27 -0
  61. sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +31 -0
  62. sdg_hub/flow.py +127 -0
  63. sdg_hub/flows/annotation/emotion/detailed_description.yaml +19 -0
  64. sdg_hub/flows/annotation/emotion/detailed_description_icl.yaml +19 -0
  65. sdg_hub/flows/annotation/emotion/simple.yaml +19 -0
  66. sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +13 -0
  67. sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +12 -0
  68. sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +89 -0
  69. sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +136 -0
  70. sdg_hub/flows/generation/skills/agentic_improve_skill.yaml +108 -0
  71. sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +12 -0
  72. sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +12 -0
  73. sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +80 -0
  74. sdg_hub/flows/generation/skills/synth_skills.yaml +59 -0
  75. sdg_hub/logger_config.py +20 -0
  76. sdg_hub/pipeline.py +66 -0
  77. sdg_hub/prompts.py +17 -0
  78. sdg_hub/py.typed +0 -0
  79. sdg_hub/registry.py +122 -0
  80. sdg_hub/sdg.py +164 -0
  81. sdg_hub/utils/__init__.py +5 -0
  82. sdg_hub/utils/chunking.py +73 -0
  83. sdg_hub/utils/datamixing.py +123 -0
  84. sdg_hub/utils/datautils.py +14 -0
  85. sdg_hub/utils/docprocessor.py +357 -0
  86. sdg_hub/utils/json.py +48 -0
  87. sdg_hub/utils/models.py +31 -0
  88. sdg_hub/utils/parse_and_convert.py +392 -0
  89. sdg_hub/utils/taxonomy.py +489 -0
  90. sdg_hub-0.1.0a1.dist-info/METADATA +154 -0
  91. sdg_hub-0.1.0a1.dist-info/RECORD +94 -0
  92. sdg_hub-0.1.0a1.dist-info/WHEEL +5 -0
  93. sdg_hub-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  94. sdg_hub-0.1.0a1.dist-info/top_level.txt +1 -0
sdg_hub/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Local
3
+ from .sdg import SDG
4
+ from ._version import __version__
sdg_hub/_version.py ADDED
@@ -0,0 +1,21 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
5
+
6
+ TYPE_CHECKING = False
7
+ if TYPE_CHECKING:
8
+ from typing import Tuple
9
+ from typing import Union
10
+
11
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
12
+ else:
13
+ VERSION_TUPLE = object
14
+
15
+ version: str
16
+ __version__: str
17
+ __version_tuple__: VERSION_TUPLE
18
+ version_tuple: VERSION_TUPLE
19
+
20
+ __version__ = version = '0.1.0a1'
21
+ __version_tuple__ = version_tuple = (0, 1, 0)
@@ -0,0 +1,6 @@
1
+ # Local
2
+ from .block import *
3
+ from .filterblock import *
4
+ from .iterblock import *
5
+ from .llmblock import *
6
+ from .utilblocks import *
@@ -0,0 +1,54 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Standard
3
+ from abc import ABC
4
+ from collections import ChainMap
5
+ from typing import Any, Dict, Union
6
+
7
+ # Third Party
8
+ from jinja2 import Template, UndefinedError
9
+ import yaml
10
+
11
+ # Local
12
+ from ..registry import BlockRegistry
13
+ from ..logger_config import setup_logger
14
+
15
+ logger = setup_logger(__name__)
16
+
17
+
18
+ @BlockRegistry.register("Block")
19
+ class Block(ABC):
20
+ def __init__(self, block_name: str) -> None:
21
+ self.block_name = block_name
22
+
23
+ @staticmethod
24
+ def _validate(prompt_template: Template, input_dict: Dict[str, Any]) -> bool:
25
+ """
26
+ Validate the input data for this block. This method validates whether all required
27
+ variables in the Jinja template are provided in the input_dict.
28
+
29
+ :param prompt_template: The Jinja2 template object.
30
+ :param input_dict: A dictionary of input values to check against the template.
31
+ :return: True if the input data is valid (i.e., no missing variables), False otherwise.
32
+ """
33
+
34
+ class Default(dict):
35
+ def __missing__(self, key: str) -> None:
36
+ raise KeyError(key)
37
+
38
+ try:
39
+ # Try rendering the template with the input_dict
40
+ prompt_template.render(ChainMap(input_dict, Default()))
41
+ return True
42
+ except UndefinedError as e:
43
+ logger.error(f"Missing key: {e}")
44
+ return False
45
+
46
+ def _load_config(self, config_path: str) -> Union[Dict[str, Any], None]:
47
+ """
48
+ Load the configuration file for this block.
49
+
50
+ :param config_path: The path to the configuration file.
51
+ :return: The loaded configuration.
52
+ """
53
+ with open(config_path, "r", encoding="utf-8") as config_file:
54
+ return yaml.safe_load(config_file)
@@ -0,0 +1,76 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Standard
3
+ import operator
4
+
5
+ # Third Party
6
+ from datasets import Dataset
7
+
8
+ # Local
9
+ from .block import Block
10
+ from ..registry import BlockRegistry
11
+ from ..logger_config import setup_logger
12
+
13
+ logger = setup_logger(__name__)
14
+
15
+
16
+ @BlockRegistry.register("FilterByValueBlock")
17
+ class FilterByValueBlock(Block):
18
+ def __init__(
19
+ self, filter_column, filter_value, operation, convert_dtype=None, **batch_kwargs
20
+ ) -> None:
21
+ """
22
+ Initializes a new instance of the FilterByValueBlock class.
23
+
24
+ Parameters:
25
+ - filter_column (str): The name of the column in the dataset to apply the filter on.
26
+ - filter_value (any or list of any): The value(s) to filter by.
27
+ - operation (callable): A function that takes two arguments (column value and filter value) and returns a boolean indicating whether the row should be included in the filtered dataset.
28
+ - convert_dtype (callable, optional): A function to convert the data type of the filter column before applying the filter. Defaults to None.
29
+ - **batch_kwargs: Additional kwargs for batch processing.
30
+
31
+ Returns:
32
+ None
33
+ """
34
+ super().__init__(block_name=self.__class__.__name__)
35
+ self.value = filter_value if isinstance(filter_value, list) else [filter_value]
36
+ self.column_name = filter_column
37
+ self.operation = operation
38
+ self.convert_dtype = convert_dtype
39
+ self.num_procs = batch_kwargs.get("num_procs", 1)
40
+
41
+ def _convert_dtype(self, sample):
42
+ try:
43
+ sample[self.column_name] = self.convert_dtype(sample[self.column_name])
44
+ except ValueError as e:
45
+ logger.error(
46
+ "Error converting dtype: %s, filling with None to be filtered later", e
47
+ )
48
+ sample[self.column_name] = None
49
+ return sample
50
+
51
+ def generate(self, samples) -> Dataset:
52
+ if self.convert_dtype:
53
+ samples = samples.map(
54
+ self._convert_dtype,
55
+ num_proc=self.num_procs,
56
+ )
57
+
58
+ if self.operation == operator.contains:
59
+ samples = samples.filter(
60
+ lambda x: self.operation(self.value, x[self.column_name]),
61
+ num_proc=self.num_procs,
62
+ )
63
+
64
+ samples = samples.filter(
65
+ lambda x: x[self.column_name] is not None,
66
+ num_proc=self.num_procs,
67
+ )
68
+
69
+ samples = samples.filter(
70
+ lambda x: any(
71
+ self.operation(x[self.column_name], value) for value in self.value
72
+ ),
73
+ num_proc=self.num_procs,
74
+ )
75
+
76
+ return samples
@@ -0,0 +1,31 @@
1
+ # Third Party
2
+ from datasets import Dataset
3
+
4
+ # Local
5
+ from .block import Block
6
+ from ..registry import BlockRegistry
7
+ from ..logger_config import setup_logger
8
+
9
+ logger = setup_logger(__name__)
10
+
11
+
12
+ @BlockRegistry.register("IterBlock")
13
+ class IterBlock(Block):
14
+ def __init__(self, block_name, num_iters, block_type, block_kwargs, **kwargs):
15
+ super().__init__(block_name)
16
+ self.num_iters = num_iters
17
+ self.block = block_type(**block_kwargs)
18
+ self.gen_kwargs = kwargs.get("gen_kwargs", {})
19
+ self.gen_kwargs = kwargs.get("gen_kwargs", {})
20
+
21
+ def generate(self, samples, **gen_kwargs) -> Dataset:
22
+ generated_samples = []
23
+ num_iters = self.num_iters
24
+
25
+ for _ in range(num_iters):
26
+ batch_generated = self.block.generate(
27
+ samples, **{**self.gen_kwargs, **gen_kwargs}
28
+ )
29
+ generated_samples.extend(batch_generated)
30
+
31
+ return Dataset.from_list(generated_samples)
@@ -0,0 +1,430 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Standard
3
+ from collections import Counter
4
+ from typing import Any, Dict, List
5
+ import json
6
+ import re
7
+
8
+ # Third Party
9
+ from datasets import Dataset
10
+ from jinja2 import Template
11
+ import openai
12
+
13
+ # Local
14
+ from .block import Block
15
+ from ..logger_config import setup_logger
16
+ from ..registry import BlockRegistry, PromptRegistry
17
+
18
+ logger = setup_logger(__name__)
19
+
20
+
21
+ def server_supports_batched(client, model_id: str) -> bool:
22
+ supported = getattr(client, "server_supports_batched", None)
23
+ if supported is not None:
24
+ return supported
25
+ try:
26
+ # Make a test call to the server to determine whether it supports
27
+ # multiple input prompts per request and also the n parameter
28
+ response = client.completions.create(
29
+ model=model_id, prompt=["test1", "test2"], max_tokens=1, n=3
30
+ )
31
+ # Number outputs should be 2 * 3 = 6
32
+ supported = len(response.choices) == 6
33
+ except openai.InternalServerError:
34
+ supported = False
35
+ client.server_supports_batched = supported
36
+ logger.info(f"LLM server supports batched inputs: {client.server_supports_batched}")
37
+ return supported
38
+
39
+
40
+ @BlockRegistry.register("LLMBlock")
41
+ # pylint: disable=dangerous-default-value
42
+ class LLMBlock(Block):
43
+ # pylint: disable=too-many-instance-attributes
44
+ def __init__(
45
+ self,
46
+ block_name,
47
+ config_path,
48
+ client,
49
+ output_cols,
50
+ parser_kwargs={},
51
+ model_prompt="{prompt}",
52
+ model_id=None,
53
+ **batch_kwargs,
54
+ ) -> None:
55
+ super().__init__(block_name)
56
+ self.block_config = self._load_config(config_path)
57
+ self.prompt_struct = (
58
+ """{system}\n{introduction}\n{principles}\n{examples}\n{generation}"""
59
+ )
60
+ filtered_config = {
61
+ k: (v if v is not None else "") for k, v in self.block_config.items()
62
+ }
63
+ self.prompt_template = Template(self.prompt_struct.format(**filtered_config))
64
+ self.client = client
65
+ if model_id:
66
+ self.model = model_id
67
+ else:
68
+ # get the default model id from client
69
+ self.model = self.client.models.list().data[0].id
70
+
71
+ self.model_prompt = model_prompt
72
+ self.output_cols = output_cols
73
+ self.batch_params = batch_kwargs.get("batch_kwargs", {})
74
+ self.parser_name = parser_kwargs.get("parser_name", None)
75
+ self.parsing_pattern = parser_kwargs.get("parsing_pattern", None)
76
+ self.parser_cleanup_tags = parser_kwargs.get("parser_cleanup_tags", None)
77
+ self.defaults = {
78
+ "model": self.model,
79
+ "temperature": 0,
80
+ "max_tokens": 4096,
81
+ }
82
+
83
+ # Whether the LLM server supports a list of input prompts
84
+ # and supports the n parameter to generate n outputs per input
85
+ self.server_supports_batched = server_supports_batched(client, self.model)
86
+
87
+ def _parse(self, generated_string) -> dict:
88
+ matches = {}
89
+
90
+ if self.parser_name is not None and self.parser_name == "custom":
91
+ pattern = re.compile(self.parsing_pattern, re.DOTALL)
92
+ all_matches = pattern.findall(generated_string)
93
+ matches = {column_name: [] for column_name in self.output_cols}
94
+ if all_matches and isinstance(all_matches[0], tuple):
95
+ for match in all_matches:
96
+ for column_name, value in zip(self.output_cols, match):
97
+ value = value.strip()
98
+ for clean_tag in self.parser_cleanup_tags:
99
+ value = value.replace(clean_tag, "")
100
+ matches[column_name].append(value)
101
+ else:
102
+ matches[self.output_cols[0]] = (
103
+ [match.strip() for match in all_matches] if all_matches else []
104
+ )
105
+ else:
106
+ for start_tag, end_tag, output_col in zip(
107
+ self.block_config.get("start_tags", []),
108
+ self.block_config.get("end_tags", []),
109
+ self.output_cols,
110
+ ):
111
+ if not start_tag and not end_tag:
112
+ matches[output_col] = [
113
+ generated_string.strip() if generated_string else None
114
+ ]
115
+ else:
116
+ pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
117
+ all_matches = re.findall(pattern, generated_string, re.DOTALL)
118
+ matches[output_col] = (
119
+ [match.strip() for match in all_matches] if all_matches else []
120
+ )
121
+
122
+ return matches
123
+
124
+ def _format_prompt(self, sample: Dict) -> str:
125
+ prompt_templated_str = self.prompt_template.render(sample).strip()
126
+ return PromptRegistry.render_template(
127
+ self.model_prompt, prompt_templated_str, add_generation_prompt=True
128
+ ).strip()
129
+
130
+ def _generate(self, samples, **gen_kwargs) -> list:
131
+ prompts = [self._format_prompt(sample) for sample in samples]
132
+ logger.debug("Prompt: %s", prompts[0])
133
+ generate_args = {**self.defaults, **gen_kwargs}
134
+
135
+ if self.server_supports_batched:
136
+ response = self.client.completions.create(prompt=prompts, **generate_args)
137
+ # if stop is provided, then we need to add the stop token to the generated text,
138
+ # this is because the stop token is not included in the generated text - this is a limitation of the openai api
139
+ # we need to add the stop token to the generated text to make it consistent for the parser
140
+ if "stop" in generate_args:
141
+ return [
142
+ choice.text.strip() + "".join(generate_args["stop"])
143
+ for choice in response.choices
144
+ ]
145
+ return [choice.text.strip() for choice in response.choices]
146
+
147
+ n = gen_kwargs.get("n", 1)
148
+ results = []
149
+ for prompt in prompts:
150
+ for _ in range(n):
151
+ response = self.client.completions.create(
152
+ prompt=prompt, **generate_args
153
+ )
154
+ if "stop" in generate_args:
155
+ results.append(
156
+ response.choices[0].text.strip()
157
+ + "".join(generate_args["stop"])
158
+ )
159
+ results.append(response.choices[0].text.strip())
160
+ return results
161
+
162
+ def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
163
+ """
164
+ Generate the output from the block. This method should first validate the input data,
165
+ then generate the output, and finally parse the generated output before returning it.
166
+
167
+ :return: The parsed output after generation.
168
+ """
169
+ num_samples = self.block_config.get("num_samples", None)
170
+ logger.debug("Generating outputs for {} samples".format(len(samples)))
171
+
172
+ if (num_samples is not None) and ("num_samples" not in samples.column_names):
173
+ samples = samples.add_column("num_samples", [num_samples] * len(samples))
174
+
175
+ # validate each sample
176
+ # Log errors and remove invalid samples
177
+ valid_samples = []
178
+
179
+ for sample in samples:
180
+ if self._validate(self.prompt_template, sample):
181
+ valid_samples.append(sample)
182
+ else:
183
+ logger.warning(
184
+ f"Sample failed validation: {sample}"
185
+ ) # Log details of the failed sample
186
+
187
+ samples = valid_samples
188
+
189
+ if len(samples) == 0:
190
+ logger.warning(
191
+ "No valid samples to generate outputs for, returning empty dataset"
192
+ )
193
+ return Dataset.from_list([])
194
+
195
+ # generate the output
196
+
197
+ outputs = self._generate(samples, **gen_kwargs)
198
+
199
+ logger.debug("Generated outputs: %s", outputs)
200
+
201
+ num_parallel_samples = gen_kwargs.get("n", 1)
202
+ extended_samples = []
203
+
204
+ # Duplicate each input sample n times, where n is the number
205
+ # of output sequences generated per input, so that we can
206
+ # pair up the inputs and outputs.
207
+ for item in samples:
208
+ extended_samples.extend([item] * num_parallel_samples)
209
+
210
+ new_data = []
211
+ for sample, output in zip(extended_samples, outputs):
212
+ parsed_outputs = self._parse(output)
213
+ max_length = max(len(value) for value in parsed_outputs.values())
214
+ for values in zip(*(lst[:max_length] for lst in parsed_outputs.values())):
215
+ new_data.append({**sample, **dict(zip(parsed_outputs.keys(), values))})
216
+
217
+ return Dataset.from_list(new_data)
218
+
219
+
220
+ @BlockRegistry.register("ConditionalLLMBlock")
221
+ class ConditionalLLMBlock(LLMBlock):
222
+ def __init__(
223
+ self,
224
+ block_name,
225
+ config_paths,
226
+ client,
227
+ model_id,
228
+ output_cols,
229
+ selector_column_name,
230
+ model_prompt="{prompt}",
231
+ **batch_kwargs,
232
+ ) -> None:
233
+ super().__init__(
234
+ block_name=block_name,
235
+ config_path=list(config_paths.values())[0],
236
+ client=client,
237
+ model_id=model_id,
238
+ output_cols=output_cols,
239
+ model_prompt=model_prompt,
240
+ **batch_kwargs,
241
+ )
242
+ self.selector_column_name = selector_column_name
243
+ self.prompt_template = {}
244
+ if "All" in config_paths:
245
+ self.prompt_template = self.prompt_struct.format(**self.block_config)
246
+ else:
247
+ for config_key, config in config_paths.items():
248
+ # Template(self.prompt_struct.format(**filtered_config))
249
+ filtered_config = {
250
+ k: (v if v is not None else "") for k, v in self.block_config.items()
251
+ }
252
+ self.prompt_template[config_key] = Template(self.prompt_struct.format(
253
+ **self._load_config(config)
254
+ ))
255
+
256
+ def _format_prompt(self, sample: Dict) -> str:
257
+ if isinstance(self.prompt_template, dict):
258
+ return (
259
+ self.prompt_template[sample[self.selector_column_name]]
260
+ .render(**sample)
261
+ .strip()
262
+ )
263
+
264
+ return self.prompt_template.render(**sample).strip()
265
+
266
+ def _validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool:
267
+ if isinstance(prompt_template, dict):
268
+ prompt_template = prompt_template[input_dict[self.selector_column_name]]
269
+ return super()._validate(prompt_template, input_dict)
270
+
271
+
272
+ @BlockRegistry.register("LLMLogProbBlock")
273
+ class LLMLogProbBlock(LLMBlock):
274
+ # init with init of the parent class
275
+ def __init__(
276
+ self,
277
+ block_name,
278
+ config_path,
279
+ client,
280
+ output_cols,
281
+ parser_kwargs={},
282
+ model_prompt="{prompt}",
283
+ model_id=None,
284
+ **batch_kwargs,
285
+ ) -> None:
286
+ super().__init__(
287
+ block_name=block_name,
288
+ config_path=config_path,
289
+ client=client,
290
+ output_cols=output_cols,
291
+ parser_kwargs=parser_kwargs,
292
+ model_prompt=model_prompt,
293
+ model_id=model_id,
294
+ **batch_kwargs,
295
+ )
296
+
297
+ def _generate_logprobs(self, samples, **gen_kwargs):
298
+ prompts = [
299
+ self.model_prompt.format(prompt=self._format_prompt(sample))
300
+ for sample in samples
301
+ ]
302
+ generate_args = {**self.defaults, **gen_kwargs}
303
+
304
+ # verify if logprobs is mentioned in the generate_args, if not add it and return top10 logprobs
305
+ if "logprobs" not in generate_args:
306
+ generate_args["logprobs"] = 10
307
+
308
+ if self.server_supports_batched:
309
+ response = self.client.completions.create(prompt=prompts, **generate_args)
310
+ return [choice.logprobs.top_logprobs for choice in response.choices]
311
+
312
+ n = gen_kwargs.get("n", 1)
313
+ results = []
314
+ for prompt in prompts:
315
+ for _ in range(n):
316
+ response = self.client.completions.create(
317
+ prompt=prompt, **generate_args
318
+ )
319
+ results.append(response.choices[0].logprobs.top_logprobs)
320
+ return results
321
+
322
+ def _parse(self, generations: List[List[Dict]]) -> List[List[str]]:
323
+ # override the parse method to convert the generations to json string
324
+ # convert the generations to json string to save as dataset
325
+ # this is because the dataset can only store key value pairs which are consistent
326
+ return [[json.dumps(item) for item in sublist] for sublist in generations]
327
+
328
+ def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
329
+ """
330
+ Generate the output from the block. This method should first validate the input data,
331
+ then generate the output, and finally parse the generated output before returning it.
332
+
333
+ :return: The parsed output after generation.
334
+ """
335
+ num_samples = self.block_config.get("num_samples", None)
336
+ logger.debug("Generating outputs for {} samples".format(len(samples)))
337
+
338
+ if (num_samples is not None) and ("num_samples" not in samples.column_names):
339
+ samples = samples.add_column("num_samples", [num_samples] * len(samples))
340
+
341
+ # validate each sample
342
+ # Log errors and remove invalid samples
343
+ valid_samples = []
344
+
345
+ for sample in samples:
346
+ if self._validate(self.prompt_template, sample):
347
+ valid_samples.append(sample)
348
+ else:
349
+ logger.warning(
350
+ f"Sample failed validation: {sample}"
351
+ ) # Log details of the failed sample
352
+
353
+ samples = valid_samples
354
+
355
+ if len(samples) == 0:
356
+ logger.warning(
357
+ "No valid samples to generate outputs for, returning empty dataset"
358
+ )
359
+ return Dataset.from_list([])
360
+
361
+ # generate the output
362
+
363
+ outputs = self._generate_logprobs(samples, **gen_kwargs)
364
+ logger.debug("Generated outputs: %s", outputs)
365
+
366
+ output_dataset = Dataset.from_list(samples)
367
+ output_dataset = output_dataset.add_column(
368
+ self.output_cols[0],
369
+ self._parse(outputs), # pylint: disable=no-value-for-parameter
370
+ )
371
+
372
+ return output_dataset
373
+
374
+
375
+ @BlockRegistry.register("LLMMessagesBlock")
376
+ class LLMMessagesBlock(Block):
377
+ def __init__(
378
+ self,
379
+ block_name,
380
+ client,
381
+ input_col,
382
+ output_col,
383
+ model_prompt=None,
384
+ model_id=None,
385
+ **batch_kwargs,
386
+ ) -> None:
387
+ self.block_name = block_name
388
+ self.model_prompt = model_prompt
389
+ self.batch_params = batch_kwargs.get("batch_kwargs", {})
390
+ self.input_col = input_col
391
+ self.output_col = output_col
392
+ self.client = client
393
+
394
+ if model_id:
395
+ self.model = model_id
396
+ else:
397
+ self.model = self.client.models.list().data[0].id
398
+
399
+ self.defaults = {
400
+ "model": self.model,
401
+ "temperature": 0,
402
+ "max_tokens": 4096,
403
+ }
404
+ self.server_supports_batched = server_supports_batched(client, self.model)
405
+
406
+ def _generate(self, samples, **gen_kwargs) -> list:
407
+ generate_args = {**self.defaults, **gen_kwargs}
408
+
409
+ if "n" in generate_args and generate_args.get("temperature", 0) <= 0:
410
+ generate_args["temperature"] = 0.7
411
+ logger.warning(
412
+ "Temperature should be greater than 0 for n > 1, setting temperature to 0.7"
413
+ )
414
+
415
+ messages = samples[self.input_col]
416
+
417
+ results = []
418
+ n = gen_kwargs.get("n", 1)
419
+ for message in messages:
420
+ responses = self.client.chat.completions.create(messages=message, **generate_args)
421
+ if n > 1:
422
+ results.append([choice.message.content for choice in responses.choices])
423
+ else:
424
+ results.append(responses.choices[0].message.content)
425
+ return results
426
+
427
+ def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
428
+ outputs = self._generate(samples, **gen_kwargs)
429
+ samples = samples.add_column(self.output_col, outputs)
430
+ return samples