data-designer 0.3.4__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. data_designer/__init__.py +2 -0
  2. data_designer/_version.py +2 -2
  3. data_designer/cli/__init__.py +2 -0
  4. data_designer/cli/commands/download.py +2 -0
  5. data_designer/cli/commands/list.py +2 -0
  6. data_designer/cli/commands/models.py +2 -0
  7. data_designer/cli/commands/providers.py +2 -0
  8. data_designer/cli/commands/reset.py +2 -0
  9. data_designer/cli/controllers/__init__.py +2 -0
  10. data_designer/cli/controllers/download_controller.py +2 -0
  11. data_designer/cli/controllers/model_controller.py +6 -1
  12. data_designer/cli/controllers/provider_controller.py +6 -1
  13. data_designer/cli/forms/__init__.py +2 -0
  14. data_designer/cli/forms/builder.py +2 -0
  15. data_designer/cli/forms/field.py +2 -0
  16. data_designer/cli/forms/form.py +2 -0
  17. data_designer/cli/forms/model_builder.py +2 -0
  18. data_designer/cli/forms/provider_builder.py +2 -0
  19. data_designer/cli/main.py +2 -0
  20. data_designer/cli/repositories/__init__.py +2 -0
  21. data_designer/cli/repositories/base.py +2 -0
  22. data_designer/cli/repositories/model_repository.py +2 -0
  23. data_designer/cli/repositories/persona_repository.py +2 -0
  24. data_designer/cli/repositories/provider_repository.py +2 -0
  25. data_designer/cli/services/__init__.py +2 -0
  26. data_designer/cli/services/download_service.py +2 -0
  27. data_designer/cli/services/model_service.py +2 -0
  28. data_designer/cli/services/provider_service.py +2 -0
  29. data_designer/cli/ui.py +2 -0
  30. data_designer/cli/utils.py +2 -0
  31. data_designer/config/analysis/column_profilers.py +2 -0
  32. data_designer/config/analysis/column_statistics.py +8 -5
  33. data_designer/config/analysis/dataset_profiler.py +9 -3
  34. data_designer/config/analysis/utils/errors.py +2 -0
  35. data_designer/config/analysis/utils/reporting.py +7 -3
  36. data_designer/config/column_configs.py +77 -7
  37. data_designer/config/column_types.py +33 -36
  38. data_designer/config/dataset_builders.py +2 -0
  39. data_designer/config/default_model_settings.py +1 -0
  40. data_designer/config/errors.py +2 -0
  41. data_designer/config/exports.py +2 -0
  42. data_designer/config/interface.py +3 -2
  43. data_designer/config/models.py +7 -2
  44. data_designer/config/preview_results.py +7 -3
  45. data_designer/config/processors.py +2 -0
  46. data_designer/config/run_config.py +2 -0
  47. data_designer/config/sampler_constraints.py +2 -0
  48. data_designer/config/sampler_params.py +7 -2
  49. data_designer/config/seed.py +2 -0
  50. data_designer/config/seed_source.py +7 -2
  51. data_designer/config/seed_source_types.py +2 -0
  52. data_designer/config/utils/constants.py +2 -0
  53. data_designer/config/utils/errors.py +2 -0
  54. data_designer/config/utils/info.py +2 -0
  55. data_designer/config/utils/io_helpers.py +8 -3
  56. data_designer/config/utils/misc.py +2 -2
  57. data_designer/config/utils/numerical_helpers.py +2 -0
  58. data_designer/config/utils/type_helpers.py +2 -0
  59. data_designer/config/utils/visualization.py +8 -4
  60. data_designer/config/validator_params.py +2 -0
  61. data_designer/engine/analysis/column_profilers/base.py +9 -8
  62. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +15 -19
  63. data_designer/engine/analysis/column_profilers/registry.py +2 -0
  64. data_designer/engine/analysis/column_statistics.py +5 -2
  65. data_designer/engine/analysis/dataset_profiler.py +12 -9
  66. data_designer/engine/analysis/errors.py +2 -0
  67. data_designer/engine/analysis/utils/column_statistics_calculations.py +7 -4
  68. data_designer/engine/analysis/utils/judge_score_processing.py +7 -3
  69. data_designer/engine/column_generators/generators/base.py +26 -14
  70. data_designer/engine/column_generators/generators/embedding.py +4 -11
  71. data_designer/engine/column_generators/generators/expression.py +7 -16
  72. data_designer/engine/column_generators/generators/llm_completion.py +11 -37
  73. data_designer/engine/column_generators/generators/samplers.py +8 -14
  74. data_designer/engine/column_generators/generators/seed_dataset.py +9 -15
  75. data_designer/engine/column_generators/generators/validation.py +8 -20
  76. data_designer/engine/column_generators/registry.py +2 -0
  77. data_designer/engine/column_generators/utils/errors.py +2 -0
  78. data_designer/engine/column_generators/utils/generator_classification.py +2 -0
  79. data_designer/engine/column_generators/utils/judge_score_factory.py +2 -0
  80. data_designer/engine/column_generators/utils/prompt_renderer.py +4 -2
  81. data_designer/engine/compiler.py +3 -6
  82. data_designer/engine/configurable_task.py +12 -13
  83. data_designer/engine/dataset_builders/artifact_storage.py +87 -8
  84. data_designer/engine/dataset_builders/column_wise_builder.py +32 -34
  85. data_designer/engine/dataset_builders/errors.py +2 -0
  86. data_designer/engine/dataset_builders/multi_column_configs.py +2 -0
  87. data_designer/engine/dataset_builders/utils/config_compiler.py +2 -0
  88. data_designer/engine/dataset_builders/utils/dag.py +7 -2
  89. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +9 -6
  90. data_designer/engine/dataset_builders/utils/errors.py +2 -0
  91. data_designer/engine/errors.py +2 -0
  92. data_designer/engine/model_provider.py +2 -0
  93. data_designer/engine/models/errors.py +23 -31
  94. data_designer/engine/models/facade.py +12 -9
  95. data_designer/engine/models/factory.py +42 -0
  96. data_designer/engine/models/litellm_overrides.py +16 -11
  97. data_designer/engine/models/parsers/errors.py +2 -0
  98. data_designer/engine/models/parsers/parser.py +2 -2
  99. data_designer/engine/models/parsers/postprocessors.py +1 -0
  100. data_designer/engine/models/parsers/tag_parsers.py +2 -0
  101. data_designer/engine/models/parsers/types.py +2 -0
  102. data_designer/engine/models/recipes/base.py +2 -0
  103. data_designer/engine/models/recipes/response_recipes.py +2 -0
  104. data_designer/engine/models/registry.py +11 -18
  105. data_designer/engine/models/telemetry.py +6 -2
  106. data_designer/engine/processing/ginja/ast.py +2 -0
  107. data_designer/engine/processing/ginja/environment.py +2 -0
  108. data_designer/engine/processing/ginja/exceptions.py +2 -0
  109. data_designer/engine/processing/ginja/record.py +2 -0
  110. data_designer/engine/processing/gsonschema/exceptions.py +9 -2
  111. data_designer/engine/processing/gsonschema/schema_transformers.py +2 -0
  112. data_designer/engine/processing/gsonschema/types.py +2 -0
  113. data_designer/engine/processing/gsonschema/validators.py +10 -6
  114. data_designer/engine/processing/processors/base.py +1 -5
  115. data_designer/engine/processing/processors/drop_columns.py +7 -10
  116. data_designer/engine/processing/processors/registry.py +2 -0
  117. data_designer/engine/processing/processors/schema_transform.py +7 -10
  118. data_designer/engine/processing/utils.py +7 -3
  119. data_designer/engine/registry/base.py +2 -0
  120. data_designer/engine/registry/data_designer_registry.py +2 -0
  121. data_designer/engine/registry/errors.py +2 -0
  122. data_designer/engine/resources/managed_dataset_generator.py +6 -2
  123. data_designer/engine/resources/managed_dataset_repository.py +8 -5
  124. data_designer/engine/resources/managed_storage.py +2 -0
  125. data_designer/engine/resources/resource_provider.py +8 -1
  126. data_designer/engine/resources/seed_reader.py +7 -2
  127. data_designer/engine/sampling_gen/column.py +2 -0
  128. data_designer/engine/sampling_gen/constraints.py +8 -2
  129. data_designer/engine/sampling_gen/data_sources/base.py +10 -7
  130. data_designer/engine/sampling_gen/data_sources/errors.py +2 -0
  131. data_designer/engine/sampling_gen/data_sources/sources.py +27 -22
  132. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +2 -2
  133. data_designer/engine/sampling_gen/entities/email_address_utils.py +2 -0
  134. data_designer/engine/sampling_gen/entities/errors.py +2 -0
  135. data_designer/engine/sampling_gen/entities/national_id_utils.py +2 -0
  136. data_designer/engine/sampling_gen/entities/person.py +2 -0
  137. data_designer/engine/sampling_gen/entities/phone_number.py +8 -1
  138. data_designer/engine/sampling_gen/errors.py +2 -0
  139. data_designer/engine/sampling_gen/generator.py +5 -4
  140. data_designer/engine/sampling_gen/jinja_utils.py +7 -3
  141. data_designer/engine/sampling_gen/people_gen.py +7 -7
  142. data_designer/engine/sampling_gen/person_constants.py +2 -0
  143. data_designer/engine/sampling_gen/schema.py +5 -1
  144. data_designer/engine/sampling_gen/schema_builder.py +2 -0
  145. data_designer/engine/sampling_gen/utils.py +7 -1
  146. data_designer/engine/secret_resolver.py +2 -0
  147. data_designer/engine/validation.py +2 -2
  148. data_designer/engine/validators/__init__.py +2 -0
  149. data_designer/engine/validators/base.py +2 -0
  150. data_designer/engine/validators/local_callable.py +7 -2
  151. data_designer/engine/validators/python.py +7 -1
  152. data_designer/engine/validators/remote.py +7 -1
  153. data_designer/engine/validators/sql.py +8 -3
  154. data_designer/errors.py +2 -0
  155. data_designer/essentials/__init__.py +2 -0
  156. data_designer/interface/data_designer.py +23 -17
  157. data_designer/interface/errors.py +2 -0
  158. data_designer/interface/results.py +5 -2
  159. data_designer/lazy_heavy_imports.py +54 -0
  160. data_designer/logging.py +2 -0
  161. data_designer/plugins/__init__.py +2 -0
  162. data_designer/plugins/errors.py +2 -0
  163. data_designer/plugins/plugin.py +0 -1
  164. data_designer/plugins/registry.py +2 -0
  165. data_designer/plugins/testing/__init__.py +2 -0
  166. data_designer/plugins/testing/stubs.py +21 -43
  167. data_designer/plugins/testing/utils.py +2 -0
  168. {data_designer-0.3.4.dist-info → data_designer-0.3.5.dist-info}/METADATA +12 -5
  169. data_designer-0.3.5.dist-info/RECORD +196 -0
  170. data_designer-0.3.4.dist-info/RECORD +0 -194
  171. {data_designer-0.3.4.dist-info → data_designer-0.3.5.dist-info}/WHEEL +0 -0
  172. {data_designer-0.3.4.dist-info → data_designer-0.3.5.dist-info}/entry_points.txt +0 -0
  173. {data_designer-0.3.4.dist-info → data_designer-0.3.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,23 +1,30 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import json
5
7
  import logging
6
8
  import shutil
7
9
  from datetime import datetime
8
10
  from functools import cached_property
9
11
  from pathlib import Path
12
+ from typing import TYPE_CHECKING
10
13
 
11
- import pandas as pd
12
14
  from pydantic import BaseModel, field_validator, model_validator
13
15
 
14
16
  from data_designer.config.utils.io_helpers import read_parquet_dataset
15
17
  from data_designer.config.utils.type_helpers import StrEnum, resolve_string_enum
16
18
  from data_designer.engine.dataset_builders.errors import ArtifactStorageError
19
+ from data_designer.lazy_heavy_imports import pd
20
+
21
+ if TYPE_CHECKING:
22
+ import pandas as pd
17
23
 
18
24
  logger = logging.getLogger(__name__)
19
25
 
20
26
  BATCH_FILE_NAME_FORMAT = "batch_{batch_number:05d}.parquet"
27
+ SDG_CONFIG_FILENAME = "sdg.json"
21
28
 
22
29
 
23
30
  class BatchStage(StrEnum):
@@ -164,12 +171,6 @@ class ArtifactStorage(BaseModel):
164
171
  shutil.move(partial_result_path, final_file_path)
165
172
  return final_file_path
166
173
 
167
- def write_configs(self, json_file_name: str, configs: list[dict]) -> Path:
168
- self.mkdir_if_needed(self.base_dataset_path)
169
- with open(self.base_dataset_path / json_file_name, "w") as file:
170
- json.dump([c.model_dump(mode="json") for c in configs], file, indent=4)
171
- return self.base_dataset_path / json_file_name
172
-
173
174
  def write_batch_to_parquet_file(
174
175
  self,
175
176
  batch_number: int,
@@ -194,11 +195,89 @@ class ArtifactStorage(BaseModel):
194
195
  dataframe.to_parquet(file_path, index=False)
195
196
  return file_path
196
197
 
198
+ def get_parquet_file_paths(self) -> list[str]:
199
+ """Get list of parquet file paths relative to base_dataset_path.
200
+
201
+ Returns:
202
+ List of relative paths to parquet files in the final dataset folder.
203
+ """
204
+ return [str(f.relative_to(self.base_dataset_path)) for f in sorted(self.final_dataset_path.glob("*.parquet"))]
205
+
206
+ def get_processor_file_paths(self) -> dict[str, list[str]]:
207
+ """Get processor output files organized by processor name.
208
+
209
+ Returns:
210
+ Dictionary mapping processor names to lists of relative file paths.
211
+ """
212
+ processor_files: dict[str, list[str]] = {}
213
+ if self.processors_outputs_path.exists():
214
+ for processor_dir in sorted(self.processors_outputs_path.iterdir()):
215
+ if processor_dir.is_dir():
216
+ processor_name = processor_dir.name
217
+ processor_files[processor_name] = [
218
+ str(f.relative_to(self.base_dataset_path))
219
+ for f in sorted(processor_dir.rglob("*"))
220
+ if f.is_file()
221
+ ]
222
+ return processor_files
223
+
224
+ def get_file_paths(self) -> dict[str, list[str] | dict[str, list[str]]]:
225
+ """Get all file paths organized by type.
226
+
227
+ Returns:
228
+ Dictionary with 'parquet-files' and 'processor-files' keys.
229
+ """
230
+ file_paths = {
231
+ "parquet-files": self.get_parquet_file_paths(),
232
+ }
233
+ processor_file_paths = self.get_processor_file_paths()
234
+ if processor_file_paths:
235
+ file_paths["processor-files"] = processor_file_paths
236
+
237
+ return file_paths
238
+
239
+ def read_metadata(self) -> dict:
240
+ """Read metadata from the metadata.json file.
241
+
242
+ Returns:
243
+ Dictionary containing the metadata.
244
+
245
+ Raises:
246
+ FileNotFoundError: If metadata file doesn't exist.
247
+ """
248
+ with open(self.metadata_file_path, "r") as file:
249
+ return json.load(file)
250
+
197
251
  def write_metadata(self, metadata: dict) -> Path:
252
+ """Write metadata to the metadata.json file.
253
+
254
+ Args:
255
+ metadata: Dictionary containing metadata to write.
256
+
257
+ Returns:
258
+ Path to the written metadata file.
259
+ """
198
260
  self.mkdir_if_needed(self.base_dataset_path)
199
261
  with open(self.metadata_file_path, "w") as file:
200
- json.dump(metadata, file)
262
+ json.dump(metadata, file, indent=4, sort_keys=True)
201
263
  return self.metadata_file_path
202
264
 
265
+ def update_metadata(self, updates: dict) -> Path:
266
+ """Update existing metadata with new fields.
267
+
268
+ Args:
269
+ updates: Dictionary of fields to add/update in metadata.
270
+
271
+ Returns:
272
+ Path to the updated metadata file.
273
+ """
274
+ try:
275
+ existing_metadata = self.read_metadata()
276
+ except FileNotFoundError:
277
+ existing_metadata = {}
278
+
279
+ existing_metadata.update(updates)
280
+ return self.write_metadata(existing_metadata)
281
+
203
282
  def _get_stage_path(self, stage: BatchStage) -> Path:
204
283
  return getattr(self, resolve_string_enum(stage, BatchStage).value)
@@ -12,9 +12,9 @@ import uuid
12
12
  from pathlib import Path
13
13
  from typing import TYPE_CHECKING, Callable
14
14
 
15
- import pandas as pd
16
-
17
15
  from data_designer.config.column_types import ColumnConfigT
16
+ from data_designer.config.config_builder import BuilderConfig
17
+ from data_designer.config.data_designer_config import DataDesignerConfig
18
18
  from data_designer.config.dataset_builders import BuildStage
19
19
  from data_designer.config.processors import (
20
20
  DropColumnsProcessorConfig,
@@ -27,40 +27,38 @@ from data_designer.engine.column_generators.generators.base import (
27
27
  GenerationStrategy,
28
28
  )
29
29
  from data_designer.engine.column_generators.utils.generator_classification import column_type_is_model_generated
30
- from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
30
+ from data_designer.engine.compiler import compile_data_designer_config
31
+ from data_designer.engine.dataset_builders.artifact_storage import SDG_CONFIG_FILENAME, ArtifactStorage
31
32
  from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
32
- from data_designer.engine.dataset_builders.multi_column_configs import (
33
- DatasetBuilderColumnConfigT,
34
- MultiColumnConfig,
35
- )
33
+ from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig
36
34
  from data_designer.engine.dataset_builders.utils.concurrency import (
37
35
  MAX_CONCURRENCY_PER_NON_LLM_GENERATOR,
38
36
  ConcurrentThreadExecutor,
39
37
  )
40
- from data_designer.engine.dataset_builders.utils.dataset_batch_manager import (
41
- DatasetBatchManager,
42
- )
38
+ from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs
39
+ from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager
43
40
  from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler
44
41
  from data_designer.engine.processing.processors.base import Processor
45
42
  from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
46
43
  from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
47
44
  from data_designer.engine.resources.resource_provider import ResourceProvider
45
+ from data_designer.lazy_heavy_imports import pd
48
46
 
49
47
  if TYPE_CHECKING:
48
+ import pandas as pd
49
+
50
50
  from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModelRegistry
51
51
  from data_designer.engine.models.usage import ModelUsageStats
52
52
 
53
53
  logger = logging.getLogger(__name__)
54
54
 
55
-
56
55
  _CLIENT_VERSION: str = importlib.metadata.version("data_designer")
57
56
 
58
57
 
59
58
  class ColumnWiseDatasetBuilder:
60
59
  def __init__(
61
60
  self,
62
- column_configs: list[DatasetBuilderColumnConfigT],
63
- processor_configs: list[ProcessorConfig],
61
+ data_designer_config: DataDesignerConfig,
64
62
  resource_provider: ResourceProvider,
65
63
  registry: DataDesignerRegistry | None = None,
66
64
  ):
@@ -68,8 +66,12 @@ class ColumnWiseDatasetBuilder:
68
66
  self._resource_provider = resource_provider
69
67
  self._records_to_drop: set[int] = set()
70
68
  self._registry = registry or DataDesignerRegistry()
71
- self._column_configs = column_configs
72
- self._processors: dict[BuildStage, list[Processor]] = self._initialize_processors(processor_configs)
69
+
70
+ self._data_designer_config = compile_data_designer_config(data_designer_config, resource_provider)
71
+ self._column_configs = compile_dataset_builder_column_configs(self._data_designer_config)
72
+ self._processors: dict[BuildStage, list[Processor]] = self._initialize_processors(
73
+ self._data_designer_config.processors or []
74
+ )
73
75
  self._validate_column_configs()
74
76
 
75
77
  @property
@@ -96,9 +98,8 @@ class ColumnWiseDatasetBuilder:
96
98
  num_records: int,
97
99
  on_batch_complete: Callable[[Path], None] | None = None,
98
100
  ) -> Path:
99
- self._write_configs()
100
101
  self._run_model_health_check_if_needed()
101
-
102
+ self._write_builder_config()
102
103
  generators = self._initialize_generators()
103
104
  start_time = time.perf_counter()
104
105
  group_id = uuid.uuid4().hex
@@ -157,6 +158,12 @@ class ColumnWiseDatasetBuilder:
157
158
  for config in self._column_configs
158
159
  ]
159
160
 
161
+ def _write_builder_config(self) -> None:
162
+ self.artifact_storage.mkdir_if_needed(self.artifact_storage.base_dataset_path)
163
+ BuilderConfig(data_designer=self._data_designer_config).to_json(
164
+ self.artifact_storage.base_dataset_path / SDG_CONFIG_FILENAME
165
+ )
166
+
160
167
  def _run_batch(
161
168
  self, generators: list[ColumnGenerator], *, batch_mode: str, save_partial_results: bool = True, group_id: str
162
169
  ) -> None:
@@ -164,15 +171,16 @@ class ColumnWiseDatasetBuilder:
164
171
  for generator in generators:
165
172
  generator.log_pre_generation()
166
173
  try:
174
+ generation_strategy = generator.get_generation_strategy()
167
175
  if generator.can_generate_from_scratch and self.batch_manager.buffer_is_empty:
168
176
  self._run_from_scratch_column_generator(generator)
169
- elif generator.generation_strategy == GenerationStrategy.CELL_BY_CELL:
177
+ elif generation_strategy == GenerationStrategy.CELL_BY_CELL:
170
178
  self._run_cell_by_cell_generator(generator)
171
- elif generator.generation_strategy == GenerationStrategy.FULL_COLUMN:
179
+ elif generation_strategy == GenerationStrategy.FULL_COLUMN:
172
180
  self._run_full_column_generator(generator)
173
181
  else:
174
- logger.error(f"❌ Unknown generation strategy: {generator.generation_strategy}")
175
- raise DatasetGenerationError(f"🛑 Unknown generation strategy: {generator.generation_strategy}")
182
+ logger.error(f"❌ Unknown generation strategy: {generation_strategy}")
183
+ raise DatasetGenerationError(f"🛑 Unknown generation strategy: {generation_strategy}")
176
184
  if save_partial_results:
177
185
  self.batch_manager.write()
178
186
  except Exception as e:
@@ -210,9 +218,9 @@ class ColumnWiseDatasetBuilder:
210
218
  )
