sdg-hub 0.1.0a3__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. sdg_hub/_version.py +2 -2
  2. sdg_hub/blocks/__init__.py +35 -5
  3. sdg_hub/blocks/block.py +58 -16
  4. sdg_hub/blocks/llmblock.py +149 -204
  5. sdg_hub/blocks/utilblocks.py +500 -43
  6. sdg_hub/checkpointer.py +139 -0
  7. sdg_hub/configs/annotations/detailed_annotations.yaml +28 -0
  8. sdg_hub/configs/annotations/simple_annotations.yaml +9 -0
  9. sdg_hub/configs/knowledge/atomic_facts.yaml +1 -0
  10. sdg_hub/configs/knowledge/detailed_summary.yaml +1 -0
  11. sdg_hub/configs/knowledge/extractive_summary.yaml +1 -0
  12. sdg_hub/configs/knowledge/generate_questions.yaml +82 -0
  13. sdg_hub/configs/knowledge/generate_responses.yaml +86 -0
  14. sdg_hub/configs/skills/contexts.yaml +18 -11
  15. sdg_hub/configs/skills/evaluate_freeform_pair.yaml +79 -12
  16. sdg_hub/configs/skills/evaluate_freeform_questions.yaml +60 -28
  17. sdg_hub/configs/skills/evaluate_grounded_pair.yaml +95 -30
  18. sdg_hub/configs/skills/freeform_questions.yaml +21 -16
  19. sdg_hub/configs/skills/freeform_responses.yaml +19 -25
  20. sdg_hub/configs/skills/router.yaml +53 -6
  21. sdg_hub/flow.py +351 -21
  22. sdg_hub/flow_runner.py +216 -0
  23. sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +26 -9
  24. sdg_hub/flows/generation/skills/{agentic_improve_skill.yaml → improve_responses.yaml} +26 -31
  25. sdg_hub/flows/generation/skills/synth_skills.yaml +4 -4
  26. sdg_hub/pipeline.py +67 -12
  27. sdg_hub/prompts.py +26 -0
  28. sdg_hub/sdg.py +128 -86
  29. sdg_hub/utils/config_validation.py +91 -0
  30. sdg_hub/utils/validation_result.py +10 -0
  31. sdg_hub-0.1.1.dist-info/METADATA +190 -0
  32. sdg_hub-0.1.1.dist-info/RECORD +86 -0
  33. {sdg_hub-0.1.0a3.dist-info → sdg_hub-0.1.1.dist-info}/WHEEL +1 -1
  34. sdg_hub/blocks/filterblock.py +0 -76
  35. sdg_hub/blocks/iterblock.py +0 -31
  36. sdg_hub/blocks/rmblocks.py +0 -194
  37. sdg_hub/configs/annotations/simple.yaml +0 -10
  38. sdg_hub/configs/knowledge/data_recipe/default_recipe.yaml +0 -3
  39. sdg_hub/configs/skills/data_recipe/default_recipe.yaml +0 -6
  40. sdg_hub/flows/annotation/emotion/detailed_description.yaml +0 -19
  41. sdg_hub/flows/annotation/emotion/detailed_description_icl.yaml +0 -19
  42. sdg_hub/flows/annotation/emotion/simple.yaml +0 -19
  43. sdg_hub/utils/chunking.py +0 -73
  44. sdg_hub/utils/docprocessor.py +0 -357
  45. sdg_hub/utils/parse_and_convert.py +0 -392
  46. sdg_hub-0.1.0a3.dist-info/METADATA +0 -154
  47. sdg_hub-0.1.0a3.dist-info/RECORD +0 -90
  48. /sdg_hub/configs/{knowledge/data_recipe → reasoning}/__init__.py +0 -0
  49. /sdg_hub/configs/skills/{_G_.yaml → icl_examples/STEM.yaml} +0 -0
  50. /sdg_hub/configs/skills/{data_recipe → icl_examples}/__init__.py +0 -0
  51. /sdg_hub/configs/skills/{_A_.yaml → icl_examples/coding.yaml} +0 -0
  52. /sdg_hub/configs/skills/{_B_.yaml → icl_examples/extraction.yaml} +0 -0
  53. /sdg_hub/configs/skills/{_C_.yaml → icl_examples/humanities.yaml} +0 -0
  54. /sdg_hub/configs/skills/{_D_.yaml → icl_examples/math.yaml} +0 -0
  55. /sdg_hub/configs/skills/{_E_.yaml → icl_examples/reasoning.yaml} +0 -0
  56. /sdg_hub/configs/skills/{_F_.yaml → icl_examples/roleplay.yaml} +0 -0
  57. /sdg_hub/configs/skills/{_H_.yaml → icl_examples/writing.yaml} +0 -0
  58. {sdg_hub-0.1.0a3.dist-info → sdg_hub-0.1.1.dist-info}/licenses/LICENSE +0 -0
  59. {sdg_hub-0.1.0a3.dist-info → sdg_hub-0.1.1.dist-info}/top_level.txt +0 -0
