data-designer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. data_designer/__init__.py +15 -0
  2. data_designer/_version.py +34 -0
  3. data_designer/cli/README.md +236 -0
  4. data_designer/cli/__init__.py +6 -0
  5. data_designer/cli/commands/__init__.py +2 -0
  6. data_designer/cli/commands/list.py +130 -0
  7. data_designer/cli/commands/models.py +10 -0
  8. data_designer/cli/commands/providers.py +11 -0
  9. data_designer/cli/commands/reset.py +100 -0
  10. data_designer/cli/controllers/__init__.py +7 -0
  11. data_designer/cli/controllers/model_controller.py +246 -0
  12. data_designer/cli/controllers/provider_controller.py +317 -0
  13. data_designer/cli/forms/__init__.py +20 -0
  14. data_designer/cli/forms/builder.py +51 -0
  15. data_designer/cli/forms/field.py +180 -0
  16. data_designer/cli/forms/form.py +59 -0
  17. data_designer/cli/forms/model_builder.py +125 -0
  18. data_designer/cli/forms/provider_builder.py +76 -0
  19. data_designer/cli/main.py +44 -0
  20. data_designer/cli/repositories/__init__.py +8 -0
  21. data_designer/cli/repositories/base.py +39 -0
  22. data_designer/cli/repositories/model_repository.py +42 -0
  23. data_designer/cli/repositories/provider_repository.py +43 -0
  24. data_designer/cli/services/__init__.py +7 -0
  25. data_designer/cli/services/model_service.py +116 -0
  26. data_designer/cli/services/provider_service.py +111 -0
  27. data_designer/cli/ui.py +448 -0
  28. data_designer/cli/utils.py +47 -0
  29. data_designer/config/__init__.py +2 -0
  30. data_designer/config/analysis/column_profilers.py +89 -0
  31. data_designer/config/analysis/column_statistics.py +274 -0
  32. data_designer/config/analysis/dataset_profiler.py +60 -0
  33. data_designer/config/analysis/utils/errors.py +8 -0
  34. data_designer/config/analysis/utils/reporting.py +188 -0
  35. data_designer/config/base.py +68 -0
  36. data_designer/config/column_configs.py +354 -0
  37. data_designer/config/column_types.py +168 -0
  38. data_designer/config/config_builder.py +660 -0
  39. data_designer/config/data_designer_config.py +40 -0
  40. data_designer/config/dataset_builders.py +11 -0
  41. data_designer/config/datastore.py +151 -0
  42. data_designer/config/default_model_settings.py +123 -0
  43. data_designer/config/errors.py +19 -0
  44. data_designer/config/interface.py +54 -0
  45. data_designer/config/models.py +231 -0
  46. data_designer/config/preview_results.py +32 -0
  47. data_designer/config/processors.py +41 -0
  48. data_designer/config/sampler_constraints.py +51 -0
  49. data_designer/config/sampler_params.py +604 -0
  50. data_designer/config/seed.py +145 -0
  51. data_designer/config/utils/code_lang.py +83 -0
  52. data_designer/config/utils/constants.py +313 -0
  53. data_designer/config/utils/errors.py +19 -0
  54. data_designer/config/utils/info.py +88 -0
  55. data_designer/config/utils/io_helpers.py +273 -0
  56. data_designer/config/utils/misc.py +81 -0
  57. data_designer/config/utils/numerical_helpers.py +28 -0
  58. data_designer/config/utils/type_helpers.py +100 -0
  59. data_designer/config/utils/validation.py +336 -0
  60. data_designer/config/utils/visualization.py +427 -0
  61. data_designer/config/validator_params.py +96 -0
  62. data_designer/engine/__init__.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +55 -0
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
  65. data_designer/engine/analysis/column_profilers/registry.py +20 -0
  66. data_designer/engine/analysis/column_statistics.py +142 -0
  67. data_designer/engine/analysis/dataset_profiler.py +125 -0
  68. data_designer/engine/analysis/errors.py +7 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
  70. data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
  71. data_designer/engine/column_generators/__init__.py +2 -0
  72. data_designer/engine/column_generators/generators/__init__.py +2 -0
  73. data_designer/engine/column_generators/generators/base.py +61 -0
  74. data_designer/engine/column_generators/generators/expression.py +63 -0
  75. data_designer/engine/column_generators/generators/llm_generators.py +172 -0
  76. data_designer/engine/column_generators/generators/samplers.py +75 -0
  77. data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
  78. data_designer/engine/column_generators/generators/validation.py +147 -0
  79. data_designer/engine/column_generators/registry.py +56 -0
  80. data_designer/engine/column_generators/utils/errors.py +13 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
  83. data_designer/engine/configurable_task.py +82 -0
  84. data_designer/engine/dataset_builders/artifact_storage.py +181 -0
  85. data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
  86. data_designer/engine/dataset_builders/errors.py +13 -0
  87. data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
  88. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +56 -0
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
  93. data_designer/engine/dataset_builders/utils/errors.py +13 -0
  94. data_designer/engine/errors.py +49 -0
  95. data_designer/engine/model_provider.py +75 -0
  96. data_designer/engine/models/__init__.py +2 -0
  97. data_designer/engine/models/errors.py +308 -0
  98. data_designer/engine/models/facade.py +225 -0
  99. data_designer/engine/models/litellm_overrides.py +162 -0
  100. data_designer/engine/models/parsers/__init__.py +2 -0
  101. data_designer/engine/models/parsers/errors.py +34 -0
  102. data_designer/engine/models/parsers/parser.py +236 -0
  103. data_designer/engine/models/parsers/postprocessors.py +93 -0
  104. data_designer/engine/models/parsers/tag_parsers.py +60 -0
  105. data_designer/engine/models/parsers/types.py +82 -0
  106. data_designer/engine/models/recipes/base.py +79 -0
  107. data_designer/engine/models/recipes/response_recipes.py +291 -0
  108. data_designer/engine/models/registry.py +118 -0
  109. data_designer/engine/models/usage.py +75 -0
  110. data_designer/engine/models/utils.py +38 -0
  111. data_designer/engine/processing/ginja/__init__.py +2 -0
  112. data_designer/engine/processing/ginja/ast.py +64 -0
  113. data_designer/engine/processing/ginja/environment.py +461 -0
  114. data_designer/engine/processing/ginja/exceptions.py +54 -0
  115. data_designer/engine/processing/ginja/record.py +30 -0
  116. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  117. data_designer/engine/processing/gsonschema/exceptions.py +8 -0
  118. data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
  119. data_designer/engine/processing/gsonschema/types.py +8 -0
  120. data_designer/engine/processing/gsonschema/validators.py +143 -0
  121. data_designer/engine/processing/processors/base.py +15 -0
  122. data_designer/engine/processing/processors/drop_columns.py +46 -0
  123. data_designer/engine/processing/processors/registry.py +20 -0
  124. data_designer/engine/processing/utils.py +120 -0
  125. data_designer/engine/registry/base.py +97 -0
  126. data_designer/engine/registry/data_designer_registry.py +37 -0
  127. data_designer/engine/registry/errors.py +10 -0
  128. data_designer/engine/resources/managed_dataset_generator.py +35 -0
  129. data_designer/engine/resources/managed_dataset_repository.py +194 -0
  130. data_designer/engine/resources/managed_storage.py +63 -0
  131. data_designer/engine/resources/resource_provider.py +46 -0
  132. data_designer/engine/resources/seed_dataset_data_store.py +66 -0
  133. data_designer/engine/sampling_gen/column.py +89 -0
  134. data_designer/engine/sampling_gen/constraints.py +95 -0
  135. data_designer/engine/sampling_gen/data_sources/base.py +214 -0
  136. data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
  137. data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
  138. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  139. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  140. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
  141. data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
  142. data_designer/engine/sampling_gen/entities/errors.py +8 -0
  143. data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
  144. data_designer/engine/sampling_gen/entities/person.py +142 -0
  145. data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
  146. data_designer/engine/sampling_gen/errors.py +24 -0
  147. data_designer/engine/sampling_gen/generator.py +121 -0
  148. data_designer/engine/sampling_gen/jinja_utils.py +60 -0
  149. data_designer/engine/sampling_gen/people_gen.py +203 -0
  150. data_designer/engine/sampling_gen/person_constants.py +54 -0
  151. data_designer/engine/sampling_gen/schema.py +143 -0
  152. data_designer/engine/sampling_gen/schema_builder.py +59 -0
  153. data_designer/engine/sampling_gen/utils.py +40 -0
  154. data_designer/engine/secret_resolver.py +80 -0
  155. data_designer/engine/validators/__init__.py +17 -0
  156. data_designer/engine/validators/base.py +36 -0
  157. data_designer/engine/validators/local_callable.py +34 -0
  158. data_designer/engine/validators/python.py +245 -0
  159. data_designer/engine/validators/remote.py +83 -0
  160. data_designer/engine/validators/sql.py +60 -0
  161. data_designer/errors.py +5 -0
  162. data_designer/essentials/__init__.py +137 -0
  163. data_designer/interface/__init__.py +2 -0
  164. data_designer/interface/data_designer.py +351 -0
  165. data_designer/interface/errors.py +16 -0
  166. data_designer/interface/results.py +55 -0
  167. data_designer/logging.py +161 -0
  168. data_designer/plugin_manager.py +83 -0
  169. data_designer/plugins/__init__.py +6 -0
  170. data_designer/plugins/errors.py +10 -0
  171. data_designer/plugins/plugin.py +69 -0
  172. data_designer/plugins/registry.py +86 -0
  173. data_designer-0.1.0.dist-info/METADATA +173 -0
  174. data_designer-0.1.0.dist-info/RECORD +177 -0
  175. data_designer-0.1.0.dist-info/WHEEL +4 -0
  176. data_designer-0.1.0.dist-info/entry_points.txt +2 -0
  177. data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,273 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from datetime import date, datetime, timedelta
