data-designer-engine 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/_version.py +34 -0
- data_designer/engine/analysis/column_profilers/base.py +49 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +153 -0
- data_designer/engine/analysis/column_profilers/registry.py +22 -0
- data_designer/engine/analysis/column_statistics.py +145 -0
- data_designer/engine/analysis/dataset_profiler.py +149 -0
- data_designer/engine/analysis/errors.py +9 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +234 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +132 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +122 -0
- data_designer/engine/column_generators/generators/embedding.py +35 -0
- data_designer/engine/column_generators/generators/expression.py +55 -0
- data_designer/engine/column_generators/generators/llm_completion.py +116 -0
- data_designer/engine/column_generators/generators/samplers.py +69 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +144 -0
- data_designer/engine/column_generators/generators/validation.py +140 -0
- data_designer/engine/column_generators/registry.py +60 -0
- data_designer/engine/column_generators/utils/errors.py +15 -0
- data_designer/engine/column_generators/utils/generator_classification.py +43 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +58 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +100 -0
- data_designer/engine/compiler.py +97 -0
- data_designer/engine/configurable_task.py +71 -0
- data_designer/engine/dataset_builders/artifact_storage.py +283 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +354 -0
- data_designer/engine/dataset_builders/errors.py +15 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +46 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +212 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +62 -0
- data_designer/engine/dataset_builders/utils/dag.py +62 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +200 -0
- data_designer/engine/dataset_builders/utils/errors.py +15 -0
- data_designer/engine/dataset_builders/utils/progress_tracker.py +122 -0
- data_designer/engine/errors.py +51 -0
- data_designer/engine/model_provider.py +77 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +300 -0
- data_designer/engine/models/facade.py +284 -0
- data_designer/engine/models/factory.py +42 -0
- data_designer/engine/models/litellm_overrides.py +179 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +235 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +62 -0
- data_designer/engine/models/parsers/types.py +84 -0
- data_designer/engine/models/recipes/base.py +81 -0
- data_designer/engine/models/recipes/response_recipes.py +293 -0
- data_designer/engine/models/registry.py +151 -0
- data_designer/engine/models/telemetry.py +362 -0
- data_designer/engine/models/usage.py +73 -0
- data_designer/engine/models/utils.py +101 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +65 -0
- data_designer/engine/processing/ginja/environment.py +463 -0
- data_designer/engine/processing/ginja/exceptions.py +56 -0
- data_designer/engine/processing/ginja/record.py +32 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +15 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +83 -0
- data_designer/engine/processing/gsonschema/types.py +10 -0
- data_designer/engine/processing/gsonschema/validators.py +202 -0
- data_designer/engine/processing/processors/base.py +13 -0
- data_designer/engine/processing/processors/drop_columns.py +42 -0
- data_designer/engine/processing/processors/registry.py +25 -0
- data_designer/engine/processing/processors/schema_transform.py +71 -0
- data_designer/engine/processing/utils.py +169 -0
- data_designer/engine/registry/base.py +99 -0
- data_designer/engine/registry/data_designer_registry.py +39 -0
- data_designer/engine/registry/errors.py +12 -0
- data_designer/engine/resources/managed_dataset_generator.py +39 -0
- data_designer/engine/resources/managed_dataset_repository.py +197 -0
- data_designer/engine/resources/managed_storage.py +65 -0
- data_designer/engine/resources/resource_provider.py +77 -0
- data_designer/engine/resources/seed_reader.py +154 -0
- data_designer/engine/sampling_gen/column.py +91 -0
- data_designer/engine/sampling_gen/constraints.py +100 -0
- data_designer/engine/sampling_gen/data_sources/base.py +217 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +12 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +347 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +90 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +171 -0
- data_designer/engine/sampling_gen/entities/errors.py +10 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +102 -0
- data_designer/engine/sampling_gen/entities/person.py +144 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +128 -0
- data_designer/engine/sampling_gen/errors.py +26 -0
- data_designer/engine/sampling_gen/generator.py +122 -0
- data_designer/engine/sampling_gen/jinja_utils.py +64 -0
- data_designer/engine/sampling_gen/people_gen.py +199 -0
- data_designer/engine/sampling_gen/person_constants.py +56 -0
- data_designer/engine/sampling_gen/schema.py +147 -0
- data_designer/engine/sampling_gen/schema_builder.py +61 -0
- data_designer/engine/sampling_gen/utils.py +46 -0
- data_designer/engine/secret_resolver.py +82 -0
- data_designer/engine/testing/__init__.py +12 -0
- data_designer/engine/testing/stubs.py +133 -0
- data_designer/engine/testing/utils.py +20 -0
- data_designer/engine/validation.py +367 -0
- data_designer/engine/validators/__init__.py +19 -0
- data_designer/engine/validators/base.py +38 -0
- data_designer/engine/validators/local_callable.py +39 -0
- data_designer/engine/validators/python.py +254 -0
- data_designer/engine/validators/remote.py +89 -0
- data_designer/engine/validators/sql.py +65 -0
- data_designer_engine-0.4.0.dist-info/METADATA +50 -0
- data_designer_engine-0.4.0.dist-info/RECORD +114 -0
- data_designer_engine-0.4.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from data_designer.config.column_types import DataDesignerColumnType
|
|
7
|
+
from data_designer.config.data_designer_config import DataDesignerConfig
|
|
8
|
+
from data_designer.config.processors import ProcessorConfig
|
|
9
|
+
from data_designer.engine.dataset_builders.multi_column_configs import (
|
|
10
|
+
DatasetBuilderColumnConfigT,
|
|
11
|
+
SamplerMultiColumnConfig,
|
|
12
|
+
SeedDatasetMultiColumnConfig,
|
|
13
|
+
)
|
|
14
|
+
from data_designer.engine.dataset_builders.utils.dag import topologically_sort_column_configs
|
|
15
|
+
from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def compile_dataset_builder_column_configs(config: DataDesignerConfig) -> list[DatasetBuilderColumnConfigT]:
|
|
19
|
+
seed_column_configs = []
|
|
20
|
+
sampler_column_configs = []
|
|
21
|
+
generated_column_configs = []
|
|
22
|
+
|
|
23
|
+
for column_config in topologically_sort_column_configs(config.columns):
|
|
24
|
+
if column_config.column_type == DataDesignerColumnType.SEED_DATASET:
|
|
25
|
+
seed_column_configs.append(column_config)
|
|
26
|
+
elif column_config.column_type == DataDesignerColumnType.SAMPLER:
|
|
27
|
+
sampler_column_configs.append(column_config)
|
|
28
|
+
else:
|
|
29
|
+
generated_column_configs.append(column_config)
|
|
30
|
+
|
|
31
|
+
compiled_column_configs = []
|
|
32
|
+
|
|
33
|
+
if len(seed_column_configs) > 0:
|
|
34
|
+
if config.seed_config is None:
|
|
35
|
+
raise ConfigCompilationError("🛑 Seed column configs require a seed configuration.")
|
|
36
|
+
compiled_column_configs.append(
|
|
37
|
+
SeedDatasetMultiColumnConfig(
|
|
38
|
+
columns=seed_column_configs,
|
|
39
|
+
source=config.seed_config.source,
|
|
40
|
+
sampling_strategy=config.seed_config.sampling_strategy,
|
|
41
|
+
selection_strategy=config.seed_config.selection_strategy,
|
|
42
|
+
)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
if len(sampler_column_configs) > 0:
|
|
46
|
+
compiled_column_configs.append(
|
|
47
|
+
SamplerMultiColumnConfig(
|
|
48
|
+
columns=sampler_column_configs,
|
|
49
|
+
constraints=config.constraints or [],
|
|
50
|
+
)
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
if len(generated_column_configs) > 0:
|
|
54
|
+
compiled_column_configs.extend(generated_column_configs)
|
|
55
|
+
|
|
56
|
+
return compiled_column_configs
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def compile_dataset_builder_processor_configs(
|
|
60
|
+
config: DataDesignerConfig,
|
|
61
|
+
) -> list[ProcessorConfig]:
|
|
62
|
+
return config.processors or []
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
from data_designer.config.column_types import ColumnConfigT
|
|
10
|
+
from data_designer.engine.column_generators.utils.generator_classification import column_type_used_in_execution_dag
|
|
11
|
+
from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
|
|
12
|
+
from data_designer.lazy_heavy_imports import nx
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import networkx as nx
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def topologically_sort_column_configs(column_configs: list[ColumnConfigT]) -> list[ColumnConfigT]:
|
|
21
|
+
dag = nx.DiGraph()
|
|
22
|
+
|
|
23
|
+
non_dag_column_config_list = [
|
|
24
|
+
col for col in column_configs if not column_type_used_in_execution_dag(col.column_type)
|
|
25
|
+
]
|
|
26
|
+
dag_column_config_dict = {
|
|
27
|
+
col.name: col for col in column_configs if column_type_used_in_execution_dag(col.column_type)
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
if len(dag_column_config_dict) == 0:
|
|
31
|
+
return non_dag_column_config_list
|
|
32
|
+
|
|
33
|
+
side_effect_dict = {n: list(c.side_effect_columns) for n, c in dag_column_config_dict.items()}
|
|
34
|
+
|
|
35
|
+
logger.info("⛓️ Sorting column configs into a Directed Acyclic Graph")
|
|
36
|
+
for name, col in dag_column_config_dict.items():
|
|
37
|
+
dag.add_node(name)
|
|
38
|
+
for req_col_name in col.required_columns:
|
|
39
|
+
if req_col_name in list(dag_column_config_dict.keys()):
|
|
40
|
+
logger.debug(f" |-- 🔗 `{name}` depends on `{req_col_name}`")
|
|
41
|
+
dag.add_edge(req_col_name, name)
|
|
42
|
+
|
|
43
|
+
# If the required column is a side effect of another column,
|
|
44
|
+
# add an edge from the parent column to the current column.
|
|
45
|
+
elif req_col_name in sum(side_effect_dict.values(), []):
|
|
46
|
+
for parent, cols in side_effect_dict.items():
|
|
47
|
+
if req_col_name in cols:
|
|
48
|
+
logger.debug(f" |-- 🔗 `{name}` depends on `{parent}` via `{req_col_name}`")
|
|
49
|
+
dag.add_edge(parent, name)
|
|
50
|
+
break
|
|
51
|
+
|
|
52
|
+
if not nx.is_directed_acyclic_graph(dag):
|
|
53
|
+
raise DAGCircularDependencyError(
|
|
54
|
+
"🛑 The Data Designer column configurations contain cyclic dependencies. Please "
|
|
55
|
+
"inspect the column configurations and ensure they can be sorted without "
|
|
56
|
+
"circular references."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
sorted_columns = non_dag_column_config_list
|
|
60
|
+
sorted_columns.extend([dag_column_config_dict[n] for n in list(nx.topological_sort(dag))])
|
|
61
|
+
|
|
62
|
+
return sorted_columns
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import shutil
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import TYPE_CHECKING, Callable, Container, Iterator
|
|
10
|
+
|
|
11
|
+
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage, BatchStage
|
|
12
|
+
from data_designer.engine.dataset_builders.utils.errors import DatasetBatchManagementError
|
|
13
|
+
from data_designer.lazy_heavy_imports import pd, pq
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
import pandas as pd
|
|
17
|
+
import pyarrow.parquet as pq
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DatasetBatchManager:
|
|
23
|
+
def __init__(self, artifact_storage: ArtifactStorage):
|
|
24
|
+
self._buffer: list[dict] = []
|
|
25
|
+
self._current_batch_number = 0
|
|
26
|
+
self._num_records_list: list[int] | None = None
|
|
27
|
+
self._buffer_size: int | None = None
|
|
28
|
+
self.artifact_storage = artifact_storage
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def num_batches(self) -> int:
|
|
32
|
+
if self._num_records_list is None:
|
|
33
|
+
return 0
|
|
34
|
+
return len(self._num_records_list)
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def num_records_batch(self) -> int:
|
|
38
|
+
if self._num_records_list is None or self._current_batch_number >= len(self._num_records_list):
|
|
39
|
+
raise DatasetBatchManagementError("🛑 Invalid batch number or num_records_list not set.")
|
|
40
|
+
return self._num_records_list[self._current_batch_number]
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def num_records_list(self) -> list[int]:
|
|
44
|
+
if self._num_records_list is None:
|
|
45
|
+
raise DatasetBatchManagementError("🛑 `num_records_list` is not set. Call start() first.")
|
|
46
|
+
return self._num_records_list
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def num_records_in_buffer(self) -> int:
|
|
50
|
+
return len(self._buffer)
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def buffer_is_empty(self) -> bool:
|
|
54
|
+
return len(self._buffer) == 0
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def buffer_size(self) -> int:
|
|
58
|
+
if self._buffer_size is None:
|
|
59
|
+
raise DatasetBatchManagementError("🛑 `buffer_size` is not set. Call start() first.")
|
|
60
|
+
return self._buffer_size
|
|
61
|
+
|
|
62
|
+
def add_record(self, record: dict) -> None:
|
|
63
|
+
self.add_records([record])
|
|
64
|
+
|
|
65
|
+
def add_records(self, records: list[dict]) -> None:
|
|
66
|
+
self._buffer.extend(records)
|
|
67
|
+
if len(self._buffer) > self.buffer_size:
|
|
68
|
+
raise DatasetBatchManagementError(
|
|
69
|
+
f"🛑 Buffer size exceeded. Current: {len(self._buffer)}, Max: {self.buffer_size}. "
|
|
70
|
+
"Flush the batch before adding more records."
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def drop_records(self, index: Container[int]) -> None:
|
|
74
|
+
self._buffer = [record for i, record in enumerate(self._buffer) if i not in index]
|
|
75
|
+
|
|
76
|
+
def finish_batch(self, on_complete: Callable[[Path], None] | None = None) -> Path | None:
|
|
77
|
+
"""Finish the batch by moving the results from the partial results path to the final parquet folder.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
The path to the written parquet file.
|
|
81
|
+
"""
|
|
82
|
+
if self._current_batch_number >= self.num_batches:
|
|
83
|
+
raise DatasetBatchManagementError("🛑 All batches have been processed.")
|
|
84
|
+
|
|
85
|
+
if self.write() is not None:
|
|
86
|
+
final_file_path = self.artifact_storage.move_partial_result_to_final_file_path(self._current_batch_number)
|
|
87
|
+
|
|
88
|
+
self.artifact_storage.write_metadata(
|
|
89
|
+
{
|
|
90
|
+
"target_num_records": sum(self.num_records_list),
|
|
91
|
+
"total_num_batches": self.num_batches,
|
|
92
|
+
"buffer_size": self._buffer_size,
|
|
93
|
+
"schema": {field.name: str(field.type) for field in pq.read_schema(final_file_path)},
|
|
94
|
+
"file_paths": self.artifact_storage.get_file_paths(),
|
|
95
|
+
"num_completed_batches": self._current_batch_number + 1,
|
|
96
|
+
"dataset_name": self.artifact_storage.dataset_name,
|
|
97
|
+
}
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if on_complete:
|
|
101
|
+
on_complete(final_file_path)
|
|
102
|
+
else:
|
|
103
|
+
final_file_path = None
|
|
104
|
+
|
|
105
|
+
logger.warning(
|
|
106
|
+
f"⚠️ Batch {self._current_batch_number + 1} finished without any results to write. "
|
|
107
|
+
"A partial dataset containing the currently available columns has been written to the partial results "
|
|
108
|
+
f"directory: {self.artifact_storage.partial_results_path}"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
self._current_batch_number += 1
|
|
112
|
+
self._buffer: list[dict] = []
|
|
113
|
+
|
|
114
|
+
return final_file_path
|
|
115
|
+
|
|
116
|
+
def finish(self) -> None:
|
|
117
|
+
"""Finish the dataset writing process by deleting the partial results path if it exists and is empty."""
|
|
118
|
+
|
|
119
|
+
# If the partial results path is empty, delete it.
|
|
120
|
+
if not any(self.artifact_storage.partial_results_path.iterdir()):
|
|
121
|
+
self.artifact_storage.partial_results_path.rmdir()
|
|
122
|
+
|
|
123
|
+
# Otherwise, log a warning, since existing partial results means the dataset is not complete.
|
|
124
|
+
else:
|
|
125
|
+
logger.warning("⚠️ Dataset writing finished with partial results.")
|
|
126
|
+
|
|
127
|
+
self.reset()
|
|
128
|
+
|
|
129
|
+
def get_current_batch_number(self) -> int:
|
|
130
|
+
return self._current_batch_number
|
|
131
|
+
|
|
132
|
+
def get_current_batch(self, *, as_dataframe: bool = False) -> pd.DataFrame | list[dict]:
|
|
133
|
+
if as_dataframe:
|
|
134
|
+
return pd.DataFrame(self._buffer)
|
|
135
|
+
return self._buffer
|
|
136
|
+
|
|
137
|
+
def iter_current_batch(self) -> Iterator[tuple[int, dict]]:
|
|
138
|
+
for i, record in enumerate(self._buffer):
|
|
139
|
+
yield i, record
|
|
140
|
+
|
|
141
|
+
def reset(self, delete_files: bool = False) -> None:
|
|
142
|
+
self._current_batch_number = 0
|
|
143
|
+
self._buffer: list[dict] = []
|
|
144
|
+
if delete_files:
|
|
145
|
+
for dir_path in [
|
|
146
|
+
self.artifact_storage.final_dataset_path,
|
|
147
|
+
self.artifact_storage.partial_results_path,
|
|
148
|
+
self.artifact_storage.dropped_columns_dataset_path,
|
|
149
|
+
self.artifact_storage.base_dataset_path,
|
|
150
|
+
]:
|
|
151
|
+
if dir_path.exists():
|
|
152
|
+
try:
|
|
153
|
+
shutil.rmtree(dir_path)
|
|
154
|
+
except OSError as e:
|
|
155
|
+
raise DatasetBatchManagementError(f"🛑 Failed to delete directory {dir_path}: {e}")
|
|
156
|
+
|
|
157
|
+
def start(self, *, num_records: int, buffer_size: int) -> None:
|
|
158
|
+
if num_records <= 0:
|
|
159
|
+
raise DatasetBatchManagementError("🛑 num_records must be positive.")
|
|
160
|
+
if buffer_size <= 0:
|
|
161
|
+
raise DatasetBatchManagementError("🛑 buffer_size must be positive.")
|
|
162
|
+
|
|
163
|
+
self._buffer_size = buffer_size
|
|
164
|
+
self._num_records_list = [buffer_size] * (num_records // buffer_size)
|
|
165
|
+
if remaining_records := num_records % buffer_size:
|
|
166
|
+
self._num_records_list.append(remaining_records)
|
|
167
|
+
self.reset()
|
|
168
|
+
|
|
169
|
+
def write(self) -> Path | None:
|
|
170
|
+
"""Write the current batch to a parquet file.
|
|
171
|
+
|
|
172
|
+
This method always writes results to the partial results path.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
The path to the written parquet file. If the buffer is empty, returns None.
|
|
176
|
+
"""
|
|
177
|
+
if len(self._buffer) == 0:
|
|
178
|
+
return None
|
|
179
|
+
try:
|
|
180
|
+
file_path = self.artifact_storage.write_batch_to_parquet_file(
|
|
181
|
+
batch_number=self._current_batch_number,
|
|
182
|
+
dataframe=pd.DataFrame(self._buffer),
|
|
183
|
+
batch_stage=BatchStage.PARTIAL_RESULT,
|
|
184
|
+
)
|
|
185
|
+
return file_path
|
|
186
|
+
except Exception as e:
|
|
187
|
+
raise DatasetBatchManagementError(f"🛑 Failed to write batch {self._current_batch_number}: {e}")
|
|
188
|
+
|
|
189
|
+
def update_record(self, index: int, record: dict) -> None:
|
|
190
|
+
if index < 0 or index >= len(self._buffer):
|
|
191
|
+
raise IndexError(f"🛑 Index {index} is out of bounds for buffer of size {len(self._buffer)}.")
|
|
192
|
+
self._buffer[index] = record
|
|
193
|
+
|
|
194
|
+
def update_records(self, records: list[dict]) -> None:
|
|
195
|
+
if len(records) != len(self._buffer):
|
|
196
|
+
raise DatasetBatchManagementError(
|
|
197
|
+
f"🛑 Number of records to update ({len(records)}) must match "
|
|
198
|
+
f"the number of records in the buffer ({len(self._buffer)})."
|
|
199
|
+
)
|
|
200
|
+
self._buffer = records
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from data_designer.engine.errors import DataDesignerError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DatasetBatchManagementError(DataDesignerError): ...
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConfigCompilationError(DataDesignerError): ...
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DAGCircularDependencyError(DataDesignerError): ...
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
from threading import Lock
|
|
9
|
+
|
|
10
|
+
from data_designer.logging import RandomEmoji
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ProgressTracker:
|
|
16
|
+
"""
|
|
17
|
+
Thread-safe progress tracker for monitoring concurrent task completion.
|
|
18
|
+
|
|
19
|
+
Tracks completed, successful, and failed task counts and logs progress
|
|
20
|
+
at configurable intervals. Designed for use with ConcurrentThreadExecutor
|
|
21
|
+
to provide visibility into long-running batch operations.
|
|
22
|
+
|
|
23
|
+
Example usage:
|
|
24
|
+
tracker = ProgressTracker(total_records=100, label="LLM_TEXT column 'response'")
|
|
25
|
+
tracker.log_start(max_workers=8)
|
|
26
|
+
|
|
27
|
+
# In callbacks from ConcurrentThreadExecutor:
|
|
28
|
+
tracker.record_success() # or tracker.record_failure()
|
|
29
|
+
|
|
30
|
+
# After executor completes:
|
|
31
|
+
tracker.log_final()
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, total_records: int, label: str, log_interval_percent: int = 10):
|
|
35
|
+
"""
|
|
36
|
+
Initialize the progress tracker.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
total_records: Total number of records to process.
|
|
40
|
+
label: Human-readable label for log messages (e.g., "LLM_TEXT column 'response'").
|
|
41
|
+
log_interval_percent: How often to log progress as a percentage (default 10%).
|
|
42
|
+
"""
|
|
43
|
+
self.total_records = total_records
|
|
44
|
+
self.label = label
|
|
45
|
+
|
|
46
|
+
self.completed = 0
|
|
47
|
+
self.success = 0
|
|
48
|
+
self.failed = 0
|
|
49
|
+
|
|
50
|
+
interval_fraction = max(1, log_interval_percent) / 100.0
|
|
51
|
+
self.log_interval = max(1, int(total_records * interval_fraction)) if total_records > 0 else 1
|
|
52
|
+
self.next_log_at = self.log_interval
|
|
53
|
+
|
|
54
|
+
self.start_time = time.perf_counter()
|
|
55
|
+
self.lock = Lock()
|
|
56
|
+
self._random_emoji = RandomEmoji()
|
|
57
|
+
|
|
58
|
+
def log_start(self, max_workers: int) -> None:
|
|
59
|
+
"""Log the start of processing with worker count and interval information."""
|
|
60
|
+
logger.info(
|
|
61
|
+
"🐙 Processing %s with %d concurrent workers",
|
|
62
|
+
self.label,
|
|
63
|
+
max_workers,
|
|
64
|
+
)
|
|
65
|
+
logger.info(
|
|
66
|
+
"🧭 %s will report progress every %d record(s).",
|
|
67
|
+
self.label,
|
|
68
|
+
self.log_interval,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def record_success(self) -> None:
|
|
72
|
+
"""Record a successful task completion and log progress if at interval."""
|
|
73
|
+
self._record_completion(success=True)
|
|
74
|
+
|
|
75
|
+
def record_failure(self) -> None:
|
|
76
|
+
"""Record a failed task completion and log progress if at interval."""
|
|
77
|
+
self._record_completion(success=False)
|
|
78
|
+
|
|
79
|
+
def log_final(self) -> None:
|
|
80
|
+
"""Log final progress summary."""
|
|
81
|
+
with self.lock:
|
|
82
|
+
if self.completed > 0:
|
|
83
|
+
self._log_progress_unlocked()
|
|
84
|
+
|
|
85
|
+
def _record_completion(self, *, success: bool) -> None:
|
|
86
|
+
should_log = False
|
|
87
|
+
with self.lock:
|
|
88
|
+
self.completed += 1
|
|
89
|
+
if success:
|
|
90
|
+
self.success += 1
|
|
91
|
+
else:
|
|
92
|
+
self.failed += 1
|
|
93
|
+
|
|
94
|
+
if self.completed >= self.next_log_at and self.completed < self.total_records:
|
|
95
|
+
should_log = True
|
|
96
|
+
while self.next_log_at <= self.completed:
|
|
97
|
+
self.next_log_at += self.log_interval
|
|
98
|
+
|
|
99
|
+
if should_log:
|
|
100
|
+
with self.lock:
|
|
101
|
+
self._log_progress_unlocked()
|
|
102
|
+
|
|
103
|
+
def _log_progress_unlocked(self) -> None:
|
|
104
|
+
"""Log current progress. Must be called while holding the lock."""
|
|
105
|
+
elapsed = time.perf_counter() - self.start_time
|
|
106
|
+
rate = self.completed / elapsed if elapsed > 0 else 0.0
|
|
107
|
+
remaining = max(0, self.total_records - self.completed)
|
|
108
|
+
eta = f"{(remaining / rate):.1f}s" if rate > 0 else "unknown"
|
|
109
|
+
percent = (self.completed / self.total_records) * 100 if self.total_records else 100.0
|
|
110
|
+
|
|
111
|
+
logger.info(
|
|
112
|
+
" |-- %s %s progress: %d/%d (%.0f%%) complete, %d ok, %d failed, %.2f rec/s, eta %s",
|
|
113
|
+
self._random_emoji.progress(percent),
|
|
114
|
+
self.label,
|
|
115
|
+
self.completed,
|
|
116
|
+
self.total_records,
|
|
117
|
+
percent,
|
|
118
|
+
self.success,
|
|
119
|
+
self.failed,
|
|
120
|
+
rate,
|
|
121
|
+
eta,
|
|
122
|
+
)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
from data_designer.errors import DataDesignerError
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DataDesignerRuntimeError(DataDesignerError): ...
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class UnknownModelAliasError(DataDesignerError): ...
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class UnknownProviderError(DataDesignerError): ...
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class NoModelProvidersError(DataDesignerError): ...
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SecretResolutionError(DataDesignerError): ...
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RemoteValidationSchemaError(DataDesignerError): ...
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class LocalCallableValidationError(DataDesignerError): ...
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ErrorTrap(BaseModel):
|
|
33
|
+
error_count: int = 0
|
|
34
|
+
task_errors: dict[str, int] = Field(default_factory=dict)
|
|
35
|
+
|
|
36
|
+
def _track_error(self, error: DataDesignerError) -> None:
|
|
37
|
+
"""
|
|
38
|
+
Track a specific error type.
|
|
39
|
+
"""
|
|
40
|
+
error_type = type(error).__name__
|
|
41
|
+
if error_type not in self.task_errors:
|
|
42
|
+
self.task_errors[error_type] = 0
|
|
43
|
+
self.task_errors[error_type] += 1
|
|
44
|
+
|
|
45
|
+
def handle_error(self, error: Exception) -> None:
|
|
46
|
+
self.error_count += 1
|
|
47
|
+
|
|
48
|
+
if not isinstance(error, DataDesignerError):
|
|
49
|
+
error = DataDesignerError(str(error))
|
|
50
|
+
|
|
51
|
+
self._track_error(error)
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from functools import cached_property
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, field_validator, model_validator
|
|
9
|
+
from typing_extensions import Self
|
|
10
|
+
|
|
11
|
+
from data_designer.config.models import ModelProvider
|
|
12
|
+
from data_designer.engine.errors import NoModelProvidersError, UnknownProviderError
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ModelProviderRegistry(BaseModel):
|
|
16
|
+
providers: list[ModelProvider]
|
|
17
|
+
default: str | None = None
|
|
18
|
+
|
|
19
|
+
@field_validator("providers", mode="after")
|
|
20
|
+
@classmethod
|
|
21
|
+
def validate_providers_not_empty(cls, v: list[ModelProvider]) -> list[ModelProvider]:
|
|
22
|
+
if len(v) == 0:
|
|
23
|
+
raise ValueError("At least one model provider must be defined")
|
|
24
|
+
return v
|
|
25
|
+
|
|
26
|
+
@field_validator("providers", mode="after")
|
|
27
|
+
@classmethod
|
|
28
|
+
def validate_providers_have_unique_names(cls, v: list[ModelProvider]) -> list[ModelProvider]:
|
|
29
|
+
names = set()
|
|
30
|
+
dupes = set()
|
|
31
|
+
for provider in v:
|
|
32
|
+
if provider.name in names:
|
|
33
|
+
dupes.add(provider.name)
|
|
34
|
+
names.add(provider.name)
|
|
35
|
+
|
|
36
|
+
if len(dupes) > 0:
|
|
37
|
+
raise ValueError(f"Model providers must have unique names, found duplicates: {dupes}")
|
|
38
|
+
return v
|
|
39
|
+
|
|
40
|
+
@model_validator(mode="after")
|
|
41
|
+
def check_implicit_default(self) -> Self:
|
|
42
|
+
if self.default is None and len(self.providers) != 1:
|
|
43
|
+
raise ValueError("A default provider must be specified if multiple model providers are defined")
|
|
44
|
+
return self
|
|
45
|
+
|
|
46
|
+
@model_validator(mode="after")
|
|
47
|
+
def check_default_exists(self) -> Self:
|
|
48
|
+
if self.default and self.default not in self._providers_dict:
|
|
49
|
+
raise ValueError(f"Specified default {self.default!r} not found in providers list")
|
|
50
|
+
return self
|
|
51
|
+
|
|
52
|
+
def get_default_provider_name(self) -> str:
|
|
53
|
+
return self.default or self.providers[0].name
|
|
54
|
+
|
|
55
|
+
@cached_property
|
|
56
|
+
def _providers_dict(self) -> dict[str, ModelProvider]:
|
|
57
|
+
return {p.name: p for p in self.providers}
|
|
58
|
+
|
|
59
|
+
def get_provider(self, name: str | None) -> ModelProvider:
|
|
60
|
+
if name is None:
|
|
61
|
+
name = self.get_default_provider_name()
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
return self._providers_dict[name]
|
|
65
|
+
except KeyError:
|
|
66
|
+
raise UnknownProviderError(f"No provider named {name!r} registered")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def resolve_model_provider_registry(
|
|
70
|
+
model_providers: list[ModelProvider], default_provider_name: str | None = None
|
|
71
|
+
) -> ModelProviderRegistry:
|
|
72
|
+
if len(model_providers) == 0:
|
|
73
|
+
raise NoModelProvidersError("At least one model provider must be defined")
|
|
74
|
+
return ModelProviderRegistry(
|
|
75
|
+
providers=model_providers,
|
|
76
|
+
default=default_provider_name or model_providers[0].name,
|
|
77
|
+
)
|