sdg-hub 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/_version.py +2 -2
- sdg_hub/core/blocks/__init__.py +0 -22
- sdg_hub/core/blocks/transform/rename_columns.py +19 -0
- sdg_hub/core/flow/base.py +146 -81
- sdg_hub/core/utils/__init__.py +11 -3
- sdg_hub/core/utils/flow_metrics.py +116 -0
- sdg_hub/core/utils/time_estimator.py +344 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/flow.yaml +5 -1
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/flow.yaml +5 -1
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/flow.yaml +5 -1
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +6 -1
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/flow.yaml +16 -10
- {sdg_hub-0.4.1.dist-info → sdg_hub-0.5.0.dist-info}/METADATA +2 -2
- {sdg_hub-0.4.1.dist-info → sdg_hub-0.5.0.dist-info}/RECORD +17 -27
- sdg_hub/core/blocks/deprecated_blocks/__init__.py +0 -29
- sdg_hub/core/blocks/deprecated_blocks/combine_columns.py +0 -93
- sdg_hub/core/blocks/deprecated_blocks/duplicate_columns.py +0 -88
- sdg_hub/core/blocks/deprecated_blocks/filter_by_value.py +0 -103
- sdg_hub/core/blocks/deprecated_blocks/flatten_columns.py +0 -94
- sdg_hub/core/blocks/deprecated_blocks/llmblock.py +0 -479
- sdg_hub/core/blocks/deprecated_blocks/rename_columns.py +0 -88
- sdg_hub/core/blocks/deprecated_blocks/sample_populator.py +0 -58
- sdg_hub/core/blocks/deprecated_blocks/selector.py +0 -97
- sdg_hub/core/blocks/deprecated_blocks/set_to_majority_value.py +0 -88
- sdg_hub/core/flow/migration.py +0 -198
- {sdg_hub-0.4.1.dist-info → sdg_hub-0.5.0.dist-info}/WHEEL +0 -0
- {sdg_hub-0.4.1.dist-info → sdg_hub-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.4.1.dist-info → sdg_hub-0.5.0.dist-info}/top_level.txt +0 -0
sdg_hub/_version.py
CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
28
28
|
commit_id: COMMIT_ID
|
29
29
|
__commit_id__: COMMIT_ID
|
30
30
|
|
31
|
-
__version__ = version = '0.
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
31
|
+
__version__ = version = '0.5.0'
|
32
|
+
__version_tuple__ = version_tuple = (0, 5, 0)
|
33
33
|
|
34
34
|
__commit_id__ = commit_id = None
|
sdg_hub/core/blocks/__init__.py
CHANGED
@@ -5,17 +5,6 @@ This package provides various block implementations for data generation, process
|
|
5
5
|
|
6
6
|
# Local
|
7
7
|
from .base import BaseBlock
|
8
|
-
from .deprecated_blocks import (
|
9
|
-
CombineColumnsBlock,
|
10
|
-
DuplicateColumns,
|
11
|
-
FilterByValueBlock,
|
12
|
-
FlattenColumnsBlock,
|
13
|
-
LLMBlock,
|
14
|
-
RenameColumns,
|
15
|
-
SamplePopulatorBlock,
|
16
|
-
SelectorBlock,
|
17
|
-
SetToMajorityValue,
|
18
|
-
)
|
19
8
|
from .filtering import ColumnValueFilterBlock
|
20
9
|
from .llm import LLMChatBlock, LLMParserBlock, PromptBuilderBlock, TextParserBlock
|
21
10
|
from .registry import BlockRegistry
|
@@ -28,8 +17,6 @@ from .transform import (
|
|
28
17
|
UniformColumnValueSetter,
|
29
18
|
)
|
30
19
|
|
31
|
-
# All blocks moved to deprecated_blocks or transform modules
|
32
|
-
|
33
20
|
__all__ = [
|
34
21
|
"BaseBlock",
|
35
22
|
"BlockRegistry",
|
@@ -40,15 +27,6 @@ __all__ = [
|
|
40
27
|
"RenameColumnsBlock",
|
41
28
|
"TextConcatBlock",
|
42
29
|
"UniformColumnValueSetter",
|
43
|
-
"CombineColumnsBlock", # Deprecated
|
44
|
-
"DuplicateColumns", # Deprecated
|
45
|
-
"FilterByValueBlock", # Deprecated
|
46
|
-
"FlattenColumnsBlock", # Deprecated
|
47
|
-
"RenameColumns", # Deprecated
|
48
|
-
"SamplePopulatorBlock", # Deprecated
|
49
|
-
"SelectorBlock", # Deprecated
|
50
|
-
"SetToMajorityValue", # Deprecated
|
51
|
-
"LLMBlock", # Deprecated
|
52
30
|
"LLMChatBlock",
|
53
31
|
"LLMParserBlock",
|
54
32
|
"TextParserBlock",
|
@@ -64,6 +64,25 @@ class RenameColumnsBlock(BaseBlock):
|
|
64
64
|
-------
|
65
65
|
Dataset
|
66
66
|
Dataset with renamed columns.
|
67
|
+
|
68
|
+
Raises
|
69
|
+
------
|
70
|
+
ValueError
|
71
|
+
If attempting to rename to a column name that already exists.
|
67
72
|
"""
|
73
|
+
# Check for column name collisions
|
74
|
+
# Strict validation: no target column name can be an existing column name
|
75
|
+
# This prevents chained/circular renames which can be confusing
|
76
|
+
existing_cols = set(samples.column_names)
|
77
|
+
target_cols = set(self.input_cols.values())
|
78
|
+
|
79
|
+
collision = target_cols & existing_cols
|
80
|
+
if collision:
|
81
|
+
raise ValueError(
|
82
|
+
f"Cannot rename to existing column names: {sorted(collision)}. "
|
83
|
+
"Target column names must not already exist in the dataset. "
|
84
|
+
"Chained renames are not supported."
|
85
|
+
)
|
86
|
+
|
68
87
|
# Rename columns using HuggingFace datasets method
|
69
88
|
return samples.rename_columns(self.input_cols)
|
sdg_hub/core/flow/base.py
CHANGED
@@ -30,13 +30,17 @@ from ..blocks.base import BaseBlock
|
|
30
30
|
from ..blocks.registry import BlockRegistry
|
31
31
|
from ..utils.datautils import safe_concatenate_with_validation, validate_no_duplicates
|
32
32
|
from ..utils.error_handling import EmptyDatasetError, FlowValidationError
|
33
|
-
from ..utils.flow_metrics import
|
33
|
+
from ..utils.flow_metrics import (
|
34
|
+
display_metrics_summary,
|
35
|
+
display_time_estimation_summary,
|
36
|
+
save_metrics_to_json,
|
37
|
+
)
|
34
38
|
from ..utils.logger_config import setup_logger
|
35
39
|
from ..utils.path_resolution import resolve_path
|
40
|
+
from ..utils.time_estimator import estimate_execution_time
|
36
41
|
from ..utils.yaml_utils import save_flow_yaml
|
37
42
|
from .checkpointer import FlowCheckpointer
|
38
43
|
from .metadata import DatasetRequirements, FlowMetadata
|
39
|
-
from .migration import FlowMigration
|
40
44
|
from .validation import FlowValidator
|
41
45
|
|
42
46
|
logger = setup_logger(__name__)
|
@@ -68,8 +72,6 @@ class Flow(BaseModel):
|
|
68
72
|
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
69
73
|
|
70
74
|
# Private attributes (not serialized)
|
71
|
-
_migrated_runtime_params: dict[str, dict[str, Any]] = {}
|
72
|
-
_llm_client: Any = None # Only used for backward compatibility with old YAMLs
|
73
75
|
_model_config_set: bool = False # Track if model configuration has been set
|
74
76
|
_block_metrics: list[dict[str, Any]] = PrivateAttr(
|
75
77
|
default_factory=list
|
@@ -108,16 +110,13 @@ class Flow(BaseModel):
|
|
108
110
|
return self
|
109
111
|
|
110
112
|
@classmethod
|
111
|
-
def from_yaml(cls, yaml_path: str
|
113
|
+
def from_yaml(cls, yaml_path: str) -> "Flow":
|
112
114
|
"""Load flow from YAML configuration file.
|
113
115
|
|
114
116
|
Parameters
|
115
117
|
----------
|
116
118
|
yaml_path : str
|
117
119
|
Path to the YAML flow configuration file.
|
118
|
-
client : Any, optional
|
119
|
-
LLM client instance. Required for backward compatibility with old format YAMLs
|
120
|
-
that use deprecated LLMBlocks. Ignored for new format YAMLs.
|
121
120
|
|
122
121
|
Returns
|
123
122
|
-------
|
@@ -148,21 +147,6 @@ class Flow(BaseModel):
|
|
148
147
|
except yaml.YAMLError as exc:
|
149
148
|
raise FlowValidationError(f"Invalid YAML in {yaml_path}: {exc}") from exc
|
150
149
|
|
151
|
-
# Check if this is an old format flow and migrate if necessary
|
152
|
-
migrated_runtime_params = None
|
153
|
-
is_old_format = FlowMigration.is_old_format(flow_config)
|
154
|
-
if is_old_format:
|
155
|
-
logger.info(f"Detected old format flow, migrating: {yaml_path}")
|
156
|
-
if client is None:
|
157
|
-
logger.warning(
|
158
|
-
"Old format YAML detected but no client provided. LLMBlocks may fail."
|
159
|
-
)
|
160
|
-
flow_config, migrated_runtime_params = FlowMigration.migrate_to_new_format(
|
161
|
-
flow_config, yaml_path
|
162
|
-
)
|
163
|
-
# Save migrated config back to YAML to persist id
|
164
|
-
save_flow_yaml(yaml_path, flow_config, "migrated to new format")
|
165
|
-
|
166
150
|
# Validate YAML structure
|
167
151
|
validator = FlowValidator()
|
168
152
|
validation_errors = validator.validate_yaml_structure(flow_config)
|
@@ -189,19 +173,6 @@ class Flow(BaseModel):
|
|
189
173
|
|
190
174
|
for i, block_config in enumerate(block_configs):
|
191
175
|
try:
|
192
|
-
# Inject client for deprecated LLMBlocks if this is an old format flow
|
193
|
-
if (
|
194
|
-
is_old_format
|
195
|
-
and block_config.get("block_type") == "LLMBlock"
|
196
|
-
and client is not None
|
197
|
-
):
|
198
|
-
if "block_config" not in block_config:
|
199
|
-
block_config["block_config"] = {}
|
200
|
-
block_config["block_config"]["client"] = client
|
201
|
-
logger.debug(
|
202
|
-
f"Injected client for deprecated LLMBlock: {block_config['block_config'].get('block_name')}"
|
203
|
-
)
|
204
|
-
|
205
176
|
block = cls._create_block_from_config(block_config, yaml_dir)
|
206
177
|
blocks.append(block)
|
207
178
|
except Exception as exc:
|
@@ -223,12 +194,6 @@ class Flow(BaseModel):
|
|
223
194
|
)
|
224
195
|
else:
|
225
196
|
logger.debug(f"Flow already had id: {flow.metadata.id}")
|
226
|
-
# Store migrated runtime params and client for backward compatibility
|
227
|
-
if migrated_runtime_params:
|
228
|
-
flow._migrated_runtime_params = migrated_runtime_params
|
229
|
-
if is_old_format and client is not None:
|
230
|
-
flow._llm_client = client
|
231
|
-
|
232
197
|
# Check if this is a flow without LLM blocks
|
233
198
|
llm_blocks = flow._detect_llm_blocks()
|
234
199
|
if not llm_blocks:
|
@@ -479,12 +444,6 @@ class Flow(BaseModel):
|
|
479
444
|
self._block_metrics = []
|
480
445
|
run_start = time.perf_counter()
|
481
446
|
|
482
|
-
# Merge migrated runtime params with provided ones (provided ones take precedence)
|
483
|
-
merged_runtime_params = self._migrated_runtime_params.copy()
|
484
|
-
if runtime_params:
|
485
|
-
merged_runtime_params.update(runtime_params)
|
486
|
-
runtime_params = merged_runtime_params
|
487
|
-
|
488
447
|
# Execute flow with metrics capture, ensuring metrics are always displayed/saved
|
489
448
|
final_dataset = None
|
490
449
|
execution_successful = False
|
@@ -642,22 +601,8 @@ class Flow(BaseModel):
|
|
642
601
|
input_cols = set(current_dataset.column_names)
|
643
602
|
|
644
603
|
try:
|
645
|
-
#
|
646
|
-
|
647
|
-
hasattr(block, "__class__")
|
648
|
-
and hasattr(block.__class__, "__module__")
|
649
|
-
and "deprecated_blocks" in block.__class__.__module__
|
650
|
-
)
|
651
|
-
|
652
|
-
if is_deprecated_block:
|
653
|
-
exec_logger.debug(
|
654
|
-
f"Skipping validations for deprecated block: {block.block_name}"
|
655
|
-
)
|
656
|
-
# Call generate() directly to skip validations, but keep the runtime params
|
657
|
-
current_dataset = block.generate(current_dataset, **block_kwargs)
|
658
|
-
else:
|
659
|
-
# Execute block with validation and logging
|
660
|
-
current_dataset = block(current_dataset, **block_kwargs)
|
604
|
+
# Execute block with validation and logging
|
605
|
+
current_dataset = block(current_dataset, **block_kwargs)
|
661
606
|
|
662
607
|
# Validate output
|
663
608
|
if len(current_dataset) == 0:
|
@@ -719,9 +664,11 @@ class Flow(BaseModel):
|
|
719
664
|
return current_dataset
|
720
665
|
|
721
666
|
def _prepare_block_kwargs(
|
722
|
-
self, block: BaseBlock, runtime_params: dict[str, dict[str, Any]]
|
667
|
+
self, block: BaseBlock, runtime_params: Optional[dict[str, dict[str, Any]]]
|
723
668
|
) -> dict[str, Any]:
|
724
669
|
"""Prepare execution parameters for a block."""
|
670
|
+
if runtime_params is None:
|
671
|
+
return {}
|
725
672
|
return runtime_params.get(block.block_name, {})
|
726
673
|
|
727
674
|
def set_model_config(
|
@@ -1006,6 +953,8 @@ class Flow(BaseModel):
|
|
1006
953
|
dataset: Dataset,
|
1007
954
|
sample_size: int = 2,
|
1008
955
|
runtime_params: Optional[dict[str, dict[str, Any]]] = None,
|
956
|
+
max_concurrency: Optional[int] = None,
|
957
|
+
enable_time_estimation: bool = False,
|
1009
958
|
) -> dict[str, Any]:
|
1010
959
|
"""Perform a dry run of the flow with a subset of data.
|
1011
960
|
|
@@ -1017,11 +966,18 @@ class Flow(BaseModel):
|
|
1017
966
|
Number of samples to use for dry run testing.
|
1018
967
|
runtime_params : Optional[Dict[str, Dict[str, Any]]], optional
|
1019
968
|
Runtime parameters organized by block name.
|
969
|
+
max_concurrency : Optional[int], optional
|
970
|
+
Maximum concurrent requests for LLM blocks. If None, no limit is applied.
|
971
|
+
enable_time_estimation : bool, default=False
|
972
|
+
If True, estimates execution time for the full dataset and displays it
|
973
|
+
in a Rich table. Automatically runs a second dry run if needed for
|
974
|
+
accurate scaling analysis.
|
1020
975
|
|
1021
976
|
Returns
|
1022
977
|
-------
|
1023
978
|
Dict[str, Any]
|
1024
979
|
Dry run results with execution info and sample outputs.
|
980
|
+
Time estimation is displayed in a table but not included in return value.
|
1025
981
|
|
1026
982
|
Raises
|
1027
983
|
------
|
@@ -1039,6 +995,19 @@ class Flow(BaseModel):
|
|
1039
995
|
|
1040
996
|
validate_no_duplicates(dataset)
|
1041
997
|
|
998
|
+
# Validate max_concurrency parameter
|
999
|
+
if max_concurrency is not None:
|
1000
|
+
if isinstance(max_concurrency, bool) or not isinstance(
|
1001
|
+
max_concurrency, int
|
1002
|
+
):
|
1003
|
+
raise FlowValidationError(
|
1004
|
+
f"max_concurrency must be an int, got {type(max_concurrency).__name__}"
|
1005
|
+
)
|
1006
|
+
if max_concurrency <= 0:
|
1007
|
+
raise FlowValidationError(
|
1008
|
+
f"max_concurrency must be greater than 0, got {max_concurrency}"
|
1009
|
+
)
|
1010
|
+
|
1042
1011
|
# Use smaller sample size if dataset is smaller
|
1043
1012
|
actual_sample_size = min(sample_size, len(dataset))
|
1044
1013
|
|
@@ -1056,6 +1025,7 @@ class Flow(BaseModel):
|
|
1056
1025
|
"flow_version": self.metadata.version,
|
1057
1026
|
"sample_size": actual_sample_size,
|
1058
1027
|
"original_dataset_size": len(dataset),
|
1028
|
+
"max_concurrency": max_concurrency,
|
1059
1029
|
"input_columns": dataset.column_names,
|
1060
1030
|
"blocks_executed": [],
|
1061
1031
|
"final_dataset": None,
|
@@ -1082,24 +1052,16 @@ class Flow(BaseModel):
|
|
1082
1052
|
# Prepare block execution parameters
|
1083
1053
|
block_kwargs = self._prepare_block_kwargs(block, runtime_params)
|
1084
1054
|
|
1085
|
-
#
|
1086
|
-
|
1087
|
-
|
1088
|
-
and hasattr(block.__class__, "__module__")
|
1089
|
-
and "deprecated_blocks" in block.__class__.__module__
|
1090
|
-
)
|
1055
|
+
# Add max_concurrency to block kwargs if provided
|
1056
|
+
if max_concurrency is not None:
|
1057
|
+
block_kwargs["_flow_max_concurrency"] = max_concurrency
|
1091
1058
|
|
1092
|
-
|
1093
|
-
|
1094
|
-
f"Dry run: Skipping validations for deprecated block: {block.block_name}"
|
1095
|
-
)
|
1096
|
-
# Call generate() directly to skip validations, but keep the runtime params
|
1097
|
-
current_dataset = block.generate(current_dataset, **block_kwargs)
|
1098
|
-
else:
|
1099
|
-
# Execute block with validation and logging
|
1100
|
-
current_dataset = block(current_dataset, **block_kwargs)
|
1059
|
+
# Execute block with validation and logging
|
1060
|
+
current_dataset = block(current_dataset, **block_kwargs)
|
1101
1061
|
|
1102
|
-
block_execution_time =
|
1062
|
+
block_execution_time = (
|
1063
|
+
time.perf_counter() - block_start_time
|
1064
|
+
) # Fixed: use perf_counter consistently
|
1103
1065
|
|
1104
1066
|
# Record block execution info
|
1105
1067
|
block_info = {
|
@@ -1138,6 +1100,12 @@ class Flow(BaseModel):
|
|
1138
1100
|
f"in {execution_time:.2f}s"
|
1139
1101
|
)
|
1140
1102
|
|
1103
|
+
# Perform time estimation if requested (displays table but doesn't store in results)
|
1104
|
+
if enable_time_estimation:
|
1105
|
+
self._estimate_total_time(
|
1106
|
+
dry_run_results, dataset, runtime_params, max_concurrency
|
1107
|
+
)
|
1108
|
+
|
1141
1109
|
return dry_run_results
|
1142
1110
|
|
1143
1111
|
except Exception as exc:
|
@@ -1150,6 +1118,103 @@ class Flow(BaseModel):
|
|
1150
1118
|
|
1151
1119
|
raise FlowValidationError(f"Dry run failed: {exc}") from exc
|
1152
1120
|
|
1121
|
+
def _estimate_total_time(
|
1122
|
+
self,
|
1123
|
+
first_run_results: dict[str, Any],
|
1124
|
+
dataset: Dataset,
|
1125
|
+
runtime_params: Optional[dict[str, dict[str, Any]]],
|
1126
|
+
max_concurrency: Optional[int],
|
1127
|
+
) -> dict[str, Any]:
|
1128
|
+
"""Estimate execution time using 2 dry runs (private method).
|
1129
|
+
|
1130
|
+
This method contains all the estimation logic. It determines if a second
|
1131
|
+
dry run is needed, executes it, and calls estimate_execution_time.
|
1132
|
+
|
1133
|
+
Parameters
|
1134
|
+
----------
|
1135
|
+
first_run_results : dict
|
1136
|
+
Results from the first dry run.
|
1137
|
+
dataset : Dataset
|
1138
|
+
Full dataset for estimation.
|
1139
|
+
runtime_params : Optional[dict]
|
1140
|
+
Runtime parameters.
|
1141
|
+
max_concurrency : Optional[int]
|
1142
|
+
Maximum concurrency.
|
1143
|
+
|
1144
|
+
Returns
|
1145
|
+
-------
|
1146
|
+
dict
|
1147
|
+
Estimation results with estimated_time_seconds, total_estimated_requests, etc.
|
1148
|
+
"""
|
1149
|
+
first_sample_size = first_run_results["sample_size"]
|
1150
|
+
|
1151
|
+
# Check if we need a second dry run
|
1152
|
+
has_async_blocks = any(
|
1153
|
+
getattr(block, "async_mode", False) for block in self.blocks
|
1154
|
+
)
|
1155
|
+
|
1156
|
+
# For sequential or no async blocks, single run is sufficient
|
1157
|
+
if max_concurrency == 1 or not has_async_blocks:
|
1158
|
+
estimation = estimate_execution_time(
|
1159
|
+
dry_run_1=first_run_results,
|
1160
|
+
dry_run_2=None,
|
1161
|
+
total_dataset_size=len(dataset),
|
1162
|
+
max_concurrency=max_concurrency,
|
1163
|
+
)
|
1164
|
+
else:
|
1165
|
+
# Need second measurement - always use canonical (1, 5) pair
|
1166
|
+
if first_sample_size == 1:
|
1167
|
+
# Already have 1, need 5
|
1168
|
+
logger.info("Running second dry run with 5 samples for time estimation")
|
1169
|
+
second_run = self.dry_run(
|
1170
|
+
dataset,
|
1171
|
+
5,
|
1172
|
+
runtime_params,
|
1173
|
+
max_concurrency,
|
1174
|
+
enable_time_estimation=False,
|
1175
|
+
)
|
1176
|
+
dry_run_1, dry_run_2 = first_run_results, second_run
|
1177
|
+
elif first_sample_size == 5:
|
1178
|
+
# Already have 5, need 1
|
1179
|
+
logger.info("Running second dry run with 1 sample for time estimation")
|
1180
|
+
second_run = self.dry_run(
|
1181
|
+
dataset,
|
1182
|
+
1,
|
1183
|
+
runtime_params,
|
1184
|
+
max_concurrency,
|
1185
|
+
enable_time_estimation=False,
|
1186
|
+
)
|
1187
|
+
dry_run_1, dry_run_2 = second_run, first_run_results
|
1188
|
+
else:
|
1189
|
+
# For other sizes: run both 1 and 5 for canonical pair
|
1190
|
+
logger.info("Running dry runs with 1 and 5 samples for time estimation")
|
1191
|
+
dry_run_1 = self.dry_run(
|
1192
|
+
dataset,
|
1193
|
+
1,
|
1194
|
+
runtime_params,
|
1195
|
+
max_concurrency,
|
1196
|
+
enable_time_estimation=False,
|
1197
|
+
)
|
1198
|
+
dry_run_2 = self.dry_run(
|
1199
|
+
dataset,
|
1200
|
+
5,
|
1201
|
+
runtime_params,
|
1202
|
+
max_concurrency,
|
1203
|
+
enable_time_estimation=False,
|
1204
|
+
)
|
1205
|
+
|
1206
|
+
estimation = estimate_execution_time(
|
1207
|
+
dry_run_1=dry_run_1,
|
1208
|
+
dry_run_2=dry_run_2,
|
1209
|
+
total_dataset_size=len(dataset),
|
1210
|
+
max_concurrency=max_concurrency,
|
1211
|
+
)
|
1212
|
+
|
1213
|
+
# Display estimation summary
|
1214
|
+
display_time_estimation_summary(estimation, len(dataset), max_concurrency)
|
1215
|
+
|
1216
|
+
return estimation
|
1217
|
+
|
1153
1218
|
def add_block(self, block: BaseBlock) -> "Flow":
|
1154
1219
|
"""Add a block to the flow, returning a new Flow instance.
|
1155
1220
|
|
sdg_hub/core/utils/__init__.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
2
2
|
|
3
3
|
# Local
|
4
|
-
from .flow_identifier import get_flow_identifier
|
5
|
-
from .path_resolution import resolve_path
|
4
|
+
from .flow_identifier import get_flow_identifier as get_flow_identifier
|
5
|
+
from .path_resolution import resolve_path as resolve_path
|
6
|
+
from .time_estimator import estimate_execution_time as estimate_execution_time
|
7
|
+
from .time_estimator import is_llm_using_block as is_llm_using_block
|
6
8
|
|
7
9
|
|
8
10
|
# This is part of the public API, and used by instructlab
|
@@ -10,4 +12,10 @@ class GenerateError(Exception):
|
|
10
12
|
"""An exception raised during generate step."""
|
11
13
|
|
12
14
|
|
13
|
-
__all__ = [
|
15
|
+
__all__ = [
|
16
|
+
"GenerateError",
|
17
|
+
"resolve_path",
|
18
|
+
"get_flow_identifier",
|
19
|
+
"estimate_execution_time",
|
20
|
+
"is_llm_using_block",
|
21
|
+
]
|
@@ -188,6 +188,122 @@ def display_metrics_summary(
|
|
188
188
|
console.print()
|
189
189
|
|
190
190
|
|
191
|
+
def display_time_estimation_summary(
|
192
|
+
time_estimation: dict[str, Any],
|
193
|
+
dataset_size: int,
|
194
|
+
max_concurrency: Optional[int] = None,
|
195
|
+
) -> None:
|
196
|
+
"""Display a rich table summarizing time estimation results.
|
197
|
+
|
198
|
+
Parameters
|
199
|
+
----------
|
200
|
+
time_estimation : dict[str, Any]
|
201
|
+
Time estimation results from estimate_total_time().
|
202
|
+
dataset_size : int
|
203
|
+
Total number of samples in the dataset.
|
204
|
+
max_concurrency : Optional[int], optional
|
205
|
+
Maximum concurrency used for estimation.
|
206
|
+
"""
|
207
|
+
console = Console()
|
208
|
+
|
209
|
+
# Create main summary table
|
210
|
+
summary_table = Table(
|
211
|
+
show_header=False,
|
212
|
+
box=None,
|
213
|
+
padding=(0, 1),
|
214
|
+
)
|
215
|
+
summary_table.add_column("Metric", style="bright_cyan")
|
216
|
+
summary_table.add_column("Value", style="bright_white")
|
217
|
+
|
218
|
+
# Format time
|
219
|
+
est_seconds = time_estimation["estimated_time_seconds"]
|
220
|
+
if est_seconds < 60:
|
221
|
+
time_str = f"{est_seconds:.1f} seconds"
|
222
|
+
elif est_seconds < 3600:
|
223
|
+
time_str = f"{est_seconds / 60:.1f} minutes ({est_seconds / 3600:.2f} hours)"
|
224
|
+
else:
|
225
|
+
time_str = f"{est_seconds / 3600:.2f} hours ({est_seconds / 60:.0f} minutes)"
|
226
|
+
|
227
|
+
summary_table.add_row("Estimated Time:", time_str)
|
228
|
+
summary_table.add_row(
|
229
|
+
"Total LLM Requests:", f"{time_estimation.get('total_estimated_requests', 0):,}"
|
230
|
+
)
|
231
|
+
|
232
|
+
if time_estimation.get("total_estimated_requests", 0) > 0:
|
233
|
+
requests_per_sample = time_estimation["total_estimated_requests"] / dataset_size
|
234
|
+
summary_table.add_row("Requests per Sample:", f"{requests_per_sample:.1f}")
|
235
|
+
|
236
|
+
if max_concurrency is not None:
|
237
|
+
summary_table.add_row("Max Concurrency:", str(max_concurrency))
|
238
|
+
|
239
|
+
# Display summary panel
|
240
|
+
console.print()
|
241
|
+
console.print(
|
242
|
+
Panel(
|
243
|
+
summary_table,
|
244
|
+
title=f"[bold bright_white]Time Estimation for {dataset_size:,} Samples[/bold bright_white]",
|
245
|
+
border_style="bright_blue",
|
246
|
+
)
|
247
|
+
)
|
248
|
+
|
249
|
+
# Display per-block breakdown if available
|
250
|
+
block_estimates = time_estimation.get("block_estimates", [])
|
251
|
+
if block_estimates:
|
252
|
+
console.print()
|
253
|
+
|
254
|
+
# Create per-block table
|
255
|
+
block_table = Table(
|
256
|
+
show_header=True,
|
257
|
+
header_style="bold bright_white",
|
258
|
+
)
|
259
|
+
block_table.add_column("Block Name", style="bright_cyan", width=20)
|
260
|
+
block_table.add_column("Time", justify="right", style="bright_yellow", width=10)
|
261
|
+
block_table.add_column(
|
262
|
+
"Requests", justify="right", style="bright_green", width=10
|
263
|
+
)
|
264
|
+
block_table.add_column(
|
265
|
+
"Throughput", justify="right", style="bright_blue", width=12
|
266
|
+
)
|
267
|
+
block_table.add_column(
|
268
|
+
"Amplif.", justify="right", style="bright_magenta", width=10
|
269
|
+
)
|
270
|
+
|
271
|
+
for block in block_estimates:
|
272
|
+
# Format time
|
273
|
+
block_seconds = block["estimated_time"]
|
274
|
+
if block_seconds < 60:
|
275
|
+
time_str = f"{block_seconds:.1f}s"
|
276
|
+
else:
|
277
|
+
time_str = f"{block_seconds / 60:.1f}min"
|
278
|
+
|
279
|
+
# Format requests
|
280
|
+
requests_str = f"{block['estimated_requests']:,.0f}"
|
281
|
+
|
282
|
+
# Format throughput
|
283
|
+
throughput_str = f"{block['throughput']:.2f}/s"
|
284
|
+
|
285
|
+
# Format amplification
|
286
|
+
amplif_str = f"{block['amplification']:.1f}x"
|
287
|
+
|
288
|
+
block_table.add_row(
|
289
|
+
block["block"],
|
290
|
+
time_str,
|
291
|
+
requests_str,
|
292
|
+
throughput_str,
|
293
|
+
amplif_str,
|
294
|
+
)
|
295
|
+
|
296
|
+
console.print(
|
297
|
+
Panel(
|
298
|
+
block_table,
|
299
|
+
title="[bold bright_white]Per-Block Breakdown[/bold bright_white]",
|
300
|
+
border_style="bright_blue",
|
301
|
+
)
|
302
|
+
)
|
303
|
+
|
304
|
+
console.print()
|
305
|
+
|
306
|
+
|
191
307
|
def save_metrics_to_json(
|
192
308
|
block_metrics: list[dict[str, Any]],
|
193
309
|
flow_name: str,
|