data-designer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. data_designer/__init__.py +15 -0
  2. data_designer/_version.py +34 -0
  3. data_designer/cli/README.md +236 -0
  4. data_designer/cli/__init__.py +6 -0
  5. data_designer/cli/commands/__init__.py +2 -0
  6. data_designer/cli/commands/list.py +130 -0
  7. data_designer/cli/commands/models.py +10 -0
  8. data_designer/cli/commands/providers.py +11 -0
  9. data_designer/cli/commands/reset.py +100 -0
  10. data_designer/cli/controllers/__init__.py +7 -0
  11. data_designer/cli/controllers/model_controller.py +246 -0
  12. data_designer/cli/controllers/provider_controller.py +317 -0
  13. data_designer/cli/forms/__init__.py +20 -0
  14. data_designer/cli/forms/builder.py +51 -0
  15. data_designer/cli/forms/field.py +180 -0
  16. data_designer/cli/forms/form.py +59 -0
  17. data_designer/cli/forms/model_builder.py +125 -0
  18. data_designer/cli/forms/provider_builder.py +76 -0
  19. data_designer/cli/main.py +44 -0
  20. data_designer/cli/repositories/__init__.py +8 -0
  21. data_designer/cli/repositories/base.py +39 -0
  22. data_designer/cli/repositories/model_repository.py +42 -0
  23. data_designer/cli/repositories/provider_repository.py +43 -0
  24. data_designer/cli/services/__init__.py +7 -0
  25. data_designer/cli/services/model_service.py +116 -0
  26. data_designer/cli/services/provider_service.py +111 -0
  27. data_designer/cli/ui.py +448 -0
  28. data_designer/cli/utils.py +47 -0
  29. data_designer/config/__init__.py +2 -0
  30. data_designer/config/analysis/column_profilers.py +89 -0
  31. data_designer/config/analysis/column_statistics.py +274 -0
  32. data_designer/config/analysis/dataset_profiler.py +60 -0
  33. data_designer/config/analysis/utils/errors.py +8 -0
  34. data_designer/config/analysis/utils/reporting.py +188 -0
  35. data_designer/config/base.py +68 -0
  36. data_designer/config/column_configs.py +354 -0
  37. data_designer/config/column_types.py +168 -0
  38. data_designer/config/config_builder.py +660 -0
  39. data_designer/config/data_designer_config.py +40 -0
  40. data_designer/config/dataset_builders.py +11 -0
  41. data_designer/config/datastore.py +151 -0
  42. data_designer/config/default_model_settings.py +123 -0
  43. data_designer/config/errors.py +19 -0
  44. data_designer/config/interface.py +54 -0
  45. data_designer/config/models.py +231 -0
  46. data_designer/config/preview_results.py +32 -0
  47. data_designer/config/processors.py +41 -0
  48. data_designer/config/sampler_constraints.py +51 -0
  49. data_designer/config/sampler_params.py +604 -0
  50. data_designer/config/seed.py +145 -0
  51. data_designer/config/utils/code_lang.py +83 -0
  52. data_designer/config/utils/constants.py +313 -0
  53. data_designer/config/utils/errors.py +19 -0
  54. data_designer/config/utils/info.py +88 -0
  55. data_designer/config/utils/io_helpers.py +273 -0
  56. data_designer/config/utils/misc.py +81 -0
  57. data_designer/config/utils/numerical_helpers.py +28 -0
  58. data_designer/config/utils/type_helpers.py +100 -0
  59. data_designer/config/utils/validation.py +336 -0
  60. data_designer/config/utils/visualization.py +427 -0
  61. data_designer/config/validator_params.py +96 -0
  62. data_designer/engine/__init__.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +55 -0
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
  65. data_designer/engine/analysis/column_profilers/registry.py +20 -0
  66. data_designer/engine/analysis/column_statistics.py +142 -0
  67. data_designer/engine/analysis/dataset_profiler.py +125 -0
  68. data_designer/engine/analysis/errors.py +7 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
  70. data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
  71. data_designer/engine/column_generators/__init__.py +2 -0
  72. data_designer/engine/column_generators/generators/__init__.py +2 -0
  73. data_designer/engine/column_generators/generators/base.py +61 -0
  74. data_designer/engine/column_generators/generators/expression.py +63 -0
  75. data_designer/engine/column_generators/generators/llm_generators.py +172 -0
  76. data_designer/engine/column_generators/generators/samplers.py +75 -0
  77. data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
  78. data_designer/engine/column_generators/generators/validation.py +147 -0
  79. data_designer/engine/column_generators/registry.py +56 -0
  80. data_designer/engine/column_generators/utils/errors.py +13 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
  83. data_designer/engine/configurable_task.py +82 -0
  84. data_designer/engine/dataset_builders/artifact_storage.py +181 -0
  85. data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
  86. data_designer/engine/dataset_builders/errors.py +13 -0
  87. data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
  88. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +56 -0
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
  93. data_designer/engine/dataset_builders/utils/errors.py +13 -0
  94. data_designer/engine/errors.py +49 -0
  95. data_designer/engine/model_provider.py +75 -0
  96. data_designer/engine/models/__init__.py +2 -0
  97. data_designer/engine/models/errors.py +308 -0
  98. data_designer/engine/models/facade.py +225 -0
  99. data_designer/engine/models/litellm_overrides.py +162 -0
  100. data_designer/engine/models/parsers/__init__.py +2 -0
  101. data_designer/engine/models/parsers/errors.py +34 -0
  102. data_designer/engine/models/parsers/parser.py +236 -0
  103. data_designer/engine/models/parsers/postprocessors.py +93 -0
  104. data_designer/engine/models/parsers/tag_parsers.py +60 -0
  105. data_designer/engine/models/parsers/types.py +82 -0
  106. data_designer/engine/models/recipes/base.py +79 -0
  107. data_designer/engine/models/recipes/response_recipes.py +291 -0
  108. data_designer/engine/models/registry.py +118 -0
  109. data_designer/engine/models/usage.py +75 -0
  110. data_designer/engine/models/utils.py +38 -0
  111. data_designer/engine/processing/ginja/__init__.py +2 -0
  112. data_designer/engine/processing/ginja/ast.py +64 -0
  113. data_designer/engine/processing/ginja/environment.py +461 -0
  114. data_designer/engine/processing/ginja/exceptions.py +54 -0
  115. data_designer/engine/processing/ginja/record.py +30 -0
  116. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  117. data_designer/engine/processing/gsonschema/exceptions.py +8 -0
  118. data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
  119. data_designer/engine/processing/gsonschema/types.py +8 -0
  120. data_designer/engine/processing/gsonschema/validators.py +143 -0
  121. data_designer/engine/processing/processors/base.py +15 -0
  122. data_designer/engine/processing/processors/drop_columns.py +46 -0
  123. data_designer/engine/processing/processors/registry.py +20 -0
  124. data_designer/engine/processing/utils.py +120 -0
  125. data_designer/engine/registry/base.py +97 -0
  126. data_designer/engine/registry/data_designer_registry.py +37 -0
  127. data_designer/engine/registry/errors.py +10 -0
  128. data_designer/engine/resources/managed_dataset_generator.py +35 -0
  129. data_designer/engine/resources/managed_dataset_repository.py +194 -0
  130. data_designer/engine/resources/managed_storage.py +63 -0
  131. data_designer/engine/resources/resource_provider.py +46 -0
  132. data_designer/engine/resources/seed_dataset_data_store.py +66 -0
  133. data_designer/engine/sampling_gen/column.py +89 -0
  134. data_designer/engine/sampling_gen/constraints.py +95 -0
  135. data_designer/engine/sampling_gen/data_sources/base.py +214 -0
  136. data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
  137. data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
  138. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  139. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  140. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
  141. data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
  142. data_designer/engine/sampling_gen/entities/errors.py +8 -0
  143. data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
  144. data_designer/engine/sampling_gen/entities/person.py +142 -0
  145. data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
  146. data_designer/engine/sampling_gen/errors.py +24 -0
  147. data_designer/engine/sampling_gen/generator.py +121 -0
  148. data_designer/engine/sampling_gen/jinja_utils.py +60 -0
  149. data_designer/engine/sampling_gen/people_gen.py +203 -0
  150. data_designer/engine/sampling_gen/person_constants.py +54 -0
  151. data_designer/engine/sampling_gen/schema.py +143 -0
  152. data_designer/engine/sampling_gen/schema_builder.py +59 -0
  153. data_designer/engine/sampling_gen/utils.py +40 -0
  154. data_designer/engine/secret_resolver.py +80 -0
  155. data_designer/engine/validators/__init__.py +17 -0
  156. data_designer/engine/validators/base.py +36 -0
  157. data_designer/engine/validators/local_callable.py +34 -0
  158. data_designer/engine/validators/python.py +245 -0
  159. data_designer/engine/validators/remote.py +83 -0
  160. data_designer/engine/validators/sql.py +60 -0
  161. data_designer/errors.py +5 -0
  162. data_designer/essentials/__init__.py +137 -0
  163. data_designer/interface/__init__.py +2 -0
  164. data_designer/interface/data_designer.py +351 -0
  165. data_designer/interface/errors.py +16 -0
  166. data_designer/interface/results.py +55 -0
  167. data_designer/logging.py +161 -0
  168. data_designer/plugin_manager.py +83 -0
  169. data_designer/plugins/__init__.py +6 -0
  170. data_designer/plugins/errors.py +10 -0
  171. data_designer/plugins/plugin.py +69 -0
  172. data_designer/plugins/registry.py +86 -0
  173. data_designer-0.1.0.dist-info/METADATA +173 -0
  174. data_designer-0.1.0.dist-info/RECORD +177 -0
  175. data_designer-0.1.0.dist-info/WHEEL +4 -0
  176. data_designer-0.1.0.dist-info/entry_points.txt +2 -0
  177. data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,40 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import Annotated, Optional
