data-designer 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. data_designer/__init__.py +2 -0
  2. data_designer/_version.py +2 -2
  3. data_designer/cli/__init__.py +2 -0
  4. data_designer/cli/commands/download.py +2 -0
  5. data_designer/cli/commands/list.py +2 -0
  6. data_designer/cli/commands/models.py +2 -0
  7. data_designer/cli/commands/providers.py +2 -0
  8. data_designer/cli/commands/reset.py +2 -0
  9. data_designer/cli/controllers/__init__.py +2 -0
  10. data_designer/cli/controllers/download_controller.py +2 -0
  11. data_designer/cli/controllers/model_controller.py +6 -1
  12. data_designer/cli/controllers/provider_controller.py +6 -1
  13. data_designer/cli/forms/__init__.py +2 -0
  14. data_designer/cli/forms/builder.py +2 -0
  15. data_designer/cli/forms/field.py +2 -0
  16. data_designer/cli/forms/form.py +2 -0
  17. data_designer/cli/forms/model_builder.py +2 -0
  18. data_designer/cli/forms/provider_builder.py +2 -0
  19. data_designer/cli/main.py +2 -0
  20. data_designer/cli/repositories/__init__.py +2 -0
  21. data_designer/cli/repositories/base.py +2 -0
  22. data_designer/cli/repositories/model_repository.py +2 -0
  23. data_designer/cli/repositories/persona_repository.py +2 -0
  24. data_designer/cli/repositories/provider_repository.py +2 -0
  25. data_designer/cli/services/__init__.py +2 -0
  26. data_designer/cli/services/download_service.py +2 -0
  27. data_designer/cli/services/model_service.py +2 -0
  28. data_designer/cli/services/provider_service.py +2 -0
  29. data_designer/cli/ui.py +2 -0
  30. data_designer/cli/utils.py +2 -0
  31. data_designer/config/analysis/column_profilers.py +2 -0
  32. data_designer/config/analysis/column_statistics.py +8 -5
  33. data_designer/config/analysis/dataset_profiler.py +9 -3
  34. data_designer/config/analysis/utils/errors.py +2 -0
  35. data_designer/config/analysis/utils/reporting.py +7 -3
  36. data_designer/config/base.py +1 -0
  37. data_designer/config/column_configs.py +77 -7
  38. data_designer/config/column_types.py +33 -36
  39. data_designer/config/dataset_builders.py +2 -0
  40. data_designer/config/dataset_metadata.py +18 -0
  41. data_designer/config/default_model_settings.py +1 -0
  42. data_designer/config/errors.py +2 -0
  43. data_designer/config/exports.py +2 -0
  44. data_designer/config/interface.py +3 -2
  45. data_designer/config/models.py +7 -2
  46. data_designer/config/preview_results.py +9 -1
  47. data_designer/config/processors.py +2 -0
  48. data_designer/config/run_config.py +19 -5
  49. data_designer/config/sampler_constraints.py +2 -0
  50. data_designer/config/sampler_params.py +7 -2
  51. data_designer/config/seed.py +2 -0
  52. data_designer/config/seed_source.py +9 -3
  53. data_designer/config/seed_source_types.py +2 -0
  54. data_designer/config/utils/constants.py +2 -0
  55. data_designer/config/utils/errors.py +2 -0
  56. data_designer/config/utils/info.py +2 -0
  57. data_designer/config/utils/io_helpers.py +8 -3
  58. data_designer/config/utils/misc.py +2 -2
  59. data_designer/config/utils/numerical_helpers.py +2 -0
  60. data_designer/config/utils/type_helpers.py +2 -0
  61. data_designer/config/utils/visualization.py +19 -11
  62. data_designer/config/validator_params.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +9 -8
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +15 -19
  65. data_designer/engine/analysis/column_profilers/registry.py +2 -0
  66. data_designer/engine/analysis/column_statistics.py +5 -2
  67. data_designer/engine/analysis/dataset_profiler.py +12 -9
  68. data_designer/engine/analysis/errors.py +2 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +7 -4
  70. data_designer/engine/analysis/utils/judge_score_processing.py +7 -3
  71. data_designer/engine/column_generators/generators/base.py +26 -14
  72. data_designer/engine/column_generators/generators/embedding.py +4 -11
  73. data_designer/engine/column_generators/generators/expression.py +7 -16
  74. data_designer/engine/column_generators/generators/llm_completion.py +13 -47
  75. data_designer/engine/column_generators/generators/samplers.py +8 -14
  76. data_designer/engine/column_generators/generators/seed_dataset.py +9 -15
  77. data_designer/engine/column_generators/generators/validation.py +9 -20
  78. data_designer/engine/column_generators/registry.py +2 -0
  79. data_designer/engine/column_generators/utils/errors.py +2 -0
  80. data_designer/engine/column_generators/utils/generator_classification.py +2 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +2 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +4 -2
  83. data_designer/engine/compiler.py +3 -6
  84. data_designer/engine/configurable_task.py +12 -13
  85. data_designer/engine/dataset_builders/artifact_storage.py +87 -8
  86. data_designer/engine/dataset_builders/column_wise_builder.py +34 -35
  87. data_designer/engine/dataset_builders/errors.py +2 -0
  88. data_designer/engine/dataset_builders/multi_column_configs.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +13 -4
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +2 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +7 -2
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +35 -25
  93. data_designer/engine/dataset_builders/utils/errors.py +2 -0
  94. data_designer/engine/errors.py +2 -0
  95. data_designer/engine/model_provider.py +2 -0
  96. data_designer/engine/models/errors.py +23 -31
  97. data_designer/engine/models/facade.py +12 -9
  98. data_designer/engine/models/factory.py +42 -0
  99. data_designer/engine/models/litellm_overrides.py +16 -11
  100. data_designer/engine/models/parsers/errors.py +2 -0
  101. data_designer/engine/models/parsers/parser.py +2 -2
  102. data_designer/engine/models/parsers/postprocessors.py +1 -0
  103. data_designer/engine/models/parsers/tag_parsers.py +2 -0
  104. data_designer/engine/models/parsers/types.py +2 -0
  105. data_designer/engine/models/recipes/base.py +2 -0
  106. data_designer/engine/models/recipes/response_recipes.py +2 -0
  107. data_designer/engine/models/registry.py +11 -18
  108. data_designer/engine/models/telemetry.py +6 -2
  109. data_designer/engine/processing/ginja/ast.py +2 -0
  110. data_designer/engine/processing/ginja/environment.py +2 -0
  111. data_designer/engine/processing/ginja/exceptions.py +2 -0
  112. data_designer/engine/processing/ginja/record.py +2 -0
  113. data_designer/engine/processing/gsonschema/exceptions.py +9 -2
  114. data_designer/engine/processing/gsonschema/schema_transformers.py +2 -0
  115. data_designer/engine/processing/gsonschema/types.py +2 -0
  116. data_designer/engine/processing/gsonschema/validators.py +10 -6
  117. data_designer/engine/processing/processors/base.py +1 -5
  118. data_designer/engine/processing/processors/drop_columns.py +7 -10
  119. data_designer/engine/processing/processors/registry.py +2 -0
  120. data_designer/engine/processing/processors/schema_transform.py +7 -10
  121. data_designer/engine/processing/utils.py +7 -3
  122. data_designer/engine/registry/base.py +2 -0
  123. data_designer/engine/registry/data_designer_registry.py +2 -0
  124. data_designer/engine/registry/errors.py +2 -0
  125. data_designer/engine/resources/managed_dataset_generator.py +6 -2
  126. data_designer/engine/resources/managed_dataset_repository.py +8 -5
  127. data_designer/engine/resources/managed_storage.py +2 -0
  128. data_designer/engine/resources/resource_provider.py +20 -1
  129. data_designer/engine/resources/seed_reader.py +7 -2
  130. data_designer/engine/sampling_gen/column.py +2 -0
  131. data_designer/engine/sampling_gen/constraints.py +8 -2
  132. data_designer/engine/sampling_gen/data_sources/base.py +10 -7
  133. data_designer/engine/sampling_gen/data_sources/errors.py +2 -0
  134. data_designer/engine/sampling_gen/data_sources/sources.py +27 -22
  135. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +2 -2
  136. data_designer/engine/sampling_gen/entities/email_address_utils.py +2 -0
  137. data_designer/engine/sampling_gen/entities/errors.py +2 -0
  138. data_designer/engine/sampling_gen/entities/national_id_utils.py +2 -0
  139. data_designer/engine/sampling_gen/entities/person.py +2 -0
  140. data_designer/engine/sampling_gen/entities/phone_number.py +8 -1
  141. data_designer/engine/sampling_gen/errors.py +2 -0
  142. data_designer/engine/sampling_gen/generator.py +5 -4
  143. data_designer/engine/sampling_gen/jinja_utils.py +7 -3
  144. data_designer/engine/sampling_gen/people_gen.py +7 -7
  145. data_designer/engine/sampling_gen/person_constants.py +2 -0
  146. data_designer/engine/sampling_gen/schema.py +5 -1
  147. data_designer/engine/sampling_gen/schema_builder.py +2 -0
  148. data_designer/engine/sampling_gen/utils.py +7 -1
  149. data_designer/engine/secret_resolver.py +2 -0
  150. data_designer/engine/validation.py +2 -2
  151. data_designer/engine/validators/__init__.py +2 -0
  152. data_designer/engine/validators/base.py +2 -0
  153. data_designer/engine/validators/local_callable.py +7 -2
  154. data_designer/engine/validators/python.py +7 -1
  155. data_designer/engine/validators/remote.py +7 -1
  156. data_designer/engine/validators/sql.py +8 -3
  157. data_designer/errors.py +2 -0
  158. data_designer/essentials/__init__.py +2 -0
  159. data_designer/interface/data_designer.py +36 -39
  160. data_designer/interface/errors.py +2 -0
  161. data_designer/interface/results.py +9 -2
  162. data_designer/lazy_heavy_imports.py +54 -0
  163. data_designer/logging.py +2 -0
  164. data_designer/plugins/__init__.py +2 -0
  165. data_designer/plugins/errors.py +2 -0
  166. data_designer/plugins/plugin.py +0 -1
  167. data_designer/plugins/registry.py +2 -0
  168. data_designer/plugins/testing/__init__.py +2 -0
  169. data_designer/plugins/testing/stubs.py +21 -43
  170. data_designer/plugins/testing/utils.py +2 -0
  171. {data_designer-0.3.3.dist-info → data_designer-0.3.5.dist-info}/METADATA +19 -4
  172. data_designer-0.3.5.dist-info/RECORD +196 -0
  173. data_designer-0.3.3.dist-info/RECORD +0 -193
  174. {data_designer-0.3.3.dist-info → data_designer-0.3.5.dist-info}/WHEEL +0 -0
  175. {data_designer-0.3.3.dist-info → data_designer-0.3.5.dist-info}/entry_points.txt +0 -0
  176. {data_designer-0.3.3.dist-info → data_designer-0.3.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import json
5
7
  import logging
6
8
 
@@ -8,7 +10,7 @@ from data_designer.config.column_configs import SingleColumnConfig
8
10
  from data_designer.config.column_types import DataDesignerColumnType
9
11
  from data_designer.config.models import ModelConfig
10
12
  from data_designer.config.utils.code_lang import CodeLang
11
- from data_designer.config.utils.misc import get_prompt_template_keywords
13
+ from data_designer.config.utils.misc import extract_keywords_from_jinja2_template
12
14
  from data_designer.config.utils.type_helpers import StrEnum
13
15
  from data_designer.engine.column_generators.utils.errors import PromptTemplateRenderError
14
16
  from data_designer.engine.column_generators.utils.judge_score_factory import (
@@ -56,7 +58,7 @@ class RecordBasedPromptRenderer(WithJinja2UserTemplateRendering):
56
58
  dataset_variables=list(record.keys()),
57
59
  )
58
60
  except (UserTemplateUnsupportedFiltersError, UserTemplateError) as exc:
59
- template_variables = get_prompt_template_keywords(prompt_template)
61
+ template_variables = extract_keywords_from_jinja2_template(prompt_template)
60
62
  missing_columns = list(set(template_variables) - set(record.keys()))
61
63
 
62
64
  error_msg = (
@@ -1,10 +1,11 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import logging
5
7
 
6
8
  from data_designer.config.column_configs import SeedDatasetColumnConfig
7
- from data_designer.config.config_builder import DataDesignerConfigBuilder
8
9
  from data_designer.config.data_designer_config import DataDesignerConfig
9
10
  from data_designer.config.errors import InvalidConfigError
10
11
  from data_designer.engine.resources.resource_provider import ResourceProvider
@@ -14,13 +15,9 @@ from data_designer.engine.validation import ViolationLevel, rich_print_violation
14
15
  logger = logging.getLogger(__name__)
15
16
 
16
17
 
17
- def compile_data_designer_config(
18
- config_builder: DataDesignerConfigBuilder, resource_provider: ResourceProvider
19
- ) -> DataDesignerConfig:
20
- config = config_builder.build()
18
+ def compile_data_designer_config(config: DataDesignerConfig, resource_provider: ResourceProvider) -> DataDesignerConfig:
21
19
  _resolve_and_add_seed_columns(config, resource_provider.seed_reader)
22
20
  _validate(config)
23
-
24
21
  return config
25
22
 
26
23
 
@@ -1,25 +1,24 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
- from abc import ABC, abstractmethod
5
- from pathlib import Path
6
- from typing import Generic, TypeVar, get_origin
4
+ from __future__ import annotations
7
5
 
8
- import pandas as pd
6
+ from abc import ABC
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Generic, TypeVar, get_origin
9
9
 
10
10
  from data_designer.config.base import ConfigBase
11
11
  from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
12
12
  from data_designer.engine.resources.resource_provider import ResourceProvider
13
+ from data_designer.lazy_heavy_imports import pd
14
+
15
+ if TYPE_CHECKING:
16
+ import pandas as pd
13
17
 
14
18
  DataT = TypeVar("DataT", dict, pd.DataFrame)
15
19
  TaskConfigT = TypeVar("ConfigT", bound=ConfigBase)
16
20
 
17
21
 
18
- class ConfigurableTaskMetadata(ConfigBase):
19
- name: str
20
- description: str
21
-
22
-
23
22
  class ConfigurableTask(ABC, Generic[TaskConfigT]):
24
23
  def __init__(self, config: TaskConfigT, resource_provider: ResourceProvider):
25
24
  self._config = self.get_config_type().model_validate(config)
@@ -57,14 +56,14 @@ class ConfigurableTask(ABC, Generic[TaskConfigT]):
57
56
  def config(self) -> TaskConfigT:
58
57
  return self._config
59
58
 
59
+ @property
60
+ def name(self) -> str:
61
+ return self.__class__.__name__
62
+
60
63
  @property
61
64
  def resource_provider(self) -> ResourceProvider:
62
65
  return self._resource_provider
63
66
 
64
- @staticmethod
65
- @abstractmethod
66
- def metadata() -> ConfigurableTaskMetadata: ...
67
-
68
67
  def _initialize(self) -> None:
69
68
  """An internal method for custom initialization logic, which will be called in the constructor."""
70
69
 
@@ -1,23 +1,30 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import json
5
7
  import logging
6
8
  import shutil
7
9
  from datetime import datetime
8
10
  from functools import cached_property
9
11
  from pathlib import Path
12
+ from typing import TYPE_CHECKING
10
13
 
11
- import pandas as pd
12
14
  from pydantic import BaseModel, field_validator, model_validator
13
15
 
14
16
  from data_designer.config.utils.io_helpers import read_parquet_dataset
15
17
  from data_designer.config.utils.type_helpers import StrEnum, resolve_string_enum
16
18
  from data_designer.engine.dataset_builders.errors import ArtifactStorageError
19
+ from data_designer.lazy_heavy_imports import pd
20
+
21
+ if TYPE_CHECKING:
22
+ import pandas as pd
17
23
 
18
24
  logger = logging.getLogger(__name__)
19
25
 
20
26
  BATCH_FILE_NAME_FORMAT = "batch_{batch_number:05d}.parquet"
27
+ SDG_CONFIG_FILENAME = "sdg.json"
21
28
 
22
29
 
23
30
  class BatchStage(StrEnum):
@@ -164,12 +171,6 @@ class ArtifactStorage(BaseModel):
164
171
  shutil.move(partial_result_path, final_file_path)
165
172
  return final_file_path
166
173
 
167
- def write_configs(self, json_file_name: str, configs: list[dict]) -> Path:
168
- self.mkdir_if_needed(self.base_dataset_path)
169
- with open(self.base_dataset_path / json_file_name, "w") as file:
170
- json.dump([c.model_dump(mode="json") for c in configs], file, indent=4)
171
- return self.base_dataset_path / json_file_name
172
-
173
174
  def write_batch_to_parquet_file(
174
175
  self,
175
176
  batch_number: int,
@@ -194,11 +195,89 @@ class ArtifactStorage(BaseModel):
194
195
  dataframe.to_parquet(file_path, index=False)
195
196
  return file_path
196
197
 
198
+ def get_parquet_file_paths(self) -> list[str]:
199
+ """Get list of parquet file paths relative to base_dataset_path.
200
+
201
+ Returns:
202
+ List of relative paths to parquet files in the final dataset folder.
203
+ """
204
+ return [str(f.relative_to(self.base_dataset_path)) for f in sorted(self.final_dataset_path.glob("*.parquet"))]
205
+
206
+ def get_processor_file_paths(self) -> dict[str, list[str]]:
207
+ """Get processor output files organized by processor name.
208
+
209
+ Returns:
210
+ Dictionary mapping processor names to lists of relative file paths.
211
+ """
212
+ processor_files: dict[str, list[str]] = {}
213
+ if self.processors_outputs_path.exists():
214
+ for processor_dir in sorted(self.processors_outputs_path.iterdir()):
215
+ if processor_dir.is_dir():
216
+ processor_name = processor_dir.name
217
+ processor_files[processor_name] = [
218
+ str(f.relative_to(self.base_dataset_path))
219
+ for f in sorted(processor_dir.rglob("*"))
220
+ if f.is_file()
221
+ ]
222
+ return processor_files
223
+
224
+ def get_file_paths(self) -> dict[str, list[str] | dict[str, list[str]]]:
225
+ """Get all file paths organized by type.
226
+
227
+ Returns:
228
+ Dictionary with 'parquet-files' and 'processor-files' keys.
229
+ """
230
+ file_paths = {
231
+ "parquet-files": self.get_parquet_file_paths(),
232
+ }
233
+ processor_file_paths = self.get_processor_file_paths()
234
+ if processor_file_paths:
235
+ file_paths["processor-files"] = processor_file_paths
236
+
237
+ return file_paths
238
+
239
+ def read_metadata(self) -> dict:
240
+ """Read metadata from the metadata.json file.
241
+
242
+ Returns:
243
+ Dictionary containing the metadata.
244
+
245
+ Raises:
246
+ FileNotFoundError: If metadata file doesn't exist.
247
+ """
248
+ with open(self.metadata_file_path, "r") as file:
249
+ return json.load(file)
250
+
197
251
  def write_metadata(self, metadata: dict) -> Path:
252
+ """Write metadata to the metadata.json file.
253
+
254
+ Args:
255
+ metadata: Dictionary containing metadata to write.
256
+
257
+ Returns:
258
+ Path to the written metadata file.
259
+ """
198
260
  self.mkdir_if_needed(self.base_dataset_path)
199
261
  with open(self.metadata_file_path, "w") as file:
200
- json.dump(metadata, file)
262
+ json.dump(metadata, file, indent=4, sort_keys=True)
201
263
  return self.metadata_file_path
202
264
 
265
+ def update_metadata(self, updates: dict) -> Path:
266
+ """Update existing metadata with new fields.
267
+
268
+ Args:
269
+ updates: Dictionary of fields to add/update in metadata.
270
+
271
+ Returns:
272
+ Path to the updated metadata file.
273
+ """
274
+ try:
275
+ existing_metadata = self.read_metadata()
276
+ except FileNotFoundError:
277
+ existing_metadata = {}
278
+
279
+ existing_metadata.update(updates)
280
+ return self.write_metadata(existing_metadata)
281
+
203
282
  def _get_stage_path(self, stage: BatchStage) -> Path:
204
283
  return getattr(self, resolve_string_enum(stage, BatchStage).value)
@@ -12,9 +12,9 @@ import uuid
12
12
  from pathlib import Path
13
13
  from typing import TYPE_CHECKING, Callable
14
14
 
15
- import pandas as pd
16
-
17
15
  from data_designer.config.column_types import ColumnConfigT
16
+ from data_designer.config.config_builder import BuilderConfig
17
+ from data_designer.config.data_designer_config import DataDesignerConfig
18
18
  from data_designer.config.dataset_builders import BuildStage
19
19
  from data_designer.config.processors import (
20
20
  DropColumnsProcessorConfig,
@@ -27,40 +27,38 @@ from data_designer.engine.column_generators.generators.base import (
27
27
  GenerationStrategy,
28
28
  )
29
29
  from data_designer.engine.column_generators.utils.generator_classification import column_type_is_model_generated
30
- from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
30
+ from data_designer.engine.compiler import compile_data_designer_config
31
+ from data_designer.engine.dataset_builders.artifact_storage import SDG_CONFIG_FILENAME, ArtifactStorage
31
32
  from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
32
- from data_designer.engine.dataset_builders.multi_column_configs import (
33
- DatasetBuilderColumnConfigT,
34
- MultiColumnConfig,
35
- )
33
+ from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig
36
34
  from data_designer.engine.dataset_builders.utils.concurrency import (
37
35
  MAX_CONCURRENCY_PER_NON_LLM_GENERATOR,
38
36
  ConcurrentThreadExecutor,
39
37
  )
40
- from data_designer.engine.dataset_builders.utils.dataset_batch_manager import (
41
- DatasetBatchManager,
42
- )
38
+ from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs
39
+ from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager
43
40
  from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler
44
41
  from data_designer.engine.processing.processors.base import Processor
45
42
  from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
46
43
  from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
47
44
  from data_designer.engine.resources.resource_provider import ResourceProvider
45
+ from data_designer.lazy_heavy_imports import pd
48
46
 
49
47
  if TYPE_CHECKING:
48
+ import pandas as pd
49
+
50
50
  from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModelRegistry
51
51
  from data_designer.engine.models.usage import ModelUsageStats
52
52
 
53
53
  logger = logging.getLogger(__name__)
54
54
 
55
-
56
55
  _CLIENT_VERSION: str = importlib.metadata.version("data_designer")
57
56
 
58
57
 
59
58
  class ColumnWiseDatasetBuilder:
60
59
  def __init__(
61
60
  self,
62
- column_configs: list[DatasetBuilderColumnConfigT],
63
- processor_configs: list[ProcessorConfig],
61
+ data_designer_config: DataDesignerConfig,
64
62
  resource_provider: ResourceProvider,
65
63
  registry: DataDesignerRegistry | None = None,
66
64
  ):
@@ -68,8 +66,12 @@ class ColumnWiseDatasetBuilder:
68
66
  self._resource_provider = resource_provider
69
67
  self._records_to_drop: set[int] = set()
70
68
  self._registry = registry or DataDesignerRegistry()
71
- self._column_configs = column_configs
72
- self._processors: dict[BuildStage, list[Processor]] = self._initialize_processors(processor_configs)
69
+
70
+ self._data_designer_config = compile_data_designer_config(data_designer_config, resource_provider)
71
+ self._column_configs = compile_dataset_builder_column_configs(self._data_designer_config)
72
+ self._processors: dict[BuildStage, list[Processor]] = self._initialize_processors(
73
+ self._data_designer_config.processors or []
74
+ )
73
75
  self._validate_column_configs()
74
76
 
75
77
  @property
@@ -94,16 +96,15 @@ class ColumnWiseDatasetBuilder:
94
96
  self,
95
97
  *,
96
98
  num_records: int,
97
- buffer_size: int,
98
99
  on_batch_complete: Callable[[Path], None] | None = None,
99
100
  ) -> Path:
100
- self._write_configs()
101
101
  self._run_model_health_check_if_needed()
102
-
102
+ self._write_builder_config()
103
103
  generators = self._initialize_generators()
104
104
  start_time = time.perf_counter()
105
105
  group_id = uuid.uuid4().hex
106
106
 
107
+ buffer_size = self._resource_provider.run_config.buffer_size
107
108
  self.batch_manager.start(num_records=num_records, buffer_size=buffer_size)
108
109
  for batch_idx in range(self.batch_manager.num_batches):
109
110
  logger.info(f"⏳ Processing batch {batch_idx + 1} of {self.batch_manager.num_batches}")
@@ -157,6 +158,12 @@ class ColumnWiseDatasetBuilder:
157
158
  for config in self._column_configs
158
159
  ]
159
160
 
161
+ def _write_builder_config(self) -> None:
162
+ self.artifact_storage.mkdir_if_needed(self.artifact_storage.base_dataset_path)
163
+ BuilderConfig(data_designer=self._data_designer_config).to_json(
164
+ self.artifact_storage.base_dataset_path / SDG_CONFIG_FILENAME
165
+ )
166
+
160
167
  def _run_batch(
161
168
  self, generators: list[ColumnGenerator], *, batch_mode: str, save_partial_results: bool = True, group_id: str
162
169
  ) -> None:
@@ -164,15 +171,16 @@ class ColumnWiseDatasetBuilder:
164
171
  for generator in generators:
165
172
  generator.log_pre_generation()
166
173
  try:
174
+ generation_strategy = generator.get_generation_strategy()
167
175
  if generator.can_generate_from_scratch and self.batch_manager.buffer_is_empty:
168
176
  self._run_from_scratch_column_generator(generator)
169
- elif generator.generation_strategy == GenerationStrategy.CELL_BY_CELL:
177
+ elif generation_strategy == GenerationStrategy.CELL_BY_CELL:
170
178
  self._run_cell_by_cell_generator(generator)
171
- elif generator.generation_strategy == GenerationStrategy.FULL_COLUMN:
179
+ elif generation_strategy == GenerationStrategy.FULL_COLUMN:
172
180
  self._run_full_column_generator(generator)
173
181
  else:
174
- logger.error(f"❌ Unknown generation strategy: {generator.generation_strategy}")
175
- raise DatasetGenerationError(f"🛑 Unknown generation strategy: {generator.generation_strategy}")
182
+ logger.error(f"❌ Unknown generation strategy: {generation_strategy}")
183
+ raise DatasetGenerationError(f"🛑 Unknown generation strategy: {generation_strategy}")
176
184
  if save_partial_results:
177
185
  self.batch_manager.write()
178
186
  except Exception as e:
@@ -210,9 +218,9 @@ class ColumnWiseDatasetBuilder:
210
218
  )
211
219
 
212
220
  def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None:
213
- if generator.generation_strategy != GenerationStrategy.CELL_BY_CELL:
221
+ if generator.get_generation_strategy() != GenerationStrategy.CELL_BY_CELL:
214
222
  raise DatasetGenerationError(
215
- f"Generator {generator.metadata().name} is not a {GenerationStrategy.CELL_BY_CELL} "
223
+ f"Generator {generator.name} is not a {GenerationStrategy.CELL_BY_CELL} "
216
224
  "generator so concurrency through threads is not supported."
217
225
  )
218
226
 
@@ -228,6 +236,7 @@ class ColumnWiseDatasetBuilder:
228
236
  error_callback=self._worker_error_callback,
229
237
  shutdown_error_rate=settings.shutdown_error_rate,
230
238
  shutdown_error_window=settings.shutdown_error_window,
239
+ disable_early_shutdown=settings.disable_early_shutdown,
231
240
  ) as executor:
232
241
  for i, record in self.batch_manager.iter_current_batch():
233
242
  executor.submit(lambda record: generator.generate(record), record, context={"index": i})
@@ -291,7 +300,7 @@ class ColumnWiseDatasetBuilder:
291
300
  dataframe = processor.process(dataframe, current_batch_number=current_batch_number)
292
301
  except Exception as e:
293
302
  raise DatasetProcessingError(
294
- f"🛑 Failed to process dataset with processor {processor.metadata().name} in stage {stage}: {e}"
303
+ f"🛑 Failed to process dataset with processor {processor.name} in stage {stage}: {e}"
295
304
  ) from e
296
305
  return dataframe
297
306
 
@@ -306,16 +315,6 @@ class ColumnWiseDatasetBuilder:
306
315
  def _worker_result_callback(self, result: dict, *, context: dict | None = None) -> None:
307
316
  self.batch_manager.update_record(context["index"], result)
308
317
 
309
- def _write_configs(self) -> None:
310
- self.artifact_storage.write_configs(
311
- json_file_name="column_configs.json",
312
- configs=self._column_configs,
313
- )
314
- self.artifact_storage.write_configs(
315
- json_file_name="model_configs.json",
316
- configs=self._resource_provider.model_registry.model_configs.values(),
317
- )
318
-
319
318
  def _emit_batch_inference_events(
320
319
  self, batch_mode: str, usage_deltas: dict[str, ModelUsageStats], group_id: str
321
320
  ) -> None:
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from data_designer.engine.errors import DataDesignerError
5
7
 
