data-designer-config 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. data_designer/config/__init__.py +149 -0
  2. data_designer/config/_version.py +34 -0
  3. data_designer/config/analysis/__init__.py +2 -0
  4. data_designer/config/analysis/column_profilers.py +159 -0
  5. data_designer/config/analysis/column_statistics.py +421 -0
  6. data_designer/config/analysis/dataset_profiler.py +84 -0
  7. data_designer/config/analysis/utils/errors.py +10 -0
  8. data_designer/config/analysis/utils/reporting.py +192 -0
  9. data_designer/config/base.py +69 -0
  10. data_designer/config/column_configs.py +476 -0
  11. data_designer/config/column_types.py +141 -0
  12. data_designer/config/config_builder.py +595 -0
  13. data_designer/config/data_designer_config.py +40 -0
  14. data_designer/config/dataset_builders.py +13 -0
  15. data_designer/config/dataset_metadata.py +18 -0
  16. data_designer/config/default_model_settings.py +129 -0
  17. data_designer/config/errors.py +24 -0
  18. data_designer/config/interface.py +55 -0
  19. data_designer/config/models.py +486 -0
  20. data_designer/config/preview_results.py +41 -0
  21. data_designer/config/processors.py +148 -0
  22. data_designer/config/run_config.py +56 -0
  23. data_designer/config/sampler_constraints.py +52 -0
  24. data_designer/config/sampler_params.py +639 -0
  25. data_designer/config/seed.py +116 -0
  26. data_designer/config/seed_source.py +84 -0
  27. data_designer/config/seed_source_types.py +19 -0
  28. data_designer/config/testing/__init__.py +6 -0
  29. data_designer/config/testing/fixtures.py +308 -0
  30. data_designer/config/utils/code_lang.py +93 -0
  31. data_designer/config/utils/constants.py +365 -0
  32. data_designer/config/utils/errors.py +21 -0
  33. data_designer/config/utils/info.py +94 -0
  34. data_designer/config/utils/io_helpers.py +258 -0
  35. data_designer/config/utils/misc.py +78 -0
  36. data_designer/config/utils/numerical_helpers.py +30 -0
  37. data_designer/config/utils/type_helpers.py +106 -0
  38. data_designer/config/utils/visualization.py +482 -0
  39. data_designer/config/validator_params.py +94 -0
  40. data_designer/errors.py +7 -0
  41. data_designer/lazy_heavy_imports.py +56 -0
  42. data_designer/logging.py +180 -0
  43. data_designer/plugin_manager.py +78 -0
  44. data_designer/plugins/__init__.py +8 -0
  45. data_designer/plugins/errors.py +15 -0
  46. data_designer/plugins/plugin.py +141 -0
  47. data_designer/plugins/registry.py +88 -0
  48. data_designer_config-0.4.0.dist-info/METADATA +75 -0
  49. data_designer_config-0.4.0.dist-info/RECORD +50 -0
  50. data_designer_config-0.4.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,258 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ import logging
