data-designer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. data_designer/__init__.py +15 -0
  2. data_designer/_version.py +34 -0
  3. data_designer/cli/README.md +236 -0
  4. data_designer/cli/__init__.py +6 -0
  5. data_designer/cli/commands/__init__.py +2 -0
  6. data_designer/cli/commands/list.py +130 -0
  7. data_designer/cli/commands/models.py +10 -0
  8. data_designer/cli/commands/providers.py +11 -0
  9. data_designer/cli/commands/reset.py +100 -0
  10. data_designer/cli/controllers/__init__.py +7 -0
  11. data_designer/cli/controllers/model_controller.py +246 -0
  12. data_designer/cli/controllers/provider_controller.py +317 -0
  13. data_designer/cli/forms/__init__.py +20 -0
  14. data_designer/cli/forms/builder.py +51 -0
  15. data_designer/cli/forms/field.py +180 -0
  16. data_designer/cli/forms/form.py +59 -0
  17. data_designer/cli/forms/model_builder.py +125 -0
  18. data_designer/cli/forms/provider_builder.py +76 -0
  19. data_designer/cli/main.py +44 -0
  20. data_designer/cli/repositories/__init__.py +8 -0
  21. data_designer/cli/repositories/base.py +39 -0
  22. data_designer/cli/repositories/model_repository.py +42 -0
  23. data_designer/cli/repositories/provider_repository.py +43 -0
  24. data_designer/cli/services/__init__.py +7 -0
  25. data_designer/cli/services/model_service.py +116 -0
  26. data_designer/cli/services/provider_service.py +111 -0
  27. data_designer/cli/ui.py +448 -0
  28. data_designer/cli/utils.py +47 -0
  29. data_designer/config/__init__.py +2 -0
  30. data_designer/config/analysis/column_profilers.py +89 -0
  31. data_designer/config/analysis/column_statistics.py +274 -0
  32. data_designer/config/analysis/dataset_profiler.py +60 -0
  33. data_designer/config/analysis/utils/errors.py +8 -0
  34. data_designer/config/analysis/utils/reporting.py +188 -0
  35. data_designer/config/base.py +68 -0
  36. data_designer/config/column_configs.py +354 -0
  37. data_designer/config/column_types.py +168 -0
  38. data_designer/config/config_builder.py +660 -0
  39. data_designer/config/data_designer_config.py +40 -0
  40. data_designer/config/dataset_builders.py +11 -0
  41. data_designer/config/datastore.py +151 -0
  42. data_designer/config/default_model_settings.py +123 -0
  43. data_designer/config/errors.py +19 -0
  44. data_designer/config/interface.py +54 -0
  45. data_designer/config/models.py +231 -0
  46. data_designer/config/preview_results.py +32 -0
  47. data_designer/config/processors.py +41 -0
  48. data_designer/config/sampler_constraints.py +51 -0
  49. data_designer/config/sampler_params.py +604 -0
  50. data_designer/config/seed.py +145 -0
  51. data_designer/config/utils/code_lang.py +83 -0
  52. data_designer/config/utils/constants.py +313 -0
  53. data_designer/config/utils/errors.py +19 -0
  54. data_designer/config/utils/info.py +88 -0
  55. data_designer/config/utils/io_helpers.py +273 -0
  56. data_designer/config/utils/misc.py +81 -0
  57. data_designer/config/utils/numerical_helpers.py +28 -0
  58. data_designer/config/utils/type_helpers.py +100 -0
  59. data_designer/config/utils/validation.py +336 -0
  60. data_designer/config/utils/visualization.py +427 -0
  61. data_designer/config/validator_params.py +96 -0
  62. data_designer/engine/__init__.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +55 -0
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
  65. data_designer/engine/analysis/column_profilers/registry.py +20 -0
  66. data_designer/engine/analysis/column_statistics.py +142 -0
  67. data_designer/engine/analysis/dataset_profiler.py +125 -0
  68. data_designer/engine/analysis/errors.py +7 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
  70. data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
  71. data_designer/engine/column_generators/__init__.py +2 -0
  72. data_designer/engine/column_generators/generators/__init__.py +2 -0
  73. data_designer/engine/column_generators/generators/base.py +61 -0
  74. data_designer/engine/column_generators/generators/expression.py +63 -0
  75. data_designer/engine/column_generators/generators/llm_generators.py +172 -0
  76. data_designer/engine/column_generators/generators/samplers.py +75 -0
  77. data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
  78. data_designer/engine/column_generators/generators/validation.py +147 -0
  79. data_designer/engine/column_generators/registry.py +56 -0
  80. data_designer/engine/column_generators/utils/errors.py +13 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
  83. data_designer/engine/configurable_task.py +82 -0
  84. data_designer/engine/dataset_builders/artifact_storage.py +181 -0
  85. data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
  86. data_designer/engine/dataset_builders/errors.py +13 -0
  87. data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
  88. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +56 -0
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
  93. data_designer/engine/dataset_builders/utils/errors.py +13 -0
  94. data_designer/engine/errors.py +49 -0
  95. data_designer/engine/model_provider.py +75 -0
  96. data_designer/engine/models/__init__.py +2 -0
  97. data_designer/engine/models/errors.py +308 -0
  98. data_designer/engine/models/facade.py +225 -0
  99. data_designer/engine/models/litellm_overrides.py +162 -0
  100. data_designer/engine/models/parsers/__init__.py +2 -0
  101. data_designer/engine/models/parsers/errors.py +34 -0
  102. data_designer/engine/models/parsers/parser.py +236 -0
  103. data_designer/engine/models/parsers/postprocessors.py +93 -0
  104. data_designer/engine/models/parsers/tag_parsers.py +60 -0
  105. data_designer/engine/models/parsers/types.py +82 -0
  106. data_designer/engine/models/recipes/base.py +79 -0
  107. data_designer/engine/models/recipes/response_recipes.py +291 -0
  108. data_designer/engine/models/registry.py +118 -0
  109. data_designer/engine/models/usage.py +75 -0
  110. data_designer/engine/models/utils.py +38 -0
  111. data_designer/engine/processing/ginja/__init__.py +2 -0
  112. data_designer/engine/processing/ginja/ast.py +64 -0
  113. data_designer/engine/processing/ginja/environment.py +461 -0
  114. data_designer/engine/processing/ginja/exceptions.py +54 -0
  115. data_designer/engine/processing/ginja/record.py +30 -0
  116. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  117. data_designer/engine/processing/gsonschema/exceptions.py +8 -0
  118. data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
  119. data_designer/engine/processing/gsonschema/types.py +8 -0
  120. data_designer/engine/processing/gsonschema/validators.py +143 -0
  121. data_designer/engine/processing/processors/base.py +15 -0
  122. data_designer/engine/processing/processors/drop_columns.py +46 -0
  123. data_designer/engine/processing/processors/registry.py +20 -0
  124. data_designer/engine/processing/utils.py +120 -0
  125. data_designer/engine/registry/base.py +97 -0
  126. data_designer/engine/registry/data_designer_registry.py +37 -0
  127. data_designer/engine/registry/errors.py +10 -0
  128. data_designer/engine/resources/managed_dataset_generator.py +35 -0
  129. data_designer/engine/resources/managed_dataset_repository.py +194 -0
  130. data_designer/engine/resources/managed_storage.py +63 -0
  131. data_designer/engine/resources/resource_provider.py +46 -0
  132. data_designer/engine/resources/seed_dataset_data_store.py +66 -0
  133. data_designer/engine/sampling_gen/column.py +89 -0
  134. data_designer/engine/sampling_gen/constraints.py +95 -0
  135. data_designer/engine/sampling_gen/data_sources/base.py +214 -0
  136. data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
  137. data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
  138. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  139. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  140. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
  141. data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
  142. data_designer/engine/sampling_gen/entities/errors.py +8 -0
  143. data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
  144. data_designer/engine/sampling_gen/entities/person.py +142 -0
  145. data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
  146. data_designer/engine/sampling_gen/errors.py +24 -0
  147. data_designer/engine/sampling_gen/generator.py +121 -0
  148. data_designer/engine/sampling_gen/jinja_utils.py +60 -0
  149. data_designer/engine/sampling_gen/people_gen.py +203 -0
  150. data_designer/engine/sampling_gen/person_constants.py +54 -0
  151. data_designer/engine/sampling_gen/schema.py +143 -0
  152. data_designer/engine/sampling_gen/schema_builder.py +59 -0
  153. data_designer/engine/sampling_gen/utils.py +40 -0
  154. data_designer/engine/secret_resolver.py +80 -0
  155. data_designer/engine/validators/__init__.py +17 -0
  156. data_designer/engine/validators/base.py +36 -0
  157. data_designer/engine/validators/local_callable.py +34 -0
  158. data_designer/engine/validators/python.py +245 -0
  159. data_designer/engine/validators/remote.py +83 -0
  160. data_designer/engine/validators/sql.py +60 -0
  161. data_designer/errors.py +5 -0
  162. data_designer/essentials/__init__.py +137 -0
  163. data_designer/interface/__init__.py +2 -0
  164. data_designer/interface/data_designer.py +351 -0
  165. data_designer/interface/errors.py +16 -0
  166. data_designer/interface/results.py +55 -0
  167. data_designer/logging.py +161 -0
  168. data_designer/plugin_manager.py +83 -0
  169. data_designer/plugins/__init__.py +6 -0
  170. data_designer/plugins/errors.py +10 -0
  171. data_designer/plugins/plugin.py +69 -0
  172. data_designer/plugins/registry.py +86 -0
  173. data_designer-0.1.0.dist-info/METADATA +173 -0
  174. data_designer-0.1.0.dist-info/RECORD +177 -0
  175. data_designer-0.1.0.dist-info/WHEEL +4 -0
  176. data_designer-0.1.0.dist-info/entry_points.txt +2 -0
  177. data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,214 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC, abstractmethod