sdg_hub/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.0a3'
21
- __version_tuple__ = version_tuple = (0, 1, 0)
20
+ __version__ = version = '0.1.1'
21
+ __version_tuple__ = version_tuple = (0, 1, 1)
@@ -1,6 +1,36 @@
1
+ """Block implementations for SDG Hub.
2
+
3
+ This package provides various block implementations for data generation, processing, and transformation.
4
+ """
5
+
1
6
  # Local
2
- from .block import *
3
- from .filterblock import *
4
- from .iterblock import *
5
- from .llmblock import *
6
- from .utilblocks import *
7
+ from .block import Block
8
+ from .llmblock import LLMBlock, ConditionalLLMBlock
9
+ from .utilblocks import (
10
+ SamplePopulatorBlock,
11
+ SelectorBlock,
12
+ CombineColumnsBlock,
13
+ FlattenColumnsBlock,
14
+ DuplicateColumns,
15
+ RenameColumns,
16
+ SetToMajorityValue,
17
+ FilterByValueBlock,
18
+ IterBlock,
19
+ )
20
+ from ..registry import BlockRegistry
21
+
22
+ __all__ = [
23
+ "Block",
24
+ "FilterByValueBlock",
25
+ "IterBlock",
26
+ "LLMBlock",
27
+ "ConditionalLLMBlock",
28
+ "SamplePopulatorBlock",
29
+ "SelectorBlock",
30
+ "CombineColumnsBlock",
31
+ "FlattenColumnsBlock",
32
+ "DuplicateColumns",
33
+ "RenameColumns",
34
+ "SetToMajorityValue",
35
+ "BlockRegistry",
36
+ ]
sdg_hub/blocks/block.py CHANGED
@@ -1,8 +1,14 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
+ """Base block implementation for the SDG Hub system.
3
+
4
+ This module provides the abstract base class for all blocks in the system,
5
+ including functionality for template validation and configuration management.
6
+ """
7
+
2
8
  # Standard
3
9
  from abc import ABC
4
10
  from collections import ChainMap
5
- from typing import Any, Dict, Union
11
+ from typing import Any, Dict, Optional
6
12
 
7
13
  # Third Party
8
14
  from jinja2 import Template, UndefinedError
@@ -17,24 +23,38 @@ logger = setup_logger(__name__)
17
23
 
18
24
  @BlockRegistry.register("Block")
19
25
  class Block(ABC):
26
+ """Base abstract class for all blocks in the system.
27
+
28
+ This class provides common functionality for block validation and configuration loading.
29
+ All specific block implementations should inherit from this class.
30
+ """
31
+
20
32
  def __init__(self, block_name: str) -> None:
21
33
  self.block_name = block_name
22
34
 
23
35
  @staticmethod
24
36
  def _validate(prompt_template: Template, input_dict: Dict[str, Any]) -> bool:
25
- """
26
- Validate the input data for this block. This method validates whether all required
27
- variables in the Jinja template are provided in the input_dict.
37
+ """Validate the input data for this block.
38
+
39
+ This method validates whether all required variables in the Jinja template are provided in the input_dict.
40
+
41
+ Parameters
42
+ ----------
43
+ prompt_template : Template
44
+ The Jinja2 template object.
45
+ input_dict : Dict[str, Any]
46
+ A dictionary of input values to check against the template.
28
47
 