6
8
 
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from abc import ABC
5
7
  from typing import TypeAlias
6
8
 
@@ -96,6 +96,7 @@ class ConcurrentThreadExecutor:
96
96
  error_callback: ErrorCallbackWithContext | None = None,
97
97
  shutdown_error_rate: float = 0.50,
98
98
  shutdown_error_window: int = 10,
99
+ disable_early_shutdown: bool = False,
99
100
  ):
100
101
  self._executor = None
101
102
  self._column_name = column_name
@@ -106,6 +107,7 @@ class ConcurrentThreadExecutor:
106
107
  self._error_callback = error_callback
107
108
  self._shutdown_error_rate = shutdown_error_rate
108
109
  self._shutdown_window_size = shutdown_error_window
110
+ self._disable_early_shutdown = disable_early_shutdown
109
111
  self._results = ExecutorResults(failure_threshold=shutdown_error_rate)
110
112
 
111
113
  @property
@@ -139,7 +141,7 @@ class ConcurrentThreadExecutor:
139
141
 
140
142
  def __exit__(self, exc_type, exc_value, traceback):
141
143
  self._shutdown_executor()
142
- if self._results.early_shutdown is True:
144
+ if not self._disable_early_shutdown and self._results.early_shutdown is True:
143
145
  self._raise_task_error()