211
219
 
212
220
  def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None:
213
- if generator.generation_strategy != GenerationStrategy.CELL_BY_CELL:
221
+ if generator.get_generation_strategy() != GenerationStrategy.CELL_BY_CELL:
214
222
  raise DatasetGenerationError(
215
- f"Generator {generator.metadata().name} is not a {GenerationStrategy.CELL_BY_CELL} "
223
+ f"Generator {generator.name} is not a {GenerationStrategy.CELL_BY_CELL} "
216
224
  "generator so concurrency through threads is not supported."
217
225
  )
218
226
 
@@ -292,7 +300,7 @@ class ColumnWiseDatasetBuilder:
292
300
  dataframe = processor.process(dataframe, current_batch_number=current_batch_number)
293
301
  except Exception as e:
294
302
  raise DatasetProcessingError(
295
- f"🛑 Failed to process dataset with processor {processor.metadata().name} in stage {stage}: {e}"
303
+ f"🛑 Failed to process dataset with processor {processor.name} in stage {stage}: {e}"
296
304
  ) from e
297
305
  return dataframe
298
306
 
@@ -307,16 +315,6 @@ class ColumnWiseDatasetBuilder:
307
315
  def _worker_result_callback(self, result: dict, *, context: dict | None = None) -> None:
308
316
  self.batch_manager.update_record(context["index"], result)
