sdg-hub 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. sdg_hub/__init__.py +28 -1
  2. sdg_hub/_version.py +2 -2
  3. sdg_hub/core/__init__.py +22 -0
  4. sdg_hub/core/blocks/__init__.py +58 -0
  5. sdg_hub/core/blocks/base.py +313 -0
  6. sdg_hub/core/blocks/deprecated_blocks/__init__.py +29 -0
  7. sdg_hub/core/blocks/deprecated_blocks/combine_columns.py +93 -0
  8. sdg_hub/core/blocks/deprecated_blocks/duplicate_columns.py +88 -0
  9. sdg_hub/core/blocks/deprecated_blocks/filter_by_value.py +103 -0
  10. sdg_hub/core/blocks/deprecated_blocks/flatten_columns.py +94 -0
  11. sdg_hub/core/blocks/deprecated_blocks/llmblock.py +479 -0
  12. sdg_hub/core/blocks/deprecated_blocks/rename_columns.py +88 -0
  13. sdg_hub/core/blocks/deprecated_blocks/sample_populator.py +58 -0
  14. sdg_hub/core/blocks/deprecated_blocks/selector.py +97 -0
  15. sdg_hub/core/blocks/deprecated_blocks/set_to_majority_value.py +88 -0
  16. sdg_hub/core/blocks/evaluation/__init__.py +9 -0
  17. sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +564 -0
  18. sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +564 -0
  19. sdg_hub/core/blocks/evaluation/verify_question_block.py +564 -0
  20. sdg_hub/core/blocks/filtering/__init__.py +12 -0
  21. sdg_hub/core/blocks/filtering/column_value_filter.py +188 -0
  22. sdg_hub/core/blocks/llm/__init__.py +25 -0
  23. sdg_hub/core/blocks/llm/client_manager.py +398 -0
  24. sdg_hub/core/blocks/llm/config.py +336 -0
  25. sdg_hub/core/blocks/llm/error_handler.py +368 -0
  26. sdg_hub/core/blocks/llm/llm_chat_block.py +542 -0
  27. sdg_hub/core/blocks/llm/prompt_builder_block.py +368 -0
  28. sdg_hub/core/blocks/llm/text_parser_block.py +310 -0
  29. sdg_hub/core/blocks/registry.py +331 -0
  30. sdg_hub/core/blocks/transform/__init__.py +23 -0
  31. sdg_hub/core/blocks/transform/duplicate_columns.py +88 -0
  32. sdg_hub/core/blocks/transform/index_based_mapper.py +225 -0
  33. sdg_hub/core/blocks/transform/melt_columns.py +126 -0
  34. sdg_hub/core/blocks/transform/rename_columns.py +69 -0
  35. sdg_hub/core/blocks/transform/text_concat.py +102 -0
  36. sdg_hub/core/blocks/transform/uniform_col_val_setter.py +101 -0
  37. sdg_hub/core/flow/__init__.py +20 -0
  38. sdg_hub/core/flow/base.py +980 -0
  39. sdg_hub/core/flow/metadata.py +344 -0
  40. sdg_hub/core/flow/migration.py +187 -0
  41. sdg_hub/core/flow/registry.py +330 -0
  42. sdg_hub/core/flow/validation.py +265 -0
  43. sdg_hub/{utils → core/utils}/__init__.py +6 -4
  44. sdg_hub/{utils → core/utils}/datautils.py +1 -3
  45. sdg_hub/core/utils/error_handling.py +208 -0
  46. sdg_hub/{utils → core/utils}/path_resolution.py +2 -2
  47. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/atomic_facts.yaml +40 -0
  48. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/detailed_summary.yaml +13 -0
  49. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_faithfulness.yaml +64 -0
  50. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_question.yaml +29 -0
  51. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_relevancy.yaml +81 -0
  52. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/extractive_summary.yaml +13 -0
  53. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +191 -0
  54. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/generate_questions_responses.yaml +54 -0
  55. sdg_hub-0.2.0.dist-info/METADATA +218 -0
  56. sdg_hub-0.2.0.dist-info/RECORD +63 -0
  57. sdg_hub/blocks/__init__.py +0 -42
  58. sdg_hub/blocks/block.py +0 -96
  59. sdg_hub/blocks/llmblock.py +0 -375
  60. sdg_hub/blocks/openaichatblock.py +0 -556
  61. sdg_hub/blocks/utilblocks.py +0 -597
  62. sdg_hub/checkpointer.py +0 -139
  63. sdg_hub/configs/annotations/cot_reflection.yaml +0 -34
  64. sdg_hub/configs/annotations/detailed_annotations.yaml +0 -28
  65. sdg_hub/configs/annotations/detailed_description.yaml +0 -10
  66. sdg_hub/configs/annotations/detailed_description_icl.yaml +0 -32
  67. sdg_hub/configs/annotations/simple_annotations.yaml +0 -9
  68. sdg_hub/configs/knowledge/__init__.py +0 -0
  69. sdg_hub/configs/knowledge/atomic_facts.yaml +0 -46
  70. sdg_hub/configs/knowledge/auxilary_instructions.yaml +0 -35
  71. sdg_hub/configs/knowledge/detailed_summary.yaml +0 -18
  72. sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +0 -68
  73. sdg_hub/configs/knowledge/evaluate_question.yaml +0 -38
  74. sdg_hub/configs/knowledge/evaluate_relevancy.yaml +0 -84
  75. sdg_hub/configs/knowledge/extractive_summary.yaml +0 -18
  76. sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +0 -39
  77. sdg_hub/configs/knowledge/generate_questions.yaml +0 -82
  78. sdg_hub/configs/knowledge/generate_questions_responses.yaml +0 -56
  79. sdg_hub/configs/knowledge/generate_responses.yaml +0 -86
  80. sdg_hub/configs/knowledge/mcq_generation.yaml +0 -83
  81. sdg_hub/configs/knowledge/router.yaml +0 -12
  82. sdg_hub/configs/knowledge/simple_generate_qa.yaml +0 -34
  83. sdg_hub/configs/reasoning/__init__.py +0 -0
  84. sdg_hub/configs/reasoning/dynamic_cot.yaml +0 -40
  85. sdg_hub/configs/skills/__init__.py +0 -0
  86. sdg_hub/configs/skills/analyzer.yaml +0 -48
  87. sdg_hub/configs/skills/annotation.yaml +0 -36
  88. sdg_hub/configs/skills/contexts.yaml +0 -28
  89. sdg_hub/configs/skills/critic.yaml +0 -60
  90. sdg_hub/configs/skills/evaluate_freeform_pair.yaml +0 -111
  91. sdg_hub/configs/skills/evaluate_freeform_questions.yaml +0 -78
  92. sdg_hub/configs/skills/evaluate_grounded_pair.yaml +0 -119
  93. sdg_hub/configs/skills/evaluate_grounded_questions.yaml +0 -51
  94. sdg_hub/configs/skills/freeform_questions.yaml +0 -34
  95. sdg_hub/configs/skills/freeform_responses.yaml +0 -39
  96. sdg_hub/configs/skills/grounded_questions.yaml +0 -38
  97. sdg_hub/configs/skills/grounded_responses.yaml +0 -59
  98. sdg_hub/configs/skills/icl_examples/STEM.yaml +0 -56
  99. sdg_hub/configs/skills/icl_examples/__init__.py +0 -0
  100. sdg_hub/configs/skills/icl_examples/coding.yaml +0 -97
  101. sdg_hub/configs/skills/icl_examples/extraction.yaml +0 -36
  102. sdg_hub/configs/skills/icl_examples/humanities.yaml +0 -71
  103. sdg_hub/configs/skills/icl_examples/math.yaml +0 -85
  104. sdg_hub/configs/skills/icl_examples/reasoning.yaml +0 -30
  105. sdg_hub/configs/skills/icl_examples/roleplay.yaml +0 -45
  106. sdg_hub/configs/skills/icl_examples/writing.yaml +0 -80
  107. sdg_hub/configs/skills/judge.yaml +0 -53
  108. sdg_hub/configs/skills/planner.yaml +0 -67
  109. sdg_hub/configs/skills/respond.yaml +0 -8
  110. sdg_hub/configs/skills/revised_responder.yaml +0 -78
  111. sdg_hub/configs/skills/router.yaml +0 -59
  112. sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +0 -27
  113. sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +0 -31
  114. sdg_hub/flow.py +0 -477
  115. sdg_hub/flow_runner.py +0 -450
  116. sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +0 -13
  117. sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +0 -12
  118. sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +0 -89
  119. sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +0 -136
  120. sdg_hub/flows/generation/skills/improve_responses.yaml +0 -103
  121. sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +0 -12
  122. sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +0 -12
  123. sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +0 -80
  124. sdg_hub/flows/generation/skills/synth_skills.yaml +0 -59
  125. sdg_hub/pipeline.py +0 -121
  126. sdg_hub/prompts.py +0 -80
  127. sdg_hub/registry.py +0 -122
  128. sdg_hub/sdg.py +0 -206
  129. sdg_hub/utils/config_validation.py +0 -91
  130. sdg_hub/utils/error_handling.py +0 -94
  131. sdg_hub/utils/validation_result.py +0 -10
  132. sdg_hub-0.1.4.dist-info/METADATA +0 -190
  133. sdg_hub-0.1.4.dist-info/RECORD +0 -89
  134. sdg_hub/{logger_config.py → core/utils/logger_config.py} +1 -1
  135. /sdg_hub/{configs/__init__.py → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/README.md} +0 -0
  136. /sdg_hub/{configs/annotations → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab}/__init__.py +0 -0
  137. {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/WHEEL +0 -0
  138. {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/licenses/LICENSE +0 -0
  139. {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,980 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Pydantic-based Flow class for managing data generation pipelines."""
3
+
4
+ # Standard
5
+ from pathlib import Path
6
+ from typing import Any, Optional, Union
7
+
8
+ # Third Party
9
+ from datasets import Dataset
10
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
11
+ import yaml
12
+
13
+ # Local
14
+ from ..blocks.base import BaseBlock
15
+ from ..blocks.registry import BlockRegistry
16
+ from ..utils.error_handling import EmptyDatasetError, FlowValidationError
17
+ from ..utils.logger_config import setup_logger
18
+ from ..utils.path_resolution import resolve_path
19
+ from .metadata import FlowMetadata, FlowParameter
20
+ from .migration import FlowMigration
21
+ from .validation import FlowValidator
22
+
23
+ logger = setup_logger(__name__)
24
+
25
+
26
+ class Flow(BaseModel):
27
+ """Pydantic-based flow for chaining data generation blocks.
28
+
29
+ A Flow represents a complete data generation pipeline with proper validation,
30
+ metadata tracking, and execution capabilities. All configuration is validated
31
+ using Pydantic models for type safety and better error messages.
32
+
33
+ Attributes
34
+ ----------
35
+ blocks : List[BaseBlock]
36
+ Ordered list of blocks to execute in the flow.
37
+ metadata : FlowMetadata
38
+ Flow metadata including name, version, author, etc.
39
+ parameters : Dict[str, FlowParameter]
40
+ Runtime parameters that can be overridden during execution.
41
+ """
42
+
43
+ blocks: list[BaseBlock] = Field(
44
+ default_factory=list,
45
+ description="Ordered list of blocks to execute in the flow",
46
+ )
47
+ metadata: FlowMetadata = Field(
48
+ description="Flow metadata including name, version, author, etc."
49
+ )
50
+ parameters: dict[str, FlowParameter] = Field(
51
+ default_factory=dict,
52
+ description="Runtime parameters that can be overridden during execution",
53
+ )
54
+
55
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
56
+
57
+ # Private attributes (not serialized)
58
+ _migrated_runtime_params: dict[str, dict[str, Any]] = {}
59
+ _llm_client: Any = None # Only used for backward compatibility with old YAMLs
60
+ _model_config_set: bool = False # Track if model configuration has been set
61
+
62
+ @field_validator("blocks")
63
+ @classmethod
64
+ def validate_blocks(cls, v: list[BaseBlock]) -> list[BaseBlock]:
65
+ """Validate that all blocks are BaseBlock instances."""
66
+ if not v:
67
+ return v
68
+
69
+ for i, block in enumerate(v):
70
+ if not isinstance(block, BaseBlock):
71
+ raise ValueError(
72
+ f"Block at index {i} is not a BaseBlock instance: {type(block)}"
73
+ )
74
+
75
+ return v
76
+
77
+ @field_validator("parameters")
78
+ @classmethod
79
+ def validate_parameters(
80
+ cls, v: dict[str, FlowParameter]
81
+ ) -> dict[str, FlowParameter]:
82
+ """Validate parameter names and ensure they are FlowParameter instances."""
83
+ if not v:
84
+ return v
85
+
86
+ validated = {}
87
+ for param_name, param_value in v.items():
88
+ if not isinstance(param_name, str) or not param_name.strip():
89
+ raise ValueError(
90
+ f"Parameter name must be a non-empty string: {param_name}"
91
+ )
92
+
93
+ if not isinstance(param_value, FlowParameter):
94
+ raise ValueError(
95
+ f"Parameter '{param_name}' must be a FlowParameter instance, "
96
+ f"got: {type(param_value)}"
97
+ )
98
+
99
+ validated[param_name.strip()] = param_value
100
+
101
+ return validated
102
+
103
+ @model_validator(mode="after")
104
+ def validate_block_names_unique(self) -> "Flow":
105
+ """Ensure all block names are unique within the flow."""
106
+ if not self.blocks:
107
+ return self
108
+
109
+ seen_names = set()
110
+ for i, block in enumerate(self.blocks):
111
+ if block.block_name in seen_names:
112
+ raise ValueError(
113
+ f"Duplicate block name '{block.block_name}' at index {i}. "
114
+ f"All block names must be unique within a flow."
115
+ )
116
+ seen_names.add(block.block_name)
117
+
118
+ return self
119
+
120
+ @classmethod
121
+ def from_yaml(cls, yaml_path: str, client: Any = None) -> "Flow":
122
+ """Load flow from YAML configuration file.
123
+
124
+ Parameters
125
+ ----------
126
+ yaml_path : str
127
+ Path to the YAML flow configuration file.
128
+ client : Any, optional
129
+ LLM client instance. Required for backward compatibility with old format YAMLs
130
+ that use deprecated LLMBlocks. Ignored for new format YAMLs.
131
+
132
+ Returns
133
+ -------
134
+ Flow
135
+ Validated Flow instance.
136
+ """
137
+ yaml_path = resolve_path(yaml_path, [])
138
+ yaml_dir = Path(yaml_path).parent
139
+
140
+ logger.info(f"Loading flow from: {yaml_path}")
141
+
142
+ # Load YAML file
143
+ try:
144
+ with open(yaml_path, encoding="utf-8") as f:
145
+ flow_config = yaml.safe_load(f)
146
+ except FileNotFoundError as exc:
147
+ raise FileNotFoundError(f"Flow file not found: {yaml_path}") from exc
148
+ except yaml.YAMLError as exc:
149
+ raise FlowValidationError(f"Invalid YAML in {yaml_path}: {exc}") from exc
150
+
151
+ # Check if this is an old format flow and migrate if necessary
152
+ migrated_runtime_params = None
153
+ is_old_format = FlowMigration.is_old_format(flow_config)
154
+ if is_old_format:
155
+ logger.info(f"Detected old format flow, migrating: {yaml_path}")
156
+ if client is None:
157
+ logger.warning(
158
+ "Old format YAML detected but no client provided. LLMBlocks may fail."
159
+ )
160
+ flow_config, migrated_runtime_params = FlowMigration.migrate_to_new_format(
161
+ flow_config, yaml_path
162
+ )
163
+
164
+ # Validate YAML structure
165
+ validator = FlowValidator()
166
+ validation_errors = validator.validate_yaml_structure(flow_config)
167
+ if validation_errors:
168
+ raise FlowValidationError(
169
+ "Invalid flow configuration:\n" + "\n".join(validation_errors)
170
+ )
171
+
172
+ # Extract and validate metadata
173
+ metadata_dict = flow_config.get("metadata", {})
174
+ if "name" not in metadata_dict:
175
+ metadata_dict["name"] = Path(yaml_path).stem
176
+
177
+ # Note: Old format compatibility removed - only new RecommendedModels format supported
178
+
179
+ try:
180
+ metadata = FlowMetadata(**metadata_dict)
181
+ except Exception as exc:
182
+ raise FlowValidationError(f"Invalid metadata configuration: {exc}") from exc
183
+
184
+ # Extract and validate parameters
185
+ parameters = {}
186
+ params_dict = flow_config.get("parameters", {})
187
+ for param_name, param_config in params_dict.items():
188
+ try:
189
+ parameters[param_name] = FlowParameter(**param_config)
190
+ except Exception as exc:
191
+ raise FlowValidationError(
192
+ f"Invalid parameter '{param_name}': {exc}"
193
+ ) from exc
194
+
195
+ # Create blocks with validation
196
+ blocks = []
197
+ block_configs = flow_config.get("blocks", [])
198
+
199
+ for i, block_config in enumerate(block_configs):
200
+ try:
201
+ # Inject client for deprecated LLMBlocks if this is an old format flow
202
+ if (
203
+ is_old_format
204
+ and block_config.get("block_type") == "LLMBlock"
205
+ and client is not None
206
+ ):
207
+ if "block_config" not in block_config:
208
+ block_config["block_config"] = {}
209
+ block_config["block_config"]["client"] = client
210
+ logger.debug(
211
+ f"Injected client for deprecated LLMBlock: {block_config['block_config'].get('block_name')}"
212
+ )
213
+
214
+ block = cls._create_block_from_config(block_config, yaml_dir)
215
+ blocks.append(block)
216
+ except Exception as exc:
217
+ raise FlowValidationError(
218
+ f"Failed to create block at index {i}: {exc}"
219
+ ) from exc
220
+
221
+ # Create and validate the flow
222
+ try:
223
+ flow = cls(blocks=blocks, metadata=metadata, parameters=parameters)
224
+ # Store migrated runtime params and client for backward compatibility
225
+ if migrated_runtime_params:
226
+ flow._migrated_runtime_params = migrated_runtime_params
227
+ if is_old_format and client is not None:
228
+ flow._llm_client = client
229
+
230
+ # Check if this is a flow without LLM blocks
231
+ llm_blocks = flow._detect_llm_blocks()
232
+ if not llm_blocks:
233
+ # No LLM blocks, so no model config needed
234
+ flow._model_config_set = True
235
+ else:
236
+ # LLM blocks present - user must call set_model_config()
237
+ flow._model_config_set = False
238
+
239
+ return flow
240
+ except Exception as exc:
241
+ raise FlowValidationError(f"Flow validation failed: {exc}") from exc
242
+
243
+ @classmethod
244
+ def _create_block_from_config(
245
+ cls,
246
+ block_config: dict[str, Any],
247
+ yaml_dir: Path,
248
+ ) -> BaseBlock:
249
+ """Create a block instance from configuration with validation.
250
+
251
+ Parameters
252
+ ----------
253
+ block_config : Dict[str, Any]
254
+ Block configuration from YAML.
255
+ yaml_dir : Path
256
+ Directory containing the flow YAML file.
257
+
258
+ Returns
259
+ -------
260
+ BaseBlock
261
+ Validated block instance.
262
+
263
+ Raises
264
+ ------
265
+ FlowValidationError
266
+ If block creation fails.
267
+ """
268
+ # Validate block configuration structure
269
+ if not isinstance(block_config, dict):
270
+ raise FlowValidationError("Block configuration must be a dictionary")
271
+
272
+ block_type_name = block_config.get("block_type")
273
+ if not block_type_name:
274
+ raise FlowValidationError("Block configuration missing 'block_type'")
275
+
276
+ # Get block class from registry
277
+ try:
278
+ block_class = BlockRegistry.get(block_type_name)
279
+ except KeyError as exc:
280
+ # Get all available blocks from all categories
281
+ all_blocks = BlockRegistry.all()
282
+ available_blocks = ", ".join(
283
+ [block for blocks in all_blocks.values() for block in blocks]
284
+ )
285
+ raise FlowValidationError(
286
+ f"Block type '{block_type_name}' not found in registry. "
287
+ f"Available blocks: {available_blocks}"
288
+ ) from exc
289
+
290
+ # Process block configuration
291
+ config = block_config.get("block_config", {})
292
+ if not isinstance(config, dict):
293
+ raise FlowValidationError("'block_config' must be a dictionary")
294
+
295
+ config = config.copy()
296
+
297
+ # Resolve config file paths relative to YAML directory
298
+ for path_key in ["config_path", "config_paths", "prompt_config_path"]:
299
+ if path_key in config:
300
+ config[path_key] = cls._resolve_config_paths(config[path_key], yaml_dir)
301
+
302
+ # Create block instance with Pydantic validation
303
+ try:
304
+ return block_class(**config)
305
+ except Exception as exc:
306
+ raise FlowValidationError(
307
+ f"Failed to create block '{block_type_name}' with config {config}: {exc}"
308
+ ) from exc
309
+
310
+ @classmethod
311
+ def _resolve_config_paths(
312
+ cls, paths: Union[str, list[str], dict[str, str]], yaml_dir: Path
313
+ ) -> Union[str, list[str], dict[str, str]]:
314
+ """Resolve configuration file paths relative to YAML directory."""
315
+ if isinstance(paths, str):
316
+ return str(yaml_dir / paths)
317
+ elif isinstance(paths, list):
318
+ return [str(yaml_dir / path) for path in paths]
319
+ elif isinstance(paths, dict):
320
+ return {key: str(yaml_dir / path) for key, path in paths.items()}
321
+ return paths
322
+
323
+ def generate(
324
+ self,
325
+ dataset: Dataset,
326
+ runtime_params: Optional[dict[str, dict[str, Any]]] = None,
327
+ ) -> Dataset:
328
+ """Execute the flow blocks in sequence to generate data.
329
+
330
+ Note: For flows with LLM blocks, set_model_config() must be called first
331
+ to configure model settings before calling generate().
332
+
333
+ Parameters
334
+ ----------
335
+ dataset : Dataset
336
+ Input dataset to process.
337
+ runtime_params : Optional[Dict[str, Dict[str, Any]]], optional
338
+ Runtime parameters organized by block name. Format:
339
+ {
340
+ "block_name": {"param1": value1, "param2": value2},
341
+ "other_block": {"param3": value3}
342
+ }
343
+
344
+ Returns
345
+ -------
346
+ Dataset
347
+ Processed dataset after all blocks have been executed.
348
+
349
+ Raises
350
+ ------
351
+ EmptyDatasetError
352
+ If input dataset is empty or any block produces an empty dataset.
353
+ FlowValidationError
354
+ If flow validation fails or if model configuration is required but not set.
355
+ """
356
+ # Validate preconditions
357
+ if not self.blocks:
358
+ raise FlowValidationError("Cannot generate with empty flow")
359
+
360
+ if len(dataset) == 0:
361
+ raise EmptyDatasetError("Input dataset is empty")
362
+
363
+ # Check if model configuration has been set for flows with LLM blocks
364
+ llm_blocks = self._detect_llm_blocks()
365
+ if llm_blocks and not self._model_config_set:
366
+ raise FlowValidationError(
367
+ f"Model configuration required before generate(). "
368
+ f"Found {len(llm_blocks)} LLM blocks: {sorted(llm_blocks)}. "
369
+ f"Call flow.set_model_config() first."
370
+ )
371
+
372
+ # Validate dataset requirements
373
+ dataset_errors = self.validate_dataset(dataset)
374
+ if dataset_errors:
375
+ raise FlowValidationError(
376
+ "Dataset validation failed:\n" + "\n".join(dataset_errors)
377
+ )
378
+
379
+ logger.info(
380
+ f"Starting flow '{self.metadata.name}' v{self.metadata.version} "
381
+ f"with {len(dataset)} samples across {len(self.blocks)} blocks"
382
+ )
383
+
384
+ current_dataset = dataset
385
+ # Merge migrated runtime params with provided ones (provided ones take precedence)
386
+ merged_runtime_params = self._migrated_runtime_params.copy()
387
+ if runtime_params:
388
+ merged_runtime_params.update(runtime_params)
389
+ runtime_params = merged_runtime_params
390
+
391
+ # Execute blocks in sequence
392
+ for i, block in enumerate(self.blocks):
393
+ logger.info(
394
+ f"Executing block {i + 1}/{len(self.blocks)}: "
395
+ f"{block.block_name} ({block.__class__.__name__})"
396
+ )
397
+
398
+ # Prepare block execution parameters
399
+ block_kwargs = self._prepare_block_kwargs(block, runtime_params)
400
+
401
+ try:
402
+ # Check if this is a deprecated block and skip validations
403
+ is_deprecated_block = (
404
+ hasattr(block, "__class__")
405
+ and hasattr(block.__class__, "__module__")
406
+ and "deprecated_blocks" in block.__class__.__module__
407
+ )
408
+
409
+ if is_deprecated_block:
410
+ logger.debug(
411
+ f"Skipping validations for deprecated block: {block.block_name}"
412
+ )
413
+ # Call generate() directly to skip validations, but keep the runtime params
414
+ current_dataset = block.generate(current_dataset, **block_kwargs)
415
+ else:
416
+ # Execute block with validation and logging
417
+ current_dataset = block(current_dataset, **block_kwargs)
418
+
419
+ # Validate output
420
+ if len(current_dataset) == 0:
421
+ raise EmptyDatasetError(
422
+ f"Block '{block.block_name}' produced empty dataset"
423
+ )
424
+
425
+ logger.info(
426
+ f"Block '{block.block_name}' completed successfully: "
427
+ f"{len(current_dataset)} samples, "
428
+ f"{len(current_dataset.column_names)} columns"
429
+ )
430
+
431
+ except Exception as exc:
432
+ logger.error(
433
+ f"Block '{block.block_name}' failed during execution: {exc}"
434
+ )
435
+ raise FlowValidationError(
436
+ f"Block '{block.block_name}' execution failed: {exc}"
437
+ ) from exc
438
+
439
+ logger.info(
440
+ f"Flow '{self.metadata.name}' completed successfully: "
441
+ f"{len(current_dataset)} final samples, "
442
+ f"{len(current_dataset.column_names)} final columns"
443
+ )
444
+
445
+ return current_dataset
446
+
447
+ def _prepare_block_kwargs(
448
+ self, block: BaseBlock, runtime_params: dict[str, dict[str, Any]]
449
+ ) -> dict[str, Any]:
450
+ """Prepare execution parameters for a block."""
451
+ return runtime_params.get(block.block_name, {})
452
+
453
+ def set_model_config(
454
+ self,
455
+ model: Optional[str] = None,
456
+ api_base: Optional[str] = None,
457
+ api_key: Optional[str] = None,
458
+ blocks: Optional[list[str]] = None,
459
+ **kwargs: Any,
460
+ ) -> None:
461
+ """Configure model settings for LLM blocks in this flow (in-place).
462
+
463
+ This method is designed to work with model-agnostic flow definitions where
464
+ LLM blocks don't have hardcoded model configurations in the YAML. Instead,
465
+ model settings are configured at runtime using this method.
466
+
467
+ Based on LiteLLM's basic usage pattern, this method focuses on the core
468
+ parameters (model, api_base, api_key) with additional parameters passed via kwargs.
469
+
470
+ By default, auto-detects all LLM blocks in the flow and applies configuration to them.
471
+ Optionally allows targeting specific blocks only.
472
+
473
+ Parameters
474
+ ----------
475
+ model : Optional[str]
476
+ Model name to configure (e.g., "hosted_vllm/openai/gpt-oss-120b").
477
+ api_base : Optional[str]
478
+ API base URL to configure (e.g., "http://localhost:8101/v1").
479
+ api_key : Optional[str]
480
+ API key to configure.
481
+ blocks : Optional[List[str]]
482
+ Specific block names to target. If None, auto-detects all LLM blocks.
483
+ **kwargs : Any
484
+ Additional model parameters (e.g., temperature, max_tokens, top_p, etc.).
485
+
486
+ Examples
487
+ --------
488
+ >>> # Recommended workflow: discover -> initialize -> set_model_config -> generate
489
+ >>> flow = Flow.from_yaml("path/to/flow.yaml") # Initialize flow
490
+ >>> flow.set_model_config( # Configure model settings
491
+ ... model="hosted_vllm/openai/gpt-oss-120b",
492
+ ... api_base="http://localhost:8101/v1",
493
+ ... api_key="your_key",
494
+ ... temperature=0.7,
495
+ ... max_tokens=2048
496
+ ... )
497
+ >>> result = flow.generate(dataset) # Generate data
498
+
499
+ >>> # Configure only specific blocks
500
+ >>> flow.set_model_config(
501
+ ... model="hosted_vllm/openai/gpt-oss-120b",
502
+ ... api_base="http://localhost:8101/v1",
503
+ ... blocks=["gen_detailed_summary", "knowledge_generation"]
504
+ ... )
505
+
506
+ Raises
507
+ ------
508
+ ValueError
509
+ If no configuration parameters are provided or if specified blocks don't exist.
510
+ """
511
+ # Build the configuration parameters dictionary
512
+ config_params = {}
513
+ if model is not None:
514
+ config_params["model"] = model
515
+ if api_base is not None:
516
+ config_params["api_base"] = api_base
517
+ if api_key is not None:
518
+ config_params["api_key"] = api_key
519
+
520
+ # Add any additional kwargs (temperature, max_tokens, etc.)
521
+ config_params.update(kwargs)
522
+
523
+ # Validate that at least one parameter is provided
524
+ if not config_params:
525
+ raise ValueError(
526
+ "At least one configuration parameter must be provided "
527
+ "(model, api_base, api_key, or **kwargs)"
528
+ )
529
+
530
+ # Determine target blocks
531
+ if blocks is not None:
532
+ # Validate that specified blocks exist in the flow
533
+ existing_block_names = {block.block_name for block in self.blocks}
534
+ invalid_blocks = set(blocks) - existing_block_names
535
+ if invalid_blocks:
536
+ raise ValueError(
537
+ f"Specified blocks not found in flow: {sorted(invalid_blocks)}. "
538
+ f"Available blocks: {sorted(existing_block_names)}"
539
+ )
540
+ target_block_names = set(blocks)
541
+ logger.info(
542
+ f"Targeting specific blocks for configuration: {sorted(target_block_names)}"
543
+ )
544
+ else:
545
+ # Auto-detect LLM blocks
546
+ target_block_names = set(self._detect_llm_blocks())
547
+ logger.info(
548
+ f"Auto-detected {len(target_block_names)} LLM blocks for configuration: {sorted(target_block_names)}"
549
+ )
550
+
551
+ # Apply configuration to target blocks
552
+ modified_count = 0
553
+ for block in self.blocks:
554
+ if block.block_name in target_block_names:
555
+ for param_name, param_value in config_params.items():
556
+ if hasattr(block, param_name):
557
+ old_value = getattr(block, param_name)
558
+ setattr(block, param_name, param_value)
559
+ logger.debug(
560
+ f"Block '{block.block_name}': {param_name} "
561
+ f"'{old_value}' -> '{param_value}'"
562
+ )
563
+ else:
564
+ logger.warning(
565
+ f"Block '{block.block_name}' ({block.__class__.__name__}) "
566
+ f"does not have attribute '{param_name}' - skipping"
567
+ )
568
+
569
+ # Reinitialize client manager for LLM blocks after updating config
570
+ if hasattr(block, "_reinitialize_client_manager"):
571
+ block._reinitialize_client_manager()
572
+
573
+ modified_count += 1
574
+
575
+ if modified_count > 0:
576
+ # Enhanced logging showing what was configured
577
+ param_summary = []
578
+ for param_name, param_value in config_params.items():
579
+ if param_name == "model":
580
+ param_summary.append(f"model: '{param_value}'")
581
+ elif param_name == "api_base":
582
+ param_summary.append(f"api_base: '{param_value}'")
583
+ else:
584
+ param_summary.append(f"{param_name}: {param_value}")
585
+
586
+ logger.info(
587
+ f"Successfully configured {modified_count} LLM blocks with: {', '.join(param_summary)}"
588
+ )
589
+ logger.info(f"Configured blocks: {sorted(target_block_names)}")
590
+
591
+ # Mark that model configuration has been set
592
+ self._model_config_set = True
593
+ else:
594
+ logger.warning(
595
+ "No blocks were modified - check block names or LLM block detection"
596
+ )
597
+
598
+ def _detect_llm_blocks(self) -> list[str]:
599
+ """Detect LLM blocks in the flow by checking for model-related attribute existence.
600
+
601
+ LLM blocks are identified by having model, api_base, or api_key attributes,
602
+ regardless of their values (they may be None until set_model_config() is called).
603
+
604
+ Returns
605
+ -------
606
+ List[str]
607
+ List of block names that have LLM-related attributes.
608
+ """
609
+ llm_blocks = []
610
+
611
+ for block in self.blocks:
612
+ block_type = block.__class__.__name__
613
+ block_name = block.block_name
614
+
615
+ # Check by attribute existence (not value) - LLM blocks have these attributes even if None
616
+ has_model_attr = hasattr(block, "model")
617
+ has_api_base_attr = hasattr(block, "api_base")
618
+ has_api_key_attr = hasattr(block, "api_key")
619
+
620
+ # A block is considered an LLM block if it has any LLM-related attributes
621
+ is_llm_block = has_model_attr or has_api_base_attr or has_api_key_attr
622
+
623
+ if is_llm_block:
624
+ llm_blocks.append(block_name)
625
+ logger.debug(
626
+ f"Detected LLM block '{block_name}' ({block_type}): "
627
+ f"has_model_attr={has_model_attr}, has_api_base_attr={has_api_base_attr}, has_api_key_attr={has_api_key_attr}"
628
+ )
629
+
630
+ return llm_blocks
631
+
632
+ def is_model_config_required(self) -> bool:
633
+ """Check if model configuration is required for this flow.
634
+
635
+ Returns
636
+ -------
637
+ bool
638
+ True if flow has LLM blocks and needs model configuration.
639
+ """
640
+ return len(self._detect_llm_blocks()) > 0
641
+
642
+ def is_model_config_set(self) -> bool:
643
+ """Check if model configuration has been set.
644
+
645
+ Returns
646
+ -------
647
+ bool
648
+ True if model configuration has been set or is not required.
649
+ """
650
+ return self._model_config_set
651
+
652
+ def reset_model_config(self) -> None:
653
+ """Reset model configuration flag (useful for testing or reconfiguration).
654
+
655
+ After calling this, set_model_config() must be called again before generate().
656
+ """
657
+ if self.is_model_config_required():
658
+ self._model_config_set = False
659
+ logger.info(
660
+ "Model configuration flag reset - call set_model_config() before generate()"
661
+ )
662
+
663
+ def get_default_model(self) -> Optional[str]:
664
+ """Get the default recommended model for this flow.
665
+
666
+ Returns
667
+ -------
668
+ Optional[str]
669
+ Default model name, or None if no models specified.
670
+
671
+ Examples
672
+ --------
673
+ >>> flow = Flow.from_yaml("path/to/flow.yaml")
674
+ >>> default_model = flow.get_default_model()
675
+ >>> print(f"Default model: {default_model}")
676
+ """
677
+ if not self.metadata.recommended_models:
678
+ return None
679
+ return self.metadata.recommended_models.default
680
+
681
+ def get_model_recommendations(self) -> dict[str, Any]:
682
+ """Get a clean summary of model recommendations for this flow.
683
+
684
+ Returns
685
+ -------
686
+ Dict[str, Any]
687
+ Dictionary with model recommendations in user-friendly format.
688
+
689
+ Examples
690
+ --------
691
+ >>> flow = Flow.from_yaml("path/to/flow.yaml")
692
+ >>> recommendations = flow.get_model_recommendations()
693
+ >>> print("Model recommendations:")
694
+ >>> print(f" Default: {recommendations['default']}")
695
+ >>> print(f" Compatible: {recommendations['compatible']}")
696
+ >>> print(f" Experimental: {recommendations['experimental']}")
697
+ """
698
+ if not self.metadata.recommended_models:
699
+ return {
700
+ "default": None,
701
+ "compatible": [],
702
+ "experimental": [],
703
+ }
704
+
705
+ return {
706
+ "default": self.metadata.recommended_models.default,
707
+ "compatible": self.metadata.recommended_models.compatible,
708
+ "experimental": self.metadata.recommended_models.experimental,
709
+ }
710
+
711
+ def validate_dataset(self, dataset: Dataset) -> list[str]:
712
+ """Validate dataset against flow requirements."""
713
+ errors = []
714
+
715
+ if len(dataset) == 0:
716
+ errors.append("Dataset is empty")
717
+
718
+ if self.metadata.dataset_requirements:
719
+ errors.extend(
720
+ self.metadata.dataset_requirements.validate_dataset(
721
+ dataset.column_names, len(dataset)
722
+ )
723
+ )
724
+
725
+ return errors
726
+
727
+ def dry_run(
728
+ self,
729
+ dataset: Dataset,
730
+ sample_size: int = 2,
731
+ runtime_params: Optional[dict[str, dict[str, Any]]] = None,
732
+ ) -> dict[str, Any]:
733
+ """Perform a dry run of the flow with a subset of data.
734
+
735
+ Parameters
736
+ ----------
737
+ dataset : Dataset
738
+ Input dataset to test with.
739
+ sample_size : int, default=2
740
+ Number of samples to use for dry run testing.
741
+ runtime_params : Optional[Dict[str, Dict[str, Any]]], optional
742
+ Runtime parameters organized by block name.
743
+
744
+ Returns
745
+ -------
746
+ Dict[str, Any]
747
+ Dry run results with execution info and sample outputs.
748
+
749
+ Raises
750
+ ------
751
+ EmptyDatasetError
752
+ If input dataset is empty.
753
+ FlowValidationError
754
+ If any block fails during dry run execution.
755
+ """
756
+ # Validate preconditions
757
+ if not self.blocks:
758
+ raise FlowValidationError("Cannot dry run empty flow")
759
+
760
+ if len(dataset) == 0:
761
+ raise EmptyDatasetError("Input dataset is empty")
762
+
763
+ # Use smaller sample size if dataset is smaller
764
+ actual_sample_size = min(sample_size, len(dataset))
765
+
766
+ logger.info(
767
+ f"Starting dry run for flow '{self.metadata.name}' "
768
+ f"with {actual_sample_size} samples"
769
+ )
770
+
771
+ # Create subset dataset
772
+ sample_dataset = dataset.select(range(actual_sample_size))
773
+
774
+ # Initialize dry run results
775
+ dry_run_results = {
776
+ "flow_name": self.metadata.name,
777
+ "flow_version": self.metadata.version,
778
+ "sample_size": actual_sample_size,
779
+ "original_dataset_size": len(dataset),
780
+ "input_columns": dataset.column_names,
781
+ "blocks_executed": [],
782
+ "final_dataset": None,
783
+ "execution_successful": True,
784
+ "execution_time_seconds": 0,
785
+ }
786
+
787
+ # Standard
788
+ import time
789
+
790
+ start_time = time.time()
791
+
792
+ try:
793
+ # Execute the flow with sample data
794
+ current_dataset = sample_dataset
795
+ runtime_params = runtime_params or {}
796
+
797
+ for i, block in enumerate(self.blocks):
798
+ block_start_time = time.time()
799
+ input_rows = len(current_dataset)
800
+
801
+ logger.info(
802
+ f"Dry run executing block {i + 1}/{len(self.blocks)}: "
803
+ f"{block.block_name} ({block.__class__.__name__})"
804
+ )
805
+
806
+ # Prepare block execution parameters
807
+ block_kwargs = self._prepare_block_kwargs(block, runtime_params)
808
+
809
+ # Check if this is a deprecated block and skip validations
810
+ is_deprecated_block = (
811
+ hasattr(block, "__class__")
812
+ and hasattr(block.__class__, "__module__")
813
+ and "deprecated_blocks" in block.__class__.__module__
814
+ )
815
+
816
+ if is_deprecated_block:
817
+ logger.debug(
818
+ f"Dry run: Skipping validations for deprecated block: {block.block_name}"
819
+ )
820
+ # Call generate() directly to skip validations, but keep the runtime params
821
+ current_dataset = block.generate(current_dataset, **block_kwargs)
822
+ else:
823
+ # Execute block with validation and logging
824
+ current_dataset = block(current_dataset, **block_kwargs)
825
+
826
+ block_execution_time = time.time() - block_start_time
827
+
828
+ # Record block execution info
829
+ block_info = {
830
+ "block_name": block.block_name,
831
+ "block_type": block.__class__.__name__,
832
+ "execution_time_seconds": block_execution_time,
833
+ "input_rows": input_rows,
834
+ "output_rows": len(current_dataset),
835
+ "output_columns": current_dataset.column_names,
836
+ "parameters_used": block_kwargs,
837
+ }
838
+
839
+ dry_run_results["blocks_executed"].append(block_info)
840
+
841
+ logger.info(
842
+ f"Dry run block '{block.block_name}' completed: "
843
+ f"{len(current_dataset)} samples, "
844
+ f"{len(current_dataset.column_names)} columns, "
845
+ f"{block_execution_time:.2f}s"
846
+ )
847
+
848
+ # Store final results
849
+ dry_run_results["final_dataset"] = {
850
+ "rows": len(current_dataset),
851
+ "columns": current_dataset.column_names,
852
+ "sample_data": current_dataset.to_dict()
853
+ if len(current_dataset) > 0
854
+ else {},
855
+ }
856
+
857
+ execution_time = time.time() - start_time
858
+ dry_run_results["execution_time_seconds"] = execution_time
859
+
860
+ logger.info(
861
+ f"Dry run completed successfully for flow '{self.metadata.name}' "
862
+ f"in {execution_time:.2f}s"
863
+ )
864
+
865
+ return dry_run_results
866
+
867
+ except Exception as exc:
868
+ execution_time = time.time() - start_time
869
+ dry_run_results["execution_successful"] = False
870
+ dry_run_results["execution_time_seconds"] = execution_time
871
+ dry_run_results["error"] = str(exc)
872
+
873
+ logger.error(f"Dry run failed for flow '{self.metadata.name}': {exc}")
874
+
875
+ raise FlowValidationError(f"Dry run failed: {exc}") from exc
876
+
877
+ def add_block(self, block: BaseBlock) -> "Flow":
878
+ """Add a block to the flow, returning a new Flow instance.
879
+
880
+ Parameters
881
+ ----------
882
+ block : BaseBlock
883
+ Block to add to the flow.
884
+
885
+ Returns
886
+ -------
887
+ Flow
888
+ New Flow instance with the added block.
889
+
890
+ Raises
891
+ ------
892
+ ValueError
893
+ If the block is invalid or creates naming conflicts.
894
+ """
895
+ if not isinstance(block, BaseBlock):
896
+ raise ValueError(f"Block must be a BaseBlock instance, got: {type(block)}")
897
+
898
+ # Check for name conflicts
899
+ existing_names = {b.block_name for b in self.blocks}
900
+ if block.block_name in existing_names:
901
+ raise ValueError(
902
+ f"Block name '{block.block_name}' already exists in flow. "
903
+ f"Block names must be unique."
904
+ )
905
+
906
+ # Create new flow with added block
907
+ new_blocks = self.blocks + [block]
908
+
909
+ return Flow(
910
+ blocks=new_blocks, metadata=self.metadata, parameters=self.parameters
911
+ )
912
+
913
+ def get_info(self) -> dict[str, Any]:
914
+ """Get information about the flow."""
915
+ return {
916
+ "metadata": self.metadata.model_dump(),
917
+ "parameters": {
918
+ name: param.model_dump() for name, param in self.parameters.items()
919
+ },
920
+ "blocks": [
921
+ {
922
+ "block_type": block.__class__.__name__,
923
+ "block_name": block.block_name,
924
+ "input_cols": getattr(block, "input_cols", None),
925
+ "output_cols": getattr(block, "output_cols", None),
926
+ }
927
+ for block in self.blocks
928
+ ],
929
+ "total_blocks": len(self.blocks),
930
+ "block_names": [block.block_name for block in self.blocks],
931
+ }
932
+
933
+ def to_yaml(self, output_path: str) -> None:
934
+ """Save flow configuration to YAML file.
935
+
936
+ Note: This creates a basic YAML structure. For exact reproduction
937
+ of original YAML, save the original file separately.
938
+ """
939
+ config = {
940
+ "metadata": self.metadata.model_dump(),
941
+ "blocks": [
942
+ {
943
+ "block_type": block.__class__.__name__,
944
+ "block_config": block.model_dump(),
945
+ }
946
+ for block in self.blocks
947
+ ],
948
+ }
949
+
950
+ if self.parameters:
951
+ config["parameters"] = {
952
+ name: param.model_dump() for name, param in self.parameters.items()
953
+ }
954
+
955
+ with open(output_path, "w", encoding="utf-8") as f:
956
+ yaml.dump(config, f, default_flow_style=False, sort_keys=False)
957
+
958
+ logger.info(f"Flow configuration saved to: {output_path}")
959
+
960
+ def __len__(self) -> int:
961
+ """Number of blocks in the flow."""
962
+ return len(self.blocks)
963
+
964
+ def __repr__(self) -> str:
965
+ """String representation of the flow."""
966
+ return (
967
+ f"Flow(name='{self.metadata.name}', "
968
+ f"version='{self.metadata.version}', "
969
+ f"blocks={len(self.blocks)})"
970
+ )
971
+
972
+ def __str__(self) -> str:
973
+ """Human-readable string representation."""
974
+ block_names = [block.block_name for block in self.blocks]
975
+ return (
976
+ f"Flow '{self.metadata.name}' v{self.metadata.version}\n"
977
+ f"Blocks: {' -> '.join(block_names) if block_names else 'None'}\n"
978
+ f"Author: {self.metadata.author or 'Unknown'}\n"
979
+ f"Description: {self.metadata.description or 'No description'}"
980
+ )