144
146
 
145
147
  def _shutdown_executor(self) -> None:
@@ -160,7 +162,7 @@ class ConcurrentThreadExecutor:
160
162
  if self._executor is None:
161
163
  raise RuntimeError("Executor is not initialized, this class should be used as a context manager.")
162
164
 
163
- if self._results.early_shutdown:
165
+ if not self._disable_early_shutdown and self._results.early_shutdown:
164
166
  self._shutdown_executor()
165
167
  self._raise_task_error()
166
168
 
@@ -176,7 +178,9 @@ class ConcurrentThreadExecutor:
176
178
  with self._lock:
177
179
  self._results.completed_count += 1
178
180
  self._results.error_trap.handle_error(err)
179
- if self._results.is_error_rate_exceeded(self._shutdown_window_size):
181
+ if not self._disable_early_shutdown and self._results.is_error_rate_exceeded(
182
+ self._shutdown_window_size
183
+ ):
180
184
  # Signal to shutdown early on the next submission (if received).
181
185
  # We cannot trigger shutdown from within this thread as it can
182
186
  # cause a deadlock.
@@ -196,7 +200,12 @@ class ConcurrentThreadExecutor:
196
200
  # We'll re-raise a custom error that can be handled at the call-site and the summary
197
201
  # can also be inspected.
