data-designer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. data_designer/__init__.py +15 -0
  2. data_designer/_version.py +34 -0
  3. data_designer/cli/README.md +236 -0
  4. data_designer/cli/__init__.py +6 -0
  5. data_designer/cli/commands/__init__.py +2 -0
  6. data_designer/cli/commands/list.py +130 -0
  7. data_designer/cli/commands/models.py +10 -0
  8. data_designer/cli/commands/providers.py +11 -0
  9. data_designer/cli/commands/reset.py +100 -0
  10. data_designer/cli/controllers/__init__.py +7 -0
  11. data_designer/cli/controllers/model_controller.py +246 -0
  12. data_designer/cli/controllers/provider_controller.py +317 -0
  13. data_designer/cli/forms/__init__.py +20 -0
  14. data_designer/cli/forms/builder.py +51 -0
  15. data_designer/cli/forms/field.py +180 -0
  16. data_designer/cli/forms/form.py +59 -0
  17. data_designer/cli/forms/model_builder.py +125 -0
  18. data_designer/cli/forms/provider_builder.py +76 -0
  19. data_designer/cli/main.py +44 -0
  20. data_designer/cli/repositories/__init__.py +8 -0
  21. data_designer/cli/repositories/base.py +39 -0
  22. data_designer/cli/repositories/model_repository.py +42 -0
  23. data_designer/cli/repositories/provider_repository.py +43 -0
  24. data_designer/cli/services/__init__.py +7 -0
  25. data_designer/cli/services/model_service.py +116 -0
  26. data_designer/cli/services/provider_service.py +111 -0
  27. data_designer/cli/ui.py +448 -0
  28. data_designer/cli/utils.py +47 -0
  29. data_designer/config/__init__.py +2 -0
  30. data_designer/config/analysis/column_profilers.py +89 -0
  31. data_designer/config/analysis/column_statistics.py +274 -0
  32. data_designer/config/analysis/dataset_profiler.py +60 -0
  33. data_designer/config/analysis/utils/errors.py +8 -0
  34. data_designer/config/analysis/utils/reporting.py +188 -0
  35. data_designer/config/base.py +68 -0
  36. data_designer/config/column_configs.py +354 -0
  37. data_designer/config/column_types.py +168 -0
  38. data_designer/config/config_builder.py +660 -0
  39. data_designer/config/data_designer_config.py +40 -0
  40. data_designer/config/dataset_builders.py +11 -0
  41. data_designer/config/datastore.py +151 -0
  42. data_designer/config/default_model_settings.py +123 -0
  43. data_designer/config/errors.py +19 -0
  44. data_designer/config/interface.py +54 -0
  45. data_designer/config/models.py +231 -0
  46. data_designer/config/preview_results.py +32 -0
  47. data_designer/config/processors.py +41 -0
  48. data_designer/config/sampler_constraints.py +51 -0
  49. data_designer/config/sampler_params.py +604 -0
  50. data_designer/config/seed.py +145 -0
  51. data_designer/config/utils/code_lang.py +83 -0
  52. data_designer/config/utils/constants.py +313 -0
  53. data_designer/config/utils/errors.py +19 -0
  54. data_designer/config/utils/info.py +88 -0
  55. data_designer/config/utils/io_helpers.py +273 -0
  56. data_designer/config/utils/misc.py +81 -0
  57. data_designer/config/utils/numerical_helpers.py +28 -0
  58. data_designer/config/utils/type_helpers.py +100 -0
  59. data_designer/config/utils/validation.py +336 -0
  60. data_designer/config/utils/visualization.py +427 -0
  61. data_designer/config/validator_params.py +96 -0
  62. data_designer/engine/__init__.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +55 -0
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
  65. data_designer/engine/analysis/column_profilers/registry.py +20 -0
  66. data_designer/engine/analysis/column_statistics.py +142 -0
  67. data_designer/engine/analysis/dataset_profiler.py +125 -0
  68. data_designer/engine/analysis/errors.py +7 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
  70. data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
  71. data_designer/engine/column_generators/__init__.py +2 -0
  72. data_designer/engine/column_generators/generators/__init__.py +2 -0
  73. data_designer/engine/column_generators/generators/base.py +61 -0
  74. data_designer/engine/column_generators/generators/expression.py +63 -0
  75. data_designer/engine/column_generators/generators/llm_generators.py +172 -0
  76. data_designer/engine/column_generators/generators/samplers.py +75 -0
  77. data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
  78. data_designer/engine/column_generators/generators/validation.py +147 -0
  79. data_designer/engine/column_generators/registry.py +56 -0
  80. data_designer/engine/column_generators/utils/errors.py +13 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
  83. data_designer/engine/configurable_task.py +82 -0
  84. data_designer/engine/dataset_builders/artifact_storage.py +181 -0
  85. data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
  86. data_designer/engine/dataset_builders/errors.py +13 -0
  87. data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
  88. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +56 -0
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
  93. data_designer/engine/dataset_builders/utils/errors.py +13 -0
  94. data_designer/engine/errors.py +49 -0
  95. data_designer/engine/model_provider.py +75 -0
  96. data_designer/engine/models/__init__.py +2 -0
  97. data_designer/engine/models/errors.py +308 -0
  98. data_designer/engine/models/facade.py +225 -0
  99. data_designer/engine/models/litellm_overrides.py +162 -0
  100. data_designer/engine/models/parsers/__init__.py +2 -0
  101. data_designer/engine/models/parsers/errors.py +34 -0
  102. data_designer/engine/models/parsers/parser.py +236 -0
  103. data_designer/engine/models/parsers/postprocessors.py +93 -0
  104. data_designer/engine/models/parsers/tag_parsers.py +60 -0
  105. data_designer/engine/models/parsers/types.py +82 -0
  106. data_designer/engine/models/recipes/base.py +79 -0
  107. data_designer/engine/models/recipes/response_recipes.py +291 -0
  108. data_designer/engine/models/registry.py +118 -0
  109. data_designer/engine/models/usage.py +75 -0
  110. data_designer/engine/models/utils.py +38 -0
  111. data_designer/engine/processing/ginja/__init__.py +2 -0
  112. data_designer/engine/processing/ginja/ast.py +64 -0
  113. data_designer/engine/processing/ginja/environment.py +461 -0
  114. data_designer/engine/processing/ginja/exceptions.py +54 -0
  115. data_designer/engine/processing/ginja/record.py +30 -0
  116. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  117. data_designer/engine/processing/gsonschema/exceptions.py +8 -0
  118. data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
  119. data_designer/engine/processing/gsonschema/types.py +8 -0
  120. data_designer/engine/processing/gsonschema/validators.py +143 -0
  121. data_designer/engine/processing/processors/base.py +15 -0
  122. data_designer/engine/processing/processors/drop_columns.py +46 -0
  123. data_designer/engine/processing/processors/registry.py +20 -0
  124. data_designer/engine/processing/utils.py +120 -0
  125. data_designer/engine/registry/base.py +97 -0
  126. data_designer/engine/registry/data_designer_registry.py +37 -0
  127. data_designer/engine/registry/errors.py +10 -0
  128. data_designer/engine/resources/managed_dataset_generator.py +35 -0
  129. data_designer/engine/resources/managed_dataset_repository.py +194 -0
  130. data_designer/engine/resources/managed_storage.py +63 -0
  131. data_designer/engine/resources/resource_provider.py +46 -0
  132. data_designer/engine/resources/seed_dataset_data_store.py +66 -0
  133. data_designer/engine/sampling_gen/column.py +89 -0
  134. data_designer/engine/sampling_gen/constraints.py +95 -0
  135. data_designer/engine/sampling_gen/data_sources/base.py +214 -0
  136. data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
  137. data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
  138. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  139. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  140. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
  141. data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
  142. data_designer/engine/sampling_gen/entities/errors.py +8 -0
  143. data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
  144. data_designer/engine/sampling_gen/entities/person.py +142 -0
  145. data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
  146. data_designer/engine/sampling_gen/errors.py +24 -0
  147. data_designer/engine/sampling_gen/generator.py +121 -0
  148. data_designer/engine/sampling_gen/jinja_utils.py +60 -0
  149. data_designer/engine/sampling_gen/people_gen.py +203 -0
  150. data_designer/engine/sampling_gen/person_constants.py +54 -0
  151. data_designer/engine/sampling_gen/schema.py +143 -0
  152. data_designer/engine/sampling_gen/schema_builder.py +59 -0
  153. data_designer/engine/sampling_gen/utils.py +40 -0
  154. data_designer/engine/secret_resolver.py +80 -0
  155. data_designer/engine/validators/__init__.py +17 -0
  156. data_designer/engine/validators/base.py +36 -0
  157. data_designer/engine/validators/local_callable.py +34 -0
  158. data_designer/engine/validators/python.py +245 -0
  159. data_designer/engine/validators/remote.py +83 -0
  160. data_designer/engine/validators/sql.py +60 -0
  161. data_designer/errors.py +5 -0
  162. data_designer/essentials/__init__.py +137 -0
  163. data_designer/interface/__init__.py +2 -0
  164. data_designer/interface/data_designer.py +351 -0
  165. data_designer/interface/errors.py +16 -0
  166. data_designer/interface/results.py +55 -0
  167. data_designer/logging.py +161 -0
  168. data_designer/plugin_manager.py +83 -0
  169. data_designer/plugins/__init__.py +6 -0
  170. data_designer/plugins/errors.py +10 -0
  171. data_designer/plugins/plugin.py +69 -0
  172. data_designer/plugins/registry.py +86 -0
  173. data_designer-0.1.0.dist-info/METADATA +173 -0
  174. data_designer-0.1.0.dist-info/RECORD +177 -0
  175. data_designer-0.1.0.dist-info/WHEEL +4 -0
  176. data_designer-0.1.0.dist-info/entry_points.txt +2 -0
  177. data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,143 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from copy import deepcopy