29
- :param prompt_template: The Jinja2 template object.
30
- :param input_dict: A dictionary of input values to check against the template.
31
- :return: True if the input data is valid (i.e., no missing variables), False otherwise.
48
+ Returns
49
+ -------
50
+ bool
51
+ True if the input data is valid (i.e., no missing variables), False otherwise.
32
52
  """
33
-
53
+
34
54
  class Default(dict):
35
55
  def __missing__(self, key: str) -> None:
36
56
  raise KeyError(key)
37
-
57
+
38
58
  try:
39
59
  # Try rendering the template with the input_dict
40
60
  prompt_template.render(ChainMap(input_dict, Default()))
@@ -43,12 +63,34 @@ class Block(ABC):
43
63
  logger.error(f"Missing key: {e}")
44
64
  return False
45
65
 
46
- def _load_config(self, config_path: str) -> Union[Dict[str, Any], None]:
47
- """
48
- Load the configuration file for this block.
66
+ def _load_config(self, config_path: str) -> Optional[Dict[str, Any]]:
67
+ """Load the configuration file for this block.
49
68
 
50
- :param config_path: The path to the configuration file.
51
- :return: The loaded configuration.
69
+ Parameters
70
+ ----------
71
+ config_path : str
72
+ The path to the configuration file.
73
+
74
+ Returns
75
+ -------
76
+ Optional[Dict[str, Any]]
77
+ The loaded configuration. Returns None if file cannot be read or parsed.
78
+
79
+ Raises
80
+ ------
81
+ FileNotFoundError
82
+ If the configuration file does not exist.
52
83
  """
53
- with open(config_path, "r", encoding="utf-8") as config_file:
54
- return yaml.safe_load(config_file)
84
+ try:
85
+ with open(config_path, "r", encoding="utf-8") as config_file:
86
+ try:
87
+ return yaml.safe_load(config_file)
88
+ except yaml.YAMLError as e:
89
+ logger.error(f"Error parsing YAML from {config_path}: {e}")
90
+ return None
91
+ except FileNotFoundError:
92
+ logger.error(f"Configuration file not found: {config_path}")
93
+ raise
94
+ except Exception as e:
95
+ logger.error(f"Unexpected error reading config file {config_path}: {e}")
96
+ return None
@@ -1,7 +1,11 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
+ """LLM-based blocks for text generation and processing.
3
+
4
+ This module provides blocks for interacting with language models.
5
+ """
6
+
2
7
  # Standard
3
- from collections import Counter
4
- from typing import Any, Dict, List
8
+ from typing import Any, Dict, List, Optional, Union
5
9
  import json
6
10
  import re
7
11
 
@@ -18,7 +22,18 @@ from ..registry import BlockRegistry, PromptRegistry
18
22
  logger = setup_logger(__name__)
19
23
 
20
24
 
21
- def server_supports_batched(client, model_id: str) -> bool:
25
+ def server_supports_batched(client: openai.OpenAI, model_id: str) -> bool:
26
+ """Check if the server supports batched inputs.
27
+
28
+ This function checks if the server supports batched inputs by making a test call to the server.
29
+
30
+ Parameters
31
+ ----------
32
+ client : openai.OpenAI
33
+ The client to use to make the test call.
34
+ model_id : str
35
+ The model ID to use for the test call.
36
+ """
22
37
  supported = getattr(client, "server_supports_batched", None)
23
38
  if supported is not None:
24
39
  return supported
@@ -38,19 +53,43 @@ def server_supports_batched(client, model_id: str) -> bool:
38
53
 
39
54
 
40
55
  @BlockRegistry.register("LLMBlock")
41
- # pylint: disable=dangerous-default-value
42
56
  class LLMBlock(Block):
57
+ """Block for generating text using language models.
58
+
59
+ This block handles text generation, prompt formatting, and output parsing
60
+ for language model interactions.
61
+
62
+ Parameters
63
+ ----------
64
+ block_name : str
65
+ Name of the block.
66
+ config_path : str
67
+ Path to the configuration file.
68
+ client : openai.OpenAI
69
+ OpenAI client instance.
70
+ output_cols : List[str]
71
+ List of output column names.
72
+ parser_kwargs : Dict[str, Any], optional
73
+ Keyword arguments for the parser, by default {}.
74
+ model_prompt : str, optional
75
+ Template string for model prompt, by default "{prompt}".
76
+ model_id : Optional[str], optional
77
+ Model ID to use, by default None.
78
+ **batch_kwargs : Dict[str, Any]
79
+ Additional keyword arguments for batch processing.
80
+ """
81
+
43
82
  # pylint: disable=too-many-instance-attributes
44
83
  def __init__(
45
84
  self,
46
- block_name,
47
- config_path,
48
- client,
49
- output_cols,
50
- parser_kwargs={},
51
- model_prompt="{prompt}",
52
- model_id=None,
53
- **batch_kwargs,
85
+ block_name: str,
86
+ config_path: str,
87
+ client: openai.OpenAI,
88
+ output_cols: List[str],
89
+ parser_kwargs: Dict[str, Any] = {},
90
+ model_prompt: str = "{prompt}",
91
+ model_id: Optional[str] = None,
92
+ **batch_kwargs: Dict[str, Any],
54
93
  ) -> None:
