data-designer 0.3.8rc1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. data_designer/cli/commands/__init__.py +1 -1
  2. data_designer/interface/__init__.py +21 -1
  3. data_designer/{_version.py → interface/_version.py} +2 -2
  4. data_designer/interface/data_designer.py +8 -11
  5. {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/METADATA +10 -42
  6. data_designer-0.4.0.dist-info/RECORD +39 -0
  7. data_designer/__init__.py +0 -17
  8. data_designer/config/__init__.py +0 -2
  9. data_designer/config/analysis/__init__.py +0 -2
  10. data_designer/config/analysis/column_profilers.py +0 -159
  11. data_designer/config/analysis/column_statistics.py +0 -421
  12. data_designer/config/analysis/dataset_profiler.py +0 -84
  13. data_designer/config/analysis/utils/errors.py +0 -10
  14. data_designer/config/analysis/utils/reporting.py +0 -192
  15. data_designer/config/base.py +0 -69
  16. data_designer/config/column_configs.py +0 -470
  17. data_designer/config/column_types.py +0 -141
  18. data_designer/config/config_builder.py +0 -595
  19. data_designer/config/data_designer_config.py +0 -40
  20. data_designer/config/dataset_builders.py +0 -13
  21. data_designer/config/dataset_metadata.py +0 -18
  22. data_designer/config/default_model_settings.py +0 -121
  23. data_designer/config/errors.py +0 -24
  24. data_designer/config/exports.py +0 -145
  25. data_designer/config/interface.py +0 -55
  26. data_designer/config/models.py +0 -455
  27. data_designer/config/preview_results.py +0 -41
  28. data_designer/config/processors.py +0 -148
  29. data_designer/config/run_config.py +0 -48
  30. data_designer/config/sampler_constraints.py +0 -52
  31. data_designer/config/sampler_params.py +0 -639
  32. data_designer/config/seed.py +0 -116
  33. data_designer/config/seed_source.py +0 -84
  34. data_designer/config/seed_source_types.py +0 -19
  35. data_designer/config/utils/code_lang.py +0 -82
  36. data_designer/config/utils/constants.py +0 -363
  37. data_designer/config/utils/errors.py +0 -21
  38. data_designer/config/utils/info.py +0 -94
  39. data_designer/config/utils/io_helpers.py +0 -258
  40. data_designer/config/utils/misc.py +0 -78
  41. data_designer/config/utils/numerical_helpers.py +0 -30
  42. data_designer/config/utils/type_helpers.py +0 -106
  43. data_designer/config/utils/visualization.py +0 -482
  44. data_designer/config/validator_params.py +0 -94
  45. data_designer/engine/__init__.py +0 -2
  46. data_designer/engine/analysis/column_profilers/base.py +0 -49
  47. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +0 -153
  48. data_designer/engine/analysis/column_profilers/registry.py +0 -22
  49. data_designer/engine/analysis/column_statistics.py +0 -145
  50. data_designer/engine/analysis/dataset_profiler.py +0 -149
  51. data_designer/engine/analysis/errors.py +0 -9
  52. data_designer/engine/analysis/utils/column_statistics_calculations.py +0 -234
  53. data_designer/engine/analysis/utils/judge_score_processing.py +0 -132
  54. data_designer/engine/column_generators/__init__.py +0 -2
  55. data_designer/engine/column_generators/generators/__init__.py +0 -2
  56. data_designer/engine/column_generators/generators/base.py +0 -122
  57. data_designer/engine/column_generators/generators/embedding.py +0 -35
  58. data_designer/engine/column_generators/generators/expression.py +0 -55
  59. data_designer/engine/column_generators/generators/llm_completion.py +0 -113
  60. data_designer/engine/column_generators/generators/samplers.py +0 -69
  61. data_designer/engine/column_generators/generators/seed_dataset.py +0 -144
  62. data_designer/engine/column_generators/generators/validation.py +0 -140
  63. data_designer/engine/column_generators/registry.py +0 -60
  64. data_designer/engine/column_generators/utils/errors.py +0 -15
  65. data_designer/engine/column_generators/utils/generator_classification.py +0 -43
  66. data_designer/engine/column_generators/utils/judge_score_factory.py +0 -58
  67. data_designer/engine/column_generators/utils/prompt_renderer.py +0 -100
  68. data_designer/engine/compiler.py +0 -97
  69. data_designer/engine/configurable_task.py +0 -71
  70. data_designer/engine/dataset_builders/artifact_storage.py +0 -283
  71. data_designer/engine/dataset_builders/column_wise_builder.py +0 -338
  72. data_designer/engine/dataset_builders/errors.py +0 -15
  73. data_designer/engine/dataset_builders/multi_column_configs.py +0 -46
  74. data_designer/engine/dataset_builders/utils/__init__.py +0 -2
  75. data_designer/engine/dataset_builders/utils/concurrency.py +0 -215
  76. data_designer/engine/dataset_builders/utils/config_compiler.py +0 -62
  77. data_designer/engine/dataset_builders/utils/dag.py +0 -62
  78. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +0 -200
  79. data_designer/engine/dataset_builders/utils/errors.py +0 -15
  80. data_designer/engine/errors.py +0 -51
  81. data_designer/engine/model_provider.py +0 -77
  82. data_designer/engine/models/__init__.py +0 -2
  83. data_designer/engine/models/errors.py +0 -300
  84. data_designer/engine/models/facade.py +0 -287
  85. data_designer/engine/models/factory.py +0 -42
  86. data_designer/engine/models/litellm_overrides.py +0 -179
  87. data_designer/engine/models/parsers/__init__.py +0 -2
  88. data_designer/engine/models/parsers/errors.py +0 -34
  89. data_designer/engine/models/parsers/parser.py +0 -235
  90. data_designer/engine/models/parsers/postprocessors.py +0 -93
  91. data_designer/engine/models/parsers/tag_parsers.py +0 -62
  92. data_designer/engine/models/parsers/types.py +0 -84
  93. data_designer/engine/models/recipes/base.py +0 -81
  94. data_designer/engine/models/recipes/response_recipes.py +0 -293
  95. data_designer/engine/models/registry.py +0 -146
  96. data_designer/engine/models/telemetry.py +0 -359
  97. data_designer/engine/models/usage.py +0 -73
  98. data_designer/engine/models/utils.py +0 -38
  99. data_designer/engine/processing/ginja/__init__.py +0 -2
  100. data_designer/engine/processing/ginja/ast.py +0 -65
  101. data_designer/engine/processing/ginja/environment.py +0 -463
  102. data_designer/engine/processing/ginja/exceptions.py +0 -56
  103. data_designer/engine/processing/ginja/record.py +0 -32
  104. data_designer/engine/processing/gsonschema/__init__.py +0 -2
  105. data_designer/engine/processing/gsonschema/exceptions.py +0 -15
  106. data_designer/engine/processing/gsonschema/schema_transformers.py +0 -83
  107. data_designer/engine/processing/gsonschema/types.py +0 -10
  108. data_designer/engine/processing/gsonschema/validators.py +0 -202
  109. data_designer/engine/processing/processors/base.py +0 -13
  110. data_designer/engine/processing/processors/drop_columns.py +0 -42
  111. data_designer/engine/processing/processors/registry.py +0 -25
  112. data_designer/engine/processing/processors/schema_transform.py +0 -49
  113. data_designer/engine/processing/utils.py +0 -169
  114. data_designer/engine/registry/base.py +0 -99
  115. data_designer/engine/registry/data_designer_registry.py +0 -39
  116. data_designer/engine/registry/errors.py +0 -12
  117. data_designer/engine/resources/managed_dataset_generator.py +0 -39
  118. data_designer/engine/resources/managed_dataset_repository.py +0 -197
  119. data_designer/engine/resources/managed_storage.py +0 -65
  120. data_designer/engine/resources/resource_provider.py +0 -77
  121. data_designer/engine/resources/seed_reader.py +0 -154
  122. data_designer/engine/sampling_gen/column.py +0 -91
  123. data_designer/engine/sampling_gen/constraints.py +0 -100
  124. data_designer/engine/sampling_gen/data_sources/base.py +0 -217
  125. data_designer/engine/sampling_gen/data_sources/errors.py +0 -12
  126. data_designer/engine/sampling_gen/data_sources/sources.py +0 -347
  127. data_designer/engine/sampling_gen/entities/__init__.py +0 -2
  128. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  129. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +0 -86
  130. data_designer/engine/sampling_gen/entities/email_address_utils.py +0 -171
  131. data_designer/engine/sampling_gen/entities/errors.py +0 -10
  132. data_designer/engine/sampling_gen/entities/national_id_utils.py +0 -102
  133. data_designer/engine/sampling_gen/entities/person.py +0 -144
  134. data_designer/engine/sampling_gen/entities/phone_number.py +0 -128
  135. data_designer/engine/sampling_gen/errors.py +0 -26
  136. data_designer/engine/sampling_gen/generator.py +0 -122
  137. data_designer/engine/sampling_gen/jinja_utils.py +0 -64
  138. data_designer/engine/sampling_gen/people_gen.py +0 -199
  139. data_designer/engine/sampling_gen/person_constants.py +0 -56
  140. data_designer/engine/sampling_gen/schema.py +0 -147
  141. data_designer/engine/sampling_gen/schema_builder.py +0 -61
  142. data_designer/engine/sampling_gen/utils.py +0 -46
  143. data_designer/engine/secret_resolver.py +0 -82
  144. data_designer/engine/validation.py +0 -367
  145. data_designer/engine/validators/__init__.py +0 -19
  146. data_designer/engine/validators/base.py +0 -38
  147. data_designer/engine/validators/local_callable.py +0 -39
  148. data_designer/engine/validators/python.py +0 -254
  149. data_designer/engine/validators/remote.py +0 -89
  150. data_designer/engine/validators/sql.py +0 -65
  151. data_designer/errors.py +0 -7
  152. data_designer/essentials/__init__.py +0 -33
  153. data_designer/lazy_heavy_imports.py +0 -54
  154. data_designer/logging.py +0 -163
  155. data_designer/plugin_manager.py +0 -78
  156. data_designer/plugins/__init__.py +0 -8
  157. data_designer/plugins/errors.py +0 -15
  158. data_designer/plugins/plugin.py +0 -141
  159. data_designer/plugins/registry.py +0 -88
  160. data_designer/plugins/testing/__init__.py +0 -10
  161. data_designer/plugins/testing/stubs.py +0 -116
  162. data_designer/plugins/testing/utils.py +0 -20
  163. data_designer-0.3.8rc1.dist-info/RECORD +0 -196
  164. data_designer-0.3.8rc1.dist-info/licenses/LICENSE +0 -201
  165. {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/WHEEL +0 -0
  166. {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/entry_points.txt +0 -0
@@ -1,100 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from abc import ABC, abstractmethod
7
- from typing import TYPE_CHECKING
8
-
9
- from numpy.typing import NDArray
10
-
11
- from data_designer.config.base import ConfigBase
12
- from data_designer.config.sampler_constraints import (
13
- ColumnInequalityConstraint,
14
- Constraint,
15
- ConstraintType,
16
- InequalityOperator,
17
- ScalarInequalityConstraint,
18
- )
19
- from data_designer.lazy_heavy_imports import np, pd
20
-
21
- if TYPE_CHECKING:
22
- import numpy as np
23
- import pandas as pd
24
-
25
-
26
- class ConstraintChecker(ConfigBase, ABC):
27
- constraint: Constraint
28
-
29
- def get_required_column_names(self) -> tuple[str, ...]:
30
- return (self.constraint.target_column,)
31
-
32
- @abstractmethod
33
- def check(self, dataframe: pd.DataFrame) -> NDArray[np.bool_]: ...
34
-
35
-
36
- class WithCompareMixin:
37
- @property
38
- def lhs(self) -> str:
39
- return self.constraint.target_column
40
-
41
- def compare(self, lhs: float | int | NDArray, rhs: float | int | NDArray) -> bool | NDArray[np.bool_]:
42
- operator = {
43
- InequalityOperator.LT: np.less,
44
- InequalityOperator.LE: np.less_equal,
45
- InequalityOperator.GT: np.greater,
46
- InequalityOperator.GE: np.greater_equal,
47
- }[InequalityOperator(self.constraint.operator)]
48
- return operator(lhs, rhs)
49
-
50
-
51
- class ScalarInequalityChecker(ConstraintChecker, WithCompareMixin):
52
- """Compare a column to a scalar value.
53
-
54
- Args:
55
- column_name: Name of the constrained column. Will be
56
- used as the left-hand side (lhs) of the comparison.
57
- operator: Comparison operator.
58
- rhs: Scalar value to compare against.
59
- """
60
-
61
- constraint: ScalarInequalityConstraint
62
-
63
- def check(self, dataframe: pd.DataFrame) -> NDArray[np.bool_]:
64
- return self.compare(dataframe[self.lhs].values, self.constraint.rhs)
65
-
66
-
67
- class ColumnInequalityChecker(ConstraintChecker, WithCompareMixin):
68
- """Compare the values of two columns.
69
-
70
- Args:
71
- column_name: Name of the constrained column. Will be
72
- used as the left-hand side (lhs) of the comparison.
73
- operator: Comparison operator.
74
- rhs: Name of the column to compare against.
75
- """
76
-
77
- constraint: ColumnInequalityConstraint
78
-
79
- def get_required_column_names(self) -> tuple[str, ...]:
80
- """Return the names of columns required for the constraint.
81
-
82
- Note that order matters. Edges in the DAG are created as column_names[1], column_names[0].
83
- """
84
- return (self.lhs, self.constraint.rhs)
85
-
86
- def check(self, dataframe: pd.DataFrame) -> NDArray[np.bool_]:
87
- return self.compare(
88
- dataframe[self.lhs].values,
89
- dataframe[self.constraint.rhs].values,
90
- )
91
-
92
-
93
- CONSTRAINT_TYPE_TO_CHECKER = {
94
- ConstraintType.SCALAR_INEQUALITY: ScalarInequalityChecker,
95
- ConstraintType.COLUMN_INEQUALITY: ColumnInequalityChecker,
96
- }
97
-
98
-
99
- def get_constraint_checker(constraint_type: ConstraintType) -> type[ConstraintChecker]:
100
- return CONSTRAINT_TYPE_TO_CHECKER[ConstraintType(constraint_type)]
@@ -1,217 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from abc import ABC, abstractmethod
7
- from typing import TYPE_CHECKING, Any, Generic, TypeVar
8
-
9
- from numpy.typing import NDArray
10
-
11
- from data_designer.config.sampler_params import SamplerParamsT
12
- from data_designer.engine.sampling_gen.utils import check_random_state
13
- from data_designer.lazy_heavy_imports import np, pd, scipy
14
-
15
- if TYPE_CHECKING:
16
- import numpy as np
17
- import pandas as pd
18
- import scipy
19
-
20
- NumpyArray1dT = NDArray[Any]
21
- RadomStateT = int | np.random.RandomState
22
-
23
- GenericParamsT = TypeVar("GenericParamsT", bound=SamplerParamsT)
24
-
25
- ###########################################################
26
- # Processing Mixins
27
- # -----------------
28
- # These mixins are used to apply pre and post processing
29
- # to the data source output. At the moment, the only
30
- # processing that is applied is an optional type/format
31
- # conversion of the output data.
32
- #
33
- # Preprocessing: Applied *before* constraints are applied.
34
- # Postprocessing: Applied at the end of dataset generation.
35
- #
36
- # IMPORTANT: These are only applied when the data are
37
- # being injected into a DataFrame by the DatasetGenerator.
38
- ###########################################################
39
-
40
-
41
- class PassthroughMixin:
42
- @staticmethod
43
- def preproc(series: pd.Series, convert_to: str) -> pd.Series:
44
- return series
45
-
46
- @staticmethod
47
- def postproc(series: pd.Series, convert_to: str) -> pd.Series:
48
- return series
49
-
50
- @staticmethod
51
- def validate_data_conversion(convert_to: str | None) -> None:
52
- pass
53
-
54
-
55
- class TypeConversionMixin:
56
- """Converts the data type of the output data.
57
-
58
- This mixin applies the same conversion to both the pre and post
59
- processing steps. The preprocessing is needed to ensure constraints
60
- are applied to the correct data type. The postprocessing is needed
61
- to ensure the final dtype is correct. For example, if the user wants an
62
- `int`, we need to convert to `int` before applying constraints, but
63
- the ints will be converted back to floats when injected into the
64
- dataframe (assuming some rows are non-int values). We therefore need
65
- to convert back to `int` after all constraints have been applied.
66
- """
67
-
68
- @staticmethod
69
- def preproc(series: pd.Series, convert_to: str) -> pd.Series:
70
- if convert_to is not None:
71
- if convert_to == "int":
72
- series = series.round()
73
- return series.astype(convert_to)
74
- return series
75
-
76
- @staticmethod
77
- def postproc(series: pd.Series, convert_to: str | None) -> pd.Series:
78
- if convert_to is not None:
79
- if convert_to == "int":
80
- series = series.round()
81
- return series.astype(convert_to)
82
- return series
83
-
84
- @staticmethod
85
- def validate_data_conversion(convert_to: str | None) -> None:
86
- if convert_to is not None and convert_to not in ["float", "int", "str"]:
87
- raise ValueError(f"Invalid `convert_to` value: {convert_to}. Must be one of: [float, int, str]")
88
-
89
-
90
- class DatetimeFormatMixin:
91
- @staticmethod
92
- def preproc(series: pd.Series, convert_to: str | None) -> pd.Series:
93
- return series
94
-
95
- @staticmethod
96
- def postproc(series: pd.Series, convert_to: str | None) -> pd.Series:
97
- if convert_to is not None:
98
- return series.dt.strftime(convert_to)
99
- if series.dt.month.nunique() == 1:
100
- return series.apply(lambda dt: dt.year).astype(str)
101
- if series.dt.day.nunique() == 1:
102
- return series.apply(lambda dt: dt.strftime("%Y-%m"))
103
- if series.dt.hour.sum() > 0 or series.dt.minute.sum() > 0:
104
- return series.apply(lambda dt: dt.isoformat()).astype(str)
105
- if series.dt.second.sum() == 0:
106
- return series.apply(lambda dt: dt.date()).astype(str)
107
- return series.apply(lambda dt: dt.isoformat()).astype(str)
108
-
109
- @staticmethod
110
- def validate_data_conversion(convert_to: str | None) -> None:
111
- if convert_to is not None:
112
- try:
113
- pd.to_datetime(pd.to_datetime("2012-12-21").strftime(convert_to))
114
- except Exception as e:
115
- raise ValueError(f"Invalid datetime format: {convert_to}. {e}")
116
-
117
-
118
- ###########################################################
119
- # Base Data Source Classes
120
- ###########################################################
121
-
122
-
123
- class DataSource(ABC, Generic[GenericParamsT]):
124
- def __init__(
125
- self,
126
- params: GenericParamsT,
127
- random_state: RadomStateT | None = None,
128
- **kwargs,
129
- ):
130
- self.rng = check_random_state(random_state)
131
- self.params = self.get_param_type().model_validate(params)
132
- self._setup(**kwargs)
133
- self._validate()
134
-
135
- @classmethod
136
- def get_param_type(cls) -> type[GenericParamsT]:
137
- return cls.__orig_bases__[-1].__args__[0]
138
-
139
- @abstractmethod
140
- def inject_data_column(
141
- self,
142
- dataframe: pd.DataFrame,
143
- column_name: str,
144
- index: list[int] | None = None,
145
- ) -> pd.DataFrame: ...
146
-
147
- @staticmethod
148
- @abstractmethod
149
- def preproc(series: pd.Series) -> pd.Series: ...
150
-
151
- @staticmethod
152
- @abstractmethod
153
- def postproc(series: pd.Series, convert_to: str | None) -> pd.Series: ...
154
-
155
- @staticmethod
156
- @abstractmethod
157
- def validate_data_conversion(convert_to: str | None) -> None: ...
158
-
159
- def get_required_column_names(self) -> tuple[str, ...]:
160
- return tuple()
161
-
162
- def _setup(self, **kwargs) -> None:
163
- pass
164
-
165
- def _validate(self) -> None:
166
- pass
167
-
168
-
169
- class Sampler(DataSource[GenericParamsT], ABC):
170
- def _recast_types_if_needed(
171
- self,
172
- index: list[int] | slice,
173
- column_name: str,
174
- sample: NumpyArray1dT,
175
- dataframe: pd.DataFrame,
176
- ) -> pd.DataFrame:
177
- # Type may be different if the column has mixed types / NaNs.
178
- if column_name in dataframe.columns:
179
- dtype = sample.dtype
180
- if dtype != dataframe.loc[index, column_name].dtype:
181
- dataframe = dataframe.astype({column_name: dtype}, errors="ignore")
182
- return dataframe
183
-
184
- def inject_data_column(
185
- self,
186
- dataframe: pd.DataFrame,
187
- column_name: str,
188
- index: list[int] | None = None,
189
- ) -> pd.DataFrame:
190
- index = slice(None) if index is None else index
191
-
192
- if len(index) == 0:
193
- return dataframe
194
-
195
- sample = self.sample(len(index))
196
-
197
- # Try recasting before assigning the sample to the dataframe, since setting an item
198
- # of incompatible dtype is deprecated and will raise an error in future versions.
199
- dataframe = self._recast_types_if_needed(index, column_name, sample, dataframe)
200
- dataframe.loc[index, column_name] = sample
201
-
202
- # Recast again in case the assignment led to inconsistencies (e.g., funny business from NaNs).
203
- dataframe = self._recast_types_if_needed(index, column_name, sample, dataframe)
204
-
205
- return dataframe
206
-
207
- @abstractmethod
208
- def sample(self, num_samples: int) -> NumpyArray1dT: ...
209
-
210
-
211
- class ScipyStatsSampler(Sampler[GenericParamsT], ABC):
212
- @property
213
- @abstractmethod
214
- def distribution(self) -> scipy.stats.rv_continuous | scipy.stats.rv_discrete: ...
215
-
216
- def sample(self, num_samples: int) -> NumpyArray1dT:
217
- return self.distribution.rvs(size=num_samples, random_state=self.rng)
@@ -1,12 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from data_designer.engine.sampling_gen.errors import SamplingGenError
7
-
8
-
9
- class InvalidSamplerParamsError(SamplingGenError): ...
10
-
11
-
12
- class PersonSamplerConstraintsError(SamplingGenError): ...
@@ -1,347 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import uuid
7
- from typing import TYPE_CHECKING
8
-
9
- from data_designer.config.sampler_params import (
10
- BernoulliMixtureSamplerParams,
11
- BernoulliSamplerParams,
12
- BinomialSamplerParams,
13
- CategorySamplerParams,
14
- DatetimeSamplerParams,
15
- GaussianSamplerParams,
16
- PersonFromFakerSamplerParams,
17
- PersonSamplerParams,
18
- PoissonSamplerParams,
19
- SamplerParamsT,
20
- SamplerType,
21
- ScipySamplerParams,
22
- SubcategorySamplerParams,
23
- TimeDeltaSamplerParams,
24
- UniformSamplerParams,
25
- UUIDSamplerParams,
26
- )
27
- from data_designer.engine.sampling_gen.data_sources.base import (
28
- DataSource,
29
- DatetimeFormatMixin,
30
- NumpyArray1dT,
31
- PassthroughMixin,
32
- Sampler,
33
- ScipyStatsSampler,
34
- TypeConversionMixin,
35
- )
36
- from data_designer.engine.sampling_gen.data_sources.errors import (
37
- InvalidSamplerParamsError,
38
- PersonSamplerConstraintsError,
39
- )
40
- from data_designer.engine.sampling_gen.entities.dataset_based_person_fields import PERSONA_FIELDS, PII_FIELDS
41
- from data_designer.engine.sampling_gen.people_gen import PeopleGen
42
- from data_designer.lazy_heavy_imports import np, pd, scipy
43
-
44
- if TYPE_CHECKING:
45
- import numpy as np
46
- import pandas as pd
47
- import scipy
48
-
49
- ONE_BILLION = 10**9
50
-
51
-
52
- class SamplerRegistry:
53
- _registry: dict[str, type] = {}
54
- _reverse_registry: dict[type, str] = {}
55
- _params_registry: dict[type, type] = {}
56
-
57
- @classmethod
58
- def register(cls, alias: str):
59
- def decorator(wrapped_class: type[DataSource[SamplerParamsT]]) -> type:
60
- cls._registry[alias] = wrapped_class
61
- cls._reverse_registry[wrapped_class] = alias
62
- cls._params_registry[wrapped_class.get_param_type()] = wrapped_class
63
- return wrapped_class
64
-
65
- return decorator
66
-
67
- @classmethod
68
- def get_sampler(cls, alias: str) -> type[DataSource[SamplerParamsT]]:
69
- return cls._registry[alias.lower()]
70
-
71
- @classmethod
72
- def get_sampler_for_params(cls, params_type: SamplerParamsT) -> type[DataSource[SamplerParamsT]]:
73
- return cls._params_registry[type(params_type)]
74
-
75
- @classmethod
76
- def get_sampler_alias_for_params(cls, params_type: SamplerParamsT) -> str:
77
- return cls._reverse_registry[cls._params_registry[type(params_type)]]
78
-
79
- @classmethod
80
- def is_registered(cls, alias: str) -> bool:
81
- return alias in cls._registry
82
-
83
- @classmethod
84
- def validate_sampler_type(
85
- cls, sampler_type: str | type[DataSource[SamplerParamsT]]
86
- ) -> type[DataSource[SamplerParamsT]]:
87
- if isinstance(sampler_type, str):
88
- if sampler_type not in cls._registry:
89
- raise ValueError(
90
- f"Sampler type `{sampler_type}` not found in the registry. "
91
- f"Available samplers: {list(cls._registry.keys())}"
92
- )
93
- sampler_type = cls.get_sampler(sampler_type)
94
- if not issubclass(sampler_type, DataSource):
95
- raise ValueError(f"Sampler type `{sampler_type}` is not a subclass of `DataSource`")
96
- return sampler_type
97
-
98
-
99
- #########################################
100
- # Data Source Subclasses
101
- #########################################
102
-
103
-
104
- @SamplerRegistry.register(SamplerType.SUBCATEGORY)
105
- class SubcategorySampler(TypeConversionMixin, DataSource[SubcategorySamplerParams]):
106
- def get_required_column_names(self) -> tuple[str, ...]:
107
- return (self.params.category,)
108
-
109
- def inject_data_column(
110
- self,
111
- dataframe: pd.DataFrame,
112
- column_name: str,
113
- index: list[int] | None = None,
114
- ) -> pd.DataFrame:
115
- index = slice(None) if index is None else index
116
-
117
- if len(index) == 0:
118
- return dataframe
119
-
120
- dataframe.loc[index, column_name] = dataframe.loc[index, self.params.category].apply(
121
- lambda cat_value: self.rng.choice(self.params.values[cat_value])
122
- )
123
-
124
- return dataframe
125
-
126
-
127
- #########################################
128
- # Sampler Subclasses
129
- #########################################
130
-
131
-
132
- @SamplerRegistry.register(SamplerType.CATEGORY)
133
- class CategorySampler(TypeConversionMixin, Sampler[CategorySamplerParams]):
134
- def sample(self, num_samples: int) -> NumpyArray1dT:
135
- return self.rng.choice(self.params.values, size=num_samples, p=self.params.weights)
136
-
137
-
138
- @SamplerRegistry.register(SamplerType.DATETIME)
139
- class DatetimeSampler(DatetimeFormatMixin, Sampler[DatetimeSamplerParams]):
140
- def sample(self, num_samples: int) -> NumpyArray1dT:
141
- # Convert nanoseconds to seconds.
142
- start_sec = pd.to_datetime(self.params.start).value // ONE_BILLION
143
- end_sec = pd.to_datetime(self.params.end).value // ONE_BILLION
144
-
145
- random_ns = (ONE_BILLION * self.rng.randint(start_sec, end_sec, num_samples, dtype=np.int64)).view(
146
- "datetime64[ns]"
147
- )
148
-
149
- return np.array(random_ns, dtype=f"datetime64[{self.params.unit}]")
150
-
151
-
152
- @SamplerRegistry.register(SamplerType.PERSON)
153
- class PersonSampler(PassthroughMixin, Sampler[PersonSamplerParams]):
154
- def _setup(self, **kwargs) -> None:
155
- self._generator = None
156
- self._fixed_kwargs = {}
157
- for field in self.params.generator_kwargs:
158
- if getattr(self.params, field) is not None:
159
- attr = getattr(self.params, field)
160
- if field == "select_field_values":
161
- for key, value in attr.items():
162
- if key == "state" and self.params.locale == "en_US":
163
- key = "region" # This is the field name in the census-based person dataset.
164
- if key not in PII_FIELDS + PERSONA_FIELDS:
165
- raise ValueError(f"Invalid field name: {key}")
166
- self._fixed_kwargs[key] = value
167
- else:
168
- self._fixed_kwargs[field] = attr
169
- if people_gen_resource := kwargs.get("people_gen_resource"):
170
- if self.params.people_gen_key not in people_gen_resource:
171
- raise ValueError(f"Person generator with key {self.params.people_gen_key} not found.")
172
- self.set_generator(people_gen_resource[self.params.people_gen_key])
173
-
174
- def set_generator(self, generator: PeopleGen) -> None:
175
- self._generator = generator
176
-
177
- def sample(self, num_samples: int) -> NumpyArray1dT:
178
- if self._generator is None:
179
- raise ValueError("Generator not set. Please setup generator before sampling.")
180
-
181
- samples = np.array(self._generator.generate(num_samples, **self._fixed_kwargs))
182
- if len(samples) < num_samples:
183
- raise PersonSamplerConstraintsError(
184
- f"🛑 Only {len(samples)} samples could be generated with the given settings: {self._fixed_kwargs!r}. "
185
- "This is likely because the filter values are too strict. Person sampling does not support "
186
- "rare combinations of field values. Please loosen the constraints and try again."
187
- )
188
- return samples
189
-
190
-
191
- @SamplerRegistry.register(SamplerType.PERSON_FROM_FAKER)
192
- class PersonFromFakerSampler(PassthroughMixin, Sampler[PersonFromFakerSamplerParams]):
193
- def _setup(self, **kwargs) -> None:
194
- self._generator = None
195
- self._fixed_kwargs = {}
196
- for field in self.params.generator_kwargs:
197
- if getattr(self.params, field) is not None:
198
- self._fixed_kwargs[field] = getattr(self.params, field)
199
- if people_gen_resource := kwargs.get("people_gen_resource"):
200
- if self.params.people_gen_key not in people_gen_resource:
201
- raise ValueError(f"Person generator with key {self.params.people_gen_key} not found.")
202
- self.set_generator(people_gen_resource[self.params.people_gen_key])
203
-
204
- def set_generator(self, generator: PeopleGen) -> None:
205
- self._generator = generator
206
-
207
- def sample(self, num_samples: int) -> NumpyArray1dT:
208
- if self._generator is None:
209
- raise ValueError("Generator not set. Please setup generator before sampling.")
210
-
211
- samples = np.array(self._generator.generate(num_samples, **self._fixed_kwargs))
212
- if len(samples) < num_samples:
213
- raise ValueError(f"Only {len(samples)} samples could be generated given constraints {self._fixed_kwargs}.")
214
- return samples
215
-
216
-
217
- @SamplerRegistry.register(SamplerType.TIMEDELTA)
218
- class TimeDeltaSampler(DatetimeFormatMixin, Sampler[TimeDeltaSamplerParams]):
219
- def get_required_column_names(self) -> tuple[str, ...]:
220
- return (self.params.reference_column_name,)
221
-
222
- def inject_data_column(
223
- self,
224
- dataframe: pd.DataFrame,
225
- column_name: str,
226
- index: list[int] | None = None,
227
- ) -> pd.DataFrame:
228
- index = slice(None) if index is None else index
229
-
230
- if self.params.reference_column_name not in list(dataframe):
231
- raise ValueError(f"Columns `{self.params.reference_column_name}` not found in dataset")
232
-
233
- dataframe.loc[index, column_name] = pd.to_datetime(
234
- dataframe.loc[index, self.params.reference_column_name]
235
- ) + pd.to_timedelta(self.sample(len(index)), unit=self.params.unit)
236
-
237
- return dataframe
238
-
239
- def sample(self, num_samples: int) -> NumpyArray1dT:
240
- deltas = self.rng.randint(self.params.dt_min, self.params.dt_max, num_samples)
241
- return np.array(deltas, dtype=f"timedelta64[{self.params.unit}]")
242
-
243
-
244
- @SamplerRegistry.register(SamplerType.UUID)
245
- class UUIDSampler(PassthroughMixin, Sampler[UUIDSamplerParams]):
246
- def sample(self, num_samples: int) -> NumpyArray1dT:
247
- prefix = self.params.prefix or ""
248
-
249
- uid_list = []
250
- while len(uid_list) < num_samples:
251
- uid = (
252
- f"{prefix}{uuid.uuid4().hex[: self.params.last_index].upper()}"
253
- if self.params.uppercase
254
- else f"{prefix}{uuid.uuid4().hex[: self.params.last_index]}"
255
- )
256
- if uid not in uid_list:
257
- uid_list.append(uid)
258
-
259
- return np.array(uid_list)
260
-
261
-
262
- #########################################
263
- # Scipy Samplers
264
- #########################################
265
-
266
-
267
- @SamplerRegistry.register(SamplerType.SCIPY)
268
- class ScipySampler(TypeConversionMixin, ScipyStatsSampler[ScipySamplerParams]):
269
- """Escape hatch sampler to give users access to any scipy.stats distribution."""
270
-
271
- @property
272
- def distribution(self) -> scipy.stats.rv_continuous | scipy.stats.rv_discrete:
273
- return getattr(scipy.stats, self.params.dist_name)(**self.params.dist_params)
274
-
275
- def _validate(self) -> None:
276
- _validate_scipy_distribution(self.params.dist_name, self.params.dist_params)
277
-
278
-
279
- @SamplerRegistry.register(SamplerType.BERNOULLI)
280
- class BernoulliSampler(TypeConversionMixin, ScipyStatsSampler[BernoulliSamplerParams]):
281
- @property
282
- def distribution(self) -> scipy.stats.rv_discrete:
283
- return scipy.stats.bernoulli(p=self.params.p)
284
-
285
-
286
- @SamplerRegistry.register(SamplerType.BERNOULLI_MIXTURE)
287
- class BernoulliMixtureSampler(TypeConversionMixin, Sampler[BernoulliMixtureSamplerParams]):
288
- def sample(self, num_samples: int) -> NumpyArray1dT:
289
- return scipy.stats.bernoulli(p=self.params.p).rvs(size=num_samples) * getattr(
290
- scipy.stats, self.params.dist_name
291
- )(**self.params.dist_params).rvs(size=num_samples)
292
-
293
- def _validate(self) -> None:
294
- _validate_scipy_distribution(self.params.dist_name, self.params.dist_params)
295
-
296
-
297
- @SamplerRegistry.register(SamplerType.BINOMIAL)
298
- class BinomialSampler(TypeConversionMixin, ScipyStatsSampler[BinomialSamplerParams]):
299
- @property
300
- def distribution(self) -> scipy.stats.rv_discrete:
301
- return scipy.stats.binom(n=self.params.n, p=self.params.p)
302
-
303
-
304
- @SamplerRegistry.register(SamplerType.GAUSSIAN)
305
- class GaussianSampler(TypeConversionMixin, ScipyStatsSampler[GaussianSamplerParams]):
306
- @property
307
- def distribution(self) -> scipy.stats.rv_continuous:
308
- return scipy.stats.norm(loc=self.params.mean, scale=self.params.stddev)
309
-
310
-
311
- @SamplerRegistry.register(SamplerType.POISSON)
312
- class PoissonSampler(TypeConversionMixin, ScipyStatsSampler[PoissonSamplerParams]):
313
- @property
314
- def distribution(self) -> scipy.stats.rv_discrete:
315
- return scipy.stats.poisson(mu=self.params.mean)
316
-
317
-
318
- @SamplerRegistry.register(SamplerType.UNIFORM)
319
- class UniformSampler(TypeConversionMixin, ScipyStatsSampler[UniformSamplerParams]):
320
- @property
321
- def distribution(self) -> scipy.stats.rv_continuous:
322
- return scipy.stats.uniform(loc=self.params.low, scale=self.params.high - self.params.low)
323
-
324
-
325
- ###################################################
326
- # Helper functions for loading sources in isolation
327
- ###################################################
328
-
329
-
330
- def load_sampler(sampler_type: SamplerType, **params) -> DataSource:
331
- """Load a data source from a source type and parameters."""
332
- return SamplerRegistry.validate_sampler_type(sampler_type)(params=params)
333
-
334
-
335
- def _validate_scipy_distribution(dist_name: str, dist_params: dict) -> None:
336
- if not hasattr(scipy.stats, dist_name):
337
- raise InvalidSamplerParamsError(f"Distribution {dist_name} not found in scipy.stats")
338
- if not hasattr(getattr(scipy.stats, dist_name), "rvs"):
339
- raise InvalidSamplerParamsError(
340
- f"Distribution {dist_name} does not have a `rvs` method, which is required for sampling."
341
- )
342
- try:
343
- getattr(scipy.stats, dist_name)(**dist_params)
344
- except Exception:
345
- raise InvalidSamplerParamsError(
346
- f"Distribution parameters {dist_params} are not a valid for distribution '{dist_name}'"
347
- )
@@ -1,2 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0