data-designer 0.3.8rc2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. data_designer/cli/commands/__init__.py +1 -1
  2. data_designer/interface/__init__.py +21 -1
  3. data_designer/{_version.py → interface/_version.py} +2 -2
  4. data_designer/interface/data_designer.py +1 -7
  5. {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0.dist-info}/METADATA +10 -42
  6. data_designer-0.4.0.dist-info/RECORD +39 -0
  7. data_designer/__init__.py +0 -17
  8. data_designer/config/__init__.py +0 -2
  9. data_designer/config/analysis/__init__.py +0 -2
  10. data_designer/config/analysis/column_profilers.py +0 -159
  11. data_designer/config/analysis/column_statistics.py +0 -421
  12. data_designer/config/analysis/dataset_profiler.py +0 -84
  13. data_designer/config/analysis/utils/errors.py +0 -10
  14. data_designer/config/analysis/utils/reporting.py +0 -192
  15. data_designer/config/base.py +0 -69
  16. data_designer/config/column_configs.py +0 -470
  17. data_designer/config/column_types.py +0 -141
  18. data_designer/config/config_builder.py +0 -595
  19. data_designer/config/data_designer_config.py +0 -40
  20. data_designer/config/dataset_builders.py +0 -13
  21. data_designer/config/dataset_metadata.py +0 -18
  22. data_designer/config/default_model_settings.py +0 -129
  23. data_designer/config/errors.py +0 -24
  24. data_designer/config/exports.py +0 -145
  25. data_designer/config/interface.py +0 -55
  26. data_designer/config/models.py +0 -455
  27. data_designer/config/preview_results.py +0 -41
  28. data_designer/config/processors.py +0 -148
  29. data_designer/config/run_config.py +0 -51
  30. data_designer/config/sampler_constraints.py +0 -52
  31. data_designer/config/sampler_params.py +0 -639
  32. data_designer/config/seed.py +0 -116
  33. data_designer/config/seed_source.py +0 -84
  34. data_designer/config/seed_source_types.py +0 -19
  35. data_designer/config/utils/code_lang.py +0 -82
  36. data_designer/config/utils/constants.py +0 -363
  37. data_designer/config/utils/errors.py +0 -21
  38. data_designer/config/utils/info.py +0 -94
  39. data_designer/config/utils/io_helpers.py +0 -258
  40. data_designer/config/utils/misc.py +0 -78
  41. data_designer/config/utils/numerical_helpers.py +0 -30
  42. data_designer/config/utils/type_helpers.py +0 -106
  43. data_designer/config/utils/visualization.py +0 -482
  44. data_designer/config/validator_params.py +0 -94
  45. data_designer/engine/__init__.py +0 -2
  46. data_designer/engine/analysis/column_profilers/base.py +0 -49
  47. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +0 -153
  48. data_designer/engine/analysis/column_profilers/registry.py +0 -22
  49. data_designer/engine/analysis/column_statistics.py +0 -145
  50. data_designer/engine/analysis/dataset_profiler.py +0 -149
  51. data_designer/engine/analysis/errors.py +0 -9
  52. data_designer/engine/analysis/utils/column_statistics_calculations.py +0 -234
  53. data_designer/engine/analysis/utils/judge_score_processing.py +0 -132
  54. data_designer/engine/column_generators/__init__.py +0 -2
  55. data_designer/engine/column_generators/generators/__init__.py +0 -2
  56. data_designer/engine/column_generators/generators/base.py +0 -122
  57. data_designer/engine/column_generators/generators/embedding.py +0 -35
  58. data_designer/engine/column_generators/generators/expression.py +0 -55
  59. data_designer/engine/column_generators/generators/llm_completion.py +0 -113
  60. data_designer/engine/column_generators/generators/samplers.py +0 -69
  61. data_designer/engine/column_generators/generators/seed_dataset.py +0 -144
  62. data_designer/engine/column_generators/generators/validation.py +0 -140
  63. data_designer/engine/column_generators/registry.py +0 -60
  64. data_designer/engine/column_generators/utils/errors.py +0 -15
  65. data_designer/engine/column_generators/utils/generator_classification.py +0 -43
  66. data_designer/engine/column_generators/utils/judge_score_factory.py +0 -58
  67. data_designer/engine/column_generators/utils/prompt_renderer.py +0 -100
  68. data_designer/engine/compiler.py +0 -97
  69. data_designer/engine/configurable_task.py +0 -71
  70. data_designer/engine/dataset_builders/artifact_storage.py +0 -283
  71. data_designer/engine/dataset_builders/column_wise_builder.py +0 -335
  72. data_designer/engine/dataset_builders/errors.py +0 -15
  73. data_designer/engine/dataset_builders/multi_column_configs.py +0 -46
  74. data_designer/engine/dataset_builders/utils/__init__.py +0 -2
  75. data_designer/engine/dataset_builders/utils/concurrency.py +0 -212
  76. data_designer/engine/dataset_builders/utils/config_compiler.py +0 -62
  77. data_designer/engine/dataset_builders/utils/dag.py +0 -62
  78. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +0 -200
  79. data_designer/engine/dataset_builders/utils/errors.py +0 -15
  80. data_designer/engine/errors.py +0 -51
  81. data_designer/engine/model_provider.py +0 -77
  82. data_designer/engine/models/__init__.py +0 -2
  83. data_designer/engine/models/errors.py +0 -300
  84. data_designer/engine/models/facade.py +0 -287
  85. data_designer/engine/models/factory.py +0 -42
  86. data_designer/engine/models/litellm_overrides.py +0 -179
  87. data_designer/engine/models/parsers/__init__.py +0 -2
  88. data_designer/engine/models/parsers/errors.py +0 -34
  89. data_designer/engine/models/parsers/parser.py +0 -235
  90. data_designer/engine/models/parsers/postprocessors.py +0 -93
  91. data_designer/engine/models/parsers/tag_parsers.py +0 -62
  92. data_designer/engine/models/parsers/types.py +0 -84
  93. data_designer/engine/models/recipes/base.py +0 -81
  94. data_designer/engine/models/recipes/response_recipes.py +0 -293
  95. data_designer/engine/models/registry.py +0 -146
  96. data_designer/engine/models/telemetry.py +0 -359
  97. data_designer/engine/models/usage.py +0 -73
  98. data_designer/engine/models/utils.py +0 -38
  99. data_designer/engine/processing/ginja/__init__.py +0 -2
  100. data_designer/engine/processing/ginja/ast.py +0 -65
  101. data_designer/engine/processing/ginja/environment.py +0 -463
  102. data_designer/engine/processing/ginja/exceptions.py +0 -56
  103. data_designer/engine/processing/ginja/record.py +0 -32
  104. data_designer/engine/processing/gsonschema/__init__.py +0 -2
  105. data_designer/engine/processing/gsonschema/exceptions.py +0 -15
  106. data_designer/engine/processing/gsonschema/schema_transformers.py +0 -83
  107. data_designer/engine/processing/gsonschema/types.py +0 -10
  108. data_designer/engine/processing/gsonschema/validators.py +0 -202
  109. data_designer/engine/processing/processors/base.py +0 -13
  110. data_designer/engine/processing/processors/drop_columns.py +0 -42
  111. data_designer/engine/processing/processors/registry.py +0 -25
  112. data_designer/engine/processing/processors/schema_transform.py +0 -49
  113. data_designer/engine/processing/utils.py +0 -169
  114. data_designer/engine/registry/base.py +0 -99
  115. data_designer/engine/registry/data_designer_registry.py +0 -39
  116. data_designer/engine/registry/errors.py +0 -12
  117. data_designer/engine/resources/managed_dataset_generator.py +0 -39
  118. data_designer/engine/resources/managed_dataset_repository.py +0 -197
  119. data_designer/engine/resources/managed_storage.py +0 -65
  120. data_designer/engine/resources/resource_provider.py +0 -77
  121. data_designer/engine/resources/seed_reader.py +0 -154
  122. data_designer/engine/sampling_gen/column.py +0 -91
  123. data_designer/engine/sampling_gen/constraints.py +0 -100
  124. data_designer/engine/sampling_gen/data_sources/base.py +0 -217
  125. data_designer/engine/sampling_gen/data_sources/errors.py +0 -12
  126. data_designer/engine/sampling_gen/data_sources/sources.py +0 -347
  127. data_designer/engine/sampling_gen/entities/__init__.py +0 -2
  128. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  129. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +0 -86
  130. data_designer/engine/sampling_gen/entities/email_address_utils.py +0 -171
  131. data_designer/engine/sampling_gen/entities/errors.py +0 -10
  132. data_designer/engine/sampling_gen/entities/national_id_utils.py +0 -102
  133. data_designer/engine/sampling_gen/entities/person.py +0 -144
  134. data_designer/engine/sampling_gen/entities/phone_number.py +0 -128
  135. data_designer/engine/sampling_gen/errors.py +0 -26
  136. data_designer/engine/sampling_gen/generator.py +0 -122
  137. data_designer/engine/sampling_gen/jinja_utils.py +0 -64
  138. data_designer/engine/sampling_gen/people_gen.py +0 -199
  139. data_designer/engine/sampling_gen/person_constants.py +0 -56
  140. data_designer/engine/sampling_gen/schema.py +0 -147
  141. data_designer/engine/sampling_gen/schema_builder.py +0 -61
  142. data_designer/engine/sampling_gen/utils.py +0 -46
  143. data_designer/engine/secret_resolver.py +0 -82
  144. data_designer/engine/validation.py +0 -367
  145. data_designer/engine/validators/__init__.py +0 -19
  146. data_designer/engine/validators/base.py +0 -38
  147. data_designer/engine/validators/local_callable.py +0 -39
  148. data_designer/engine/validators/python.py +0 -254
  149. data_designer/engine/validators/remote.py +0 -89
  150. data_designer/engine/validators/sql.py +0 -65
  151. data_designer/errors.py +0 -7
  152. data_designer/essentials/__init__.py +0 -33
  153. data_designer/lazy_heavy_imports.py +0 -54
  154. data_designer/logging.py +0 -163
  155. data_designer/plugin_manager.py +0 -78
  156. data_designer/plugins/__init__.py +0 -8
  157. data_designer/plugins/errors.py +0 -15
  158. data_designer/plugins/plugin.py +0 -141
  159. data_designer/plugins/registry.py +0 -88
  160. data_designer/plugins/testing/__init__.py +0 -10
  161. data_designer/plugins/testing/stubs.py +0 -116
  162. data_designer/plugins/testing/utils.py +0 -20
  163. data_designer-0.3.8rc2.dist-info/RECORD +0 -196
  164. data_designer-0.3.8rc2.dist-info/licenses/LICENSE +0 -201
  165. {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0.dist-info}/WHEEL +0 -0
  166. {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0.dist-info}/entry_points.txt +0 -0
