data-designer 0.3.8rc1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. data_designer/cli/commands/__init__.py +1 -1
  2. data_designer/interface/__init__.py +21 -1
  3. data_designer/{_version.py → interface/_version.py} +2 -2
  4. data_designer/interface/data_designer.py +8 -11
  5. {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/METADATA +10 -42
  6. data_designer-0.4.0.dist-info/RECORD +39 -0
  7. data_designer/__init__.py +0 -17
  8. data_designer/config/__init__.py +0 -2
  9. data_designer/config/analysis/__init__.py +0 -2
  10. data_designer/config/analysis/column_profilers.py +0 -159
  11. data_designer/config/analysis/column_statistics.py +0 -421
  12. data_designer/config/analysis/dataset_profiler.py +0 -84
  13. data_designer/config/analysis/utils/errors.py +0 -10
  14. data_designer/config/analysis/utils/reporting.py +0 -192
  15. data_designer/config/base.py +0 -69
  16. data_designer/config/column_configs.py +0 -470
  17. data_designer/config/column_types.py +0 -141
  18. data_designer/config/config_builder.py +0 -595
  19. data_designer/config/data_designer_config.py +0 -40
  20. data_designer/config/dataset_builders.py +0 -13
  21. data_designer/config/dataset_metadata.py +0 -18
  22. data_designer/config/default_model_settings.py +0 -121
  23. data_designer/config/errors.py +0 -24
  24. data_designer/config/exports.py +0 -145
  25. data_designer/config/interface.py +0 -55
  26. data_designer/config/models.py +0 -455
  27. data_designer/config/preview_results.py +0 -41
  28. data_designer/config/processors.py +0 -148
  29. data_designer/config/run_config.py +0 -48
  30. data_designer/config/sampler_constraints.py +0 -52
  31. data_designer/config/sampler_params.py +0 -639
  32. data_designer/config/seed.py +0 -116
  33. data_designer/config/seed_source.py +0 -84
  34. data_designer/config/seed_source_types.py +0 -19
  35. data_designer/config/utils/code_lang.py +0 -82
  36. data_designer/config/utils/constants.py +0 -363
  37. data_designer/config/utils/errors.py +0 -21
  38. data_designer/config/utils/info.py +0 -94
  39. data_designer/config/utils/io_helpers.py +0 -258
  40. data_designer/config/utils/misc.py +0 -78
  41. data_designer/config/utils/numerical_helpers.py +0 -30
  42. data_designer/config/utils/type_helpers.py +0 -106
  43. data_designer/config/utils/visualization.py +0 -482
  44. data_designer/config/validator_params.py +0 -94
  45. data_designer/engine/__init__.py +0 -2
  46. data_designer/engine/analysis/column_profilers/base.py +0 -49
  47. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +0 -153
  48. data_designer/engine/analysis/column_profilers/registry.py +0 -22
  49. data_designer/engine/analysis/column_statistics.py +0 -145
  50. data_designer/engine/analysis/dataset_profiler.py +0 -149
  51. data_designer/engine/analysis/errors.py +0 -9
  52. data_designer/engine/analysis/utils/column_statistics_calculations.py +0 -234
  53. data_designer/engine/analysis/utils/judge_score_processing.py +0 -132
  54. data_designer/engine/column_generators/__init__.py +0 -2
  55. data_designer/engine/column_generators/generators/__init__.py +0 -2
  56. data_designer/engine/column_generators/generators/base.py +0 -122
  57. data_designer/engine/column_generators/generators/embedding.py +0 -35
  58. data_designer/engine/column_generators/generators/expression.py +0 -55
  59. data_designer/engine/column_generators/generators/llm_completion.py +0 -113
  60. data_designer/engine/column_generators/generators/samplers.py +0 -69
  61. data_designer/engine/column_generators/generators/seed_dataset.py +0 -144
  62. data_designer/engine/column_generators/generators/validation.py +0 -140
  63. data_designer/engine/column_generators/registry.py +0 -60
  64. data_designer/engine/column_generators/utils/errors.py +0 -15
  65. data_designer/engine/column_generators/utils/generator_classification.py +0 -43
  66. data_designer/engine/column_generators/utils/judge_score_factory.py +0 -58
  67. data_designer/engine/column_generators/utils/prompt_renderer.py +0 -100
  68. data_designer/engine/compiler.py +0 -97
  69. data_designer/engine/configurable_task.py +0 -71
  70. data_designer/engine/dataset_builders/artifact_storage.py +0 -283
  71. data_designer/engine/dataset_builders/column_wise_builder.py +0 -338
  72. data_designer/engine/dataset_builders/errors.py +0 -15
  73. data_designer/engine/dataset_builders/multi_column_configs.py +0 -46
  74. data_designer/engine/dataset_builders/utils/__init__.py +0 -2
  75. data_designer/engine/dataset_builders/utils/concurrency.py +0 -215
  76. data_designer/engine/dataset_builders/utils/config_compiler.py +0 -62
  77. data_designer/engine/dataset_builders/utils/dag.py +0 -62
  78. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +0 -200
  79. data_designer/engine/dataset_builders/utils/errors.py +0 -15
  80. data_designer/engine/errors.py +0 -51
  81. data_designer/engine/model_provider.py +0 -77
  82. data_designer/engine/models/__init__.py +0 -2
  83. data_designer/engine/models/errors.py +0 -300
  84. data_designer/engine/models/facade.py +0 -287
  85. data_designer/engine/models/factory.py +0 -42
  86. data_designer/engine/models/litellm_overrides.py +0 -179
  87. data_designer/engine/models/parsers/__init__.py +0 -2
  88. data_designer/engine/models/parsers/errors.py +0 -34
  89. data_designer/engine/models/parsers/parser.py +0 -235
  90. data_designer/engine/models/parsers/postprocessors.py +0 -93
  91. data_designer/engine/models/parsers/tag_parsers.py +0 -62
  92. data_designer/engine/models/parsers/types.py +0 -84
  93. data_designer/engine/models/recipes/base.py +0 -81
  94. data_designer/engine/models/recipes/response_recipes.py +0 -293
  95. data_designer/engine/models/registry.py +0 -146
  96. data_designer/engine/models/telemetry.py +0 -359
  97. data_designer/engine/models/usage.py +0 -73
  98. data_designer/engine/models/utils.py +0 -38
  99. data_designer/engine/processing/ginja/__init__.py +0 -2
  100. data_designer/engine/processing/ginja/ast.py +0 -65
  101. data_designer/engine/processing/ginja/environment.py +0 -463
  102. data_designer/engine/processing/ginja/exceptions.py +0 -56
  103. data_designer/engine/processing/ginja/record.py +0 -32
  104. data_designer/engine/processing/gsonschema/__init__.py +0 -2
  105. data_designer/engine/processing/gsonschema/exceptions.py +0 -15
  106. data_designer/engine/processing/gsonschema/schema_transformers.py +0 -83
  107. data_designer/engine/processing/gsonschema/types.py +0 -10
  108. data_designer/engine/processing/gsonschema/validators.py +0 -202
  109. data_designer/engine/processing/processors/base.py +0 -13
  110. data_designer/engine/processing/processors/drop_columns.py +0 -42
  111. data_designer/engine/processing/processors/registry.py +0 -25
  112. data_designer/engine/processing/processors/schema_transform.py +0 -49
  113. data_designer/engine/processing/utils.py +0 -169
  114. data_designer/engine/registry/base.py +0 -99
  115. data_designer/engine/registry/data_designer_registry.py +0 -39
  116. data_designer/engine/registry/errors.py +0 -12
  117. data_designer/engine/resources/managed_dataset_generator.py +0 -39
  118. data_designer/engine/resources/managed_dataset_repository.py +0 -197
  119. data_designer/engine/resources/managed_storage.py +0 -65
  120. data_designer/engine/resources/resource_provider.py +0 -77
  121. data_designer/engine/resources/seed_reader.py +0 -154
  122. data_designer/engine/sampling_gen/column.py +0 -91
  123. data_designer/engine/sampling_gen/constraints.py +0 -100
  124. data_designer/engine/sampling_gen/data_sources/base.py +0 -217
  125. data_designer/engine/sampling_gen/data_sources/errors.py +0 -12
  126. data_designer/engine/sampling_gen/data_sources/sources.py +0 -347
  127. data_designer/engine/sampling_gen/entities/__init__.py +0 -2
  128. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  129. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +0 -86
  130. data_designer/engine/sampling_gen/entities/email_address_utils.py +0 -171
  131. data_designer/engine/sampling_gen/entities/errors.py +0 -10
  132. data_designer/engine/sampling_gen/entities/national_id_utils.py +0 -102
  133. data_designer/engine/sampling_gen/entities/person.py +0 -144
  134. data_designer/engine/sampling_gen/entities/phone_number.py +0 -128
  135. data_designer/engine/sampling_gen/errors.py +0 -26
  136. data_designer/engine/sampling_gen/generator.py +0 -122
  137. data_designer/engine/sampling_gen/jinja_utils.py +0 -64
  138. data_designer/engine/sampling_gen/people_gen.py +0 -199
  139. data_designer/engine/sampling_gen/person_constants.py +0 -56
  140. data_designer/engine/sampling_gen/schema.py +0 -147
  141. data_designer/engine/sampling_gen/schema_builder.py +0 -61
  142. data_designer/engine/sampling_gen/utils.py +0 -46
  143. data_designer/engine/secret_resolver.py +0 -82
  144. data_designer/engine/validation.py +0 -367
  145. data_designer/engine/validators/__init__.py +0 -19
  146. data_designer/engine/validators/base.py +0 -38
  147. data_designer/engine/validators/local_callable.py +0 -39
  148. data_designer/engine/validators/python.py +0 -254
  149. data_designer/engine/validators/remote.py +0 -89
  150. data_designer/engine/validators/sql.py +0 -65
  151. data_designer/errors.py +0 -7
  152. data_designer/essentials/__init__.py +0 -33
  153. data_designer/lazy_heavy_imports.py +0 -54
  154. data_designer/logging.py +0 -163
  155. data_designer/plugin_manager.py +0 -78
  156. data_designer/plugins/__init__.py +0 -8
  157. data_designer/plugins/errors.py +0 -15
  158. data_designer/plugins/plugin.py +0 -141
  159. data_designer/plugins/registry.py +0 -88
  160. data_designer/plugins/testing/__init__.py +0 -10
  161. data_designer/plugins/testing/stubs.py +0 -116
  162. data_designer/plugins/testing/utils.py +0 -20
  163. data_designer-0.3.8rc1.dist-info/RECORD +0 -196
  164. data_designer-0.3.8rc1.dist-info/licenses/LICENSE +0 -201
  165. {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/WHEEL +0 -0
  166. {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/entry_points.txt +0 -0
@@ -1,52 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from abc import ABC, abstractmethod
7
- from enum import Enum
8
-
9
- from typing_extensions import TypeAlias
10
-
11
- from data_designer.config.base import ConfigBase
12
-
13
-
14
- class ConstraintType(str, Enum):
15
- SCALAR_INEQUALITY = "scalar_inequality"
16
- COLUMN_INEQUALITY = "column_inequality"
17
-
18
-
19
- class InequalityOperator(str, Enum):
20
- LT = "lt"
21
- LE = "le"
22
- GT = "gt"
23
- GE = "ge"
24
-
25
-
26
- class Constraint(ConfigBase, ABC):
27
- target_column: str
28
-
29
- @property
30
- @abstractmethod
31
- def constraint_type(self) -> ConstraintType: ...
32
-
33
-
34
- class ScalarInequalityConstraint(Constraint):
35
- rhs: float
36
- operator: InequalityOperator
37
-
38
- @property
39
- def constraint_type(self) -> ConstraintType:
40
- return ConstraintType.SCALAR_INEQUALITY
41
-
42
-
43
- class ColumnInequalityConstraint(Constraint):
44
- rhs: str
45
- operator: InequalityOperator
46
-
47
- @property
48
- def constraint_type(self) -> ConstraintType:
49
- return ConstraintType.COLUMN_INEQUALITY
50
-
51
-
52
- ColumnConstraintT: TypeAlias = ScalarInequalityConstraint | ColumnInequalityConstraint
@@ -1,639 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from enum import Enum
7
- from typing import TYPE_CHECKING, Literal
8
-
9
- from pydantic import Field, field_validator, model_validator
10
- from typing_extensions import Self, TypeAlias
11
-
12
- from data_designer.config.base import ConfigBase
13
- from data_designer.config.utils.constants import (
14
- AVAILABLE_LOCALES,
15
- DEFAULT_AGE_RANGE,
16
- LOCALES_WITH_MANAGED_DATASETS,
17
- MAX_AGE,
18
- MIN_AGE,
19
- )
20
- from data_designer.lazy_heavy_imports import pd
21
-
22
- if TYPE_CHECKING:
23
- import pandas as pd
24
-
25
-
26
- class SamplerType(str, Enum):
27
- BERNOULLI = "bernoulli"
28
- BERNOULLI_MIXTURE = "bernoulli_mixture"
29
- BINOMIAL = "binomial"
30
- CATEGORY = "category"
31
- DATETIME = "datetime"
32
- GAUSSIAN = "gaussian"
33
- PERSON = "person"
34
- PERSON_FROM_FAKER = "person_from_faker"
35
- POISSON = "poisson"
36
- SCIPY = "scipy"
37
- SUBCATEGORY = "subcategory"
38
- TIMEDELTA = "timedelta"
39
- UNIFORM = "uniform"
40
- UUID = "uuid"
41
-
42
-
43
- #########################################
44
- # Sampler Parameters
45
- #########################################
46
-
47
-
48
- class CategorySamplerParams(ConfigBase):
49
- """Parameters for categorical sampling with optional probability weighting.
50
-
51
- Samples values from a discrete set of categories. When weights are provided, values are
52
- sampled according to their assigned probabilities. Without weights, uniform sampling is used.
53
-
54
- Attributes:
55
- values: List of possible categorical values to sample from. Can contain strings, integers,
56
- or floats. Must contain at least one value.
57
- weights: Optional unnormalized probability weights for each value. If provided, must be
58
- the same length as `values`. Weights are automatically normalized to sum to 1.0.
59
- Larger weights result in higher sampling probability for the corresponding value.
60
- """
61
-
62
- values: list[str | int | float] = Field(
63
- ...,
64
- min_length=1,
65
- description="List of possible categorical values that can be sampled from.",
66
- )
67
- weights: list[float] | None = Field(
68
- default=None,
69
- description=(
70
- "List of unnormalized probability weights to assigned to each value, in order. "
71
- "Larger values will be sampled with higher probability."
72
- ),
73
- )
74
- sampler_type: Literal[SamplerType.CATEGORY] = SamplerType.CATEGORY
75
-
76
- @model_validator(mode="after")
77
- def _normalize_weights_if_needed(self) -> Self:
78
- if self.weights is not None:
79
- self.weights = [w / sum(self.weights) for w in self.weights]
80
- return self
81
-
82
- @model_validator(mode="after")
83
- def _validate_equal_lengths(self) -> Self:
84
- if self.weights and len(self.values) != len(self.weights):
85
- raise ValueError("'categories' and 'weights' must have the same length")
86
- return self
87
-
88
-
89
- class DatetimeSamplerParams(ConfigBase):
90
- """Parameters for uniform datetime sampling within a specified range.
91
-
92
- Samples datetime values uniformly between a start and end date with a specified granularity.
93
- The sampling unit determines the smallest possible time interval between consecutive samples.
94
-
95
- Attributes:
96
- start: Earliest possible datetime for the sampling range (inclusive). Must be a valid
97
- datetime string parseable by pandas.to_datetime().
98
- end: Latest possible datetime for the sampling range (inclusive). Must be a valid
99
- datetime string parseable by pandas.to_datetime().
100
- unit: Time unit for sampling granularity. Options:
101
- - "Y": Years
102
- - "M": Months
103
- - "D": Days (default)
104
- - "h": Hours
105
- - "m": Minutes
106
- - "s": Seconds
107
- """
108
-
109
- start: str = Field(..., description="Earliest possible datetime for sampling range, inclusive.")
110
- end: str = Field(..., description="Latest possible datetime for sampling range, inclusive.")
111
- unit: Literal["Y", "M", "D", "h", "m", "s"] = Field(
112
- default="D",
113
- description="Sampling units, e.g. the smallest possible time interval between samples.",
114
- )
115
- sampler_type: Literal[SamplerType.DATETIME] = SamplerType.DATETIME
116
-
117
- @field_validator("start", "end")
118
- @classmethod
119
- def _validate_param_is_datetime(cls, value: str) -> str:
120
- try:
121
- pd.to_datetime(value)
122
- except ValueError:
123
- raise ValueError(f"Invalid datetime format: {value}")
124
- return value
125
-
126
-
127
- class SubcategorySamplerParams(ConfigBase):
128
- """Parameters for subcategory sampling conditioned on a parent category column.
129
-
130
- Samples subcategory values based on the value of a parent category column. Each parent
131
- category value maps to its own list of possible subcategory values, enabling hierarchical
132
- or conditional sampling patterns.
133
-
134
- Attributes:
135
- category: Name of the parent category column that this subcategory depends on.
136
- The parent column must be generated before this subcategory column.
137
- values: Mapping from each parent category value to a list of possible subcategory values.
138
- Each key must correspond to a value that appears in the parent category column.
139
- """
140
-
141
- category: str = Field(..., description="Name of parent category to this subcategory.")
142
- values: dict[str, list[str | int | float]] = Field(
143
- ...,
144
- description="Mapping from each value of parent category to a list of subcategory values.",
145
- )
146
- sampler_type: Literal[SamplerType.SUBCATEGORY] = SamplerType.SUBCATEGORY
147
-
148
-
149
- class TimeDeltaSamplerParams(ConfigBase):
150
- """Parameters for sampling time deltas relative to a reference datetime column.
151
-
152
- Samples time offsets within a specified range and adds them to values from a reference
153
- datetime column. This is useful for generating related datetime columns like order dates
154
- and delivery dates, or event start times and end times.
155
-
156
- Note:
157
- Years and months are not supported as timedelta units because they have variable lengths.
158
- See: [pandas timedelta documentation](https://pandas.pydata.org/docs/user_guide/timedeltas.html)
159
-
160
- Attributes:
161
- dt_min: Minimum time-delta value (inclusive). Must be non-negative and less than `dt_max`.
162
- Specified in units defined by the `unit` parameter.
163
- dt_max: Maximum time-delta value (exclusive). Must be positive and greater than `dt_min`.
164
- Specified in units defined by the `unit` parameter.
165
- reference_column_name: Name of an existing datetime column to add the time-delta to.
166
- This column must be generated before the timedelta column.
167
- unit: Time unit for the delta values. Options:
168
- - "D": Days (default)
169
- - "h": Hours
170
- - "m": Minutes
171
- - "s": Seconds
172
- """
173
-
174
- dt_min: int = Field(
175
- ...,
176
- ge=0,
177
- description=("Minimum possible time-delta for sampling range, inclusive. Must be less than `dt_max`."),
178
- )
179
- dt_max: int = Field(
180
- ...,
181
- gt=0,
182
- description=("Maximum possible time-delta for sampling range, exclusive. Must be greater than `dt_min`."),
183
- )
184
-
185
- reference_column_name: str = Field(
186
- ...,
187
- description="Name of an existing datetime column to condition time-delta sampling on.",
188
- )
189
-
190
- # NOTE: pandas does not support years or months as timedelta units
191
- # since they are ambiguous. We will need to update the implementation
192
- # if we need to support these units.
193
- # see: https://pandas.pydata.org/docs/user_guide/timedeltas.html.
194
- unit: Literal["D", "h", "m", "s"] = Field(
195
- default="D",
196
- description="Sampling units, e.g. the smallest possible time interval between samples.",
197
- )
198
- sampler_type: Literal[SamplerType.TIMEDELTA] = SamplerType.TIMEDELTA
199
-
200
- @model_validator(mode="after")
201
- def _validate_min_less_than_max(self) -> Self:
202
- if self.dt_min >= self.dt_max:
203
- raise ValueError("'dt_min' must be less than 'dt_max'")
204
- return self
205
-
206
-
207
- class UUIDSamplerParams(ConfigBase):
208
- """Parameters for generating UUID (Universally Unique Identifier) values.
209
-
210
- Generates UUID4 (random) identifiers with optional formatting options. UUIDs are useful
211
- for creating unique identifiers for records, entities, or transactions.
212
-
213
- Attributes:
214
- prefix: Optional string to prepend to each UUID. Useful for creating namespaced or
215
- typed identifiers (e.g., "user-", "order-", "txn-").
216
- short_form: If True, truncates UUIDs to 8 characters (first segment only). Default is False
217
- for full 32-character UUIDs (excluding hyphens).
218
- uppercase: If True, converts all hexadecimal letters to uppercase. Default is False for
219
- lowercase UUIDs.
220
- """
221
-
222
- prefix: str | None = Field(default=None, description="String prepended to the front of the UUID.")
223
- short_form: bool = Field(
224
- default=False,
225
- description="If true, all UUIDs sampled will be truncated at 8 characters.",
226
- )
227
- uppercase: bool = Field(
228
- default=False,
229
- description="If true, all letters in the UUID will be capitalized.",
230
- )
231
- sampler_type: Literal[SamplerType.UUID] = SamplerType.UUID
232
-
233
- @property
234
- def last_index(self) -> int:
235
- return 8 if self.short_form else 32
236
-
237
-
238
- #########################################
239
- # Scipy Sampler Parameters
240
- #########################################
241
-
242
-
243
- class ScipySamplerParams(ConfigBase):
244
- """Parameters for sampling from any scipy.stats continuous or discrete distribution.
245
-
246
- Provides a flexible interface to sample from the wide range of probability distributions
247
- available in scipy.stats. This enables advanced statistical sampling beyond the built-in
248
- distribution types (Gaussian, Uniform, etc.).
249
-
250
- See: [scipy.stats documentation](https://docs.scipy.org/doc/scipy/reference/stats.html)
251
-
252
- Attributes:
253
- dist_name: Name of the scipy.stats distribution to sample from (e.g., "beta", "gamma",
254
- "lognorm", "expon"). Must be a valid distribution name from scipy.stats.
255
- dist_params: Dictionary of parameters for the specified distribution. Parameter names
256
- and values must match the scipy.stats distribution specification (e.g., {"a": 2, "b": 5}
257
- for beta distribution, {"scale": 1.5} for exponential).
258
- decimal_places: Optional number of decimal places to round sampled values to. If None,
259
- values are not rounded.
260
- """
261
-
262
- dist_name: str = Field(..., description="Name of a scipy.stats distribution.")
263
- dist_params: dict = Field(
264
- ...,
265
- description="Parameters of the scipy.stats distribution given in `dist_name`.",
266
- )
267
- decimal_places: int | None = Field(
268
- default=None, description="Number of decimal places to round the sampled values to."
269
- )
270
- sampler_type: Literal[SamplerType.SCIPY] = SamplerType.SCIPY
271
-
272
-
273
- class BinomialSamplerParams(ConfigBase):
274
- """Parameters for sampling from a Binomial distribution.
275
-
276
- Samples integer values representing the number of successes in a fixed number of independent
277
- Bernoulli trials, each with the same probability of success. Commonly used to model the number
278
- of successful outcomes in repeated experiments.
279
-
280
- Attributes:
281
- n: Number of independent trials. Must be a positive integer.
282
- p: Probability of success on each trial. Must be between 0.0 and 1.0 (inclusive).
283
- """
284
-
285
- n: int = Field(..., description="Number of trials.")
286
- p: float = Field(..., description="Probability of success on each trial.", ge=0.0, le=1.0)
287
- sampler_type: Literal[SamplerType.BINOMIAL] = SamplerType.BINOMIAL
288
-
289
-
290
- class BernoulliSamplerParams(ConfigBase):
291
- """Parameters for sampling from a Bernoulli distribution.
292
-
293
- Samples binary values (0 or 1) representing the outcome of a single trial with a fixed
294
- probability of success. This is the simplest discrete probability distribution, useful for
295
- modeling binary outcomes like success/failure, yes/no, or true/false.
296
-
297
- Attributes:
298
- p: Probability of success (sampling 1). Must be between 0.0 and 1.0 (inclusive).
299
- The probability of failure (sampling 0) is automatically 1 - p.
300
- """
301
-
302
- p: float = Field(..., description="Probability of success.", ge=0.0, le=1.0)
303
- sampler_type: Literal[SamplerType.BERNOULLI] = SamplerType.BERNOULLI
304
-
305
-
306
- class BernoulliMixtureSamplerParams(ConfigBase):
307
- """Parameters for sampling from a Bernoulli mixture distribution.
308
-
309
- Combines a Bernoulli distribution with another continuous distribution, creating a mixture
310
- where values are either 0 (with probability 1-p) or sampled from the specified distribution
311
- (with probability p). This is useful for modeling scenarios with many zero values mixed with
312
- a continuous distribution of non-zero values.
313
-
314
- Common use cases include modeling sparse events, zero-inflated data, or situations where
315
- an outcome either doesn't occur (0) or follows a specific distribution when it does occur.
316
-
317
- Attributes:
318
- p: Probability of sampling from the mixture distribution (non-zero outcome).
319
- Must be between 0.0 and 1.0 (inclusive). With probability 1-p, the sample is 0.
320
- dist_name: Name of the scipy.stats distribution to sample from when outcome is non-zero.
321
- Must be a valid scipy.stats distribution name (e.g., "norm", "gamma", "expon").
322
- dist_params: Parameters for the specified scipy.stats distribution.
323
- """
324
-
325
- p: float = Field(
326
- ...,
327
- description="Bernoulli distribution probability of success.",
328
- ge=0.0,
329
- le=1.0,
330
- )
331
- dist_name: str = Field(
332
- ...,
333
- description=(
334
- "Mixture distribution name. Samples will be equal to the "
335
- "distribution sample with probability `p`, otherwise equal to 0. "
336
- "Must be a valid scipy.stats distribution name."
337
- ),
338
- )
339
- dist_params: dict = Field(
340
- ...,
341
- description="Parameters of the scipy.stats distribution given in `dist_name`.",
342
- )
343
- sampler_type: Literal[SamplerType.BERNOULLI_MIXTURE] = SamplerType.BERNOULLI_MIXTURE
344
-
345
-
346
- class GaussianSamplerParams(ConfigBase):
347
- """Parameters for sampling from a Gaussian (Normal) distribution.
348
-
349
- Samples continuous values from a normal distribution characterized by its mean and standard
350
- deviation. The Gaussian distribution is one of the most commonly used probability distributions,
351
- appearing naturally in many real-world phenomena due to the Central Limit Theorem.
352
-
353
- Attributes:
354
- mean: Mean (center) of the Gaussian distribution. This is the expected value and the
355
- location of the distribution's peak.
356
- stddev: Standard deviation of the Gaussian distribution. Controls the spread or width
357
- of the distribution. Must be positive.
358
- decimal_places: Optional number of decimal places to round sampled values to. If None,
359
- values are not rounded.
360
- """
361
-
362
- mean: float = Field(..., description="Mean of the Gaussian distribution")
363
- stddev: float = Field(..., description="Standard deviation of the Gaussian distribution")
364
- decimal_places: int | None = Field(
365
- default=None, description="Number of decimal places to round the sampled values to."
366
- )
367
- sampler_type: Literal[SamplerType.GAUSSIAN] = SamplerType.GAUSSIAN
368
-
369
-
370
- class PoissonSamplerParams(ConfigBase):
371
- """Parameters for sampling from a Poisson distribution.
372
-
373
- Samples non-negative integer values representing the number of events occurring in a fixed
374
- interval of time or space. The Poisson distribution is commonly used to model count data
375
- like the number of arrivals, occurrences, or events per time period.
376
-
377
- The distribution is characterized by a single parameter (mean/rate), and both the mean and
378
- variance equal this parameter value.
379
-
380
- Attributes:
381
- mean: Mean number of events in the fixed interval (also called rate parameter λ).
382
- Must be positive. This represents both the expected value and the variance of the
383
- distribution.
384
- """
385
-
386
- mean: float = Field(..., description="Mean number of events in a fixed interval.")
387
- sampler_type: Literal[SamplerType.POISSON] = SamplerType.POISSON
388
-
389
-
390
- class UniformSamplerParams(ConfigBase):
391
- """Parameters for sampling from a continuous Uniform distribution.
392
-
393
- Samples continuous values uniformly from a specified range, where every value in the range
394
- has equal probability of being sampled. This is useful when all values within a range are
395
- equally likely, such as random percentages, proportions, or unbiased measurements.
396
-
397
- Attributes:
398
- low: Lower bound of the uniform distribution (inclusive). Can be any real number.
399
- high: Upper bound of the uniform distribution (inclusive). Must be greater than `low`.
400
- decimal_places: Optional number of decimal places to round sampled values to. If None,
401
- values are not rounded and may have many decimal places.
402
- """
403
-
404
- low: float = Field(..., description="Lower bound of the uniform distribution, inclusive.")
405
- high: float = Field(..., description="Upper bound of the uniform distribution, inclusive.")
406
- decimal_places: int | None = Field(
407
- default=None, description="Number of decimal places to round the sampled values to."
408
- )
409
- sampler_type: Literal[SamplerType.UNIFORM] = SamplerType.UNIFORM
410
-
411
-
412
- #########################################
413
- # Person Sampler Parameters
414
- #########################################
415
-
416
- SexT: TypeAlias = Literal["Male", "Female"]
417
-
418
-
419
- class PersonSamplerParams(ConfigBase):
420
- """Parameters for sampling synthetic person data with demographic attributes.
421
-
422
- Generates realistic synthetic person data including names, addresses, phone numbers, and other
423
- demographic information. Data can be sampled from managed datasets (when available) or generated
424
- using Faker. The sampler supports filtering by locale, sex, age, geographic location, and can
425
- optionally include synthetic persona descriptions.
426
-
427
- Attributes:
428
- locale: Locale string determining the language and geographic region for synthetic people.
429
- Must be a locale supported by a managed Nemotron Personas dataset. The dataset must
430
- be downloaded and available in the managed assets directory.
431
- sex: If specified, filters to only sample people of the specified sex. Options: "Male" or
432
- "Female". If None, samples both sexes.
433
- city: If specified, filters to only sample people from the specified city or cities. Can be
434
- a single city name (string) or a list of city names.
435
- age_range: Two-element list [min_age, max_age] specifying the age range to sample from
436
- (inclusive). Defaults to a standard age range. Both values must be between minimum and
437
- maximum allowed ages.
438
- with_synthetic_personas: If True, appends additional synthetic persona columns including
439
- personality traits, interests, and background descriptions. Only supported for certain
440
- locales with managed datasets.
441
- sample_dataset_when_available: If True, samples from curated managed datasets when available
442
- for the specified locale. If False or unavailable, falls back to Faker-generated data.
443
- Managed datasets typically provide more realistic and diverse synthetic people.
444
- """
445
-
446
- locale: str = Field(
447
- default="en_US",
448
- description=(
449
- "Locale that determines the language and geographic location "
450
- "that a synthetic person will be sampled from. Must be a locale supported by "
451
- "a managed Nemotron Personas dataset. Managed datasets exist for the following locales: "
452
- f"{', '.join(LOCALES_WITH_MANAGED_DATASETS)}."
453
- ),
454
- )
455
- sex: SexT | None = Field(
456
- default=None,
457
- description="If specified, then only synthetic people of the specified sex will be sampled.",
458
- )
459
- city: str | list[str] | None = Field(
460
- default=None,
461
- description="If specified, then only synthetic people from these cities will be sampled.",
462
- )
463
- age_range: list[int] = Field(
464
- default=DEFAULT_AGE_RANGE,
465
- description="If specified, then only synthetic people within this age range will be sampled.",
466
- min_length=2,
467
- max_length=2,
468
- )
469
- select_field_values: dict[str, list[str]] | None = Field(
470
- default=None,
471
- description=(
472
- "Sample synthetic people with the specified field values. This is meant to be a flexible argument for "
473
- "selecting a subset of the population from the managed dataset. Note that this sampler does not support "
474
- "rare combinations of field values and will likely fail if your desired subset is not well-represented "
475
- "in the managed Nemotron Personas dataset. We generally recommend using the `sex`, `city`, and `age_range` "
476
- "arguments to filter the population when possible."
477
- ),
478
- examples=[
479
- {"state": ["NY", "CA", "OH", "TX", "NV"], "education_level": ["high_school", "some_college", "bachelors"]}
480
- ],
481
- )
482
-
483
- with_synthetic_personas: bool = Field(
484
- default=False,
485
- description="If True, then append synthetic persona columns to each generated person.",
486
- )
487
- sampler_type: Literal[SamplerType.PERSON] = SamplerType.PERSON
488
-
489
- @property
490
- def generator_kwargs(self) -> list[str]:
491
- """Keyword arguments to pass to the person generator."""
492
- return [f for f in list(PersonSamplerParams.model_fields) if f not in ("locale", "sampler_type")]
493
-
494
- @property
495
- def people_gen_key(self) -> str:
496
- return f"{self.locale}_with_personas" if self.with_synthetic_personas else self.locale
497
-
498
- @field_validator("age_range")
499
- @classmethod
500
- def _validate_age_range(cls, value: list[int]) -> list[int]:
501
- msg_prefix = "'age_range' must be a list of two integers, representing the min and max age."
502
- if value[0] < MIN_AGE:
503
- raise ValueError(
504
- f"{msg_prefix} The first integer (min age) must be greater than or equal to {MIN_AGE}, "
505
- f"but the first integer provided was {value[0]}."
506
- )
507
- if value[1] > MAX_AGE:
508
- raise ValueError(
509
- f"{msg_prefix} The second integer (max age) must be less than or equal to {MAX_AGE}, "
510
- f"but the second integer provided was {value[1]}."
511
- )
512
- if value[0] >= value[1]:
513
- raise ValueError(
514
- f"{msg_prefix} The first integer (min age) must be less than the second integer (max age), "
515
- f"but the first integer provided was {value[0]} and the second integer provided was {value[1]}."
516
- )
517
- return value
518
-
519
- @model_validator(mode="after")
520
- def _validate_locale_with_managed_datasets(self) -> Self:
521
- if self.locale not in LOCALES_WITH_MANAGED_DATASETS:
522
- raise ValueError(
523
- "Person sampling from managed datasets is only supported for the following "
524
- f"locales: {', '.join(LOCALES_WITH_MANAGED_DATASETS)}."
525
- )
526
- return self
527
-
528
-
529
- class PersonFromFakerSamplerParams(ConfigBase):
530
- """Parameters for sampling synthetic person data with demographic attributes from Faker.
531
-
532
- Uses the Faker library to generate random personal information. The data is basic and not demographically
533
- accurate, but is useful for quick testing, prototyping, or when realistic demographic distributions are not
534
- relevant for your use case. For demographically accurate person data, use the `PersonSamplerParams` sampler.
535
-
536
- Attributes:
537
- locale: Locale string determining the language and geographic region for synthetic people.
538
- Can be any locale supported by Faker.
539
- sex: If specified, filters to only sample people of the specified sex. Options: "Male" or
540
- "Female". If None, samples both sexes.
541
- city: If specified, filters to only sample people from the specified city or cities. Can be
542
- a single city name (string) or a list of city names.
543
- age_range: Two-element list [min_age, max_age] specifying the age range to sample from
544
- (inclusive). Defaults to a standard age range. Both values must be between the minimum and
545
- maximum allowed ages.
546
- sampler_type: Discriminator for the sampler type. Must be `SamplerType.PERSON_FROM_FAKER`.
547
- """
548
-
549
- locale: str = Field(
550
- default="en_US",
551
- description=(
552
- "Locale string, determines the language and geographic locale "
553
- "that a synthetic person will be sampled from. E.g, en_US, en_GB, fr_FR, ..."
554
- ),
555
- )
556
- sex: SexT | None = Field(
557
- default=None,
558
- description="If specified, then only synthetic people of the specified sex will be sampled.",
559
- )
560
- city: str | list[str] | None = Field(
561
- default=None,
562
- description="If specified, then only synthetic people from these cities will be sampled.",
563
- )
564
- age_range: list[int] = Field(
565
- default=DEFAULT_AGE_RANGE,
566
- description="If specified, then only synthetic people within this age range will be sampled.",
567
- min_length=2,
568
- max_length=2,
569
- )
570
- sampler_type: Literal[SamplerType.PERSON_FROM_FAKER] = SamplerType.PERSON_FROM_FAKER
571
-
572
- @property
573
- def generator_kwargs(self) -> list[str]:
574
- """Keyword arguments to pass to the person generator."""
575
- return [f for f in list(PersonFromFakerSamplerParams.model_fields) if f not in ("locale", "sampler_type")]
576
-
577
- @property
578
- def people_gen_key(self) -> str:
579
- return f"{self.locale}_faker"
580
-
581
- @field_validator("age_range")
582
- @classmethod
583
- def _validate_age_range(cls, value: list[int]) -> list[int]:
584
- msg_prefix = "'age_range' must be a list of two integers, representing the min and max age."
585
- if value[0] < MIN_AGE:
586
- raise ValueError(
587
- f"{msg_prefix} The first integer (min age) must be greater than or equal to {MIN_AGE}, "
588
- f"but the first integer provided was {value[0]}."
589
- )
590
- if value[1] > MAX_AGE:
591
- raise ValueError(
592
- f"{msg_prefix} The second integer (max age) must be less than or equal to {MAX_AGE}, "
593
- f"but the second integer provided was {value[1]}."
594
- )
595
- if value[0] >= value[1]:
596
- raise ValueError(
597
- f"{msg_prefix} The first integer (min age) must be less than the second integer (max age), "
598
- f"but the first integer provided was {value[0]} and the second integer provided was {value[1]}."
599
- )
600
- return value
601
-
602
- @field_validator("locale")
603
- @classmethod
604
- def _validate_locale(cls, value: str) -> str:
605
- if value not in AVAILABLE_LOCALES:
606
- raise ValueError(
607
- f"Locale {value!r} is not a supported locale. Supported locales: {', '.join(AVAILABLE_LOCALES)}"
608
- )
609
- return value
610
-
611
-
612
- SamplerParamsT: TypeAlias = (
613
- SubcategorySamplerParams
614
- | CategorySamplerParams
615
- | DatetimeSamplerParams
616
- | PersonSamplerParams
617
- | PersonFromFakerSamplerParams
618
- | TimeDeltaSamplerParams
619
- | UUIDSamplerParams
620
- | BernoulliSamplerParams
621
- | BernoulliMixtureSamplerParams
622
- | BinomialSamplerParams
623
- | GaussianSamplerParams
624
- | PoissonSamplerParams
625
- | UniformSamplerParams
626
- | ScipySamplerParams
627
- )
628
-
629
-
630
- def is_numerical_sampler_type(sampler_type: SamplerType) -> bool:
631
- return SamplerType(sampler_type) in {
632
- SamplerType.BERNOULLI_MIXTURE,
633
- SamplerType.BERNOULLI,
634
- SamplerType.BINOMIAL,
635
- SamplerType.GAUSSIAN,
636
- SamplerType.POISSON,
637
- SamplerType.SCIPY,
638
- SamplerType.UNIFORM,
639
- }