55
94
  super().__init__(block_name)
56
95
  self.block_config = self._load_config(config_path)
@@ -84,7 +123,27 @@ class LLMBlock(Block):
84
123
  # and supports the n parameter to generate n outputs per input
85
124
  self.server_supports_batched = server_supports_batched(client, self.model)
86
125
 
87
- def _parse(self, generated_string) -> dict:
126
+ def _extract_matches(
127
+ self, text: str, start_tag: Optional[str], end_tag: Optional[str]
128
+ ) -> List[str]:
129
+ if not text:
130
+ return []
131
+ if not start_tag and not end_tag:
132
+ return [text.strip()]
133
+
134
+ pattern = ""
135
+ if start_tag:
136
+ pattern += re.escape(start_tag)
137
+ pattern += r"(.*?)"
138
+ if end_tag:
139
+ pattern += re.escape(end_tag)
140
+ elif start_tag:
141
+ # Enforce matching till end of string when only start_tag is provided.
142
+ pattern += "$"
143
+
144
+ return [match.strip() for match in re.findall(pattern, text, re.DOTALL)]
145
+
146
+ def _parse(self, generated_string: str) -> dict:
88
147
  matches = {}
89
148
 
90
149
  if self.parser_name is not None and self.parser_name == "custom":
@@ -108,16 +167,9 @@ class LLMBlock(Block):
108
167
  self.block_config.get("end_tags", []),
109
168
  self.output_cols,
110
169
  ):
111
- if not start_tag and not end_tag:
112
- matches[output_col] = [
113
- generated_string.strip() if generated_string else None
114
- ]
115
- else:
116
- pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
117
- all_matches = re.findall(pattern, generated_string, re.DOTALL)
118
- matches[output_col] = (
119
- [match.strip() for match in all_matches] if all_matches else []
120
- )
170
+ matches[output_col] = self._extract_matches(
171
+ generated_string, start_tag, end_tag
172
+ )
121
173
 
122
174
  return matches
123
175
 
@@ -127,7 +179,7 @@ class LLMBlock(Block):
127
179
  self.model_prompt, prompt_templated_str, add_generation_prompt=True
128
180
  ).strip()
129
181
 
130
- def _generate(self, samples, **gen_kwargs) -> list:
182
+ def _generate(self, samples: Dataset, **gen_kwargs: Dict[str, Any]) -> list:
131
183
  prompts = [self._format_prompt(sample) for sample in samples]
132
184
  logger.debug("Prompt: %s", prompts[0])
133
185
  generate_args = {**self.defaults, **gen_kwargs}
@@ -159,12 +211,16 @@ class LLMBlock(Block):
159
211
  results.append(response.choices[0].text.strip())
160
212
  return results
161
213
 
162
- def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
163
- """
164
- Generate the output from the block. This method should first validate the input data,
214
+ def generate(self, samples: Dataset, **gen_kwargs: Dict[str, Any]) -> Dataset:
215
+ """Generate the output from the block.
216
+
217
+ This method should first validate the input data,
165
218
  then generate the output, and finally parse the generated output before returning it.
166
219
 
