data-designer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. data_designer/__init__.py +15 -0
  2. data_designer/_version.py +34 -0
  3. data_designer/cli/README.md +236 -0
  4. data_designer/cli/__init__.py +6 -0
  5. data_designer/cli/commands/__init__.py +2 -0
  6. data_designer/cli/commands/list.py +130 -0
  7. data_designer/cli/commands/models.py +10 -0
  8. data_designer/cli/commands/providers.py +11 -0
  9. data_designer/cli/commands/reset.py +100 -0
  10. data_designer/cli/controllers/__init__.py +7 -0
  11. data_designer/cli/controllers/model_controller.py +246 -0
  12. data_designer/cli/controllers/provider_controller.py +317 -0
  13. data_designer/cli/forms/__init__.py +20 -0
  14. data_designer/cli/forms/builder.py +51 -0
  15. data_designer/cli/forms/field.py +180 -0
  16. data_designer/cli/forms/form.py +59 -0
  17. data_designer/cli/forms/model_builder.py +125 -0
  18. data_designer/cli/forms/provider_builder.py +76 -0
  19. data_designer/cli/main.py +44 -0
  20. data_designer/cli/repositories/__init__.py +8 -0
  21. data_designer/cli/repositories/base.py +39 -0
  22. data_designer/cli/repositories/model_repository.py +42 -0
  23. data_designer/cli/repositories/provider_repository.py +43 -0
  24. data_designer/cli/services/__init__.py +7 -0
  25. data_designer/cli/services/model_service.py +116 -0
  26. data_designer/cli/services/provider_service.py +111 -0
  27. data_designer/cli/ui.py +448 -0
  28. data_designer/cli/utils.py +47 -0
  29. data_designer/config/__init__.py +2 -0
  30. data_designer/config/analysis/column_profilers.py +89 -0
  31. data_designer/config/analysis/column_statistics.py +274 -0
  32. data_designer/config/analysis/dataset_profiler.py +60 -0
  33. data_designer/config/analysis/utils/errors.py +8 -0
  34. data_designer/config/analysis/utils/reporting.py +188 -0
  35. data_designer/config/base.py +68 -0
  36. data_designer/config/column_configs.py +354 -0
  37. data_designer/config/column_types.py +168 -0
  38. data_designer/config/config_builder.py +660 -0
  39. data_designer/config/data_designer_config.py +40 -0
  40. data_designer/config/dataset_builders.py +11 -0
  41. data_designer/config/datastore.py +151 -0
  42. data_designer/config/default_model_settings.py +123 -0
  43. data_designer/config/errors.py +19 -0
  44. data_designer/config/interface.py +54 -0
  45. data_designer/config/models.py +231 -0
  46. data_designer/config/preview_results.py +32 -0
  47. data_designer/config/processors.py +41 -0
  48. data_designer/config/sampler_constraints.py +51 -0
  49. data_designer/config/sampler_params.py +604 -0
  50. data_designer/config/seed.py +145 -0
  51. data_designer/config/utils/code_lang.py +83 -0
  52. data_designer/config/utils/constants.py +313 -0
  53. data_designer/config/utils/errors.py +19 -0
  54. data_designer/config/utils/info.py +88 -0
  55. data_designer/config/utils/io_helpers.py +273 -0
  56. data_designer/config/utils/misc.py +81 -0
  57. data_designer/config/utils/numerical_helpers.py +28 -0
  58. data_designer/config/utils/type_helpers.py +100 -0
  59. data_designer/config/utils/validation.py +336 -0
  60. data_designer/config/utils/visualization.py +427 -0
  61. data_designer/config/validator_params.py +96 -0
  62. data_designer/engine/__init__.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +55 -0
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
  65. data_designer/engine/analysis/column_profilers/registry.py +20 -0
  66. data_designer/engine/analysis/column_statistics.py +142 -0
  67. data_designer/engine/analysis/dataset_profiler.py +125 -0
  68. data_designer/engine/analysis/errors.py +7 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
  70. data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
  71. data_designer/engine/column_generators/__init__.py +2 -0
  72. data_designer/engine/column_generators/generators/__init__.py +2 -0
  73. data_designer/engine/column_generators/generators/base.py +61 -0
  74. data_designer/engine/column_generators/generators/expression.py +63 -0
  75. data_designer/engine/column_generators/generators/llm_generators.py +172 -0
  76. data_designer/engine/column_generators/generators/samplers.py +75 -0
  77. data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
  78. data_designer/engine/column_generators/generators/validation.py +147 -0
  79. data_designer/engine/column_generators/registry.py +56 -0
  80. data_designer/engine/column_generators/utils/errors.py +13 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
  83. data_designer/engine/configurable_task.py +82 -0
  84. data_designer/engine/dataset_builders/artifact_storage.py +181 -0
  85. data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
  86. data_designer/engine/dataset_builders/errors.py +13 -0
  87. data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
  88. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +56 -0
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
  93. data_designer/engine/dataset_builders/utils/errors.py +13 -0
  94. data_designer/engine/errors.py +49 -0
  95. data_designer/engine/model_provider.py +75 -0
  96. data_designer/engine/models/__init__.py +2 -0
  97. data_designer/engine/models/errors.py +308 -0
  98. data_designer/engine/models/facade.py +225 -0
  99. data_designer/engine/models/litellm_overrides.py +162 -0
  100. data_designer/engine/models/parsers/__init__.py +2 -0
  101. data_designer/engine/models/parsers/errors.py +34 -0
  102. data_designer/engine/models/parsers/parser.py +236 -0
  103. data_designer/engine/models/parsers/postprocessors.py +93 -0
  104. data_designer/engine/models/parsers/tag_parsers.py +60 -0
  105. data_designer/engine/models/parsers/types.py +82 -0
  106. data_designer/engine/models/recipes/base.py +79 -0
  107. data_designer/engine/models/recipes/response_recipes.py +291 -0
  108. data_designer/engine/models/registry.py +118 -0
  109. data_designer/engine/models/usage.py +75 -0
  110. data_designer/engine/models/utils.py +38 -0
  111. data_designer/engine/processing/ginja/__init__.py +2 -0
  112. data_designer/engine/processing/ginja/ast.py +64 -0
  113. data_designer/engine/processing/ginja/environment.py +461 -0
  114. data_designer/engine/processing/ginja/exceptions.py +54 -0
  115. data_designer/engine/processing/ginja/record.py +30 -0
  116. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  117. data_designer/engine/processing/gsonschema/exceptions.py +8 -0
  118. data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
  119. data_designer/engine/processing/gsonschema/types.py +8 -0
  120. data_designer/engine/processing/gsonschema/validators.py +143 -0
  121. data_designer/engine/processing/processors/base.py +15 -0
  122. data_designer/engine/processing/processors/drop_columns.py +46 -0
  123. data_designer/engine/processing/processors/registry.py +20 -0
  124. data_designer/engine/processing/utils.py +120 -0
  125. data_designer/engine/registry/base.py +97 -0
  126. data_designer/engine/registry/data_designer_registry.py +37 -0
  127. data_designer/engine/registry/errors.py +10 -0
  128. data_designer/engine/resources/managed_dataset_generator.py +35 -0
  129. data_designer/engine/resources/managed_dataset_repository.py +194 -0
  130. data_designer/engine/resources/managed_storage.py +63 -0
  131. data_designer/engine/resources/resource_provider.py +46 -0
  132. data_designer/engine/resources/seed_dataset_data_store.py +66 -0
  133. data_designer/engine/sampling_gen/column.py +89 -0
  134. data_designer/engine/sampling_gen/constraints.py +95 -0
  135. data_designer/engine/sampling_gen/data_sources/base.py +214 -0
  136. data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
  137. data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
  138. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  139. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  140. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
  141. data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
  142. data_designer/engine/sampling_gen/entities/errors.py +8 -0
  143. data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
  144. data_designer/engine/sampling_gen/entities/person.py +142 -0
  145. data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
  146. data_designer/engine/sampling_gen/errors.py +24 -0
  147. data_designer/engine/sampling_gen/generator.py +121 -0
  148. data_designer/engine/sampling_gen/jinja_utils.py +60 -0
  149. data_designer/engine/sampling_gen/people_gen.py +203 -0
  150. data_designer/engine/sampling_gen/person_constants.py +54 -0
  151. data_designer/engine/sampling_gen/schema.py +143 -0
  152. data_designer/engine/sampling_gen/schema_builder.py +59 -0
  153. data_designer/engine/sampling_gen/utils.py +40 -0
  154. data_designer/engine/secret_resolver.py +80 -0
  155. data_designer/engine/validators/__init__.py +17 -0
  156. data_designer/engine/validators/base.py +36 -0
  157. data_designer/engine/validators/local_callable.py +34 -0
  158. data_designer/engine/validators/python.py +245 -0
  159. data_designer/engine/validators/remote.py +83 -0
  160. data_designer/engine/validators/sql.py +60 -0
  161. data_designer/errors.py +5 -0
  162. data_designer/essentials/__init__.py +137 -0
  163. data_designer/interface/__init__.py +2 -0
  164. data_designer/interface/data_designer.py +351 -0
  165. data_designer/interface/errors.py +16 -0
  166. data_designer/interface/results.py +55 -0
  167. data_designer/logging.py +161 -0
  168. data_designer/plugin_manager.py +83 -0
  169. data_designer/plugins/__init__.py +6 -0
  170. data_designer/plugins/errors.py +10 -0
  171. data_designer/plugins/plugin.py +69 -0
  172. data_designer/plugins/registry.py +86 -0
  173. data_designer-0.1.0.dist-info/METADATA +173 -0
  174. data_designer-0.1.0.dist-info/RECORD +177 -0
  175. data_designer-0.1.0.dist-info/WHEEL +4 -0
  176. data_designer-0.1.0.dist-info/entry_points.txt +2 -0
  177. data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,41 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC
5
+ from enum import Enum
6
+ from typing import Literal
7
+
8
+ from pydantic import Field, field_validator
9
+
10
+ from .base import ConfigBase
11
+ from .dataset_builders import BuildStage
12
+
13
+ SUPPORTED_STAGES = [BuildStage.POST_BATCH]
14
+
15
+
16
+ class ProcessorType(str, Enum):
17
+ DROP_COLUMNS = "drop_columns"
18
+
19
+
20
+ class ProcessorConfig(ConfigBase, ABC):
21
+ build_stage: BuildStage = Field(
22
+ ..., description=f"The stage at which the processor will run. Supported stages: {', '.join(SUPPORTED_STAGES)}"
23
+ )
24
+
25
+ @field_validator("build_stage")
26
+ def validate_build_stage(cls, v: BuildStage) -> BuildStage:
27
+ if v not in SUPPORTED_STAGES:
28
+ raise ValueError(
29
+ f"Invalid dataset builder stage: {v}. Only these stages are supported: {', '.join(SUPPORTED_STAGES)}"
30
+ )
31
+ return v
32
+
33
+
34
+ def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs) -> ProcessorConfig:
35
+ if processor_type == ProcessorType.DROP_COLUMNS:
36
+ return DropColumnsProcessorConfig(**kwargs)
37
+
38
+
39
+ class DropColumnsProcessorConfig(ProcessorConfig):
40
+ column_names: list[str]
41
+ processor_type: Literal[ProcessorType.DROP_COLUMNS] = ProcessorType.DROP_COLUMNS
@@ -0,0 +1,51 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC, abstractmethod
5
+ from enum import Enum
6
+ from typing import Union
7
+
8
+ from typing_extensions import TypeAlias
9
+
10
+ from .base import ConfigBase
11
+
12
+
13
+ class ConstraintType(str, Enum):
14
+ SCALAR_INEQUALITY = "scalar_inequality"
15
+ COLUMN_INEQUALITY = "column_inequality"
16
+
17
+
18
+ class InequalityOperator(str, Enum):
19
+ LT = "lt"
20
+ LE = "le"
21
+ GT = "gt"
22
+ GE = "ge"
23
+
24
+
25
+ class Constraint(ConfigBase, ABC):
26
+ target_column: str
27
+
28
+ @property
29
+ @abstractmethod
30
+ def constraint_type(self) -> ConstraintType: ...
31
+
32
+
33
+ class ScalarInequalityConstraint(Constraint):
34
+ rhs: float
35
+ operator: InequalityOperator
36
+
37
+ @property
38
+ def constraint_type(self) -> ConstraintType:
39
+ return ConstraintType.SCALAR_INEQUALITY
40
+
41
+
42
+ class ColumnInequalityConstraint(Constraint):
43
+ rhs: str
44
+ operator: InequalityOperator
45
+
46
+ @property
47
+ def constraint_type(self) -> ConstraintType:
48
+ return ConstraintType.COLUMN_INEQUALITY
49
+
50
+
51
+ ColumnConstraintT: TypeAlias = Union[ScalarInequalityConstraint, ColumnInequalityConstraint]
@@ -0,0 +1,604 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from enum import Enum
5
+ from typing import Literal, Optional, Union
6
+
7
+ import pandas as pd
8
+ from pydantic import Field, field_validator, model_validator
9
+ from typing_extensions import Self, TypeAlias
10
+
11
+ from .base import ConfigBase
12
+ from .utils.constants import (
13
+ AVAILABLE_LOCALES,
14
+ DEFAULT_AGE_RANGE,
15
+ LOCALES_WITH_MANAGED_DATASETS,
16
+ MAX_AGE,
17
+ MIN_AGE,
18
+ )
19
+
20
+
21
+ class SamplerType(str, Enum):
22
+ BERNOULLI = "bernoulli"
23
+ BERNOULLI_MIXTURE = "bernoulli_mixture"
24
+ BINOMIAL = "binomial"
25
+ CATEGORY = "category"
26
+ DATETIME = "datetime"
27
+ GAUSSIAN = "gaussian"
28
+ PERSON = "person"
29
+ PERSON_FROM_FAKER = "person_from_faker"
30
+ POISSON = "poisson"
31
+ SCIPY = "scipy"
32
+ SUBCATEGORY = "subcategory"
33
+ TIMEDELTA = "timedelta"
34
+ UNIFORM = "uniform"
35
+ UUID = "uuid"
36
+
37
+
38
+ #########################################
39
+ # Sampler Parameters
40
+ #########################################
41
+
42
+
43
+ class CategorySamplerParams(ConfigBase):
44
+ """Parameters for categorical sampling with optional probability weighting.
45
+
46
+ Samples values from a discrete set of categories. When weights are provided, values are
47
+ sampled according to their assigned probabilities. Without weights, uniform sampling is used.
48
+
49
+ Attributes:
50
+ values: List of possible categorical values to sample from. Can contain strings, integers,
51
+ or floats. Must contain at least one value.
52
+ weights: Optional unnormalized probability weights for each value. If provided, must be
53
+ the same length as `values`. Weights are automatically normalized to sum to 1.0.
54
+ Larger weights result in higher sampling probability for the corresponding value.
55
+ """
56
+
57
+ values: list[Union[str, int, float]] = Field(
58
+ ...,
59
+ min_length=1,
60
+ description="List of possible categorical values that can be sampled from.",
61
+ )
62
+ weights: Optional[list[float]] = Field(
63
+ default=None,
64
+ description=(
65
+ "List of unnormalized probability weights to assigned to each value, in order. "
66
+ "Larger values will be sampled with higher probability."
67
+ ),
68
+ )
69
+
70
+ @model_validator(mode="after")
71
+ def _normalize_weights_if_needed(self) -> Self:
72
+ if self.weights is not None:
73
+ self.weights = [w / sum(self.weights) for w in self.weights]
74
+ return self
75
+
76
+ @model_validator(mode="after")
77
+ def _validate_equal_lengths(self) -> Self:
78
+ if self.weights and len(self.values) != len(self.weights):
79
+ raise ValueError("'categories' and 'weights' must have the same length")
80
+ return self
81
+
82
+
83
+ class DatetimeSamplerParams(ConfigBase):
84
+ """Parameters for uniform datetime sampling within a specified range.
85
+
86
+ Samples datetime values uniformly between a start and end date with a specified granularity.
87
+ The sampling unit determines the smallest possible time interval between consecutive samples.
88
+
89
+ Attributes:
90
+ start: Earliest possible datetime for the sampling range (inclusive). Must be a valid
91
+ datetime string parseable by pandas.to_datetime().
92
+ end: Latest possible datetime for the sampling range (inclusive). Must be a valid
93
+ datetime string parseable by pandas.to_datetime().
94
+ unit: Time unit for sampling granularity. Options:
95
+ - "Y": Years
96
+ - "M": Months
97
+ - "D": Days (default)
98
+ - "h": Hours
99
+ - "m": Minutes
100
+ - "s": Seconds
101
+ """
102
+
103
+ start: str = Field(..., description="Earliest possible datetime for sampling range, inclusive.")
104
+ end: str = Field(..., description="Latest possible datetime for sampling range, inclusive.")
105
+ unit: Literal["Y", "M", "D", "h", "m", "s"] = Field(
106
+ default="D",
107
+ description="Sampling units, e.g. the smallest possible time interval between samples.",
108
+ )
109
+
110
+ @field_validator("start", "end")
111
+ @classmethod
112
+ def _validate_param_is_datetime(cls, value: str) -> str:
113
+ try:
114
+ pd.to_datetime(value)
115
+ except ValueError:
116
+ raise ValueError(f"Invalid datetime format: {value}")
117
+ return value
118
+
119
+
120
+ class SubcategorySamplerParams(ConfigBase):
121
+ """Parameters for subcategory sampling conditioned on a parent category column.
122
+
123
+ Samples subcategory values based on the value of a parent category column. Each parent
124
+ category value maps to its own list of possible subcategory values, enabling hierarchical
125
+ or conditional sampling patterns.
126
+
127
+ Attributes:
128
+ category: Name of the parent category column that this subcategory depends on.
129
+ The parent column must be generated before this subcategory column.
130
+ values: Mapping from each parent category value to a list of possible subcategory values.
131
+ Each key must correspond to a value that appears in the parent category column.
132
+ """
133
+
134
+ category: str = Field(..., description="Name of parent category to this subcategory.")
135
+ values: dict[str, list[Union[str, int, float]]] = Field(
136
+ ...,
137
+ description="Mapping from each value of parent category to a list of subcategory values.",
138
+ )
139
+
140
+
141
+ class TimeDeltaSamplerParams(ConfigBase):
142
+ """Parameters for sampling time deltas relative to a reference datetime column.
143
+
144
+ Samples time offsets within a specified range and adds them to values from a reference
145
+ datetime column. This is useful for generating related datetime columns like order dates
146
+ and delivery dates, or event start times and end times.
147
+
148
+ Note:
149
+ Years and months are not supported as timedelta units because they have variable lengths.
150
+ See: [pandas timedelta documentation](https://pandas.pydata.org/docs/user_guide/timedeltas.html)
151
+
152
+ Attributes:
153
+ dt_min: Minimum time-delta value (inclusive). Must be non-negative and less than `dt_max`.
154
+ Specified in units defined by the `unit` parameter.
155
+ dt_max: Maximum time-delta value (exclusive). Must be positive and greater than `dt_min`.
156
+ Specified in units defined by the `unit` parameter.
157
+ reference_column_name: Name of an existing datetime column to add the time-delta to.
158
+ This column must be generated before the timedelta column.
159
+ unit: Time unit for the delta values. Options:
160
+ - "D": Days (default)
161
+ - "h": Hours
162
+ - "m": Minutes
163
+ - "s": Seconds
164
+ """
165
+
166
+ dt_min: int = Field(
167
+ ...,
168
+ ge=0,
169
+ description=("Minimum possible time-delta for sampling range, inclusive. Must be less than `dt_max`."),
170
+ )
171
+ dt_max: int = Field(
172
+ ...,
173
+ gt=0,
174
+ description=("Maximum possible time-delta for sampling range, exclusive. Must be greater than `dt_min`."),
175
+ )
176
+
177
+ reference_column_name: str = Field(
178
+ ...,
179
+ description="Name of an existing datetime column to condition time-delta sampling on.",
180
+ )
181
+
182
+ # NOTE: pandas does not support years or months as timedelta units
183
+ # since they are ambiguous. We will need to update the implementation
184
+ # if we need to support these units.
185
+ # see: https://pandas.pydata.org/docs/user_guide/timedeltas.html.
186
+ unit: Literal["D", "h", "m", "s"] = Field(
187
+ default="D",
188
+ description="Sampling units, e.g. the smallest possible time interval between samples.",
189
+ )
190
+
191
+ @model_validator(mode="after")
192
+ def _validate_min_less_than_max(self) -> Self:
193
+ if self.dt_min >= self.dt_max:
194
+ raise ValueError("'dt_min' must be less than 'dt_max'")
195
+ return self
196
+
197
+
198
+ class UUIDSamplerParams(ConfigBase):
199
+ """Parameters for generating UUID (Universally Unique Identifier) values.
200
+
201
+ Generates UUID4 (random) identifiers with optional formatting options. UUIDs are useful
202
+ for creating unique identifiers for records, entities, or transactions.
203
+
204
+ Attributes:
205
+ prefix: Optional string to prepend to each UUID. Useful for creating namespaced or
206
+ typed identifiers (e.g., "user-", "order-", "txn-").
207
+ short_form: If True, truncates UUIDs to 8 characters (first segment only). Default is False
208
+ for full 32-character UUIDs (excluding hyphens).
209
+ uppercase: If True, converts all hexadecimal letters to uppercase. Default is False for
210
+ lowercase UUIDs.
211
+ """
212
+
213
+ prefix: Optional[str] = Field(default=None, description="String prepended to the front of the UUID.")
214
+ short_form: bool = Field(
215
+ default=False,
216
+ description="If true, all UUIDs sampled will be truncated at 8 characters.",
217
+ )
218
+ uppercase: bool = Field(
219
+ default=False,
220
+ description="If true, all letters in the UUID will be capitalized.",
221
+ )
222
+
223
+ @property
224
+ def last_index(self) -> int:
225
+ return 8 if self.short_form else 32
226
+
227
+
228
+ #########################################
229
+ # Scipy Sampler Parameters
230
+ #########################################
231
+
232
+
233
+ class ScipySamplerParams(ConfigBase):
234
+ """Parameters for sampling from any scipy.stats continuous or discrete distribution.
235
+
236
+ Provides a flexible interface to sample from the wide range of probability distributions
237
+ available in scipy.stats. This enables advanced statistical sampling beyond the built-in
238
+ distribution types (Gaussian, Uniform, etc.).
239
+
240
+ See: [scipy.stats documentation](https://docs.scipy.org/doc/scipy/reference/stats.html)
241
+
242
+ Attributes:
243
+ dist_name: Name of the scipy.stats distribution to sample from (e.g., "beta", "gamma",
244
+ "lognorm", "expon"). Must be a valid distribution name from scipy.stats.
245
+ dist_params: Dictionary of parameters for the specified distribution. Parameter names
246
+ and values must match the scipy.stats distribution specification (e.g., {"a": 2, "b": 5}
247
+ for beta distribution, {"scale": 1.5} for exponential).
248
+ decimal_places: Optional number of decimal places to round sampled values to. If None,
249
+ values are not rounded.
250
+ """
251
+
252
+ dist_name: str = Field(..., description="Name of a scipy.stats distribution.")
253
+ dist_params: dict = Field(
254
+ ...,
255
+ description="Parameters of the scipy.stats distribution given in `dist_name`.",
256
+ )
257
+ decimal_places: Optional[int] = Field(
258
+ default=None, description="Number of decimal places to round the sampled values to."
259
+ )
260
+
261
+
262
+ class BinomialSamplerParams(ConfigBase):
263
+ """Parameters for sampling from a Binomial distribution.
264
+
265
+ Samples integer values representing the number of successes in a fixed number of independent
266
+ Bernoulli trials, each with the same probability of success. Commonly used to model the number
267
+ of successful outcomes in repeated experiments.
268
+
269
+ Attributes:
270
+ n: Number of independent trials. Must be a positive integer.
271
+ p: Probability of success on each trial. Must be between 0.0 and 1.0 (inclusive).
272
+ """
273
+
274
+ n: int = Field(..., description="Number of trials.")
275
+ p: float = Field(..., description="Probability of success on each trial.", ge=0.0, le=1.0)
276
+
277
+
278
+ class BernoulliSamplerParams(ConfigBase):
279
+ """Parameters for sampling from a Bernoulli distribution.
280
+
281
+ Samples binary values (0 or 1) representing the outcome of a single trial with a fixed
282
+ probability of success. This is the simplest discrete probability distribution, useful for
283
+ modeling binary outcomes like success/failure, yes/no, or true/false.
284
+
285
+ Attributes:
286
+ p: Probability of success (sampling 1). Must be between 0.0 and 1.0 (inclusive).
287
+ The probability of failure (sampling 0) is automatically 1 - p.
288
+ """
289
+
290
+ p: float = Field(..., description="Probability of success.", ge=0.0, le=1.0)
291
+
292
+
293
+ class BernoulliMixtureSamplerParams(ConfigBase):
294
+ """Parameters for sampling from a Bernoulli mixture distribution.
295
+
296
+ Combines a Bernoulli distribution with another continuous distribution, creating a mixture
297
+ where values are either 0 (with probability 1-p) or sampled from the specified distribution
298
+ (with probability p). This is useful for modeling scenarios with many zero values mixed with
299
+ a continuous distribution of non-zero values.
300
+
301
+ Common use cases include modeling sparse events, zero-inflated data, or situations where
302
+ an outcome either doesn't occur (0) or follows a specific distribution when it does occur.
303
+
304
+ Attributes:
305
+ p: Probability of sampling from the mixture distribution (non-zero outcome).
306
+ Must be between 0.0 and 1.0 (inclusive). With probability 1-p, the sample is 0.
307
+ dist_name: Name of the scipy.stats distribution to sample from when outcome is non-zero.
308
+ Must be a valid scipy.stats distribution name (e.g., "norm", "gamma", "expon").
309
+ dist_params: Parameters for the specified scipy.stats distribution.
310
+ """
311
+
312
+ p: float = Field(
313
+ ...,
314
+ description="Bernoulli distribution probability of success.",
315
+ ge=0.0,
316
+ le=1.0,
317
+ )
318
+ dist_name: str = Field(
319
+ ...,
320
+ description=(
321
+ "Mixture distribution name. Samples will be equal to the "
322
+ "distribution sample with probability `p`, otherwise equal to 0. "
323
+ "Must be a valid scipy.stats distribution name."
324
+ ),
325
+ )
326
+ dist_params: dict = Field(
327
+ ...,
328
+ description="Parameters of the scipy.stats distribution given in `dist_name`.",
329
+ )
330
+
331
+
332
+ class GaussianSamplerParams(ConfigBase):
333
+ """Parameters for sampling from a Gaussian (Normal) distribution.
334
+
335
+ Samples continuous values from a normal distribution characterized by its mean and standard
336
+ deviation. The Gaussian distribution is one of the most commonly used probability distributions,
337
+ appearing naturally in many real-world phenomena due to the Central Limit Theorem.
338
+
339
+ Attributes:
340
+ mean: Mean (center) of the Gaussian distribution. This is the expected value and the
341
+ location of the distribution's peak.
342
+ stddev: Standard deviation of the Gaussian distribution. Controls the spread or width
343
+ of the distribution. Must be positive.
344
+ decimal_places: Optional number of decimal places to round sampled values to. If None,
345
+ values are not rounded.
346
+ """
347
+
348
+ mean: float = Field(..., description="Mean of the Gaussian distribution")
349
+ stddev: float = Field(..., description="Standard deviation of the Gaussian distribution")
350
+ decimal_places: Optional[int] = Field(
351
+ default=None, description="Number of decimal places to round the sampled values to."
352
+ )
353
+
354
+
355
+ class PoissonSamplerParams(ConfigBase):
356
+ """Parameters for sampling from a Poisson distribution.
357
+
358
+ Samples non-negative integer values representing the number of events occurring in a fixed
359
+ interval of time or space. The Poisson distribution is commonly used to model count data
360
+ like the number of arrivals, occurrences, or events per time period.
361
+
362
+ The distribution is characterized by a single parameter (mean/rate), and both the mean and
363
+ variance equal this parameter value.
364
+
365
+ Attributes:
366
+ mean: Mean number of events in the fixed interval (also called rate parameter λ).
367
+ Must be positive. This represents both the expected value and the variance of the
368
+ distribution.
369
+ """
370
+
371
+ mean: float = Field(..., description="Mean number of events in a fixed interval.")
372
+
373
+
374
+ class UniformSamplerParams(ConfigBase):
375
+ """Parameters for sampling from a continuous Uniform distribution.
376
+
377
+ Samples continuous values uniformly from a specified range, where every value in the range
378
+ has equal probability of being sampled. This is useful when all values within a range are
379
+ equally likely, such as random percentages, proportions, or unbiased measurements.
380
+
381
+ Attributes:
382
+ low: Lower bound of the uniform distribution (inclusive). Can be any real number.
383
+ high: Upper bound of the uniform distribution (inclusive). Must be greater than `low`.
384
+ decimal_places: Optional number of decimal places to round sampled values to. If None,
385
+ values are not rounded and may have many decimal places.
386
+ """
387
+
388
+ low: float = Field(..., description="Lower bound of the uniform distribution, inclusive.")
389
+ high: float = Field(..., description="Upper bound of the uniform distribution, inclusive.")
390
+ decimal_places: Optional[int] = Field(
391
+ default=None, description="Number of decimal places to round the sampled values to."
392
+ )
393
+
394
+
395
+ #########################################
396
+ # Person Sampler Parameters
397
+ #########################################
398
+
399
+ SexT: TypeAlias = Literal["Male", "Female"]
400
+
401
+
402
+ class PersonSamplerParams(ConfigBase):
403
+ """Parameters for sampling synthetic person data with demographic attributes.
404
+
405
+ Generates realistic synthetic person data including names, addresses, phone numbers, and other
406
+ demographic information. Data can be sampled from managed datasets (when available) or generated
407
+ using Faker. The sampler supports filtering by locale, sex, age, geographic location, and can
408
+ optionally include synthetic persona descriptions.
409
+
410
+ Attributes:
411
+ locale: Locale string determining the language and geographic region for synthetic people.
412
+ Format: language_COUNTRY (e.g., "en_US", "en_GB", "fr_FR", "de_DE", "es_ES", "ja_JP").
413
+ Defaults to "en_US".
414
+ sex: If specified, filters to only sample people of the specified sex. Options: "Male" or
415
+ "Female". If None, samples both sexes.
416
+ city: If specified, filters to only sample people from the specified city or cities. Can be
417
+ a single city name (string) or a list of city names.
418
+ age_range: Two-element list [min_age, max_age] specifying the age range to sample from
419
+ (inclusive). Defaults to a standard age range. Both values must be between minimum and
420
+ maximum allowed ages.
421
+ state: Only supported for "en_US" locale. Filters to sample people from specified US state(s).
422
+ Must be provided as two-letter state abbreviations (e.g., "CA", "NY", "TX"). Can be a
423
+ single state or a list of states.
424
+ with_synthetic_personas: If True, appends additional synthetic persona columns including
425
+ personality traits, interests, and background descriptions. Only supported for certain
426
+ locales with managed datasets.
427
+ sample_dataset_when_available: If True, samples from curated managed datasets when available
428
+ for the specified locale. If False or unavailable, falls back to Faker-generated data.
429
+ Managed datasets typically provide more realistic and diverse synthetic people.
430
+ """
431
+
432
+ locale: str = Field(
433
+ default="en_US",
434
+ description=(
435
+ "Locale that determines the language and geographic location "
436
+ "that a synthetic person will be sampled from. Must be a locale supported by "
437
+ "a managed Nemotron Personas dataset. Managed datasets exist for the following locales: "
438
+ f"{', '.join(LOCALES_WITH_MANAGED_DATASETS)}."
439
+ ),
440
+ )
441
+ sex: Optional[SexT] = Field(
442
+ default=None,
443
+ description="If specified, then only synthetic people of the specified sex will be sampled.",
444
+ )
445
+ city: Optional[Union[str, list[str]]] = Field(
446
+ default=None,
447
+ description="If specified, then only synthetic people from these cities will be sampled.",
448
+ )
449
+ age_range: list[int] = Field(
450
+ default=DEFAULT_AGE_RANGE,
451
+ description="If specified, then only synthetic people within this age range will be sampled.",
452
+ min_length=2,
453
+ max_length=2,
454
+ )
455
+ select_field_values: Optional[dict[str, list[str]]] = Field(
456
+ default=None,
457
+ description=(
458
+ "Sample synthetic people with the specified field values. This is meant to be a flexible argument for "
459
+ "selecting a subset of the population from the managed dataset. Note that this sampler does not support "
460
+ "rare combinations of field values and will likely fail if your desired subset is not well-represented "
461
+ "in the managed Nemotron Personas dataset. We generally recommend using the `sex`, `city`, and `age_range` "
462
+ "arguments to filter the population when possible."
463
+ ),
464
+ examples=[
465
+ {"state": ["NY", "CA", "OH", "TX", "NV"], "education_level": ["high_school", "some_college", "bachelors"]}
466
+ ],
467
+ )
468
+
469
+ with_synthetic_personas: bool = Field(
470
+ default=False,
471
+ description="If True, then append synthetic persona columns to each generated person.",
472
+ )
473
+
474
+ @property
475
+ def generator_kwargs(self) -> list[str]:
476
+ """Keyword arguments to pass to the person generator."""
477
+ return [f for f in list(PersonSamplerParams.model_fields) if f != "locale"]
478
+
479
+ @property
480
+ def people_gen_key(self) -> str:
481
+ return f"{self.locale}_with_personas" if self.with_synthetic_personas else self.locale
482
+
483
+ @field_validator("age_range")
484
+ @classmethod
485
+ def _validate_age_range(cls, value: list[int]) -> list[int]:
486
+ msg_prefix = "'age_range' must be a list of two integers, representing the min and max age."
487
+ if value[0] < MIN_AGE:
488
+ raise ValueError(
489
+ f"{msg_prefix} The first integer (min age) must be greater than or equal to {MIN_AGE}, "
490
+ f"but the first integer provided was {value[0]}."
491
+ )
492
+ if value[1] > MAX_AGE:
493
+ raise ValueError(
494
+ f"{msg_prefix} The second integer (max age) must be less than or equal to {MAX_AGE}, "
495
+ f"but the second integer provided was {value[1]}."
496
+ )
497
+ if value[0] >= value[1]:
498
+ raise ValueError(
499
+ f"{msg_prefix} The first integer (min age) must be less than the second integer (max age), "
500
+ f"but the first integer provided was {value[0]} and the second integer provided was {value[1]}."
501
+ )
502
+ return value
503
+
504
+ @model_validator(mode="after")
505
+ def _validate_locale_with_managed_datasets(self) -> Self:
506
+ if self.locale not in LOCALES_WITH_MANAGED_DATASETS:
507
+ raise ValueError(
508
+ "Person sampling from managed datasets is only supported for the following "
509
+ f"locales: {', '.join(LOCALES_WITH_MANAGED_DATASETS)}."
510
+ )
511
+ return self
512
+
513
+
514
+ class PersonFromFakerSamplerParams(ConfigBase):
515
+ locale: str = Field(
516
+ default="en_US",
517
+ description=(
518
+ "Locale string, determines the language and geographic locale "
519
+ "that a synthetic person will be sampled from. E.g, en_US, en_GB, fr_FR, ..."
520
+ ),
521
+ )
522
+ sex: Optional[SexT] = Field(
523
+ default=None,
524
+ description="If specified, then only synthetic people of the specified sex will be sampled.",
525
+ )
526
+ city: Optional[Union[str, list[str]]] = Field(
527
+ default=None,
528
+ description="If specified, then only synthetic people from these cities will be sampled.",
529
+ )
530
+ age_range: list[int] = Field(
531
+ default=DEFAULT_AGE_RANGE,
532
+ description="If specified, then only synthetic people within this age range will be sampled.",
533
+ min_length=2,
534
+ max_length=2,
535
+ )
536
+
537
+ @property
538
+ def generator_kwargs(self) -> list[str]:
539
+ """Keyword arguments to pass to the person generator."""
540
+ return [f for f in list(PersonFromFakerSamplerParams.model_fields) if f != "locale"]
541
+
542
+ @property
543
+ def people_gen_key(self) -> str:
544
+ return f"{self.locale}_faker"
545
+
546
+ @field_validator("age_range")
547
+ @classmethod
548
+ def _validate_age_range(cls, value: list[int]) -> list[int]:
549
+ msg_prefix = "'age_range' must be a list of two integers, representing the min and max age."
550
+ if value[0] < MIN_AGE:
551
+ raise ValueError(
552
+ f"{msg_prefix} The first integer (min age) must be greater than or equal to {MIN_AGE}, "
553
+ f"but the first integer provided was {value[0]}."
554
+ )
555
+ if value[1] > MAX_AGE:
556
+ raise ValueError(
557
+ f"{msg_prefix} The second integer (max age) must be less than or equal to {MAX_AGE}, "
558
+ f"but the second integer provided was {value[1]}."
559
+ )
560
+ if value[0] >= value[1]:
561
+ raise ValueError(
562
+ f"{msg_prefix} The first integer (min age) must be less than the second integer (max age), "
563
+ f"but the first integer provided was {value[0]} and the second integer provided was {value[1]}."
564
+ )
565
+ return value
566
+
567
+ @field_validator("locale")
568
+ @classmethod
569
+ def _validate_locale(cls, value: str) -> str:
570
+ if value not in AVAILABLE_LOCALES:
571
+ raise ValueError(
572
+ f"Locale {value!r} is not a supported locale. Supported locales: {', '.join(AVAILABLE_LOCALES)}"
573
+ )
574
+ return value
575
+
576
+
577
+ SamplerParamsT: TypeAlias = Union[
578
+ SubcategorySamplerParams,
579
+ CategorySamplerParams,
580
+ DatetimeSamplerParams,
581
+ PersonSamplerParams,
582
+ PersonFromFakerSamplerParams,
583
+ TimeDeltaSamplerParams,
584
+ UUIDSamplerParams,
585
+ BernoulliSamplerParams,
586
+ BernoulliMixtureSamplerParams,
587
+ BinomialSamplerParams,
588
+ GaussianSamplerParams,
589
+ PoissonSamplerParams,
590
+ UniformSamplerParams,
591
+ ScipySamplerParams,
592
+ ]
593
+
594
+
595
+ def is_numerical_sampler_type(sampler_type: SamplerType) -> bool:
596
+ return SamplerType(sampler_type) in {
597
+ SamplerType.BERNOULLI_MIXTURE,
598
+ SamplerType.BERNOULLI,
599
+ SamplerType.BINOMIAL,
600
+ SamplerType.GAUSSIAN,
601
+ SamplerType.POISSON,
602
+ SamplerType.SCIPY,
603
+ SamplerType.UNIFORM,
604
+ }