data-designer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +15 -0
- data_designer/_version.py +34 -0
- data_designer/cli/README.md +236 -0
- data_designer/cli/__init__.py +6 -0
- data_designer/cli/commands/__init__.py +2 -0
- data_designer/cli/commands/list.py +130 -0
- data_designer/cli/commands/models.py +10 -0
- data_designer/cli/commands/providers.py +11 -0
- data_designer/cli/commands/reset.py +100 -0
- data_designer/cli/controllers/__init__.py +7 -0
- data_designer/cli/controllers/model_controller.py +246 -0
- data_designer/cli/controllers/provider_controller.py +317 -0
- data_designer/cli/forms/__init__.py +20 -0
- data_designer/cli/forms/builder.py +51 -0
- data_designer/cli/forms/field.py +180 -0
- data_designer/cli/forms/form.py +59 -0
- data_designer/cli/forms/model_builder.py +125 -0
- data_designer/cli/forms/provider_builder.py +76 -0
- data_designer/cli/main.py +44 -0
- data_designer/cli/repositories/__init__.py +8 -0
- data_designer/cli/repositories/base.py +39 -0
- data_designer/cli/repositories/model_repository.py +42 -0
- data_designer/cli/repositories/provider_repository.py +43 -0
- data_designer/cli/services/__init__.py +7 -0
- data_designer/cli/services/model_service.py +116 -0
- data_designer/cli/services/provider_service.py +111 -0
- data_designer/cli/ui.py +448 -0
- data_designer/cli/utils.py +47 -0
- data_designer/config/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +89 -0
- data_designer/config/analysis/column_statistics.py +274 -0
- data_designer/config/analysis/dataset_profiler.py +60 -0
- data_designer/config/analysis/utils/errors.py +8 -0
- data_designer/config/analysis/utils/reporting.py +188 -0
- data_designer/config/base.py +68 -0
- data_designer/config/column_configs.py +354 -0
- data_designer/config/column_types.py +168 -0
- data_designer/config/config_builder.py +660 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +11 -0
- data_designer/config/datastore.py +151 -0
- data_designer/config/default_model_settings.py +123 -0
- data_designer/config/errors.py +19 -0
- data_designer/config/interface.py +54 -0
- data_designer/config/models.py +231 -0
- data_designer/config/preview_results.py +32 -0
- data_designer/config/processors.py +41 -0
- data_designer/config/sampler_constraints.py +51 -0
- data_designer/config/sampler_params.py +604 -0
- data_designer/config/seed.py +145 -0
- data_designer/config/utils/code_lang.py +83 -0
- data_designer/config/utils/constants.py +313 -0
- data_designer/config/utils/errors.py +19 -0
- data_designer/config/utils/info.py +88 -0
- data_designer/config/utils/io_helpers.py +273 -0
- data_designer/config/utils/misc.py +81 -0
- data_designer/config/utils/numerical_helpers.py +28 -0
- data_designer/config/utils/type_helpers.py +100 -0
- data_designer/config/utils/validation.py +336 -0
- data_designer/config/utils/visualization.py +427 -0
- data_designer/config/validator_params.py +96 -0
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +55 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
- data_designer/engine/analysis/column_profilers/registry.py +20 -0
- data_designer/engine/analysis/column_statistics.py +142 -0
- data_designer/engine/analysis/dataset_profiler.py +125 -0
- data_designer/engine/analysis/errors.py +7 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +61 -0
- data_designer/engine/column_generators/generators/expression.py +63 -0
- data_designer/engine/column_generators/generators/llm_generators.py +172 -0
- data_designer/engine/column_generators/generators/samplers.py +75 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
- data_designer/engine/column_generators/generators/validation.py +147 -0
- data_designer/engine/column_generators/registry.py +56 -0
- data_designer/engine/column_generators/utils/errors.py +13 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
- data_designer/engine/configurable_task.py +82 -0
- data_designer/engine/dataset_builders/artifact_storage.py +181 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
- data_designer/engine/dataset_builders/errors.py +13 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
- data_designer/engine/dataset_builders/utils/dag.py +56 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
- data_designer/engine/dataset_builders/utils/errors.py +13 -0
- data_designer/engine/errors.py +49 -0
- data_designer/engine/model_provider.py +75 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +308 -0
- data_designer/engine/models/facade.py +225 -0
- data_designer/engine/models/litellm_overrides.py +162 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +236 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +60 -0
- data_designer/engine/models/parsers/types.py +82 -0
- data_designer/engine/models/recipes/base.py +79 -0
- data_designer/engine/models/recipes/response_recipes.py +291 -0
- data_designer/engine/models/registry.py +118 -0
- data_designer/engine/models/usage.py +75 -0
- data_designer/engine/models/utils.py +38 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +64 -0
- data_designer/engine/processing/ginja/environment.py +461 -0
- data_designer/engine/processing/ginja/exceptions.py +54 -0
- data_designer/engine/processing/ginja/record.py +30 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +8 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
- data_designer/engine/processing/gsonschema/types.py +8 -0
- data_designer/engine/processing/gsonschema/validators.py +143 -0
- data_designer/engine/processing/processors/base.py +15 -0
- data_designer/engine/processing/processors/drop_columns.py +46 -0
- data_designer/engine/processing/processors/registry.py +20 -0
- data_designer/engine/processing/utils.py +120 -0
- data_designer/engine/registry/base.py +97 -0
- data_designer/engine/registry/data_designer_registry.py +37 -0
- data_designer/engine/registry/errors.py +10 -0
- data_designer/engine/resources/managed_dataset_generator.py +35 -0
- data_designer/engine/resources/managed_dataset_repository.py +194 -0
- data_designer/engine/resources/managed_storage.py +63 -0
- data_designer/engine/resources/resource_provider.py +46 -0
- data_designer/engine/resources/seed_dataset_data_store.py +66 -0
- data_designer/engine/sampling_gen/column.py +89 -0
- data_designer/engine/sampling_gen/constraints.py +95 -0
- data_designer/engine/sampling_gen/data_sources/base.py +214 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
- data_designer/engine/sampling_gen/entities/errors.py +8 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
- data_designer/engine/sampling_gen/entities/person.py +142 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
- data_designer/engine/sampling_gen/errors.py +24 -0
- data_designer/engine/sampling_gen/generator.py +121 -0
- data_designer/engine/sampling_gen/jinja_utils.py +60 -0
- data_designer/engine/sampling_gen/people_gen.py +203 -0
- data_designer/engine/sampling_gen/person_constants.py +54 -0
- data_designer/engine/sampling_gen/schema.py +143 -0
- data_designer/engine/sampling_gen/schema_builder.py +59 -0
- data_designer/engine/sampling_gen/utils.py +40 -0
- data_designer/engine/secret_resolver.py +80 -0
- data_designer/engine/validators/__init__.py +17 -0
- data_designer/engine/validators/base.py +36 -0
- data_designer/engine/validators/local_callable.py +34 -0
- data_designer/engine/validators/python.py +245 -0
- data_designer/engine/validators/remote.py +83 -0
- data_designer/engine/validators/sql.py +60 -0
- data_designer/errors.py +5 -0
- data_designer/essentials/__init__.py +137 -0
- data_designer/interface/__init__.py +2 -0
- data_designer/interface/data_designer.py +351 -0
- data_designer/interface/errors.py +16 -0
- data_designer/interface/results.py +55 -0
- data_designer/logging.py +161 -0
- data_designer/plugin_manager.py +83 -0
- data_designer/plugins/__init__.py +6 -0
- data_designer/plugins/errors.py +10 -0
- data_designer/plugins/plugin.py +69 -0
- data_designer/plugins/registry.py +86 -0
- data_designer-0.1.0.dist-info/METADATA +173 -0
- data_designer-0.1.0.dist-info/RECORD +177 -0
- data_designer-0.1.0.dist-info/WHEEL +4 -0
- data_designer-0.1.0.dist-info/entry_points.txt +2 -0
- data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import functools
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
import time
|
|
9
|
+
from typing import Callable
|
|
10
|
+
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
from data_designer.config.column_types import ColumnConfigT, column_type_is_llm_generated
|
|
14
|
+
from data_designer.config.dataset_builders import BuildStage
|
|
15
|
+
from data_designer.config.processors import (
|
|
16
|
+
DropColumnsProcessorConfig,
|
|
17
|
+
ProcessorConfig,
|
|
18
|
+
ProcessorType,
|
|
19
|
+
)
|
|
20
|
+
from data_designer.engine.column_generators.generators.base import ColumnGenerator, GenerationStrategy
|
|
21
|
+
from data_designer.engine.column_generators.generators.llm_generators import WithLLMGeneration
|
|
22
|
+
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
|
|
23
|
+
from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
|
|
24
|
+
from data_designer.engine.dataset_builders.multi_column_configs import (
|
|
25
|
+
DatasetBuilderColumnConfigT,
|
|
26
|
+
MultiColumnConfig,
|
|
27
|
+
)
|
|
28
|
+
from data_designer.engine.dataset_builders.utils.concurrency import (
|
|
29
|
+
MAX_CONCURRENCY_PER_NON_LLM_GENERATOR,
|
|
30
|
+
ConcurrentThreadExecutor,
|
|
31
|
+
)
|
|
32
|
+
from data_designer.engine.dataset_builders.utils.dataset_batch_manager import (
|
|
33
|
+
DatasetBatchManager,
|
|
34
|
+
)
|
|
35
|
+
from data_designer.engine.processing.processors.base import Processor
|
|
36
|
+
from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
|
|
37
|
+
from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
|
|
38
|
+
from data_designer.engine.resources.resource_provider import ResourceProvider
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ColumnWiseDatasetBuilder:
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
column_configs: list[DatasetBuilderColumnConfigT],
|
|
47
|
+
processor_configs: list[ProcessorConfig],
|
|
48
|
+
resource_provider: ResourceProvider,
|
|
49
|
+
registry: DataDesignerRegistry | None = None,
|
|
50
|
+
):
|
|
51
|
+
self.batch_manager = DatasetBatchManager(resource_provider.artifact_storage)
|
|
52
|
+
self._resource_provider = resource_provider
|
|
53
|
+
self._records_to_drop: set[int] = set()
|
|
54
|
+
self._registry = registry or DataDesignerRegistry()
|
|
55
|
+
self._column_configs = column_configs
|
|
56
|
+
self._processors: dict[BuildStage, list[Processor]] = self._initialize_processors(processor_configs)
|
|
57
|
+
self._validate_column_configs()
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def artifact_storage(self) -> ArtifactStorage:
|
|
61
|
+
return self._resource_provider.artifact_storage
|
|
62
|
+
|
|
63
|
+
@functools.cached_property
|
|
64
|
+
def single_column_configs(self) -> list[ColumnConfigT]:
|
|
65
|
+
configs = []
|
|
66
|
+
for config in self._column_configs:
|
|
67
|
+
if isinstance(config, MultiColumnConfig):
|
|
68
|
+
configs.extend(config.columns)
|
|
69
|
+
else:
|
|
70
|
+
configs.append(config)
|
|
71
|
+
return configs
|
|
72
|
+
|
|
73
|
+
@functools.cached_property
|
|
74
|
+
def llm_generated_column_configs(self) -> list[ColumnConfigT]:
|
|
75
|
+
return [config for config in self.single_column_configs if column_type_is_llm_generated(config.column_type)]
|
|
76
|
+
|
|
77
|
+
def build(
|
|
78
|
+
self,
|
|
79
|
+
*,
|
|
80
|
+
num_records: int,
|
|
81
|
+
buffer_size: int,
|
|
82
|
+
on_batch_complete: Callable[[Path], None] | None = None,
|
|
83
|
+
) -> Path:
|
|
84
|
+
self._write_configs()
|
|
85
|
+
self._run_model_health_check_if_needed()
|
|
86
|
+
|
|
87
|
+
generators = self._initialize_generators()
|
|
88
|
+
start_time = time.perf_counter()
|
|
89
|
+
|
|
90
|
+
self.batch_manager.start(num_records=num_records, buffer_size=buffer_size)
|
|
91
|
+
for batch_idx in range(1, self.batch_manager.num_batches + 1):
|
|
92
|
+
logger.info(f"⏳ Processing batch {batch_idx} of {self.batch_manager.num_batches}")
|
|
93
|
+
self._run_batch(generators)
|
|
94
|
+
df_batch = self._run_processors(
|
|
95
|
+
stage=BuildStage.POST_BATCH,
|
|
96
|
+
dataframe=self.batch_manager.get_current_batch(as_dataframe=True),
|
|
97
|
+
current_batch_number=batch_idx,
|
|
98
|
+
)
|
|
99
|
+
self._write_processed_batch(df_batch)
|
|
100
|
+
self.batch_manager.finish_batch(on_batch_complete)
|
|
101
|
+
self.batch_manager.finish()
|
|
102
|
+
|
|
103
|
+
model_usage_stats = self._resource_provider.model_registry.get_model_usage_stats(
|
|
104
|
+
time.perf_counter() - start_time
|
|
105
|
+
)
|
|
106
|
+
logger.info(f"📊 Model usage summary:\n{json.dumps(model_usage_stats, indent=4)}")
|
|
107
|
+
|
|
108
|
+
return self.artifact_storage.final_dataset_path
|
|
109
|
+
|
|
110
|
+
def build_preview(self, *, num_records: int) -> pd.DataFrame:
|
|
111
|
+
self._run_model_health_check_if_needed()
|
|
112
|
+
|
|
113
|
+
generators = self._initialize_generators()
|
|
114
|
+
|
|
115
|
+
start_time = time.perf_counter()
|
|
116
|
+
self.batch_manager.start(num_records=num_records, buffer_size=num_records)
|
|
117
|
+
self._run_batch(generators, save_partial_results=False)
|
|
118
|
+
dataset = self.batch_manager.get_current_batch(as_dataframe=True)
|
|
119
|
+
self.batch_manager.reset()
|
|
120
|
+
|
|
121
|
+
model_usage_stats = self._resource_provider.model_registry.get_model_usage_stats(
|
|
122
|
+
time.perf_counter() - start_time
|
|
123
|
+
)
|
|
124
|
+
logger.info(f"📊 Model usage summary:\n{json.dumps(model_usage_stats, indent=4)}")
|
|
125
|
+
|
|
126
|
+
return dataset
|
|
127
|
+
|
|
128
|
+
def process_preview(self, dataset: pd.DataFrame) -> pd.DataFrame:
|
|
129
|
+
return self._run_processors(
|
|
130
|
+
stage=BuildStage.POST_BATCH,
|
|
131
|
+
dataframe=dataset.copy(),
|
|
132
|
+
current_batch_number=None, # preview mode does not have a batch number
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def _initialize_generators(self) -> list[ColumnGenerator]:
|
|
136
|
+
return [
|
|
137
|
+
self._registry.column_generators.get_for_config_type(type(config))(
|
|
138
|
+
config=config, resource_provider=self._resource_provider
|
|
139
|
+
)
|
|
140
|
+
for config in self._column_configs
|
|
141
|
+
]
|
|
142
|
+
|
|
143
|
+
def _run_batch(self, generators: list[ColumnGenerator], *, save_partial_results: bool = True) -> None:
|
|
144
|
+
for generator in generators:
|
|
145
|
+
generator.log_pre_generation()
|
|
146
|
+
try:
|
|
147
|
+
if generator.can_generate_from_scratch and self.batch_manager.buffer_is_empty:
|
|
148
|
+
self._run_from_scratch_column_generator(generator)
|
|
149
|
+
elif generator.generation_strategy == GenerationStrategy.CELL_BY_CELL:
|
|
150
|
+
self._run_cell_by_cell_generator(generator)
|
|
151
|
+
elif generator.generation_strategy == GenerationStrategy.FULL_COLUMN:
|
|
152
|
+
self._run_full_column_generator(generator)
|
|
153
|
+
else:
|
|
154
|
+
logger.error(f"❌ Unknown generation strategy: {generator.generation_strategy}")
|
|
155
|
+
raise DatasetGenerationError(f"🛑 Unknown generation strategy: {generator.generation_strategy}")
|
|
156
|
+
if save_partial_results:
|
|
157
|
+
self.batch_manager.write()
|
|
158
|
+
except Exception as e:
|
|
159
|
+
column_error_str = (
|
|
160
|
+
f"columns {generator.config.column_names}"
|
|
161
|
+
if hasattr(generator.config, "column_names")
|
|
162
|
+
else f"column {generator.config.name!r}"
|
|
163
|
+
)
|
|
164
|
+
raise DatasetGenerationError(f"🛑 Failed to process {column_error_str}:\n{e}")
|
|
165
|
+
|
|
166
|
+
def _run_from_scratch_column_generator(self, generator: ColumnGenerator) -> None:
|
|
167
|
+
df = generator.generate_from_scratch(self.batch_manager.num_records_batch)
|
|
168
|
+
self.batch_manager.add_records(df.to_dict(orient="records"))
|
|
169
|
+
|
|
170
|
+
def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
|
|
171
|
+
max_workers = MAX_CONCURRENCY_PER_NON_LLM_GENERATOR
|
|
172
|
+
if isinstance(generator, WithLLMGeneration):
|
|
173
|
+
max_workers = generator.inference_parameters.max_parallel_requests
|
|
174
|
+
self._fan_out_with_threads(generator, max_workers=max_workers)
|
|
175
|
+
|
|
176
|
+
def _run_full_column_generator(self, generator: ColumnGenerator) -> None:
|
|
177
|
+
df = generator.generate(self.batch_manager.get_current_batch(as_dataframe=True))
|
|
178
|
+
self.batch_manager.update_records(df.to_dict(orient="records"))
|
|
179
|
+
|
|
180
|
+
def _run_model_health_check_if_needed(self) -> bool:
|
|
181
|
+
if any(column_type_is_llm_generated(config.column_type) for config in self.single_column_configs):
|
|
182
|
+
self._resource_provider.model_registry.run_health_check(
|
|
183
|
+
set(config.model_alias for config in self.llm_generated_column_configs)
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def _fan_out_with_threads(self, generator: WithLLMGeneration, max_workers: int) -> None:
|
|
187
|
+
if generator.generation_strategy != GenerationStrategy.CELL_BY_CELL:
|
|
188
|
+
raise DatasetGenerationError(
|
|
189
|
+
f"Generator {generator.metadata().name} is not a {GenerationStrategy.CELL_BY_CELL} "
|
|
190
|
+
"generator so concurrency through threads is not supported."
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
logger.info(
|
|
194
|
+
f"🐙 Processing {generator.config.column_type} column '{generator.config.name}' "
|
|
195
|
+
f"with {max_workers} concurrent workers"
|
|
196
|
+
)
|
|
197
|
+
with ConcurrentThreadExecutor(
|
|
198
|
+
max_workers=max_workers,
|
|
199
|
+
column_name=generator.config.name,
|
|
200
|
+
result_callback=self._worker_result_callback,
|
|
201
|
+
error_callback=self._worker_error_callback,
|
|
202
|
+
) as executor:
|
|
203
|
+
for i, record in self.batch_manager.iter_current_batch():
|
|
204
|
+
executor.submit(lambda record: generator.generate(record), record, context={"index": i})
|
|
205
|
+
|
|
206
|
+
if len(self._records_to_drop) > 0:
|
|
207
|
+
self.batch_manager.drop_records(self._records_to_drop)
|
|
208
|
+
self._records_to_drop.clear()
|
|
209
|
+
|
|
210
|
+
def _write_processed_batch(self, dataframe: pd.DataFrame) -> None:
|
|
211
|
+
self.batch_manager.update_records(dataframe.to_dict(orient="records"))
|
|
212
|
+
self.batch_manager.write()
|
|
213
|
+
|
|
214
|
+
def _validate_column_configs(self) -> None:
|
|
215
|
+
if len(self._column_configs) == 0:
|
|
216
|
+
raise DatasetGenerationError("🛑 No column configs provided.")
|
|
217
|
+
|
|
218
|
+
if not self._registry.column_generators.get_for_config_type(
|
|
219
|
+
type(self._column_configs[0])
|
|
220
|
+
).can_generate_from_scratch:
|
|
221
|
+
raise DatasetGenerationError("🛑 The first column config must be a from-scratch column generator.")
|
|
222
|
+
|
|
223
|
+
def _initialize_processors(self, processor_configs: list[ProcessorConfig]) -> dict[BuildStage, list[Processor]]:
|
|
224
|
+
# Check columns marked for drop
|
|
225
|
+
columns_to_drop = [config.name for config in self.single_column_configs if config.drop]
|
|
226
|
+
|
|
227
|
+
processors: dict[BuildStage, list[Processor]] = {stage: [] for stage in BuildStage}
|
|
228
|
+
for config in processor_configs:
|
|
229
|
+
processors[config.build_stage].append(
|
|
230
|
+
self._registry.processors.get_for_config_type(type(config))(
|
|
231
|
+
config=config,
|
|
232
|
+
resource_provider=self._resource_provider,
|
|
233
|
+
)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Manually included "drop columns" processor takes precedence (can e.g., pick stages other than post-batch)
|
|
237
|
+
if config.processor_type == ProcessorType.DROP_COLUMNS:
|
|
238
|
+
for column in config.column_names:
|
|
239
|
+
if column in columns_to_drop:
|
|
240
|
+
columns_to_drop.remove(column)
|
|
241
|
+
|
|
242
|
+
# If there are still columns marked for drop, add the "drop columns" processor to drop them
|
|
243
|
+
if len(columns_to_drop) > 0:
|
|
244
|
+
processors[BuildStage.POST_BATCH].append( # as post-batch by default
|
|
245
|
+
DropColumnsProcessor(
|
|
246
|
+
config=DropColumnsProcessorConfig(
|
|
247
|
+
column_names=columns_to_drop,
|
|
248
|
+
build_stage=BuildStage.POST_BATCH,
|
|
249
|
+
),
|
|
250
|
+
resource_provider=self._resource_provider,
|
|
251
|
+
)
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
return processors
|
|
255
|
+
|
|
256
|
+
def _run_processors(
|
|
257
|
+
self, stage: BuildStage, dataframe: pd.DataFrame, current_batch_number: int | None = None
|
|
258
|
+
) -> pd.DataFrame:
|
|
259
|
+
for processor in self._processors[stage]:
|
|
260
|
+
try:
|
|
261
|
+
dataframe = processor.process(dataframe, current_batch_number=current_batch_number)
|
|
262
|
+
except Exception as e:
|
|
263
|
+
raise DatasetProcessingError(
|
|
264
|
+
f"🛑 Failed to process dataset with processor {processor.metadata().name} in stage {stage}: {e}"
|
|
265
|
+
) from e
|
|
266
|
+
return dataframe
|
|
267
|
+
|
|
268
|
+
def _worker_error_callback(self, exc: Exception, *, context: dict | None = None) -> None:
|
|
269
|
+
"""If a worker fails, we can handle the exception here."""
|
|
270
|
+
logger.warning(
|
|
271
|
+
f"⚠️ Generation for record at index {context['index']} failed. "
|
|
272
|
+
f"Will omit this record from the dataset.\n{exc}"
|
|
273
|
+
)
|
|
274
|
+
self._records_to_drop.add(context["index"])
|
|
275
|
+
|
|
276
|
+
def _worker_result_callback(self, result: dict, *, context: dict | None = None) -> None:
|
|
277
|
+
self.batch_manager.update_record(context["index"], result)
|
|
278
|
+
|
|
279
|
+
def _write_configs(self) -> None:
|
|
280
|
+
self.artifact_storage.write_configs(
|
|
281
|
+
json_file_name="column_configs.json",
|
|
282
|
+
configs=self._column_configs,
|
|
283
|
+
)
|
|
284
|
+
self.artifact_storage.write_configs(
|
|
285
|
+
json_file_name="model_configs.json",
|
|
286
|
+
configs=self._resource_provider.model_registry.model_configs.values(),
|
|
287
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.engine.errors import DataDesignerError
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ArtifactStorageError(DataDesignerError): ...
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DatasetGenerationError(DataDesignerError): ...
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DatasetProcessingError(DataDesignerError): ...
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from abc import ABC
|
|
5
|
+
from typing import TypeAlias
|
|
6
|
+
|
|
7
|
+
from pydantic import Field, field_validator
|
|
8
|
+
|
|
9
|
+
from data_designer.config.base import ConfigBase
|
|
10
|
+
from data_designer.config.column_configs import SamplerColumnConfig, SeedDatasetColumnConfig, SingleColumnConfig
|
|
11
|
+
from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType
|
|
12
|
+
from data_designer.config.sampler_constraints import ColumnConstraintT
|
|
13
|
+
from data_designer.config.seed import SeedConfig
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MultiColumnConfig(ConfigBase, ABC):
|
|
17
|
+
columns: list[SingleColumnConfig] = Field(..., min_length=1)
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def column_names(self) -> list[str]:
|
|
21
|
+
return [c.name for c in self.columns]
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def column_type(self) -> DataDesignerColumnType:
|
|
25
|
+
return self.columns[0].column_type
|
|
26
|
+
|
|
27
|
+
@field_validator("columns", mode="after")
|
|
28
|
+
def validate_column_types_are_the_same(cls, v: list[SingleColumnConfig]) -> list[SingleColumnConfig]:
|
|
29
|
+
if len(set([c.column_type for c in v])) != 1:
|
|
30
|
+
raise ValueError("All column types must be of the same type")
|
|
31
|
+
return v
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class SamplerMultiColumnConfig(MultiColumnConfig):
|
|
35
|
+
columns: list[SamplerColumnConfig]
|
|
36
|
+
constraints: list[ColumnConstraintT] = []
|
|
37
|
+
max_rejections_factor: int = 5
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class SeedDatasetMultiColumnConfig(SeedConfig, MultiColumnConfig):
|
|
41
|
+
columns: list[SeedDatasetColumnConfig]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
DatasetBuilderColumnConfigT: TypeAlias = ColumnConfigT | SeedDatasetMultiColumnConfig | SamplerMultiColumnConfig
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from concurrent.futures import Future, ThreadPoolExecutor
|
|
7
|
+
import contextvars
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
from threading import Lock, Semaphore
|
|
11
|
+
from typing import Any, Optional, Protocol
|
|
12
|
+
|
|
13
|
+
from pydantic import BaseModel, Field
|
|
14
|
+
|
|
15
|
+
from data_designer.engine.errors import DataDesignerRuntimeError, ErrorTrap
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
# Constants
|
|
20
|
+
MAX_CONCURRENCY_PER_NON_LLM_GENERATOR = 4
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ExecutorResults(BaseModel):
|
|
24
|
+
failure_threshold: float = 0.0 # Error rate threshold
|
|
25
|
+
completed_count: int = 0 # How many tasks/jobs completed
|
|
26
|
+
success_count: int = 0 # How many tasks/jobs were successful
|
|
27
|
+
early_shutdown: bool = False # Did we shutdown early due to errors?
|
|
28
|
+
error_trap: ErrorTrap = Field(default_factory=ErrorTrap)
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def summary(self) -> dict:
|
|
32
|
+
summary = self.model_dump(exclude={"error_trap"})
|
|
33
|
+
summary |= self.error_trap.model_dump()
|
|
34
|
+
return summary
|
|
35
|
+
|
|
36
|
+
def get_error_rate(self, window: int) -> float:
|
|
37
|
+
# We don't start actually tracking until our minimum window size is met
|
|
38
|
+
if self.completed_count < window:
|
|
39
|
+
return 0.0
|
|
40
|
+
return self.error_trap.error_count / max(1, self.completed_count)
|
|
41
|
+
|
|
42
|
+
def is_error_rate_exceeded(self, window: int) -> bool:
|
|
43
|
+
return self.get_error_rate(window) >= self.failure_threshold
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class CallbackWithContext(Protocol):
|
|
47
|
+
"""Executor callback functions must accept a context kw argument."""
|
|
48
|
+
|
|
49
|
+
def __call__(self, result: Any, *, context: Optional[dict] = None) -> Any: ...
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ErrorCallbackWithContext(Protocol):
|
|
53
|
+
"""Error callbacks take the Exception instance and context."""
|
|
54
|
+
|
|
55
|
+
def __call__(self, exc: Exception, *, context: Optional[dict] = None) -> Any: ...
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ConcurrentThreadExecutor:
|
|
59
|
+
"""
|
|
60
|
+
Interface for executing multiple concurrent tasks with error rate monitoring.
|
|
61
|
+
|
|
62
|
+
This interface should be used exclusively as
|
|
63
|
+
a context manager. New tasks can be submitted to the executor using the `submit`
|
|
64
|
+
method. This submit method functions similarly to the
|
|
65
|
+
submit method of a ThreadPoolExecutor.
|
|
66
|
+
|
|
67
|
+
The underlying queue of tasks is bounded by the `max_workers`
|
|
68
|
+
parameter. This means that only `max_workers` number of
|
|
69
|
+
tasks can be queued up for execution. As tasks complete,
|
|
70
|
+
if there are errors, those are tracked and counted. If
|
|
71
|
+
a certain error rate is exceeded, the executor will shutdown
|
|
72
|
+
early. All queued and running tasks will complete.
|
|
73
|
+
|
|
74
|
+
The reason we bound the underlying task queue is to ensure that when
|
|
75
|
+
a certain error threshold is met there aren't an unbounded
|
|
76
|
+
number of tasks that need to complete. Generally speaking,
|
|
77
|
+
tasks should not be sitting in the queue for long at all since
|
|
78
|
+
the queue size == `max_workers`. The side effect of this is that
|
|
79
|
+
the `submit()` method will block, however this should not matter
|
|
80
|
+
because upstream Tasks need to wait for all jobs to complete
|
|
81
|
+
before the Task can be considered complete.
|
|
82
|
+
|
|
83
|
+
ContextVars from the main parent thread are automatically propagated
|
|
84
|
+
to all child threads.
|
|
85
|
+
|
|
86
|
+
When a task is completed, the user provided `result_callback`
|
|
87
|
+
function will be called with the task result as the only argument.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
*,
|
|
93
|
+
max_workers: int,
|
|
94
|
+
column_name: str,
|
|
95
|
+
result_callback: Optional[CallbackWithContext] = None,
|
|
96
|
+
error_callback: Optional[ErrorCallbackWithContext] = None,
|
|
97
|
+
shutdown_error_rate: float = 0.50,
|
|
98
|
+
shutdown_error_window: int = 10,
|
|
99
|
+
):
|
|
100
|
+
self._executor = None
|
|
101
|
+
self._column_name = column_name
|
|
102
|
+
self._max_workers = max_workers
|
|
103
|
+
self._lock = Lock()
|
|
104
|
+
self._semaphore = Semaphore(self._max_workers)
|
|
105
|
+
self._result_callback = result_callback
|
|
106
|
+
self._error_callback = error_callback
|
|
107
|
+
self._shutdown_error_rate = shutdown_error_rate
|
|
108
|
+
self._shutdown_window_size = shutdown_error_window
|
|
109
|
+
self._results = ExecutorResults(failure_threshold=shutdown_error_rate)
|
|
110
|
+
|
|
111
|
+
def __enter__(self) -> ConcurrentThreadExecutor:
|
|
112
|
+
self._executor = ThreadPoolExecutor(
|
|
113
|
+
max_workers=self._max_workers,
|
|
114
|
+
thread_name_prefix="ConcurrentThreadExecutor",
|
|
115
|
+
initializer=_set_worker_contextvars,
|
|
116
|
+
initargs=(contextvars.copy_context(),),
|
|
117
|
+
)
|
|
118
|
+
return self
|
|
119
|
+
|
|
120
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
121
|
+
self._shutdown_executor()
|
|
122
|
+
if self._results.early_shutdown is True:
|
|
123
|
+
self._raise_task_error()
|
|
124
|
+
|
|
125
|
+
def _shutdown_executor(self) -> None:
|
|
126
|
+
if self._executor is not None:
|
|
127
|
+
self._executor.shutdown()
|
|
128
|
+
|
|
129
|
+
def _raise_task_error(self):
|
|
130
|
+
raise DataDesignerRuntimeError(
|
|
131
|
+
"\n".join(
|
|
132
|
+
[
|
|
133
|
+
" |-- Data generation was terminated early due to error rate exceeding threshold.",
|
|
134
|
+
f" |-- The summary of encountered errors is: \n{json.dumps(self._results.summary, indent=4)}",
|
|
135
|
+
]
|
|
136
|
+
)
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def submit(self, fn, *args, context: Optional[dict] = None, **kwargs) -> None:
|
|
140
|
+
if self._executor is None:
|
|
141
|
+
raise RuntimeError("Executor is not initialized, this class should be used as a context manager.")
|
|
142
|
+
|
|
143
|
+
if self._results.early_shutdown:
|
|
144
|
+
self._shutdown_executor()
|
|
145
|
+
self._raise_task_error()
|
|
146
|
+
|
|
147
|
+
def _handle_future(future: Future) -> None:
|
|
148
|
+
self._results.completed_count += 1
|
|
149
|
+
try:
|
|
150
|
+
result = future.result()
|
|
151
|
+
if self._result_callback is not None:
|
|
152
|
+
self._result_callback(result, context=context)
|
|
153
|
+
self._results.success_count += 1
|
|
154
|
+
except Exception as err:
|
|
155
|
+
with self._lock:
|
|
156
|
+
self._results.error_trap.handle_error(err)
|
|
157
|
+
if self._results.is_error_rate_exceeded(self._shutdown_window_size):
|
|
158
|
+
# Signal to shutdown early on the next submission (if received).
|
|
159
|
+
# We cannot trigger shutdown from within this thread as it can
|
|
160
|
+
# cause a deadlock.
|
|
161
|
+
if not self._results.early_shutdown:
|
|
162
|
+
self._results.early_shutdown = True
|
|
163
|
+
if self._error_callback is not None:
|
|
164
|
+
self._error_callback(err, context=context)
|
|
165
|
+
finally:
|
|
166
|
+
self._semaphore.release()
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
self._semaphore.acquire()
|
|
170
|
+
future = self._executor.submit(fn, *args, **kwargs)
|
|
171
|
+
future.add_done_callback(_handle_future)
|
|
172
|
+
except Exception as err:
|
|
173
|
+
# If we get here, the pool is shutting down (likely due to early termination from errors)
|
|
174
|
+
# We'll re-raise a custom error that can be handled at the call-site and the summary
|
|
175
|
+
# can also be inspected.
|
|
176
|
+
self._semaphore.release()
|
|
177
|
+
if not isinstance(err, RuntimeError) and "after shutdown" not in str(err):
|
|
178
|
+
raise err
|
|
179
|
+
self._raise_task_error()
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _set_worker_contextvars(context: contextvars.Context):
|
|
183
|
+
for var, value in context.items():
|
|
184
|
+
var.set(value)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.config.column_types import DataDesignerColumnType
|
|
5
|
+
from data_designer.config.data_designer_config import DataDesignerConfig
|
|
6
|
+
from data_designer.config.processors import ProcessorConfig
|
|
7
|
+
from data_designer.engine.dataset_builders.multi_column_configs import (
|
|
8
|
+
DatasetBuilderColumnConfigT,
|
|
9
|
+
SamplerMultiColumnConfig,
|
|
10
|
+
SeedDatasetMultiColumnConfig,
|
|
11
|
+
)
|
|
12
|
+
from data_designer.engine.dataset_builders.utils.dag import topologically_sort_column_configs
|
|
13
|
+
from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def compile_dataset_builder_column_configs(config: DataDesignerConfig) -> list[DatasetBuilderColumnConfigT]:
|
|
17
|
+
seed_column_configs = []
|
|
18
|
+
sampler_column_configs = []
|
|
19
|
+
generated_column_configs = []
|
|
20
|
+
|
|
21
|
+
for column_config in topologically_sort_column_configs(config.columns):
|
|
22
|
+
if column_config.column_type == DataDesignerColumnType.SEED_DATASET:
|
|
23
|
+
seed_column_configs.append(column_config)
|
|
24
|
+
elif column_config.column_type == DataDesignerColumnType.SAMPLER:
|
|
25
|
+
sampler_column_configs.append(column_config)
|
|
26
|
+
else:
|
|
27
|
+
generated_column_configs.append(column_config)
|
|
28
|
+
|
|
29
|
+
compiled_column_configs = []
|
|
30
|
+
|
|
31
|
+
if len(seed_column_configs) > 0:
|
|
32
|
+
if config.seed_config is None:
|
|
33
|
+
raise ConfigCompilationError("🛑 Seed column configs require a seed configuration.")
|
|
34
|
+
compiled_column_configs.append(
|
|
35
|
+
SeedDatasetMultiColumnConfig(
|
|
36
|
+
columns=seed_column_configs,
|
|
37
|
+
dataset=config.seed_config.dataset,
|
|
38
|
+
sampling_strategy=config.seed_config.sampling_strategy,
|
|
39
|
+
selection_strategy=config.seed_config.selection_strategy,
|
|
40
|
+
)
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
if len(sampler_column_configs) > 0:
|
|
44
|
+
compiled_column_configs.append(
|
|
45
|
+
SamplerMultiColumnConfig(
|
|
46
|
+
columns=sampler_column_configs,
|
|
47
|
+
constraints=config.constraints or [],
|
|
48
|
+
)
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if len(generated_column_configs) > 0:
|
|
52
|
+
compiled_column_configs.extend(generated_column_configs)
|
|
53
|
+
|
|
54
|
+
return compiled_column_configs
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def compile_dataset_builder_processor_configs(
|
|
58
|
+
config: DataDesignerConfig,
|
|
59
|
+
) -> list[ProcessorConfig]:
|
|
60
|
+
return config.processors or []
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import networkx as nx
|
|
7
|
+
|
|
8
|
+
from data_designer.config.column_types import ColumnConfigT, column_type_used_in_execution_dag
|
|
9
|
+
from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def topologically_sort_column_configs(column_configs: list[ColumnConfigT]) -> list[ColumnConfigT]:
|
|
15
|
+
dag = nx.DiGraph()
|
|
16
|
+
|
|
17
|
+
non_dag_column_config_list = [
|
|
18
|
+
col for col in column_configs if not column_type_used_in_execution_dag(col.column_type)
|
|
19
|
+
]
|
|
20
|
+
dag_column_config_dict = {
|
|
21
|
+
col.name: col for col in column_configs if column_type_used_in_execution_dag(col.column_type)
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
if len(dag_column_config_dict) == 0:
|
|
25
|
+
return non_dag_column_config_list
|
|
26
|
+
|
|
27
|
+
side_effect_dict = {n: list(c.side_effect_columns) for n, c in dag_column_config_dict.items()}
|
|
28
|
+
|
|
29
|
+
logger.info("⛓️ Sorting column configs into a Directed Acyclic Graph")
|
|
30
|
+
for name, col in dag_column_config_dict.items():
|
|
31
|
+
dag.add_node(name)
|
|
32
|
+
for req_col_name in col.required_columns:
|
|
33
|
+
if req_col_name in list(dag_column_config_dict.keys()):
|
|
34
|
+
logger.debug(f" |-- 🔗 `{name}` depends on `{req_col_name}`")
|
|
35
|
+
dag.add_edge(req_col_name, name)
|
|
36
|
+
|
|
37
|
+
# If the required column is a side effect of another column,
|
|
38
|
+
# add an edge from the parent column to the current column.
|
|
39
|
+
elif req_col_name in sum(side_effect_dict.values(), []):
|
|
40
|
+
for parent, cols in side_effect_dict.items():
|
|
41
|
+
if req_col_name in cols:
|
|
42
|
+
logger.debug(f" |-- 🔗 `{name}` depends on `{parent}` via `{req_col_name}`")
|
|
43
|
+
dag.add_edge(parent, name)
|
|
44
|
+
break
|
|
45
|
+
|
|
46
|
+
if not nx.is_directed_acyclic_graph(dag):
|
|
47
|
+
raise DAGCircularDependencyError(
|
|
48
|
+
"🛑 The Data Designer column configurations contain cyclic dependencies. Please "
|
|
49
|
+
"inspect the column configurations and ensure they can be sorted without "
|
|
50
|
+
"circular references."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
sorted_columns = non_dag_column_config_list
|
|
54
|
+
sorted_columns.extend([dag_column_config_dict[n] for n in list(nx.topological_sort(dag))])
|
|
55
|
+
|
|
56
|
+
return sorted_columns
|