198
202
  self._semaphore.release()
199
- if not isinstance(err, RuntimeError) and "after shutdown" not in str(err):
203
+ is_shutdown_error = isinstance(err, RuntimeError) and (
204
+ "after shutdown" in str(err) or "Pool shutdown" in str(err)
205
+ )
206
+ if not is_shutdown_error:
207
+ raise err
208
+ if self._disable_early_shutdown:
200
209
  raise err
201
210
  self._raise_task_error()
202
211
 
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from data_designer.config.column_types import DataDesignerColumnType
5
7
  from data_designer.config.data_designer_config import DataDesignerConfig
6
8
  from data_designer.config.processors import ProcessorConfig
@@ -1,13 +1,18 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
- import logging
4
+ from __future__ import annotations
5
5
 
6
- import networkx as nx
6
+ import logging
7
+ from typing import TYPE_CHECKING
7
8
 
8
9
  from data_designer.config.column_types import ColumnConfigT
9
10
  from data_designer.engine.column_generators.utils.generator_classification import column_type_used_in_execution_dag
10
11
  from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
12
+ from data_designer.lazy_heavy_imports import nx
13
+
14
+ if TYPE_CHECKING:
15
+ import networkx as nx
11
16
 
12
17
  logger = logging.getLogger(__name__)