309
317
 
310
- def _write_configs(self) -> None:
311
- self.artifact_storage.write_configs(
312
- json_file_name="column_configs.json",
313
- configs=self._column_configs,
314
- )
315
- self.artifact_storage.write_configs(
316
- json_file_name="model_configs.json",
317
- configs=self._resource_provider.model_registry.model_configs.values(),
318
- )
319
-
320
318
  def _emit_batch_inference_events(
321
319
  self, batch_mode: str, usage_deltas: dict[str, ModelUsageStats], group_id: str
322
320
  ) -> None:
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from data_designer.engine.errors import DataDesignerError
5
7
 
6
8
 
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from abc import ABC
5
7
  from typing import TypeAlias
6
8
 
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from data_designer.config.column_types import DataDesignerColumnType
5
7
  from data_designer.config.data_designer_config import DataDesignerConfig
6
8
  from data_designer.config.processors import ProcessorConfig
@@ -1,13 +1,18 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
- import logging
4
+ from __future__ import annotations
5
5
 
6
- import networkx as nx
6
+ import logging
7
+ from typing import TYPE_CHECKING
7
8
 
8
9
  from data_designer.config.column_types import ColumnConfigT
9
10
  from data_designer.engine.column_generators.utils.generator_classification import column_type_used_in_execution_dag