5
+ from typing import Any, Generic, Optional, Type, TypeVar, Union
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+ import pandas as pd
10
+ from scipy import stats
11
+
12
+ from data_designer.config.sampler_params import SamplerParamsT
13
+ from data_designer.engine.sampling_gen.utils import check_random_state
14
+
15
+ NumpyArray1dT = NDArray[Any]
16
+ RadomStateT = int | np.random.RandomState
17
+
18
+
19
+ GenericParamsT = TypeVar("GenericParamsT", bound=SamplerParamsT)
20
+
21
+
22
+ ###########################################################
23
+ # Processing Mixins
24
+ # -----------------
25
+ # These mixins are used to apply pre and post processing
26
+ # to the data source output. At the moment, the only
27
+ # processing that is applied is an optional type/format
28
+ # conversion of the output data.
29
+ #
30
+ # Preprocessing: Applied *before* constraints are applied.
31
+ # Postprocessing: Applied at the end of dataset generation.
32
+ #
33
+ # IMPORTANT: These are only applied when the data are
34
+ # being injected into a DataFrame by the DatasetGenerator.
35
+ ###########################################################
36
+
37
+
38
+ class PassthroughMixin:
39
+ @staticmethod
40
+ def preproc(series: pd.Series, convert_to: str) -> pd.Series:
41
+ return series
42
+
43
+ @staticmethod
44
+ def postproc(series: pd.Series, convert_to: str) -> pd.Series:
45
+ return series
46
+
47
+ @staticmethod
48
+ def validate_data_conversion(convert_to: Optional[str]) -> None:
49
+ pass
50
+
51
+
52
+ class TypeConversionMixin:
53
+ """Converts the data type of the output data.
54
+
55
+ This mixin applies the same conversion to both the pre and post
56
+ processing steps. The preprocessing is needed to ensure constraints
57
+ are applied to the correct data type. The postprocessing is needed
58
+ to ensure the final dtype is correct. For example, if the user wants an
59
+ `int`, we need to convert to `int` before applying constraints, but
60
+ the ints will be converted back to floats when injected into the
61
+ dataframe (assuming some rows are non-int values). We therefore need
62
+ to convert back to `int` after all constraints have been applied.
63
+ """
64
+
65
+ @staticmethod
66
+ def preproc(series: pd.Series, convert_to: str) -> pd.Series:
67
+ if convert_to is not None:
68
+ if convert_to == "int":
69
+ series = series.round()
70
+ return series.astype(convert_to)
71
+ return series
72
+
73
+ @staticmethod
74
+ def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series:
75
+ if convert_to is not None:
76
+ if convert_to == "int":
77
+ series = series.round()
78
+ return series.astype(convert_to)
79
+ return series
80
+
81
+ @staticmethod
82
+ def validate_data_conversion(convert_to: Optional[str]) -> None:
83
+ if convert_to is not None and convert_to not in ["float", "int", "str"]:
84
+ raise ValueError(f"Invalid `convert_to` value: {convert_to}. Must be one of: [float, int, str]")
85
+
86
+
87
+ class DatetimeFormatMixin:
88
+ @staticmethod
89
+ def preproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series:
90
+ return series
91
+
92
+ @staticmethod
93
+ def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series:
94
+ if convert_to is not None:
95
+ return series.dt.strftime(convert_to)
96
+ if series.dt.month.nunique() == 1:
97
+ return series.apply(lambda dt: dt.year).astype(str)
98
+ if series.dt.day.nunique() == 1:
99
+ return series.apply(lambda dt: dt.strftime("%Y-%m"))
100
+ if series.dt.hour.sum() > 0 or series.dt.minute.sum() > 0:
101
+ return series.apply(lambda dt: dt.isoformat()).astype(str)
102
+ if series.dt.second.sum() == 0:
103
+ return series.apply(lambda dt: dt.date()).astype(str)
104
+ return series.apply(lambda dt: dt.isoformat()).astype(str)
105
+
106
+ @staticmethod
107
+ def validate_data_conversion(convert_to: Optional[str]) -> None:
108
+ if convert_to is not None:
109
+ try:
110
+ pd.to_datetime(pd.to_datetime("2012-12-21").strftime(convert_to))
111
+ except Exception as e:
112
+ raise ValueError(f"Invalid datetime format: {convert_to}. {e}")
113
+
114
+
115
+ ###########################################################
116
+ # Base Data Source Classes
117
+ ###########################################################
118
+
119
+
120
+ class DataSource(ABC, Generic[GenericParamsT]):
121
+ def __init__(
122
+ self,
123
+ params: GenericParamsT,
124
+ random_state: Optional[RadomStateT] = None,
125
+ **kwargs,
126
+ ):
127
+ self.rng = check_random_state(random_state)
128
+ self.params = self.get_param_type().model_validate(params)
129
+ self._setup(**kwargs)
130
+ self._validate()
131
+
132
+ @classmethod
133
+ def get_param_type(cls) -> Type[GenericParamsT]:
134
+ return cls.__orig_bases__[-1].__args__[0]
135
+
136
+ @abstractmethod
137
+ def inject_data_column(
138
+ self,
139
+ dataframe: pd.DataFrame,
140
+ column_name: str,
141
+ index: Optional[list[int]] = None,
142
+ ) -> pd.DataFrame: ...
143
+
144
+ @staticmethod
145
+ @abstractmethod
146
+ def preproc(series: pd.Series) -> pd.Series: ...
147
+
148
+ @staticmethod
149
+ @abstractmethod
150
+ def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series: ...
151
+
152
+ @staticmethod
153
+ @abstractmethod
154
+ def validate_data_conversion(convert_to: Optional[str]) -> None: ...
155
+
156
+ def get_required_column_names(self) -> tuple[str, ...]:
157
+ return tuple()
158
+
159
+ def _setup(self, **kwargs) -> None:
160
+ pass
161
+
162
+ def _validate(self) -> None:
163
+ pass
164
+
165
+
166
+ class Sampler(DataSource[GenericParamsT], ABC):
167
+ def _recast_types_if_needed(
168
+ self,
169
+ index: list[int] | slice,
170
+ column_name: str,
171
+ sample: NumpyArray1dT,
172
+ dataframe: pd.DataFrame,
173
+ ) -> pd.DataFrame:
174
+ # Type may be different if the column has mixed types / NaNs.
175
+ if column_name in dataframe.columns:
176
+ dtype = sample.dtype
177
+ if dtype != dataframe.loc[index, column_name].dtype:
178
+ dataframe = dataframe.astype({column_name: dtype}, errors="ignore")
179
+ return dataframe
180
+
181
+ def inject_data_column(
182
+ self,
183
+ dataframe: pd.DataFrame,
184
+ column_name: str,
185
+ index: Optional[list[int]] = None,
186
+ ) -> pd.DataFrame:
187
+ index = slice(None) if index is None else index
188
+
189
+ if len(index) == 0:
190
+ return dataframe
191
+
192
+ sample = self.sample(len(index))
193
+
194
+ # Try recasting before assigning the sample to the dataframe, since setting an item
195
+ # of incompatible dtype is deprecated and will raise an error in future versions.
196
+ dataframe = self._recast_types_if_needed(index, column_name, sample, dataframe)
197
+ dataframe.loc[index, column_name] = sample
198
+
199
+ # Recast again in case the assignment led to inconsistencies (e.g., funny business from NaNs).
200
+ dataframe = self._recast_types_if_needed(index, column_name, sample, dataframe)
201
+
202
+ return dataframe
203
+
204
+ @abstractmethod
205
+ def sample(self, num_samples: int) -> NumpyArray1dT: ...
206
+
207
+
208
+ class ScipyStatsSampler(Sampler[GenericParamsT], ABC):
209
+ @property
210
+ @abstractmethod
211
+ def distribution(self) -> Union[stats.rv_continuous, stats.rv_discrete]: ...
212
+
213
+ def sample(self, num_samples: int) -> NumpyArray1dT:
214
+ return self.distribution.rvs(size=num_samples, random_state=self.rng)
@@ -0,0 +1,10 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from data_designer.engine.sampling_gen.errors import SamplingGenError
5
+
6
+
7
+ class InvalidSamplerParamsError(SamplingGenError): ...
8
+
9
+
10
+ class PersonSamplerConstraintsError(SamplingGenError): ...
@@ -0,0 +1,342 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import uuid
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from scipy import stats
9
+
10
+ from data_designer.config.sampler_params import (
11
+ BernoulliMixtureSamplerParams,
12
+ BernoulliSamplerParams,
13
+ BinomialSamplerParams,
14
+ CategorySamplerParams,
15
+ DatetimeSamplerParams,
16
+ GaussianSamplerParams,
17
+ PersonFromFakerSamplerParams,
18
+ PersonSamplerParams,
19
+ PoissonSamplerParams,
20
+ SamplerParamsT,
21
+ SamplerType,
22
+ ScipySamplerParams,
23
+ SubcategorySamplerParams,
24
+ TimeDeltaSamplerParams,
25
+ UniformSamplerParams,
26
+ UUIDSamplerParams,
27
+ )
28
+ from data_designer.engine.sampling_gen.data_sources.base import (
29
+ DataSource,
30
+ DatetimeFormatMixin,
31
+ NumpyArray1dT,
32
+ PassthroughMixin,
33
+ Sampler,
34
+ ScipyStatsSampler,
35
+ TypeConversionMixin,
36
+ )
37
+ from data_designer.engine.sampling_gen.data_sources.errors import (
38
+ InvalidSamplerParamsError,
39
+ PersonSamplerConstraintsError,
40
+ )
41
+ from data_designer.engine.sampling_gen.entities.dataset_based_person_fields import PERSONA_FIELDS, PII_FIELDS
42
+ from data_designer.engine.sampling_gen.people_gen import PeopleGen
43
+
44
+ ONE_BILLION = 10**9
45
+
46
+
47
+ class SamplerRegistry:
48
+ _registry: dict[str, type] = {}
49
+ _reverse_registry: dict[type, str] = {}
50
+ _params_registry: dict[type, type] = {}
51
+
52
+ @classmethod
53
+ def register(cls, alias: str):
54
+ def decorator(wrapped_class: type[DataSource[SamplerParamsT]]) -> type:
55
+ cls._registry[alias] = wrapped_class
56
+ cls._reverse_registry[wrapped_class] = alias
57
+ cls._params_registry[wrapped_class.get_param_type()] = wrapped_class
58
+ return wrapped_class
59
+
60
+ return decorator
61
+
62
+ @classmethod
63
+ def get_sampler(cls, alias: str) -> type[DataSource[SamplerParamsT]]:
64
+ return cls._registry[alias.lower()]
65
+
66
+ @classmethod
67
+ def get_sampler_for_params(cls, params_type: SamplerParamsT) -> type[DataSource[SamplerParamsT]]:
68
+ return cls._params_registry[type(params_type)]
69
+
70
+ @classmethod
71
+ def get_sampler_alias_for_params(cls, params_type: SamplerParamsT) -> str:
72
+ return cls._reverse_registry[cls._params_registry[type(params_type)]]
73
+
74
+ @classmethod
75
+ def is_registered(cls, alias: str) -> bool:
76
+ return alias in cls._registry
77
+
78
+ @classmethod
79
+ def validate_sampler_type(
80
+ cls, sampler_type: str | type[DataSource[SamplerParamsT]]
81
+ ) -> type[DataSource[SamplerParamsT]]:
82
+ if isinstance(sampler_type, str):
83
+ if sampler_type not in cls._registry:
84
+ raise ValueError(
85
+ f"Sampler type `{sampler_type}` not found in the registry. "
86
+ f"Available samplers: {list(cls._registry.keys())}"
87
+ )
88
+ sampler_type = cls.get_sampler(sampler_type)
89
+ if not issubclass(sampler_type, DataSource):
90
+ raise ValueError(f"Sampler type `{sampler_type}` is not a subclass of `DataSource`")
91
+ return sampler_type
92
+
93
+
94
+ #########################################
95
+ # Data Source Subclasses
96
+ #########################################
97
+
98
+
99
+ @SamplerRegistry.register(SamplerType.SUBCATEGORY)
100
+ class SubcategorySampler(TypeConversionMixin, DataSource[SubcategorySamplerParams]):
101
+ def get_required_column_names(self) -> tuple[str, ...]:
102
+ return (self.params.category,)
103
+
104
+ def inject_data_column(
105
+ self,
106
+ dataframe: pd.DataFrame,
107
+ column_name: str,
108
+ index: list[int] | None = None,
109
+ ) -> pd.DataFrame:
110
+ index = slice(None) if index is None else index
111
+
112
+ if len(index) == 0:
113
+ return dataframe
114
+
115
+ dataframe.loc[index, column_name] = dataframe.loc[index, self.params.category].apply(
116
+ lambda cat_value: self.rng.choice(self.params.values[cat_value])
117
+ )
118
+
119
+ return dataframe
120
+
121
+
122
+ #########################################
123
+ # Sampler Subclasses
124
+ #########################################
125
+
126
+
127
+ @SamplerRegistry.register(SamplerType.CATEGORY)
128
+ class CategorySampler(TypeConversionMixin, Sampler[CategorySamplerParams]):
129
+ def sample(self, num_samples: int) -> NumpyArray1dT:
130
+ return self.rng.choice(self.params.values, size=num_samples, p=self.params.weights)
131
+
132
+
133
+ @SamplerRegistry.register(SamplerType.DATETIME)
134
+ class DatetimeSampler(DatetimeFormatMixin, Sampler[DatetimeSamplerParams]):
135
+ def sample(self, num_samples: int) -> NumpyArray1dT:
136
+ # Convert nanoseconds to seconds.
137
+ start_sec = pd.to_datetime(self.params.start).value // ONE_BILLION
138
+ end_sec = pd.to_datetime(self.params.end).value // ONE_BILLION
139
+
140
+ random_ns = (ONE_BILLION * self.rng.randint(start_sec, end_sec, num_samples, dtype=np.int64)).view(
141
+ "datetime64[ns]"
142
+ )
143
+
144
+ return np.array(random_ns, dtype=f"datetime64[{self.params.unit}]")
145
+
146
+
147
+ @SamplerRegistry.register(SamplerType.PERSON)
148
+ class PersonSampler(PassthroughMixin, Sampler[PersonSamplerParams]):
149
+ def _setup(self, **kwargs) -> None:
150
+ self._generator = None
151
+ self._fixed_kwargs = {}
152
+ for field in self.params.generator_kwargs:
153
+ if getattr(self.params, field) is not None:
154
+ attr = getattr(self.params, field)
155
+ if field == "select_field_values":
156
+ for key, value in attr.items():
157
+ if key == "state" and self.params.locale == "en_US":
158
+ key = "region" # This is the field name in the census-based person dataset.
159
+ if key not in PII_FIELDS + PERSONA_FIELDS:
160
+ raise ValueError(f"Invalid field name: {key}")
161
+ self._fixed_kwargs[key] = value
162
+ else:
163
+ self._fixed_kwargs[field] = attr
164
+ if people_gen_resource := kwargs.get("people_gen_resource"):
165
+ if self.params.people_gen_key not in people_gen_resource:
166
+ raise ValueError(f"Person generator with key {self.params.people_gen_key} not found.")
167
+ self.set_generator(people_gen_resource[self.params.people_gen_key])
168
+
169
+ def set_generator(self, generator: PeopleGen) -> None:
170
+ self._generator = generator
171
+
172
+ def sample(self, num_samples: int) -> NumpyArray1dT:
173
+ if self._generator is None:
174
+ raise ValueError("Generator not set. Please setup generator before sampling.")
175
+
176
+ samples = np.array(self._generator.generate(num_samples, **self._fixed_kwargs))
177
+ if len(samples) < num_samples:
178
+ raise PersonSamplerConstraintsError(
179
+ f"🛑 Only {len(samples)} samples could be generated with the given settings: {self._fixed_kwargs!r}. "
180
+ "This is likely because the filter values are too strict. Person sampling does not support "
181
+ "rare combinations of field values. Please loosen the constraints and try again."
182
+ )
183
+ return samples
184
+
185
+
186
+ @SamplerRegistry.register(SamplerType.PERSON_FROM_FAKER)
187
+ class PersonFromFakerSampler(PassthroughMixin, Sampler[PersonFromFakerSamplerParams]):
188
+ def _setup(self, **kwargs) -> None:
189
+ self._generator = None
190
+ self._fixed_kwargs = {}
191
+ for field in self.params.generator_kwargs:
192
+ if getattr(self.params, field) is not None:
193
+ self._fixed_kwargs[field] = getattr(self.params, field)
194
+ if people_gen_resource := kwargs.get("people_gen_resource"):
195
+ if self.params.people_gen_key not in people_gen_resource:
196
+ raise ValueError(f"Person generator with key {self.params.people_gen_key} not found.")
197
+ self.set_generator(people_gen_resource[self.params.people_gen_key])
198
+
199
+ def set_generator(self, generator: PeopleGen) -> None:
200
+ self._generator = generator
201
+
202
+ def sample(self, num_samples: int) -> NumpyArray1dT:
203
+ if self._generator is None:
204
+ raise ValueError("Generator not set. Please setup generator before sampling.")
205
+
206
+ samples = np.array(self._generator.generate(num_samples, **self._fixed_kwargs))
207
+ if len(samples) < num_samples:
208
+ raise ValueError(f"Only {len(samples)} samples could be generated given constraints {self._fixed_kwargs}.")
209
+ return samples
210
+
211
+
212
+ @SamplerRegistry.register(SamplerType.TIMEDELTA)
213
+ class TimeDeltaSampler(DatetimeFormatMixin, Sampler[TimeDeltaSamplerParams]):
214
+ def get_required_column_names(self) -> tuple[str, ...]:
215
+ return (self.params.reference_column_name,)
216
+
217
+ def inject_data_column(
218
+ self,
219
+ dataframe: pd.DataFrame,
220
+ column_name: str,
221
+ index: list[int] | None = None,
222
+ ) -> pd.DataFrame:
223
+ index = slice(None) if index is None else index
224
+
225
+ if self.params.reference_column_name not in list(dataframe):
226
+ raise ValueError(f"Columns `{self.params.reference_column_name}` not found in dataset")
227
+
228
+ dataframe.loc[index, column_name] = pd.to_datetime(
229
+ dataframe.loc[index, self.params.reference_column_name]
230
+ ) + pd.to_timedelta(self.sample(len(index)), unit=self.params.unit)
231
+
232
+ return dataframe
233
+
234
+ def sample(self, num_samples: int) -> NumpyArray1dT:
235
+ deltas = self.rng.randint(self.params.dt_min, self.params.dt_max, num_samples)
236
+ return np.array(deltas, dtype=f"timedelta64[{self.params.unit}]")
237
+
238
+
239
+ @SamplerRegistry.register(SamplerType.UUID)
240
+ class UUIDSampler(PassthroughMixin, Sampler[UUIDSamplerParams]):
241
+ def sample(self, num_samples: int) -> NumpyArray1dT:
242
+ prefix = self.params.prefix or ""
243
+
244
+ uid_list = []
245
+ while len(uid_list) < num_samples:
246
+ uid = (
247
+ f"{prefix}{uuid.uuid4().hex[: self.params.last_index].upper()}"
248
+ if self.params.uppercase
249
+ else f"{prefix}{uuid.uuid4().hex[: self.params.last_index]}"
250
+ )
251
+ if uid not in uid_list:
252
+ uid_list.append(uid)
253
+
254
+ return np.array(uid_list)
255
+
256
+
257
+ #########################################
258
+ # Scipy Samplers
259
+ #########################################
260
+
261
+
262
+ @SamplerRegistry.register(SamplerType.SCIPY)
263
+ class ScipySampler(TypeConversionMixin, ScipyStatsSampler[ScipySamplerParams]):
264
+ """Escape hatch sampler to give users access to any scipy.stats distribution."""
265
+
266
+ @property
267
+ def distribution(self) -> stats.rv_continuous | stats.rv_discrete:
268
+ return getattr(stats, self.params.dist_name)(**self.params.dist_params)
269
+
270
+ def _validate(self) -> None:
271
+ _validate_scipy_distribution(self.params.dist_name, self.params.dist_params)
272
+
273
+
274
+ @SamplerRegistry.register(SamplerType.BERNOULLI)
275
+ class BernoulliSampler(TypeConversionMixin, ScipyStatsSampler[BernoulliSamplerParams]):
276
+ @property
277
+ def distribution(self) -> stats.rv_discrete:
278
+ return stats.bernoulli(p=self.params.p)
279
+
280
+
281
+ @SamplerRegistry.register(SamplerType.BERNOULLI_MIXTURE)
282
+ class BernoulliMixtureSampler(TypeConversionMixin, Sampler[BernoulliMixtureSamplerParams]):
283
+ def sample(self, num_samples: int) -> NumpyArray1dT:
284
+ return stats.bernoulli(p=self.params.p).rvs(size=num_samples) * getattr(stats, self.params.dist_name)(
285
+ **self.params.dist_params
286
+ ).rvs(size=num_samples)
287
+
288
+ def _validate(self) -> None:
289
+ _validate_scipy_distribution(self.params.dist_name, self.params.dist_params)
290
+
291
+
292
+ @SamplerRegistry.register(SamplerType.BINOMIAL)
293
+ class BinomialSampler(TypeConversionMixin, ScipyStatsSampler[BinomialSamplerParams]):
294
+ @property
295
+ def distribution(self) -> stats.rv_discrete:
296
+ return stats.binom(n=self.params.n, p=self.params.p)
297
+
298
+
299
+ @SamplerRegistry.register(SamplerType.GAUSSIAN)
300
+ class GaussianSampler(TypeConversionMixin, ScipyStatsSampler[GaussianSamplerParams]):
301
+ @property
302
+ def distribution(self) -> stats.rv_continuous:
303
+ return stats.norm(loc=self.params.mean, scale=self.params.stddev)
304
+
305
+
306
+ @SamplerRegistry.register(SamplerType.POISSON)
307
+ class PoissonSampler(TypeConversionMixin, ScipyStatsSampler[PoissonSamplerParams]):
308
+ @property
309
+ def distribution(self) -> stats.rv_discrete:
310
+ return stats.poisson(mu=self.params.mean)
311
+
312
+
313
+ @SamplerRegistry.register(SamplerType.UNIFORM)
314
+ class UniformSampler(TypeConversionMixin, ScipyStatsSampler[UniformSamplerParams]):
315
+ @property
316
+ def distribution(self) -> stats.rv_continuous:
317
+ return stats.uniform(loc=self.params.low, scale=self.params.high - self.params.low)
318
+
319
+
320
+ ###################################################
321
+ # Helper functions for loading sources in isolation
322
+ ###################################################
323
+
324
+
325
+ def load_sampler(sampler_type: SamplerType, **params) -> DataSource:
326
+ """Load a data source from a source type and parameters."""
327
+ return SamplerRegistry.validate_sampler_type(sampler_type)(params=params)
328
+
329
+
330
+ def _validate_scipy_distribution(dist_name: str, dist_params: dict) -> None:
331
+ if not hasattr(stats, dist_name):
332
+ raise InvalidSamplerParamsError(f"Distribution {dist_name} not found in scipy.stats")
333
+ if not hasattr(getattr(stats, dist_name), "rvs"):
334
+ raise InvalidSamplerParamsError(
335
+ f"Distribution {dist_name} does not have a `rvs` method, which is required for sampling."
336
+ )
337
+ try:
338
+ getattr(stats, dist_name)(**dist_params)
339
+ except Exception:
340
+ raise InvalidSamplerParamsError(
341
+ f"Distribution parameters {dist_params} are not a valid for distribution '{dist_name}'"
342
+ )
@@ -0,0 +1,2 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,64 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ This file contains all possible fields that:
6
+
7
+ 1. Exist in a managed PII + persona dataset
8
+ 2. Are included in the final generated dataset
9
+
10
+ Do not add any other code or logic in this file.
11
+ """
12
+
13
+ REQUIRED_FIELDS = {"first_name", "last_name", "age", "locale"}
14
+
15
+
16
+ PII_FIELDS = [
17
+ "uuid",
18
+ "first_name",
19
+ "middle_name",
20
+ "last_name",
21
+ "sex",
22
+ "age",
23
+ "birth_date",
24
+ "marital_status",
25
+ "street_name",
26
+ "street_number",
27
+ "unit",
28
+ "postcode",
29
+ "region",
30
+ "city",
31
+ "district",
32
+ "country",
33
+ "area",
34
+ "zone",
35
+ "bachelors_field",
36
+ "education_degree",
37
+ "education_level",
38
+ "occupation",
39
+ "locale",
40
+ ]
41
+
42
+
43
+ PERSONA_FIELDS = [
44
+ "persona",
45
+ "career_goals_and_ambitions",
46
+ "arts_persona",
47
+ "culinary_persona",
48
+ "cultural_background",
49
+ "detailed_persona",
50
+ "finance_persona",
51
+ "healthcare_persona",
52
+ "hobbies_and_interests_list",
53
+ "hobbies_and_interests",
54
+ "professional_persona",
55
+ "skills_and_expertise_list",
56
+ "skills_and_expertise",
57
+ "sports_persona",
58
+ "travel_persona",
59
+ "openness",
60
+ "conscientiousness",
61
+ "extraversion",
62
+ "agreeableness",
63
+ "neuroticism",
64
+ ]