sdg-hub 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/__init__.py +28 -1
- sdg_hub/_version.py +2 -2
- sdg_hub/core/__init__.py +22 -0
- sdg_hub/core/blocks/__init__.py +58 -0
- sdg_hub/core/blocks/base.py +313 -0
- sdg_hub/core/blocks/deprecated_blocks/__init__.py +29 -0
- sdg_hub/core/blocks/deprecated_blocks/combine_columns.py +93 -0
- sdg_hub/core/blocks/deprecated_blocks/duplicate_columns.py +88 -0
- sdg_hub/core/blocks/deprecated_blocks/filter_by_value.py +103 -0
- sdg_hub/core/blocks/deprecated_blocks/flatten_columns.py +94 -0
- sdg_hub/core/blocks/deprecated_blocks/llmblock.py +479 -0
- sdg_hub/core/blocks/deprecated_blocks/rename_columns.py +88 -0
- sdg_hub/core/blocks/deprecated_blocks/sample_populator.py +58 -0
- sdg_hub/core/blocks/deprecated_blocks/selector.py +97 -0
- sdg_hub/core/blocks/deprecated_blocks/set_to_majority_value.py +88 -0
- sdg_hub/core/blocks/evaluation/__init__.py +9 -0
- sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +564 -0
- sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +564 -0
- sdg_hub/core/blocks/evaluation/verify_question_block.py +564 -0
- sdg_hub/core/blocks/filtering/__init__.py +12 -0
- sdg_hub/core/blocks/filtering/column_value_filter.py +188 -0
- sdg_hub/core/blocks/llm/__init__.py +25 -0
- sdg_hub/core/blocks/llm/client_manager.py +398 -0
- sdg_hub/core/blocks/llm/config.py +336 -0
- sdg_hub/core/blocks/llm/error_handler.py +368 -0
- sdg_hub/core/blocks/llm/llm_chat_block.py +542 -0
- sdg_hub/core/blocks/llm/prompt_builder_block.py +368 -0
- sdg_hub/core/blocks/llm/text_parser_block.py +310 -0
- sdg_hub/core/blocks/registry.py +331 -0
- sdg_hub/core/blocks/transform/__init__.py +23 -0
- sdg_hub/core/blocks/transform/duplicate_columns.py +88 -0
- sdg_hub/core/blocks/transform/index_based_mapper.py +225 -0
- sdg_hub/core/blocks/transform/melt_columns.py +126 -0
- sdg_hub/core/blocks/transform/rename_columns.py +69 -0
- sdg_hub/core/blocks/transform/text_concat.py +102 -0
- sdg_hub/core/blocks/transform/uniform_col_val_setter.py +101 -0
- sdg_hub/core/flow/__init__.py +20 -0
- sdg_hub/core/flow/base.py +980 -0
- sdg_hub/core/flow/metadata.py +344 -0
- sdg_hub/core/flow/migration.py +187 -0
- sdg_hub/core/flow/registry.py +330 -0
- sdg_hub/core/flow/validation.py +265 -0
- sdg_hub/{utils → core/utils}/__init__.py +6 -4
- sdg_hub/{utils → core/utils}/datautils.py +1 -3
- sdg_hub/core/utils/error_handling.py +208 -0
- sdg_hub/{utils → core/utils}/path_resolution.py +2 -2
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/atomic_facts.yaml +40 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/detailed_summary.yaml +13 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_faithfulness.yaml +64 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_question.yaml +29 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_relevancy.yaml +81 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/extractive_summary.yaml +13 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +191 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/generate_questions_responses.yaml +54 -0
- sdg_hub-0.2.0.dist-info/METADATA +218 -0
- sdg_hub-0.2.0.dist-info/RECORD +63 -0
- sdg_hub/blocks/__init__.py +0 -42
- sdg_hub/blocks/block.py +0 -96
- sdg_hub/blocks/llmblock.py +0 -375
- sdg_hub/blocks/openaichatblock.py +0 -556
- sdg_hub/blocks/utilblocks.py +0 -597
- sdg_hub/checkpointer.py +0 -139
- sdg_hub/configs/annotations/cot_reflection.yaml +0 -34
- sdg_hub/configs/annotations/detailed_annotations.yaml +0 -28
- sdg_hub/configs/annotations/detailed_description.yaml +0 -10
- sdg_hub/configs/annotations/detailed_description_icl.yaml +0 -32
- sdg_hub/configs/annotations/simple_annotations.yaml +0 -9
- sdg_hub/configs/knowledge/__init__.py +0 -0
- sdg_hub/configs/knowledge/atomic_facts.yaml +0 -46
- sdg_hub/configs/knowledge/auxilary_instructions.yaml +0 -35
- sdg_hub/configs/knowledge/detailed_summary.yaml +0 -18
- sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +0 -68
- sdg_hub/configs/knowledge/evaluate_question.yaml +0 -38
- sdg_hub/configs/knowledge/evaluate_relevancy.yaml +0 -84
- sdg_hub/configs/knowledge/extractive_summary.yaml +0 -18
- sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +0 -39
- sdg_hub/configs/knowledge/generate_questions.yaml +0 -82
- sdg_hub/configs/knowledge/generate_questions_responses.yaml +0 -56
- sdg_hub/configs/knowledge/generate_responses.yaml +0 -86
- sdg_hub/configs/knowledge/mcq_generation.yaml +0 -83
- sdg_hub/configs/knowledge/router.yaml +0 -12
- sdg_hub/configs/knowledge/simple_generate_qa.yaml +0 -34
- sdg_hub/configs/reasoning/__init__.py +0 -0
- sdg_hub/configs/reasoning/dynamic_cot.yaml +0 -40
- sdg_hub/configs/skills/__init__.py +0 -0
- sdg_hub/configs/skills/analyzer.yaml +0 -48
- sdg_hub/configs/skills/annotation.yaml +0 -36
- sdg_hub/configs/skills/contexts.yaml +0 -28
- sdg_hub/configs/skills/critic.yaml +0 -60
- sdg_hub/configs/skills/evaluate_freeform_pair.yaml +0 -111
- sdg_hub/configs/skills/evaluate_freeform_questions.yaml +0 -78
- sdg_hub/configs/skills/evaluate_grounded_pair.yaml +0 -119
- sdg_hub/configs/skills/evaluate_grounded_questions.yaml +0 -51
- sdg_hub/configs/skills/freeform_questions.yaml +0 -34
- sdg_hub/configs/skills/freeform_responses.yaml +0 -39
- sdg_hub/configs/skills/grounded_questions.yaml +0 -38
- sdg_hub/configs/skills/grounded_responses.yaml +0 -59
- sdg_hub/configs/skills/icl_examples/STEM.yaml +0 -56
- sdg_hub/configs/skills/icl_examples/__init__.py +0 -0
- sdg_hub/configs/skills/icl_examples/coding.yaml +0 -97
- sdg_hub/configs/skills/icl_examples/extraction.yaml +0 -36
- sdg_hub/configs/skills/icl_examples/humanities.yaml +0 -71
- sdg_hub/configs/skills/icl_examples/math.yaml +0 -85
- sdg_hub/configs/skills/icl_examples/reasoning.yaml +0 -30
- sdg_hub/configs/skills/icl_examples/roleplay.yaml +0 -45
- sdg_hub/configs/skills/icl_examples/writing.yaml +0 -80
- sdg_hub/configs/skills/judge.yaml +0 -53
- sdg_hub/configs/skills/planner.yaml +0 -67
- sdg_hub/configs/skills/respond.yaml +0 -8
- sdg_hub/configs/skills/revised_responder.yaml +0 -78
- sdg_hub/configs/skills/router.yaml +0 -59
- sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +0 -27
- sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +0 -31
- sdg_hub/flow.py +0 -477
- sdg_hub/flow_runner.py +0 -450
- sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +0 -13
- sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +0 -12
- sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +0 -89
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +0 -148
- sdg_hub/flows/generation/skills/improve_responses.yaml +0 -103
- sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +0 -12
- sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +0 -12
- sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +0 -80
- sdg_hub/flows/generation/skills/synth_skills.yaml +0 -59
- sdg_hub/pipeline.py +0 -121
- sdg_hub/prompts.py +0 -74
- sdg_hub/registry.py +0 -122
- sdg_hub/sdg.py +0 -206
- sdg_hub/utils/config_validation.py +0 -91
- sdg_hub/utils/error_handling.py +0 -94
- sdg_hub/utils/validation_result.py +0 -10
- sdg_hub-0.1.3.dist-info/METADATA +0 -190
- sdg_hub-0.1.3.dist-info/RECORD +0 -89
- sdg_hub/{logger_config.py → core/utils/logger_config.py} +1 -1
- /sdg_hub/{configs/__init__.py → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/README.md} +0 -0
- /sdg_hub/{configs/annotations → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab}/__init__.py +0 -0
- {sdg_hub-0.1.3.dist-info → sdg_hub-0.2.0.dist-info}/WHEEL +0 -0
- {sdg_hub-0.1.3.dist-info → sdg_hub-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.1.3.dist-info → sdg_hub-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,368 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Prompt builder block for formatting prompts into structured chat messages or plain text.
|
3
|
+
|
4
|
+
This module provides the PromptBuilderBlock for handling LLM prompt formatting,
|
5
|
+
including conversion to OpenAI Messages format and template rendering.
|
6
|
+
"""
|
7
|
+
|
8
|
+
# Standard
|
9
|
+
from typing import Any, Literal, Optional
|
10
|
+
|
11
|
+
# Third Party
|
12
|
+
from datasets import Dataset
|
13
|
+
from jinja2 import Template, meta
|
14
|
+
from pydantic import BaseModel, Field, field_validator
|
15
|
+
import yaml
|
16
|
+
|
17
|
+
# Local
|
18
|
+
from ...utils.error_handling import TemplateValidationError
|
19
|
+
from ...utils.logger_config import setup_logger
|
20
|
+
from ..base import BaseBlock
|
21
|
+
from ..registry import BlockRegistry
|
22
|
+
|
23
|
+
logger = setup_logger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
class ChatMessage(BaseModel):
|
27
|
+
"""Pydantic model for chat messages with proper validation."""
|
28
|
+
|
29
|
+
role: Literal["system", "user", "assistant", "tool"]
|
30
|
+
content: str
|
31
|
+
|
32
|
+
@field_validator("content")
|
33
|
+
@classmethod
|
34
|
+
def validate_content_not_empty(cls, v: str) -> str:
|
35
|
+
"""Ensure content is not empty or just whitespace."""
|
36
|
+
if not v or not v.strip():
|
37
|
+
raise ValueError("Message content cannot be empty")
|
38
|
+
return v.strip()
|
39
|
+
|
40
|
+
|
41
|
+
class MessageTemplate(BaseModel):
|
42
|
+
"""Template for a chat message with Jinja2 template and original source."""
|
43
|
+
|
44
|
+
role: Literal["system", "user", "assistant", "tool"]
|
45
|
+
content_template: Template
|
46
|
+
original_source: str
|
47
|
+
|
48
|
+
model_config = {"arbitrary_types_allowed": True}
|
49
|
+
|
50
|
+
|
51
|
+
class PromptTemplateConfig:
|
52
|
+
"""Self-contained class for loading and validating YAML prompt configurations."""
|
53
|
+
|
54
|
+
def __init__(self, config_path: str):
|
55
|
+
"""Initialize with path to YAML config file."""
|
56
|
+
self.config_path = config_path
|
57
|
+
self.message_templates: list[MessageTemplate] = []
|
58
|
+
self._load_and_validate()
|
59
|
+
|
60
|
+
def _load_and_validate(self) -> None:
|
61
|
+
"""Load YAML config and validate format."""
|
62
|
+
try:
|
63
|
+
with open(self.config_path, encoding="utf-8") as config_file:
|
64
|
+
config = yaml.safe_load(config_file)
|
65
|
+
|
66
|
+
if not isinstance(config, list):
|
67
|
+
raise ValueError(
|
68
|
+
"Template config must be a list of message objects"
|
69
|
+
)
|
70
|
+
|
71
|
+
if not config:
|
72
|
+
raise ValueError("Prompt configuration cannot be empty")
|
73
|
+
|
74
|
+
self._compile_templates(config)
|
75
|
+
self._validate_message_flow()
|
76
|
+
|
77
|
+
except FileNotFoundError:
|
78
|
+
logger.error(f"Configuration file not found: {self.config_path}")
|
79
|
+
raise
|
80
|
+
except yaml.YAMLError as e:
|
81
|
+
logger.error(f"Error parsing YAML from {self.config_path}: {e}")
|
82
|
+
raise
|
83
|
+
except Exception as e:
|
84
|
+
logger.error(
|
85
|
+
f"Unexpected error reading config file {self.config_path}: {e}"
|
86
|
+
)
|
87
|
+
raise
|
88
|
+
|
89
|
+
def _compile_templates(self, config: list[dict[str, Any]]) -> None:
|
90
|
+
"""Compile Jinja templates for each message in the config."""
|
91
|
+
for i, message in enumerate(config):
|
92
|
+
if "role" not in message or "content" not in message:
|
93
|
+
raise ValueError(
|
94
|
+
f"Message {i} must have 'role' and 'content' fields. Got: {message.keys()}"
|
95
|
+
)
|
96
|
+
|
97
|
+
try:
|
98
|
+
content_source = message["content"]
|
99
|
+
message_template = MessageTemplate(
|
100
|
+
role=message["role"],
|
101
|
+
content_template=Template(content_source),
|
102
|
+
original_source=content_source,
|
103
|
+
)
|
104
|
+
self.message_templates.append(message_template)
|
105
|
+
except Exception as e:
|
106
|
+
raise ValueError(
|
107
|
+
f"Failed to compile template for message {i}: {e}"
|
108
|
+
) from e
|
109
|
+
|
110
|
+
def _validate_message_flow(self) -> None:
|
111
|
+
"""Validate that message flow is appropriate for chat completion."""
|
112
|
+
user_messages = [msg for msg in self.message_templates if msg.role == "user"]
|
113
|
+
if not user_messages:
|
114
|
+
raise ValueError(
|
115
|
+
"Template must contain at least one message with role='user' for proper conversation flow."
|
116
|
+
)
|
117
|
+
|
118
|
+
if self.message_templates and self.message_templates[-1].role != "user":
|
119
|
+
raise ValueError(
|
120
|
+
f"The final message must have role='user' for proper chat completion. "
|
121
|
+
f"Got role='{self.message_templates[-1].role}' for the last message."
|
122
|
+
)
|
123
|
+
|
124
|
+
def get_message_templates(self) -> list[MessageTemplate]:
|
125
|
+
"""Return the compiled message templates."""
|
126
|
+
return self.message_templates
|
127
|
+
|
128
|
+
|
129
|
+
class PromptRenderer:
|
130
|
+
"""Handles rendering of message templates with variable substitution."""
|
131
|
+
|
132
|
+
def __init__(self, message_templates: list[MessageTemplate]):
|
133
|
+
"""Initialize with a list of message templates."""
|
134
|
+
self.message_templates = message_templates
|
135
|
+
|
136
|
+
def get_required_variables(self) -> set:
|
137
|
+
"""Extract all required variables from message templates."""
|
138
|
+
required_vars = set()
|
139
|
+
for msg_template in self.message_templates:
|
140
|
+
# Parse the original source to find undeclared variables
|
141
|
+
# Use the template's existing environment to ensure consistency
|
142
|
+
ast = msg_template.content_template.environment.parse(
|
143
|
+
msg_template.original_source
|
144
|
+
)
|
145
|
+
required_vars.update(meta.find_undeclared_variables(ast))
|
146
|
+
return required_vars
|
147
|
+
|
148
|
+
def resolve_template_vars(
|
149
|
+
self, sample: dict[str, Any], input_cols
|
150
|
+
) -> dict[str, Any]:
|
151
|
+
"""Resolve template variables from dataset columns based on input_cols.
|
152
|
+
|
153
|
+
Parameters
|
154
|
+
----------
|
155
|
+
sample : Dict[str, Any]
|
156
|
+
Input sample from dataset.
|
157
|
+
input_cols : Union[str, List[str], Dict[str, str]]
|
158
|
+
Input column specification - now maps dataset columns to template variables.
|
159
|
+
|
160
|
+
Returns
|
161
|
+
-------
|
162
|
+
Dict[str, Any]
|
163
|
+
Template variables mapped from dataset columns.
|
164
|
+
"""
|
165
|
+
template_vars = {}
|
166
|
+
|
167
|
+
if isinstance(input_cols, dict):
|
168
|
+
# Map dataset columns to template variables
|
169
|
+
for dataset_col, template_var in input_cols.items():
|
170
|
+
if dataset_col in sample:
|
171
|
+
template_vars[template_var] = sample[dataset_col]
|
172
|
+
else:
|
173
|
+
logger.warning(
|
174
|
+
f"Dataset column '{dataset_col}' not found in sample"
|
175
|
+
)
|
176
|
+
else:
|
177
|
+
# Use column names directly as template variables
|
178
|
+
for col in input_cols:
|
179
|
+
if col in sample:
|
180
|
+
template_vars[col] = sample[col]
|
181
|
+
else:
|
182
|
+
logger.warning(f"Dataset column '{col}' not found in sample")
|
183
|
+
|
184
|
+
return template_vars
|
185
|
+
|
186
|
+
def render_messages(self, template_vars: dict[str, Any]) -> list[ChatMessage]:
|
187
|
+
"""Render all message templates with the given variables.
|
188
|
+
|
189
|
+
Parameters
|
190
|
+
----------
|
191
|
+
template_vars : Dict[str, Any]
|
192
|
+
Variables to substitute in templates.
|
193
|
+
|
194
|
+
Returns
|
195
|
+
-------
|
196
|
+
List[ChatMessage]
|
197
|
+
List of rendered and validated chat messages.
|
198
|
+
"""
|
199
|
+
rendered_messages = []
|
200
|
+
|
201
|
+
for i, msg_template in enumerate(self.message_templates):
|
202
|
+
try:
|
203
|
+
rendered_content = msg_template.content_template.render(
|
204
|
+
template_vars
|
205
|
+
).strip()
|
206
|
+
if rendered_content: # Only add non-empty messages
|
207
|
+
chat_message = ChatMessage(
|
208
|
+
role=msg_template.role, content=rendered_content
|
209
|
+
)
|
210
|
+
rendered_messages.append(chat_message)
|
211
|
+
except Exception as e:
|
212
|
+
logger.warning(f"Failed to render message {i}: {e}")
|
213
|
+
continue
|
214
|
+
|
215
|
+
return rendered_messages
|
216
|
+
|
217
|
+
|
218
|
+
@BlockRegistry.register(
|
219
|
+
"PromptBuilderBlock",
|
220
|
+
"llm",
|
221
|
+
"Formats prompts into structured chat messages or plain text using Jinja templates",
|
222
|
+
)
|
223
|
+
class PromptBuilderBlock(BaseBlock):
|
224
|
+
"""Block for formatting prompts into structured chat messages or plain text.
|
225
|
+
|
226
|
+
This block takes input from dataset columns, applies Jinja templates from a YAML config
|
227
|
+
containing a list of messages, and outputs either structured chat messages or formatted text.
|
228
|
+
|
229
|
+
Parameters
|
230
|
+
----------
|
231
|
+
block_name : str
|
232
|
+
Name of the block.
|
233
|
+
input_cols : Union[str, List[str], Dict[str, str]]
|
234
|
+
Input column specification:
|
235
|
+
- str: Single column name
|
236
|
+
- List[str]: List of column names
|
237
|
+
- Dict[str, str]: Mapping from dataset column names to template variables
|
238
|
+
output_cols : str
|
239
|
+
Name of the output column where formatted content will be saved.
|
240
|
+
prompt_config_path : str
|
241
|
+
Path to YAML file containing list of message objects with 'role' and 'content' fields.
|
242
|
+
format_as_messages : bool, optional
|
243
|
+
Whether to format output as chat messages (default True).
|
244
|
+
If True, outputs List[Dict[str, str]] with 'role' and 'content' keys.
|
245
|
+
If False, outputs concatenated string with role prefixes.
|
246
|
+
"""
|
247
|
+
|
248
|
+
prompt_config_path: str = Field(
|
249
|
+
..., description="Path to YAML file containing the Jinja template configuration"
|
250
|
+
)
|
251
|
+
format_as_messages: bool = Field(
|
252
|
+
True, description="Whether to format output as chat messages"
|
253
|
+
)
|
254
|
+
|
255
|
+
# Internal fields for configuration and renderer
|
256
|
+
prompt_template_config: Optional[PromptTemplateConfig] = Field(
|
257
|
+
None, description="Loaded prompt template configuration", exclude=True
|
258
|
+
)
|
259
|
+
prompt_renderer: Optional[PromptRenderer] = Field(
|
260
|
+
None, description="Prompt renderer instance", exclude=True
|
261
|
+
)
|
262
|
+
|
263
|
+
@field_validator("output_cols", mode="after")
|
264
|
+
@classmethod
|
265
|
+
def validate_single_output_col(cls, v):
|
266
|
+
"""Validate that exactly one output column is specified."""
|
267
|
+
if len(v) != 1:
|
268
|
+
raise ValueError(
|
269
|
+
f"PromptBuilderBlock expects exactly one output column, got {len(v)}: {v}"
|
270
|
+
)
|
271
|
+
return v
|
272
|
+
|
273
|
+
def model_post_init(self, __context: Any) -> None:
|
274
|
+
"""Initialize the block after Pydantic validation."""
|
275
|
+
# Load and validate prompt configuration
|
276
|
+
self.prompt_template_config = PromptTemplateConfig(self.prompt_config_path)
|
277
|
+
|
278
|
+
# Initialize prompt renderer
|
279
|
+
message_templates = self.prompt_template_config.get_message_templates()
|
280
|
+
self.prompt_renderer = PromptRenderer(message_templates)
|
281
|
+
|
282
|
+
def _validate_custom(self, dataset: Dataset) -> None:
|
283
|
+
if len(dataset) > 0:
|
284
|
+
# Get required variables from all message templates
|
285
|
+
required_vars = self.prompt_renderer.get_required_variables()
|
286
|
+
|
287
|
+
sample = dataset[0]
|
288
|
+
template_vars = self.prompt_renderer.resolve_template_vars(
|
289
|
+
sample, self.input_cols
|
290
|
+
)
|
291
|
+
missing_vars = required_vars - set(template_vars.keys())
|
292
|
+
|
293
|
+
if missing_vars:
|
294
|
+
raise TemplateValidationError(
|
295
|
+
block_name=self.block_name,
|
296
|
+
missing_variables=list(missing_vars),
|
297
|
+
available_variables=list(template_vars.keys()),
|
298
|
+
)
|
299
|
+
|
300
|
+
def _generate(self, sample: dict[str, Any]) -> dict[str, Any]:
|
301
|
+
"""Generate formatted output for a single sample.
|
302
|
+
|
303
|
+
1. Resolve columns needed for prompt templating
|
304
|
+
2. Render each message template with the variables
|
305
|
+
3. Format as messages or concatenated string based on format_as_messages
|
306
|
+
|
307
|
+
Parameters
|
308
|
+
----------
|
309
|
+
sample : Dict[str, Any]
|
310
|
+
Input sample from dataset.
|
311
|
+
|
312
|
+
Returns
|
313
|
+
-------
|
314
|
+
Dict[str, Any]
|
315
|
+
Sample with formatted output added to specified output column.
|
316
|
+
"""
|
317
|
+
output_col = self.output_cols[0]
|
318
|
+
|
319
|
+
try:
|
320
|
+
# Step 1: Resolve template variables from dataset columns
|
321
|
+
template_vars = self.prompt_renderer.resolve_template_vars(
|
322
|
+
sample, self.input_cols
|
323
|
+
)
|
324
|
+
|
325
|
+
# Step 2: Render messages using the prompt renderer
|
326
|
+
rendered_messages = self.prompt_renderer.render_messages(template_vars)
|
327
|
+
|
328
|
+
# Step 3: Format output based on format_as_messages setting
|
329
|
+
if not rendered_messages:
|
330
|
+
logger.warning(f"No valid messages generated for sample: {sample}")
|
331
|
+
sample[output_col] = [] if self.format_as_messages else ""
|
332
|
+
elif self.format_as_messages:
|
333
|
+
# Convert to dict format for serialization
|
334
|
+
sample[output_col] = [msg.model_dump() for msg in rendered_messages]
|
335
|
+
else:
|
336
|
+
# Concatenate all messages into a single string
|
337
|
+
sample[output_col] = "\n\n".join(
|
338
|
+
[f"{msg.role}: {msg.content}" for msg in rendered_messages]
|
339
|
+
)
|
340
|
+
|
341
|
+
except Exception as e:
|
342
|
+
logger.error(f"Failed to format sample: {e}")
|
343
|
+
sample[output_col] = [] if self.format_as_messages else ""
|
344
|
+
|
345
|
+
return sample
|
346
|
+
|
347
|
+
def generate(self, samples: Dataset, **_kwargs: Any) -> Dataset:
|
348
|
+
"""Generate formatted output for all samples using dataset map.
|
349
|
+
|
350
|
+
Parameters
|
351
|
+
----------
|
352
|
+
samples : Dataset
|
353
|
+
Input dataset containing samples to be formatted.
|
354
|
+
**kwargs : Dict[str, Any]
|
355
|
+
Additional keyword arguments (unused in this block).
|
356
|
+
|
357
|
+
Returns
|
358
|
+
-------
|
359
|
+
Dataset
|
360
|
+
Dataset with the formatted output added to the specified column.
|
361
|
+
"""
|
362
|
+
logger.debug(f"Formatting prompts for {len(samples)} samples")
|
363
|
+
|
364
|
+
# Use dataset map for efficient processing
|
365
|
+
formatted_dataset = samples.map(self._generate)
|
366
|
+
|
367
|
+
logger.debug(f"Successfully formatted {len(formatted_dataset)} samples")
|
368
|
+
return formatted_dataset
|
@@ -0,0 +1,310 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Text parser block for parsing and post-processing LLM outputs.
|
3
|
+
|
4
|
+
This module provides the TextParserBlock for handling output parsing using
|
5
|
+
start/end tags, custom regex patterns, and cleanup operations.
|
6
|
+
"""
|
7
|
+
|
8
|
+
# Standard
|
9
|
+
from typing import Any, Optional
|
10
|
+
import re
|
11
|
+
|
12
|
+
# Third Party
|
13
|
+
from datasets import Dataset
|
14
|
+
from pydantic import Field, field_validator, model_validator
|
15
|
+
|
16
|
+
# Local
|
17
|
+
from ...utils.logger_config import setup_logger
|
18
|
+
from ..base import BaseBlock
|
19
|
+
from ..registry import BlockRegistry
|
20
|
+
|
21
|
+
logger = setup_logger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
@BlockRegistry.register(
|
25
|
+
"TextParserBlock",
|
26
|
+
"llm",
|
27
|
+
"Parses and post-processes LLM outputs using tags or regex patterns",
|
28
|
+
)
|
29
|
+
class TextParserBlock(BaseBlock):
|
30
|
+
"""Block for parsing and post-processing LLM outputs.
|
31
|
+
|
32
|
+
This block handles output parsing using start/end tags, custom regex patterns,
|
33
|
+
and cleanup operations. It expects exactly one input column containing raw LLM output.
|
34
|
+
|
35
|
+
Attributes
|
36
|
+
----------
|
37
|
+
block_name : str
|
38
|
+
Unique identifier for this block instance.
|
39
|
+
input_cols : Union[str, List[str], Dict[str, Any], None]
|
40
|
+
Input column name(s) containing raw LLM output. Must specify exactly one column.
|
41
|
+
output_cols : Union[str, List[str], Dict[str, Any], None]
|
42
|
+
Output column name(s) for parsed results.
|
43
|
+
start_tags : List[str]
|
44
|
+
List of start tags for tag-based parsing.
|
45
|
+
end_tags : List[str]
|
46
|
+
List of end tags for tag-based parsing.
|
47
|
+
parsing_pattern : Optional[str]
|
48
|
+
Regex pattern for custom parsing.
|
49
|
+
parser_cleanup_tags : Optional[List[str]]
|
50
|
+
List of tags to clean from parsed output.
|
51
|
+
"""
|
52
|
+
|
53
|
+
start_tags: list[str] = Field(
|
54
|
+
default_factory=list, description="List of start tags for tag-based parsing"
|
55
|
+
)
|
56
|
+
end_tags: list[str] = Field(
|
57
|
+
default_factory=list, description="List of end tags for tag-based parsing"
|
58
|
+
)
|
59
|
+
parsing_pattern: Optional[str] = Field(
|
60
|
+
default=None, description="Regex pattern for custom parsing"
|
61
|
+
)
|
62
|
+
parser_cleanup_tags: Optional[list[str]] = Field(
|
63
|
+
default=None, description="List of tags to clean from parsed output"
|
64
|
+
)
|
65
|
+
|
66
|
+
@field_validator("start_tags", "end_tags", mode="before")
|
67
|
+
@classmethod
|
68
|
+
def normalize_tags(cls, v):
|
69
|
+
"""Normalize tag lists to ensure they are always lists."""
|
70
|
+
if v is None:
|
71
|
+
return []
|
72
|
+
if isinstance(v, str):
|
73
|
+
return [v]
|
74
|
+
if isinstance(v, list):
|
75
|
+
return v
|
76
|
+
raise ValueError(f"Tags must be a string, list, or None, got {type(v)}")
|
77
|
+
|
78
|
+
@field_validator("parser_cleanup_tags", mode="before")
|
79
|
+
@classmethod
|
80
|
+
def normalize_cleanup_tags(cls, v):
|
81
|
+
"""Normalize cleanup tags to ensure they are always lists when not None."""
|
82
|
+
if v is None:
|
83
|
+
return None
|
84
|
+
if isinstance(v, str):
|
85
|
+
return [v]
|
86
|
+
if isinstance(v, list):
|
87
|
+
return v
|
88
|
+
raise ValueError(f"Cleanup tags must be a string, list, or None, got {type(v)}")
|
89
|
+
|
90
|
+
@model_validator(mode="after")
|
91
|
+
def validate_parsing_configuration(self):
|
92
|
+
"""Validate that parsing configuration is consistent."""
|
93
|
+
# Validate that at least one parsing method is configured
|
94
|
+
has_regex = self.parsing_pattern is not None
|
95
|
+
has_tags = bool(self.start_tags) or bool(self.end_tags)
|
96
|
+
|
97
|
+
if not has_regex and not has_tags:
|
98
|
+
raise ValueError(
|
99
|
+
"TextParserBlock requires at least one parsing method: "
|
100
|
+
"either 'parsing_pattern' (regex) or 'start_tags'/'end_tags' (tag-based parsing)"
|
101
|
+
)
|
102
|
+
|
103
|
+
# Validate tag parsing configuration
|
104
|
+
if has_tags:
|
105
|
+
if len(self.start_tags) != len(self.end_tags):
|
106
|
+
raise ValueError(
|
107
|
+
f"start_tags and end_tags must have the same length. "
|
108
|
+
f"Got {len(self.start_tags)} start_tags and {len(self.end_tags)} end_tags"
|
109
|
+
)
|
110
|
+
|
111
|
+
# We can't validate against output_cols here since they might not be normalized yet
|
112
|
+
# This validation will be moved to _validate_custom
|
113
|
+
|
114
|
+
return self
|
115
|
+
|
116
|
+
def _validate_custom(self, dataset: Dataset) -> None:
|
117
|
+
"""Validate TextParserBlock specific requirements.
|
118
|
+
|
119
|
+
Parameters
|
120
|
+
----------
|
121
|
+
dataset : Dataset
|
122
|
+
The dataset to validate.
|
123
|
+
|
124
|
+
Raises
|
125
|
+
------
|
126
|
+
ValueError
|
127
|
+
If TextParserBlock requirements are not met.
|
128
|
+
"""
|
129
|
+
# Validate that we have exactly one input column
|
130
|
+
if len(self.input_cols) == 0:
|
131
|
+
raise ValueError("TextParserBlock expects at least one input column")
|
132
|
+
if len(self.input_cols) > 1:
|
133
|
+
logger.warning(
|
134
|
+
f"TextParserBlock expects exactly one input column, but got {len(self.input_cols)}. "
|
135
|
+
f"Using the first column: {self.input_cols[0]}"
|
136
|
+
)
|
137
|
+
|
138
|
+
# Validate tag parsing against output columns (can only be done after model creation)
|
139
|
+
has_tags = bool(self.start_tags) or bool(self.end_tags)
|
140
|
+
if has_tags and len(self.start_tags) != len(self.output_cols):
|
141
|
+
raise ValueError(
|
142
|
+
f"When using tag-based parsing, the number of tag pairs must match output_cols. "
|
143
|
+
f"Got {len(self.start_tags)} tag pairs and {len(self.output_cols)} output columns"
|
144
|
+
)
|
145
|
+
|
146
|
+
def _extract_matches(
|
147
|
+
self, text: str, start_tag: Optional[str], end_tag: Optional[str]
|
148
|
+
) -> list[str]:
|
149
|
+
if not text:
|
150
|
+
return []
|
151
|
+
if not start_tag and not end_tag:
|
152
|
+
return [text.strip()]
|
153
|
+
|
154
|
+
pattern = ""
|
155
|
+
if start_tag:
|
156
|
+
pattern += re.escape(start_tag)
|
157
|
+
pattern += r"(.*?)"
|
158
|
+
if end_tag:
|
159
|
+
pattern += re.escape(end_tag)
|
160
|
+
elif start_tag:
|
161
|
+
pattern += "$"
|
162
|
+
|
163
|
+
return [match.strip() for match in re.findall(pattern, text, re.DOTALL)]
|
164
|
+
|
165
|
+
def _parse(self, generated_string: str) -> dict[str, list[str]]:
|
166
|
+
if self.parsing_pattern is not None:
|
167
|
+
return self._parse_with_regex(generated_string)
|
168
|
+
return self._parse_with_tags(generated_string)
|
169
|
+
|
170
|
+
def _parse_with_regex(self, generated_string: str) -> dict[str, list[str]]:
|
171
|
+
"""Parse using regex pattern."""
|
172
|
+
if self.parsing_pattern is None:
|
173
|
+
raise ValueError("parsing_pattern is required for regex parsing")
|
174
|
+
pattern = re.compile(self.parsing_pattern, re.DOTALL)
|
175
|
+
all_matches = pattern.findall(generated_string)
|
176
|
+
matches: dict[str, list[str]] = {
|
177
|
+
column_name: [] for column_name in self.output_cols
|
178
|
+
}
|
179
|
+
|
180
|
+
logger.debug(
|
181
|
+
f"Regex parsing found {len(all_matches)} matches with pattern: {self.parsing_pattern}"
|
182
|
+
)
|
183
|
+
|
184
|
+
if all_matches and isinstance(all_matches[0], tuple):
|
185
|
+
return self._process_tuple_matches(all_matches, matches)
|
186
|
+
return self._process_single_matches(all_matches, matches)
|
187
|
+
|
188
|
+
def _parse_with_tags(self, generated_string: str) -> dict[str, list[str]]:
|
189
|
+
"""Parse using start/end tags."""
|
190
|
+
matches: dict[str, list[str]] = {
|
191
|
+
column_name: [] for column_name in self.output_cols
|
192
|
+
}
|
193
|
+
|
194
|
+
for start_tag, end_tag, output_col in zip(
|
195
|
+
self.start_tags, self.end_tags, self.output_cols
|
196
|
+
):
|
197
|
+
extracted = self._extract_matches(generated_string, start_tag, end_tag)
|
198
|
+
matches[output_col] = extracted
|
199
|
+
logger.debug(
|
200
|
+
f"Tag parsing for '{output_col}' with tags '{start_tag}'/'{end_tag}' found {len(extracted)} matches"
|
201
|
+
)
|
202
|
+
|
203
|
+
return matches
|
204
|
+
|
205
|
+
def _process_tuple_matches(
|
206
|
+
self, all_matches: list, matches: dict[str, list[str]]
|
207
|
+
) -> dict[str, list[str]]:
|
208
|
+
"""Process regex matches that are tuples."""
|
209
|
+
for match in all_matches:
|
210
|
+
for column_name, value in zip(self.output_cols, match):
|
211
|
+
value = self._clean_value(value.strip())
|
212
|
+
matches[column_name].append(value)
|
213
|
+
return matches
|
214
|
+
|
215
|
+
def _process_single_matches(
|
216
|
+
self, all_matches: list, matches: dict[str, list[str]]
|
217
|
+
) -> dict[str, list[str]]:
|
218
|
+
"""Process regex matches that are single values."""
|
219
|
+
cleaned_matches = [self._clean_value(match.strip()) for match in all_matches]
|
220
|
+
matches[self.output_cols[0]] = cleaned_matches
|
221
|
+
return matches
|
222
|
+
|
223
|
+
def _clean_value(self, value: str) -> str:
|
224
|
+
"""Clean value by removing cleanup tags."""
|
225
|
+
if self.parser_cleanup_tags:
|
226
|
+
for clean_tag in self.parser_cleanup_tags:
|
227
|
+
value = value.replace(clean_tag, "")
|
228
|
+
return value
|
229
|
+
|
230
|
+
def _generate(self, sample: dict) -> list[dict]:
|
231
|
+
input_column = self.input_cols[0]
|
232
|
+
raw_output = sample[input_column]
|
233
|
+
|
234
|
+
# Handle list inputs (e.g., from LLMChatBlock with n > 1)
|
235
|
+
if isinstance(raw_output, list):
|
236
|
+
if not raw_output:
|
237
|
+
logger.warning(f"Input column '{input_column}' contains empty list")
|
238
|
+
return []
|
239
|
+
|
240
|
+
all_results = []
|
241
|
+
for i, response in enumerate(raw_output):
|
242
|
+
if not response or not isinstance(response, str):
|
243
|
+
logger.warning(
|
244
|
+
f"List item {i} in column '{input_column}' contains invalid data "
|
245
|
+
f"(empty or non-string): {type(response)}"
|
246
|
+
)
|
247
|
+
continue
|
248
|
+
|
249
|
+
parsed_outputs = self._parse(response)
|
250
|
+
|
251
|
+
if not parsed_outputs or not any(
|
252
|
+
len(value) > 0 for value in parsed_outputs.values()
|
253
|
+
):
|
254
|
+
logger.warning(
|
255
|
+
f"Failed to parse content from list item {i}. Raw output length: {len(response)}, "
|
256
|
+
f"parsing method: {'regex' if self.parsing_pattern else 'tags'}"
|
257
|
+
)
|
258
|
+
continue
|
259
|
+
|
260
|
+
# Create output rows for this response
|
261
|
+
max_length = max(len(value) for value in parsed_outputs.values())
|
262
|
+
for values in zip(
|
263
|
+
*(lst[:max_length] for lst in parsed_outputs.values())
|
264
|
+
):
|
265
|
+
all_results.append(
|
266
|
+
{**sample, **dict(zip(parsed_outputs.keys(), values))}
|
267
|
+
)
|
268
|
+
|
269
|
+
return all_results
|
270
|
+
|
271
|
+
# Handle string inputs (existing logic)
|
272
|
+
elif isinstance(raw_output, str):
|
273
|
+
if not raw_output:
|
274
|
+
logger.warning(f"Input column '{input_column}' contains empty string")
|
275
|
+
return []
|
276
|
+
|
277
|
+
parsed_outputs = self._parse(raw_output)
|
278
|
+
|
279
|
+
if not parsed_outputs or not any(
|
280
|
+
len(value) > 0 for value in parsed_outputs.values()
|
281
|
+
):
|
282
|
+
logger.warning(
|
283
|
+
f"Failed to parse any content from input. Raw output length: {len(raw_output)}, "
|
284
|
+
f"parsing method: {'regex' if self.parsing_pattern else 'tags'}"
|
285
|
+
)
|
286
|
+
return []
|
287
|
+
|
288
|
+
result = []
|
289
|
+
max_length = max(len(value) for value in parsed_outputs.values())
|
290
|
+
for values in zip(*(lst[:max_length] for lst in parsed_outputs.values())):
|
291
|
+
result.append({**sample, **dict(zip(parsed_outputs.keys(), values))})
|
292
|
+
return result
|
293
|
+
|
294
|
+
else:
|
295
|
+
logger.warning(
|
296
|
+
f"Input column '{input_column}' contains invalid data type: {type(raw_output)}. "
|
297
|
+
f"Expected str or List[str]"
|
298
|
+
)
|
299
|
+
return []
|
300
|
+
|
301
|
+
def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
|
302
|
+
logger.debug(f"Parsing outputs for {len(samples)} samples")
|
303
|
+
if len(samples) == 0:
|
304
|
+
logger.warning("No samples to parse, returning empty dataset")
|
305
|
+
return Dataset.from_list([])
|
306
|
+
|
307
|
+
new_data = []
|
308
|
+
for sample in samples:
|
309
|
+
new_data.extend(self._generate(sample))
|
310
|
+
return Dataset.from_list(new_data)
|