10
11
  from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
12
+ from data_designer.lazy_heavy_imports import nx
13
+
14
+ if TYPE_CHECKING:
15
+ import networkx as nx
11
16
 
12
17
  logger = logging.getLogger(__name__)
13
18
 
@@ -1,16 +1,20 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import logging
5
7
  import shutil
6
8
  from pathlib import Path
7
- from typing import Callable, Container, Iterator
8
-
9
- import pandas as pd
10
- import pyarrow.parquet as pq
9
+ from typing import TYPE_CHECKING, Callable, Container, Iterator
11
10
 
12
11
  from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage, BatchStage
13
12
  from data_designer.engine.dataset_builders.utils.errors import DatasetBatchManagementError
13
+ from data_designer.lazy_heavy_imports import pd, pq
14
+
15
+ if TYPE_CHECKING:
16
+ import pandas as pd
17
+ import pyarrow.parquet as pq
14
18
 
15
19
  logger = logging.getLogger(__name__)
16
20
 
@@ -87,8 +91,7 @@ class DatasetBatchManager:
87
91
  "total_num_batches": self.num_batches,
88
92
  "buffer_size": self._buffer_size,
89
93
  "schema": {field.name: str(field.type) for field in pq.read_schema(final_file_path)},
90
- "file_paths": [str(f) for f in sorted(self.artifact_storage.final_dataset_path.glob("*.parquet"))],
91
- "num_records": self.num_records_list[: self._current_batch_number + 1],
94
+ "file_paths": self.artifact_storage.get_file_paths(),
92
95
  "num_completed_batches": self._current_batch_number + 1,
