sdg-hub 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/__init__.py +28 -1
- sdg_hub/_version.py +2 -2
- sdg_hub/core/__init__.py +22 -0
- sdg_hub/core/blocks/__init__.py +58 -0
- sdg_hub/core/blocks/base.py +313 -0
- sdg_hub/core/blocks/deprecated_blocks/__init__.py +29 -0
- sdg_hub/core/blocks/deprecated_blocks/combine_columns.py +93 -0
- sdg_hub/core/blocks/deprecated_blocks/duplicate_columns.py +88 -0
- sdg_hub/core/blocks/deprecated_blocks/filter_by_value.py +103 -0
- sdg_hub/core/blocks/deprecated_blocks/flatten_columns.py +94 -0
- sdg_hub/core/blocks/deprecated_blocks/llmblock.py +479 -0
- sdg_hub/core/blocks/deprecated_blocks/rename_columns.py +88 -0
- sdg_hub/core/blocks/deprecated_blocks/sample_populator.py +58 -0
- sdg_hub/core/blocks/deprecated_blocks/selector.py +97 -0
- sdg_hub/core/blocks/deprecated_blocks/set_to_majority_value.py +88 -0
- sdg_hub/core/blocks/evaluation/__init__.py +9 -0
- sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +564 -0
- sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +564 -0
- sdg_hub/core/blocks/evaluation/verify_question_block.py +564 -0
- sdg_hub/core/blocks/filtering/__init__.py +12 -0
- sdg_hub/core/blocks/filtering/column_value_filter.py +188 -0
- sdg_hub/core/blocks/llm/__init__.py +27 -0
- sdg_hub/core/blocks/llm/client_manager.py +398 -0
- sdg_hub/core/blocks/llm/config.py +336 -0
- sdg_hub/core/blocks/llm/error_handler.py +368 -0
- sdg_hub/core/blocks/llm/llm_chat_block.py +542 -0
- sdg_hub/core/blocks/llm/llm_chat_with_parsing_retry_block.py +491 -0
- sdg_hub/core/blocks/llm/prompt_builder_block.py +368 -0
- sdg_hub/core/blocks/llm/text_parser_block.py +357 -0
- sdg_hub/core/blocks/registry.py +331 -0
- sdg_hub/core/blocks/transform/__init__.py +23 -0
- sdg_hub/core/blocks/transform/duplicate_columns.py +88 -0
- sdg_hub/core/blocks/transform/index_based_mapper.py +225 -0
- sdg_hub/core/blocks/transform/melt_columns.py +126 -0
- sdg_hub/core/blocks/transform/rename_columns.py +69 -0
- sdg_hub/core/blocks/transform/text_concat.py +102 -0
- sdg_hub/core/blocks/transform/uniform_col_val_setter.py +101 -0
- sdg_hub/core/flow/__init__.py +20 -0
- sdg_hub/core/flow/base.py +1209 -0
- sdg_hub/core/flow/checkpointer.py +333 -0
- sdg_hub/core/flow/metadata.py +389 -0
- sdg_hub/core/flow/migration.py +198 -0
- sdg_hub/core/flow/registry.py +393 -0
- sdg_hub/core/flow/validation.py +277 -0
- sdg_hub/{utils → core/utils}/__init__.py +7 -4
- sdg_hub/core/utils/datautils.py +63 -0
- sdg_hub/core/utils/error_handling.py +208 -0
- sdg_hub/core/utils/flow_id_words.yaml +231 -0
- sdg_hub/core/utils/flow_identifier.py +94 -0
- sdg_hub/{utils → core/utils}/path_resolution.py +2 -2
- sdg_hub/core/utils/yaml_utils.py +59 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/atomic_facts.yaml +40 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/detailed_summary.yaml +13 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_faithfulness.yaml +64 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_question.yaml +29 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_relevancy.yaml +81 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/extractive_summary.yaml +13 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +192 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/generate_questions_responses.yaml +54 -0
- sdg_hub-0.2.1.dist-info/METADATA +221 -0
- sdg_hub-0.2.1.dist-info/RECORD +68 -0
- sdg_hub/blocks/__init__.py +0 -42
- sdg_hub/blocks/block.py +0 -96
- sdg_hub/blocks/llmblock.py +0 -375
- sdg_hub/blocks/openaichatblock.py +0 -556
- sdg_hub/blocks/utilblocks.py +0 -597
- sdg_hub/checkpointer.py +0 -139
- sdg_hub/configs/annotations/cot_reflection.yaml +0 -34
- sdg_hub/configs/annotations/detailed_annotations.yaml +0 -28
- sdg_hub/configs/annotations/detailed_description.yaml +0 -10
- sdg_hub/configs/annotations/detailed_description_icl.yaml +0 -32
- sdg_hub/configs/annotations/simple_annotations.yaml +0 -9
- sdg_hub/configs/knowledge/__init__.py +0 -0
- sdg_hub/configs/knowledge/atomic_facts.yaml +0 -46
- sdg_hub/configs/knowledge/auxilary_instructions.yaml +0 -35
- sdg_hub/configs/knowledge/detailed_summary.yaml +0 -18
- sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +0 -68
- sdg_hub/configs/knowledge/evaluate_question.yaml +0 -38
- sdg_hub/configs/knowledge/evaluate_relevancy.yaml +0 -84
- sdg_hub/configs/knowledge/extractive_summary.yaml +0 -18
- sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +0 -39
- sdg_hub/configs/knowledge/generate_questions.yaml +0 -82
- sdg_hub/configs/knowledge/generate_questions_responses.yaml +0 -56
- sdg_hub/configs/knowledge/generate_responses.yaml +0 -86
- sdg_hub/configs/knowledge/mcq_generation.yaml +0 -83
- sdg_hub/configs/knowledge/router.yaml +0 -12
- sdg_hub/configs/knowledge/simple_generate_qa.yaml +0 -34
- sdg_hub/configs/reasoning/__init__.py +0 -0
- sdg_hub/configs/reasoning/dynamic_cot.yaml +0 -40
- sdg_hub/configs/skills/__init__.py +0 -0
- sdg_hub/configs/skills/analyzer.yaml +0 -48
- sdg_hub/configs/skills/annotation.yaml +0 -36
- sdg_hub/configs/skills/contexts.yaml +0 -28
- sdg_hub/configs/skills/critic.yaml +0 -60
- sdg_hub/configs/skills/evaluate_freeform_pair.yaml +0 -111
- sdg_hub/configs/skills/evaluate_freeform_questions.yaml +0 -78
- sdg_hub/configs/skills/evaluate_grounded_pair.yaml +0 -119
- sdg_hub/configs/skills/evaluate_grounded_questions.yaml +0 -51
- sdg_hub/configs/skills/freeform_questions.yaml +0 -34
- sdg_hub/configs/skills/freeform_responses.yaml +0 -39
- sdg_hub/configs/skills/grounded_questions.yaml +0 -38
- sdg_hub/configs/skills/grounded_responses.yaml +0 -59
- sdg_hub/configs/skills/icl_examples/STEM.yaml +0 -56
- sdg_hub/configs/skills/icl_examples/__init__.py +0 -0
- sdg_hub/configs/skills/icl_examples/coding.yaml +0 -97
- sdg_hub/configs/skills/icl_examples/extraction.yaml +0 -36
- sdg_hub/configs/skills/icl_examples/humanities.yaml +0 -71
- sdg_hub/configs/skills/icl_examples/math.yaml +0 -85
- sdg_hub/configs/skills/icl_examples/reasoning.yaml +0 -30
- sdg_hub/configs/skills/icl_examples/roleplay.yaml +0 -45
- sdg_hub/configs/skills/icl_examples/writing.yaml +0 -80
- sdg_hub/configs/skills/judge.yaml +0 -53
- sdg_hub/configs/skills/planner.yaml +0 -67
- sdg_hub/configs/skills/respond.yaml +0 -8
- sdg_hub/configs/skills/revised_responder.yaml +0 -78
- sdg_hub/configs/skills/router.yaml +0 -59
- sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +0 -27
- sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +0 -31
- sdg_hub/flow.py +0 -477
- sdg_hub/flow_runner.py +0 -450
- sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +0 -13
- sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +0 -12
- sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +0 -89
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +0 -136
- sdg_hub/flows/generation/skills/improve_responses.yaml +0 -103
- sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +0 -12
- sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +0 -12
- sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +0 -80
- sdg_hub/flows/generation/skills/synth_skills.yaml +0 -59
- sdg_hub/pipeline.py +0 -121
- sdg_hub/prompts.py +0 -80
- sdg_hub/registry.py +0 -122
- sdg_hub/sdg.py +0 -206
- sdg_hub/utils/config_validation.py +0 -91
- sdg_hub/utils/datautils.py +0 -14
- sdg_hub/utils/error_handling.py +0 -94
- sdg_hub/utils/validation_result.py +0 -10
- sdg_hub-0.1.4.dist-info/METADATA +0 -190
- sdg_hub-0.1.4.dist-info/RECORD +0 -89
- sdg_hub/{logger_config.py → core/utils/logger_config.py} +1 -1
- /sdg_hub/{configs/__init__.py → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/README.md} +0 -0
- /sdg_hub/{configs/annotations → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab}/__init__.py +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.1.dist-info}/WHEEL +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1209 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Pydantic-based Flow class for managing data generation pipelines."""
|
3
|
+
|
4
|
+
# Standard
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, Optional, Union
|
7
|
+
import time
|
8
|
+
|
9
|
+
# Third Party
|
10
|
+
from datasets import Dataset
|
11
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
12
|
+
from rich.console import Console
|
13
|
+
from rich.panel import Panel
|
14
|
+
from rich.table import Table
|
15
|
+
from rich.tree import Tree
|
16
|
+
import yaml
|
17
|
+
|
18
|
+
# Local
|
19
|
+
from ..blocks.base import BaseBlock
|
20
|
+
from ..blocks.registry import BlockRegistry
|
21
|
+
from ..utils.datautils import safe_concatenate_with_validation
|
22
|
+
from ..utils.error_handling import EmptyDatasetError, FlowValidationError
|
23
|
+
from ..utils.logger_config import setup_logger
|
24
|
+
from ..utils.path_resolution import resolve_path
|
25
|
+
from ..utils.yaml_utils import save_flow_yaml
|
26
|
+
from .checkpointer import FlowCheckpointer
|
27
|
+
from .metadata import FlowMetadata, FlowParameter
|
28
|
+
from .migration import FlowMigration
|
29
|
+
from .validation import FlowValidator
|
30
|
+
|
31
|
+
logger = setup_logger(__name__)
|
32
|
+
|
33
|
+
|
34
|
+
class Flow(BaseModel):
|
35
|
+
"""Pydantic-based flow for chaining data generation blocks.
|
36
|
+
|
37
|
+
A Flow represents a complete data generation pipeline with proper validation,
|
38
|
+
metadata tracking, and execution capabilities. All configuration is validated
|
39
|
+
using Pydantic models for type safety and better error messages.
|
40
|
+
|
41
|
+
Attributes
|
42
|
+
----------
|
43
|
+
blocks : List[BaseBlock]
|
44
|
+
Ordered list of blocks to execute in the flow.
|
45
|
+
metadata : FlowMetadata
|
46
|
+
Flow metadata including name, version, author, etc.
|
47
|
+
parameters : Dict[str, FlowParameter]
|
48
|
+
Runtime parameters that can be overridden during execution.
|
49
|
+
"""
|
50
|
+
|
51
|
+
blocks: list[BaseBlock] = Field(
|
52
|
+
default_factory=list,
|
53
|
+
description="Ordered list of blocks to execute in the flow",
|
54
|
+
)
|
55
|
+
metadata: FlowMetadata = Field(
|
56
|
+
description="Flow metadata including name, version, author, etc."
|
57
|
+
)
|
58
|
+
parameters: dict[str, FlowParameter] = Field(
|
59
|
+
default_factory=dict,
|
60
|
+
description="Runtime parameters that can be overridden during execution",
|
61
|
+
)
|
62
|
+
|
63
|
+
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
64
|
+
|
65
|
+
# Private attributes (not serialized)
|
66
|
+
_migrated_runtime_params: dict[str, dict[str, Any]] = {}
|
67
|
+
_llm_client: Any = None # Only used for backward compatibility with old YAMLs
|
68
|
+
_model_config_set: bool = False # Track if model configuration has been set
|
69
|
+
|
70
|
+
@field_validator("blocks")
|
71
|
+
@classmethod
|
72
|
+
def validate_blocks(cls, v: list[BaseBlock]) -> list[BaseBlock]:
|
73
|
+
"""Validate that all blocks are BaseBlock instances."""
|
74
|
+
if not v:
|
75
|
+
return v
|
76
|
+
|
77
|
+
for i, block in enumerate(v):
|
78
|
+
if not isinstance(block, BaseBlock):
|
79
|
+
raise ValueError(
|
80
|
+
f"Block at index {i} is not a BaseBlock instance: {type(block)}"
|
81
|
+
)
|
82
|
+
|
83
|
+
return v
|
84
|
+
|
85
|
+
@field_validator("parameters")
|
86
|
+
@classmethod
|
87
|
+
def validate_parameters(
|
88
|
+
cls, v: dict[str, FlowParameter]
|
89
|
+
) -> dict[str, FlowParameter]:
|
90
|
+
"""Validate parameter names and ensure they are FlowParameter instances."""
|
91
|
+
if not v:
|
92
|
+
return v
|
93
|
+
|
94
|
+
validated = {}
|
95
|
+
for param_name, param_value in v.items():
|
96
|
+
if not isinstance(param_name, str) or not param_name.strip():
|
97
|
+
raise ValueError(
|
98
|
+
f"Parameter name must be a non-empty string: {param_name}"
|
99
|
+
)
|
100
|
+
|
101
|
+
if not isinstance(param_value, FlowParameter):
|
102
|
+
raise ValueError(
|
103
|
+
f"Parameter '{param_name}' must be a FlowParameter instance, "
|
104
|
+
f"got: {type(param_value)}"
|
105
|
+
)
|
106
|
+
|
107
|
+
validated[param_name.strip()] = param_value
|
108
|
+
|
109
|
+
return validated
|
110
|
+
|
111
|
+
@model_validator(mode="after")
|
112
|
+
def validate_block_names_unique(self) -> "Flow":
|
113
|
+
"""Ensure all block names are unique within the flow."""
|
114
|
+
if not self.blocks:
|
115
|
+
return self
|
116
|
+
|
117
|
+
seen_names = set()
|
118
|
+
for i, block in enumerate(self.blocks):
|
119
|
+
if block.block_name in seen_names:
|
120
|
+
raise ValueError(
|
121
|
+
f"Duplicate block name '{block.block_name}' at index {i}. "
|
122
|
+
f"All block names must be unique within a flow."
|
123
|
+
)
|
124
|
+
seen_names.add(block.block_name)
|
125
|
+
|
126
|
+
return self
|
127
|
+
|
128
|
+
@classmethod
|
129
|
+
def from_yaml(cls, yaml_path: str, client: Any = None) -> "Flow":
|
130
|
+
"""Load flow from YAML configuration file.
|
131
|
+
|
132
|
+
Parameters
|
133
|
+
----------
|
134
|
+
yaml_path : str
|
135
|
+
Path to the YAML flow configuration file.
|
136
|
+
client : Any, optional
|
137
|
+
LLM client instance. Required for backward compatibility with old format YAMLs
|
138
|
+
that use deprecated LLMBlocks. Ignored for new format YAMLs.
|
139
|
+
|
140
|
+
Returns
|
141
|
+
-------
|
142
|
+
Flow
|
143
|
+
Validated Flow instance.
|
144
|
+
|
145
|
+
Raises
|
146
|
+
------
|
147
|
+
FlowValidationError
|
148
|
+
If yaml_path is None or the file doesn't exist.
|
149
|
+
"""
|
150
|
+
if yaml_path is None:
|
151
|
+
raise FlowValidationError(
|
152
|
+
"Flow path cannot be None. Please provide a valid YAML file path or check that the flow exists in the registry."
|
153
|
+
)
|
154
|
+
|
155
|
+
yaml_path = resolve_path(yaml_path, [])
|
156
|
+
yaml_dir = Path(yaml_path).parent
|
157
|
+
|
158
|
+
logger.info(f"Loading flow from: {yaml_path}")
|
159
|
+
|
160
|
+
# Load YAML file
|
161
|
+
try:
|
162
|
+
with open(yaml_path, encoding="utf-8") as f:
|
163
|
+
flow_config = yaml.safe_load(f)
|
164
|
+
except FileNotFoundError as exc:
|
165
|
+
raise FileNotFoundError(f"Flow file not found: {yaml_path}") from exc
|
166
|
+
except yaml.YAMLError as exc:
|
167
|
+
raise FlowValidationError(f"Invalid YAML in {yaml_path}: {exc}") from exc
|
168
|
+
|
169
|
+
# Check if this is an old format flow and migrate if necessary
|
170
|
+
migrated_runtime_params = None
|
171
|
+
is_old_format = FlowMigration.is_old_format(flow_config)
|
172
|
+
if is_old_format:
|
173
|
+
logger.info(f"Detected old format flow, migrating: {yaml_path}")
|
174
|
+
if client is None:
|
175
|
+
logger.warning(
|
176
|
+
"Old format YAML detected but no client provided. LLMBlocks may fail."
|
177
|
+
)
|
178
|
+
flow_config, migrated_runtime_params = FlowMigration.migrate_to_new_format(
|
179
|
+
flow_config, yaml_path
|
180
|
+
)
|
181
|
+
# Save migrated config back to YAML to persist id
|
182
|
+
save_flow_yaml(yaml_path, flow_config, "migrated to new format")
|
183
|
+
|
184
|
+
# Validate YAML structure
|
185
|
+
validator = FlowValidator()
|
186
|
+
validation_errors = validator.validate_yaml_structure(flow_config)
|
187
|
+
if validation_errors:
|
188
|
+
raise FlowValidationError(
|
189
|
+
"Invalid flow configuration:\n" + "\n".join(validation_errors)
|
190
|
+
)
|
191
|
+
|
192
|
+
# Extract and validate metadata
|
193
|
+
metadata_dict = flow_config.get("metadata", {})
|
194
|
+
if "name" not in metadata_dict:
|
195
|
+
metadata_dict["name"] = Path(yaml_path).stem
|
196
|
+
|
197
|
+
# Note: Old format compatibility removed - only new RecommendedModels format supported
|
198
|
+
|
199
|
+
try:
|
200
|
+
metadata = FlowMetadata(**metadata_dict)
|
201
|
+
except Exception as exc:
|
202
|
+
raise FlowValidationError(f"Invalid metadata configuration: {exc}") from exc
|
203
|
+
|
204
|
+
# Extract and validate parameters
|
205
|
+
parameters = {}
|
206
|
+
params_dict = flow_config.get("parameters", {})
|
207
|
+
for param_name, param_config in params_dict.items():
|
208
|
+
try:
|
209
|
+
parameters[param_name] = FlowParameter(**param_config)
|
210
|
+
except Exception as exc:
|
211
|
+
raise FlowValidationError(
|
212
|
+
f"Invalid parameter '{param_name}': {exc}"
|
213
|
+
) from exc
|
214
|
+
|
215
|
+
# Create blocks with validation
|
216
|
+
blocks = []
|
217
|
+
block_configs = flow_config.get("blocks", [])
|
218
|
+
|
219
|
+
for i, block_config in enumerate(block_configs):
|
220
|
+
try:
|
221
|
+
# Inject client for deprecated LLMBlocks if this is an old format flow
|
222
|
+
if (
|
223
|
+
is_old_format
|
224
|
+
and block_config.get("block_type") == "LLMBlock"
|
225
|
+
and client is not None
|
226
|
+
):
|
227
|
+
if "block_config" not in block_config:
|
228
|
+
block_config["block_config"] = {}
|
229
|
+
block_config["block_config"]["client"] = client
|
230
|
+
logger.debug(
|
231
|
+
f"Injected client for deprecated LLMBlock: {block_config['block_config'].get('block_name')}"
|
232
|
+
)
|
233
|
+
|
234
|
+
block = cls._create_block_from_config(block_config, yaml_dir)
|
235
|
+
blocks.append(block)
|
236
|
+
except Exception as exc:
|
237
|
+
raise FlowValidationError(
|
238
|
+
f"Failed to create block at index {i}: {exc}"
|
239
|
+
) from exc
|
240
|
+
|
241
|
+
# Create and validate the flow
|
242
|
+
try:
|
243
|
+
flow = cls(blocks=blocks, metadata=metadata, parameters=parameters)
|
244
|
+
# Persist generated id back to the YAML file (only on initial load)
|
245
|
+
# If the file had no metadata.id originally, update and rewrite
|
246
|
+
if not flow_config.get("metadata", {}).get("id"):
|
247
|
+
flow_config.setdefault("metadata", {})["id"] = flow.metadata.id
|
248
|
+
save_flow_yaml(
|
249
|
+
yaml_path,
|
250
|
+
flow_config,
|
251
|
+
f"added generated id: {flow.metadata.id}",
|
252
|
+
)
|
253
|
+
else:
|
254
|
+
logger.debug(f"Flow already had id: {flow.metadata.id}")
|
255
|
+
# Store migrated runtime params and client for backward compatibility
|
256
|
+
if migrated_runtime_params:
|
257
|
+
flow._migrated_runtime_params = migrated_runtime_params
|
258
|
+
if is_old_format and client is not None:
|
259
|
+
flow._llm_client = client
|
260
|
+
|
261
|
+
# Check if this is a flow without LLM blocks
|
262
|
+
llm_blocks = flow._detect_llm_blocks()
|
263
|
+
if not llm_blocks:
|
264
|
+
# No LLM blocks, so no model config needed
|
265
|
+
flow._model_config_set = True
|
266
|
+
else:
|
267
|
+
# LLM blocks present - user must call set_model_config()
|
268
|
+
flow._model_config_set = False
|
269
|
+
|
270
|
+
return flow
|
271
|
+
except Exception as exc:
|
272
|
+
raise FlowValidationError(f"Flow validation failed: {exc}") from exc
|
273
|
+
|
274
|
+
@classmethod
|
275
|
+
def _create_block_from_config(
|
276
|
+
cls,
|
277
|
+
block_config: dict[str, Any],
|
278
|
+
yaml_dir: Path,
|
279
|
+
) -> BaseBlock:
|
280
|
+
"""Create a block instance from configuration with validation.
|
281
|
+
|
282
|
+
Parameters
|
283
|
+
----------
|
284
|
+
block_config : Dict[str, Any]
|
285
|
+
Block configuration from YAML.
|
286
|
+
yaml_dir : Path
|
287
|
+
Directory containing the flow YAML file.
|
288
|
+
|
289
|
+
Returns
|
290
|
+
-------
|
291
|
+
BaseBlock
|
292
|
+
Validated block instance.
|
293
|
+
|
294
|
+
Raises
|
295
|
+
------
|
296
|
+
FlowValidationError
|
297
|
+
If block creation fails.
|
298
|
+
"""
|
299
|
+
# Validate block configuration structure
|
300
|
+
if not isinstance(block_config, dict):
|
301
|
+
raise FlowValidationError("Block configuration must be a dictionary")
|
302
|
+
|
303
|
+
block_type_name = block_config.get("block_type")
|
304
|
+
if not block_type_name:
|
305
|
+
raise FlowValidationError("Block configuration missing 'block_type'")
|
306
|
+
|
307
|
+
# Get block class from registry
|
308
|
+
try:
|
309
|
+
block_class = BlockRegistry.get(block_type_name)
|
310
|
+
except KeyError as exc:
|
311
|
+
# Get all available blocks from all categories
|
312
|
+
all_blocks = BlockRegistry.all()
|
313
|
+
available_blocks = ", ".join(
|
314
|
+
[block for blocks in all_blocks.values() for block in blocks]
|
315
|
+
)
|
316
|
+
raise FlowValidationError(
|
317
|
+
f"Block type '{block_type_name}' not found in registry. "
|
318
|
+
f"Available blocks: {available_blocks}"
|
319
|
+
) from exc
|
320
|
+
|
321
|
+
# Process block configuration
|
322
|
+
config = block_config.get("block_config", {})
|
323
|
+
if not isinstance(config, dict):
|
324
|
+
raise FlowValidationError("'block_config' must be a dictionary")
|
325
|
+
|
326
|
+
config = config.copy()
|
327
|
+
|
328
|
+
# Resolve config file paths relative to YAML directory
|
329
|
+
for path_key in ["config_path", "config_paths", "prompt_config_path"]:
|
330
|
+
if path_key in config:
|
331
|
+
config[path_key] = cls._resolve_config_paths(config[path_key], yaml_dir)
|
332
|
+
|
333
|
+
# Create block instance with Pydantic validation
|
334
|
+
try:
|
335
|
+
return block_class(**config)
|
336
|
+
except Exception as exc:
|
337
|
+
raise FlowValidationError(
|
338
|
+
f"Failed to create block '{block_type_name}' with config {config}: {exc}"
|
339
|
+
) from exc
|
340
|
+
|
341
|
+
@classmethod
|
342
|
+
def _resolve_config_paths(
|
343
|
+
cls, paths: Union[str, list[str], dict[str, str]], yaml_dir: Path
|
344
|
+
) -> Union[str, list[str], dict[str, str]]:
|
345
|
+
"""Resolve configuration file paths relative to YAML directory."""
|
346
|
+
if isinstance(paths, str):
|
347
|
+
return str(yaml_dir / paths)
|
348
|
+
elif isinstance(paths, list):
|
349
|
+
return [str(yaml_dir / path) for path in paths]
|
350
|
+
elif isinstance(paths, dict):
|
351
|
+
return {key: str(yaml_dir / path) for key, path in paths.items()}
|
352
|
+
return paths
|
353
|
+
|
354
|
+
def generate(
|
355
|
+
self,
|
356
|
+
dataset: Dataset,
|
357
|
+
runtime_params: Optional[dict[str, dict[str, Any]]] = None,
|
358
|
+
checkpoint_dir: Optional[str] = None,
|
359
|
+
save_freq: Optional[int] = None,
|
360
|
+
) -> Dataset:
|
361
|
+
"""Execute the flow blocks in sequence to generate data.
|
362
|
+
|
363
|
+
Note: For flows with LLM blocks, set_model_config() must be called first
|
364
|
+
to configure model settings before calling generate().
|
365
|
+
|
366
|
+
Parameters
|
367
|
+
----------
|
368
|
+
dataset : Dataset
|
369
|
+
Input dataset to process.
|
370
|
+
runtime_params : Optional[Dict[str, Dict[str, Any]]], optional
|
371
|
+
Runtime parameters organized by block name. Format:
|
372
|
+
{
|
373
|
+
"block_name": {"param1": value1, "param2": value2},
|
374
|
+
"other_block": {"param3": value3}
|
375
|
+
}
|
376
|
+
checkpoint_dir : Optional[str], optional
|
377
|
+
Directory to save/load checkpoints. If provided, enables checkpointing.
|
378
|
+
save_freq : Optional[int], optional
|
379
|
+
Number of completed samples after which to save a checkpoint.
|
380
|
+
If None, only saves final results when checkpointing is enabled.
|
381
|
+
|
382
|
+
Returns
|
383
|
+
-------
|
384
|
+
Dataset
|
385
|
+
Processed dataset after all blocks have been executed.
|
386
|
+
|
387
|
+
Raises
|
388
|
+
------
|
389
|
+
EmptyDatasetError
|
390
|
+
If input dataset is empty or any block produces an empty dataset.
|
391
|
+
FlowValidationError
|
392
|
+
If flow validation fails or if model configuration is required but not set.
|
393
|
+
"""
|
394
|
+
# Validate save_freq parameter early to prevent range() errors
|
395
|
+
if save_freq is not None and save_freq <= 0:
|
396
|
+
raise FlowValidationError(
|
397
|
+
f"save_freq must be greater than 0, got {save_freq}"
|
398
|
+
)
|
399
|
+
|
400
|
+
# Validate preconditions
|
401
|
+
if not self.blocks:
|
402
|
+
raise FlowValidationError("Cannot generate with empty flow")
|
403
|
+
|
404
|
+
if len(dataset) == 0:
|
405
|
+
raise EmptyDatasetError("Input dataset is empty")
|
406
|
+
|
407
|
+
# Check if model configuration has been set for flows with LLM blocks
|
408
|
+
llm_blocks = self._detect_llm_blocks()
|
409
|
+
if llm_blocks and not self._model_config_set:
|
410
|
+
raise FlowValidationError(
|
411
|
+
f"Model configuration required before generate(). "
|
412
|
+
f"Found {len(llm_blocks)} LLM blocks: {sorted(llm_blocks)}. "
|
413
|
+
f"Call flow.set_model_config() first."
|
414
|
+
)
|
415
|
+
|
416
|
+
# Validate dataset requirements
|
417
|
+
dataset_errors = self.validate_dataset(dataset)
|
418
|
+
if dataset_errors:
|
419
|
+
raise FlowValidationError(
|
420
|
+
"Dataset validation failed:\n" + "\n".join(dataset_errors)
|
421
|
+
)
|
422
|
+
|
423
|
+
# Initialize checkpointer if enabled
|
424
|
+
checkpointer = None
|
425
|
+
completed_dataset = None
|
426
|
+
if checkpoint_dir:
|
427
|
+
checkpointer = FlowCheckpointer(
|
428
|
+
checkpoint_dir=checkpoint_dir,
|
429
|
+
save_freq=save_freq,
|
430
|
+
flow_id=self.metadata.id,
|
431
|
+
)
|
432
|
+
|
433
|
+
# Load existing progress
|
434
|
+
remaining_dataset, completed_dataset = checkpointer.load_existing_progress(
|
435
|
+
dataset
|
436
|
+
)
|
437
|
+
|
438
|
+
if len(remaining_dataset) == 0:
|
439
|
+
logger.info("All samples already completed, returning existing results")
|
440
|
+
return completed_dataset
|
441
|
+
|
442
|
+
dataset = remaining_dataset
|
443
|
+
logger.info(f"Resuming with {len(dataset)} remaining samples")
|
444
|
+
|
445
|
+
logger.info(
|
446
|
+
f"Starting flow '{self.metadata.name}' v{self.metadata.version} "
|
447
|
+
f"with {len(dataset)} samples across {len(self.blocks)} blocks"
|
448
|
+
)
|
449
|
+
|
450
|
+
# Merge migrated runtime params with provided ones (provided ones take precedence)
|
451
|
+
merged_runtime_params = self._migrated_runtime_params.copy()
|
452
|
+
if runtime_params:
|
453
|
+
merged_runtime_params.update(runtime_params)
|
454
|
+
runtime_params = merged_runtime_params
|
455
|
+
|
456
|
+
# Process dataset in chunks if checkpointing with save_freq
|
457
|
+
if checkpointer and save_freq:
|
458
|
+
all_processed = []
|
459
|
+
|
460
|
+
# Process in chunks of save_freq
|
461
|
+
for i in range(0, len(dataset), save_freq):
|
462
|
+
chunk_end = min(i + save_freq, len(dataset))
|
463
|
+
chunk_dataset = dataset.select(range(i, chunk_end))
|
464
|
+
|
465
|
+
logger.info(
|
466
|
+
f"Processing chunk {i // save_freq + 1}: samples {i} to {chunk_end - 1}"
|
467
|
+
)
|
468
|
+
|
469
|
+
# Execute all blocks on this chunk
|
470
|
+
processed_chunk = self._execute_blocks_on_dataset(
|
471
|
+
chunk_dataset, runtime_params
|
472
|
+
)
|
473
|
+
all_processed.append(processed_chunk)
|
474
|
+
|
475
|
+
# Save checkpoint after chunk completion
|
476
|
+
checkpointer.add_completed_samples(processed_chunk)
|
477
|
+
|
478
|
+
# Save final checkpoint for any remaining samples
|
479
|
+
checkpointer.save_final_checkpoint()
|
480
|
+
|
481
|
+
# Combine all processed chunks
|
482
|
+
final_dataset = safe_concatenate_with_validation(
|
483
|
+
all_processed, "processed chunks from flow execution"
|
484
|
+
)
|
485
|
+
|
486
|
+
# Combine with previously completed samples if any
|
487
|
+
if checkpointer and completed_dataset:
|
488
|
+
final_dataset = safe_concatenate_with_validation(
|
489
|
+
[completed_dataset, final_dataset],
|
490
|
+
"completed checkpoint data with newly processed data",
|
491
|
+
)
|
492
|
+
|
493
|
+
else:
|
494
|
+
# Process entire dataset at once
|
495
|
+
final_dataset = self._execute_blocks_on_dataset(dataset, runtime_params)
|
496
|
+
|
497
|
+
# Save final checkpoint if checkpointing enabled
|
498
|
+
if checkpointer:
|
499
|
+
checkpointer.add_completed_samples(final_dataset)
|
500
|
+
checkpointer.save_final_checkpoint()
|
501
|
+
|
502
|
+
# Combine with previously completed samples if any
|
503
|
+
if completed_dataset:
|
504
|
+
final_dataset = safe_concatenate_with_validation(
|
505
|
+
[completed_dataset, final_dataset],
|
506
|
+
"completed checkpoint data with newly processed data",
|
507
|
+
)
|
508
|
+
|
509
|
+
logger.info(
|
510
|
+
f"Flow '{self.metadata.name}' completed successfully: "
|
511
|
+
f"{len(final_dataset)} final samples, "
|
512
|
+
f"{len(final_dataset.column_names)} final columns"
|
513
|
+
)
|
514
|
+
|
515
|
+
return final_dataset
|
516
|
+
|
517
|
+
def _execute_blocks_on_dataset(
|
518
|
+
self, dataset: Dataset, runtime_params: dict[str, dict[str, Any]]
|
519
|
+
) -> Dataset:
|
520
|
+
"""Execute all blocks in sequence on the given dataset.
|
521
|
+
|
522
|
+
Parameters
|
523
|
+
----------
|
524
|
+
dataset : Dataset
|
525
|
+
Dataset to process through all blocks.
|
526
|
+
runtime_params : Dict[str, Dict[str, Any]]
|
527
|
+
Runtime parameters for block execution.
|
528
|
+
|
529
|
+
Returns
|
530
|
+
-------
|
531
|
+
Dataset
|
532
|
+
Dataset after processing through all blocks.
|
533
|
+
"""
|
534
|
+
current_dataset = dataset
|
535
|
+
|
536
|
+
# Execute blocks in sequence
|
537
|
+
for i, block in enumerate(self.blocks):
|
538
|
+
logger.info(
|
539
|
+
f"Executing block {i + 1}/{len(self.blocks)}: "
|
540
|
+
f"{block.block_name} ({block.__class__.__name__})"
|
541
|
+
)
|
542
|
+
|
543
|
+
# Prepare block execution parameters
|
544
|
+
block_kwargs = self._prepare_block_kwargs(block, runtime_params)
|
545
|
+
|
546
|
+
try:
|
547
|
+
# Check if this is a deprecated block and skip validations
|
548
|
+
is_deprecated_block = (
|
549
|
+
hasattr(block, "__class__")
|
550
|
+
and hasattr(block.__class__, "__module__")
|
551
|
+
and "deprecated_blocks" in block.__class__.__module__
|
552
|
+
)
|
553
|
+
|
554
|
+
if is_deprecated_block:
|
555
|
+
logger.debug(
|
556
|
+
f"Skipping validations for deprecated block: {block.block_name}"
|
557
|
+
)
|
558
|
+
# Call generate() directly to skip validations, but keep the runtime params
|
559
|
+
current_dataset = block.generate(current_dataset, **block_kwargs)
|
560
|
+
else:
|
561
|
+
# Execute block with validation and logging
|
562
|
+
current_dataset = block(current_dataset, **block_kwargs)
|
563
|
+
|
564
|
+
# Validate output
|
565
|
+
if len(current_dataset) == 0:
|
566
|
+
raise EmptyDatasetError(
|
567
|
+
f"Block '{block.block_name}' produced empty dataset"
|
568
|
+
)
|
569
|
+
|
570
|
+
logger.info(
|
571
|
+
f"Block '{block.block_name}' completed successfully: "
|
572
|
+
f"{len(current_dataset)} samples, "
|
573
|
+
f"{len(current_dataset.column_names)} columns"
|
574
|
+
)
|
575
|
+
|
576
|
+
except Exception as exc:
|
577
|
+
logger.error(
|
578
|
+
f"Block '{block.block_name}' failed during execution: {exc}"
|
579
|
+
)
|
580
|
+
raise FlowValidationError(
|
581
|
+
f"Block '{block.block_name}' execution failed: {exc}"
|
582
|
+
) from exc
|
583
|
+
|
584
|
+
return current_dataset
|
585
|
+
|
586
|
+
def _prepare_block_kwargs(
|
587
|
+
self, block: BaseBlock, runtime_params: dict[str, dict[str, Any]]
|
588
|
+
) -> dict[str, Any]:
|
589
|
+
"""Prepare execution parameters for a block."""
|
590
|
+
return runtime_params.get(block.block_name, {})
|
591
|
+
|
592
|
+
def set_model_config(
|
593
|
+
self,
|
594
|
+
model: Optional[str] = None,
|
595
|
+
api_base: Optional[str] = None,
|
596
|
+
api_key: Optional[str] = None,
|
597
|
+
blocks: Optional[list[str]] = None,
|
598
|
+
**kwargs: Any,
|
599
|
+
) -> None:
|
600
|
+
"""Configure model settings for LLM blocks in this flow (in-place).
|
601
|
+
|
602
|
+
This method is designed to work with model-agnostic flow definitions where
|
603
|
+
LLM blocks don't have hardcoded model configurations in the YAML. Instead,
|
604
|
+
model settings are configured at runtime using this method.
|
605
|
+
|
606
|
+
Based on LiteLLM's basic usage pattern, this method focuses on the core
|
607
|
+
parameters (model, api_base, api_key) with additional parameters passed via kwargs.
|
608
|
+
|
609
|
+
By default, auto-detects all LLM blocks in the flow and applies configuration to them.
|
610
|
+
Optionally allows targeting specific blocks only.
|
611
|
+
|
612
|
+
Parameters
|
613
|
+
----------
|
614
|
+
model : Optional[str]
|
615
|
+
Model name to configure (e.g., "hosted_vllm/openai/gpt-oss-120b").
|
616
|
+
api_base : Optional[str]
|
617
|
+
API base URL to configure (e.g., "http://localhost:8101/v1").
|
618
|
+
api_key : Optional[str]
|
619
|
+
API key to configure.
|
620
|
+
blocks : Optional[List[str]]
|
621
|
+
Specific block names to target. If None, auto-detects all LLM blocks.
|
622
|
+
**kwargs : Any
|
623
|
+
Additional model parameters (e.g., temperature, max_tokens, top_p, etc.).
|
624
|
+
|
625
|
+
Examples
|
626
|
+
--------
|
627
|
+
>>> # Recommended workflow: discover -> initialize -> set_model_config -> generate
|
628
|
+
>>> flow = Flow.from_yaml("path/to/flow.yaml") # Initialize flow
|
629
|
+
>>> flow.set_model_config( # Configure model settings
|
630
|
+
... model="hosted_vllm/openai/gpt-oss-120b",
|
631
|
+
... api_base="http://localhost:8101/v1",
|
632
|
+
... api_key="your_key",
|
633
|
+
... temperature=0.7,
|
634
|
+
... max_tokens=2048
|
635
|
+
... )
|
636
|
+
>>> result = flow.generate(dataset) # Generate data
|
637
|
+
|
638
|
+
>>> # Configure only specific blocks
|
639
|
+
>>> flow.set_model_config(
|
640
|
+
... model="hosted_vllm/openai/gpt-oss-120b",
|
641
|
+
... api_base="http://localhost:8101/v1",
|
642
|
+
... blocks=["gen_detailed_summary", "knowledge_generation"]
|
643
|
+
... )
|
644
|
+
|
645
|
+
Raises
|
646
|
+
------
|
647
|
+
ValueError
|
648
|
+
If no configuration parameters are provided or if specified blocks don't exist.
|
649
|
+
"""
|
650
|
+
# Build the configuration parameters dictionary
|
651
|
+
config_params = {}
|
652
|
+
if model is not None:
|
653
|
+
config_params["model"] = model
|
654
|
+
if api_base is not None:
|
655
|
+
config_params["api_base"] = api_base
|
656
|
+
if api_key is not None:
|
657
|
+
config_params["api_key"] = api_key
|
658
|
+
|
659
|
+
# Add any additional kwargs (temperature, max_tokens, etc.)
|
660
|
+
config_params.update(kwargs)
|
661
|
+
|
662
|
+
# Validate that at least one parameter is provided
|
663
|
+
if not config_params:
|
664
|
+
raise ValueError(
|
665
|
+
"At least one configuration parameter must be provided "
|
666
|
+
"(model, api_base, api_key, or **kwargs)"
|
667
|
+
)
|
668
|
+
|
669
|
+
# Determine target blocks
|
670
|
+
if blocks is not None:
|
671
|
+
# Validate that specified blocks exist in the flow
|
672
|
+
existing_block_names = {block.block_name for block in self.blocks}
|
673
|
+
invalid_blocks = set(blocks) - existing_block_names
|
674
|
+
if invalid_blocks:
|
675
|
+
raise ValueError(
|
676
|
+
f"Specified blocks not found in flow: {sorted(invalid_blocks)}. "
|
677
|
+
f"Available blocks: {sorted(existing_block_names)}"
|
678
|
+
)
|
679
|
+
target_block_names = set(blocks)
|
680
|
+
logger.info(
|
681
|
+
f"Targeting specific blocks for configuration: {sorted(target_block_names)}"
|
682
|
+
)
|
683
|
+
else:
|
684
|
+
# Auto-detect LLM blocks
|
685
|
+
target_block_names = set(self._detect_llm_blocks())
|
686
|
+
logger.info(
|
687
|
+
f"Auto-detected {len(target_block_names)} LLM blocks for configuration: {sorted(target_block_names)}"
|
688
|
+
)
|
689
|
+
|
690
|
+
# Apply configuration to target blocks
|
691
|
+
modified_count = 0
|
692
|
+
for block in self.blocks:
|
693
|
+
if block.block_name in target_block_names:
|
694
|
+
for param_name, param_value in config_params.items():
|
695
|
+
if hasattr(block, param_name):
|
696
|
+
old_value = getattr(block, param_name)
|
697
|
+
setattr(block, param_name, param_value)
|
698
|
+
logger.debug(
|
699
|
+
f"Block '{block.block_name}': {param_name} "
|
700
|
+
f"'{old_value}' -> '{param_value}'"
|
701
|
+
)
|
702
|
+
else:
|
703
|
+
logger.warning(
|
704
|
+
f"Block '{block.block_name}' ({block.__class__.__name__}) "
|
705
|
+
f"does not have attribute '{param_name}' - skipping"
|
706
|
+
)
|
707
|
+
|
708
|
+
# Reinitialize client manager for LLM blocks after updating config
|
709
|
+
if hasattr(block, "_reinitialize_client_manager"):
|
710
|
+
block._reinitialize_client_manager()
|
711
|
+
|
712
|
+
modified_count += 1
|
713
|
+
|
714
|
+
if modified_count > 0:
|
715
|
+
# Enhanced logging showing what was configured
|
716
|
+
param_summary = []
|
717
|
+
for param_name, param_value in config_params.items():
|
718
|
+
if param_name == "model":
|
719
|
+
param_summary.append(f"model: '{param_value}'")
|
720
|
+
elif param_name == "api_base":
|
721
|
+
param_summary.append(f"api_base: '{param_value}'")
|
722
|
+
else:
|
723
|
+
param_summary.append(f"{param_name}: {param_value}")
|
724
|
+
|
725
|
+
logger.info(
|
726
|
+
f"Successfully configured {modified_count} LLM blocks with: {', '.join(param_summary)}"
|
727
|
+
)
|
728
|
+
logger.info(f"Configured blocks: {sorted(target_block_names)}")
|
729
|
+
|
730
|
+
# Mark that model configuration has been set
|
731
|
+
self._model_config_set = True
|
732
|
+
else:
|
733
|
+
logger.warning(
|
734
|
+
"No blocks were modified - check block names or LLM block detection"
|
735
|
+
)
|
736
|
+
|
737
|
+
def _detect_llm_blocks(self) -> list[str]:
|
738
|
+
"""Detect LLM blocks in the flow by checking for model-related attribute existence.
|
739
|
+
|
740
|
+
LLM blocks are identified by having model, api_base, or api_key attributes,
|
741
|
+
regardless of their values (they may be None until set_model_config() is called).
|
742
|
+
|
743
|
+
Returns
|
744
|
+
-------
|
745
|
+
List[str]
|
746
|
+
List of block names that have LLM-related attributes.
|
747
|
+
"""
|
748
|
+
llm_blocks = []
|
749
|
+
|
750
|
+
for block in self.blocks:
|
751
|
+
block_type = block.__class__.__name__
|
752
|
+
block_name = block.block_name
|
753
|
+
|
754
|
+
# Check by attribute existence (not value) - LLM blocks have these attributes even if None
|
755
|
+
has_model_attr = hasattr(block, "model")
|
756
|
+
has_api_base_attr = hasattr(block, "api_base")
|
757
|
+
has_api_key_attr = hasattr(block, "api_key")
|
758
|
+
|
759
|
+
# A block is considered an LLM block if it has any LLM-related attributes
|
760
|
+
is_llm_block = has_model_attr or has_api_base_attr or has_api_key_attr
|
761
|
+
|
762
|
+
if is_llm_block:
|
763
|
+
llm_blocks.append(block_name)
|
764
|
+
logger.debug(
|
765
|
+
f"Detected LLM block '{block_name}' ({block_type}): "
|
766
|
+
f"has_model_attr={has_model_attr}, has_api_base_attr={has_api_base_attr}, has_api_key_attr={has_api_key_attr}"
|
767
|
+
)
|
768
|
+
|
769
|
+
return llm_blocks
|
770
|
+
|
771
|
+
def is_model_config_required(self) -> bool:
|
772
|
+
"""Check if model configuration is required for this flow.
|
773
|
+
|
774
|
+
Returns
|
775
|
+
-------
|
776
|
+
bool
|
777
|
+
True if flow has LLM blocks and needs model configuration.
|
778
|
+
"""
|
779
|
+
return len(self._detect_llm_blocks()) > 0
|
780
|
+
|
781
|
+
def is_model_config_set(self) -> bool:
|
782
|
+
"""Check if model configuration has been set.
|
783
|
+
|
784
|
+
Returns
|
785
|
+
-------
|
786
|
+
bool
|
787
|
+
True if model configuration has been set or is not required.
|
788
|
+
"""
|
789
|
+
return self._model_config_set
|
790
|
+
|
791
|
+
def reset_model_config(self) -> None:
|
792
|
+
"""Reset model configuration flag (useful for testing or reconfiguration).
|
793
|
+
|
794
|
+
After calling this, set_model_config() must be called again before generate().
|
795
|
+
"""
|
796
|
+
if self.is_model_config_required():
|
797
|
+
self._model_config_set = False
|
798
|
+
logger.info(
|
799
|
+
"Model configuration flag reset - call set_model_config() before generate()"
|
800
|
+
)
|
801
|
+
|
802
|
+
def get_default_model(self) -> Optional[str]:
|
803
|
+
"""Get the default recommended model for this flow.
|
804
|
+
|
805
|
+
Returns
|
806
|
+
-------
|
807
|
+
Optional[str]
|
808
|
+
Default model name, or None if no models specified.
|
809
|
+
|
810
|
+
Examples
|
811
|
+
--------
|
812
|
+
>>> flow = Flow.from_yaml("path/to/flow.yaml")
|
813
|
+
>>> default_model = flow.get_default_model()
|
814
|
+
>>> print(f"Default model: {default_model}")
|
815
|
+
"""
|
816
|
+
if not self.metadata.recommended_models:
|
817
|
+
return None
|
818
|
+
return self.metadata.recommended_models.default
|
819
|
+
|
820
|
+
def get_model_recommendations(self) -> dict[str, Any]:
|
821
|
+
"""Get a clean summary of model recommendations for this flow.
|
822
|
+
|
823
|
+
Returns
|
824
|
+
-------
|
825
|
+
Dict[str, Any]
|
826
|
+
Dictionary with model recommendations in user-friendly format.
|
827
|
+
|
828
|
+
Examples
|
829
|
+
--------
|
830
|
+
>>> flow = Flow.from_yaml("path/to/flow.yaml")
|
831
|
+
>>> recommendations = flow.get_model_recommendations()
|
832
|
+
>>> print("Model recommendations:")
|
833
|
+
>>> print(f" Default: {recommendations['default']}")
|
834
|
+
>>> print(f" Compatible: {recommendations['compatible']}")
|
835
|
+
>>> print(f" Experimental: {recommendations['experimental']}")
|
836
|
+
"""
|
837
|
+
if not self.metadata.recommended_models:
|
838
|
+
return {
|
839
|
+
"default": None,
|
840
|
+
"compatible": [],
|
841
|
+
"experimental": [],
|
842
|
+
}
|
843
|
+
|
844
|
+
return {
|
845
|
+
"default": self.metadata.recommended_models.default,
|
846
|
+
"compatible": self.metadata.recommended_models.compatible,
|
847
|
+
"experimental": self.metadata.recommended_models.experimental,
|
848
|
+
}
|
849
|
+
|
850
|
+
def validate_dataset(self, dataset: Dataset) -> list[str]:
|
851
|
+
"""Validate dataset against flow requirements."""
|
852
|
+
errors = []
|
853
|
+
|
854
|
+
if len(dataset) == 0:
|
855
|
+
errors.append("Dataset is empty")
|
856
|
+
|
857
|
+
if self.metadata.dataset_requirements:
|
858
|
+
errors.extend(
|
859
|
+
self.metadata.dataset_requirements.validate_dataset(
|
860
|
+
dataset.column_names, len(dataset)
|
861
|
+
)
|
862
|
+
)
|
863
|
+
|
864
|
+
return errors
|
865
|
+
|
866
|
+
def dry_run(
|
867
|
+
self,
|
868
|
+
dataset: Dataset,
|
869
|
+
sample_size: int = 2,
|
870
|
+
runtime_params: Optional[dict[str, dict[str, Any]]] = None,
|
871
|
+
) -> dict[str, Any]:
|
872
|
+
"""Perform a dry run of the flow with a subset of data.
|
873
|
+
|
874
|
+
Parameters
|
875
|
+
----------
|
876
|
+
dataset : Dataset
|
877
|
+
Input dataset to test with.
|
878
|
+
sample_size : int, default=2
|
879
|
+
Number of samples to use for dry run testing.
|
880
|
+
runtime_params : Optional[Dict[str, Dict[str, Any]]], optional
|
881
|
+
Runtime parameters organized by block name.
|
882
|
+
|
883
|
+
Returns
|
884
|
+
-------
|
885
|
+
Dict[str, Any]
|
886
|
+
Dry run results with execution info and sample outputs.
|
887
|
+
|
888
|
+
Raises
|
889
|
+
------
|
890
|
+
EmptyDatasetError
|
891
|
+
If input dataset is empty.
|
892
|
+
FlowValidationError
|
893
|
+
If any block fails during dry run execution.
|
894
|
+
"""
|
895
|
+
# Validate preconditions
|
896
|
+
if not self.blocks:
|
897
|
+
raise FlowValidationError("Cannot dry run empty flow")
|
898
|
+
|
899
|
+
if len(dataset) == 0:
|
900
|
+
raise EmptyDatasetError("Input dataset is empty")
|
901
|
+
|
902
|
+
# Use smaller sample size if dataset is smaller
|
903
|
+
actual_sample_size = min(sample_size, len(dataset))
|
904
|
+
|
905
|
+
logger.info(
|
906
|
+
f"Starting dry run for flow '{self.metadata.name}' "
|
907
|
+
f"with {actual_sample_size} samples"
|
908
|
+
)
|
909
|
+
|
910
|
+
# Create subset dataset
|
911
|
+
sample_dataset = dataset.select(range(actual_sample_size))
|
912
|
+
|
913
|
+
# Initialize dry run results
|
914
|
+
dry_run_results = {
|
915
|
+
"flow_name": self.metadata.name,
|
916
|
+
"flow_version": self.metadata.version,
|
917
|
+
"sample_size": actual_sample_size,
|
918
|
+
"original_dataset_size": len(dataset),
|
919
|
+
"input_columns": dataset.column_names,
|
920
|
+
"blocks_executed": [],
|
921
|
+
"final_dataset": None,
|
922
|
+
"execution_successful": True,
|
923
|
+
"execution_time_seconds": 0,
|
924
|
+
}
|
925
|
+
|
926
|
+
start_time = time.time()
|
927
|
+
|
928
|
+
try:
|
929
|
+
# Execute the flow with sample data
|
930
|
+
current_dataset = sample_dataset
|
931
|
+
runtime_params = runtime_params or {}
|
932
|
+
|
933
|
+
for i, block in enumerate(self.blocks):
|
934
|
+
block_start_time = time.time()
|
935
|
+
input_rows = len(current_dataset)
|
936
|
+
|
937
|
+
logger.info(
|
938
|
+
f"Dry run executing block {i + 1}/{len(self.blocks)}: "
|
939
|
+
f"{block.block_name} ({block.__class__.__name__})"
|
940
|
+
)
|
941
|
+
|
942
|
+
# Prepare block execution parameters
|
943
|
+
block_kwargs = self._prepare_block_kwargs(block, runtime_params)
|
944
|
+
|
945
|
+
# Check if this is a deprecated block and skip validations
|
946
|
+
is_deprecated_block = (
|
947
|
+
hasattr(block, "__class__")
|
948
|
+
and hasattr(block.__class__, "__module__")
|
949
|
+
and "deprecated_blocks" in block.__class__.__module__
|
950
|
+
)
|
951
|
+
|
952
|
+
if is_deprecated_block:
|
953
|
+
logger.debug(
|
954
|
+
f"Dry run: Skipping validations for deprecated block: {block.block_name}"
|
955
|
+
)
|
956
|
+
# Call generate() directly to skip validations, but keep the runtime params
|
957
|
+
current_dataset = block.generate(current_dataset, **block_kwargs)
|
958
|
+
else:
|
959
|
+
# Execute block with validation and logging
|
960
|
+
current_dataset = block(current_dataset, **block_kwargs)
|
961
|
+
|
962
|
+
block_execution_time = time.time() - block_start_time
|
963
|
+
|
964
|
+
# Record block execution info
|
965
|
+
block_info = {
|
966
|
+
"block_name": block.block_name,
|
967
|
+
"block_type": block.__class__.__name__,
|
968
|
+
"execution_time_seconds": block_execution_time,
|
969
|
+
"input_rows": input_rows,
|
970
|
+
"output_rows": len(current_dataset),
|
971
|
+
"output_columns": current_dataset.column_names,
|
972
|
+
"parameters_used": block_kwargs,
|
973
|
+
}
|
974
|
+
|
975
|
+
dry_run_results["blocks_executed"].append(block_info)
|
976
|
+
|
977
|
+
logger.info(
|
978
|
+
f"Dry run block '{block.block_name}' completed: "
|
979
|
+
f"{len(current_dataset)} samples, "
|
980
|
+
f"{len(current_dataset.column_names)} columns, "
|
981
|
+
f"{block_execution_time:.2f}s"
|
982
|
+
)
|
983
|
+
|
984
|
+
# Store final results
|
985
|
+
dry_run_results["final_dataset"] = {
|
986
|
+
"rows": len(current_dataset),
|
987
|
+
"columns": current_dataset.column_names,
|
988
|
+
"sample_data": current_dataset.to_dict()
|
989
|
+
if len(current_dataset) > 0
|
990
|
+
else {},
|
991
|
+
}
|
992
|
+
|
993
|
+
execution_time = time.time() - start_time
|
994
|
+
dry_run_results["execution_time_seconds"] = execution_time
|
995
|
+
|
996
|
+
logger.info(
|
997
|
+
f"Dry run completed successfully for flow '{self.metadata.name}' "
|
998
|
+
f"in {execution_time:.2f}s"
|
999
|
+
)
|
1000
|
+
|
1001
|
+
return dry_run_results
|
1002
|
+
|
1003
|
+
except Exception as exc:
|
1004
|
+
execution_time = time.time() - start_time
|
1005
|
+
dry_run_results["execution_successful"] = False
|
1006
|
+
dry_run_results["execution_time_seconds"] = execution_time
|
1007
|
+
dry_run_results["error"] = str(exc)
|
1008
|
+
|
1009
|
+
logger.error(f"Dry run failed for flow '{self.metadata.name}': {exc}")
|
1010
|
+
|
1011
|
+
raise FlowValidationError(f"Dry run failed: {exc}") from exc
|
1012
|
+
|
1013
|
+
def add_block(self, block: BaseBlock) -> "Flow":
|
1014
|
+
"""Add a block to the flow, returning a new Flow instance.
|
1015
|
+
|
1016
|
+
Parameters
|
1017
|
+
----------
|
1018
|
+
block : BaseBlock
|
1019
|
+
Block to add to the flow.
|
1020
|
+
|
1021
|
+
Returns
|
1022
|
+
-------
|
1023
|
+
Flow
|
1024
|
+
New Flow instance with the added block.
|
1025
|
+
|
1026
|
+
Raises
|
1027
|
+
------
|
1028
|
+
ValueError
|
1029
|
+
If the block is invalid or creates naming conflicts.
|
1030
|
+
"""
|
1031
|
+
if not isinstance(block, BaseBlock):
|
1032
|
+
raise ValueError(f"Block must be a BaseBlock instance, got: {type(block)}")
|
1033
|
+
|
1034
|
+
# Check for name conflicts
|
1035
|
+
existing_names = {b.block_name for b in self.blocks}
|
1036
|
+
if block.block_name in existing_names:
|
1037
|
+
raise ValueError(
|
1038
|
+
f"Block name '{block.block_name}' already exists in flow. "
|
1039
|
+
f"Block names must be unique."
|
1040
|
+
)
|
1041
|
+
|
1042
|
+
# Create new flow with added block
|
1043
|
+
new_blocks = self.blocks + [block]
|
1044
|
+
|
1045
|
+
return Flow(
|
1046
|
+
blocks=new_blocks, metadata=self.metadata, parameters=self.parameters
|
1047
|
+
)
|
1048
|
+
|
1049
|
+
def get_info(self) -> dict[str, Any]:
|
1050
|
+
"""Get information about the flow."""
|
1051
|
+
return {
|
1052
|
+
"metadata": self.metadata.model_dump(),
|
1053
|
+
"parameters": {
|
1054
|
+
name: param.model_dump() for name, param in self.parameters.items()
|
1055
|
+
},
|
1056
|
+
"blocks": [
|
1057
|
+
{
|
1058
|
+
"block_type": block.__class__.__name__,
|
1059
|
+
"block_name": block.block_name,
|
1060
|
+
"input_cols": getattr(block, "input_cols", None),
|
1061
|
+
"output_cols": getattr(block, "output_cols", None),
|
1062
|
+
}
|
1063
|
+
for block in self.blocks
|
1064
|
+
],
|
1065
|
+
"total_blocks": len(self.blocks),
|
1066
|
+
"block_names": [block.block_name for block in self.blocks],
|
1067
|
+
}
|
1068
|
+
|
1069
|
+
def print_info(self) -> None:
|
1070
|
+
"""
|
1071
|
+
Print an interactive summary of the Flow in the console.
|
1072
|
+
|
1073
|
+
The summary contains:
|
1074
|
+
1. Flow metadata (name, version, author, description)
|
1075
|
+
2. Defined runtime parameters with type hints and defaults
|
1076
|
+
3. A table of all blocks with their input and output columns
|
1077
|
+
|
1078
|
+
Notes
|
1079
|
+
-----
|
1080
|
+
Uses the `rich` library for colourised output; install with
|
1081
|
+
`pip install rich` if not already present.
|
1082
|
+
|
1083
|
+
Returns
|
1084
|
+
-------
|
1085
|
+
None
|
1086
|
+
"""
|
1087
|
+
|
1088
|
+
console = Console()
|
1089
|
+
|
1090
|
+
# Create main tree structure
|
1091
|
+
flow_tree = Tree(
|
1092
|
+
f"[bold bright_blue]{self.metadata.name}[/bold bright_blue] Flow"
|
1093
|
+
)
|
1094
|
+
|
1095
|
+
# Metadata section
|
1096
|
+
metadata_branch = flow_tree.add(
|
1097
|
+
"[bold bright_green]Metadata[/bold bright_green]"
|
1098
|
+
)
|
1099
|
+
metadata_branch.add(
|
1100
|
+
f"Version: [bright_cyan]{self.metadata.version}[/bright_cyan]"
|
1101
|
+
)
|
1102
|
+
metadata_branch.add(
|
1103
|
+
f"Author: [bright_cyan]{self.metadata.author}[/bright_cyan]"
|
1104
|
+
)
|
1105
|
+
if self.metadata.description:
|
1106
|
+
metadata_branch.add(
|
1107
|
+
f"Description: [white]{self.metadata.description}[/white]"
|
1108
|
+
)
|
1109
|
+
|
1110
|
+
# Parameters section
|
1111
|
+
if self.parameters:
|
1112
|
+
params_branch = flow_tree.add(
|
1113
|
+
"[bold bright_yellow]Parameters[/bold bright_yellow]"
|
1114
|
+
)
|
1115
|
+
for name, param in self.parameters.items():
|
1116
|
+
param_info = f"[bright_cyan]{name}[/bright_cyan]: [white]{param.type_hint}[/white]"
|
1117
|
+
if param.default is not None:
|
1118
|
+
param_info += f" = [bright_white]{param.default}[/bright_white]"
|
1119
|
+
params_branch.add(param_info)
|
1120
|
+
|
1121
|
+
# Blocks overview
|
1122
|
+
flow_tree.add(
|
1123
|
+
f"[bold bright_magenta]Blocks[/bold bright_magenta] ({len(self.blocks)} total)"
|
1124
|
+
)
|
1125
|
+
|
1126
|
+
# Create blocks table
|
1127
|
+
blocks_table = Table(show_header=True, header_style="bold bright_white")
|
1128
|
+
blocks_table.add_column("Block Name", style="bright_cyan")
|
1129
|
+
blocks_table.add_column("Type", style="bright_green")
|
1130
|
+
blocks_table.add_column("Input Cols", style="bright_yellow")
|
1131
|
+
blocks_table.add_column("Output Cols", style="bright_red")
|
1132
|
+
|
1133
|
+
for block in self.blocks:
|
1134
|
+
input_cols = getattr(block, "input_cols", None)
|
1135
|
+
output_cols = getattr(block, "output_cols", None)
|
1136
|
+
|
1137
|
+
blocks_table.add_row(
|
1138
|
+
block.block_name,
|
1139
|
+
block.__class__.__name__,
|
1140
|
+
str(input_cols) if input_cols else "[bright_black]None[/bright_black]",
|
1141
|
+
str(output_cols)
|
1142
|
+
if output_cols
|
1143
|
+
else "[bright_black]None[/bright_black]",
|
1144
|
+
)
|
1145
|
+
|
1146
|
+
# Print everything
|
1147
|
+
console.print()
|
1148
|
+
console.print(
|
1149
|
+
Panel(
|
1150
|
+
flow_tree,
|
1151
|
+
title="[bold bright_white]Flow Information[/bold bright_white]",
|
1152
|
+
border_style="bright_blue",
|
1153
|
+
)
|
1154
|
+
)
|
1155
|
+
console.print()
|
1156
|
+
console.print(
|
1157
|
+
Panel(
|
1158
|
+
blocks_table,
|
1159
|
+
title="[bold bright_white]Block Details[/bold bright_white]",
|
1160
|
+
border_style="bright_magenta",
|
1161
|
+
)
|
1162
|
+
)
|
1163
|
+
console.print()
|
1164
|
+
|
1165
|
+
def to_yaml(self, output_path: str) -> None:
|
1166
|
+
"""Save flow configuration to YAML file.
|
1167
|
+
|
1168
|
+
Note: This creates a basic YAML structure. For exact reproduction
|
1169
|
+
of original YAML, save the original file separately.
|
1170
|
+
"""
|
1171
|
+
config = {
|
1172
|
+
"metadata": self.metadata.model_dump(),
|
1173
|
+
"blocks": [
|
1174
|
+
{
|
1175
|
+
"block_type": block.__class__.__name__,
|
1176
|
+
"block_config": block.model_dump(),
|
1177
|
+
}
|
1178
|
+
for block in self.blocks
|
1179
|
+
],
|
1180
|
+
}
|
1181
|
+
|
1182
|
+
if self.parameters:
|
1183
|
+
config["parameters"] = {
|
1184
|
+
name: param.model_dump() for name, param in self.parameters.items()
|
1185
|
+
}
|
1186
|
+
|
1187
|
+
save_flow_yaml(output_path, config)
|
1188
|
+
|
1189
|
+
def __len__(self) -> int:
|
1190
|
+
"""Number of blocks in the flow."""
|
1191
|
+
return len(self.blocks)
|
1192
|
+
|
1193
|
+
def __repr__(self) -> str:
|
1194
|
+
"""String representation of the flow."""
|
1195
|
+
return (
|
1196
|
+
f"Flow(name='{self.metadata.name}', "
|
1197
|
+
f"version='{self.metadata.version}', "
|
1198
|
+
f"blocks={len(self.blocks)})"
|
1199
|
+
)
|
1200
|
+
|
1201
|
+
def __str__(self) -> str:
|
1202
|
+
"""Human-readable string representation."""
|
1203
|
+
block_names = [block.block_name for block in self.blocks]
|
1204
|
+
return (
|
1205
|
+
f"Flow '{self.metadata.name}' v{self.metadata.version}\n"
|
1206
|
+
f"Blocks: {' -> '.join(block_names) if block_names else 'None'}\n"
|
1207
|
+
f"Author: {self.metadata.author or 'Unknown'}\n"
|
1208
|
+
f"Description: {self.metadata.description or 'No description'}"
|
1209
|
+
)
|