167
- :return: The parsed output after generation.
220
+ Returns
221
+ -------
222
+ Dataset
223
+ The parsed output after generation.
168
224
  """
169
225
  num_samples = self.block_config.get("num_samples", None)
170
226
  logger.debug("Generating outputs for {} samples".format(len(samples)))
@@ -219,16 +275,40 @@ class LLMBlock(Block):
219
275
 
220
276
  @BlockRegistry.register("ConditionalLLMBlock")
221
277
  class ConditionalLLMBlock(LLMBlock):
278
+ """Block for conditional text generation using language models.
279
+
280
+ This block selects different prompt templates based on a selector column value.
281
+
282
+ Parameters
283
+ ----------
284
+ block_name : str
285
+ Name of the block.
286
+ config_paths : Dict[str, str]
287
+ Dictionary mapping selector values to their config file paths.
288
+ client : openai.OpenAI
289
+ OpenAI client instance.
290
+ model_id : str
291
+ Model ID to use.
292
+ output_cols : List[str]
293
+ List of output column names.
294
+ selector_column_name : str
295
+ Name of the column used to select the prompt template.
296
+ model_prompt : str, optional
297
+ Template string for model prompt, by default "{prompt}".
298
+ **batch_kwargs : Dict[str, Any]
299
+ Additional keyword arguments for batch processing.
300
+ """
301
+
222
302
  def __init__(
223
303
  self,
224
- block_name,
225
- config_paths,
226
- client,
227
- model_id,
228
- output_cols,
229
- selector_column_name,
230
- model_prompt="{prompt}",
231
- **batch_kwargs,
304
+ block_name: str,
305
+ config_paths: Dict[str, str],
306
+ client: openai.OpenAI,
307
+ model_id: str,
308
+ output_cols: List[str],
309
+ selector_column_name: str,
310
+ model_prompt: str = "{prompt}",
311
+ **batch_kwargs: Dict[str, Any],
232
312
  ) -> None:
233
313
  super().__init__(
234
314
  block_name=block_name,
@@ -245,15 +325,27 @@ class ConditionalLLMBlock(LLMBlock):
245
325
  self.prompt_template = self.prompt_struct.format(**self.block_config)
246
326
  else:
247
327
  for config_key, config in config_paths.items():
248
- # Template(self.prompt_struct.format(**filtered_config))
249
328
  filtered_config = {
250
- k: (v if v is not None else "") for k, v in self.block_config.items()
329
+ k: (v if v is not None else "")
330
+ for k, v in self.block_config.items()
251
331
  }
252
- self.prompt_template[config_key] = Template(self.prompt_struct.format(
253
- **self._load_config(config)
254
- ))
332
+ self.prompt_template[config_key] = Template(
333
+ self.prompt_struct.format(**self._load_config(config))
334
+ )
255
335
 
256
- def _format_prompt(self, sample: Dict) -> str:
336
+ def _format_prompt(self, sample: Dict[str, Any]) -> str:
337
+ """Format the prompt based on the selector column value.
338
+
339
+ Parameters
340
+ ----------
341
+ sample : Dict[str, Any]
342
+ Input sample containing the selector column.
343
+
344
+ Returns
345
+ -------
346
+ str
347
+ Formatted prompt string.
348
+ """
257
349
  if isinstance(self.prompt_template, dict):
258
350
  return (
259
351
  self.prompt_template[sample[self.selector_column_name]]
@@ -263,168 +355,21 @@ class ConditionalLLMBlock(LLMBlock):
263
355
 
264
356
  return self.prompt_template.render(**sample).strip()
265
357
 
266
- def _validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool:
267
- if isinstance(prompt_template, dict):
268
- prompt_template = prompt_template[input_dict[self.selector_column_name]]
269
- return super()._validate(prompt_template, input_dict)
270
-
271
-
272
- @BlockRegistry.register("LLMLogProbBlock")
273
- class LLMLogProbBlock(LLMBlock):
274
- # init with init of the parent class
275
- def __init__(
276
- self,
277
- block_name,
278
- config_path,
279
- client,
280
- output_cols,
281
- parser_kwargs={},
282
- model_prompt="{prompt}",
283
- model_id=None,
284
- **batch_kwargs,
285
- ) -> None:
286
- super().__init__(
287
- block_name=block_name,
288
- config_path=config_path,
289
- client=client,
290
- output_cols=output_cols,
291
- parser_kwargs=parser_kwargs,
292
- model_prompt=model_prompt,
293
- model_id=model_id,
294
- **batch_kwargs,
295
- )
296
-
297
- def _generate_logprobs(self, samples, **gen_kwargs):
298
- prompts = [
299
- self.model_prompt.format(prompt=self._format_prompt(sample))
300
- for sample in samples
301
- ]
302
- generate_args = {**self.defaults, **gen_kwargs}
303
-
304
- # verify if logprobs is mentioned in the generate_args, if not add it and return top10 logprobs
305
- if "logprobs" not in generate_args:
306
- generate_args["logprobs"] = 10
358
+ def _validate(self, prompt_template: Union[str, Template], input_dict: Dict[str, Any]) -> bool:
359
+ """Validate the input data for this block.
307
360
 
308
- if self.server_supports_batched:
309
- response = self.client.completions.create(prompt=prompts, **generate_args)
310
- return [choice.logprobs.top_logprobs for choice in response.choices]
361
+ Parameters
362
+ ----------
363
+ prompt_template : Union[str, Template]
364
+ The template to validate against.
365
+ input_dict : Dict[str, Any]
366
+ Input data to validate.
311
367
 