93
96
  "dataset_name": self.artifact_storage.dataset_name,
94
97
  }
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from data_designer.engine.errors import DataDesignerError
5
7
 
6
8
 
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from pydantic import BaseModel, Field
5
7
 
6
8
  from data_designer.errors import DataDesignerError
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from functools import cached_property
5
7
 
6
8
  from pydantic import BaseModel, field_validator, model_validator
@@ -6,25 +6,15 @@ from __future__ import annotations
6
6
  import logging
7
7
  from collections.abc import Callable
8
8
  from functools import wraps
9
- from typing import Any
10
-
11
- from litellm.exceptions import (
12
- APIConnectionError,
13
- APIError,
14
- AuthenticationError,
15
- BadRequestError,
16
- ContextWindowExceededError,
17
- InternalServerError,
18
- NotFoundError,
19
- PermissionDeniedError,
20
- RateLimitError,
21
- Timeout,
22
- UnprocessableEntityError,
23
- UnsupportedParamsError,
24
- )
9
+ from typing import TYPE_CHECKING, Any
10
+
25
11
  from pydantic import BaseModel
26
12
 
27
13
  from data_designer.engine.errors import DataDesignerError
14
+ from data_designer.lazy_heavy_imports import litellm
15
+
16
+ if TYPE_CHECKING:
17
+ import litellm
28
18
 
