data-designer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. data_designer/__init__.py +15 -0
  2. data_designer/_version.py +34 -0
  3. data_designer/cli/README.md +236 -0
  4. data_designer/cli/__init__.py +6 -0
  5. data_designer/cli/commands/__init__.py +2 -0
  6. data_designer/cli/commands/list.py +130 -0
  7. data_designer/cli/commands/models.py +10 -0
  8. data_designer/cli/commands/providers.py +11 -0
  9. data_designer/cli/commands/reset.py +100 -0
  10. data_designer/cli/controllers/__init__.py +7 -0
  11. data_designer/cli/controllers/model_controller.py +246 -0
  12. data_designer/cli/controllers/provider_controller.py +317 -0
  13. data_designer/cli/forms/__init__.py +20 -0
  14. data_designer/cli/forms/builder.py +51 -0
  15. data_designer/cli/forms/field.py +180 -0
  16. data_designer/cli/forms/form.py +59 -0
  17. data_designer/cli/forms/model_builder.py +125 -0
  18. data_designer/cli/forms/provider_builder.py +76 -0
  19. data_designer/cli/main.py +44 -0
  20. data_designer/cli/repositories/__init__.py +8 -0
  21. data_designer/cli/repositories/base.py +39 -0
  22. data_designer/cli/repositories/model_repository.py +42 -0
  23. data_designer/cli/repositories/provider_repository.py +43 -0
  24. data_designer/cli/services/__init__.py +7 -0
  25. data_designer/cli/services/model_service.py +116 -0
  26. data_designer/cli/services/provider_service.py +111 -0
  27. data_designer/cli/ui.py +448 -0
  28. data_designer/cli/utils.py +47 -0
  29. data_designer/config/__init__.py +2 -0
  30. data_designer/config/analysis/column_profilers.py +89 -0
  31. data_designer/config/analysis/column_statistics.py +274 -0
  32. data_designer/config/analysis/dataset_profiler.py +60 -0
  33. data_designer/config/analysis/utils/errors.py +8 -0
  34. data_designer/config/analysis/utils/reporting.py +188 -0
  35. data_designer/config/base.py +68 -0
  36. data_designer/config/column_configs.py +354 -0
  37. data_designer/config/column_types.py +168 -0
  38. data_designer/config/config_builder.py +660 -0
  39. data_designer/config/data_designer_config.py +40 -0
  40. data_designer/config/dataset_builders.py +11 -0
  41. data_designer/config/datastore.py +151 -0
  42. data_designer/config/default_model_settings.py +123 -0
  43. data_designer/config/errors.py +19 -0
  44. data_designer/config/interface.py +54 -0
  45. data_designer/config/models.py +231 -0
  46. data_designer/config/preview_results.py +32 -0
  47. data_designer/config/processors.py +41 -0
  48. data_designer/config/sampler_constraints.py +51 -0
  49. data_designer/config/sampler_params.py +604 -0
  50. data_designer/config/seed.py +145 -0
  51. data_designer/config/utils/code_lang.py +83 -0
  52. data_designer/config/utils/constants.py +313 -0
  53. data_designer/config/utils/errors.py +19 -0
  54. data_designer/config/utils/info.py +88 -0
  55. data_designer/config/utils/io_helpers.py +273 -0
  56. data_designer/config/utils/misc.py +81 -0
  57. data_designer/config/utils/numerical_helpers.py +28 -0
  58. data_designer/config/utils/type_helpers.py +100 -0
  59. data_designer/config/utils/validation.py +336 -0
  60. data_designer/config/utils/visualization.py +427 -0
  61. data_designer/config/validator_params.py +96 -0
  62. data_designer/engine/__init__.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +55 -0
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
  65. data_designer/engine/analysis/column_profilers/registry.py +20 -0
  66. data_designer/engine/analysis/column_statistics.py +142 -0
  67. data_designer/engine/analysis/dataset_profiler.py +125 -0
  68. data_designer/engine/analysis/errors.py +7 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
  70. data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
  71. data_designer/engine/column_generators/__init__.py +2 -0
  72. data_designer/engine/column_generators/generators/__init__.py +2 -0
  73. data_designer/engine/column_generators/generators/base.py +61 -0
  74. data_designer/engine/column_generators/generators/expression.py +63 -0
  75. data_designer/engine/column_generators/generators/llm_generators.py +172 -0
  76. data_designer/engine/column_generators/generators/samplers.py +75 -0
  77. data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
  78. data_designer/engine/column_generators/generators/validation.py +147 -0
  79. data_designer/engine/column_generators/registry.py +56 -0
  80. data_designer/engine/column_generators/utils/errors.py +13 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
  83. data_designer/engine/configurable_task.py +82 -0
  84. data_designer/engine/dataset_builders/artifact_storage.py +181 -0
  85. data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
  86. data_designer/engine/dataset_builders/errors.py +13 -0
  87. data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
  88. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +56 -0
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
  93. data_designer/engine/dataset_builders/utils/errors.py +13 -0
  94. data_designer/engine/errors.py +49 -0
  95. data_designer/engine/model_provider.py +75 -0
  96. data_designer/engine/models/__init__.py +2 -0
  97. data_designer/engine/models/errors.py +308 -0
  98. data_designer/engine/models/facade.py +225 -0
  99. data_designer/engine/models/litellm_overrides.py +162 -0
  100. data_designer/engine/models/parsers/__init__.py +2 -0
  101. data_designer/engine/models/parsers/errors.py +34 -0
  102. data_designer/engine/models/parsers/parser.py +236 -0
  103. data_designer/engine/models/parsers/postprocessors.py +93 -0
  104. data_designer/engine/models/parsers/tag_parsers.py +60 -0
  105. data_designer/engine/models/parsers/types.py +82 -0
  106. data_designer/engine/models/recipes/base.py +79 -0
  107. data_designer/engine/models/recipes/response_recipes.py +291 -0
  108. data_designer/engine/models/registry.py +118 -0
  109. data_designer/engine/models/usage.py +75 -0
  110. data_designer/engine/models/utils.py +38 -0
  111. data_designer/engine/processing/ginja/__init__.py +2 -0
  112. data_designer/engine/processing/ginja/ast.py +64 -0
  113. data_designer/engine/processing/ginja/environment.py +461 -0
  114. data_designer/engine/processing/ginja/exceptions.py +54 -0
  115. data_designer/engine/processing/ginja/record.py +30 -0
  116. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  117. data_designer/engine/processing/gsonschema/exceptions.py +8 -0
  118. data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
  119. data_designer/engine/processing/gsonschema/types.py +8 -0
  120. data_designer/engine/processing/gsonschema/validators.py +143 -0
  121. data_designer/engine/processing/processors/base.py +15 -0
  122. data_designer/engine/processing/processors/drop_columns.py +46 -0
  123. data_designer/engine/processing/processors/registry.py +20 -0
  124. data_designer/engine/processing/utils.py +120 -0
  125. data_designer/engine/registry/base.py +97 -0
  126. data_designer/engine/registry/data_designer_registry.py +37 -0
  127. data_designer/engine/registry/errors.py +10 -0
  128. data_designer/engine/resources/managed_dataset_generator.py +35 -0
  129. data_designer/engine/resources/managed_dataset_repository.py +194 -0
  130. data_designer/engine/resources/managed_storage.py +63 -0
  131. data_designer/engine/resources/resource_provider.py +46 -0
  132. data_designer/engine/resources/seed_dataset_data_store.py +66 -0
  133. data_designer/engine/sampling_gen/column.py +89 -0
  134. data_designer/engine/sampling_gen/constraints.py +95 -0
  135. data_designer/engine/sampling_gen/data_sources/base.py +214 -0
  136. data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
  137. data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
  138. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  139. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  140. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
  141. data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
  142. data_designer/engine/sampling_gen/entities/errors.py +8 -0
  143. data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
  144. data_designer/engine/sampling_gen/entities/person.py +142 -0
  145. data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
  146. data_designer/engine/sampling_gen/errors.py +24 -0
  147. data_designer/engine/sampling_gen/generator.py +121 -0
  148. data_designer/engine/sampling_gen/jinja_utils.py +60 -0
  149. data_designer/engine/sampling_gen/people_gen.py +203 -0
  150. data_designer/engine/sampling_gen/person_constants.py +54 -0
  151. data_designer/engine/sampling_gen/schema.py +143 -0
  152. data_designer/engine/sampling_gen/schema_builder.py +59 -0
  153. data_designer/engine/sampling_gen/utils.py +40 -0
  154. data_designer/engine/secret_resolver.py +80 -0
  155. data_designer/engine/validators/__init__.py +17 -0
  156. data_designer/engine/validators/base.py +36 -0
  157. data_designer/engine/validators/local_callable.py +34 -0
  158. data_designer/engine/validators/python.py +245 -0
  159. data_designer/engine/validators/remote.py +83 -0
  160. data_designer/engine/validators/sql.py +60 -0
  161. data_designer/errors.py +5 -0
  162. data_designer/essentials/__init__.py +137 -0
  163. data_designer/interface/__init__.py +2 -0
  164. data_designer/interface/data_designer.py +351 -0
  165. data_designer/interface/errors.py +16 -0
  166. data_designer/interface/results.py +55 -0
  167. data_designer/logging.py +161 -0
  168. data_designer/plugin_manager.py +83 -0
  169. data_designer/plugins/__init__.py +6 -0
  170. data_designer/plugins/errors.py +10 -0
  171. data_designer/plugins/plugin.py +69 -0
  172. data_designer/plugins/registry.py +86 -0
  173. data_designer-0.1.0.dist-info/METADATA +173 -0
  174. data_designer-0.1.0.dist-info/RECORD +177 -0
  175. data_designer-0.1.0.dist-info/WHEEL +4 -0
  176. data_designer-0.1.0.dist-info/entry_points.txt +2 -0
  177. data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,60 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import ast
