data-designer-engine 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. data_designer/engine/__init__.py +2 -0
  2. data_designer/engine/_version.py +34 -0
  3. data_designer/engine/analysis/column_profilers/base.py +49 -0
  4. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +153 -0
  5. data_designer/engine/analysis/column_profilers/registry.py +22 -0
  6. data_designer/engine/analysis/column_statistics.py +145 -0
  7. data_designer/engine/analysis/dataset_profiler.py +149 -0
  8. data_designer/engine/analysis/errors.py +9 -0
  9. data_designer/engine/analysis/utils/column_statistics_calculations.py +234 -0
  10. data_designer/engine/analysis/utils/judge_score_processing.py +132 -0
  11. data_designer/engine/column_generators/__init__.py +2 -0
  12. data_designer/engine/column_generators/generators/__init__.py +2 -0
  13. data_designer/engine/column_generators/generators/base.py +122 -0
  14. data_designer/engine/column_generators/generators/embedding.py +35 -0
  15. data_designer/engine/column_generators/generators/expression.py +55 -0
  16. data_designer/engine/column_generators/generators/llm_completion.py +116 -0
  17. data_designer/engine/column_generators/generators/samplers.py +69 -0
  18. data_designer/engine/column_generators/generators/seed_dataset.py +144 -0
  19. data_designer/engine/column_generators/generators/validation.py +140 -0
  20. data_designer/engine/column_generators/registry.py +60 -0
  21. data_designer/engine/column_generators/utils/errors.py +15 -0
  22. data_designer/engine/column_generators/utils/generator_classification.py +43 -0
  23. data_designer/engine/column_generators/utils/judge_score_factory.py +58 -0
  24. data_designer/engine/column_generators/utils/prompt_renderer.py +100 -0
  25. data_designer/engine/compiler.py +97 -0
  26. data_designer/engine/configurable_task.py +71 -0
  27. data_designer/engine/dataset_builders/artifact_storage.py +283 -0
  28. data_designer/engine/dataset_builders/column_wise_builder.py +354 -0
  29. data_designer/engine/dataset_builders/errors.py +15 -0
  30. data_designer/engine/dataset_builders/multi_column_configs.py +46 -0
  31. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  32. data_designer/engine/dataset_builders/utils/concurrency.py +212 -0
  33. data_designer/engine/dataset_builders/utils/config_compiler.py +62 -0
  34. data_designer/engine/dataset_builders/utils/dag.py +62 -0
  35. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +200 -0
  36. data_designer/engine/dataset_builders/utils/errors.py +15 -0
  37. data_designer/engine/dataset_builders/utils/progress_tracker.py +122 -0
  38. data_designer/engine/errors.py +51 -0
  39. data_designer/engine/model_provider.py +77 -0
  40. data_designer/engine/models/__init__.py +2 -0
  41. data_designer/engine/models/errors.py +300 -0
  42. data_designer/engine/models/facade.py +284 -0
  43. data_designer/engine/models/factory.py +42 -0
  44. data_designer/engine/models/litellm_overrides.py +179 -0
  45. data_designer/engine/models/parsers/__init__.py +2 -0
  46. data_designer/engine/models/parsers/errors.py +34 -0
  47. data_designer/engine/models/parsers/parser.py +235 -0
  48. data_designer/engine/models/parsers/postprocessors.py +93 -0
  49. data_designer/engine/models/parsers/tag_parsers.py +62 -0
  50. data_designer/engine/models/parsers/types.py +84 -0
  51. data_designer/engine/models/recipes/base.py +81 -0
  52. data_designer/engine/models/recipes/response_recipes.py +293 -0
  53. data_designer/engine/models/registry.py +151 -0
  54. data_designer/engine/models/telemetry.py +362 -0
  55. data_designer/engine/models/usage.py +73 -0
  56. data_designer/engine/models/utils.py +101 -0
  57. data_designer/engine/processing/ginja/__init__.py +2 -0
  58. data_designer/engine/processing/ginja/ast.py +65 -0
  59. data_designer/engine/processing/ginja/environment.py +463 -0
  60. data_designer/engine/processing/ginja/exceptions.py +56 -0
  61. data_designer/engine/processing/ginja/record.py +32 -0
  62. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  63. data_designer/engine/processing/gsonschema/exceptions.py +15 -0
  64. data_designer/engine/processing/gsonschema/schema_transformers.py +83 -0
  65. data_designer/engine/processing/gsonschema/types.py +10 -0
  66. data_designer/engine/processing/gsonschema/validators.py +202 -0
  67. data_designer/engine/processing/processors/base.py +13 -0
  68. data_designer/engine/processing/processors/drop_columns.py +42 -0
  69. data_designer/engine/processing/processors/registry.py +25 -0
  70. data_designer/engine/processing/processors/schema_transform.py +71 -0
  71. data_designer/engine/processing/utils.py +169 -0
  72. data_designer/engine/registry/base.py +99 -0
  73. data_designer/engine/registry/data_designer_registry.py +39 -0
  74. data_designer/engine/registry/errors.py +12 -0
  75. data_designer/engine/resources/managed_dataset_generator.py +39 -0
  76. data_designer/engine/resources/managed_dataset_repository.py +197 -0
  77. data_designer/engine/resources/managed_storage.py +65 -0
  78. data_designer/engine/resources/resource_provider.py +77 -0
  79. data_designer/engine/resources/seed_reader.py +154 -0
  80. data_designer/engine/sampling_gen/column.py +91 -0
  81. data_designer/engine/sampling_gen/constraints.py +100 -0
  82. data_designer/engine/sampling_gen/data_sources/base.py +217 -0
  83. data_designer/engine/sampling_gen/data_sources/errors.py +12 -0
  84. data_designer/engine/sampling_gen/data_sources/sources.py +347 -0
  85. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  86. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  87. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +90 -0
  88. data_designer/engine/sampling_gen/entities/email_address_utils.py +171 -0
  89. data_designer/engine/sampling_gen/entities/errors.py +10 -0
  90. data_designer/engine/sampling_gen/entities/national_id_utils.py +102 -0
  91. data_designer/engine/sampling_gen/entities/person.py +144 -0
  92. data_designer/engine/sampling_gen/entities/phone_number.py +128 -0
  93. data_designer/engine/sampling_gen/errors.py +26 -0
  94. data_designer/engine/sampling_gen/generator.py +122 -0
  95. data_designer/engine/sampling_gen/jinja_utils.py +64 -0
  96. data_designer/engine/sampling_gen/people_gen.py +199 -0
  97. data_designer/engine/sampling_gen/person_constants.py +56 -0
  98. data_designer/engine/sampling_gen/schema.py +147 -0
  99. data_designer/engine/sampling_gen/schema_builder.py +61 -0
  100. data_designer/engine/sampling_gen/utils.py +46 -0
  101. data_designer/engine/secret_resolver.py +82 -0
  102. data_designer/engine/testing/__init__.py +12 -0
  103. data_designer/engine/testing/stubs.py +133 -0
  104. data_designer/engine/testing/utils.py +20 -0
  105. data_designer/engine/validation.py +367 -0
  106. data_designer/engine/validators/__init__.py +19 -0
  107. data_designer/engine/validators/base.py +38 -0
  108. data_designer/engine/validators/local_callable.py +39 -0
  109. data_designer/engine/validators/python.py +254 -0
  110. data_designer/engine/validators/remote.py +89 -0
  111. data_designer/engine/validators/sql.py +65 -0
  112. data_designer_engine-0.4.0.dist-info/METADATA +50 -0
  113. data_designer_engine-0.4.0.dist-info/RECORD +114 -0
  114. data_designer_engine-0.4.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,354 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import functools