29
19
  logger = logging.getLogger(__name__)
30
20
 
@@ -132,10 +122,10 @@ def handle_llm_exceptions(
132
122
  err_msg_parser = DownstreamLLMExceptionMessageParser(model_name, model_provider_name, purpose)
133
123
  match exception:
134
124
  # Common errors that can come from LiteLLM
135
- case APIError():
125
+ case litellm.exceptions.APIError():
136
126
  raise err_msg_parser.parse_api_error(exception, authentication_error) from None
137
127
 
138
- case APIConnectionError():
128
+ case litellm.exceptions.APIConnectionError():
139
129
  raise ModelAPIConnectionError(
140
130
  FormattedLLMErrorMessage(
141
131
  cause=f"Connection to model {model_name!r} hosted on model provider {model_provider_name!r} failed while {purpose}.",
@@ -143,13 +133,13 @@ def handle_llm_exceptions(
143
133
  )
144
134
  ) from None
145
135
 
146
- case AuthenticationError():
136
+ case litellm.exceptions.AuthenticationError():
147
137
  raise ModelAuthenticationError(authentication_error) from None
148
138
 
149
- case ContextWindowExceededError():
139
+ case litellm.exceptions.ContextWindowExceededError():
150
140
  raise err_msg_parser.parse_context_window_exceeded_error(exception) from None
151
141
 
152
- case UnsupportedParamsError():
142
+ case litellm.exceptions.UnsupportedParamsError():
153
143
  raise ModelUnsupportedParamsError(
154
144
  FormattedLLMErrorMessage(
155
145
  cause=f"One or more of the parameters you provided were found to be unsupported by model {model_name!r} while {purpose}.",
@@ -157,10 +147,10 @@ def handle_llm_exceptions(
157
147
  )
158
148
  ) from None
159
149
 
160
- case BadRequestError():
150
+ case litellm.exceptions.BadRequestError():
161
151
  raise err_msg_parser.parse_bad_request_error(exception) from None
162
152
 
163
- case InternalServerError():
153
+ case litellm.exceptions.InternalServerError():
164
154
  raise ModelInternalServerError(
165
155
  FormattedLLMErrorMessage(
166
156
  cause=f"Model {model_name!r} is currently experiencing internal server issues while {purpose}.",
@@ -168,7 +158,7 @@ def handle_llm_exceptions(
168
158
  )
169
159
  ) from None
170
160
 
171
- case NotFoundError():
161
+ case litellm.exceptions.NotFoundError():
172
162
  raise ModelNotFoundError(
173
163
  FormattedLLMErrorMessage(
174
164
  cause=f"The specified model {model_name!r} could not be found while {purpose}.",
@@ -176,7 +166,7 @@ def handle_llm_exceptions(
176
166
  )
177
167
  ) from None
178
168
 
179
- case PermissionDeniedError():
169
+ case litellm.exceptions.PermissionDeniedError():
180
170
  raise ModelPermissionDeniedError(
181
171
  FormattedLLMErrorMessage(
182
172
  cause=f"Your API key was found to lack the necessary permissions to use model {model_name!r} while {purpose}.",
@@ -184,7 +174,7 @@ def handle_llm_exceptions(
184
174
  )
185
175
  ) from None
186
176
 
187
- case RateLimitError():
177
+ case litellm.exceptions.RateLimitError():
188
178
  raise ModelRateLimitError(
189
179
  FormattedLLMErrorMessage(
190
180
  cause=f"You have exceeded the rate limit for model {model_name!r} while {purpose}.",
@@ -192,7 +182,7 @@ def handle_llm_exceptions(
192
182
  )
193
183
  ) from None
194
184
 
195
- case Timeout():
185
+ case litellm.exceptions.Timeout():
196
186
  raise ModelTimeoutError(
197
187
  FormattedLLMErrorMessage(
198
188
  cause=f"The request to model {model_name!r} timed out while {purpose}.",
@@ -200,7 +190,7 @@ def handle_llm_exceptions(
200
190
  )
201
191
  ) from None
202
192
 
203
- case UnprocessableEntityError():
193
+ case litellm.exceptions.UnprocessableEntityError():
204
194
  raise ModelUnprocessableEntityError(
205
195
  FormattedLLMErrorMessage(
206
196
  cause=f"The request to model {model_name!r} failed despite correct request format while {purpose}.",
@@ -264,7 +254,7 @@ class DownstreamLLMExceptionMessageParser:
264
254
  self.model_provider_name = model_provider_name
265
255
  self.purpose = purpose
266
256
 
267
- def parse_bad_request_error(self, exception: BadRequestError) -> DataDesignerError:
257
+ def parse_bad_request_error(self, exception: litellm.exceptions.BadRequestError) -> DataDesignerError:
268
258
  err_msg = FormattedLLMErrorMessage(
269
259
  cause=f"The request for model {self.model_name!r} was found to be malformed or missing required parameters while {self.purpose}.",
270
260
  solution="Check your request parameters and try again.",
@@ -276,7 +266,9 @@ class DownstreamLLMExceptionMessageParser:
276
266
  )
277
267
  return ModelBadRequestError(err_msg)
278
268
 
279
- def parse_context_window_exceeded_error(self, exception: ContextWindowExceededError) -> DataDesignerError:
269
+ def parse_context_window_exceeded_error(
270
+ self, exception: litellm.exceptions.ContextWindowExceededError
271
+ ) -> DataDesignerError:
280
272
  cause = f"The input data for model '{self.model_name}' was found to exceed its supported context width while {self.purpose}."
281
273
  try:
282
274
  if "OpenAIException - This model's maximum context length is " in str(exception):
@@ -295,7 +287,7 @@ class DownstreamLLMExceptionMessageParser:
295
287
  )
296
288
 
297
289
  def parse_api_error(
298
- self, exception: InternalServerError, auth_error_msg: FormattedLLMErrorMessage
290
+ self, exception: litellm.exceptions.InternalServerError, auth_error_msg: FormattedLLMErrorMessage
299
291
  ) -> DataDesignerError:
300
292
  if "Error code: 403" in str(exception):
301
293
  return ModelAuthenticationError(auth_error_msg)
@@ -6,10 +6,7 @@ from __future__ import annotations
6
6
  import logging
7
7
  from collections.abc import Callable
8
8
  from copy import deepcopy
9
- from typing import Any
10
-
11
- from litellm.types.router import DeploymentTypedDict, LiteLLM_Params
12
- from litellm.types.utils import EmbeddingResponse, ModelResponse
9
+ from typing import TYPE_CHECKING, Any
13
10
 
14
11
  from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
15
12
  from data_designer.engine.model_provider import ModelProviderRegistry
@@ -23,6 +20,10 @@ from data_designer.engine.models.parsers.errors import ParserException
23
20
  from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
24
21
  from data_designer.engine.models.utils import prompt_to_messages, str_to_message
25
22
  from data_designer.engine.secret_resolver import SecretResolver
23
+ from data_designer.lazy_heavy_imports import litellm
24
+
25
+ if TYPE_CHECKING:
26
+ import litellm
26
27
 
27
28
  logger = logging.getLogger(__name__)
28
29
 
@@ -65,7 +66,9 @@ class ModelFacade:
65
66
  def usage_stats(self) -> ModelUsageStats:
66
67
  return self._usage_stats
67
68
 
68
- def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs) -> ModelResponse:
69
+ def completion(
70
+ self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs
71
+ ) -> litellm.ModelResponse:
69
72
  logger.debug(
70
73
  f"Prompting model {self.model_name!r}...",
71
74
  extra={"model": self.model_name, "messages": messages},
@@ -236,14 +239,14 @@ class ModelFacade:
236
239
  ) from exc
237
240
  return output_obj, reasoning_trace
238
241
 
239
- def _get_litellm_deployment(self, model_config: ModelConfig) -> DeploymentTypedDict:
242
+ def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict:
240
243
  provider = self._model_provider_registry.get_provider(model_config.provider)
241
244
  api_key = None
242
245
  if provider.api_key:
243
246
  api_key = self._secret_resolver.resolve(provider.api_key)
244
247
  api_key = api_key or "not-used-but-required"
245
248
 
246
- litellm_params = LiteLLM_Params(
249
+ litellm_params = litellm.LiteLLM_Params(
247
250
  model=f"{provider.provider_type}/{model_config.model}",
248
251
  api_base=provider.endpoint,
249
252
  api_key=api_key,
@@ -253,7 +256,7 @@ class ModelFacade:
253
256
  "litellm_params": litellm_params.model_dump(),
254
257
  }
255
258
 
256
- def _track_usage(self, response: ModelResponse | None) -> None:
259
+ def _track_usage(self, response: litellm.types.utils.ModelResponse | None) -> None:
257
260
  if response is None:
258
261
  self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
259
262
  return
@@ -270,7 +273,7 @@ class ModelFacade:
270
273
  request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
271
274
  )
272
275
 
273
- def _track_usage_from_embedding(self, response: EmbeddingResponse | None) -> None:
276
+ def _track_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None:
274
277
  if response is None:
275
278
  self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
276
279
  return
@@ -0,0 +1,42 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING
7
+
8
+ from data_designer.config.models import ModelConfig
9
+ from data_designer.engine.model_provider import ModelProviderRegistry
10
+ from data_designer.engine.secret_resolver import SecretResolver
11
+
12
+ if TYPE_CHECKING:
13
+ from data_designer.engine.models.registry import ModelRegistry
14
+
15
+
16
+ def create_model_registry(
17
+ *,
18
+ model_configs: list[ModelConfig] | None = None,
19
+ secret_resolver: SecretResolver,
20
+ model_provider_registry: ModelProviderRegistry,
21
+ ) -> ModelRegistry:
22
+ """Factory function for creating a ModelRegistry instance.
23
+
24
+ Heavy dependencies (litellm, httpx) are deferred until this function is called.
25
+ This is a factory function pattern - imports inside factories are idiomatic Python
26
+ for lazy initialization.
27
+ """
28
+ from data_designer.engine.models.facade import ModelFacade
29
+ from data_designer.engine.models.litellm_overrides import apply_litellm_patches
30
+ from data_designer.engine.models.registry import ModelRegistry
31
+
32
+ apply_litellm_patches()
33
+
34
+ def model_facade_factory(model_config, secret_resolver, model_provider_registry):
35
+ return ModelFacade(model_config, secret_resolver, model_provider_registry)
36
+
37
+ return ModelRegistry(
38
+ model_configs=model_configs,
39
+ secret_resolver=secret_resolver,
40
+ model_provider_registry=model_provider_registry,
41
+ model_facade_factory=model_facade_factory,
42
+ )