5
+ import logging
6
+ from typing import Any, overload
7
+
8
+ from jsonschema import Draft202012Validator, ValidationError, validators
9
+
10
+ from data_designer.engine.processing.gsonschema.exceptions import JSONSchemaValidationError
11
+ from data_designer.engine.processing.gsonschema.schema_transformers import forbid_additional_properties
12
+ from data_designer.engine.processing.gsonschema.types import DataObjectT, JSONSchemaT, T_primitive
13
+
14
+ DEFAULT_JSONSCHEMA_VALIDATOR = Draft202012Validator
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def prune_additional_properties(
20
+ _, allow_additional_properties: bool, instance: DataObjectT, schema: JSONSchemaT
21
+ ) -> None:
22
+ """A JSONSchemaValidtor extension function to prune additional properties.
23
+
24
+ Operates on an individual schema in-place.
25
+
26
+ Args:
27
+ allow_additional_properties (bool): The value of the `additionalProperties`
28
+ field for this schema.
29
+ instance (DataObjectT): The data object being validated.
30
+ schema (JSONSchemaT): The schema for this object.
31
+
32
+ Returns:
33
+ Nothing (in place)
34
+ """
35
+ # Only act if the instance is a dict.
36
+ if not isinstance(instance, dict) or allow_additional_properties:
37
+ return
38
+
39
+ # Allowed keys are those defined in the schema's "properties".
40
+ allowed = schema.get("properties", {}).keys()
41
+
42
+ # Iterate over a copy of keys so we can modify the dict in place.
43
+ n_removed = 0
44
+ for key in list(instance.keys()):
45
+ if key not in allowed:
46
+ instance.pop(key)
47
+ n_removed += 1
48
+ logger.info(f"Unspecified property removed from data object: {key}.")
49
+
50
+ if n_removed > 0:
51
+ logger.info(f"{n_removed} unspecified properties removed from data object.")
52
+
53
+
54
+ def extend_jsonschema_validator_with_pruning(validator):
55
+ """Modify behavior of a jsonschema.Validator to use pruning.
56
+
57
+ Validators extended using this function will prune extra
58
+ fields, rather than raising a ValidationError, when encountering
59
+ extra, unspecified fiends when `additionalProperties: False` is
60
+ set in the validating schema.
61
+
62
+ Args:
63
+ validator (Type[jsonschema.Validator): A validator class
64
+ to extend with pruning behavior.
65
+
66
+ Returns:
67
+ Type[jsonschema.Validator]: A validator class that will
68
+ prune extra fields.
69
+ """
70
+ return validators.extend(validator, {"additionalProperties": prune_additional_properties})
71
+
72
+
73
+ ## We don't expect the outer data type (e.g. dict, list, or const) to be
74
+ ## modified by the pruning action.
75
+ @overload
76
+ def validate(
77
+ obj: dict[str, Any],
78
+ schema: JSONSchemaT,
79
+ pruning: bool = False,
80
+ no_extra_properties: bool = False,
81
+ ) -> dict[str, Any]: ...
82
+
83
+
84
+ @overload
85
+ def validate(
86
+ obj: list[Any],
87
+ schema: JSONSchemaT,
88
+ pruning: bool = False,
89
+ no_extra_properties: bool = False,
90
+ ) -> list[Any]: ...
91
+
92
+
93
+ @overload
94
+ def validate(
95
+ obj: T_primitive,
96
+ schema: JSONSchemaT,
97
+ pruning: bool = False,
98
+ no_extra_properties: bool = False,
99
+ ) -> T_primitive: ...
100
+
101
+
102
+ def validate(
103
+ obj: DataObjectT,
104
+ schema: JSONSchemaT,
105
+ pruning: bool = False,
106
+ no_extra_properties: bool = False,
107
+ ) -> DataObjectT:
108
+ """Validate a data object against a JSONSchema.
109
+
110
+ Args:
111
+ obj (DataObjectT): A data structure to validate against the
112
+ schema.
113
+ schema: (JSONSchemaT): A valid JSONSchema to use to validate
114
+ the provided data object.
115
+ pruning (bool): If set to `True`, then the default behavior for
116
+ `additionalProperties: False` is set to prune non-specified
117
+ properties instead of raising a ValidationError.
118
+ Default: `False`.
119
+ no_extra_properties (bool): If set to `True`, then
120
+ `additionalProperties: False` is set on all the schema
121
+ and all of its sub-schemas. This operation overrides any
122
+ existing settings of `additionalProperties` within the
123
+ schema. Default: `False`.
124
+
125
+ Raises:
126
+ JSONSchemaValidationError: This exception raised in the
127
+ event that the JSONSchema doesn't match the provided
128
+ schema.
129
+ """
130
+ final_object = deepcopy(obj)
131
+ validator = DEFAULT_JSONSCHEMA_VALIDATOR
132
+ if pruning:
133
+ validator = extend_jsonschema_validator_with_pruning(validator)
134
+
135
+ if no_extra_properties:
136
+ schema = forbid_additional_properties(schema)
137
+
138
+ try:
139
+ validator(schema).validate(final_object)
140
+ except ValidationError as exc:
141
+ raise JSONSchemaValidationError(str(exc)) from exc
142
+
143
+ return final_object
@@ -0,0 +1,15 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC, abstractmethod
5
+
6
+ from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT
7
+
8
+
9
+ class Processor(ConfigurableTask[TaskConfigT], ABC):
10
+ @staticmethod
11
+ @abstractmethod
12
+ def metadata() -> ConfigurableTaskMetadata: ...
13
+
14
+ @abstractmethod
15
+ def process(self, data: DataT, *, current_batch_number: int | None = None) -> DataT: ...
@@ -0,0 +1,46 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import logging
5
+
6
+ import pandas as pd
7
+
8
+ from data_designer.config.processors import DropColumnsProcessorConfig
9
+ from data_designer.engine.configurable_task import ConfigurableTaskMetadata
10
+ from data_designer.engine.dataset_builders.artifact_storage import BatchStage
11
+ from data_designer.engine.processing.processors.base import Processor
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class DropColumnsProcessor(Processor[DropColumnsProcessorConfig]):
17
+ @staticmethod
18
+ def metadata() -> ConfigurableTaskMetadata:
19
+ return ConfigurableTaskMetadata(
20
+ name="drop_columns",
21
+ description="Drop columns from the input dataset.",
22
+ required_resources=None,
23
+ )
24
+
25
+ def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None) -> pd.DataFrame:
26
+ logger.info(f"🙈 Dropping columns: {self.config.column_names}")
27
+ if current_batch_number is not None: # not in preview mode
28
+ self._save_dropped_columns_if_needed(data, current_batch_number)
29
+ for column in self.config.column_names:
30
+ if column in data.columns:
31
+ data.drop(columns=[column], inplace=True)
32
+ else:
33
+ logger.warning(f"⚠️ Cannot drop column: `{column}` not found in the dataset.")
34
+ return data
35
+
36
+ def _save_dropped_columns_if_needed(self, data: pd.DataFrame, current_batch_number: int) -> None:
37
+ logger.debug("📦 Saving dropped columns to dropped-columns directory")
38
+ dropped_column_parquet_file_name = self.artifact_storage.create_batch_file_path(
39
+ batch_number=current_batch_number,
40
+ batch_stage=BatchStage.DROPPED_COLUMNS,
41
+ ).name
42
+ self.artifact_storage.write_parquet_file(
43
+ parquet_file_name=dropped_column_parquet_file_name,
44
+ dataframe=data[self.config.column_names],
45
+ batch_stage=BatchStage.DROPPED_COLUMNS,
46
+ )
@@ -0,0 +1,20 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from data_designer.config.base import ConfigBase
5
+ from data_designer.config.processors import (
6
+ DropColumnsProcessorConfig,
7
+ ProcessorType,
8
+ )
9
+ from data_designer.engine.processing.processors.base import Processor
10
+ from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
11
+ from data_designer.engine.registry.base import TaskRegistry
12
+
13
+
14
+ class ProcessorRegistry(TaskRegistry[str, Processor, ConfigBase]): ...
15
+
16
+
17
+ def create_default_processor_registry() -> ProcessorRegistry:
18
+ registry = ProcessorRegistry()
19
+ registry.register(ProcessorType.DROP_COLUMNS, DropColumnsProcessor, DropColumnsProcessorConfig, False)
20
+ return registry
@@ -0,0 +1,120 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import json
5
+ import logging
6
+ from typing import Any, TypeVar, Union, overload
7
+
8
+ import pandas as pd
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ T = TypeVar("T")
13
+ K = TypeVar("K")
14
+ V = TypeVar("V")
15
+
16
+
17
+ def concat_datasets(datasets: list[pd.DataFrame]) -> pd.DataFrame:
18
+ _verify_columns_are_unique(datasets)
19
+ _verify_dataset_lengths_are_equal(datasets)
20
+ emoji = " + ".join(["💾"] * len(datasets))
21
+ logger.info(f"({emoji}) Concatenating {len(datasets)} datasets")
22
+ return pd.concat([df for df in datasets], axis=1)
23
+
24
+
25
+ # Overloads to help static type checker better understand
26
+ # the input/output types of the deserialize_json_values function.
27
+ @overload
28
+ def deserialize_json_values(data: str) -> Union[dict[str, Any], list[Any], Any]: ...
29
+
30
+
31
+ @overload
32
+ def deserialize_json_values(data: list[T]) -> list[Any]: ...
33
+
34
+
35
+ @overload
36
+ def deserialize_json_values(data: dict[K, V]) -> dict[K, Any]: ...
37
+
38
+
39
+ @overload
40
+ def deserialize_json_values(data: T) -> T: ...
41
+
42
+
43
+ def deserialize_json_values(data):
44
+ """De-serialize JSON strings in various input formats.
45
+
46
+ Args:
47
+ data: Input data in one of four formats:
48
+ - Single string (JSON string to deserialize)
49
+ - List of strings (list of JSON strings to deserialize)
50
+ - Dictionary (potentially with nested JSON strings to deserialize)
51
+ - Some other object that can't be deserialized.
52
+
53
+
54
+ Returns:
55
+ Deserialized data in the corresponding format:
56
+ - Dictionary (when input is a single string)
57
+ - List of dictionaries (when input is a list of strings)
58
+ - Dictionary (when input is a dictionary, with nested JSON strings deserialized)
59
+ - The original object (if there is no deserialization to perform)
60
+ """
61
+ # Case 1: Single string input
62
+ if isinstance(data, str):
63
+ try:
64
+ return json.loads(data)
65
+ except json.JSONDecodeError:
66
+ return data
67
+
68
+ # Case 2: List of strings input
69
+ elif isinstance(data, list):
70
+ result = []
71
+ for item in data:
72
+ if isinstance(item, str):
73
+ try:
74
+ result.append(json.loads(item))
75
+ except json.JSONDecodeError:
76
+ result.append(item)
77
+ else:
78
+ # If list contains non-string items, recursively process them
79
+ result.append(deserialize_json_values(item))
80
+ return result
81
+
82
+ # Case 3: Dictionary input with potential nested JSON strings
83
+ elif isinstance(data, dict):
84
+ result = {}
85
+ for key, value in data.items():
86
+ if isinstance(value, str):
87
+ try:
88
+ result[key] = json.loads(value)
89
+ except json.JSONDecodeError:
90
+ result[key] = value
91
+ elif isinstance(value, (dict, list)):
92
+ # Recursively process nested dictionaries and lists
93
+ result[key] = deserialize_json_values(value)
94
+ else:
95
+ result[key] = value
96
+ return result
97
+
98
+ # Fallback for other data types
99
+ else:
100
+ return data
101
+
102
+
103
+ def _verify_columns_are_unique(datasets: list[pd.DataFrame]) -> None:
104
+ joined_columns = set()
105
+ for df in datasets:
106
+ columns = set(df.columns)
107
+ overlapping_columns = joined_columns & columns
108
+ if len(overlapping_columns) > 0:
109
+ raise ValueError(
110
+ f"🛑 Input datasets have overlapping columns: {overlapping_columns} "
111
+ "Please ensure that the column names are unique."
112
+ )
113
+ joined_columns.update(columns)
114
+
115
+
116
+ def _verify_dataset_lengths_are_equal(datasets: list[pd.DataFrame]) -> None:
117
+ if len(set([len(df) for df in datasets])) > 1:
118
+ raise ValueError(
119
+ "🛑 Input datasets have different lengths. Please ensure that the datasets have the same number of rows."
120
+ )
@@ -0,0 +1,97 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import threading
5
+ from typing import Any, Generic, Type, TypeVar
6
+
7
+ from data_designer.config.base import ConfigBase
8
+ from data_designer.config.utils.type_helpers import StrEnum
9
+ from data_designer.engine.configurable_task import ConfigurableTask
10
+ from data_designer.engine.registry.errors import NotFoundInRegistryError, RegistryItemNotTypeError
11
+
12
+ EnumNameT = TypeVar("EnumNameT", bound=StrEnum)
13
+ TaskT = TypeVar("TaskT", bound=ConfigurableTask)
14
+ TaskConfigT = TypeVar("TaskConfigT", bound=ConfigBase)
15
+
16
+
17
+ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
18
+ # registered type name -> type
19
+ _registry: dict[EnumNameT, Type[TaskT]] = {}
20
+ # type -> registered type name
21
+ _reverse_registry: dict[Type[TaskT], EnumNameT] = {}
22
+
23
+ # registered type name -> config type
24
+ _config_registry: dict[EnumNameT, Type[TaskConfigT]] = {}
25
+ # config type -> registered type name
26
+ _reverse_config_registry: dict[Type[TaskConfigT], EnumNameT] = {}
27
+
28
+ # all registries are singletons
29
+ _instance = None
30
+ _lock = threading.Lock()
31
+
32
+ @classmethod
33
+ def register(
34
+ cls,
35
+ name: EnumNameT,
36
+ task: Type[TaskT],
37
+ config: Type[TaskConfigT],
38
+ raise_on_collision: bool = False,
39
+ ) -> None:
40
+ if cls._has_been_registered(name):
41
+ if not raise_on_collision:
42
+ return
43
+ raise ValueError(f"{name} has already been registered!")
44
+
45
+ cls._raise_if_not_type(task)
46
+ cls._raise_if_not_type(config)
47
+
48
+ with cls._lock:
49
+ cls._registry[name] = task
50
+ cls._reverse_registry[task] = name
51
+ cls._config_registry[name] = config
52
+ cls._reverse_config_registry[config] = name
53
+
54
+ @classmethod
55
+ def get_task_type(cls, name: EnumNameT) -> Type[TaskT]:
56
+ cls._raise_if_not_registered(name, cls._registry)
57
+ return cls._registry[name]
58
+
59
+ @classmethod
60
+ def get_config_type(cls, name: EnumNameT) -> Type[TaskConfigT]:
61
+ cls._raise_if_not_registered(name, cls._config_registry)
62
+ return cls._config_registry[name]
63
+
64
+ @classmethod
65
+ def get_registered_name(cls, task: Type[TaskT]) -> EnumNameT:
66
+ cls._raise_if_not_registered(task, cls._reverse_registry)
67
+ return cls._reverse_registry[task]
68
+
69
+ @classmethod
70
+ def get_for_config_type(cls, config: Type[TaskConfigT]) -> Type[TaskT]:
71
+ cls._raise_if_not_registered(config, cls._reverse_config_registry)
72
+ name = cls._reverse_config_registry[config]
73
+ return cls.get_task_type(name)
74
+
75
+ @classmethod
76
+ def _has_been_registered(cls, name: EnumNameT) -> bool:
77
+ return name in cls._registry
78
+
79
+ @classmethod
80
+ def _raise_if_not_registered(cls, key: EnumNameT | Type[TaskT] | Type[TaskConfigT], mapping: dict) -> None:
81
+ if not (isinstance(key, StrEnum) or isinstance(key, str)):
82
+ cls._raise_if_not_type(key)
83
+ if key not in mapping:
84
+ raise NotFoundInRegistryError(f"{key} not found in registry")
85
+
86
+ @classmethod
87
+ def _raise_if_not_type(cls, obj: Any) -> None:
88
+ if not isinstance(obj, type):
89
+ raise RegistryItemNotTypeError(f"{obj} is not a class!")
90
+
91
+ def __new__(cls, *args, **kwargs):
92
+ """Registry is a singleton."""
93
+ if not cls._instance:
94
+ with cls._lock:
95
+ if not cls._instance:
96
+ cls._instance = super().__new__(cls)
97
+ return cls._instance
@@ -0,0 +1,37 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from data_designer.engine.analysis.column_profilers.registry import (
5
+ ColumnProfilerRegistry,
6
+ create_default_column_profiler_registry,
7
+ )
8
+ from data_designer.engine.column_generators.registry import (
9
+ ColumnGeneratorRegistry,
10
+ create_default_column_generator_registry,
11
+ )
12
+ from data_designer.engine.processing.processors.registry import ProcessorRegistry, create_default_processor_registry
13
+
14
+
15
+ class DataDesignerRegistry:
16
+ def __init__(
17
+ self,
18
+ *,
19
+ column_generator_registry: ColumnGeneratorRegistry | None = None,
20
+ column_profiler_registry: ColumnProfilerRegistry | None = None,
21
+ processor_registry: ProcessorRegistry | None = None,
22
+ ):
23
+ self._column_generator_registry = column_generator_registry or create_default_column_generator_registry()
24
+ self._column_profiler_registry = column_profiler_registry or create_default_column_profiler_registry()
25
+ self._processor_registry = processor_registry or create_default_processor_registry()
26
+
27
+ @property
28
+ def column_generators(self) -> ColumnGeneratorRegistry:
29
+ return self._column_generator_registry
30
+
31
+ @property
32
+ def column_profilers(self) -> ColumnProfilerRegistry:
33
+ return self._column_profiler_registry
34
+
35
+ @property
36
+ def processors(self) -> ProcessorRegistry:
37
+ return self._processor_registry
@@ -0,0 +1,10 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from data_designer.engine.errors import DataDesignerError
5
+
6
+
7
+ class NotFoundInRegistryError(DataDesignerError): ...
8
+
9
+
10
+ class RegistryItemNotTypeError(DataDesignerError): ...
@@ -0,0 +1,35 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ import pandas as pd
7
+
8
+ from data_designer.engine.resources.managed_dataset_repository import ManagedDatasetRepository
9
+
10
+
11
+ class ManagedDatasetGenerator:
12
+ def __init__(self, managed_datasets: ManagedDatasetRepository, dataset_name: str):
13
+ self.managed_datasets = managed_datasets
14
+ self.dataset_name = dataset_name
15
+
16
+ def generate_samples(
17
+ self,
18
+ size: int = 1,
19
+ evidence: dict[str, Any | list[Any]] = {},
20
+ ) -> pd.DataFrame:
21
+ parameters = []
22
+ query = f"select * from {self.dataset_name}"
23
+ if evidence:
24
+ where_conditions = []
25
+ for column, values in evidence.items():
26
+ if values:
27
+ values = values if isinstance(values, list) else [values]
28
+ formatted_values = ["?"] * len(values)
29
+ condition = f"{column} IN ({', '.join(formatted_values)})"
30
+ where_conditions.append(condition)
31
+ parameters.extend(values)
32
+ if where_conditions:
33
+ query += " where " + " and ".join(where_conditions)
34
+ query += f" order by random() limit {size}"
35
+ return self.managed_datasets.query(query, parameters)