sdg-hub 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/__init__.py +28 -1
- sdg_hub/_version.py +2 -2
- sdg_hub/core/__init__.py +22 -0
- sdg_hub/core/blocks/__init__.py +58 -0
- sdg_hub/core/blocks/base.py +313 -0
- sdg_hub/core/blocks/deprecated_blocks/__init__.py +29 -0
- sdg_hub/core/blocks/deprecated_blocks/combine_columns.py +93 -0
- sdg_hub/core/blocks/deprecated_blocks/duplicate_columns.py +88 -0
- sdg_hub/core/blocks/deprecated_blocks/filter_by_value.py +103 -0
- sdg_hub/core/blocks/deprecated_blocks/flatten_columns.py +94 -0
- sdg_hub/core/blocks/deprecated_blocks/llmblock.py +479 -0
- sdg_hub/core/blocks/deprecated_blocks/rename_columns.py +88 -0
- sdg_hub/core/blocks/deprecated_blocks/sample_populator.py +58 -0
- sdg_hub/core/blocks/deprecated_blocks/selector.py +97 -0
- sdg_hub/core/blocks/deprecated_blocks/set_to_majority_value.py +88 -0
- sdg_hub/core/blocks/evaluation/__init__.py +9 -0
- sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +564 -0
- sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +564 -0
- sdg_hub/core/blocks/evaluation/verify_question_block.py +564 -0
- sdg_hub/core/blocks/filtering/__init__.py +12 -0
- sdg_hub/core/blocks/filtering/column_value_filter.py +188 -0
- sdg_hub/core/blocks/llm/__init__.py +27 -0
- sdg_hub/core/blocks/llm/client_manager.py +398 -0
- sdg_hub/core/blocks/llm/config.py +336 -0
- sdg_hub/core/blocks/llm/error_handler.py +368 -0
- sdg_hub/core/blocks/llm/llm_chat_block.py +542 -0
- sdg_hub/core/blocks/llm/llm_chat_with_parsing_retry_block.py +491 -0
- sdg_hub/core/blocks/llm/prompt_builder_block.py +368 -0
- sdg_hub/core/blocks/llm/text_parser_block.py +357 -0
- sdg_hub/core/blocks/registry.py +331 -0
- sdg_hub/core/blocks/transform/__init__.py +23 -0
- sdg_hub/core/blocks/transform/duplicate_columns.py +88 -0
- sdg_hub/core/blocks/transform/index_based_mapper.py +225 -0
- sdg_hub/core/blocks/transform/melt_columns.py +126 -0
- sdg_hub/core/blocks/transform/rename_columns.py +69 -0
- sdg_hub/core/blocks/transform/text_concat.py +102 -0
- sdg_hub/core/blocks/transform/uniform_col_val_setter.py +101 -0
- sdg_hub/core/flow/__init__.py +20 -0
- sdg_hub/core/flow/base.py +1209 -0
- sdg_hub/core/flow/checkpointer.py +333 -0
- sdg_hub/core/flow/metadata.py +389 -0
- sdg_hub/core/flow/migration.py +198 -0
- sdg_hub/core/flow/registry.py +393 -0
- sdg_hub/core/flow/validation.py +277 -0
- sdg_hub/{utils → core/utils}/__init__.py +7 -4
- sdg_hub/core/utils/datautils.py +63 -0
- sdg_hub/core/utils/error_handling.py +208 -0
- sdg_hub/core/utils/flow_id_words.yaml +231 -0
- sdg_hub/core/utils/flow_identifier.py +94 -0
- sdg_hub/{utils → core/utils}/path_resolution.py +2 -2
- sdg_hub/core/utils/yaml_utils.py +59 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/atomic_facts.yaml +40 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/detailed_summary.yaml +13 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_faithfulness.yaml +64 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_question.yaml +29 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_relevancy.yaml +81 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/extractive_summary.yaml +13 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +192 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/generate_questions_responses.yaml +54 -0
- sdg_hub-0.2.1.dist-info/METADATA +221 -0
- sdg_hub-0.2.1.dist-info/RECORD +68 -0
- sdg_hub/blocks/__init__.py +0 -42
- sdg_hub/blocks/block.py +0 -96
- sdg_hub/blocks/llmblock.py +0 -375
- sdg_hub/blocks/openaichatblock.py +0 -556
- sdg_hub/blocks/utilblocks.py +0 -597
- sdg_hub/checkpointer.py +0 -139
- sdg_hub/configs/annotations/cot_reflection.yaml +0 -34
- sdg_hub/configs/annotations/detailed_annotations.yaml +0 -28
- sdg_hub/configs/annotations/detailed_description.yaml +0 -10
- sdg_hub/configs/annotations/detailed_description_icl.yaml +0 -32
- sdg_hub/configs/annotations/simple_annotations.yaml +0 -9
- sdg_hub/configs/knowledge/__init__.py +0 -0
- sdg_hub/configs/knowledge/atomic_facts.yaml +0 -46
- sdg_hub/configs/knowledge/auxilary_instructions.yaml +0 -35
- sdg_hub/configs/knowledge/detailed_summary.yaml +0 -18
- sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +0 -68
- sdg_hub/configs/knowledge/evaluate_question.yaml +0 -38
- sdg_hub/configs/knowledge/evaluate_relevancy.yaml +0 -84
- sdg_hub/configs/knowledge/extractive_summary.yaml +0 -18
- sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +0 -39
- sdg_hub/configs/knowledge/generate_questions.yaml +0 -82
- sdg_hub/configs/knowledge/generate_questions_responses.yaml +0 -56
- sdg_hub/configs/knowledge/generate_responses.yaml +0 -86
- sdg_hub/configs/knowledge/mcq_generation.yaml +0 -83
- sdg_hub/configs/knowledge/router.yaml +0 -12
- sdg_hub/configs/knowledge/simple_generate_qa.yaml +0 -34
- sdg_hub/configs/reasoning/__init__.py +0 -0
- sdg_hub/configs/reasoning/dynamic_cot.yaml +0 -40
- sdg_hub/configs/skills/__init__.py +0 -0
- sdg_hub/configs/skills/analyzer.yaml +0 -48
- sdg_hub/configs/skills/annotation.yaml +0 -36
- sdg_hub/configs/skills/contexts.yaml +0 -28
- sdg_hub/configs/skills/critic.yaml +0 -60
- sdg_hub/configs/skills/evaluate_freeform_pair.yaml +0 -111
- sdg_hub/configs/skills/evaluate_freeform_questions.yaml +0 -78
- sdg_hub/configs/skills/evaluate_grounded_pair.yaml +0 -119
- sdg_hub/configs/skills/evaluate_grounded_questions.yaml +0 -51
- sdg_hub/configs/skills/freeform_questions.yaml +0 -34
- sdg_hub/configs/skills/freeform_responses.yaml +0 -39
- sdg_hub/configs/skills/grounded_questions.yaml +0 -38
- sdg_hub/configs/skills/grounded_responses.yaml +0 -59
- sdg_hub/configs/skills/icl_examples/STEM.yaml +0 -56
- sdg_hub/configs/skills/icl_examples/__init__.py +0 -0
- sdg_hub/configs/skills/icl_examples/coding.yaml +0 -97
- sdg_hub/configs/skills/icl_examples/extraction.yaml +0 -36
- sdg_hub/configs/skills/icl_examples/humanities.yaml +0 -71
- sdg_hub/configs/skills/icl_examples/math.yaml +0 -85
- sdg_hub/configs/skills/icl_examples/reasoning.yaml +0 -30
- sdg_hub/configs/skills/icl_examples/roleplay.yaml +0 -45
- sdg_hub/configs/skills/icl_examples/writing.yaml +0 -80
- sdg_hub/configs/skills/judge.yaml +0 -53
- sdg_hub/configs/skills/planner.yaml +0 -67
- sdg_hub/configs/skills/respond.yaml +0 -8
- sdg_hub/configs/skills/revised_responder.yaml +0 -78
- sdg_hub/configs/skills/router.yaml +0 -59
- sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +0 -27
- sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +0 -31
- sdg_hub/flow.py +0 -477
- sdg_hub/flow_runner.py +0 -450
- sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +0 -13
- sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +0 -12
- sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +0 -89
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +0 -136
- sdg_hub/flows/generation/skills/improve_responses.yaml +0 -103
- sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +0 -12
- sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +0 -12
- sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +0 -80
- sdg_hub/flows/generation/skills/synth_skills.yaml +0 -59
- sdg_hub/pipeline.py +0 -121
- sdg_hub/prompts.py +0 -80
- sdg_hub/registry.py +0 -122
- sdg_hub/sdg.py +0 -206
- sdg_hub/utils/config_validation.py +0 -91
- sdg_hub/utils/datautils.py +0 -14
- sdg_hub/utils/error_handling.py +0 -94
- sdg_hub/utils/validation_result.py +0 -10
- sdg_hub-0.1.4.dist-info/METADATA +0 -190
- sdg_hub-0.1.4.dist-info/RECORD +0 -89
- sdg_hub/{logger_config.py → core/utils/logger_config.py} +1 -1
- /sdg_hub/{configs/__init__.py → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/README.md} +0 -0
- /sdg_hub/{configs/annotations → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab}/__init__.py +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.1.dist-info}/WHEEL +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,368 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Prompt builder block for formatting prompts into structured chat messages or plain text.
|
3
|
+
|
4
|
+
This module provides the PromptBuilderBlock for handling LLM prompt formatting,
|
5
|
+
including conversion to OpenAI Messages format and template rendering.
|
6
|
+
"""
|
7
|
+
|
8
|
+
# Standard
|
9
|
+
from typing import Any, Literal, Optional
|
10
|
+
|
11
|
+
# Third Party
|
12
|
+
from datasets import Dataset
|
13
|
+
from jinja2 import Template, meta
|
14
|
+
from pydantic import BaseModel, Field, field_validator
|
15
|
+
import yaml
|
16
|
+
|
17
|
+
# Local
|
18
|
+
from ...utils.error_handling import TemplateValidationError
|
19
|
+
from ...utils.logger_config import setup_logger
|
20
|
+
from ..base import BaseBlock
|
21
|
+
from ..registry import BlockRegistry
|
22
|
+
|
23
|
+
logger = setup_logger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
class ChatMessage(BaseModel):
|
27
|
+
"""Pydantic model for chat messages with proper validation."""
|
28
|
+
|
29
|
+
role: Literal["system", "user", "assistant", "tool"]
|
30
|
+
content: str
|
31
|
+
|
32
|
+
@field_validator("content")
|
33
|
+
@classmethod
|
34
|
+
def validate_content_not_empty(cls, v: str) -> str:
|
35
|
+
"""Ensure content is not empty or just whitespace."""
|
36
|
+
if not v or not v.strip():
|
37
|
+
raise ValueError("Message content cannot be empty")
|
38
|
+
return v.strip()
|
39
|
+
|
40
|
+
|
41
|
+
class MessageTemplate(BaseModel):
|
42
|
+
"""Template for a chat message with Jinja2 template and original source."""
|
43
|
+
|
44
|
+
role: Literal["system", "user", "assistant", "tool"]
|
45
|
+
content_template: Template
|
46
|
+
original_source: str
|
47
|
+
|
48
|
+
model_config = {"arbitrary_types_allowed": True}
|
49
|
+
|
50
|
+
|
51
|
+
class PromptTemplateConfig:
|
52
|
+
"""Self-contained class for loading and validating YAML prompt configurations."""
|
53
|
+
|
54
|
+
def __init__(self, config_path: str):
|
55
|
+
"""Initialize with path to YAML config file."""
|
56
|
+
self.config_path = config_path
|
57
|
+
self.message_templates: list[MessageTemplate] = []
|
58
|
+
self._load_and_validate()
|
59
|
+
|
60
|
+
def _load_and_validate(self) -> None:
|
61
|
+
"""Load YAML config and validate format."""
|
62
|
+
try:
|
63
|
+
with open(self.config_path, encoding="utf-8") as config_file:
|
64
|
+
config = yaml.safe_load(config_file)
|
65
|
+
|
66
|
+
if not isinstance(config, list):
|
67
|
+
raise ValueError(
|
68
|
+
"Template config must be a list of message objects"
|
69
|
+
)
|
70
|
+
|
71
|
+
if not config:
|
72
|
+
raise ValueError("Prompt configuration cannot be empty")
|
73
|
+
|
74
|
+
self._compile_templates(config)
|
75
|
+
self._validate_message_flow()
|
76
|
+
|
77
|
+
except FileNotFoundError:
|
78
|
+
logger.error(f"Configuration file not found: {self.config_path}")
|
79
|
+
raise
|
80
|
+
except yaml.YAMLError as e:
|
81
|
+
logger.error(f"Error parsing YAML from {self.config_path}: {e}")
|
82
|
+
raise
|
83
|
+
except Exception as e:
|
84
|
+
logger.error(
|
85
|
+
f"Unexpected error reading config file {self.config_path}: {e}"
|
86
|
+
)
|
87
|
+
raise
|
88
|
+
|
89
|
+
def _compile_templates(self, config: list[dict[str, Any]]) -> None:
|
90
|
+
"""Compile Jinja templates for each message in the config."""
|
91
|
+
for i, message in enumerate(config):
|
92
|
+
if "role" not in message or "content" not in message:
|
93
|
+
raise ValueError(
|
94
|
+
f"Message {i} must have 'role' and 'content' fields. Got: {message.keys()}"
|
95
|
+
)
|
96
|
+
|
97
|
+
try:
|
98
|
+
content_source = message["content"]
|
99
|
+
message_template = MessageTemplate(
|
100
|
+
role=message["role"],
|
101
|
+
content_template=Template(content_source),
|
102
|
+
original_source=content_source,
|
103
|
+
)
|
104
|
+
self.message_templates.append(message_template)
|
105
|
+
except Exception as e:
|
106
|
+
raise ValueError(
|
107
|
+
f"Failed to compile template for message {i}: {e}"
|
108
|
+
) from e
|
109
|
+
|
110
|
+
def _validate_message_flow(self) -> None:
|
111
|
+
"""Validate that message flow is appropriate for chat completion."""
|
112
|
+
user_messages = [msg for msg in self.message_templates if msg.role == "user"]
|
113
|
+
if not user_messages:
|
114
|
+
raise ValueError(
|
115
|
+
"Template must contain at least one message with role='user' for proper conversation flow."
|
116
|
+
)
|
117
|
+
|
118
|
+
if self.message_templates and self.message_templates[-1].role != "user":
|
119
|
+
raise ValueError(
|
120
|
+
f"The final message must have role='user' for proper chat completion. "
|
121
|
+
f"Got role='{self.message_templates[-1].role}' for the last message."
|
122
|
+
)
|
123
|
+
|
124
|
+
def get_message_templates(self) -> list[MessageTemplate]:
|
125
|
+
"""Return the compiled message templates."""
|
126
|
+
return self.message_templates
|
127
|
+
|
128
|
+
|
129
|
+
class PromptRenderer:
|
130
|
+
"""Handles rendering of message templates with variable substitution."""
|
131
|
+
|
132
|
+
def __init__(self, message_templates: list[MessageTemplate]):
|
133
|
+
"""Initialize with a list of message templates."""
|
134
|
+
self.message_templates = message_templates
|
135
|
+
|
136
|
+
def get_required_variables(self) -> set:
|
137
|
+
"""Extract all required variables from message templates."""
|
138
|
+
required_vars = set()
|
139
|
+
for msg_template in self.message_templates:
|
140
|
+
# Parse the original source to find undeclared variables
|
141
|
+
# Use the template's existing environment to ensure consistency
|
142
|
+
ast = msg_template.content_template.environment.parse(
|
143
|
+
msg_template.original_source
|
144
|
+
)
|
145
|
+
required_vars.update(meta.find_undeclared_variables(ast))
|
146
|
+
return required_vars
|
147
|
+
|
148
|
+
def resolve_template_vars(
|
149
|
+
self, sample: dict[str, Any], input_cols
|
150
|
+
) -> dict[str, Any]:
|
151
|
+
"""Resolve template variables from dataset columns based on input_cols.
|
152
|
+
|
153
|
+
Parameters
|
154
|
+
----------
|
155
|
+
sample : Dict[str, Any]
|
156
|
+
Input sample from dataset.
|
157
|
+
input_cols : Union[str, List[str], Dict[str, str]]
|
158
|
+
Input column specification - now maps dataset columns to template variables.
|
159
|
+
|
160
|
+
Returns
|
161
|
+
-------
|
162
|
+
Dict[str, Any]
|
163
|
+
Template variables mapped from dataset columns.
|
164
|
+
"""
|
165
|
+
template_vars = {}
|
166
|
+
|
167
|
+
if isinstance(input_cols, dict):
|
168
|
+
# Map dataset columns to template variables
|
169
|
+
for dataset_col, template_var in input_cols.items():
|
170
|
+
if dataset_col in sample:
|
171
|
+
template_vars[template_var] = sample[dataset_col]
|
172
|
+
else:
|
173
|
+
logger.warning(
|
174
|
+
f"Dataset column '{dataset_col}' not found in sample"
|
175
|
+
)
|
176
|
+
else:
|
177
|
+
# Use column names directly as template variables
|
178
|
+
for col in input_cols:
|
179
|
+
if col in sample:
|
180
|
+
template_vars[col] = sample[col]
|
181
|
+
else:
|
182
|
+
logger.warning(f"Dataset column '{col}' not found in sample")
|
183
|
+
|
184
|
+
return template_vars
|
185
|
+
|
186
|
+
def render_messages(self, template_vars: dict[str, Any]) -> list[ChatMessage]:
|
187
|
+
"""Render all message templates with the given variables.
|
188
|
+
|
189
|
+
Parameters
|
190
|
+
----------
|
191
|
+
template_vars : Dict[str, Any]
|
192
|
+
Variables to substitute in templates.
|
193
|
+
|
194
|
+
Returns
|
195
|
+
-------
|
196
|
+
List[ChatMessage]
|
197
|
+
List of rendered and validated chat messages.
|
198
|
+
"""
|
199
|
+
rendered_messages = []
|
200
|
+
|
201
|
+
for i, msg_template in enumerate(self.message_templates):
|
202
|
+
try:
|
203
|
+
rendered_content = msg_template.content_template.render(
|
204
|
+
template_vars
|
205
|
+
).strip()
|
206
|
+
if rendered_content: # Only add non-empty messages
|
207
|
+
chat_message = ChatMessage(
|
208
|
+
role=msg_template.role, content=rendered_content
|
209
|
+
)
|
210
|
+
rendered_messages.append(chat_message)
|
211
|
+
except Exception as e:
|
212
|
+
logger.warning(f"Failed to render message {i}: {e}")
|
213
|
+
continue
|
214
|
+
|
215
|
+
return rendered_messages
|
216
|
+
|
217
|
+
|
218
|
+
@BlockRegistry.register(
|
219
|
+
"PromptBuilderBlock",
|
220
|
+
"llm",
|
221
|
+
"Formats prompts into structured chat messages or plain text using Jinja templates",
|
222
|
+
)
|
223
|
+
class PromptBuilderBlock(BaseBlock):
|
224
|
+
"""Block for formatting prompts into structured chat messages or plain text.
|
225
|
+
|
226
|
+
This block takes input from dataset columns, applies Jinja templates from a YAML config
|
227
|
+
containing a list of messages, and outputs either structured chat messages or formatted text.
|
228
|
+
|
229
|
+
Parameters
|
230
|
+
----------
|
231
|
+
block_name : str
|
232
|
+
Name of the block.
|
233
|
+
input_cols : Union[str, List[str], Dict[str, str]]
|
234
|
+
Input column specification:
|
235
|
+
- str: Single column name
|
236
|
+
- List[str]: List of column names
|
237
|
+
- Dict[str, str]: Mapping from dataset column names to template variables
|
238
|
+
output_cols : str
|
239
|
+
Name of the output column where formatted content will be saved.
|
240
|
+
prompt_config_path : str
|
241
|
+
Path to YAML file containing list of message objects with 'role' and 'content' fields.
|
242
|
+
format_as_messages : bool, optional
|
243
|
+
Whether to format output as chat messages (default True).
|
244
|
+
If True, outputs List[Dict[str, str]] with 'role' and 'content' keys.
|
245
|
+
If False, outputs concatenated string with role prefixes.
|
246
|
+
"""
|
247
|
+
|
248
|
+
prompt_config_path: str = Field(
|
249
|
+
..., description="Path to YAML file containing the Jinja template configuration"
|
250
|
+
)
|
251
|
+
format_as_messages: bool = Field(
|
252
|
+
True, description="Whether to format output as chat messages"
|
253
|
+
)
|
254
|
+
|
255
|
+
# Internal fields for configuration and renderer
|
256
|
+
prompt_template_config: Optional[PromptTemplateConfig] = Field(
|
257
|
+
None, description="Loaded prompt template configuration", exclude=True
|
258
|
+
)
|
259
|
+
prompt_renderer: Optional[PromptRenderer] = Field(
|
260
|
+
None, description="Prompt renderer instance", exclude=True
|
261
|
+
)
|
262
|
+
|
263
|
+
@field_validator("output_cols", mode="after")
|
264
|
+
@classmethod
|
265
|
+
def validate_single_output_col(cls, v):
|
266
|
+
"""Validate that exactly one output column is specified."""
|
267
|
+
if len(v) != 1:
|
268
|
+
raise ValueError(
|
269
|
+
f"PromptBuilderBlock expects exactly one output column, got {len(v)}: {v}"
|
270
|
+
)
|
271
|
+
return v
|
272
|
+
|
273
|
+
def model_post_init(self, __context: Any) -> None:
|
274
|
+
"""Initialize the block after Pydantic validation."""
|
275
|
+
# Load and validate prompt configuration
|
276
|
+
self.prompt_template_config = PromptTemplateConfig(self.prompt_config_path)
|
277
|
+
|
278
|
+
# Initialize prompt renderer
|
279
|
+
message_templates = self.prompt_template_config.get_message_templates()
|
280
|
+
self.prompt_renderer = PromptRenderer(message_templates)
|
281
|
+
|
282
|
+
def _validate_custom(self, dataset: Dataset) -> None:
|
283
|
+
if len(dataset) > 0:
|
284
|
+
# Get required variables from all message templates
|
285
|
+
required_vars = self.prompt_renderer.get_required_variables()
|
286
|
+
|
287
|
+
sample = dataset[0]
|
288
|
+
template_vars = self.prompt_renderer.resolve_template_vars(
|
289
|
+
sample, self.input_cols
|
290
|
+
)
|
291
|
+
missing_vars = required_vars - set(template_vars.keys())
|
292
|
+
|
293
|
+
if missing_vars:
|
294
|
+
raise TemplateValidationError(
|
295
|
+
block_name=self.block_name,
|
296
|
+
missing_variables=list(missing_vars),
|
297
|
+
available_variables=list(template_vars.keys()),
|
298
|
+
)
|
299
|
+
|
300
|
+
def _generate(self, sample: dict[str, Any]) -> dict[str, Any]:
|
301
|
+
"""Generate formatted output for a single sample.
|
302
|
+
|
303
|
+
1. Resolve columns needed for prompt templating
|
304
|
+
2. Render each message template with the variables
|
305
|
+
3. Format as messages or concatenated string based on format_as_messages
|
306
|
+
|
307
|
+
Parameters
|
308
|
+
----------
|
309
|
+
sample : Dict[str, Any]
|
310
|
+
Input sample from dataset.
|
311
|
+
|
312
|
+
Returns
|
313
|
+
-------
|
314
|
+
Dict[str, Any]
|
315
|
+
Sample with formatted output added to specified output column.
|
316
|
+
"""
|
317
|
+
output_col = self.output_cols[0]
|
318
|
+
|
319
|
+
try:
|
320
|
+
# Step 1: Resolve template variables from dataset columns
|
321
|
+
template_vars = self.prompt_renderer.resolve_template_vars(
|
322
|
+
sample, self.input_cols
|
323
|
+
)
|
324
|
+
|
325
|
+
# Step 2: Render messages using the prompt renderer
|
326
|
+
rendered_messages = self.prompt_renderer.render_messages(template_vars)
|
327
|
+
|
328
|
+
# Step 3: Format output based on format_as_messages setting
|
329
|
+
if not rendered_messages:
|
330
|
+
logger.warning(f"No valid messages generated for sample: {sample}")
|
331
|
+
sample[output_col] = [] if self.format_as_messages else ""
|
332
|
+
elif self.format_as_messages:
|
333
|
+
# Convert to dict format for serialization
|
334
|
+
sample[output_col] = [msg.model_dump() for msg in rendered_messages]
|
335
|
+
else:
|
336
|
+
# Concatenate all messages into a single string
|
337
|
+
sample[output_col] = "\n\n".join(
|
338
|
+
[f"{msg.role}: {msg.content}" for msg in rendered_messages]
|
339
|
+
)
|
340
|
+
|
341
|
+
except Exception as e:
|
342
|
+
logger.error(f"Failed to format sample: {e}")
|
343
|
+
sample[output_col] = [] if self.format_as_messages else ""
|
344
|
+
|
345
|
+
return sample
|
346
|
+
|
347
|
+
def generate(self, samples: Dataset, **_kwargs: Any) -> Dataset:
|
348
|
+
"""Generate formatted output for all samples using dataset map.
|
349
|
+
|
350
|
+
Parameters
|
351
|
+
----------
|
352
|
+
samples : Dataset
|
353
|
+
Input dataset containing samples to be formatted.
|
354
|
+
**kwargs : Dict[str, Any]
|
355
|
+
Additional keyword arguments (unused in this block).
|
356
|
+
|
357
|
+
Returns
|
358
|
+
-------
|
359
|
+
Dataset
|
360
|
+
Dataset with the formatted output added to the specified column.
|
361
|
+
"""
|
362
|
+
logger.debug(f"Formatting prompts for {len(samples)} samples")
|
363
|
+
|
364
|
+
# Use dataset map for efficient processing
|
365
|
+
formatted_dataset = samples.map(self._generate)
|
366
|
+
|
367
|
+
logger.debug(f"Successfully formatted {len(formatted_dataset)} samples")
|
368
|
+
return formatted_dataset
|