7
+
8
+ from pydantic import Field
9
+
10
+ from .analysis.column_profilers import ColumnProfilerConfigT
11
+ from .base import ExportableConfigBase
12
+ from .column_types import ColumnConfigT
13
+ from .models import ModelConfig
14
+ from .processors import ProcessorConfig
15
+ from .sampler_constraints import ColumnConstraintT
16
+ from .seed import SeedConfig
17
+
18
+
19
+ class DataDesignerConfig(ExportableConfigBase):
20
+ """Configuration for NeMo Data Designer.
21
+
22
+ This class defines the main configuration structure for NeMo Data Designer,
23
+ which orchestrates the generation of synthetic data.
24
+
25
+ Attributes:
26
+ columns: Required list of column configurations defining how each column
27
+ should be generated. Must contain at least one column.
28
+ model_configs: Optional list of model configurations for LLM-based generation.
29
+ Each model config defines the model, provider, and inference parameters.
30
+ seed_config: Optional seed dataset settings to use for generation.
31
+ constraints: Optional list of column constraints.
32
+ profilers: Optional list of column profilers for analyzing generated data characteristics.
33
+ """
34
+
35
+ columns: list[Annotated[ColumnConfigT, Field(discriminator="column_type")]] = Field(min_length=1)
36
+ model_configs: Optional[list[ModelConfig]] = None
37
+ seed_config: Optional[SeedConfig] = None
38
+ constraints: Optional[list[ColumnConstraintT]] = None
39
+ profilers: Optional[list[ColumnProfilerConfigT]] = None
40
+ processors: Optional[list[ProcessorConfig]] = None
@@ -0,0 +1,11 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from enum import Enum
5
+
6
+
7
+ class BuildStage(str, Enum):
8
+ PRE_BATCH = "pre_batch"
9
+ POST_BATCH = "post_batch"
10
+ PRE_GENERATION = "pre_generation"
11
+ POST_GENERATION = "post_generation"
@@ -0,0 +1,151 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Optional, Union
9
+
10
+ from huggingface_hub import HfApi, HfFileSystem
11
+ import pandas as pd
12
+ import pyarrow.parquet as pq
13
+ from pydantic import BaseModel, Field
14
+
15
+ from .errors import InvalidConfigError, InvalidFileFormatError, InvalidFilePathError
16
+ from .utils.io_helpers import VALID_DATASET_FILE_EXTENSIONS, validate_path_contains_files_of_type
17
+
18
+ if TYPE_CHECKING:
19
+ from .seed import SeedDatasetReference
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class DatastoreSettings(BaseModel):
25
+ """Configuration for interacting with a datastore."""
26
+
27
+ endpoint: str = Field(
28
+ ...,
29
+ description="Datastore endpoint. Use 'https://huggingface.co' for the Hugging Face Hub.",
30
+ )
31
+ token: Optional[str] = Field(default=None, description="If needed, token to use for authentication.")
32
+
33
+
34
+ def get_file_column_names(file_path: Union[str, Path], file_type: str) -> list[str]:
35
+ """Extract column names based on file type. Supports glob patterns like '../path/*.parquet'."""
36
+ file_path = Path(file_path)
37
+ if "*" in str(file_path):
38
+ matching_files = sorted(file_path.parent.glob(file_path.name))
39
+ if not matching_files:
40
+ raise InvalidFilePathError(f"🛑 No files found matching pattern: {str(file_path)!r}")
41
+ logger.debug(f"0️⃣ Using the first matching file in {str(file_path)!r} to determine column names in seed dataset")
42
+ file_path = matching_files[0]
43
+
44
+ if file_type == "parquet":
45
+ try:
46
+ schema = pq.read_schema(file_path)
47
+ if hasattr(schema, "names"):
48
+ return schema.names
49
+ else:
50
+ return [field.name for field in schema]
51
+ except Exception as e:
52
+ logger.warning(f"Failed to process parquet file {file_path}: {e}")
53
+ return []
54
+ elif file_type in ["json", "jsonl"]:
55
+ return pd.read_json(file_path, orient="records", lines=True, nrows=1).columns.tolist()
56
+ elif file_type == "csv":
57
+ try:
58
+ df = pd.read_csv(file_path, nrows=1)
59
+ return df.columns.tolist()
60
+ except (pd.errors.EmptyDataError, pd.errors.ParserError) as e:
61
+ logger.warning(f"Failed to process CSV file {file_path}: {e}")
62
+ return []
63
+ else:
64
+ raise InvalidFilePathError(f"🛑 Unsupported file type: {file_type!r}")
65
+
66
+
67
+ def fetch_seed_dataset_column_names(seed_dataset_reference: SeedDatasetReference) -> list[str]:
68
+ if hasattr(seed_dataset_reference, "datastore_settings"):
69
+ return _fetch_seed_dataset_column_names_from_datastore(
70
+ seed_dataset_reference.repo_id,
71
+ seed_dataset_reference.filename,
72
+ seed_dataset_reference.datastore_settings,
73
+ )
74
+ return _fetch_seed_dataset_column_names_from_local_file(seed_dataset_reference.dataset)
75
+
76
+
77
+ def resolve_datastore_settings(datastore_settings: DatastoreSettings | dict | None) -> DatastoreSettings:
78
+ if datastore_settings is None:
79
+ raise InvalidConfigError("🛑 Datastore settings are required in order to upload datasets to the datastore.")
80
+ if isinstance(datastore_settings, DatastoreSettings):
81
+ return datastore_settings
82
+ elif isinstance(datastore_settings, dict):
83
+ return DatastoreSettings.model_validate(datastore_settings)
84
+ else:
85
+ raise InvalidConfigError(
86
+ "🛑 Invalid datastore settings format. Must be DatastoreSettings object or dictionary."
87
+ )
88
+
89
+
90
+ def upload_to_hf_hub(
91
+ dataset_path: Union[str, Path],
92
+ filename: str,
93
+ repo_id: str,
94
+ datastore_settings: DatastoreSettings,
95
+ **kwargs,
96
+ ) -> str:
97
+ datastore_settings = resolve_datastore_settings(datastore_settings)
98
+ dataset_path = _validate_dataset_path(dataset_path)
99
+ filename_ext = filename.split(".")[-1].lower()
100
+ if dataset_path.suffix.lower()[1:] != filename_ext:
101
+ raise InvalidFileFormatError(
102
+ f"🛑 Dataset file extension {dataset_path.suffix!r} does not match `filename` extension .{filename_ext!r}"
103
+ )
104
+
105
+ hfapi = HfApi(endpoint=datastore_settings.endpoint, token=datastore_settings.token)
106
+ hfapi.create_repo(repo_id, exist_ok=True, repo_type="dataset")
107
+ hfapi.upload_file(
108
+ path_or_fileobj=dataset_path,
109
+ path_in_repo=filename,
110
+ repo_id=repo_id,
111
+ repo_type="dataset",
112
+ **kwargs,
113
+ )
114
+ return f"{repo_id}/{filename}"
115
+
116
+
117
+ def _fetch_seed_dataset_column_names_from_datastore(
118
+ repo_id: str,
119
+ filename: str,
120
+ datastore_settings: Optional[Union[DatastoreSettings, dict]] = None,
121
+ ) -> list[str]:
122
+ file_type = filename.split(".")[-1]
123
+ if f".{file_type}" not in VALID_DATASET_FILE_EXTENSIONS:
124
+ raise InvalidFileFormatError(f"🛑 Unsupported file type: {filename!r}")
125
+
126
+ datastore_settings = resolve_datastore_settings(datastore_settings)
127
+ fs = HfFileSystem(endpoint=datastore_settings.endpoint, token=datastore_settings.token)
128
+
129
+ with fs.open(f"datasets/{repo_id}/{filename}") as f:
130
+ return get_file_column_names(f, file_type)
131
+
132
+
133
+ def _fetch_seed_dataset_column_names_from_local_file(dataset_path: str | Path) -> list[str]:
134
+ dataset_path = _validate_dataset_path(dataset_path, allow_glob_pattern=True)
135
+ return get_file_column_names(dataset_path, str(dataset_path).split(".")[-1])
136
+
137
+
138
+ def _validate_dataset_path(dataset_path: Union[str, Path], allow_glob_pattern: bool = False) -> Path:
139
+ if allow_glob_pattern and "*" in str(dataset_path):
140
+ parts = str(dataset_path).split("*.")
141
+ file_path = parts[0]
142
+ file_extension = parts[-1]
143
+ validate_path_contains_files_of_type(file_path, file_extension)
144
+ return Path(dataset_path)
145
+ if not Path(dataset_path).is_file():
146
+ raise InvalidFilePathError("🛑 To upload a dataset to the datastore, you must provide a valid file path.")
147
+ if not Path(dataset_path).name.endswith(tuple(VALID_DATASET_FILE_EXTENSIONS)):
148
+ raise InvalidFileFormatError(
149
+ "🛑 Dataset files must be in `parquet`, `csv`, or `json` (orient='records', lines=True) format."
150
+ )
151
+ return Path(dataset_path)
@@ -0,0 +1,123 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ from functools import lru_cache
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import Any, Literal, Optional
9
+
10
+ from .models import InferenceParameters, ModelConfig, ModelProvider
11
+ from .utils.constants import (
12
+ MANAGED_ASSETS_PATH,
13
+ MODEL_CONFIGS_FILE_PATH,
14
+ MODEL_PROVIDERS_FILE_PATH,
15
+ PREDEFINED_PROVIDERS,
16
+ PREDEFINED_PROVIDERS_MODEL_MAP,
17
+ )
18
+ from .utils.info import ConfigBuilderInfo, InfoType, InterfaceInfo
19
+ from .utils.io_helpers import load_config_file, save_config_file
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def get_default_text_alias_inference_parameters() -> InferenceParameters:
25
+ return InferenceParameters(
26
+ temperature=0.85,
27
+ top_p=0.95,
28
+ )
29
+
30
+
31
+ def get_default_reasoning_alias_inference_parameters() -> InferenceParameters:
32
+ return InferenceParameters(
33
+ temperature=0.35,
34
+ top_p=0.95,
35
+ )
36
+
37
+
38
+ def get_default_vision_alias_inference_parameters() -> InferenceParameters:
39
+ return InferenceParameters(
40
+ temperature=0.85,
41
+ top_p=0.95,
42
+ )
43
+
44
+
45
+ def get_default_inference_parameters(model_alias: Literal["text", "reasoning", "vision"]) -> InferenceParameters:
46
+ if model_alias == "reasoning":
47
+ return get_default_reasoning_alias_inference_parameters()
48
+ elif model_alias == "vision":
49
+ return get_default_vision_alias_inference_parameters()
50
+ else:
51
+ return get_default_text_alias_inference_parameters()
52
+
53
+
54
+ def get_builtin_model_configs() -> list[ModelConfig]:
55
+ model_configs = []
56
+ for provider, model_alias_map in PREDEFINED_PROVIDERS_MODEL_MAP.items():
57
+ for model_alias, model_id in model_alias_map.items():
58
+ model_configs.append(
59
+ ModelConfig(
60
+ alias=f"{provider}-{model_alias}",
61
+ model=model_id,
62
+ provider=provider,
63
+ inference_parameters=get_default_inference_parameters(model_alias),
64
+ )
65
+ )
66
+ return model_configs
67
+
68
+
69
+ def get_builtin_model_providers() -> list[ModelProvider]:
70
+ return [ModelProvider.model_validate(provider) for provider in PREDEFINED_PROVIDERS]
71
+
72
+
73
+ def get_default_model_configs() -> list[ModelConfig]:
74
+ if MODEL_CONFIGS_FILE_PATH.exists():
75
+ config_dict = load_config_file(MODEL_CONFIGS_FILE_PATH)
76
+ if "model_configs" in config_dict:
77
+ return [ModelConfig.model_validate(mc) for mc in config_dict["model_configs"]]
78
+ raise FileNotFoundError(f"Default model configs file not found at {str(MODEL_CONFIGS_FILE_PATH)!r}")
79
+
80
+
81
+ def get_default_providers() -> list[ModelProvider]:
82
+ config_dict = _get_default_providers_file_content(MODEL_PROVIDERS_FILE_PATH)
83
+ if "providers" in config_dict:
84
+ return [ModelProvider.model_validate(p) for p in config_dict["providers"]]
85
+ return []
86
+
87
+
88
+ def get_default_provider_name() -> Optional[str]:
89
+ return _get_default_providers_file_content(MODEL_PROVIDERS_FILE_PATH).get("default")
90
+
91
+
92
+ def resolve_seed_default_model_settings() -> None:
93
+ if not MODEL_CONFIGS_FILE_PATH.exists():
94
+ logger.info(
95
+ f"🍾 Default model configs were not found, so writing the following to {str(MODEL_CONFIGS_FILE_PATH)!r}"
96
+ )
97
+ config_builder_info = ConfigBuilderInfo(model_configs=get_builtin_model_configs())
98
+ config_builder_info.display(info_type=InfoType.MODEL_CONFIGS)
99
+ save_config_file(
100
+ MODEL_CONFIGS_FILE_PATH, {"model_configs": [mc.model_dump() for mc in get_builtin_model_configs()]}
101
+ )
102
+
103
+ if not MODEL_PROVIDERS_FILE_PATH.exists():
104
+ logger.info(
105
+ f"🪄 Default model providers were not found, so writing the following to {str(MODEL_PROVIDERS_FILE_PATH)!r}"
106
+ )
107
+ interface_info = InterfaceInfo(model_providers=get_builtin_model_providers())
108
+ interface_info.display(info_type=InfoType.MODEL_PROVIDERS)
109
+ save_config_file(
110
+ MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump() for p in get_builtin_model_providers()]}
111
+ )
112
+
113
+ if not MANAGED_ASSETS_PATH.exists():
114
+ logger.debug(f"🏗️ Default managed assets path was not found, so creating it at {str(MANAGED_ASSETS_PATH)!r}")
115
+ MANAGED_ASSETS_PATH.mkdir(parents=True, exist_ok=True)
116
+
117
+
118
+ @lru_cache(maxsize=1)
119
+ def _get_default_providers_file_content(file_path: Path) -> dict[str, Any]:
120
+ """Load and cache the default providers file content."""
121
+ if file_path.exists():
122
+ return load_config_file(file_path)
123
+ raise FileNotFoundError(f"Default model providers file not found at {str(file_path)!r}")
@@ -0,0 +1,19 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from ..errors import DataDesignerError
5
+
6
+
7
+ class BuilderConfigurationError(DataDesignerError): ...
8
+
9
+
10
+ class InvalidColumnTypeError(DataDesignerError): ...
11
+
12
+
13
+ class InvalidConfigError(DataDesignerError): ...
14
+
15
+
16
+ class InvalidFilePathError(DataDesignerError): ...
17
+
18
+
19
+ class InvalidFileFormatError(DataDesignerError): ...
@@ -0,0 +1,54 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from abc import ABC, abstractmethod
7
+ from typing import TYPE_CHECKING, Generic, Protocol, TypeVar
8
+
9
+ import pandas as pd
10
+
11
+ from .models import ModelConfig, ModelProvider
12
+ from .utils.constants import DEFAULT_NUM_RECORDS
13
+ from .utils.info import InterfaceInfo
14
+
15
+ if TYPE_CHECKING:
16
+ from .analysis.dataset_profiler import DatasetProfilerResults
17
+ from .config_builder import DataDesignerConfigBuilder
18
+ from .preview_results import PreviewResults
19
+
20
+
21
+ class ResultsProtocol(Protocol):
22
+ def load_analysis(self) -> DatasetProfilerResults: ...
23
+ def load_dataset(self) -> pd.DataFrame: ...
24
+
25
+
26
+ ResultsT = TypeVar("ResultsT", bound=ResultsProtocol)
27
+
28
+
29
+ class DataDesignerInterface(ABC, Generic[ResultsT]):
30
+ @abstractmethod
31
+ def create(
32
+ self,
33
+ config_builder: DataDesignerConfigBuilder,
34
+ *,
35
+ num_records: int = DEFAULT_NUM_RECORDS,
36
+ ) -> ResultsT: ...
37
+
38
+ @abstractmethod
39
+ def preview(
40
+ self,
41
+ config_builder: DataDesignerConfigBuilder,
42
+ *,
43
+ num_records: int = DEFAULT_NUM_RECORDS,
44
+ ) -> PreviewResults: ...
45
+
46
+ @abstractmethod
47
+ def get_default_model_configs(self) -> list[ModelConfig]: ...
48
+
49
+ @abstractmethod
50
+ def get_default_model_providers(self) -> list[ModelProvider]: ...
51
+
52
+ @property
53
+ @abstractmethod
54
+ def info(self) -> InterfaceInfo: ...
@@ -0,0 +1,231 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC, abstractmethod
5
+ from enum import Enum
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import Any, Generic, List, Optional, TypeVar, Union
9
+
10
+ import numpy as np
11
+ from pydantic import BaseModel, Field, model_validator
12
+ from typing_extensions import Self, TypeAlias
13
+
14
+ from .base import ConfigBase
15
+ from .errors import InvalidConfigError
16
+ from .utils.constants import (
17
+ MAX_TEMPERATURE,
18
+ MAX_TOP_P,
19
+ MIN_TEMPERATURE,
20
+ MIN_TOP_P,
21
+ )
22
+ from .utils.io_helpers import smart_load_yaml
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class Modality(str, Enum):
28
+ IMAGE = "image"
29
+
30
+
31
+ class ModalityDataType(str, Enum):
32
+ URL = "url"
33
+ BASE64 = "base64"
34
+
35
+
36
+ class ImageFormat(str, Enum):
37
+ PNG = "png"
38
+ JPG = "jpg"
39
+ JPEG = "jpeg"
40
+ GIF = "gif"
41
+ WEBP = "webp"
42
+
43
+
44
+ class DistributionType(str, Enum):
45
+ UNIFORM = "uniform"
46
+ MANUAL = "manual"
47
+
48
+
49
+ class ModalityContext(ABC, BaseModel):
50
+ modality: Modality
51
+ column_name: str
52
+ data_type: ModalityDataType
53
+
54
+ @abstractmethod
55
+ def get_context(self, record: dict) -> dict[str, Any]: ...
56
+
57
+
58
+ class ImageContext(ModalityContext):
59
+ modality: Modality = Modality.IMAGE
60
+ image_format: Optional[ImageFormat] = None
61
+
62
+ def get_context(self, record: dict) -> dict[str, Any]:
63
+ context = dict(type="image_url")
64
+ context_value = record[self.column_name]
65
+ if self.data_type == ModalityDataType.URL:
66
+ context["image_url"] = context_value
67
+ else:
68
+ context["image_url"] = {
69
+ "url": f"data:image/{self.image_format.value};base64,{context_value}",
70
+ "format": self.image_format.value,
71
+ }
72
+ return context
73
+
74
+ @model_validator(mode="after")
75
+ def _validate_image_format(self) -> Self:
76
+ if self.data_type == ModalityDataType.BASE64 and self.image_format is None:
77
+ raise ValueError(f"image_format is required when data_type is {self.data_type.value}")
78
+ return self
79
+
80
+
81
+ DistributionParamsT = TypeVar("DistributionParamsT", bound=ConfigBase)
82
+
83
+
84
+ class Distribution(ABC, ConfigBase, Generic[DistributionParamsT]):
85
+ distribution_type: DistributionType
86
+ params: DistributionParamsT
87
+
88
+ @abstractmethod
89
+ def sample(self) -> float: ...
90
+
91
+
92
+ class ManualDistributionParams(ConfigBase):
93
+ values: List[float] = Field(min_length=1)
94
+ weights: Optional[List[float]] = None
95
+
96
+ @model_validator(mode="after")
97
+ def _normalize_weights(self) -> Self:
98
+ if self.weights is not None:
99
+ self.weights = [w / sum(self.weights) for w in self.weights]
100
+ return self
101
+
102
+ @model_validator(mode="after")
103
+ def _validate_equal_lengths(self) -> Self:
104
+ if self.weights and len(self.values) != len(self.weights):
105
+ raise ValueError("`values` and `weights` must have the same length")
106
+ return self
107
+
108
+
109
+ class ManualDistribution(Distribution[ManualDistributionParams]):
110
+ distribution_type: Optional[DistributionType] = "manual"
111
+ params: ManualDistributionParams
112
+
113
+ def sample(self) -> float:
114
+ return float(np.random.choice(self.params.values, p=self.params.weights))
115
+
116
+
117
+ class UniformDistributionParams(ConfigBase):
118
+ low: float
119
+ high: float
120
+
121
+ @model_validator(mode="after")
122
+ def _validate_low_lt_high(self) -> Self:
123
+ if self.low >= self.high:
124
+ raise ValueError("`low` must be less than `high`")
125
+ return self
126
+
127
+
128
+ class UniformDistribution(Distribution[UniformDistributionParams]):
129
+ distribution_type: Optional[DistributionType] = "uniform"
130
+ params: UniformDistributionParams
131
+
132
+ def sample(self) -> float:
133
+ return float(np.random.uniform(low=self.params.low, high=self.params.high, size=1)[0])
134
+
135
+
136
+ DistributionT: TypeAlias = Union[UniformDistribution, ManualDistribution]
137
+
138
+
139
+ class InferenceParameters(ConfigBase):
140
+ temperature: Optional[Union[float, DistributionT]] = None
141
+ top_p: Optional[Union[float, DistributionT]] = None
142
+ max_tokens: Optional[int] = Field(default=None, ge=1)
143
+ max_parallel_requests: int = Field(default=4, ge=1)
144
+ timeout: Optional[int] = Field(default=None, ge=1)
145
+ extra_body: Optional[dict[str, Any]] = None
146
+
147
+ @property
148
+ def generate_kwargs(self) -> dict[str, Union[float, int]]:
149
+ result = {}
150
+ if self.temperature is not None:
151
+ result["temperature"] = (
152
+ self.temperature.sample() if hasattr(self.temperature, "sample") else self.temperature
153
+ )
154
+ if self.top_p is not None:
155
+ result["top_p"] = self.top_p.sample() if hasattr(self.top_p, "sample") else self.top_p
156
+ if self.max_tokens is not None:
157
+ result["max_tokens"] = self.max_tokens
158
+ if self.timeout is not None:
159
+ result["timeout"] = self.timeout
160
+ if self.extra_body is not None and self.extra_body != {}:
161
+ result["extra_body"] = self.extra_body
162
+ return result
163
+
164
+ @model_validator(mode="after")
165
+ def _validate_temperature(self) -> Self:
166
+ return self._run_validation(
167
+ value=self.temperature,
168
+ param_name="temperature",
169
+ min_value=MIN_TEMPERATURE,
170
+ max_value=MAX_TEMPERATURE,
171
+ )
172
+
173
+ @model_validator(mode="after")
174
+ def _validate_top_p(self) -> Self:
175
+ return self._run_validation(
176
+ value=self.top_p,
177
+ param_name="top_p",
178
+ min_value=MIN_TOP_P,
179
+ max_value=MAX_TOP_P,
180
+ )
181
+
182
+ def _run_validation(
183
+ self,
184
+ value: Union[float, DistributionT, None],
185
+ param_name: str,
186
+ min_value: float,
187
+ max_value: float,
188
+ ) -> Self:
189
+ if value is None:
190
+ return self
191
+ value_err = ValueError(f"{param_name} defined in model config must be between {min_value} and {max_value}")
192
+ if isinstance(value, Distribution):
193
+ if value.distribution_type == DistributionType.UNIFORM:
194
+ if value.params.low < min_value or value.params.high > max_value:
195
+ raise value_err
196
+ elif value.distribution_type == DistributionType.MANUAL:
197
+ if any(not self._is_value_in_range(v, min_value, max_value) for v in value.params.values):
198
+ raise value_err
199
+ else:
200
+ if not self._is_value_in_range(value, min_value, max_value):
201
+ raise value_err
202
+ return self
203
+
204
+ def _is_value_in_range(self, value: float, min_value: float, max_value: float) -> bool:
205
+ return min_value <= value <= max_value
206
+
207
+
208
+ class ModelConfig(ConfigBase):
209
+ alias: str
210
+ model: str
211
+ inference_parameters: InferenceParameters = Field(default_factory=InferenceParameters)
212
+ provider: Optional[str] = None
213
+
214
+
215
+ class ModelProvider(ConfigBase):
216
+ name: str
217
+ endpoint: str
218
+ provider_type: str = "openai"
219
+ api_key: Optional[str] = None
220
+ extra_body: Optional[dict[str, Any]] = None
221
+
222
+
223
+ def load_model_configs(model_configs: Union[list[ModelConfig], str, Path]) -> list[ModelConfig]:
224
+ if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs):
225
+ return model_configs
226
+ json_config = smart_load_yaml(model_configs)
227
+ if "model_configs" not in json_config:
228
+ raise InvalidConfigError(
229
+ "The list of model configs must be provided under model_configs in the configuration file."
230
+ )
231
+ return [ModelConfig.model_validate(mc) for mc in json_config["model_configs"]]
@@ -0,0 +1,32 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import Optional
7
+
8
+ import pandas as pd
9
+
10
+ from .analysis.dataset_profiler import DatasetProfilerResults
11
+ from .config_builder import DataDesignerConfigBuilder
12
+ from .utils.visualization import WithRecordSamplerMixin
13
+
14
+
15
+ class PreviewResults(WithRecordSamplerMixin):
16
+ def __init__(
17
+ self,
18
+ *,
19
+ config_builder: DataDesignerConfigBuilder,
20
+ dataset: Optional[pd.DataFrame] = None,
21
+ analysis: Optional[DatasetProfilerResults] = None,
22
+ ):
23
+ """Creates a new instance with results from a Data Designer preview run.
24
+
25
+ Args:
26
+ config_builder: Data Designer configuration builder.
27
+ dataset: Dataset of the preview run.
28
+ analysis: Analysis of the preview run.
29
+ """
30
+ self.dataset: pd.DataFrame | None = dataset
31
+ self.analysis: DatasetProfilerResults | None = analysis
32
+ self._config_builder = config_builder