8
+ import os
9
+ from datetime import date, datetime, timedelta
10
+ from decimal import Decimal
11
+ from numbers import Number
12
+ from pathlib import Path
13
+ from typing import TYPE_CHECKING, Any
14
+
15
+ import yaml
16
+
17
+ from data_designer.config.errors import InvalidFileFormatError, InvalidFilePathError
18
+ from data_designer.lazy_heavy_imports import np, pd
19
+
20
+ if TYPE_CHECKING:
21
+ import numpy as np
22
+ import pandas as pd
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ VALID_DATASET_FILE_EXTENSIONS = {".parquet", ".csv", ".json", ".jsonl"}
27
+
28
+
29
+ def ensure_config_dir_exists(config_dir: Path) -> None:
30
+ """Create configuration directory if it doesn't exist.
31
+
32
+ Args:
33
+ config_dir: Directory path to create
34
+ """
35
+ config_dir.mkdir(parents=True, exist_ok=True)
36
+
37
+
38
+ def load_config_file(file_path: Path) -> dict:
39
+ """Load a YAML configuration file.
40
+
41
+ Args:
42
+ file_path: Path to the YAML file
43
+
44
+ Returns:
45
+ Parsed YAML content as dictionary
46
+
47
+ Raises:
48
+ InvalidFilePathError: If file doesn't exist
49
+ InvalidFileFormatError: If YAML is malformed
50
+ InvalidConfigError: If file is empty
51
+ """
52
+ from data_designer.config.errors import InvalidConfigError
53
+
54
+ if not file_path.exists():
55
+ raise InvalidFilePathError(f"Configuration file not found: {file_path}")
56
+
57
+ try:
58
+ with open(file_path) as f:
59
+ content = yaml.safe_load(f)
60
+
61
+ if content is None:
62
+ raise InvalidConfigError(f"Configuration file is empty: {file_path}")
63
+
64
+ return content
65
+
66
+ except yaml.YAMLError as e:
67
+ raise InvalidFileFormatError(f"Invalid YAML format in {file_path}: {e}")
68
+
69
+
70
+ def save_config_file(file_path: Path, config: dict) -> None:
71
+ """Save configuration to a YAML file.
72
+
73
+ Args:
74
+ file_path: Path where to save the file
75
+ config: Configuration dictionary to save
76
+
77
+ Raises:
78
+ IOError: If file cannot be written
79
+ """
80
+ # Ensure parent directory exists
81
+ file_path.parent.mkdir(parents=True, exist_ok=True)
82
+
83
+ with open(file_path, "w") as f:
84
+ yaml.safe_dump(
85
+ config,
86
+ f,
87
+ default_flow_style=False,
88
+ sort_keys=False,
89
+ indent=2,
90
+ allow_unicode=True,
91
+ )
92
+
93
+
94
+ def read_parquet_dataset(path: Path) -> pd.DataFrame:
95
+ """Read a parquet dataset from a path.
96
+
97
+ Args:
98
+ path: The path to the parquet dataset, can be either a file or a directory.
99
+
100
+ Returns:
101
+ The parquet dataset as a pandas DataFrame.
102
+ """
103
+ try:
104
+ return pd.read_parquet(path, dtype_backend="pyarrow")
105
+ except Exception as e:
106
+ if path.is_dir() and "Unsupported cast" in str(e):
107
+ logger.warning("Failed to read parquets as folder, falling back to individual files")
108
+ return pd.concat(
109
+ [pd.read_parquet(file, dtype_backend="pyarrow") for file in sorted(path.glob("*.parquet"))],
110
+ ignore_index=True,
111
+ )
112
+ else:
113
+ raise e
114
+
115
+
116
+ def validate_dataset_file_path(file_path: str | Path, should_exist: bool = True) -> Path:
117
+ """Validate that a dataset file path has a valid extension and optionally exists.
118
+
119
+ Args:
120
+ file_path: The path to validate, either as a string or Path object.
121
+ should_exist: If True, verify that the file exists. Defaults to True.
122
+ Returns:
123
+ The validated path as a Path object.
124
+ Raises:
125
+ InvalidFilePathError: If the path is not a file.
126
+ InvalidFileFormatError: If the path does not have a valid extension.
127
+ """
128
+ file_path = Path(file_path)
129
+ if should_exist and not Path(file_path).is_file():
130
+ raise InvalidFilePathError(f"🛑 Path {file_path} is not a file.")
131
+ if not file_path.name.lower().endswith(tuple(VALID_DATASET_FILE_EXTENSIONS)):
132
+ raise InvalidFileFormatError(
133
+ "🛑 Dataset files must be in parquet, csv, or jsonl/json (orient='records', lines=True) format."
134
+ )
135
+ return file_path
136
+
137
+
138
+ def validate_path_contains_files_of_type(path: str | Path, file_extension: str) -> None:
139
+ """Validate that a path contains files of a specific type.
140
+
141
+ Args:
142
+ path: The path to validate. Can contain wildcards like `*.parquet`.
143
+ file_extension: The extension of the files to validate (without the dot, e.g., "parquet").
144
+ Returns:
145
+ None if the path contains files of the specified type, raises an error otherwise.
146
+ Raises:
147
+ InvalidFilePathError: If the path does not contain files of the specified type.
148
+ """
149
+ if not any(Path(path).glob(f"*.{file_extension}")):
150
+ raise InvalidFilePathError(f"🛑 Path {path!r} does not contain files of type {file_extension!r}.")
151
+
152
+
153
+ def smart_load_dataframe(dataframe: str | Path | pd.DataFrame) -> pd.DataFrame:
154
+ """Load a dataframe from file if a path is given, otherwise return the dataframe.
155
+
156
+ Args:
157
+ dataframe: A path to a file or a pandas DataFrame object.
158
+
159
+ Returns:
160
+ A pandas DataFrame object.
161
+ """
162
+ if isinstance(dataframe, pd.DataFrame):
163
+ return dataframe
164
+
165
+ # Get the file extension.
166
+ if isinstance(dataframe, str) and dataframe.startswith("http"):
167
+ ext = dataframe.split(".")[-1].lower()
168
+ else:
169
+ dataframe = Path(dataframe)
170
+ ext = dataframe.suffix.lower()
171
+ if not dataframe.exists():
172
+ raise FileNotFoundError(f"File not found: {dataframe}")
173
+
174
+ # Load the dataframe based on the file extension.
175
+ if ext == "csv":
176
+ return pd.read_csv(dataframe)
177
+ elif ext == "json":
178
+ return pd.read_json(dataframe, lines=True)
179
+ elif ext == "parquet":
180
+ return pd.read_parquet(dataframe)
181
+ else:
182
+ raise ValueError(f"Unsupported file format: {dataframe}")
183
+
184
+
185
+ def smart_load_yaml(yaml_in: str | Path | dict) -> dict:
186
+ """Return the yaml config as a dict given flexible input types.
187
+
188
+ Args:
189
+ config: The config as a dict, yaml string, or yaml file path.
190
+
191
+ Returns:
192
+ The config as a dict.
193
+ """
194
+ if isinstance(yaml_in, dict):
195
+ yaml_out = yaml_in
196
+ elif isinstance(yaml_in, Path) or (isinstance(yaml_in, str) and os.path.isfile(yaml_in)):
197
+ with open(yaml_in) as file:
198
+ yaml_out = yaml.safe_load(file)
199
+ elif isinstance(yaml_in, str):
200
+ if yaml_in.endswith((".yaml", ".yml")) and not os.path.isfile(yaml_in):
201
+ raise FileNotFoundError(f"File not found: {yaml_in}")
202
+ else:
203
+ yaml_out = yaml.safe_load(yaml_in)
204
+ else:
205
+ raise ValueError(
206
+ f"'{yaml_in}' is an invalid yaml config format. Valid options are: dict, yaml string, or yaml file path."
207
+ )
208
+
209
+ if not isinstance(yaml_out, dict):
210
+ raise ValueError(f"Loaded yaml must be a dict. Got {yaml_out}, which is of type {type(yaml_out)}.")
211
+
212
+ return yaml_out
213
+
214
+
215
+ def serialize_data(data: dict | list | str | Number, **kwargs) -> str:
216
+ if isinstance(data, dict):
217
+ return json.dumps(data, ensure_ascii=False, default=_convert_to_serializable, **kwargs)
218
+ elif isinstance(data, list):
219
+ return json.dumps(data, ensure_ascii=False, default=_convert_to_serializable, **kwargs)
220
+ elif isinstance(data, str):
221
+ return data
222
+ elif isinstance(data, Number):
223
+ return str(data)
224
+ else:
225
+ raise ValueError(f"Invalid data type: {type(data)}")
226
+
227
+
228
+ def _convert_to_serializable(obj: Any) -> Any:
229
+ """Convert non-JSON-serializable objects to JSON-serializable Python-native types.
230
+
231
+ Raises:
232
+ TypeError: If the object type is not supported for serialization.
233
+ """
234
+ if isinstance(obj, (set, list)):
235
+ return list(obj)
236
+ if isinstance(obj, (pd.Series, np.ndarray)):
237
+ return obj.tolist()
238
+
239
+ if pd.isna(obj):
240
+ return None
241
+
242
+ if isinstance(obj, (datetime, date, pd.Timestamp)):
243
+ return obj.isoformat()
244
+ if isinstance(obj, timedelta):
245
+ return obj.total_seconds()
246
+ if isinstance(obj, (np.datetime64, np.timedelta64)):
247
+ return str(obj)
248
+
249
+ if isinstance(obj, Decimal):
250
+ return float(obj)
251
+ if isinstance(obj, (np.integer, np.floating, np.bool_)):
252
+ return obj.item()
253
+
254
+ if isinstance(obj, bytes):
255
+ return obj.decode("utf-8", errors="replace")
256
+
257
+ # Unsupported type
258
+ raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
@@ -0,0 +1,78 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ from contextlib import contextmanager
8
+
9
+ from jinja2 import TemplateSyntaxError, meta
10
+ from jinja2.sandbox import ImmutableSandboxedEnvironment
11
+
12
+ from data_designer.config.utils.errors import UserJinjaTemplateSyntaxError
13
+
14
+ REPR_LIST_LENGTH_USE_JSON = 4
15
+
16
+
17
+ def kebab_to_snake(s: str) -> str:
18
+ return s.replace("-", "_")
19
+
20
+
21
+ @contextmanager
22
+ def template_error_handler():
23
+ try:
24
+ yield
25
+ except TemplateSyntaxError as exception:
26
+ exception_string = (
27
+ f"Encountered a syntax error in the provided Jinja2 template:\n{str(exception)}\n"
28
+ "For more information on writing Jinja2 templates, "
29
+ "refer to https://jinja.palletsprojects.com/en/stable/templates"
30
+ )
31
+ raise UserJinjaTemplateSyntaxError(exception_string)
32
+ except Exception:
33
+ raise
34
+
35
+
36
+ def assert_valid_jinja2_template(template: str) -> None:
37
+ """Raises an error if the template cannot be parsed."""
38
+ with template_error_handler():
39
+ meta.find_undeclared_variables(ImmutableSandboxedEnvironment().parse(template))
40
+
41
+
42
+ def can_run_data_designer_locally() -> bool:
43
+ """Returns True if Data Designer can be run locally, False otherwise."""
44
+ try:
45
+ from ... import engine # noqa: F401, TID252
46
+ except ImportError:
47
+ return False
48
+ return True
49
+
50
+
51
+ def extract_keywords_from_jinja2_template(template: str) -> set[str]:
52
+ """Extract all keywords from a valid Jinja2 template."""
53
+ with template_error_handler():
54
+ ast = ImmutableSandboxedEnvironment().parse(template)
55
+ keywords = set(meta.find_undeclared_variables(ast))
56
+
57
+ return keywords
58
+
59
+
60
+ def json_indent_list_of_strings(column_names: list[str], *, indent: int | str | None = None) -> list[str] | str | None:
61
+ """Convert a list of column names to a JSON string if the list is long.
62
+
63
+ This function helps keep Data Designer's __repr__ output clean and readable.
64
+
65
+ Args:
66
+ column_names: List of column names.
67
+ indent: Indentation for the JSON string.
68
+
69
+ Returns:
70
+ A list of column names or a JSON string if the list is long.
71
+ """
72
+ return (
73
+ None
74
+ if len(column_names) == 0
75
+ else (
76
+ column_names if len(column_names) < REPR_LIST_LENGTH_USE_JSON else json.dumps(column_names, indent=indent)
77
+ )
78
+ )
@@ -0,0 +1,30 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import numbers
7
+ from numbers import Number
8
+ from typing import Any
9
+
10
+ from data_designer.config.utils.constants import REPORTING_PRECISION
11
+
12
+
13
+ def is_int(val: Any) -> bool:
14
+ return isinstance(val, numbers.Integral)
15
+
16
+
17
+ def is_float(val: Any) -> bool:
18
+ return isinstance(val, numbers.Real) and not isinstance(val, numbers.Integral)
19
+
20
+
21
+ def prepare_number_for_reporting(
22
+ value: Number,
23
+ target_type: type[Number],
24
+ precision: int = REPORTING_PRECISION,
25
+ ) -> Number:
26
+ """Ensure native python types and round to `precision` decimal digits."""
27
+ value = target_type(value)
28
+ if is_float(value):
29
+ return round(value, precision)
30
+ return value
@@ -0,0 +1,106 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import inspect
7
+ from enum import Enum
8
+ from typing import Any, Literal, get_args, get_origin
9
+
10
+ from pydantic import BaseModel
11
+
12
+ from data_designer.config import sampler_params
13
+ from data_designer.config.utils.errors import (
14
+ InvalidDiscriminatorFieldError,
15
+ InvalidEnumValueError,
16
+ InvalidTypeUnionError,
17
+ )
18
+
19
+
20
+ class StrEnum(str, Enum):
21
+ pass
22
+
23
+
24
+ def create_str_enum_from_discriminated_type_union(
25
+ enum_name: str,
26
+ type_union: type,
27
+ discriminator_field_name: str,
28
+ ) -> StrEnum:
29
+ """Create a string enum from a type union.
30
+
31
+ The type union is assumed to be a union of configs (Pydantic models) that have a discriminator field,
32
+ which must be a Literal string type - e.g., Literal["expression"].
33
+
34
+ Args:
35
+ enum_name: Name of the StrEnum.
36
+ type_union: Type union of configs (Pydantic models).
37
+ discriminator_field_name: Name of the discriminator field.
38
+
39
+ Returns:
40
+ StrEnum with values being the discriminator field values of the configs in the type union.
41
+
42
+ Example:
43
+ DataDesignerColumnType = create_str_enum_from_discriminated_type_union(
44
+ enum_name="DataDesignerColumnType",
45
+ type_union=ColumnConfigT,
46
+ discriminator_field_name="column_type",
47
+ )
48
+ """
49
+ discriminator_field_values = []
50
+ for model in type_union.__args__:
51
+ if not issubclass(model, BaseModel):
52
+ raise InvalidTypeUnionError(f"🛑 {model} must be a subclass of pydantic.BaseModel.")
53
+ if discriminator_field_name not in model.model_fields:
54
+ raise InvalidDiscriminatorFieldError(f"🛑 '{discriminator_field_name}' is not a field of {model}.")
55
+ if get_origin(model.model_fields[discriminator_field_name].annotation) is not Literal:
56
+ raise InvalidDiscriminatorFieldError(f"🛑 '{discriminator_field_name}' must be a Literal type.")
57
+ discriminator_field_values.extend(get_args(model.model_fields[discriminator_field_name].annotation))
58
+ return StrEnum(enum_name, {v.replace("-", "_").upper(): v for v in set(discriminator_field_values)})
59
+
60
+
61
+ def get_sampler_params() -> dict[str, type[BaseModel]]:
62
+ """Returns a dictionary of sampler parameter classes."""
63
+ params_cls_list = [
64
+ params_cls
65
+ for cls_name, params_cls in inspect.getmembers(sampler_params, inspect.isclass)
66
+ if cls_name.endswith("SamplerParams")
67
+ ]
68
+
69
+ params_cls_dict = {}
70
+
71
+ for source in sampler_params.SamplerType:
72
+ source_name = source.value.replace("_", "")
73
+ # Iterate in reverse order so the shortest match is first.
74
+ # This is necessary for params that start with the same name.
75
+ # For example, "bernoulli" and "bernoulli_mixture".
76
+ params_cls_dict[source.value] = [
77
+ params_cls
78
+ for params_cls in reversed(params_cls_list)
79
+ # Match param type string with parameter class.
80
+ # For example, "gaussian" -> "GaussianSamplerParams"
81
+ if source_name == params_cls.__name__.lower()[: len(source_name)]
82
+ # Take the first match.
83
+ ][0]
84
+
85
+ return params_cls_dict
86
+
87
+
88
+ def resolve_string_enum(enum_instance: Any, enum_type: type[Enum]) -> Enum:
89
+ if not issubclass(enum_type, Enum):
90
+ raise InvalidEnumValueError(f"🛑 `enum_type` must be a subclass of Enum. You provided: {enum_type}")
91
+ invalid_enum_value_error = InvalidEnumValueError(
92
+ f"🛑 '{enum_instance}' is not a valid string enum of type {type(enum_type)}. "
93
+ f"Valid options are: {[option.value for option in enum_type]}"
94
+ )
95
+ if isinstance(enum_instance, enum_type):
96
+ return enum_instance
97
+ elif isinstance(enum_instance, str):
98
+ try:
99
+ return enum_type(enum_instance)
100
+ except ValueError:
101
+ raise invalid_enum_value_error
102
+ else:
103
+ raise invalid_enum_value_error
104
+
105
+
106
+ SAMPLER_PARAMS = get_sampler_params()