data-designer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. data_designer/__init__.py +15 -0
  2. data_designer/_version.py +34 -0
  3. data_designer/cli/README.md +236 -0
  4. data_designer/cli/__init__.py +6 -0
  5. data_designer/cli/commands/__init__.py +2 -0
  6. data_designer/cli/commands/list.py +130 -0
  7. data_designer/cli/commands/models.py +10 -0
  8. data_designer/cli/commands/providers.py +11 -0
  9. data_designer/cli/commands/reset.py +100 -0
  10. data_designer/cli/controllers/__init__.py +7 -0
  11. data_designer/cli/controllers/model_controller.py +246 -0
  12. data_designer/cli/controllers/provider_controller.py +317 -0
  13. data_designer/cli/forms/__init__.py +20 -0
  14. data_designer/cli/forms/builder.py +51 -0
  15. data_designer/cli/forms/field.py +180 -0
  16. data_designer/cli/forms/form.py +59 -0
  17. data_designer/cli/forms/model_builder.py +125 -0
  18. data_designer/cli/forms/provider_builder.py +76 -0
  19. data_designer/cli/main.py +44 -0
  20. data_designer/cli/repositories/__init__.py +8 -0
  21. data_designer/cli/repositories/base.py +39 -0
  22. data_designer/cli/repositories/model_repository.py +42 -0
  23. data_designer/cli/repositories/provider_repository.py +43 -0
  24. data_designer/cli/services/__init__.py +7 -0
  25. data_designer/cli/services/model_service.py +116 -0
  26. data_designer/cli/services/provider_service.py +111 -0
  27. data_designer/cli/ui.py +448 -0
  28. data_designer/cli/utils.py +47 -0
  29. data_designer/config/__init__.py +2 -0
  30. data_designer/config/analysis/column_profilers.py +89 -0
  31. data_designer/config/analysis/column_statistics.py +274 -0
  32. data_designer/config/analysis/dataset_profiler.py +60 -0
  33. data_designer/config/analysis/utils/errors.py +8 -0
  34. data_designer/config/analysis/utils/reporting.py +188 -0
  35. data_designer/config/base.py +68 -0
  36. data_designer/config/column_configs.py +354 -0
  37. data_designer/config/column_types.py +168 -0
  38. data_designer/config/config_builder.py +660 -0
  39. data_designer/config/data_designer_config.py +40 -0
  40. data_designer/config/dataset_builders.py +11 -0
  41. data_designer/config/datastore.py +151 -0
  42. data_designer/config/default_model_settings.py +123 -0
  43. data_designer/config/errors.py +19 -0
  44. data_designer/config/interface.py +54 -0
  45. data_designer/config/models.py +231 -0
  46. data_designer/config/preview_results.py +32 -0
  47. data_designer/config/processors.py +41 -0
  48. data_designer/config/sampler_constraints.py +51 -0
  49. data_designer/config/sampler_params.py +604 -0
  50. data_designer/config/seed.py +145 -0
  51. data_designer/config/utils/code_lang.py +83 -0
  52. data_designer/config/utils/constants.py +313 -0
  53. data_designer/config/utils/errors.py +19 -0
  54. data_designer/config/utils/info.py +88 -0
  55. data_designer/config/utils/io_helpers.py +273 -0
  56. data_designer/config/utils/misc.py +81 -0
  57. data_designer/config/utils/numerical_helpers.py +28 -0
  58. data_designer/config/utils/type_helpers.py +100 -0
  59. data_designer/config/utils/validation.py +336 -0
  60. data_designer/config/utils/visualization.py +427 -0
  61. data_designer/config/validator_params.py +96 -0
  62. data_designer/engine/__init__.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +55 -0
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
  65. data_designer/engine/analysis/column_profilers/registry.py +20 -0
  66. data_designer/engine/analysis/column_statistics.py +142 -0
  67. data_designer/engine/analysis/dataset_profiler.py +125 -0
  68. data_designer/engine/analysis/errors.py +7 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
  70. data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
  71. data_designer/engine/column_generators/__init__.py +2 -0
  72. data_designer/engine/column_generators/generators/__init__.py +2 -0
  73. data_designer/engine/column_generators/generators/base.py +61 -0
  74. data_designer/engine/column_generators/generators/expression.py +63 -0
  75. data_designer/engine/column_generators/generators/llm_generators.py +172 -0
  76. data_designer/engine/column_generators/generators/samplers.py +75 -0
  77. data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
  78. data_designer/engine/column_generators/generators/validation.py +147 -0
  79. data_designer/engine/column_generators/registry.py +56 -0
  80. data_designer/engine/column_generators/utils/errors.py +13 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
  83. data_designer/engine/configurable_task.py +82 -0
  84. data_designer/engine/dataset_builders/artifact_storage.py +181 -0
  85. data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
  86. data_designer/engine/dataset_builders/errors.py +13 -0
  87. data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
  88. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +56 -0
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
  93. data_designer/engine/dataset_builders/utils/errors.py +13 -0
  94. data_designer/engine/errors.py +49 -0
  95. data_designer/engine/model_provider.py +75 -0
  96. data_designer/engine/models/__init__.py +2 -0
  97. data_designer/engine/models/errors.py +308 -0
  98. data_designer/engine/models/facade.py +225 -0
  99. data_designer/engine/models/litellm_overrides.py +162 -0
  100. data_designer/engine/models/parsers/__init__.py +2 -0
  101. data_designer/engine/models/parsers/errors.py +34 -0
  102. data_designer/engine/models/parsers/parser.py +236 -0
  103. data_designer/engine/models/parsers/postprocessors.py +93 -0
  104. data_designer/engine/models/parsers/tag_parsers.py +60 -0
  105. data_designer/engine/models/parsers/types.py +82 -0
  106. data_designer/engine/models/recipes/base.py +79 -0
  107. data_designer/engine/models/recipes/response_recipes.py +291 -0
  108. data_designer/engine/models/registry.py +118 -0
  109. data_designer/engine/models/usage.py +75 -0
  110. data_designer/engine/models/utils.py +38 -0
  111. data_designer/engine/processing/ginja/__init__.py +2 -0
  112. data_designer/engine/processing/ginja/ast.py +64 -0
  113. data_designer/engine/processing/ginja/environment.py +461 -0
  114. data_designer/engine/processing/ginja/exceptions.py +54 -0
  115. data_designer/engine/processing/ginja/record.py +30 -0
  116. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  117. data_designer/engine/processing/gsonschema/exceptions.py +8 -0
  118. data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
  119. data_designer/engine/processing/gsonschema/types.py +8 -0
  120. data_designer/engine/processing/gsonschema/validators.py +143 -0
  121. data_designer/engine/processing/processors/base.py +15 -0
  122. data_designer/engine/processing/processors/drop_columns.py +46 -0
  123. data_designer/engine/processing/processors/registry.py +20 -0
  124. data_designer/engine/processing/utils.py +120 -0
  125. data_designer/engine/registry/base.py +97 -0
  126. data_designer/engine/registry/data_designer_registry.py +37 -0
  127. data_designer/engine/registry/errors.py +10 -0
  128. data_designer/engine/resources/managed_dataset_generator.py +35 -0
  129. data_designer/engine/resources/managed_dataset_repository.py +194 -0
  130. data_designer/engine/resources/managed_storage.py +63 -0
  131. data_designer/engine/resources/resource_provider.py +46 -0
  132. data_designer/engine/resources/seed_dataset_data_store.py +66 -0
  133. data_designer/engine/sampling_gen/column.py +89 -0
  134. data_designer/engine/sampling_gen/constraints.py +95 -0
  135. data_designer/engine/sampling_gen/data_sources/base.py +214 -0
  136. data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
  137. data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
  138. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  139. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  140. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
  141. data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
  142. data_designer/engine/sampling_gen/entities/errors.py +8 -0
  143. data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
  144. data_designer/engine/sampling_gen/entities/person.py +142 -0
  145. data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
  146. data_designer/engine/sampling_gen/errors.py +24 -0
  147. data_designer/engine/sampling_gen/generator.py +121 -0
  148. data_designer/engine/sampling_gen/jinja_utils.py +60 -0
  149. data_designer/engine/sampling_gen/people_gen.py +203 -0
  150. data_designer/engine/sampling_gen/person_constants.py +54 -0
  151. data_designer/engine/sampling_gen/schema.py +143 -0
  152. data_designer/engine/sampling_gen/schema_builder.py +59 -0
  153. data_designer/engine/sampling_gen/utils.py +40 -0
  154. data_designer/engine/secret_resolver.py +80 -0
  155. data_designer/engine/validators/__init__.py +17 -0
  156. data_designer/engine/validators/base.py +36 -0
  157. data_designer/engine/validators/local_callable.py +34 -0
  158. data_designer/engine/validators/python.py +245 -0
  159. data_designer/engine/validators/remote.py +83 -0
  160. data_designer/engine/validators/sql.py +60 -0
  161. data_designer/errors.py +5 -0
  162. data_designer/essentials/__init__.py +137 -0
  163. data_designer/interface/__init__.py +2 -0
  164. data_designer/interface/data_designer.py +351 -0
  165. data_designer/interface/errors.py +16 -0
  166. data_designer/interface/results.py +55 -0
  167. data_designer/logging.py +161 -0
  168. data_designer/plugin_manager.py +83 -0
  169. data_designer/plugins/__init__.py +6 -0
  170. data_designer/plugins/errors.py +10 -0
  171. data_designer/plugins/plugin.py +69 -0
  172. data_designer/plugins/registry.py +86 -0
  173. data_designer-0.1.0.dist-info/METADATA +173 -0
  174. data_designer-0.1.0.dist-info/RECORD +177 -0
  175. data_designer-0.1.0.dist-info/WHEEL +4 -0
  176. data_designer-0.1.0.dist-info/entry_points.txt +2 -0
  177. data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,287 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import functools