13
18
 
@@ -1,16 +1,20 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import logging
5
7
  import shutil
6
8
  from pathlib import Path
7
- from typing import Callable, Container, Iterator
8
-
9
- import pandas as pd
10
- import pyarrow.parquet as pq
9
+ from typing import TYPE_CHECKING, Callable, Container, Iterator
11
10
 
12
11
  from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage, BatchStage
13
12
  from data_designer.engine.dataset_builders.utils.errors import DatasetBatchManagementError
13
+ from data_designer.lazy_heavy_imports import pd, pq
14
+
15
+ if TYPE_CHECKING:
16
+ import pandas as pd
17
+ import pyarrow.parquet as pq
14
18
 
15
19
  logger = logging.getLogger(__name__)
16
20
 
@@ -69,7 +73,7 @@ class DatasetBatchManager:
69
73
  def drop_records(self, index: Container[int]) -> None:
70
74
  self._buffer = [record for i, record in enumerate(self._buffer) if i not in index]
71
75
 
72
- def finish_batch(self, on_complete: Callable[[Path], None] | None = None) -> Path:
76
+ def finish_batch(self, on_complete: Callable[[Path], None] | None = None) -> Path | None:
73
77
  """Finish the batch by moving the results from the partial results path to the final parquet folder.
74
78
 
75
79
  Returns:
@@ -78,29 +82,35 @@ class DatasetBatchManager:
78
82
  if self._current_batch_number >= self.num_batches:
79
83
  raise DatasetBatchManagementError("🛑 All batches have been processed.")
80
84
 
81
- if not self.write():
82
- raise DatasetBatchManagementError("🛑 Batch finished without any results to write.")
83
-
84
- final_file_path = self.artifact_storage.move_partial_result_to_final_file_path(self._current_batch_number)
85
-
86
- self.artifact_storage.write_metadata(
87
- {
88
- "target_num_records": sum(self.num_records_list),
89
- "total_num_batches": self.num_batches,
90
- "buffer_size": self._buffer_size,
91
- "schema": {field.name: str(field.type) for field in pq.read_schema(final_file_path)},
92
- "file_paths": [str(f) for f in sorted(self.artifact_storage.final_dataset_path.glob("*.parquet"))],
93
- "num_records": self.num_records_list[: self._current_batch_number + 1],
94
- "num_completed_batches": self._current_batch_number + 1,
95
- "dataset_name": self.artifact_storage.dataset_name,
96
- }
97
- )
85
+ if self.write() is not None:
86
+ final_file_path = self.artifact_storage.move_partial_result_to_final_file_path(self._current_batch_number)
87
+
88
+ self.artifact_storage.write_metadata(
89
+ {
90
+ "target_num_records": sum(self.num_records_list),
91
+ "total_num_batches": self.num_batches,
92
+ "buffer_size": self._buffer_size,
93
+ "schema": {field.name: str(field.type) for field in pq.read_schema(final_file_path)},
94
+ "file_paths": self.artifact_storage.get_file_paths(),
95
+ "num_completed_batches": self._current_batch_number + 1,
96
+ "dataset_name": self.artifact_storage.dataset_name,
97
+ }
98
+ )
99
+
100
+ if on_complete:
101
+ on_complete(final_file_path)
102
+ else:
103
+ final_file_path = None
104
+
105
+ logger.warning(
106
+ f"⚠️ Batch {self._current_batch_number + 1} finished without any results to write. "
107
+ "A partial dataset containing the currently available columns has been written to the partial results "
108
+ f"directory: {self.artifact_storage.partial_results_path}"
109
+ )
110
+
98
111
  self._current_batch_number += 1
99
112
  self._buffer: list[dict] = []
100
113
 
101
- if on_complete:
102
- on_complete(final_file_path)
103
-
104
114
  return final_file_path
105
115
 
106
116
  def finish(self) -> None:
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from data_designer.engine.errors import DataDesignerError
5
7
 
6
8
 
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from pydantic import BaseModel, Field
5
7
 
6
8
  from data_designer.errors import DataDesignerError
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from functools import cached_property
5
7
 
6
8
  from pydantic import BaseModel, field_validator, model_validator