7
+ import importlib.metadata
8
+ import json
9
+ import logging
10
+ import time
11
+ import uuid
12
+ from pathlib import Path
13
+ from typing import TYPE_CHECKING, Callable
14
+
15
+ from data_designer.config.column_types import ColumnConfigT
16
+ from data_designer.config.config_builder import BuilderConfig
17
+ from data_designer.config.data_designer_config import DataDesignerConfig
18
+ from data_designer.config.dataset_builders import BuildStage
19
+ from data_designer.config.processors import (
20
+ DropColumnsProcessorConfig,
21
+ ProcessorConfig,
22
+ ProcessorType,
23
+ )
24
+ from data_designer.engine.column_generators.generators.base import (
25
+ ColumnGenerator,
26
+ ColumnGeneratorWithModel,
27
+ GenerationStrategy,
28
+ )
29
+ from data_designer.engine.column_generators.utils.generator_classification import column_type_is_model_generated
30
+ from data_designer.engine.compiler import compile_data_designer_config
31
+ from data_designer.engine.dataset_builders.artifact_storage import SDG_CONFIG_FILENAME, ArtifactStorage
32
+ from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
33
+ from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig
34
+ from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor
35
+ from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs
36
+ from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager
37
+ from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker
38
+ from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler
39
+ from data_designer.engine.processing.processors.base import Processor
40
+ from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
41
+ from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
42
+ from data_designer.engine.resources.resource_provider import ResourceProvider
43
+ from data_designer.lazy_heavy_imports import pd
44
+
45
+ if TYPE_CHECKING:
46
+ import pandas as pd
47
+
48
+ from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModelRegistry
49
+ from data_designer.engine.models.usage import ModelUsageStats
50
+
51
+ logger = logging.getLogger(__name__)
52
+
53
+ _CLIENT_VERSION: str = importlib.metadata.version("data-designer-engine")
54
+
55
+
56
+ class ColumnWiseDatasetBuilder:
57
+ def __init__(
58
+ self,
59
+ data_designer_config: DataDesignerConfig,
60
+ resource_provider: ResourceProvider,
61
+ registry: DataDesignerRegistry | None = None,
62
+ ):
63
+ self.batch_manager = DatasetBatchManager(resource_provider.artifact_storage)
64
+ self._resource_provider = resource_provider
65
+ self._records_to_drop: set[int] = set()
66
+ self._registry = registry or DataDesignerRegistry()
67
+
68
+ self._data_designer_config = compile_data_designer_config(data_designer_config, resource_provider)
69
+ self._column_configs = compile_dataset_builder_column_configs(self._data_designer_config)
70
+ self._processors: dict[BuildStage, list[Processor]] = self._initialize_processors(
71
+ self._data_designer_config.processors or []
72
+ )
73
+ self._validate_column_configs()
74
+
75
+ @property
76
+ def artifact_storage(self) -> ArtifactStorage:
77
+ return self._resource_provider.artifact_storage
78
+
79
+ @functools.cached_property
80
+ def single_column_configs(self) -> list[ColumnConfigT]:
81
+ configs = []
82
+ for config in self._column_configs:
83
+ if isinstance(config, MultiColumnConfig):
84
+ configs.extend(config.columns)
85
+ else:
86
+ configs.append(config)
87
+ return configs
88
+
89
+ @functools.cached_property
90
+ def llm_generated_column_configs(self) -> list[ColumnConfigT]:
91
+ return [config for config in self.single_column_configs if column_type_is_model_generated(config.column_type)]
92
+
93
+ def build(
94
+ self,
95
+ *,
96
+ num_records: int,
97
+ on_batch_complete: Callable[[Path], None] | None = None,
98
+ ) -> Path:
99
+ self._run_model_health_check_if_needed()
100
+ self._write_builder_config()
101
+ generators = self._initialize_generators()
102
+ start_time = time.perf_counter()
103
+ group_id = uuid.uuid4().hex
104
+
105
+ buffer_size = self._resource_provider.run_config.buffer_size
106
+ self.batch_manager.start(num_records=num_records, buffer_size=buffer_size)
107
+ for batch_idx in range(self.batch_manager.num_batches):
108
+ logger.info(f"⏳ Processing batch {batch_idx + 1} of {self.batch_manager.num_batches}")
109
+ self._run_batch(generators, batch_mode="batch", group_id=group_id)
110
+ df_batch = self._run_processors(
111
+ stage=BuildStage.POST_BATCH,
112
+ dataframe=self.batch_manager.get_current_batch(as_dataframe=True),
113
+ current_batch_number=batch_idx,
114
+ )
115
+ self._write_processed_batch(df_batch)
116
+ self.batch_manager.finish_batch(on_batch_complete)
117
+ self.batch_manager.finish()
118
+
119
+ model_usage_stats = self._resource_provider.model_registry.get_model_usage_stats(
120
+ time.perf_counter() - start_time
121
+ )
122
+ logger.info(f"📊 Model usage summary:\n{json.dumps(model_usage_stats, indent=4)}")
123
+
124
+ return self.artifact_storage.final_dataset_path
125
+
126
+ def build_preview(self, *, num_records: int) -> pd.DataFrame:
127
+ self._run_model_health_check_if_needed()
128
+
129
+ generators = self._initialize_generators()
130
+ group_id = uuid.uuid4().hex
131
+ start_time = time.perf_counter()
132
+ self.batch_manager.start(num_records=num_records, buffer_size=num_records)
133
+ self._run_batch(generators, batch_mode="preview", save_partial_results=False, group_id=group_id)
134
+ dataset = self.batch_manager.get_current_batch(as_dataframe=True)
135
+ self.batch_manager.reset()
136
+
137
+ model_usage_stats = self._resource_provider.model_registry.get_model_usage_stats(
138
+ time.perf_counter() - start_time
139
+ )
140
+ logger.info(f"📊 Model usage summary:\n{json.dumps(model_usage_stats, indent=4)}")
141
+
142
+ return dataset
143
+
144
+ def process_preview(self, dataset: pd.DataFrame) -> pd.DataFrame:
145
+ return self._run_processors(
146
+ stage=BuildStage.POST_BATCH,
147
+ dataframe=dataset.copy(),
148
+ current_batch_number=None, # preview mode does not have a batch number
149
+ )
150
+
151
+ def _initialize_generators(self) -> list[ColumnGenerator]:
152
+ return [
153
+ self._registry.column_generators.get_for_config_type(type(config))(
154
+ config=config, resource_provider=self._resource_provider
155
+ )
156
+ for config in self._column_configs
157
+ ]
158
+
159
+ def _write_builder_config(self) -> None:
160
+ self.artifact_storage.mkdir_if_needed(self.artifact_storage.base_dataset_path)
161
+ BuilderConfig(data_designer=self._data_designer_config).to_json(
162
+ self.artifact_storage.base_dataset_path / SDG_CONFIG_FILENAME
163
+ )
164
+
165
+ def _run_batch(
166
+ self, generators: list[ColumnGenerator], *, batch_mode: str, save_partial_results: bool = True, group_id: str
167
+ ) -> None:
168
+ pre_batch_snapshot = self._resource_provider.model_registry.get_model_usage_snapshot()
169
+ for generator in generators:
170
+ generator.log_pre_generation()
171
+ try:
172
+ generation_strategy = generator.get_generation_strategy()
173
+ if generator.can_generate_from_scratch and self.batch_manager.buffer_is_empty:
174
+ self._run_from_scratch_column_generator(generator)
175
+ elif generation_strategy == GenerationStrategy.CELL_BY_CELL:
176
+ self._run_cell_by_cell_generator(generator)
177
+ elif generation_strategy == GenerationStrategy.FULL_COLUMN:
178
+ self._run_full_column_generator(generator)
179
+ else:
180
+ logger.error(f"❌ Unknown generation strategy: {generation_strategy}")
181
+ raise DatasetGenerationError(f"🛑 Unknown generation strategy: {generation_strategy}")
182
+ if save_partial_results:
183
+ self.batch_manager.write()
184
+ except Exception as e:
185
+ column_error_str = (
186
+ f"columns {generator.config.column_names}"
187
+ if hasattr(generator.config, "column_names")
188
+ else f"column {generator.config.name!r}"
189
+ )
190
+ raise DatasetGenerationError(f"🛑 Failed to process {column_error_str}:\n{e}")
191
+
192
+ try:
193
+ usage_deltas = self._resource_provider.model_registry.get_usage_deltas(pre_batch_snapshot)
194
+ self._emit_batch_inference_events(batch_mode, usage_deltas, group_id)
195
+ except Exception:
196
+ pass
197
+
198
+ def _run_from_scratch_column_generator(self, generator: ColumnGenerator) -> None:
199
+ df = generator.generate_from_scratch(self.batch_manager.num_records_batch)
200
+ self.batch_manager.add_records(df.to_dict(orient="records"))
201
+
202
+ def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
203
+ max_workers = self._resource_provider.run_config.non_inference_max_parallel_workers
204
+ if isinstance(generator, ColumnGeneratorWithModel):
205
+ max_workers = generator.inference_parameters.max_parallel_requests
206
+ self._fan_out_with_threads(generator, max_workers=max_workers)
207
+
208
+ def _run_full_column_generator(self, generator: ColumnGenerator) -> None:
209
+ df = generator.generate(self.batch_manager.get_current_batch(as_dataframe=True))
210
+ self.batch_manager.update_records(df.to_dict(orient="records"))
211
+
212
+ def _run_model_health_check_if_needed(self) -> bool:
213
+ if any(column_type_is_model_generated(config.column_type) for config in self.single_column_configs):
214
+ self._resource_provider.model_registry.run_health_check(
215
+ list(set(config.model_alias for config in self.llm_generated_column_configs))
216
+ )
217
+
218
+ def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None:
219
+ if generator.get_generation_strategy() != GenerationStrategy.CELL_BY_CELL:
220
+ raise DatasetGenerationError(
221
+ f"Generator {generator.name} is not a {GenerationStrategy.CELL_BY_CELL} "
222
+ "generator so concurrency through threads is not supported."
223
+ )
224
+
225
+ progress_tracker = ProgressTracker(
226
+ total_records=self.batch_manager.num_records_batch,
227
+ label=f"{generator.config.column_type} column '{generator.config.name}'",
228
+ )
229
+ progress_tracker.log_start(max_workers)
230
+
231
+ settings = self._resource_provider.run_config
232
+ with ConcurrentThreadExecutor(
233
+ max_workers=max_workers,
234
+ column_name=generator.config.name,
235
+ result_callback=self._make_result_callback(progress_tracker),
236
+ error_callback=self._make_error_callback(progress_tracker),
237
+ shutdown_error_rate=settings.shutdown_error_rate,
238
+ shutdown_error_window=settings.shutdown_error_window,
239
+ disable_early_shutdown=settings.disable_early_shutdown,
240
+ ) as executor:
241
+ for i, record in self.batch_manager.iter_current_batch():
242
+ executor.submit(lambda record: generator.generate(record), record, context={"index": i})
243
+
244
+ progress_tracker.log_final()
245
+
246
+ if len(self._records_to_drop) > 0:
247
+ self.batch_manager.drop_records(self._records_to_drop)
248
+ self._records_to_drop.clear()
249
+
250
+ def _make_result_callback(self, progress_tracker: ProgressTracker) -> Callable[[dict], None]:
251
+ def callback(result: dict, *, context: dict | None = None) -> None:
252
+ self._worker_result_callback(result, context=context)
253
+ progress_tracker.record_success()
254
+
255
+ return callback
256
+
257
+ def _make_error_callback(self, progress_tracker: ProgressTracker) -> Callable[[Exception], None]:
258
+ def callback(exc: Exception, *, context: dict | None = None) -> None:
259
+ self._worker_error_callback(exc, context=context)
260
+ progress_tracker.record_failure()
261
+
262
+ return callback
263
+
264
+ def _write_processed_batch(self, dataframe: pd.DataFrame) -> None:
265
+ self.batch_manager.update_records(dataframe.to_dict(orient="records"))
266
+ self.batch_manager.write()
267
+
268
+ def _validate_column_configs(self) -> None:
269
+ if len(self._column_configs) == 0:
270
+ raise DatasetGenerationError("🛑 No column configs provided.")
271
+
272
+ if not self._registry.column_generators.get_for_config_type(
273
+ type(self._column_configs[0])
274
+ ).can_generate_from_scratch:
275
+ raise DatasetGenerationError("🛑 The first column config must be a from-scratch column generator.")
276
+
277
+ def _initialize_processors(self, processor_configs: list[ProcessorConfig]) -> dict[BuildStage, list[Processor]]:
278
+ # Check columns marked for drop
279
+ columns_to_drop = [config.name for config in self.single_column_configs if config.drop]
280
+
281
+ processors: dict[BuildStage, list[Processor]] = {stage: [] for stage in BuildStage}
282
+ for config in processor_configs:
283
+ processors[config.build_stage].append(
284
+ self._registry.processors.get_for_config_type(type(config))(
285
+ config=config,
286
+ resource_provider=self._resource_provider,
287
+ )
288
+ )
289
+
290
+ # Manually included "drop columns" processor takes precedence (can e.g., pick stages other than post-batch)
291
+ if config.processor_type == ProcessorType.DROP_COLUMNS:
292
+ for column in config.column_names:
293
+ if column in columns_to_drop:
294
+ columns_to_drop.remove(column)
295
+
296
+ # If there are still columns marked for drop, add the "drop columns" processor to drop them
297
+ if len(columns_to_drop) > 0:
298
+ processors[BuildStage.POST_BATCH].append( # as post-batch by default
299
+ DropColumnsProcessor(
300
+ config=DropColumnsProcessorConfig(
301
+ name="default_drop_columns_processor",
302
+ column_names=columns_to_drop,
303
+ build_stage=BuildStage.POST_BATCH,
304
+ ),
305
+ resource_provider=self._resource_provider,
306
+ )
307
+ )
308
+
309
+ return processors
310
+
311
+ def _run_processors(
312
+ self, stage: BuildStage, dataframe: pd.DataFrame, current_batch_number: int | None = None
313
+ ) -> pd.DataFrame:
314
+ for processor in self._processors[stage]:
315
+ try:
316
+ dataframe = processor.process(dataframe, current_batch_number=current_batch_number)
317
+ except Exception as e:
318
+ raise DatasetProcessingError(
319
+ f"🛑 Failed to process dataset with processor {processor.name} in stage {stage}: {e}"
320
+ ) from e
321
+ return dataframe
322
+
323
+ def _worker_error_callback(self, exc: Exception, *, context: dict | None = None) -> None:
324
+ """If a worker fails, we can handle the exception here."""
325
+ logger.warning(
326
+ f"⚠️ Generation for record at index {context['index']} failed. "
327
+ f"Will omit this record from the dataset.\n{exc}"
328
+ )
329
+ self._records_to_drop.add(context["index"])
330
+
331
+ def _worker_result_callback(self, result: dict, *, context: dict | None = None) -> None:
332
+ self.batch_manager.update_record(context["index"], result)
333
+
334
+ def _emit_batch_inference_events(
335
+ self, batch_mode: str, usage_deltas: dict[str, ModelUsageStats], group_id: str
336
+ ) -> None:
337
+ if not usage_deltas:
338
+ return
339
+
340
+ events = [
341
+ InferenceEvent(
342
+ nemo_source=NemoSourceEnum.DATADESIGNER,
343
+ task=batch_mode,
344
+ task_status=TaskStatusEnum.SUCCESS,
345
+ model=model_name,
346
+ input_tokens=delta.token_usage.input_tokens,
347
+ output_tokens=delta.token_usage.output_tokens,
348
+ )
349
+ for model_name, delta in usage_deltas.items()
350
+ ]
351
+
352
+ with TelemetryHandler(source_client_version=_CLIENT_VERSION, session_id=group_id) as telemetry_handler:
353
+ for event in events:
354
+ telemetry_handler.enqueue(event)
@@ -0,0 +1,15 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from data_designer.engine.errors import DataDesignerError
7
+
8
+
9
+ class ArtifactStorageError(DataDesignerError): ...
10
+
11
+
12
+ class DatasetGenerationError(DataDesignerError): ...
13
+
14
+
15
+ class DatasetProcessingError(DataDesignerError): ...
@@ -0,0 +1,46 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from abc import ABC
7
+ from typing import TypeAlias
8
+
9
+ from pydantic import Field, field_validator
10
+
11
+ from data_designer.config.base import ConfigBase
12
+ from data_designer.config.column_configs import SamplerColumnConfig, SeedDatasetColumnConfig, SingleColumnConfig
13
+ from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType
14
+ from data_designer.config.sampler_constraints import ColumnConstraintT
15
+ from data_designer.config.seed import SeedConfig
16
+
17
+
18
+ class MultiColumnConfig(ConfigBase, ABC):
19
+ columns: list[SingleColumnConfig] = Field(..., min_length=1)
20
+
21
+ @property
22
+ def column_names(self) -> list[str]:
23
+ return [c.name for c in self.columns]
24
+
25
+ @property
26
+ def column_type(self) -> DataDesignerColumnType:
27
+ return self.columns[0].column_type
28
+
29
+ @field_validator("columns", mode="after")
30
+ def validate_column_types_are_the_same(cls, v: list[SingleColumnConfig]) -> list[SingleColumnConfig]:
31
+ if len(set([c.column_type for c in v])) != 1:
32
+ raise ValueError("All column types must be of the same type")
33
+ return v
34
+
35
+
36
+ class SamplerMultiColumnConfig(MultiColumnConfig):
37
+ columns: list[SamplerColumnConfig]
38
+ constraints: list[ColumnConstraintT] = []
39
+ max_rejections_factor: int = 5
40
+
41
+
42
+ class SeedDatasetMultiColumnConfig(SeedConfig, MultiColumnConfig):
43
+ columns: list[SeedDatasetColumnConfig]
44
+
45
+
46
+ DatasetBuilderColumnConfigT: TypeAlias = ColumnConfigT | SeedDatasetMultiColumnConfig | SamplerMultiColumnConfig
@@ -0,0 +1,2 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,212 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import contextvars
7
+ import json
8
+ import logging
9
+ from concurrent.futures import Future, ThreadPoolExecutor
10
+ from threading import Lock, Semaphore
11
+ from typing import Any, Protocol
12
+
13
+ from pydantic import BaseModel, Field
14
+
15
+ from data_designer.engine.errors import DataDesignerRuntimeError, ErrorTrap
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class ExecutorResults(BaseModel):
21
+ failure_threshold: float = 0.0 # Error rate threshold
22
+ completed_count: int = 0 # How many tasks/jobs completed
23
+ success_count: int = 0 # How many tasks/jobs were successful
24
+ early_shutdown: bool = False # Did we shutdown early due to errors?
25
+ error_trap: ErrorTrap = Field(default_factory=ErrorTrap)
26
+
27
+ @property
28
+ def summary(self) -> dict:
29
+ summary = self.model_dump(exclude={"error_trap"})
30
+ summary |= self.error_trap.model_dump()
31
+ return summary
32
+
33
+ def get_error_rate(self, window: int) -> float:
34
+ # We don't start actually tracking until our minimum window size is met
35
+ if self.completed_count < window:
36
+ return 0.0
37
+ return self.error_trap.error_count / max(1, self.completed_count)
38
+
39
+ def is_error_rate_exceeded(self, window: int) -> bool:
40
+ return self.get_error_rate(window) >= self.failure_threshold
41
+
42
+
43
+ class CallbackWithContext(Protocol):
44
+ """Executor callback functions must accept a context kw argument."""
45
+
46
+ def __call__(self, result: Any, *, context: dict | None = None) -> Any: ...
47
+
48
+
49
+ class ErrorCallbackWithContext(Protocol):
50
+ """Error callbacks take the Exception instance and context."""
51
+
52
+ def __call__(self, exc: Exception, *, context: dict | None = None) -> Any: ...
53
+
54
+
55
+ class ConcurrentThreadExecutor:
56
+ """
57
+ Interface for executing multiple concurrent tasks with error rate monitoring.
58
+
59
+ This interface should be used exclusively as
60
+ a context manager. New tasks can be submitted to the executor using the `submit`
61
+ method. This submit method functions similarly to the
62
+ submit method of a ThreadPoolExecutor.
63
+
64
+ The underlying queue of tasks is bounded by the `max_workers`
65
+ parameter. This means that only `max_workers` number of
66
+ tasks can be queued up for execution. As tasks complete,
67
+ if there are errors, those are tracked and counted. If
68
+ a certain error rate is exceeded, the executor will shutdown
69
+ early. All queued and running tasks will complete.
70
+
71
+ The reason we bound the underlying task queue is to ensure that when
72
+ a certain error threshold is met there aren't an unbounded
73
+ number of tasks that need to complete. Generally speaking,
74
+ tasks should not be sitting in the queue for long at all since
75
+ the queue size == `max_workers`. The side effect of this is that
76
+ the `submit()` method will block, however this should not matter
77
+ because upstream Tasks need to wait for all jobs to complete
78
+ before the Task can be considered complete.
79
+
80
+ ContextVars from the main parent thread are automatically propagated
81
+ to all child threads.
82
+
83
+ When a task is completed, the user provided `result_callback`
84
+ function will be called with the task result as the only argument.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ *,
90
+ max_workers: int,
91
+ column_name: str,
92
+ result_callback: CallbackWithContext | None = None,
93
+ error_callback: ErrorCallbackWithContext | None = None,
94
+ shutdown_error_rate: float = 0.50,
95
+ shutdown_error_window: int = 10,
96
+ disable_early_shutdown: bool = False,
97
+ ):
98
+ self._executor = None
99
+ self._column_name = column_name
100
+ self._max_workers = max_workers
101
+ self._lock = Lock()
102
+ self._semaphore = Semaphore(self._max_workers)
103
+ self._result_callback = result_callback
104
+ self._error_callback = error_callback
105
+ self._shutdown_error_rate = shutdown_error_rate
106
+ self._shutdown_window_size = shutdown_error_window
107
+ self._disable_early_shutdown = disable_early_shutdown
108
+ self._results = ExecutorResults(failure_threshold=shutdown_error_rate)
109
+
110
+ @property
111
+ def results(self) -> ExecutorResults:
112
+ return self._results
113
+
114
+ @property
115
+ def max_workers(self) -> int:
116
+ return self._max_workers
117
+
118
+ @property
119
+ def shutdown_error_rate(self) -> float:
120
+ return self._shutdown_error_rate
121
+
122
+ @property
123
+ def shutdown_window_size(self) -> int:
124
+ return self._shutdown_window_size
125
+
126
+ @property
127
+ def semaphore(self) -> Semaphore:
128
+ return self._semaphore
129
+
130
+ def __enter__(self) -> ConcurrentThreadExecutor:
131
+ self._executor = ThreadPoolExecutor(
132
+ max_workers=self._max_workers,
133
+ thread_name_prefix="ConcurrentThreadExecutor",
134
+ initializer=_set_worker_contextvars,
135
+ initargs=(contextvars.copy_context(),),
136
+ )
137
+ return self
138
+
139
+ def __exit__(self, exc_type, exc_value, traceback):
140
+ self._shutdown_executor()
141
+ if not self._disable_early_shutdown and self._results.early_shutdown is True:
142
+ self._raise_task_error()
143
+
144
+ def _shutdown_executor(self) -> None:
145
+ if self._executor is not None:
146
+ self._executor.shutdown()
147
+
148
+ def _raise_task_error(self):
149
+ raise DataDesignerRuntimeError(
150
+ "\n".join(
151
+ [
152
+ " |-- Data generation was terminated early due to error rate exceeding threshold.",
153
+ f" |-- The summary of encountered errors is: \n{json.dumps(self._results.summary, indent=4)}",
154
+ ]
155
+ )
156
+ )
157
+
158
+ def submit(self, fn, *args, context: dict | None = None, **kwargs) -> None:
159
+ if self._executor is None:
160
+ raise RuntimeError("Executor is not initialized, this class should be used as a context manager.")
161
+
162
+ if not self._disable_early_shutdown and self._results.early_shutdown:
163
+ self._shutdown_executor()
164
+ self._raise_task_error()
165
+
166
+ def _handle_future(future: Future) -> None:
167
+ try:
168
+ result = future.result()
169
+ if self._result_callback is not None:
170
+ self._result_callback(result, context=context)
171
+ with self._lock:
172
+ self._results.completed_count += 1
173
+ self._results.success_count += 1
174
+ except Exception as err:
175
+ with self._lock:
176
+ self._results.completed_count += 1
177
+ self._results.error_trap.handle_error(err)
178
+ if not self._disable_early_shutdown and self._results.is_error_rate_exceeded(
179
+ self._shutdown_window_size
180
+ ):
181
+ # Signal to shutdown early on the next submission (if received).
182
+ # We cannot trigger shutdown from within this thread as it can
183
+ # cause a deadlock.
184
+ if not self._results.early_shutdown:
185
+ self._results.early_shutdown = True
186
+ if self._error_callback is not None:
187
+ self._error_callback(err, context=context)
188
+ finally:
189
+ self._semaphore.release()
190
+
191
+ try:
192
+ self._semaphore.acquire()
193
+ future = self._executor.submit(fn, *args, **kwargs)
194
+ future.add_done_callback(_handle_future)
195
+ except Exception as err:
196
+ # If we get here, the pool is shutting down (likely due to early termination from errors)
197
+ # We'll re-raise a custom error that can be handled at the call-site and the summary
198
+ # can also be inspected.
199
+ self._semaphore.release()
200
+ is_shutdown_error = isinstance(err, RuntimeError) and (
201
+ "after shutdown" in str(err) or "Pool shutdown" in str(err)
202
+ )
203
+ if not is_shutdown_error:
204
+ raise err
205
+ if self._disable_early_shutdown:
206
+ raise err
207
+ self._raise_task_error()
208
+
209
+
210
+ def _set_worker_contextvars(context: contextvars.Context):
211
+ for var, value in context.items():
212
+ var.set(value)