5
+ from typing import Any
6
+
7
+ import pandas as pd
8
+
9
+ from data_designer.engine.processing.ginja.environment import (
10
+ UserTemplateSandboxEnvironment,
11
+ WithJinja2UserTemplateRendering,
12
+ )
13
+
14
+
15
+ class JinjaDataFrame(WithJinja2UserTemplateRendering):
16
+ def __init__(self, expr: str):
17
+ self.expr = expr
18
+
19
+ def _jsonify(self, record) -> dict[str, Any]:
20
+ for key, value in record.items():
21
+ if isinstance(value, pd.Timestamp):
22
+ record[key] = value.isoformat()
23
+ return record
24
+
25
+ @property
26
+ def _expr(self) -> str:
27
+ return "{{ " + self.expr + " }}"
28
+
29
+ def select_index(self, dataframe: pd.DataFrame) -> pd.Index:
30
+ if dataframe.empty or self.expr == "...":
31
+ return dataframe.index
32
+
33
+ self.prepare_jinja2_template_renderer(self._expr, list(dataframe))
34
+
35
+ where = dataframe.apply(
36
+ lambda row: self.render_template(self._jsonify(row.to_dict())) == "True",
37
+ axis=1,
38
+ ).to_numpy()
39
+
40
+ return dataframe[where].index
41
+
42
+ def to_column(self, dataframe: pd.DataFrame) -> list[Any]:
43
+ self.prepare_jinja2_template_renderer(self._expr, list(dataframe))
44
+
45
+ expr_values = []
46
+ for record in dataframe.to_dict(orient="records"):
47
+ rendered = self.render_template(self._jsonify(record))
48
+ try:
49
+ # Non-string expressions are evaluated as literals.
50
+ expr_values.append(ast.literal_eval(rendered))
51
+ except (ValueError, SyntaxError):
52
+ # Strings throw an error and are appended directly.
53
+ expr_values.append(rendered)
54
+
55
+ return expr_values
56
+
57
+
58
+ def extract_column_names_from_expression(expr: str) -> set[str]:
59
+ """Extract valid column names from the given expression."""
60
+ return UserTemplateSandboxEnvironment().get_references("{{ " + expr + " }}")
@@ -0,0 +1,203 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from abc import ABC, abstractmethod
7
+ from collections.abc import Callable
8
+ from copy import deepcopy
9
+ import random
10
+ from typing import TYPE_CHECKING, Any, Union
11
+ import uuid
12
+
13
+ from faker import Faker
14
+ import pandas as pd
15
+
16
+ from data_designer.config.utils.constants import AVAILABLE_LOCALES, DEFAULT_AGE_RANGE
17
+ from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
18
+ from data_designer.engine.sampling_gen.entities.dataset_based_person_fields import PERSONA_FIELDS, PII_FIELDS
19
+ from data_designer.engine.sampling_gen.entities.person import (
20
+ convert_age_to_birth_date,
21
+ generate_and_insert_derived_fields,
22
+ )
23
+ from data_designer.engine.sampling_gen.errors import ManagedDatasetGeneratorError
24
+ from data_designer.engine.sampling_gen.person_constants import faker_constants
25
+
26
+ if TYPE_CHECKING:
27
+ from data_designer.engine.sampling_gen.schema import DataSchema
28
+
29
+
30
+ EngineT = Union[Faker, ManagedDatasetGenerator]
31
+
32
+
33
+ class PeopleGen(ABC):
34
+ """Unified interface for generating people data."""
35
+
36
+ def __init__(self, engine: EngineT, locale: str):
37
+ if locale not in AVAILABLE_LOCALES:
38
+ raise ValueError(
39
+ f"Locale {locale} is not a supported locale.Supported locales: {', '.join(AVAILABLE_LOCALES)}"
40
+ )
41
+ self.locale = locale
42
+ self._engine = engine
43
+
44
+ def set_engine(self, engine: EngineT) -> None:
45
+ self._engine = engine
46
+
47
+ @abstractmethod
48
+ def generate(self, n: int, **kwargs) -> list[dict[str, Any]]: ...
49
+
50
+
51
+ class PeopleGenFaker(PeopleGen):
52
+ @property
53
+ def _fake(self) -> Faker:
54
+ return self._engine
55
+
56
+ def try_fake_else_none(self, attr_name: str, none_fill: Any | None = None) -> type:
57
+ try:
58
+ return getattr(self._fake, attr_name)()
59
+ except AttributeError:
60
+ return none_fill
61
+
62
+ def _generate_name_and_sex(self, **kwargs) -> dict[str, str]:
63
+ options = faker_constants.sex
64
+ if "sex" in kwargs and kwargs["sex"] in [*options, *[[o] for o in options]]:
65
+ sex = random.choice(kwargs["sex"]) if isinstance(kwargs["sex"], list) else kwargs["sex"]
66
+ else:
67
+ sex = random.choice(options)
68
+
69
+ return {
70
+ "first_name": getattr(self._fake, f"first_name_{sex.lower()}")(),
71
+ "last_name": getattr(self._fake, f"last_name_{sex.lower()}")(),
72
+ "middle_name": None,
73
+ "sex": sex,
74
+ }
75
+
76
+ def _generate_address_fields(self, **kwargs) -> dict[str, str]:
77
+ address = {
78
+ "street_number": self.try_fake_else_none(faker_constants.attr_map["street_number"]),
79
+ "street_name": self.try_fake_else_none("street_name"),
80
+ }
81
+
82
+ # Location fields can be filtered using the fixed_kwargs.
83
+ for attr in faker_constants.location:
84
+ if attr in kwargs:
85
+ address[attr] = random.choice(kwargs[attr]) if isinstance(kwargs[attr], list) else kwargs[attr]
86
+ else:
87
+ address[attr] = self.try_fake_else_none(attr)
88
+
89
+ return address
90
+
91
+ def _generate_age(self, **kwargs) -> int:
92
+ return random.randint(*kwargs.get("age_range", DEFAULT_AGE_RANGE))
93
+
94
+ def _generate_marital_status(self, **kwargs) -> str:
95
+ return random.choice(faker_constants.marital_status)
96
+
97
+ def _generate_bachelors_field(self, **kwargs) -> str:
98
+ return random.choice(faker_constants.bachelors)
99
+
100
+ def _generate_education_level(self, **kwargs) -> str:
101
+ return random.choice(faker_constants.education_level)
102
+
103
+ def make_person(self, **kwargs) -> dict[str, Any]:
104
+ person = {"uuid": str(uuid.uuid4()), "locale": self.locale}
105
+ person.update(self._generate_name_and_sex(**kwargs))
106
+ person.update(self._generate_address_fields(**kwargs))
107
+ person.update({"age": self._generate_age(**kwargs)})
108
+ person.update({"birth_date": convert_age_to_birth_date(person["age"]).isoformat()})
109
+ person.update({"country": self.try_fake_else_none("country")})
110
+ person.update({"marital_status": self._generate_marital_status(**kwargs)})
111
+ person.update({"education_level": self._generate_education_level(**kwargs)})
112
+ person.update({"unit": ""})
113
+ person.update({"occupation": self.try_fake_else_none(faker_constants.attr_map["occupation"])})
114
+ person.update({"phone_number": (self.try_fake_else_none("phone_number") if person["age"] >= 18 else None)})
115
+ if person["education_level"] in faker_constants.college_level:
116
+ person.update({"bachelors_field": self._generate_bachelors_field(**kwargs)})
117
+ else:
118
+ person.update({"bachelors_field": "no_degree"})
119
+ return person
120
+
121
+ def generate(self, n: int, **kwargs) -> list[dict[str, Any]]:
122
+ return [self.make_person(**kwargs) for _ in range(n)]
123
+
124
+
125
+ class PeopleGenFromDataset(PeopleGen):
126
+ def _get_ages(self, age_range: tuple[int, int]) -> list[int]:
127
+ return list(range(age_range[0], age_range[1] + 1))
128
+
129
+ def _generate_from_dataset(self, n: int, **kwargs) -> pd.DataFrame:
130
+ kw = deepcopy(kwargs)
131
+ with_synthetic_personas = kw.pop("with_synthetic_personas", False)
132
+ kw["age"] = self._get_ages(kw.pop("age_range", DEFAULT_AGE_RANGE))
133
+
134
+ # Generate samples and drop columns where all rows are null.
135
+ df = self._engine.generate_samples(size=n, evidence=kw).dropna(axis=1, how="all")
136
+
137
+ # We need this for derived fields.
138
+ df["locale"] = self.locale
139
+
140
+ # Only keep columns that are listed in the schema.
141
+ fields = [field for field in PII_FIELDS if field in df.columns]
142
+ if with_synthetic_personas:
143
+ fields.extend([field for field in PERSONA_FIELDS if field in df.columns])
144
+
145
+ return df[fields]
146
+
147
+ def generate(self, n: int, **kwargs) -> list[dict[str, Any]]:
148
+ return [
149
+ generate_and_insert_derived_fields(p)
150
+ for p in self._generate_from_dataset(n, **kwargs).to_dict(orient="records")
151
+ ]
152
+
153
+
154
+ def create_people_gen_resource(
155
+ schema: DataSchema,
156
+ person_generator_loader: Callable[[bool], ManagedDatasetGenerator] | None = None,
157
+ ) -> dict[str, PeopleGen]:
158
+ """Creates resource of unique people generators needed to generate the dataset.
159
+
160
+ The resource is a dictionary of person generators, where the keys are the following:
161
+ - {locale} for dataset-based person generators
162
+ - {locale}_with_personas for dataset-based person generators with synthetic personas
163
+ - {locale}_faker for faker-based person generators
164
+
165
+ Args:
166
+ schema: Schema of the dataset that we will generate.
167
+ person_generator_loader: Function that loads a managed dataset generator.
168
+
169
+ Returns:
170
+ Dictionary of unique people generators needed to generate the dataset.
171
+ """
172
+ people_gen_resource = {}
173
+
174
+ # ------------------------------------------------------------
175
+ # Preload dataset-based person generators
176
+ # ------------------------------------------------------------
177
+
178
+ for column in schema.get_columns_by_sampler_type("person"):
179
+ for params in [column.params, *list(column.conditional_params.values())]:
180
+ if params.people_gen_key not in people_gen_resource:
181
+ try:
182
+ engine = person_generator_loader(locale=params.locale)
183
+ people_gen_resource[params.people_gen_key] = PeopleGenFromDataset(
184
+ engine=engine, locale=params.locale
185
+ )
186
+ except Exception as e:
187
+ raise ManagedDatasetGeneratorError(
188
+ f"🛑 Failed to load dataset-based person generator for locale {params.locale}. "
189
+ "Please check if you have access to person data for this locale. "
190
+ ) from e
191
+
192
+ # ------------------------------------------------------------
193
+ # Preload faker-based person generators
194
+ # ------------------------------------------------------------
195
+
196
+ for column in schema.get_columns_by_sampler_type("person_from_faker"):
197
+ for params in [column.params, *list(column.conditional_params.values())]:
198
+ if params.people_gen_key not in people_gen_resource:
199
+ people_gen_resource[params.people_gen_key] = PeopleGenFaker(
200
+ engine=Faker(params.locale), locale=params.locale
201
+ )
202
+
203
+ return people_gen_resource
@@ -0,0 +1,54 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import NamedTuple
5
+
6
+
7
+ class FakerPersonData(NamedTuple):
8
+ sex: list[str] = ["Male", "Female"]
9
+
10
+ us_locale_only: list[str] = [
11
+ "state",
12
+ "county",
13
+ "unit",
14
+ "middle_name",
15
+ "ethnic_background",
16
+ "ssn",
17
+ ]
18
+
19
+ location: list[str] = ["city", "state", "postcode"]
20
+
21
+ bachelors: list[str] = [
22
+ "stem",
23
+ "business",
24
+ "education",
25
+ "arts_humanities",
26
+ "stem_related",
27
+ ]
28
+
29
+ education_level: list[str] = [
30
+ "secondary_education",
31
+ "some_college",
32
+ "bachelors",
33
+ "associates",
34
+ "graduate",
35
+ "doctorate",
36
+ ]
37
+
38
+ marital_status: list[str] = [
39
+ "married_present",
40
+ "divorced",
41
+ "never_married",
42
+ "separated",
43
+ "widowed",
44
+ ]
45
+
46
+ college_level: list[str] = ["bachelors", "graduate", "doctorate"]
47
+
48
+ attr_map: dict[str, str] = {
49
+ "street_number": "building_number",
50
+ "occupation": "job",
51
+ }
52
+
53
+
54
+ faker_constants = FakerPersonData()
@@ -0,0 +1,143 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from functools import cached_property
7
+
8
+ import networkx as nx
9
+ from pydantic import BaseModel, Field, field_validator, model_validator
10
+ from typing_extensions import Self
11
+
12
+ from data_designer.config.base import ConfigBase
13
+ from data_designer.config.sampler_constraints import ColumnConstraintT
14
+ from data_designer.config.sampler_params import SamplerType
15
+ from data_designer.engine.sampling_gen.column import ConditionalDataColumn
16
+ from data_designer.engine.sampling_gen.constraints import ConstraintChecker, get_constraint_checker
17
+
18
+
19
+ class Dag(BaseModel):
20
+ nodes: set[str]
21
+ edges: set[tuple[str, str]]
22
+
23
+ @model_validator(mode="after")
24
+ def validate_is_dag(self) -> Self:
25
+ if not nx.is_directed_acyclic_graph(self.to_networkx()):
26
+ raise ValueError("There are circular dependencies in the definitions of your sampler columns.")
27
+ return self
28
+
29
+ def to_networkx(self) -> nx.DiGraph:
30
+ dag = nx.DiGraph()
31
+ for node in self.nodes:
32
+ dag.add_node(node)
33
+ for edge in self.edges:
34
+ dag.add_edge(*edge)
35
+ return dag
36
+
37
+
38
+ class DataSchema(ConfigBase):
39
+ """Defines the data schema for synthetic data generation.
40
+
41
+ A DataSchema represents a collection of columns and their relationships through
42
+ conditional parameters and/or constraints. Upon initialization, the schema validates
43
+ that column dependencies form a DAG and that all constraints reference valid columns.
44
+ """
45
+
46
+ columns: list[ConditionalDataColumn] = Field(..., min_length=1)
47
+ constraints: list[ColumnConstraintT] = []
48
+
49
+ @cached_property
50
+ def constraint_checkers(self) -> list[ConstraintChecker]:
51
+ return [get_constraint_checker(c.constraint_type)(constraint=c) for c in self.constraints]
52
+
53
+ @property
54
+ def column_names(self) -> list[str]:
55
+ return [column.name for column in self.columns]
56
+
57
+ @property
58
+ def dag(self) -> Dag:
59
+ nodes = set()
60
+ edges = set()
61
+
62
+ for column in self.columns:
63
+ nodes.add(column.name)
64
+
65
+ # Add edges for the conditional columns.
66
+ for conditional_column in column.conditional_column_names:
67
+ edges.add((conditional_column, column.name))
68
+
69
+ # Add edges if the source has required columns.
70
+ for condition in column.conditions:
71
+ source = column.get_sampler(condition)
72
+ for required_column in source.get_required_column_names():
73
+ edges.add((required_column, column.name))
74
+
75
+ for checker in self.constraint_checkers:
76
+ column_names = checker.get_required_column_names()
77
+ if len(column_names) == 2:
78
+ edges.add((column_names[1], column_names[0]))
79
+ return Dag(nodes=nodes, edges=edges)
80
+
81
+ @field_validator("columns", mode="after")
82
+ def check_unique_column_names(cls, columns: list[ConditionalDataColumn]) -> list[ConditionalDataColumn]:
83
+ column_names = [column.name for column in columns]
84
+ if len(column_names) != len(set(column_names)):
85
+ raise ValueError("Column names must be unique")
86
+ return columns
87
+
88
+ @model_validator(mode="after")
89
+ def validate_constraints(self) -> Self:
90
+ column_names = [column.name for column in self.columns]
91
+
92
+ # Check if all columns required by constraints are present in the schema.
93
+ for checker in self.constraint_checkers:
94
+ constrained_column_names = checker.get_required_column_names()
95
+ if not set(constrained_column_names).issubset(column_names):
96
+ missing = set(constrained_column_names) - set(column_names)
97
+ raise ValueError(
98
+ f"These constrained columns are missing in the definitions of your sampler columns: {missing}"
99
+ )
100
+
101
+ return self
102
+
103
+ @model_validator(mode="after")
104
+ def validate_dag(self) -> Self:
105
+ self.dag
106
+ return self
107
+
108
+ @model_validator(mode="after")
109
+ def validate_subcategory_columns_if_present(self) -> Self:
110
+ for sub in self.get_columns_by_sampler_type(SamplerType.SUBCATEGORY):
111
+ cat = self.get_column(sub.params.category)
112
+ if cat.sampler_type != SamplerType.CATEGORY:
113
+ raise ValueError(
114
+ f"The parent of subcategory column '{sub.name}' must be a category "
115
+ f"source type, but '{cat.name}' is of type '{cat.sampler_type}'."
116
+ )
117
+ cat_vals = set(cat.params.values)
118
+ for params in cat.conditional_params.values():
119
+ cat_vals.update(params.values)
120
+ sub_vals = set(sub.params.values.keys())
121
+ if cat_vals.symmetric_difference(sub_vals):
122
+ raise ValueError(
123
+ f"Subcategory column '{sub.name}' must have values for each value of "
124
+ f"its parent category '{sub.params.category}'. The following "
125
+ f"values need attention: {cat_vals.symmetric_difference(sub_vals)}"
126
+ )
127
+ if not all(len(v) > 0 for v in sub.params.values.values()):
128
+ raise ValueError(
129
+ f"Subcategory column '{sub.name}' must have non-empty values for "
130
+ f"each value of its parent category '{sub.params.category}'."
131
+ )
132
+ return self
133
+
134
+ def get_column(self, column_name: str) -> ConditionalDataColumn:
135
+ if column_name not in self.column_names:
136
+ raise ValueError(f"Column '{column_name}' not found in schema")
137
+ return next(column for column in self.columns if column.name == column_name)
138
+
139
+ def get_columns_by_sampler_type(self, sampler_type: SamplerType) -> list[ConditionalDataColumn]:
140
+ return [c for c in self.columns if c.sampler_type == sampler_type]
141
+
142
+ def get_constraint_checkers(self, column_name: str) -> list[ConstraintChecker]:
143
+ return [c for c in self.constraint_checkers if column_name == c.constraint.target_column]
@@ -0,0 +1,59 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from copy import deepcopy
5
+
6
+ from data_designer.config.column_configs import SamplerColumnConfig
7
+ from data_designer.config.sampler_constraints import ColumnConstraintT
8
+ from data_designer.config.sampler_params import SamplerParamsT
9
+ from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
10
+ from data_designer.engine.sampling_gen.column import ConditionalDataColumn
11
+ from data_designer.engine.sampling_gen.schema import DataSchema
12
+
13
+
14
+ class SchemaBuilder:
15
+ """Builder class for DataSchema objects.
16
+
17
+ This class is meant to be a helper for internal usage and experimentation. It
18
+ provides a simple interface for constructing a DataSchema object via `add_column`
19
+ and `add_constraint` methods similar.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ columns: list[ConditionalDataColumn] | None = None,
25
+ constraints: list[ColumnConstraintT] | None = None,
26
+ ):
27
+ self._columns = columns or []
28
+ self._constraints = constraints or []
29
+
30
+ def add_column(
31
+ self,
32
+ name: str,
33
+ sampler_type: str | None,
34
+ params: dict | SamplerParamsT | None,
35
+ conditional_params: dict[str, SamplerParamsT] | None = None,
36
+ convert_to: str | None = None,
37
+ ) -> None:
38
+ self._columns.append(
39
+ ConditionalDataColumn(
40
+ name=name,
41
+ sampler_type=sampler_type,
42
+ params=params,
43
+ conditional_params=conditional_params or {},
44
+ convert_to=convert_to,
45
+ )
46
+ )
47
+
48
+ def add_constraint(self, constraint: ColumnConstraintT) -> None:
49
+ self._constraints.append(constraint)
50
+
51
+ def to_sampler_columns(self, max_rejections_factor: int = 5) -> SamplerMultiColumnConfig:
52
+ return SamplerMultiColumnConfig(
53
+ columns=[SamplerColumnConfig(**c.model_dump(mode="json")) for c in self._columns],
54
+ constraints=self._constraints,
55
+ max_rejections_factor=max_rejections_factor,
56
+ )
57
+
58
+ def build(self) -> DataSchema:
59
+ return DataSchema(columns=deepcopy(self._columns), constraints=deepcopy(self._constraints))
@@ -0,0 +1,40 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numbers
5
+
6
+ import numpy as np
7
+
8
+
9
+ def check_random_state(seed):
10
+ """Turn seed into a np.random.RandomState instance.
11
+
12
+ This function was taken from scikit-learn's utils module.
13
+ Source GitHub: https://github.com/scikit-learn/scikit-learn
14
+
15
+ Parameters
16
+ ----------
17
+ seed : None, int or instance of RandomState
18
+ If seed is None, return the RandomState singleton used by np.random.
19
+ If seed is an int, return a new RandomState instance seeded with seed.
20
+ If seed is already a RandomState instance, return it.
21
+ Otherwise raise ValueError.
22
+
23
+ Returns
24
+ -------
25
+ :class:`numpy:numpy.random.RandomState`
26
+ The random state object based on `seed` parameter.
27
+
28
+ Examples
29
+ --------
30
+ >>> from data_designer.engine.sampling_gen.utils import check_random_state
31
+ >>> check_random_state(42)
32
+ RandomState(MT19937) at 0x...
33
+ """
34
+ if seed is None or seed is np.random:
35
+ return np.random.mtrand._rand
36
+ if isinstance(seed, numbers.Integral):
37
+ return np.random.RandomState(seed)
38
+ if isinstance(seed, np.random.RandomState):
39
+ return seed
40
+ raise ValueError("%r cannot be used to seed a numpy.random.RandomState instance" % seed)
@@ -0,0 +1,80 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from collections.abc import Sequence
5
+ import json
6
+ import logging
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Protocol
10
+
11
+ from data_designer.engine.errors import SecretResolutionError
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class SecretResolver(Protocol):
17
+ def resolve(self, secret: str) -> str: ...
18
+
19
+
20
+ class SecretsFileResolver(SecretResolver):
21
+ _secrets: dict[str, str]
22
+
23
+ def __init__(self, filepath: Path):
24
+ if not filepath.exists():
25
+ self._secrets = {}
26
+ else:
27
+ with open(filepath) as f:
28
+ self._secrets = json.load(f)
29
+
30
+ def resolve(self, secret: str) -> str:
31
+ try:
32
+ return self._secrets[secret]
33
+ except KeyError:
34
+ raise SecretResolutionError(f"No secret found in secrets file with key {secret!r}")
35
+
36
+
37
+ class EnvironmentResolver(SecretResolver):
38
+ def resolve(self, secret: str) -> str:
39
+ try:
40
+ return os.environ[secret]
41
+ except KeyError:
42
+ raise SecretResolutionError(
43
+ f"Environment variable with name {secret!r} is required but not set. Please set it in your environment and try again."
44
+ )
45
+
46
+
47
+ class PlaintextResolver(SecretResolver):
48
+ def resolve(self, secret: str) -> str:
49
+ return secret
50
+
51
+
52
+ class CompositeResolver(SecretResolver):
53
+ _resolvers: Sequence[SecretResolver]
54
+
55
+ def __init__(self, resolvers: Sequence[SecretResolver]):
56
+ if len(resolvers) == 0:
57
+ raise SecretResolutionError("Must provide at least one SecretResolver to CompositeResolver")
58
+ self._resolvers = resolvers
59
+
60
+ @property
61
+ def resolvers(self) -> Sequence[SecretResolver]:
62
+ """Get the sequence of resolvers in this composite resolver.
63
+
64
+ Returns:
65
+ Sequence of SecretResolver instances used to resolve secrets.
66
+ """
67
+ return self._resolvers
68
+
69
+ def resolve(self, secret: str) -> str:
70
+ errors = []
71
+ for resolver in self._resolvers:
72
+ try:
73
+ return resolver.resolve(secret)
74
+ except SecretResolutionError as err:
75
+ errors.append(str(err))
76
+ continue
77
+
78
+ raise SecretResolutionError(
79
+ f"No configured resolvers were able to resolve secret {secret!r}: {', '.join(errors)}"
80
+ )
@@ -0,0 +1,17 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from data_designer.engine.validators.base import BaseValidator, ValidationResult
5
+ from data_designer.engine.validators.local_callable import LocalCallableValidator
6
+ from data_designer.engine.validators.python import PythonValidator
7
+ from data_designer.engine.validators.remote import RemoteValidator
8
+ from data_designer.engine.validators.sql import SQLValidator
9
+
10
+ __all__ = [
11
+ "BaseValidator",
12
+ "LocalCallableValidator",
13
+ "RemoteValidator",
14
+ "ValidationResult",
15
+ "PythonValidator",
16
+ "SQLValidator",
17
+ ]