5
+ import json
6
+ import logging
7
+ from pathlib import Path
8
+ import time
9
+ from typing import Callable
10
+
11
+ import pandas as pd
12
+
13
+ from data_designer.config.column_types import ColumnConfigT, column_type_is_llm_generated
14
+ from data_designer.config.dataset_builders import BuildStage
15
+ from data_designer.config.processors import (
16
+ DropColumnsProcessorConfig,
17
+ ProcessorConfig,
18
+ ProcessorType,
19
+ )
20
+ from data_designer.engine.column_generators.generators.base import ColumnGenerator, GenerationStrategy
21
+ from data_designer.engine.column_generators.generators.llm_generators import WithLLMGeneration
22
+ from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
23
+ from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
24
+ from data_designer.engine.dataset_builders.multi_column_configs import (
25
+ DatasetBuilderColumnConfigT,
26
+ MultiColumnConfig,
27
+ )
28
+ from data_designer.engine.dataset_builders.utils.concurrency import (
29
+ MAX_CONCURRENCY_PER_NON_LLM_GENERATOR,
30
+ ConcurrentThreadExecutor,
31
+ )
32
+ from data_designer.engine.dataset_builders.utils.dataset_batch_manager import (
33
+ DatasetBatchManager,
34
+ )
35
+ from data_designer.engine.processing.processors.base import Processor
36
+ from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
37
+ from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
38
+ from data_designer.engine.resources.resource_provider import ResourceProvider
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ class ColumnWiseDatasetBuilder:
44
+ def __init__(
45
+ self,
46
+ column_configs: list[DatasetBuilderColumnConfigT],
47
+ processor_configs: list[ProcessorConfig],
48
+ resource_provider: ResourceProvider,
49
+ registry: DataDesignerRegistry | None = None,
50
+ ):
51
+ self.batch_manager = DatasetBatchManager(resource_provider.artifact_storage)
52
+ self._resource_provider = resource_provider
53
+ self._records_to_drop: set[int] = set()
54
+ self._registry = registry or DataDesignerRegistry()
55
+ self._column_configs = column_configs
56
+ self._processors: dict[BuildStage, list[Processor]] = self._initialize_processors(processor_configs)
57
+ self._validate_column_configs()
58
+
59
+ @property
60
+ def artifact_storage(self) -> ArtifactStorage:
61
+ return self._resource_provider.artifact_storage
62
+
63
+ @functools.cached_property
64
+ def single_column_configs(self) -> list[ColumnConfigT]:
65
+ configs = []
66
+ for config in self._column_configs:
67
+ if isinstance(config, MultiColumnConfig):
68
+ configs.extend(config.columns)
69
+ else:
70
+ configs.append(config)
71
+ return configs
72
+
73
+ @functools.cached_property
74
+ def llm_generated_column_configs(self) -> list[ColumnConfigT]:
75
+ return [config for config in self.single_column_configs if column_type_is_llm_generated(config.column_type)]
76
+
77
+ def build(
78
+ self,
79
+ *,
80
+ num_records: int,
81
+ buffer_size: int,
82
+ on_batch_complete: Callable[[Path], None] | None = None,
83
+ ) -> Path:
84
+ self._write_configs()
85
+ self._run_model_health_check_if_needed()
86
+
87
+ generators = self._initialize_generators()
88
+ start_time = time.perf_counter()
89
+
90
+ self.batch_manager.start(num_records=num_records, buffer_size=buffer_size)
91
+ for batch_idx in range(1, self.batch_manager.num_batches + 1):
92
+ logger.info(f"⏳ Processing batch {batch_idx} of {self.batch_manager.num_batches}")
93
+ self._run_batch(generators)
94
+ df_batch = self._run_processors(
95
+ stage=BuildStage.POST_BATCH,
96
+ dataframe=self.batch_manager.get_current_batch(as_dataframe=True),
97
+ current_batch_number=batch_idx,
98
+ )
99
+ self._write_processed_batch(df_batch)
100
+ self.batch_manager.finish_batch(on_batch_complete)
101
+ self.batch_manager.finish()
102
+
103
+ model_usage_stats = self._resource_provider.model_registry.get_model_usage_stats(
104
+ time.perf_counter() - start_time
105
+ )
106
+ logger.info(f"📊 Model usage summary:\n{json.dumps(model_usage_stats, indent=4)}")
107
+
108
+ return self.artifact_storage.final_dataset_path
109
+
110
+ def build_preview(self, *, num_records: int) -> pd.DataFrame:
111
+ self._run_model_health_check_if_needed()
112
+
113
+ generators = self._initialize_generators()
114
+
115
+ start_time = time.perf_counter()
116
+ self.batch_manager.start(num_records=num_records, buffer_size=num_records)
117
+ self._run_batch(generators, save_partial_results=False)
118
+ dataset = self.batch_manager.get_current_batch(as_dataframe=True)
119
+ self.batch_manager.reset()
120
+
121
+ model_usage_stats = self._resource_provider.model_registry.get_model_usage_stats(
122
+ time.perf_counter() - start_time
123
+ )
124
+ logger.info(f"📊 Model usage summary:\n{json.dumps(model_usage_stats, indent=4)}")
125
+
126
+ return dataset
127
+
128
+ def process_preview(self, dataset: pd.DataFrame) -> pd.DataFrame:
129
+ return self._run_processors(
130
+ stage=BuildStage.POST_BATCH,
131
+ dataframe=dataset.copy(),
132
+ current_batch_number=None, # preview mode does not have a batch number
133
+ )
134
+
135
+ def _initialize_generators(self) -> list[ColumnGenerator]:
136
+ return [
137
+ self._registry.column_generators.get_for_config_type(type(config))(
138
+ config=config, resource_provider=self._resource_provider
139
+ )
140
+ for config in self._column_configs
141
+ ]
142
+
143
+ def _run_batch(self, generators: list[ColumnGenerator], *, save_partial_results: bool = True) -> None:
144
+ for generator in generators:
145
+ generator.log_pre_generation()
146
+ try:
147
+ if generator.can_generate_from_scratch and self.batch_manager.buffer_is_empty:
148
+ self._run_from_scratch_column_generator(generator)
149
+ elif generator.generation_strategy == GenerationStrategy.CELL_BY_CELL:
150
+ self._run_cell_by_cell_generator(generator)
151
+ elif generator.generation_strategy == GenerationStrategy.FULL_COLUMN:
152
+ self._run_full_column_generator(generator)
153
+ else:
154
+ logger.error(f"❌ Unknown generation strategy: {generator.generation_strategy}")
155
+ raise DatasetGenerationError(f"🛑 Unknown generation strategy: {generator.generation_strategy}")
156
+ if save_partial_results:
157
+ self.batch_manager.write()
158
+ except Exception as e:
159
+ column_error_str = (
160
+ f"columns {generator.config.column_names}"
161
+ if hasattr(generator.config, "column_names")
162
+ else f"column {generator.config.name!r}"
163
+ )
164
+ raise DatasetGenerationError(f"🛑 Failed to process {column_error_str}:\n{e}")
165
+
166
+ def _run_from_scratch_column_generator(self, generator: ColumnGenerator) -> None:
167
+ df = generator.generate_from_scratch(self.batch_manager.num_records_batch)
168
+ self.batch_manager.add_records(df.to_dict(orient="records"))
169
+
170
+ def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
171
+ max_workers = MAX_CONCURRENCY_PER_NON_LLM_GENERATOR
172
+ if isinstance(generator, WithLLMGeneration):
173
+ max_workers = generator.inference_parameters.max_parallel_requests
174
+ self._fan_out_with_threads(generator, max_workers=max_workers)
175
+
176
+ def _run_full_column_generator(self, generator: ColumnGenerator) -> None:
177
+ df = generator.generate(self.batch_manager.get_current_batch(as_dataframe=True))
178
+ self.batch_manager.update_records(df.to_dict(orient="records"))
179
+
180
+ def _run_model_health_check_if_needed(self) -> bool:
181
+ if any(column_type_is_llm_generated(config.column_type) for config in self.single_column_configs):
182
+ self._resource_provider.model_registry.run_health_check(
183
+ set(config.model_alias for config in self.llm_generated_column_configs)
184
+ )
185
+
186
+ def _fan_out_with_threads(self, generator: WithLLMGeneration, max_workers: int) -> None:
187
+ if generator.generation_strategy != GenerationStrategy.CELL_BY_CELL:
188
+ raise DatasetGenerationError(
189
+ f"Generator {generator.metadata().name} is not a {GenerationStrategy.CELL_BY_CELL} "
190
+ "generator so concurrency through threads is not supported."
191
+ )
192
+
193
+ logger.info(
194
+ f"🐙 Processing {generator.config.column_type} column '{generator.config.name}' "
195
+ f"with {max_workers} concurrent workers"
196
+ )
197
+ with ConcurrentThreadExecutor(
198
+ max_workers=max_workers,
199
+ column_name=generator.config.name,
200
+ result_callback=self._worker_result_callback,
201
+ error_callback=self._worker_error_callback,
202
+ ) as executor:
203
+ for i, record in self.batch_manager.iter_current_batch():
204
+ executor.submit(lambda record: generator.generate(record), record, context={"index": i})
205
+
206
+ if len(self._records_to_drop) > 0:
207
+ self.batch_manager.drop_records(self._records_to_drop)
208
+ self._records_to_drop.clear()
209
+
210
+ def _write_processed_batch(self, dataframe: pd.DataFrame) -> None:
211
+ self.batch_manager.update_records(dataframe.to_dict(orient="records"))
212
+ self.batch_manager.write()
213
+
214
+ def _validate_column_configs(self) -> None:
215
+ if len(self._column_configs) == 0:
216
+ raise DatasetGenerationError("🛑 No column configs provided.")
217
+
218
+ if not self._registry.column_generators.get_for_config_type(
219
+ type(self._column_configs[0])
220
+ ).can_generate_from_scratch:
221
+ raise DatasetGenerationError("🛑 The first column config must be a from-scratch column generator.")
222
+
223
+ def _initialize_processors(self, processor_configs: list[ProcessorConfig]) -> dict[BuildStage, list[Processor]]:
224
+ # Check columns marked for drop
225
+ columns_to_drop = [config.name for config in self.single_column_configs if config.drop]
226
+
227
+ processors: dict[BuildStage, list[Processor]] = {stage: [] for stage in BuildStage}
228
+ for config in processor_configs:
229
+ processors[config.build_stage].append(
230
+ self._registry.processors.get_for_config_type(type(config))(
231
+ config=config,
232
+ resource_provider=self._resource_provider,
233
+ )
234
+ )
235
+
236
+ # Manually included "drop columns" processor takes precedence (can e.g., pick stages other than post-batch)
237
+ if config.processor_type == ProcessorType.DROP_COLUMNS:
238
+ for column in config.column_names:
239
+ if column in columns_to_drop:
240
+ columns_to_drop.remove(column)
241
+
242
+ # If there are still columns marked for drop, add the "drop columns" processor to drop them
243
+ if len(columns_to_drop) > 0:
244
+ processors[BuildStage.POST_BATCH].append( # as post-batch by default
245
+ DropColumnsProcessor(
246
+ config=DropColumnsProcessorConfig(
247
+ column_names=columns_to_drop,
248
+ build_stage=BuildStage.POST_BATCH,
249
+ ),
250
+ resource_provider=self._resource_provider,
251
+ )
252
+ )
253
+
254
+ return processors
255
+
256
+ def _run_processors(
257
+ self, stage: BuildStage, dataframe: pd.DataFrame, current_batch_number: int | None = None
258
+ ) -> pd.DataFrame:
259
+ for processor in self._processors[stage]:
260
+ try:
261
+ dataframe = processor.process(dataframe, current_batch_number=current_batch_number)
262
+ except Exception as e:
263
+ raise DatasetProcessingError(
264
+ f"🛑 Failed to process dataset with processor {processor.metadata().name} in stage {stage}: {e}"
265
+ ) from e
266
+ return dataframe
267
+
268
+ def _worker_error_callback(self, exc: Exception, *, context: dict | None = None) -> None:
269
+ """If a worker fails, we can handle the exception here."""
270
+ logger.warning(
271
+ f"⚠️ Generation for record at index {context['index']} failed. "
272
+ f"Will omit this record from the dataset.\n{exc}"
273
+ )
274
+ self._records_to_drop.add(context["index"])
275
+
276
+ def _worker_result_callback(self, result: dict, *, context: dict | None = None) -> None:
277
+ self.batch_manager.update_record(context["index"], result)
278
+
279
+ def _write_configs(self) -> None:
280
+ self.artifact_storage.write_configs(
281
+ json_file_name="column_configs.json",
282
+ configs=self._column_configs,
283
+ )
284
+ self.artifact_storage.write_configs(
285
+ json_file_name="model_configs.json",
286
+ configs=self._resource_provider.model_registry.model_configs.values(),
287
+ )
@@ -0,0 +1,13 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from data_designer.engine.errors import DataDesignerError
5
+
6
+
7
+ class ArtifactStorageError(DataDesignerError): ...
8
+
9
+
10
+ class DatasetGenerationError(DataDesignerError): ...
11
+
12
+
13
+ class DatasetProcessingError(DataDesignerError): ...
@@ -0,0 +1,44 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC
5
+ from typing import TypeAlias
6
+
7
+ from pydantic import Field, field_validator
8
+
9
+ from data_designer.config.base import ConfigBase
10
+ from data_designer.config.column_configs import SamplerColumnConfig, SeedDatasetColumnConfig, SingleColumnConfig
11
+ from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType
12
+ from data_designer.config.sampler_constraints import ColumnConstraintT
13
+ from data_designer.config.seed import SeedConfig
14
+
15
+
16
+ class MultiColumnConfig(ConfigBase, ABC):
17
+ columns: list[SingleColumnConfig] = Field(..., min_length=1)
18
+
19
+ @property
20
+ def column_names(self) -> list[str]:
21
+ return [c.name for c in self.columns]
22
+
23
+ @property
24
+ def column_type(self) -> DataDesignerColumnType:
25
+ return self.columns[0].column_type
26
+
27
+ @field_validator("columns", mode="after")
28
+ def validate_column_types_are_the_same(cls, v: list[SingleColumnConfig]) -> list[SingleColumnConfig]:
29
+ if len(set([c.column_type for c in v])) != 1:
30
+ raise ValueError("All column types must be of the same type")
31
+ return v
32
+
33
+
34
+ class SamplerMultiColumnConfig(MultiColumnConfig):
35
+ columns: list[SamplerColumnConfig]
36
+ constraints: list[ColumnConstraintT] = []
37
+ max_rejections_factor: int = 5
38
+
39
+
40
+ class SeedDatasetMultiColumnConfig(SeedConfig, MultiColumnConfig):
41
+ columns: list[SeedDatasetColumnConfig]
42
+
43
+
44
+ DatasetBuilderColumnConfigT: TypeAlias = ColumnConfigT | SeedDatasetMultiColumnConfig | SamplerMultiColumnConfig
@@ -0,0 +1,2 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,184 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from concurrent.futures import Future, ThreadPoolExecutor
7
+ import contextvars
8
+ import json
9
+ import logging
10
+ from threading import Lock, Semaphore
11
+ from typing import Any, Optional, Protocol
12
+
13
+ from pydantic import BaseModel, Field
14
+
15
+ from data_designer.engine.errors import DataDesignerRuntimeError, ErrorTrap
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Constants
20
+ MAX_CONCURRENCY_PER_NON_LLM_GENERATOR = 4
21
+
22
+
23
+ class ExecutorResults(BaseModel):
24
+ failure_threshold: float = 0.0 # Error rate threshold
25
+ completed_count: int = 0 # How many tasks/jobs completed
26
+ success_count: int = 0 # How many tasks/jobs were successful
27
+ early_shutdown: bool = False # Did we shutdown early due to errors?
28
+ error_trap: ErrorTrap = Field(default_factory=ErrorTrap)
29
+
30
+ @property
31
+ def summary(self) -> dict:
32
+ summary = self.model_dump(exclude={"error_trap"})
33
+ summary |= self.error_trap.model_dump()
34
+ return summary
35
+
36
+ def get_error_rate(self, window: int) -> float:
37
+ # We don't start actually tracking until our minimum window size is met
38
+ if self.completed_count < window:
39
+ return 0.0
40
+ return self.error_trap.error_count / max(1, self.completed_count)
41
+
42
+ def is_error_rate_exceeded(self, window: int) -> bool:
43
+ return self.get_error_rate(window) >= self.failure_threshold
44
+
45
+
46
+ class CallbackWithContext(Protocol):
47
+ """Executor callback functions must accept a context kw argument."""
48
+
49
+ def __call__(self, result: Any, *, context: Optional[dict] = None) -> Any: ...
50
+
51
+
52
+ class ErrorCallbackWithContext(Protocol):
53
+ """Error callbacks take the Exception instance and context."""
54
+
55
+ def __call__(self, exc: Exception, *, context: Optional[dict] = None) -> Any: ...
56
+
57
+
58
+ class ConcurrentThreadExecutor:
59
+ """
60
+ Interface for executing multiple concurrent tasks with error rate monitoring.
61
+
62
+ This interface should be used exclusively as
63
+ a context manager. New tasks can be submitted to the executor using the `submit`
64
+ method. This submit method functions similarly to the
65
+ submit method of a ThreadPoolExecutor.
66
+
67
+ The underlying queue of tasks is bounded by the `max_workers`
68
+ parameter. This means that only `max_workers` number of
69
+ tasks can be queued up for execution. As tasks complete,
70
+ if there are errors, those are tracked and counted. If
71
+ a certain error rate is exceeded, the executor will shutdown
72
+ early. All queued and running tasks will complete.
73
+
74
+ The reason we bound the underlying task queue is to ensure that when
75
+ a certain error threshold is met there aren't an unbounded
76
+ number of tasks that need to complete. Generally speaking,
77
+ tasks should not be sitting in the queue for long at all since
78
+ the queue size == `max_workers`. The side effect of this is that
79
+ the `submit()` method will block, however this should not matter
80
+ because upstream Tasks need to wait for all jobs to complete
81
+ before the Task can be considered complete.
82
+
83
+ ContextVars from the main parent thread are automatically propagated
84
+ to all child threads.
85
+
86
+ When a task is completed, the user provided `result_callback`
87
+ function will be called with the task result as the only argument.
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ *,
93
+ max_workers: int,
94
+ column_name: str,
95
+ result_callback: Optional[CallbackWithContext] = None,
96
+ error_callback: Optional[ErrorCallbackWithContext] = None,
97
+ shutdown_error_rate: float = 0.50,
98
+ shutdown_error_window: int = 10,
99
+ ):
100
+ self._executor = None
101
+ self._column_name = column_name
102
+ self._max_workers = max_workers
103
+ self._lock = Lock()
104
+ self._semaphore = Semaphore(self._max_workers)
105
+ self._result_callback = result_callback
106
+ self._error_callback = error_callback
107
+ self._shutdown_error_rate = shutdown_error_rate
108
+ self._shutdown_window_size = shutdown_error_window
109
+ self._results = ExecutorResults(failure_threshold=shutdown_error_rate)
110
+
111
+ def __enter__(self) -> ConcurrentThreadExecutor:
112
+ self._executor = ThreadPoolExecutor(
113
+ max_workers=self._max_workers,
114
+ thread_name_prefix="ConcurrentThreadExecutor",
115
+ initializer=_set_worker_contextvars,
116
+ initargs=(contextvars.copy_context(),),
117
+ )
118
+ return self
119
+
120
+ def __exit__(self, exc_type, exc_value, traceback):
121
+ self._shutdown_executor()
122
+ if self._results.early_shutdown is True:
123
+ self._raise_task_error()
124
+
125
+ def _shutdown_executor(self) -> None:
126
+ if self._executor is not None:
127
+ self._executor.shutdown()
128
+
129
+ def _raise_task_error(self):
130
+ raise DataDesignerRuntimeError(
131
+ "\n".join(
132
+ [
133
+ " |-- Data generation was terminated early due to error rate exceeding threshold.",
134
+ f" |-- The summary of encountered errors is: \n{json.dumps(self._results.summary, indent=4)}",
135
+ ]
136
+ )
137
+ )
138
+
139
+ def submit(self, fn, *args, context: Optional[dict] = None, **kwargs) -> None:
140
+ if self._executor is None:
141
+ raise RuntimeError("Executor is not initialized, this class should be used as a context manager.")
142
+
143
+ if self._results.early_shutdown:
144
+ self._shutdown_executor()
145
+ self._raise_task_error()
146
+
147
+ def _handle_future(future: Future) -> None:
148
+ self._results.completed_count += 1
149
+ try:
150
+ result = future.result()
151
+ if self._result_callback is not None:
152
+ self._result_callback(result, context=context)
153
+ self._results.success_count += 1
154
+ except Exception as err:
155
+ with self._lock:
156
+ self._results.error_trap.handle_error(err)
157
+ if self._results.is_error_rate_exceeded(self._shutdown_window_size):
158
+ # Signal to shutdown early on the next submission (if received).
159
+ # We cannot trigger shutdown from within this thread as it can
160
+ # cause a deadlock.
161
+ if not self._results.early_shutdown:
162
+ self._results.early_shutdown = True
163
+ if self._error_callback is not None:
164
+ self._error_callback(err, context=context)
165
+ finally:
166
+ self._semaphore.release()
167
+
168
+ try:
169
+ self._semaphore.acquire()
170
+ future = self._executor.submit(fn, *args, **kwargs)
171
+ future.add_done_callback(_handle_future)
172
+ except Exception as err:
173
+ # If we get here, the pool is shutting down (likely due to early termination from errors)
174
+ # We'll re-raise a custom error that can be handled at the call-site and the summary
175
+ # can also be inspected.
176
+ self._semaphore.release()
177
+ if not isinstance(err, RuntimeError) and "after shutdown" not in str(err):
178
+ raise err
179
+ self._raise_task_error()
180
+
181
+
182
+ def _set_worker_contextvars(context: contextvars.Context):
183
+ for var, value in context.items():
184
+ var.set(value)
@@ -0,0 +1,60 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from data_designer.config.column_types import DataDesignerColumnType
5
+ from data_designer.config.data_designer_config import DataDesignerConfig
6
+ from data_designer.config.processors import ProcessorConfig
7
+ from data_designer.engine.dataset_builders.multi_column_configs import (
8
+ DatasetBuilderColumnConfigT,
9
+ SamplerMultiColumnConfig,
10
+ SeedDatasetMultiColumnConfig,
11
+ )
12
+ from data_designer.engine.dataset_builders.utils.dag import topologically_sort_column_configs
13
+ from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError
14
+
15
+
16
+ def compile_dataset_builder_column_configs(config: DataDesignerConfig) -> list[DatasetBuilderColumnConfigT]:
17
+ seed_column_configs = []
18
+ sampler_column_configs = []
19
+ generated_column_configs = []
20
+
21
+ for column_config in topologically_sort_column_configs(config.columns):
22
+ if column_config.column_type == DataDesignerColumnType.SEED_DATASET:
23
+ seed_column_configs.append(column_config)
24
+ elif column_config.column_type == DataDesignerColumnType.SAMPLER:
25
+ sampler_column_configs.append(column_config)
26
+ else:
27
+ generated_column_configs.append(column_config)
28
+
29
+ compiled_column_configs = []
30
+
31
+ if len(seed_column_configs) > 0:
32
+ if config.seed_config is None:
33
+ raise ConfigCompilationError("🛑 Seed column configs require a seed configuration.")
34
+ compiled_column_configs.append(
35
+ SeedDatasetMultiColumnConfig(
36
+ columns=seed_column_configs,
37
+ dataset=config.seed_config.dataset,
38
+ sampling_strategy=config.seed_config.sampling_strategy,
39
+ selection_strategy=config.seed_config.selection_strategy,
40
+ )
41
+ )
42
+
43
+ if len(sampler_column_configs) > 0:
44
+ compiled_column_configs.append(
45
+ SamplerMultiColumnConfig(
46
+ columns=sampler_column_configs,
47
+ constraints=config.constraints or [],
48
+ )
49
+ )
50
+
51
+ if len(generated_column_configs) > 0:
52
+ compiled_column_configs.extend(generated_column_configs)
53
+
54
+ return compiled_column_configs
55
+
56
+
57
+ def compile_dataset_builder_processor_configs(
58
+ config: DataDesignerConfig,
59
+ ) -> list[ProcessorConfig]:
60
+ return config.processors or []
@@ -0,0 +1,56 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import logging
5
+
6
+ import networkx as nx
7
+
8
+ from data_designer.config.column_types import ColumnConfigT, column_type_used_in_execution_dag
9
+ from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def topologically_sort_column_configs(column_configs: list[ColumnConfigT]) -> list[ColumnConfigT]:
15
+ dag = nx.DiGraph()
16
+
17
+ non_dag_column_config_list = [
18
+ col for col in column_configs if not column_type_used_in_execution_dag(col.column_type)
19
+ ]
20
+ dag_column_config_dict = {
21
+ col.name: col for col in column_configs if column_type_used_in_execution_dag(col.column_type)
22
+ }
23
+
24
+ if len(dag_column_config_dict) == 0:
25
+ return non_dag_column_config_list
26
+
27
+ side_effect_dict = {n: list(c.side_effect_columns) for n, c in dag_column_config_dict.items()}
28
+
29
+ logger.info("⛓️ Sorting column configs into a Directed Acyclic Graph")
30
+ for name, col in dag_column_config_dict.items():
31
+ dag.add_node(name)
32
+ for req_col_name in col.required_columns:
33
+ if req_col_name in list(dag_column_config_dict.keys()):
34
+ logger.debug(f" |-- 🔗 `{name}` depends on `{req_col_name}`")
35
+ dag.add_edge(req_col_name, name)
36
+
37
+ # If the required column is a side effect of another column,
38
+ # add an edge from the parent column to the current column.
39
+ elif req_col_name in sum(side_effect_dict.values(), []):
40
+ for parent, cols in side_effect_dict.items():
41
+ if req_col_name in cols:
42
+ logger.debug(f" |-- 🔗 `{name}` depends on `{parent}` via `{req_col_name}`")
43
+ dag.add_edge(parent, name)
44
+ break
45
+
46
+ if not nx.is_directed_acyclic_graph(dag):
47
+ raise DAGCircularDependencyError(
48
+ "🛑 The Data Designer column configurations contain cyclic dependencies. Please "
49
+ "inspect the column configurations and ensure they can be sorted without "
50
+ "circular references."
51
+ )
52
+
53
+ sorted_columns = non_dag_column_config_list
54
+ sorted_columns.extend([dag_column_config_dict[n] for n in list(nx.topological_sort(dag))])
55
+
56
+ return sorted_columns