312
- n = gen_kwargs.get("n", 1)
313
- results = []
314
- for prompt in prompts:
315
- for _ in range(n):
316
- response = self.client.completions.create(
317
- prompt=prompt, **generate_args
318
- )
319
- results.append(response.choices[0].logprobs.top_logprobs)
320
- return results
321
-
322
- def _parse(self, generations: List[List[Dict]]) -> List[List[str]]:
323
- # override the parse method to convert the generations to json string
324
- # convert the generations to json string to save as dataset
325
- # this is because the dataset can only store key value pairs which are consistent
326
- return [[json.dumps(item) for item in sublist] for sublist in generations]
327
-
328
- def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
368
+ Returns
369
+ -------
370
+ bool
371
+ True if the input data is valid, False otherwise.
329
372
  """
330
- Generate the output from the block. This method should first validate the input data,
331
- then generate the output, and finally parse the generated output before returning it.
332
-
333
- :return: The parsed output after generation.
334
- """
335
- num_samples = self.block_config.get("num_samples", None)
336
- logger.debug("Generating outputs for {} samples".format(len(samples)))
337
-
338
- if (num_samples is not None) and ("num_samples" not in samples.column_names):
339
- samples = samples.add_column("num_samples", [num_samples] * len(samples))
340
-
341
- # validate each sample
342
- # Log errors and remove invalid samples
343
- valid_samples = []
344
-
345
- for sample in samples:
346
- if self._validate(self.prompt_template, sample):
347
- valid_samples.append(sample)
348
- else:
349
- logger.warning(
350
- f"Sample failed validation: {sample}"
351
- ) # Log details of the failed sample
352
-
353
- samples = valid_samples
354
-
355
- if len(samples) == 0:
356
- logger.warning(
357
- "No valid samples to generate outputs for, returning empty dataset"
358
- )
359
- return Dataset.from_list([])
360
-
361
- # generate the output
362
-
363
- outputs = self._generate_logprobs(samples, **gen_kwargs)
364
- logger.debug("Generated outputs: %s", outputs)
365
-
366
- output_dataset = Dataset.from_list(samples)
367
- output_dataset = output_dataset.add_column(
368
- self.output_cols[0],
369
- self._parse(outputs), # pylint: disable=no-value-for-parameter
370
- )
371
-
372
- return output_dataset
373
-
374
-
375
- @BlockRegistry.register("LLMMessagesBlock")
376
- class LLMMessagesBlock(Block):
377
- def __init__(
378
- self,
379
- block_name,
380
- client,
381
- input_col,
382
- output_col,
383
- model_prompt=None,
384
- model_id=None,
385
- **batch_kwargs,
386
- ) -> None:
387
- self.block_name = block_name
388
- self.model_prompt = model_prompt
389
- self.batch_params = batch_kwargs.get("batch_kwargs", {})
390
- self.input_col = input_col
391
- self.output_col = output_col
392
- self.client = client
393
-
394
- if model_id:
395
- self.model = model_id
396
- else:
397
- self.model = self.client.models.list().data[0].id
398
-
399
- self.defaults = {
400
- "model": self.model,
401
- "temperature": 0,
402
- "max_tokens": 4096,
403
- }
404
- self.server_supports_batched = server_supports_batched(client, self.model)
405
-
406
- def _generate(self, samples, **gen_kwargs) -> list:
407
- generate_args = {**self.defaults, **gen_kwargs}
408
-
409
- if "n" in generate_args and generate_args.get("temperature", 0) <= 0:
410
- generate_args["temperature"] = 0.7
411
- logger.warning(
412
- "Temperature should be greater than 0 for n > 1, setting temperature to 0.7"
413
- )
414
-
415
- messages = samples[self.input_col]
416
-
417
- results = []
418
- n = gen_kwargs.get("n", 1)
419
- for message in messages:
420
- responses = self.client.chat.completions.create(messages=message, **generate_args)
421
- if n > 1:
422
- results.append([choice.message.content for choice in responses.choices])
423
- else:
424
- results.append(responses.choices[0].message.content)
425
- return results
426
-
427
- def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
428
- outputs = self._generate(samples, **gen_kwargs)
429
- samples = samples.add_column(self.output_col, outputs)
430
- return samples
373
+ if isinstance(prompt_template, dict):
374
+ prompt_template = prompt_template[input_dict[self.selector_column_name]]
375
+ return super()._validate(prompt_template, input_dict)