5
+ from decimal import Decimal
6
+ import json
7
+ import logging
8
+ from numbers import Number
9
+ import os
10
+ from pathlib import Path
11
+ from typing import Any, Union
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+ import yaml
16
+
17
+ from ..errors import InvalidFileFormatError, InvalidFilePathError
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ VALID_DATASET_FILE_EXTENSIONS = {".parquet", ".csv", ".json", ".jsonl"}
22
+
23
+
24
+ def ensure_config_dir_exists(config_dir: Path) -> None:
25
+ """Create configuration directory if it doesn't exist.
26
+
27
+ Args:
28
+ config_dir: Directory path to create
29
+ """
30
+ config_dir.mkdir(parents=True, exist_ok=True)
31
+
32
+
33
+ def load_config_file(file_path: Path) -> dict:
34
+ """Load a YAML configuration file.
35
+
36
+ Args:
37
+ file_path: Path to the YAML file
38
+
39
+ Returns:
40
+ Parsed YAML content as dictionary
41
+
42
+ Raises:
43
+ InvalidFilePathError: If file doesn't exist
44
+ InvalidFileFormatError: If YAML is malformed
45
+ InvalidConfigError: If file is empty
46
+ """
47
+ from ..errors import InvalidConfigError
48
+
49
+ if not file_path.exists():
50
+ raise InvalidFilePathError(f"Configuration file not found: {file_path}")
51
+
52
+ try:
53
+ with open(file_path) as f:
54
+ content = yaml.safe_load(f)
55
+
56
+ if content is None:
57
+ raise InvalidConfigError(f"Configuration file is empty: {file_path}")
58
+
59
+ return content
60
+
61
+ except yaml.YAMLError as e:
62
+ raise InvalidFileFormatError(f"Invalid YAML format in {file_path}: {e}")
63
+
64
+
65
+ def save_config_file(file_path: Path, config: dict) -> None:
66
+ """Save configuration to a YAML file.
67
+
68
+ Args:
69
+ file_path: Path where to save the file
70
+ config: Configuration dictionary to save
71
+
72
+ Raises:
73
+ IOError: If file cannot be written
74
+ """
75
+ # Ensure parent directory exists
76
+ file_path.parent.mkdir(parents=True, exist_ok=True)
77
+
78
+ with open(file_path, "w") as f:
79
+ yaml.safe_dump(
80
+ config,
81
+ f,
82
+ default_flow_style=False,
83
+ sort_keys=False,
84
+ indent=2,
85
+ allow_unicode=True,
86
+ )
87
+
88
+
89
+ def read_parquet_dataset(path: Path) -> pd.DataFrame:
90
+ """Read a parquet dataset from a path.
91
+
92
+ Args:
93
+ path: The path to the parquet dataset, can be either a file or a directory.
94
+
95
+ Returns:
96
+ The parquet dataset as a pandas DataFrame.
97
+ """
98
+ try:
99
+ return pd.read_parquet(path, dtype_backend="pyarrow")
100
+ except Exception as e:
101
+ if path.is_dir() and "Unsupported cast" in str(e):
102
+ logger.warning("Failed to read parquets as folder, falling back to individual files")
103
+ return pd.concat(
104
+ [pd.read_parquet(file, dtype_backend="pyarrow") for file in sorted(path.glob("*.parquet"))],
105
+ ignore_index=True,
106
+ )
107
+ else:
108
+ raise e
109
+
110
+
111
+ def write_seed_dataset(dataframe: pd.DataFrame, file_path: Path) -> None:
112
+ """Write a seed dataset to a file in the specified format.
113
+
114
+ Supported file extensions: .parquet, .csv, .json, .jsonl
115
+
116
+ Args:
117
+ dataframe: The pandas DataFrame to write.
118
+ file_path: The path where the dataset should be saved.
119
+ Format is inferred from the file extension.
120
+ """
121
+ file_path = validate_dataset_file_path(file_path, should_exist=False)
122
+ logger.info(f"💾 Saving seed dataset to {file_path}")
123
+ if file_path.suffix.lower() == ".parquet":
124
+ dataframe.to_parquet(file_path, index=False)
125
+ elif file_path.suffix.lower() == ".csv":
126
+ dataframe.to_csv(file_path, index=False)
127
+ elif file_path.suffix.lower() in {".json", ".jsonl"}:
128
+ dataframe.to_json(file_path, orient="records", lines=True)
129
+
130
+
131
+ def validate_dataset_file_path(file_path: Union[str, Path], should_exist: bool = True) -> Path:
132
+ """Validate that a dataset file path has a valid extension and optionally exists.
133
+
134
+ Args:
135
+ file_path: The path to validate, either as a string or Path object.
136
+ should_exist: If True, verify that the file exists. Defaults to True.
137
+ Returns:
138
+ The validated path as a Path object.
139
+ Raises:
140
+ InvalidFilePathError: If the path is not a file.
141
+ InvalidFileFormatError: If the path does not have a valid extension.
142
+ """
143
+ file_path = Path(file_path)
144
+ if should_exist and not Path(file_path).is_file():
145
+ raise InvalidFilePathError(f"🛑 Path {file_path} is not a file.")
146
+ if not file_path.name.lower().endswith(tuple(VALID_DATASET_FILE_EXTENSIONS)):
147
+ raise InvalidFileFormatError(
148
+ "🛑 Dataset files must be in parquet, csv, or jsonl/json (orient='records', lines=True) format."
149
+ )
150
+ return file_path
151
+
152
+
153
+ def validate_path_contains_files_of_type(path: str | Path, file_extension: str) -> None:
154
+ """Validate that a path contains files of a specific type.
155
+
156
+ Args:
157
+ path: The path to validate. Can contain wildcards like `*.parquet`.
158
+ file_extension: The extension of the files to validate (without the dot, e.g., "parquet").
159
+ Returns:
160
+ None if the path contains files of the specified type, raises an error otherwise.
161
+ Raises:
162
+ InvalidFilePathError: If the path does not contain files of the specified type.
163
+ """
164
+ if not any(Path(path).glob(f"*.{file_extension}")):
165
+ raise InvalidFilePathError(f"🛑 Path {path!r} does not contain files of type {file_extension!r}.")
166
+
167
+
168
+ def smart_load_dataframe(dataframe: Union[str, Path, pd.DataFrame]) -> pd.DataFrame:
169
+ """Load a dataframe from file if a path is given, otherwise return the dataframe.
170
+
171
+ Args:
172
+ dataframe: A path to a file or a pandas DataFrame object.
173
+
174
+ Returns:
175
+ A pandas DataFrame object.
176
+ """
177
+ if isinstance(dataframe, pd.DataFrame):
178
+ return dataframe
179
+
180
+ # Get the file extension.
181
+ if isinstance(dataframe, str) and dataframe.startswith("http"):
182
+ ext = dataframe.split(".")[-1].lower()
183
+ else:
184
+ dataframe = Path(dataframe)
185
+ ext = dataframe.suffix.lower()
186
+ if not dataframe.exists():
187
+ raise FileNotFoundError(f"File not found: {dataframe}")
188
+
189
+ # Load the dataframe based on the file extension.
190
+ if ext == "csv":
191
+ return pd.read_csv(dataframe)
192
+ elif ext == "json":
193
+ return pd.read_json(dataframe, lines=True)
194
+ elif ext == "parquet":
195
+ return pd.read_parquet(dataframe)
196
+ else:
197
+ raise ValueError(f"Unsupported file format: {dataframe}")
198
+
199
+
200
+ def smart_load_yaml(yaml_in: Union[str, Path, dict]) -> dict:
201
+ """Return the yaml config as a dict given flexible input types.
202
+
203
+ Args:
204
+ config: The config as a dict, yaml string, or yaml file path.
205
+
206
+ Returns:
207
+ The config as a dict.
208
+ """
209
+ if isinstance(yaml_in, dict):
210
+ yaml_out = yaml_in
211
+ elif isinstance(yaml_in, Path) or (isinstance(yaml_in, str) and os.path.isfile(yaml_in)):
212
+ with open(yaml_in) as file:
213
+ yaml_out = yaml.safe_load(file)
214
+ elif isinstance(yaml_in, str):
215
+ if yaml_in.endswith((".yaml", ".yml")) and not os.path.isfile(yaml_in):
216
+ raise FileNotFoundError(f"File not found: {yaml_in}")
217
+ else:
218
+ yaml_out = yaml.safe_load(yaml_in)
219
+ else:
220
+ raise ValueError(
221
+ f"'{yaml_in}' is an invalid yaml config format. Valid options are: dict, yaml string, or yaml file path."
222
+ )
223
+
224
+ if not isinstance(yaml_out, dict):
225
+ raise ValueError(f"Loaded yaml must be a dict. Got {yaml_out}, which is of type {type(yaml_out)}.")
226
+
227
+ return yaml_out
228
+
229
+
230
+ def serialize_data(data: Union[dict, list, str, Number], **kwargs) -> str:
231
+ if isinstance(data, dict):
232
+ return json.dumps(data, ensure_ascii=False, default=_convert_to_serializable, **kwargs)
233
+ elif isinstance(data, list):
234
+ return json.dumps(data, ensure_ascii=False, default=_convert_to_serializable, **kwargs)
235
+ elif isinstance(data, str):
236
+ return data
237
+ elif isinstance(data, Number):
238
+ return str(data)
239
+ else:
240
+ raise ValueError(f"Invalid data type: {type(data)}")
241
+
242
+
243
+ def _convert_to_serializable(obj: Any) -> Any:
244
+ """Convert non-JSON-serializable objects to JSON-serializable Python-native types.
245
+
246
+ Raises:
247
+ TypeError: If the object type is not supported for serialization.
248
+ """
249
+ if isinstance(obj, (set, list)):
250
+ return list(obj)
251
+ if isinstance(obj, (pd.Series, np.ndarray)):
252
+ return obj.tolist()
253
+
254
+ if pd.isna(obj):
255
+ return None
256
+
257
+ if isinstance(obj, (datetime, date, pd.Timestamp)):
258
+ return obj.isoformat()
259
+ if isinstance(obj, timedelta):
260
+ return obj.total_seconds()
261
+ if isinstance(obj, (np.datetime64, np.timedelta64)):
262
+ return str(obj)
263
+
264
+ if isinstance(obj, Decimal):
265
+ return float(obj)
266
+ if isinstance(obj, (np.integer, np.floating, np.bool_)):
267
+ return obj.item()
268
+
269
+ if isinstance(obj, bytes):
270
+ return obj.decode("utf-8", errors="replace")
271
+
272
+ # Unsupported type
273
+ raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
@@ -0,0 +1,81 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from contextlib import contextmanager
7
+ import json
8
+ from typing import Optional, Union
9
+
10
+ from jinja2 import TemplateSyntaxError, meta
11
+ from jinja2.sandbox import ImmutableSandboxedEnvironment
12
+
13
+ from .errors import UserJinjaTemplateSyntaxError
14
+
15
+ REPR_LIST_LENGTH_USE_JSON = 4
16
+
17
+
18
+ def kebab_to_snake(s: str) -> str:
19
+ return s.replace("-", "_")
20
+
21
+
22
+ @contextmanager
23
+ def template_error_handler():
24
+ try:
25
+ yield
26
+ except TemplateSyntaxError as exception:
27
+ exception_string = (
28
+ f"Encountered a syntax error in the provided Jinja2 template:\n{str(exception)}\n"
29
+ "For more information on writing Jinja2 templates, "
30
+ "refer to https://jinja.palletsprojects.com/en/stable/templates"
31
+ )
32
+ raise UserJinjaTemplateSyntaxError(exception_string)
33
+ except Exception:
34
+ raise
35
+
36
+
37
+ def assert_valid_jinja2_template(template: str) -> None:
38
+ """Raises an error if the template cannot be parsed."""
39
+ with template_error_handler():
40
+ meta.find_undeclared_variables(ImmutableSandboxedEnvironment().parse(template))
41
+
42
+
43
+ def can_run_data_designer_locally() -> bool:
44
+ """Returns True if Data Designer can be run locally, False otherwise."""
45
+ try:
46
+ from ... import engine # noqa: F401
47
+ except ImportError:
48
+ return False
49
+ return True
50
+
51
+
52
+ def get_prompt_template_keywords(template: str) -> set[str]:
53
+ """Extract all keywords from a valid string template."""
54
+ with template_error_handler():
55
+ ast = ImmutableSandboxedEnvironment().parse(template)
56
+ keywords = set(meta.find_undeclared_variables(ast))
57
+
58
+ return keywords
59
+
60
+
61
+ def json_indent_list_of_strings(
62
+ column_names: list[str], *, indent: Optional[Union[int, str]] = None
63
+ ) -> Optional[Union[list[str], str]]:
64
+ """Convert a list of column names to a JSON string if the list is long.
65
+
66
+ This function helps keep Data Designer's __repr__ output clean and readable.
67
+
68
+ Args:
69
+ column_names: List of column names.
70
+ indent: Indentation for the JSON string.
71
+
72
+ Returns:
73
+ A list of column names or a JSON string if the list is long.
74
+ """
75
+ return (
76
+ None
77
+ if len(column_names) == 0
78
+ else (
79
+ column_names if len(column_names) < REPR_LIST_LENGTH_USE_JSON else json.dumps(column_names, indent=indent)
80
+ )
81
+ )
@@ -0,0 +1,28 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numbers
5
+ from numbers import Number
6
+ from typing import Any, Type
7
+
8
+ from .constants import REPORTING_PRECISION
9
+
10
+
11
+ def is_int(val: Any) -> bool:
12
+ return isinstance(val, numbers.Integral)
13
+
14
+
15
+ def is_float(val: Any) -> bool:
16
+ return isinstance(val, numbers.Real) and not isinstance(val, numbers.Integral)
17
+
18
+
19
+ def prepare_number_for_reporting(
20
+ value: Number,
21
+ target_type: Type[Number],
22
+ precision: int = REPORTING_PRECISION,
23
+ ) -> Number:
24
+ """Ensure native python types and round to `precision` decimal digits."""
25
+ value = target_type(value)
26
+ if is_float(value):
27
+ return round(value, precision)
28
+ return value
@@ -0,0 +1,100 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from enum import Enum
5
+ import inspect
6
+ from typing import Any, Literal, Type, get_args, get_origin
7
+
8
+ from pydantic import BaseModel
9
+
10
+ from .. import sampler_params
11
+ from .errors import InvalidDiscriminatorFieldError, InvalidEnumValueError, InvalidTypeUnionError
12
+
13
+
14
+ class StrEnum(str, Enum):
15
+ pass
16
+
17
+
18
+ def create_str_enum_from_discriminated_type_union(
19
+ enum_name: str,
20
+ type_union: type,
21
+ discriminator_field_name: str,
22
+ ) -> StrEnum:
23
+ """Create a string enum from a type union.
24
+
25
+ The type union is assumed to be a union of configs (Pydantic models) that have a discriminator field,
26
+ which must be a Literal string type - e.g., Literal["expression"].
27
+
28
+ Args:
29
+ enum_name: Name of the StrEnum.
30
+ type_union: Type union of configs (Pydantic models).
31
+ discriminator_field_name: Name of the discriminator field.
32
+
33
+ Returns:
34
+ StrEnum with values being the discriminator field values of the configs in the type union.
35
+
36
+ Example:
37
+ DataDesignerColumnType = create_str_enum_from_discriminated_type_union(
38
+ enum_name="DataDesignerColumnType",
39
+ type_union=ColumnConfigT,
40
+ discriminator_field_name="column_type",
41
+ )
42
+ """
43
+ discriminator_field_values = []
44
+ for model in type_union.__args__:
45
+ if not issubclass(model, BaseModel):
46
+ raise InvalidTypeUnionError(f"🛑 {model} must be a subclass of pydantic.BaseModel.")
47
+ if discriminator_field_name not in model.model_fields:
48
+ raise InvalidDiscriminatorFieldError(f"🛑 '{discriminator_field_name}' is not a field of {model}.")
49
+ if get_origin(model.model_fields[discriminator_field_name].annotation) is not Literal:
50
+ raise InvalidDiscriminatorFieldError(f"🛑 '{discriminator_field_name}' must be a Literal type.")
51
+ discriminator_field_values.extend(get_args(model.model_fields[discriminator_field_name].annotation))
52
+ return StrEnum(enum_name, {v.replace("-", "_").upper(): v for v in set(discriminator_field_values)})
53
+
54
+
55
+ def get_sampler_params() -> dict[str, Type[BaseModel]]:
56
+ """Returns a dictionary of sampler parameter classes."""
57
+ params_cls_list = [
58
+ params_cls
59
+ for cls_name, params_cls in inspect.getmembers(sampler_params, inspect.isclass)
60
+ if cls_name.endswith("SamplerParams")
61
+ ]
62
+
63
+ params_cls_dict = {}
64
+
65
+ for source in sampler_params.SamplerType:
66
+ source_name = source.value.replace("_", "")
67
+ # Iterate in reverse order so the shortest match is first.
68
+ # This is necessary for params that start with the same name.
69
+ # For example, "bernoulli" and "bernoulli_mixture".
70
+ params_cls_dict[source.value] = [
71
+ params_cls
72
+ for params_cls in reversed(params_cls_list)
73
+ # Match param type string with parameter class.
74
+ # For example, "gaussian" -> "GaussianSamplerParams"
75
+ if source_name == params_cls.__name__.lower()[: len(source_name)]
76
+ # Take the first match.
77
+ ][0]
78
+
79
+ return params_cls_dict
80
+
81
+
82
+ def resolve_string_enum(enum_instance: Any, enum_type: Type[Enum]) -> Enum:
83
+ if not issubclass(enum_type, Enum):
84
+ raise InvalidEnumValueError(f"🛑 `enum_type` must be a subclass of Enum. You provided: {enum_type}")
85
+ invalid_enum_value_error = InvalidEnumValueError(
86
+ f"🛑 '{enum_instance}' is not a valid string enum of type {type(enum_type)}. "
87
+ f"Valid options are: {[option.value for option in enum_type]}"
88
+ )
89
+ if isinstance(enum_instance, enum_type):
90
+ return enum_instance
91
+ elif isinstance(enum_instance, str):
92
+ try:
93
+ return enum_type(enum_instance)
94
+ except ValueError:
95
+ raise invalid_enum_value_error
96
+ else:
97
+ raise invalid_enum_value_error
98
+
99
+
100
+ SAMPLER_PARAMS = get_sampler_params()