@@ -1,122 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from collections.abc import Callable
7
- from typing import TYPE_CHECKING
8
-
9
- from data_designer.engine.sampling_gen.data_sources.base import RadomStateT
10
- from data_designer.engine.sampling_gen.errors import RejectionSamplingError
11
- from data_designer.engine.sampling_gen.jinja_utils import JinjaDataFrame
12
- from data_designer.engine.sampling_gen.people_gen import create_people_gen_resource
13
- from data_designer.engine.sampling_gen.schema import DataSchema
14
- from data_designer.engine.sampling_gen.utils import check_random_state
15
- from data_designer.lazy_heavy_imports import np, nx, pd
16
-
17
- if TYPE_CHECKING:
18
- import networkx as nx
19
- import numpy as np
20
- import pandas as pd
21
-
22
- from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
23
- from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
24
- from data_designer.engine.sampling_gen.column import ConditionalDataColumn
25
-
26
-
27
- class DatasetGenerator:
28
- """Generates synthetic datasets based on the given schema definition.
29
-
30
- This object generates synthetic data based on the schema using sampling-based
31
- methods (implemented as "data sources"), including handling conditional generation
32
- and enforcing constraints through rejection sampling.
33
-
34
- Args:
35
- sampler_columns: Sampler columns to generate.
36
- random_state: Random number generator or seed for reproducibility.
37
- person_generator_loader: A function that loads a person generator. If None,
38
- person generation will not be supported.
39
-
40
- Note:
41
- The generator leverages the schema's DAG to topologically sort the columns
42
- and uses rejection sampling to satisfy constraints. If constraints are too strict,
43
- generation may fail with a RejectionSamplingError.
44
- """
45
-
46
- def __init__(
47
- self,
48
- sampler_columns: SamplerMultiColumnConfig | None,
49
- random_state: RadomStateT | None = None,
50
- person_generator_loader: Callable[[bool], ManagedDatasetGenerator] | None = None,
51
- *,
52
- schema: DataSchema | None = None,
53
- max_rejections_factor: int = 5,
54
- ):
55
- # This is temporary while we need the legacy and refactored code to coexist.
56
- if schema is not None:
57
- self.schema = schema
58
- self.max_rejections_factor = max_rejections_factor
59
- else:
60
- self.schema = DataSchema(
61
- columns=[column.model_dump() for column in sampler_columns.columns],
62
- constraints=sampler_columns.constraints,
63
- )
64
- self.max_rejections_factor = sampler_columns.max_rejections_factor
65
-
66
- self.rng = check_random_state(random_state)
67
- self._dag = self.schema.dag.to_networkx()
68
- self._shared_sampler_kwargs = {
69
- "random_state": self.rng,
70
- "people_gen_resource": create_people_gen_resource(self.schema, person_generator_loader),
71
- }
72
-
73
- def _round_if_needed(self, column: ConditionalDataColumn, df: pd.DataFrame) -> pd.DataFrame:
74
- if hasattr(column.params, "decimal_places") and column.params.decimal_places is not None:
75
- df[column.name] = df[column.name].round(column.params.decimal_places)
76
- return df
77
-
78
- def _run_rejection_sampling(self, df: pd.DataFrame, column: ConditionalDataColumn) -> pd.DataFrame:
79
- name = column.name
80
- num_iterations = 0
81
- num_samples = len(df)
82
- needs_samples = np.ones(num_samples, dtype=bool)
83
-
84
- while needs_samples.any():
85
- for condition in column.conditions:
86
- index = JinjaDataFrame(condition).select_index(df[needs_samples])
87
- src = column.get_sampler(condition, **self._shared_sampler_kwargs)
88
- df = src.inject_data_column(df, name, index)
89
-
90
- df[name] = column.get_default_sampler(**self._shared_sampler_kwargs).preproc(df[name], column.convert_to)
91
-
92
- # Check all constraints on the column.
93
- batch_mask = np.ones(num_samples, dtype=bool)
94
- for constraint in self.schema.get_constraint_checkers(name):
95
- batch_mask &= constraint.check(df)
96
- needs_samples[batch_mask] = False
97
- num_iterations += 1
98
-
99
- if num_iterations > self.max_rejections_factor * num_samples:
100
- raise RejectionSamplingError(
101
- "Exceeded the maximum number of rejections (max_rejections_factor * "
102
- f"num_samples = {self.max_rejections_factor * num_samples}) while "
103
- f"sampling `{column.name}`. Please consider adjusting the constraints "
104
- "and/or column's generation configuration."
105
- )
106
-
107
- return df
108
-
109
- def generate(self, num_samples: int) -> pd.DataFrame:
110
- dataset = pd.DataFrame(index=range(num_samples))
111
-
112
- for column_name in nx.topological_sort(self._dag):
113
- column = self.schema.get_column(column_name)
114
- dataset = self._run_rejection_sampling(dataset, column)
115
-
116
- for column in self.schema.columns:
117
- dataset[column.name] = column.get_default_sampler(**self._shared_sampler_kwargs).postproc(
118
- dataset[column.name], column.convert_to
119
- )
120
- dataset = self._round_if_needed(column, dataset)
121
-
122
- return dataset[self.schema.column_names]
@@ -1,64 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import ast
7
- from typing import TYPE_CHECKING, Any
8
-
9
- from data_designer.engine.processing.ginja.environment import (
10
- UserTemplateSandboxEnvironment,
11
- WithJinja2UserTemplateRendering,
12
- )
13
- from data_designer.lazy_heavy_imports import pd
14
-
15
- if TYPE_CHECKING:
16
- import pandas as pd
17
-
18
-
19
- class JinjaDataFrame(WithJinja2UserTemplateRendering):
20
- def __init__(self, expr: str):
21
- self.expr = expr
22
-
23
- def _jsonify(self, record) -> dict[str, Any]:
24
- for key, value in record.items():
25
- if isinstance(value, pd.Timestamp):
26
- record[key] = value.isoformat()
27
- return record
28
-
29
- @property
30
- def _expr(self) -> str:
31
- return "{{ " + self.expr + " }}"
32
-
33
- def select_index(self, dataframe: pd.DataFrame) -> pd.Index:
34
- if dataframe.empty or self.expr == "...":
35
- return dataframe.index
36
-
37
- self.prepare_jinja2_template_renderer(self._expr, list(dataframe))
38
-
39
- where = dataframe.apply(
40
- lambda row: self.render_template(self._jsonify(row.to_dict())) == "True",
41
- axis=1,
42
- ).to_numpy()
43
-
44
- return dataframe[where].index
45
-
46
- def to_column(self, dataframe: pd.DataFrame) -> list[Any]:
47
- self.prepare_jinja2_template_renderer(self._expr, list(dataframe))
48
-
49
- expr_values = []
50
- for record in dataframe.to_dict(orient="records"):
51
- rendered = self.render_template(self._jsonify(record))
52
- try:
53
- # Non-string expressions are evaluated as literals.
54
- expr_values.append(ast.literal_eval(rendered))
55
- except (ValueError, SyntaxError):
56
- # Strings throw an error and are appended directly.
57
- expr_values.append(rendered)
58
-
59
- return expr_values
60
-
61
-
62
- def extract_column_names_from_expression(expr: str) -> set[str]:
63
- """Extract valid column names from the given expression."""
64
- return UserTemplateSandboxEnvironment().get_references("{{ " + expr + " }}")
@@ -1,199 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import random
7
- import uuid
8
- from abc import ABC, abstractmethod
9
- from collections.abc import Callable
10
- from copy import deepcopy
11
- from typing import TYPE_CHECKING, Any, TypeAlias
12
-
13
- from data_designer.config.utils.constants import DEFAULT_AGE_RANGE
14
- from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
15
- from data_designer.engine.sampling_gen.entities.dataset_based_person_fields import PERSONA_FIELDS, PII_FIELDS
16
- from data_designer.engine.sampling_gen.entities.person import (
17
- convert_age_to_birth_date,
18
- generate_and_insert_derived_fields,
19
- )
20
- from data_designer.engine.sampling_gen.errors import ManagedDatasetGeneratorError
21
- from data_designer.engine.sampling_gen.person_constants import faker_constants
22
- from data_designer.lazy_heavy_imports import faker, pd
23
-
24
- if TYPE_CHECKING:
25
- import faker
26
- import pandas as pd
27
-
28
- from data_designer.engine.sampling_gen.schema import DataSchema
29
-
30
- EngineT: TypeAlias = faker.Faker | ManagedDatasetGenerator
31
-
32
-
33
- class PeopleGen(ABC):
34
- """Unified interface for generating people data."""
35
-
36
- def __init__(self, engine: EngineT, locale: str):
37
- self.locale = locale
38
- self._engine = engine
39
-
40
- def set_engine(self, engine: EngineT) -> None:
41
- self._engine = engine
42
-
43
- @abstractmethod
44
- def generate(self, n: int, **kwargs) -> list[dict[str, Any]]: ...
45
-
46
-
47
- class PeopleGenFaker(PeopleGen):
48
- @property
49
- def _fake(self) -> faker.Faker:
50
- return self._engine
51
-
52
- def try_fake_else_none(self, attr_name: str, none_fill: Any | None = None) -> type:
53
- try:
54
- return getattr(self._fake, attr_name)()
55
- except AttributeError:
56
- return none_fill
57
-
58
- def _generate_name_and_sex(self, **kwargs) -> dict[str, str]:
59
- options = faker_constants.sex
60
- if "sex" in kwargs and kwargs["sex"] in [*options, *[[o] for o in options]]:
61
- sex = random.choice(kwargs["sex"]) if isinstance(kwargs["sex"], list) else kwargs["sex"]
62
- else:
63
- sex = random.choice(options)
64
-
65
- return {
66
- "first_name": getattr(self._fake, f"first_name_{sex.lower()}")(),
67
- "last_name": getattr(self._fake, f"last_name_{sex.lower()}")(),
68
- "middle_name": None,
69
- "sex": sex,
70
- }
71
-
72
- def _generate_address_fields(self, **kwargs) -> dict[str, str]:
73
- address = {
74
- "street_number": self.try_fake_else_none(faker_constants.attr_map["street_number"]),
75
- "street_name": self.try_fake_else_none("street_name"),
76
- }
77
-
78
- # Location fields can be filtered using the fixed_kwargs.
79
- for attr in faker_constants.location:
80
- if attr in kwargs:
81
- address[attr] = random.choice(kwargs[attr]) if isinstance(kwargs[attr], list) else kwargs[attr]
82
- else:
83
- address[attr] = self.try_fake_else_none(attr)
84
-
85
- return address
86
-
87
- def _generate_age(self, **kwargs) -> int:
88
- return random.randint(*kwargs.get("age_range", DEFAULT_AGE_RANGE))
89
-
90
- def _generate_marital_status(self, **kwargs) -> str:
91
- return random.choice(faker_constants.marital_status)
92
-
93
- def _generate_bachelors_field(self, **kwargs) -> str:
94
- return random.choice(faker_constants.bachelors)
95
-
96
- def _generate_education_level(self, **kwargs) -> str:
97
- return random.choice(faker_constants.education_level)
98
-
99
- def make_person(self, **kwargs) -> dict[str, Any]:
100
- person = {"uuid": str(uuid.uuid4()), "locale": self.locale}
101
- person.update(self._generate_name_and_sex(**kwargs))
102
- person.update(self._generate_address_fields(**kwargs))
103
- person.update({"age": self._generate_age(**kwargs)})
104
- person.update({"birth_date": convert_age_to_birth_date(person["age"]).isoformat()})
105
- person.update({"country": self.try_fake_else_none("country")})
106
- person.update({"marital_status": self._generate_marital_status(**kwargs)})
107
- person.update({"education_level": self._generate_education_level(**kwargs)})
108
- person.update({"unit": ""})
109
- person.update({"occupation": self.try_fake_else_none(faker_constants.attr_map["occupation"])})
110
- person.update({"phone_number": (self.try_fake_else_none("phone_number") if person["age"] >= 18 else None)})
111
- if person["education_level"] in faker_constants.college_level:
112
- person.update({"bachelors_field": self._generate_bachelors_field(**kwargs)})
113
- else:
114
- person.update({"bachelors_field": "no_degree"})
115
- return person
116
-
117
- def generate(self, n: int, **kwargs) -> list[dict[str, Any]]:
118
- return [self.make_person(**kwargs) for _ in range(n)]
119
-
120
-
121
- class PeopleGenFromDataset(PeopleGen):
122
- def _get_ages(self, age_range: tuple[int, int]) -> list[int]:
123
- return list(range(age_range[0], age_range[1] + 1))
124
-
125
- def _generate_from_dataset(self, n: int, **kwargs) -> pd.DataFrame:
126
- kw = deepcopy(kwargs)
127
- with_synthetic_personas = kw.pop("with_synthetic_personas", False)
128
- kw["age"] = self._get_ages(kw.pop("age_range", DEFAULT_AGE_RANGE))
129
-
130
- # Generate samples and drop columns where all rows are null.
131
- df = self._engine.generate_samples(size=n, evidence=kw).dropna(axis=1, how="all")
132
-
133
- # We need this for derived fields.
134
- df["locale"] = self.locale
135
-
136
- # Only keep columns that are listed in the schema.
137
- fields = [field for field in PII_FIELDS if field in df.columns]
138
- if with_synthetic_personas:
139
- fields.extend([field for field in PERSONA_FIELDS if field in df.columns])
140
-
141
- return df[fields]
142
-
143
- def generate(self, n: int, **kwargs) -> list[dict[str, Any]]:
144
- return [
145
- generate_and_insert_derived_fields(p)
146
- for p in self._generate_from_dataset(n, **kwargs).to_dict(orient="records")
147
- ]
148
-
149
-
150
- def create_people_gen_resource(
151
- schema: DataSchema,
152
- person_generator_loader: Callable[[bool], ManagedDatasetGenerator] | None = None,
153
- ) -> dict[str, PeopleGen]:
154
- """Creates resource of unique people generators needed to generate the dataset.
155
-
156
- The resource is a dictionary of person generators, where the keys are the following:
157
- - {locale} for dataset-based person generators
158
- - {locale}_with_personas for dataset-based person generators with synthetic personas
159
- - {locale}_faker for faker-based person generators
160
-
161
- Args:
162
- schema: Schema of the dataset that we will generate.
163
- person_generator_loader: Function that loads a managed dataset generator.
164
-
165
- Returns:
166
- Dictionary of unique people generators needed to generate the dataset.
167
- """
168
- people_gen_resource = {}
169
-
170
- # ------------------------------------------------------------
171
- # Preload dataset-based person generators
172
- # ------------------------------------------------------------
173
-
174
- for column in schema.get_columns_by_sampler_type("person"):
175
- for params in [column.params, *list(column.conditional_params.values())]:
176
- if params.people_gen_key not in people_gen_resource:
177
- try:
178
- engine = person_generator_loader(locale=params.locale)
179
- people_gen_resource[params.people_gen_key] = PeopleGenFromDataset(
180
- engine=engine, locale=params.locale
181
- )
182
- except Exception as e:
183
- raise ManagedDatasetGeneratorError(
184
- f"🛑 Failed to load dataset-based person generator for locale {params.locale}. "
185
- "Please check if you have access to person data for this locale. "
186
- ) from e
187
-
188
- # ------------------------------------------------------------
189
- # Preload faker-based person generators
190
- # ------------------------------------------------------------
191
-
192
- for column in schema.get_columns_by_sampler_type("person_from_faker"):
193
- for params in [column.params, *list(column.conditional_params.values())]:
194
- if params.people_gen_key not in people_gen_resource:
195
- people_gen_resource[params.people_gen_key] = PeopleGenFaker(
196
- engine=faker.Faker(params.locale), locale=params.locale
197
- )
198
-
199
- return people_gen_resource
@@ -1,56 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from typing import NamedTuple
7
-
8
-
9
- class FakerPersonData(NamedTuple):
10
- sex: list[str] = ["Male", "Female"]
11
-
12
- us_locale_only: list[str] = [
13
- "state",
14
- "county",
15
- "unit",
16
- "middle_name",
17
- "ethnic_background",
18
- "ssn",
19
- ]
20
-
21
- location: list[str] = ["city", "state", "postcode"]
22
-
23
- bachelors: list[str] = [
24
- "stem",
25
- "business",
26
- "education",
27
- "arts_humanities",
28
- "stem_related",
29
- ]
30
-
31
- education_level: list[str] = [
32
- "secondary_education",
33
- "some_college",
34
- "bachelors",
35
- "associates",
36
- "graduate",
37
- "doctorate",
38
- ]
39
-
40
- marital_status: list[str] = [
41
- "married_present",
42
- "divorced",
43
- "never_married",
44
- "separated",
45
- "widowed",
46
- ]
47
-
48
- college_level: list[str] = ["bachelors", "graduate", "doctorate"]
49
-
50
- attr_map: dict[str, str] = {
51
- "street_number": "building_number",
52
- "occupation": "job",
53
- }
54
-
55
-
56
- faker_constants = FakerPersonData()
@@ -1,147 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from functools import cached_property
7
- from typing import TYPE_CHECKING
8
-
9
- from pydantic import BaseModel, Field, field_validator, model_validator
10
- from typing_extensions import Self
11
-
12
- from data_designer.config.base import ConfigBase
13
- from data_designer.config.sampler_constraints import ColumnConstraintT
14
- from data_designer.config.sampler_params import SamplerType
15
- from data_designer.engine.sampling_gen.column import ConditionalDataColumn
16
- from data_designer.engine.sampling_gen.constraints import ConstraintChecker, get_constraint_checker
17
- from data_designer.lazy_heavy_imports import nx
18
-
19
- if TYPE_CHECKING:
20
- import networkx as nx
21
-
22
-
23
- class Dag(BaseModel):
24
- nodes: set[str]
25
- edges: set[tuple[str, str]]
26
-
27
- @model_validator(mode="after")
28
- def validate_is_dag(self) -> Self:
29
- if not nx.is_directed_acyclic_graph(self.to_networkx()):
30
- raise ValueError("There are circular dependencies in the definitions of your sampler columns.")
31
- return self
32
-
33
- def to_networkx(self) -> nx.DiGraph:
34
- dag = nx.DiGraph()
35
- for node in self.nodes:
36
- dag.add_node(node)
37
- for edge in self.edges:
38
- dag.add_edge(*edge)
39
- return dag
40
-
41
-
42
- class DataSchema(ConfigBase):
43
- """Defines the data schema for synthetic data generation.
44
-
45
- A DataSchema represents a collection of columns and their relationships through
46
- conditional parameters and/or constraints. Upon initialization, the schema validates
47
- that column dependencies form a DAG and that all constraints reference valid columns.
48
- """
49
-
50
- columns: list[ConditionalDataColumn] = Field(..., min_length=1)
51
- constraints: list[ColumnConstraintT] = []
52
-
53
- @cached_property
54
- def constraint_checkers(self) -> list[ConstraintChecker]:
55
- return [get_constraint_checker(c.constraint_type)(constraint=c) for c in self.constraints]
56
-
57
- @property
58
- def column_names(self) -> list[str]:
59
- return [column.name for column in self.columns]
60
-
61
- @property
62
- def dag(self) -> Dag:
63
- nodes = set()
64
- edges = set()
65
-
66
- for column in self.columns:
67
- nodes.add(column.name)
68
-
69
- # Add edges for the conditional columns.
70
- for conditional_column in column.conditional_column_names:
71
- edges.add((conditional_column, column.name))
72
-
73
- # Add edges if the source has required columns.
74
- for condition in column.conditions:
75
- source = column.get_sampler(condition)
76
- for required_column in source.get_required_column_names():
77
- edges.add((required_column, column.name))
78
-
79
- for checker in self.constraint_checkers:
80
- column_names = checker.get_required_column_names()
81
- if len(column_names) == 2:
82
- edges.add((column_names[1], column_names[0]))
83
- return Dag(nodes=nodes, edges=edges)
84
-
85
- @field_validator("columns", mode="after")
86
- def check_unique_column_names(cls, columns: list[ConditionalDataColumn]) -> list[ConditionalDataColumn]:
87
- column_names = [column.name for column in columns]
88
- if len(column_names) != len(set(column_names)):
89
- raise ValueError("Column names must be unique")
90
- return columns
91
-
92
- @model_validator(mode="after")
93
- def validate_constraints(self) -> Self:
94
- column_names = [column.name for column in self.columns]
95
-
96
- # Check if all columns required by constraints are present in the schema.
97
- for checker in self.constraint_checkers:
98
- constrained_column_names = checker.get_required_column_names()
99
- if not set(constrained_column_names).issubset(column_names):
100
- missing = set(constrained_column_names) - set(column_names)
101
- raise ValueError(
102
- f"These constrained columns are missing in the definitions of your sampler columns: {missing}"
103
- )
104
-
105
- return self
106
-
107
- @model_validator(mode="after")
108
- def validate_dag(self) -> Self:
109
- self.dag
110
- return self
111
-
112
- @model_validator(mode="after")
113
- def validate_subcategory_columns_if_present(self) -> Self:
114
- for sub in self.get_columns_by_sampler_type(SamplerType.SUBCATEGORY):
115
- cat = self.get_column(sub.params.category)
116
- if cat.sampler_type != SamplerType.CATEGORY:
117
- raise ValueError(
118
- f"The parent of subcategory column '{sub.name}' must be a category "
119
- f"source type, but '{cat.name}' is of type '{cat.sampler_type}'."
120
- )
121
- cat_vals = set(cat.params.values)
122
- for params in cat.conditional_params.values():
123
- cat_vals.update(params.values)
124
- sub_vals = set(sub.params.values.keys())
125
- if cat_vals.symmetric_difference(sub_vals):
126
- raise ValueError(
127
- f"Subcategory column '{sub.name}' must have values for each value of "
128
- f"its parent category '{sub.params.category}'. The following "
129
- f"values need attention: {cat_vals.symmetric_difference(sub_vals)}"
130
- )
131
- if not all(len(v) > 0 for v in sub.params.values.values()):
132
- raise ValueError(
133
- f"Subcategory column '{sub.name}' must have non-empty values for "
134
- f"each value of its parent category '{sub.params.category}'."
135
- )
136
- return self
137
-
138
- def get_column(self, column_name: str) -> ConditionalDataColumn:
139
- if column_name not in self.column_names:
140
- raise ValueError(f"Column '{column_name}' not found in schema")
141
- return next(column for column in self.columns if column.name == column_name)
142
-
143
- def get_columns_by_sampler_type(self, sampler_type: SamplerType) -> list[ConditionalDataColumn]:
144
- return [c for c in self.columns if c.sampler_type == sampler_type]
145
-
146
- def get_constraint_checkers(self, column_name: str) -> list[ConstraintChecker]:
147
- return [c for c in self.constraint_checkers if column_name == c.constraint.target_column]
@@ -1,61 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from copy import deepcopy
7
-
8
- from data_designer.config.column_configs import SamplerColumnConfig
9
- from data_designer.config.sampler_constraints import ColumnConstraintT
10
- from data_designer.config.sampler_params import SamplerParamsT
11
- from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
12
- from data_designer.engine.sampling_gen.column import ConditionalDataColumn
13
- from data_designer.engine.sampling_gen.schema import DataSchema
14
-
15
-
16
- class SchemaBuilder:
17
- """Builder class for DataSchema objects.
18
-
19
- This class is meant to be a helper for internal usage and experimentation. It
20
- provides a simple interface for constructing a DataSchema object via `add_column`
21
- and `add_constraint` methods similar.
22
- """
23
-
24
- def __init__(
25
- self,
26
- columns: list[ConditionalDataColumn] | None = None,
27
- constraints: list[ColumnConstraintT] | None = None,
28
- ):
29
- self._columns = columns or []
30
- self._constraints = constraints or []
31
-
32
- def add_column(
33
- self,
34
- name: str,
35
- sampler_type: str | None,
36
- params: dict | SamplerParamsT | None,
37
- conditional_params: dict[str, SamplerParamsT] | None = None,
38
- convert_to: str | None = None,
39
- ) -> None:
40
- self._columns.append(
41
- ConditionalDataColumn(
42
- name=name,
43
- sampler_type=sampler_type,
44
- params=params,
45
- conditional_params=conditional_params or {},
46
- convert_to=convert_to,
47
- )
48
- )
49
-
50
- def add_constraint(self, constraint: ColumnConstraintT) -> None:
51
- self._constraints.append(constraint)
52
-
53
- def to_sampler_columns(self, max_rejections_factor: int = 5) -> SamplerMultiColumnConfig:
54
- return SamplerMultiColumnConfig(
55
- columns=[SamplerColumnConfig(**c.model_dump(mode="json")) for c in self._columns],
56
- constraints=self._constraints,
57
- max_rejections_factor=max_rejections_factor,
58
- )
59
-
60
- def build(self) -> DataSchema:
61
- return DataSchema(columns=deepcopy(self._columns), constraints=deepcopy(self._constraints))