data-designer 0.3.8rc1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/cli/commands/__init__.py +1 -1
- data_designer/interface/__init__.py +21 -1
- data_designer/{_version.py → interface/_version.py} +2 -2
- data_designer/interface/data_designer.py +8 -11
- {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/METADATA +10 -42
- data_designer-0.4.0.dist-info/RECORD +39 -0
- data_designer/__init__.py +0 -17
- data_designer/config/__init__.py +0 -2
- data_designer/config/analysis/__init__.py +0 -2
- data_designer/config/analysis/column_profilers.py +0 -159
- data_designer/config/analysis/column_statistics.py +0 -421
- data_designer/config/analysis/dataset_profiler.py +0 -84
- data_designer/config/analysis/utils/errors.py +0 -10
- data_designer/config/analysis/utils/reporting.py +0 -192
- data_designer/config/base.py +0 -69
- data_designer/config/column_configs.py +0 -470
- data_designer/config/column_types.py +0 -141
- data_designer/config/config_builder.py +0 -595
- data_designer/config/data_designer_config.py +0 -40
- data_designer/config/dataset_builders.py +0 -13
- data_designer/config/dataset_metadata.py +0 -18
- data_designer/config/default_model_settings.py +0 -121
- data_designer/config/errors.py +0 -24
- data_designer/config/exports.py +0 -145
- data_designer/config/interface.py +0 -55
- data_designer/config/models.py +0 -455
- data_designer/config/preview_results.py +0 -41
- data_designer/config/processors.py +0 -148
- data_designer/config/run_config.py +0 -48
- data_designer/config/sampler_constraints.py +0 -52
- data_designer/config/sampler_params.py +0 -639
- data_designer/config/seed.py +0 -116
- data_designer/config/seed_source.py +0 -84
- data_designer/config/seed_source_types.py +0 -19
- data_designer/config/utils/code_lang.py +0 -82
- data_designer/config/utils/constants.py +0 -363
- data_designer/config/utils/errors.py +0 -21
- data_designer/config/utils/info.py +0 -94
- data_designer/config/utils/io_helpers.py +0 -258
- data_designer/config/utils/misc.py +0 -78
- data_designer/config/utils/numerical_helpers.py +0 -30
- data_designer/config/utils/type_helpers.py +0 -106
- data_designer/config/utils/visualization.py +0 -482
- data_designer/config/validator_params.py +0 -94
- data_designer/engine/__init__.py +0 -2
- data_designer/engine/analysis/column_profilers/base.py +0 -49
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +0 -153
- data_designer/engine/analysis/column_profilers/registry.py +0 -22
- data_designer/engine/analysis/column_statistics.py +0 -145
- data_designer/engine/analysis/dataset_profiler.py +0 -149
- data_designer/engine/analysis/errors.py +0 -9
- data_designer/engine/analysis/utils/column_statistics_calculations.py +0 -234
- data_designer/engine/analysis/utils/judge_score_processing.py +0 -132
- data_designer/engine/column_generators/__init__.py +0 -2
- data_designer/engine/column_generators/generators/__init__.py +0 -2
- data_designer/engine/column_generators/generators/base.py +0 -122
- data_designer/engine/column_generators/generators/embedding.py +0 -35
- data_designer/engine/column_generators/generators/expression.py +0 -55
- data_designer/engine/column_generators/generators/llm_completion.py +0 -113
- data_designer/engine/column_generators/generators/samplers.py +0 -69
- data_designer/engine/column_generators/generators/seed_dataset.py +0 -144
- data_designer/engine/column_generators/generators/validation.py +0 -140
- data_designer/engine/column_generators/registry.py +0 -60
- data_designer/engine/column_generators/utils/errors.py +0 -15
- data_designer/engine/column_generators/utils/generator_classification.py +0 -43
- data_designer/engine/column_generators/utils/judge_score_factory.py +0 -58
- data_designer/engine/column_generators/utils/prompt_renderer.py +0 -100
- data_designer/engine/compiler.py +0 -97
- data_designer/engine/configurable_task.py +0 -71
- data_designer/engine/dataset_builders/artifact_storage.py +0 -283
- data_designer/engine/dataset_builders/column_wise_builder.py +0 -338
- data_designer/engine/dataset_builders/errors.py +0 -15
- data_designer/engine/dataset_builders/multi_column_configs.py +0 -46
- data_designer/engine/dataset_builders/utils/__init__.py +0 -2
- data_designer/engine/dataset_builders/utils/concurrency.py +0 -215
- data_designer/engine/dataset_builders/utils/config_compiler.py +0 -62
- data_designer/engine/dataset_builders/utils/dag.py +0 -62
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +0 -200
- data_designer/engine/dataset_builders/utils/errors.py +0 -15
- data_designer/engine/errors.py +0 -51
- data_designer/engine/model_provider.py +0 -77
- data_designer/engine/models/__init__.py +0 -2
- data_designer/engine/models/errors.py +0 -300
- data_designer/engine/models/facade.py +0 -287
- data_designer/engine/models/factory.py +0 -42
- data_designer/engine/models/litellm_overrides.py +0 -179
- data_designer/engine/models/parsers/__init__.py +0 -2
- data_designer/engine/models/parsers/errors.py +0 -34
- data_designer/engine/models/parsers/parser.py +0 -235
- data_designer/engine/models/parsers/postprocessors.py +0 -93
- data_designer/engine/models/parsers/tag_parsers.py +0 -62
- data_designer/engine/models/parsers/types.py +0 -84
- data_designer/engine/models/recipes/base.py +0 -81
- data_designer/engine/models/recipes/response_recipes.py +0 -293
- data_designer/engine/models/registry.py +0 -146
- data_designer/engine/models/telemetry.py +0 -359
- data_designer/engine/models/usage.py +0 -73
- data_designer/engine/models/utils.py +0 -38
- data_designer/engine/processing/ginja/__init__.py +0 -2
- data_designer/engine/processing/ginja/ast.py +0 -65
- data_designer/engine/processing/ginja/environment.py +0 -463
- data_designer/engine/processing/ginja/exceptions.py +0 -56
- data_designer/engine/processing/ginja/record.py +0 -32
- data_designer/engine/processing/gsonschema/__init__.py +0 -2
- data_designer/engine/processing/gsonschema/exceptions.py +0 -15
- data_designer/engine/processing/gsonschema/schema_transformers.py +0 -83
- data_designer/engine/processing/gsonschema/types.py +0 -10
- data_designer/engine/processing/gsonschema/validators.py +0 -202
- data_designer/engine/processing/processors/base.py +0 -13
- data_designer/engine/processing/processors/drop_columns.py +0 -42
- data_designer/engine/processing/processors/registry.py +0 -25
- data_designer/engine/processing/processors/schema_transform.py +0 -49
- data_designer/engine/processing/utils.py +0 -169
- data_designer/engine/registry/base.py +0 -99
- data_designer/engine/registry/data_designer_registry.py +0 -39
- data_designer/engine/registry/errors.py +0 -12
- data_designer/engine/resources/managed_dataset_generator.py +0 -39
- data_designer/engine/resources/managed_dataset_repository.py +0 -197
- data_designer/engine/resources/managed_storage.py +0 -65
- data_designer/engine/resources/resource_provider.py +0 -77
- data_designer/engine/resources/seed_reader.py +0 -154
- data_designer/engine/sampling_gen/column.py +0 -91
- data_designer/engine/sampling_gen/constraints.py +0 -100
- data_designer/engine/sampling_gen/data_sources/base.py +0 -217
- data_designer/engine/sampling_gen/data_sources/errors.py +0 -12
- data_designer/engine/sampling_gen/data_sources/sources.py +0 -347
- data_designer/engine/sampling_gen/entities/__init__.py +0 -2
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +0 -86
- data_designer/engine/sampling_gen/entities/email_address_utils.py +0 -171
- data_designer/engine/sampling_gen/entities/errors.py +0 -10
- data_designer/engine/sampling_gen/entities/national_id_utils.py +0 -102
- data_designer/engine/sampling_gen/entities/person.py +0 -144
- data_designer/engine/sampling_gen/entities/phone_number.py +0 -128
- data_designer/engine/sampling_gen/errors.py +0 -26
- data_designer/engine/sampling_gen/generator.py +0 -122
- data_designer/engine/sampling_gen/jinja_utils.py +0 -64
- data_designer/engine/sampling_gen/people_gen.py +0 -199
- data_designer/engine/sampling_gen/person_constants.py +0 -56
- data_designer/engine/sampling_gen/schema.py +0 -147
- data_designer/engine/sampling_gen/schema_builder.py +0 -61
- data_designer/engine/sampling_gen/utils.py +0 -46
- data_designer/engine/secret_resolver.py +0 -82
- data_designer/engine/validation.py +0 -367
- data_designer/engine/validators/__init__.py +0 -19
- data_designer/engine/validators/base.py +0 -38
- data_designer/engine/validators/local_callable.py +0 -39
- data_designer/engine/validators/python.py +0 -254
- data_designer/engine/validators/remote.py +0 -89
- data_designer/engine/validators/sql.py +0 -65
- data_designer/errors.py +0 -7
- data_designer/essentials/__init__.py +0 -33
- data_designer/lazy_heavy_imports.py +0 -54
- data_designer/logging.py +0 -163
- data_designer/plugin_manager.py +0 -78
- data_designer/plugins/__init__.py +0 -8
- data_designer/plugins/errors.py +0 -15
- data_designer/plugins/plugin.py +0 -141
- data_designer/plugins/registry.py +0 -88
- data_designer/plugins/testing/__init__.py +0 -10
- data_designer/plugins/testing/stubs.py +0 -116
- data_designer/plugins/testing/utils.py +0 -20
- data_designer-0.3.8rc1.dist-info/RECORD +0 -196
- data_designer-0.3.8rc1.dist-info/licenses/LICENSE +0 -201
- {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/WHEEL +0 -0
- {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,100 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
from abc import ABC, abstractmethod
|
|
7
|
-
from typing import TYPE_CHECKING
|
|
8
|
-
|
|
9
|
-
from numpy.typing import NDArray
|
|
10
|
-
|
|
11
|
-
from data_designer.config.base import ConfigBase
|
|
12
|
-
from data_designer.config.sampler_constraints import (
|
|
13
|
-
ColumnInequalityConstraint,
|
|
14
|
-
Constraint,
|
|
15
|
-
ConstraintType,
|
|
16
|
-
InequalityOperator,
|
|
17
|
-
ScalarInequalityConstraint,
|
|
18
|
-
)
|
|
19
|
-
from data_designer.lazy_heavy_imports import np, pd
|
|
20
|
-
|
|
21
|
-
if TYPE_CHECKING:
|
|
22
|
-
import numpy as np
|
|
23
|
-
import pandas as pd
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class ConstraintChecker(ConfigBase, ABC):
|
|
27
|
-
constraint: Constraint
|
|
28
|
-
|
|
29
|
-
def get_required_column_names(self) -> tuple[str, ...]:
|
|
30
|
-
return (self.constraint.target_column,)
|
|
31
|
-
|
|
32
|
-
@abstractmethod
|
|
33
|
-
def check(self, dataframe: pd.DataFrame) -> NDArray[np.bool_]: ...
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
class WithCompareMixin:
|
|
37
|
-
@property
|
|
38
|
-
def lhs(self) -> str:
|
|
39
|
-
return self.constraint.target_column
|
|
40
|
-
|
|
41
|
-
def compare(self, lhs: float | int | NDArray, rhs: float | int | NDArray) -> bool | NDArray[np.bool_]:
|
|
42
|
-
operator = {
|
|
43
|
-
InequalityOperator.LT: np.less,
|
|
44
|
-
InequalityOperator.LE: np.less_equal,
|
|
45
|
-
InequalityOperator.GT: np.greater,
|
|
46
|
-
InequalityOperator.GE: np.greater_equal,
|
|
47
|
-
}[InequalityOperator(self.constraint.operator)]
|
|
48
|
-
return operator(lhs, rhs)
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
class ScalarInequalityChecker(ConstraintChecker, WithCompareMixin):
|
|
52
|
-
"""Compare a column to a scalar value.
|
|
53
|
-
|
|
54
|
-
Args:
|
|
55
|
-
column_name: Name of the constrained column. Will be
|
|
56
|
-
used as the left-hand side (lhs) of the comparison.
|
|
57
|
-
operator: Comparison operator.
|
|
58
|
-
rhs: Scalar value to compare against.
|
|
59
|
-
"""
|
|
60
|
-
|
|
61
|
-
constraint: ScalarInequalityConstraint
|
|
62
|
-
|
|
63
|
-
def check(self, dataframe: pd.DataFrame) -> NDArray[np.bool_]:
|
|
64
|
-
return self.compare(dataframe[self.lhs].values, self.constraint.rhs)
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
class ColumnInequalityChecker(ConstraintChecker, WithCompareMixin):
|
|
68
|
-
"""Compare the values of two columns.
|
|
69
|
-
|
|
70
|
-
Args:
|
|
71
|
-
column_name: Name of the constrained column. Will be
|
|
72
|
-
used as the left-hand side (lhs) of the comparison.
|
|
73
|
-
operator: Comparison operator.
|
|
74
|
-
rhs: Name of the column to compare against.
|
|
75
|
-
"""
|
|
76
|
-
|
|
77
|
-
constraint: ColumnInequalityConstraint
|
|
78
|
-
|
|
79
|
-
def get_required_column_names(self) -> tuple[str, ...]:
|
|
80
|
-
"""Return the names of columns required for the constraint.
|
|
81
|
-
|
|
82
|
-
Note that order matters. Edges in the DAG are created as column_names[1], column_names[0].
|
|
83
|
-
"""
|
|
84
|
-
return (self.lhs, self.constraint.rhs)
|
|
85
|
-
|
|
86
|
-
def check(self, dataframe: pd.DataFrame) -> NDArray[np.bool_]:
|
|
87
|
-
return self.compare(
|
|
88
|
-
dataframe[self.lhs].values,
|
|
89
|
-
dataframe[self.constraint.rhs].values,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
CONSTRAINT_TYPE_TO_CHECKER = {
|
|
94
|
-
ConstraintType.SCALAR_INEQUALITY: ScalarInequalityChecker,
|
|
95
|
-
ConstraintType.COLUMN_INEQUALITY: ColumnInequalityChecker,
|
|
96
|
-
}
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def get_constraint_checker(constraint_type: ConstraintType) -> type[ConstraintChecker]:
|
|
100
|
-
return CONSTRAINT_TYPE_TO_CHECKER[ConstraintType(constraint_type)]
|
|
@@ -1,217 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
from abc import ABC, abstractmethod
|
|
7
|
-
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
|
8
|
-
|
|
9
|
-
from numpy.typing import NDArray
|
|
10
|
-
|
|
11
|
-
from data_designer.config.sampler_params import SamplerParamsT
|
|
12
|
-
from data_designer.engine.sampling_gen.utils import check_random_state
|
|
13
|
-
from data_designer.lazy_heavy_imports import np, pd, scipy
|
|
14
|
-
|
|
15
|
-
if TYPE_CHECKING:
|
|
16
|
-
import numpy as np
|
|
17
|
-
import pandas as pd
|
|
18
|
-
import scipy
|
|
19
|
-
|
|
20
|
-
NumpyArray1dT = NDArray[Any]
|
|
21
|
-
RadomStateT = int | np.random.RandomState
|
|
22
|
-
|
|
23
|
-
GenericParamsT = TypeVar("GenericParamsT", bound=SamplerParamsT)
|
|
24
|
-
|
|
25
|
-
###########################################################
|
|
26
|
-
# Processing Mixins
|
|
27
|
-
# -----------------
|
|
28
|
-
# These mixins are used to apply pre and post processing
|
|
29
|
-
# to the data source output. At the moment, the only
|
|
30
|
-
# processing that is applied is an optional type/format
|
|
31
|
-
# conversion of the output data.
|
|
32
|
-
#
|
|
33
|
-
# Preprocessing: Applied *before* constraints are applied.
|
|
34
|
-
# Postprocessing: Applied at the end of dataset generation.
|
|
35
|
-
#
|
|
36
|
-
# IMPORTANT: These are only applied when the data are
|
|
37
|
-
# being injected into a DataFrame by the DatasetGenerator.
|
|
38
|
-
###########################################################
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
class PassthroughMixin:
|
|
42
|
-
@staticmethod
|
|
43
|
-
def preproc(series: pd.Series, convert_to: str) -> pd.Series:
|
|
44
|
-
return series
|
|
45
|
-
|
|
46
|
-
@staticmethod
|
|
47
|
-
def postproc(series: pd.Series, convert_to: str) -> pd.Series:
|
|
48
|
-
return series
|
|
49
|
-
|
|
50
|
-
@staticmethod
|
|
51
|
-
def validate_data_conversion(convert_to: str | None) -> None:
|
|
52
|
-
pass
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
class TypeConversionMixin:
|
|
56
|
-
"""Converts the data type of the output data.
|
|
57
|
-
|
|
58
|
-
This mixin applies the same conversion to both the pre and post
|
|
59
|
-
processing steps. The preprocessing is needed to ensure constraints
|
|
60
|
-
are applied to the correct data type. The postprocessing is needed
|
|
61
|
-
to ensure the final dtype is correct. For example, if the user wants an
|
|
62
|
-
`int`, we need to convert to `int` before applying constraints, but
|
|
63
|
-
the ints will be converted back to floats when injected into the
|
|
64
|
-
dataframe (assuming some rows are non-int values). We therefore need
|
|
65
|
-
to convert back to `int` after all constraints have been applied.
|
|
66
|
-
"""
|
|
67
|
-
|
|
68
|
-
@staticmethod
|
|
69
|
-
def preproc(series: pd.Series, convert_to: str) -> pd.Series:
|
|
70
|
-
if convert_to is not None:
|
|
71
|
-
if convert_to == "int":
|
|
72
|
-
series = series.round()
|
|
73
|
-
return series.astype(convert_to)
|
|
74
|
-
return series
|
|
75
|
-
|
|
76
|
-
@staticmethod
|
|
77
|
-
def postproc(series: pd.Series, convert_to: str | None) -> pd.Series:
|
|
78
|
-
if convert_to is not None:
|
|
79
|
-
if convert_to == "int":
|
|
80
|
-
series = series.round()
|
|
81
|
-
return series.astype(convert_to)
|
|
82
|
-
return series
|
|
83
|
-
|
|
84
|
-
@staticmethod
|
|
85
|
-
def validate_data_conversion(convert_to: str | None) -> None:
|
|
86
|
-
if convert_to is not None and convert_to not in ["float", "int", "str"]:
|
|
87
|
-
raise ValueError(f"Invalid `convert_to` value: {convert_to}. Must be one of: [float, int, str]")
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
class DatetimeFormatMixin:
|
|
91
|
-
@staticmethod
|
|
92
|
-
def preproc(series: pd.Series, convert_to: str | None) -> pd.Series:
|
|
93
|
-
return series
|
|
94
|
-
|
|
95
|
-
@staticmethod
|
|
96
|
-
def postproc(series: pd.Series, convert_to: str | None) -> pd.Series:
|
|
97
|
-
if convert_to is not None:
|
|
98
|
-
return series.dt.strftime(convert_to)
|
|
99
|
-
if series.dt.month.nunique() == 1:
|
|
100
|
-
return series.apply(lambda dt: dt.year).astype(str)
|
|
101
|
-
if series.dt.day.nunique() == 1:
|
|
102
|
-
return series.apply(lambda dt: dt.strftime("%Y-%m"))
|
|
103
|
-
if series.dt.hour.sum() > 0 or series.dt.minute.sum() > 0:
|
|
104
|
-
return series.apply(lambda dt: dt.isoformat()).astype(str)
|
|
105
|
-
if series.dt.second.sum() == 0:
|
|
106
|
-
return series.apply(lambda dt: dt.date()).astype(str)
|
|
107
|
-
return series.apply(lambda dt: dt.isoformat()).astype(str)
|
|
108
|
-
|
|
109
|
-
@staticmethod
|
|
110
|
-
def validate_data_conversion(convert_to: str | None) -> None:
|
|
111
|
-
if convert_to is not None:
|
|
112
|
-
try:
|
|
113
|
-
pd.to_datetime(pd.to_datetime("2012-12-21").strftime(convert_to))
|
|
114
|
-
except Exception as e:
|
|
115
|
-
raise ValueError(f"Invalid datetime format: {convert_to}. {e}")
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
###########################################################
|
|
119
|
-
# Base Data Source Classes
|
|
120
|
-
###########################################################
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
class DataSource(ABC, Generic[GenericParamsT]):
|
|
124
|
-
def __init__(
|
|
125
|
-
self,
|
|
126
|
-
params: GenericParamsT,
|
|
127
|
-
random_state: RadomStateT | None = None,
|
|
128
|
-
**kwargs,
|
|
129
|
-
):
|
|
130
|
-
self.rng = check_random_state(random_state)
|
|
131
|
-
self.params = self.get_param_type().model_validate(params)
|
|
132
|
-
self._setup(**kwargs)
|
|
133
|
-
self._validate()
|
|
134
|
-
|
|
135
|
-
@classmethod
|
|
136
|
-
def get_param_type(cls) -> type[GenericParamsT]:
|
|
137
|
-
return cls.__orig_bases__[-1].__args__[0]
|
|
138
|
-
|
|
139
|
-
@abstractmethod
|
|
140
|
-
def inject_data_column(
|
|
141
|
-
self,
|
|
142
|
-
dataframe: pd.DataFrame,
|
|
143
|
-
column_name: str,
|
|
144
|
-
index: list[int] | None = None,
|
|
145
|
-
) -> pd.DataFrame: ...
|
|
146
|
-
|
|
147
|
-
@staticmethod
|
|
148
|
-
@abstractmethod
|
|
149
|
-
def preproc(series: pd.Series) -> pd.Series: ...
|
|
150
|
-
|
|
151
|
-
@staticmethod
|
|
152
|
-
@abstractmethod
|
|
153
|
-
def postproc(series: pd.Series, convert_to: str | None) -> pd.Series: ...
|
|
154
|
-
|
|
155
|
-
@staticmethod
|
|
156
|
-
@abstractmethod
|
|
157
|
-
def validate_data_conversion(convert_to: str | None) -> None: ...
|
|
158
|
-
|
|
159
|
-
def get_required_column_names(self) -> tuple[str, ...]:
|
|
160
|
-
return tuple()
|
|
161
|
-
|
|
162
|
-
def _setup(self, **kwargs) -> None:
|
|
163
|
-
pass
|
|
164
|
-
|
|
165
|
-
def _validate(self) -> None:
|
|
166
|
-
pass
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
class Sampler(DataSource[GenericParamsT], ABC):
|
|
170
|
-
def _recast_types_if_needed(
|
|
171
|
-
self,
|
|
172
|
-
index: list[int] | slice,
|
|
173
|
-
column_name: str,
|
|
174
|
-
sample: NumpyArray1dT,
|
|
175
|
-
dataframe: pd.DataFrame,
|
|
176
|
-
) -> pd.DataFrame:
|
|
177
|
-
# Type may be different if the column has mixed types / NaNs.
|
|
178
|
-
if column_name in dataframe.columns:
|
|
179
|
-
dtype = sample.dtype
|
|
180
|
-
if dtype != dataframe.loc[index, column_name].dtype:
|
|
181
|
-
dataframe = dataframe.astype({column_name: dtype}, errors="ignore")
|
|
182
|
-
return dataframe
|
|
183
|
-
|
|
184
|
-
def inject_data_column(
|
|
185
|
-
self,
|
|
186
|
-
dataframe: pd.DataFrame,
|
|
187
|
-
column_name: str,
|
|
188
|
-
index: list[int] | None = None,
|
|
189
|
-
) -> pd.DataFrame:
|
|
190
|
-
index = slice(None) if index is None else index
|
|
191
|
-
|
|
192
|
-
if len(index) == 0:
|
|
193
|
-
return dataframe
|
|
194
|
-
|
|
195
|
-
sample = self.sample(len(index))
|
|
196
|
-
|
|
197
|
-
# Try recasting before assigning the sample to the dataframe, since setting an item
|
|
198
|
-
# of incompatible dtype is deprecated and will raise an error in future versions.
|
|
199
|
-
dataframe = self._recast_types_if_needed(index, column_name, sample, dataframe)
|
|
200
|
-
dataframe.loc[index, column_name] = sample
|
|
201
|
-
|
|
202
|
-
# Recast again in case the assignment led to inconsistencies (e.g., funny business from NaNs).
|
|
203
|
-
dataframe = self._recast_types_if_needed(index, column_name, sample, dataframe)
|
|
204
|
-
|
|
205
|
-
return dataframe
|
|
206
|
-
|
|
207
|
-
@abstractmethod
|
|
208
|
-
def sample(self, num_samples: int) -> NumpyArray1dT: ...
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
class ScipyStatsSampler(Sampler[GenericParamsT], ABC):
|
|
212
|
-
@property
|
|
213
|
-
@abstractmethod
|
|
214
|
-
def distribution(self) -> scipy.stats.rv_continuous | scipy.stats.rv_discrete: ...
|
|
215
|
-
|
|
216
|
-
def sample(self, num_samples: int) -> NumpyArray1dT:
|
|
217
|
-
return self.distribution.rvs(size=num_samples, random_state=self.rng)
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
from data_designer.engine.sampling_gen.errors import SamplingGenError
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class InvalidSamplerParamsError(SamplingGenError): ...
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class PersonSamplerConstraintsError(SamplingGenError): ...
|
|
@@ -1,347 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
import uuid
|
|
7
|
-
from typing import TYPE_CHECKING
|
|
8
|
-
|
|
9
|
-
from data_designer.config.sampler_params import (
|
|
10
|
-
BernoulliMixtureSamplerParams,
|
|
11
|
-
BernoulliSamplerParams,
|
|
12
|
-
BinomialSamplerParams,
|
|
13
|
-
CategorySamplerParams,
|
|
14
|
-
DatetimeSamplerParams,
|
|
15
|
-
GaussianSamplerParams,
|
|
16
|
-
PersonFromFakerSamplerParams,
|
|
17
|
-
PersonSamplerParams,
|
|
18
|
-
PoissonSamplerParams,
|
|
19
|
-
SamplerParamsT,
|
|
20
|
-
SamplerType,
|
|
21
|
-
ScipySamplerParams,
|
|
22
|
-
SubcategorySamplerParams,
|
|
23
|
-
TimeDeltaSamplerParams,
|
|
24
|
-
UniformSamplerParams,
|
|
25
|
-
UUIDSamplerParams,
|
|
26
|
-
)
|
|
27
|
-
from data_designer.engine.sampling_gen.data_sources.base import (
|
|
28
|
-
DataSource,
|
|
29
|
-
DatetimeFormatMixin,
|
|
30
|
-
NumpyArray1dT,
|
|
31
|
-
PassthroughMixin,
|
|
32
|
-
Sampler,
|
|
33
|
-
ScipyStatsSampler,
|
|
34
|
-
TypeConversionMixin,
|
|
35
|
-
)
|
|
36
|
-
from data_designer.engine.sampling_gen.data_sources.errors import (
|
|
37
|
-
InvalidSamplerParamsError,
|
|
38
|
-
PersonSamplerConstraintsError,
|
|
39
|
-
)
|
|
40
|
-
from data_designer.engine.sampling_gen.entities.dataset_based_person_fields import PERSONA_FIELDS, PII_FIELDS
|
|
41
|
-
from data_designer.engine.sampling_gen.people_gen import PeopleGen
|
|
42
|
-
from data_designer.lazy_heavy_imports import np, pd, scipy
|
|
43
|
-
|
|
44
|
-
if TYPE_CHECKING:
|
|
45
|
-
import numpy as np
|
|
46
|
-
import pandas as pd
|
|
47
|
-
import scipy
|
|
48
|
-
|
|
49
|
-
ONE_BILLION = 10**9
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
class SamplerRegistry:
|
|
53
|
-
_registry: dict[str, type] = {}
|
|
54
|
-
_reverse_registry: dict[type, str] = {}
|
|
55
|
-
_params_registry: dict[type, type] = {}
|
|
56
|
-
|
|
57
|
-
@classmethod
|
|
58
|
-
def register(cls, alias: str):
|
|
59
|
-
def decorator(wrapped_class: type[DataSource[SamplerParamsT]]) -> type:
|
|
60
|
-
cls._registry[alias] = wrapped_class
|
|
61
|
-
cls._reverse_registry[wrapped_class] = alias
|
|
62
|
-
cls._params_registry[wrapped_class.get_param_type()] = wrapped_class
|
|
63
|
-
return wrapped_class
|
|
64
|
-
|
|
65
|
-
return decorator
|
|
66
|
-
|
|
67
|
-
@classmethod
|
|
68
|
-
def get_sampler(cls, alias: str) -> type[DataSource[SamplerParamsT]]:
|
|
69
|
-
return cls._registry[alias.lower()]
|
|
70
|
-
|
|
71
|
-
@classmethod
|
|
72
|
-
def get_sampler_for_params(cls, params_type: SamplerParamsT) -> type[DataSource[SamplerParamsT]]:
|
|
73
|
-
return cls._params_registry[type(params_type)]
|
|
74
|
-
|
|
75
|
-
@classmethod
|
|
76
|
-
def get_sampler_alias_for_params(cls, params_type: SamplerParamsT) -> str:
|
|
77
|
-
return cls._reverse_registry[cls._params_registry[type(params_type)]]
|
|
78
|
-
|
|
79
|
-
@classmethod
|
|
80
|
-
def is_registered(cls, alias: str) -> bool:
|
|
81
|
-
return alias in cls._registry
|
|
82
|
-
|
|
83
|
-
@classmethod
|
|
84
|
-
def validate_sampler_type(
|
|
85
|
-
cls, sampler_type: str | type[DataSource[SamplerParamsT]]
|
|
86
|
-
) -> type[DataSource[SamplerParamsT]]:
|
|
87
|
-
if isinstance(sampler_type, str):
|
|
88
|
-
if sampler_type not in cls._registry:
|
|
89
|
-
raise ValueError(
|
|
90
|
-
f"Sampler type `{sampler_type}` not found in the registry. "
|
|
91
|
-
f"Available samplers: {list(cls._registry.keys())}"
|
|
92
|
-
)
|
|
93
|
-
sampler_type = cls.get_sampler(sampler_type)
|
|
94
|
-
if not issubclass(sampler_type, DataSource):
|
|
95
|
-
raise ValueError(f"Sampler type `{sampler_type}` is not a subclass of `DataSource`")
|
|
96
|
-
return sampler_type
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
#########################################
|
|
100
|
-
# Data Source Subclasses
|
|
101
|
-
#########################################
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
@SamplerRegistry.register(SamplerType.SUBCATEGORY)
|
|
105
|
-
class SubcategorySampler(TypeConversionMixin, DataSource[SubcategorySamplerParams]):
|
|
106
|
-
def get_required_column_names(self) -> tuple[str, ...]:
|
|
107
|
-
return (self.params.category,)
|
|
108
|
-
|
|
109
|
-
def inject_data_column(
|
|
110
|
-
self,
|
|
111
|
-
dataframe: pd.DataFrame,
|
|
112
|
-
column_name: str,
|
|
113
|
-
index: list[int] | None = None,
|
|
114
|
-
) -> pd.DataFrame:
|
|
115
|
-
index = slice(None) if index is None else index
|
|
116
|
-
|
|
117
|
-
if len(index) == 0:
|
|
118
|
-
return dataframe
|
|
119
|
-
|
|
120
|
-
dataframe.loc[index, column_name] = dataframe.loc[index, self.params.category].apply(
|
|
121
|
-
lambda cat_value: self.rng.choice(self.params.values[cat_value])
|
|
122
|
-
)
|
|
123
|
-
|
|
124
|
-
return dataframe
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
#########################################
|
|
128
|
-
# Sampler Subclasses
|
|
129
|
-
#########################################
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
@SamplerRegistry.register(SamplerType.CATEGORY)
|
|
133
|
-
class CategorySampler(TypeConversionMixin, Sampler[CategorySamplerParams]):
|
|
134
|
-
def sample(self, num_samples: int) -> NumpyArray1dT:
|
|
135
|
-
return self.rng.choice(self.params.values, size=num_samples, p=self.params.weights)
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
@SamplerRegistry.register(SamplerType.DATETIME)
|
|
139
|
-
class DatetimeSampler(DatetimeFormatMixin, Sampler[DatetimeSamplerParams]):
|
|
140
|
-
def sample(self, num_samples: int) -> NumpyArray1dT:
|
|
141
|
-
# Convert nanoseconds to seconds.
|
|
142
|
-
start_sec = pd.to_datetime(self.params.start).value // ONE_BILLION
|
|
143
|
-
end_sec = pd.to_datetime(self.params.end).value // ONE_BILLION
|
|
144
|
-
|
|
145
|
-
random_ns = (ONE_BILLION * self.rng.randint(start_sec, end_sec, num_samples, dtype=np.int64)).view(
|
|
146
|
-
"datetime64[ns]"
|
|
147
|
-
)
|
|
148
|
-
|
|
149
|
-
return np.array(random_ns, dtype=f"datetime64[{self.params.unit}]")
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
@SamplerRegistry.register(SamplerType.PERSON)
|
|
153
|
-
class PersonSampler(PassthroughMixin, Sampler[PersonSamplerParams]):
|
|
154
|
-
def _setup(self, **kwargs) -> None:
|
|
155
|
-
self._generator = None
|
|
156
|
-
self._fixed_kwargs = {}
|
|
157
|
-
for field in self.params.generator_kwargs:
|
|
158
|
-
if getattr(self.params, field) is not None:
|
|
159
|
-
attr = getattr(self.params, field)
|
|
160
|
-
if field == "select_field_values":
|
|
161
|
-
for key, value in attr.items():
|
|
162
|
-
if key == "state" and self.params.locale == "en_US":
|
|
163
|
-
key = "region" # This is the field name in the census-based person dataset.
|
|
164
|
-
if key not in PII_FIELDS + PERSONA_FIELDS:
|
|
165
|
-
raise ValueError(f"Invalid field name: {key}")
|
|
166
|
-
self._fixed_kwargs[key] = value
|
|
167
|
-
else:
|
|
168
|
-
self._fixed_kwargs[field] = attr
|
|
169
|
-
if people_gen_resource := kwargs.get("people_gen_resource"):
|
|
170
|
-
if self.params.people_gen_key not in people_gen_resource:
|
|
171
|
-
raise ValueError(f"Person generator with key {self.params.people_gen_key} not found.")
|
|
172
|
-
self.set_generator(people_gen_resource[self.params.people_gen_key])
|
|
173
|
-
|
|
174
|
-
def set_generator(self, generator: PeopleGen) -> None:
|
|
175
|
-
self._generator = generator
|
|
176
|
-
|
|
177
|
-
def sample(self, num_samples: int) -> NumpyArray1dT:
|
|
178
|
-
if self._generator is None:
|
|
179
|
-
raise ValueError("Generator not set. Please setup generator before sampling.")
|
|
180
|
-
|
|
181
|
-
samples = np.array(self._generator.generate(num_samples, **self._fixed_kwargs))
|
|
182
|
-
if len(samples) < num_samples:
|
|
183
|
-
raise PersonSamplerConstraintsError(
|
|
184
|
-
f"🛑 Only {len(samples)} samples could be generated with the given settings: {self._fixed_kwargs!r}. "
|
|
185
|
-
"This is likely because the filter values are too strict. Person sampling does not support "
|
|
186
|
-
"rare combinations of field values. Please loosen the constraints and try again."
|
|
187
|
-
)
|
|
188
|
-
return samples
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
@SamplerRegistry.register(SamplerType.PERSON_FROM_FAKER)
|
|
192
|
-
class PersonFromFakerSampler(PassthroughMixin, Sampler[PersonFromFakerSamplerParams]):
|
|
193
|
-
def _setup(self, **kwargs) -> None:
|
|
194
|
-
self._generator = None
|
|
195
|
-
self._fixed_kwargs = {}
|
|
196
|
-
for field in self.params.generator_kwargs:
|
|
197
|
-
if getattr(self.params, field) is not None:
|
|
198
|
-
self._fixed_kwargs[field] = getattr(self.params, field)
|
|
199
|
-
if people_gen_resource := kwargs.get("people_gen_resource"):
|
|
200
|
-
if self.params.people_gen_key not in people_gen_resource:
|
|
201
|
-
raise ValueError(f"Person generator with key {self.params.people_gen_key} not found.")
|
|
202
|
-
self.set_generator(people_gen_resource[self.params.people_gen_key])
|
|
203
|
-
|
|
204
|
-
def set_generator(self, generator: PeopleGen) -> None:
|
|
205
|
-
self._generator = generator
|
|
206
|
-
|
|
207
|
-
def sample(self, num_samples: int) -> NumpyArray1dT:
|
|
208
|
-
if self._generator is None:
|
|
209
|
-
raise ValueError("Generator not set. Please setup generator before sampling.")
|
|
210
|
-
|
|
211
|
-
samples = np.array(self._generator.generate(num_samples, **self._fixed_kwargs))
|
|
212
|
-
if len(samples) < num_samples:
|
|
213
|
-
raise ValueError(f"Only {len(samples)} samples could be generated given constraints {self._fixed_kwargs}.")
|
|
214
|
-
return samples
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
@SamplerRegistry.register(SamplerType.TIMEDELTA)
|
|
218
|
-
class TimeDeltaSampler(DatetimeFormatMixin, Sampler[TimeDeltaSamplerParams]):
|
|
219
|
-
def get_required_column_names(self) -> tuple[str, ...]:
|
|
220
|
-
return (self.params.reference_column_name,)
|
|
221
|
-
|
|
222
|
-
def inject_data_column(
|
|
223
|
-
self,
|
|
224
|
-
dataframe: pd.DataFrame,
|
|
225
|
-
column_name: str,
|
|
226
|
-
index: list[int] | None = None,
|
|
227
|
-
) -> pd.DataFrame:
|
|
228
|
-
index = slice(None) if index is None else index
|
|
229
|
-
|
|
230
|
-
if self.params.reference_column_name not in list(dataframe):
|
|
231
|
-
raise ValueError(f"Columns `{self.params.reference_column_name}` not found in dataset")
|
|
232
|
-
|
|
233
|
-
dataframe.loc[index, column_name] = pd.to_datetime(
|
|
234
|
-
dataframe.loc[index, self.params.reference_column_name]
|
|
235
|
-
) + pd.to_timedelta(self.sample(len(index)), unit=self.params.unit)
|
|
236
|
-
|
|
237
|
-
return dataframe
|
|
238
|
-
|
|
239
|
-
def sample(self, num_samples: int) -> NumpyArray1dT:
|
|
240
|
-
deltas = self.rng.randint(self.params.dt_min, self.params.dt_max, num_samples)
|
|
241
|
-
return np.array(deltas, dtype=f"timedelta64[{self.params.unit}]")
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
@SamplerRegistry.register(SamplerType.UUID)
|
|
245
|
-
class UUIDSampler(PassthroughMixin, Sampler[UUIDSamplerParams]):
|
|
246
|
-
def sample(self, num_samples: int) -> NumpyArray1dT:
|
|
247
|
-
prefix = self.params.prefix or ""
|
|
248
|
-
|
|
249
|
-
uid_list = []
|
|
250
|
-
while len(uid_list) < num_samples:
|
|
251
|
-
uid = (
|
|
252
|
-
f"{prefix}{uuid.uuid4().hex[: self.params.last_index].upper()}"
|
|
253
|
-
if self.params.uppercase
|
|
254
|
-
else f"{prefix}{uuid.uuid4().hex[: self.params.last_index]}"
|
|
255
|
-
)
|
|
256
|
-
if uid not in uid_list:
|
|
257
|
-
uid_list.append(uid)
|
|
258
|
-
|
|
259
|
-
return np.array(uid_list)
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
#########################################
|
|
263
|
-
# Scipy Samplers
|
|
264
|
-
#########################################
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
@SamplerRegistry.register(SamplerType.SCIPY)
|
|
268
|
-
class ScipySampler(TypeConversionMixin, ScipyStatsSampler[ScipySamplerParams]):
|
|
269
|
-
"""Escape hatch sampler to give users access to any scipy.stats distribution."""
|
|
270
|
-
|
|
271
|
-
@property
|
|
272
|
-
def distribution(self) -> scipy.stats.rv_continuous | scipy.stats.rv_discrete:
|
|
273
|
-
return getattr(scipy.stats, self.params.dist_name)(**self.params.dist_params)
|
|
274
|
-
|
|
275
|
-
def _validate(self) -> None:
|
|
276
|
-
_validate_scipy_distribution(self.params.dist_name, self.params.dist_params)
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
@SamplerRegistry.register(SamplerType.BERNOULLI)
|
|
280
|
-
class BernoulliSampler(TypeConversionMixin, ScipyStatsSampler[BernoulliSamplerParams]):
|
|
281
|
-
@property
|
|
282
|
-
def distribution(self) -> scipy.stats.rv_discrete:
|
|
283
|
-
return scipy.stats.bernoulli(p=self.params.p)
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
@SamplerRegistry.register(SamplerType.BERNOULLI_MIXTURE)
|
|
287
|
-
class BernoulliMixtureSampler(TypeConversionMixin, Sampler[BernoulliMixtureSamplerParams]):
|
|
288
|
-
def sample(self, num_samples: int) -> NumpyArray1dT:
|
|
289
|
-
return scipy.stats.bernoulli(p=self.params.p).rvs(size=num_samples) * getattr(
|
|
290
|
-
scipy.stats, self.params.dist_name
|
|
291
|
-
)(**self.params.dist_params).rvs(size=num_samples)
|
|
292
|
-
|
|
293
|
-
def _validate(self) -> None:
|
|
294
|
-
_validate_scipy_distribution(self.params.dist_name, self.params.dist_params)
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
@SamplerRegistry.register(SamplerType.BINOMIAL)
|
|
298
|
-
class BinomialSampler(TypeConversionMixin, ScipyStatsSampler[BinomialSamplerParams]):
|
|
299
|
-
@property
|
|
300
|
-
def distribution(self) -> scipy.stats.rv_discrete:
|
|
301
|
-
return scipy.stats.binom(n=self.params.n, p=self.params.p)
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
@SamplerRegistry.register(SamplerType.GAUSSIAN)
|
|
305
|
-
class GaussianSampler(TypeConversionMixin, ScipyStatsSampler[GaussianSamplerParams]):
|
|
306
|
-
@property
|
|
307
|
-
def distribution(self) -> scipy.stats.rv_continuous:
|
|
308
|
-
return scipy.stats.norm(loc=self.params.mean, scale=self.params.stddev)
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
@SamplerRegistry.register(SamplerType.POISSON)
|
|
312
|
-
class PoissonSampler(TypeConversionMixin, ScipyStatsSampler[PoissonSamplerParams]):
|
|
313
|
-
@property
|
|
314
|
-
def distribution(self) -> scipy.stats.rv_discrete:
|
|
315
|
-
return scipy.stats.poisson(mu=self.params.mean)
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
@SamplerRegistry.register(SamplerType.UNIFORM)
|
|
319
|
-
class UniformSampler(TypeConversionMixin, ScipyStatsSampler[UniformSamplerParams]):
|
|
320
|
-
@property
|
|
321
|
-
def distribution(self) -> scipy.stats.rv_continuous:
|
|
322
|
-
return scipy.stats.uniform(loc=self.params.low, scale=self.params.high - self.params.low)
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
###################################################
|
|
326
|
-
# Helper functions for loading sources in isolation
|
|
327
|
-
###################################################
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
def load_sampler(sampler_type: SamplerType, **params) -> DataSource:
|
|
331
|
-
"""Load a data source from a source type and parameters."""
|
|
332
|
-
return SamplerRegistry.validate_sampler_type(sampler_type)(params=params)
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
def _validate_scipy_distribution(dist_name: str, dist_params: dict) -> None:
|
|
336
|
-
if not hasattr(scipy.stats, dist_name):
|
|
337
|
-
raise InvalidSamplerParamsError(f"Distribution {dist_name} not found in scipy.stats")
|
|
338
|
-
if not hasattr(getattr(scipy.stats, dist_name), "rvs"):
|
|
339
|
-
raise InvalidSamplerParamsError(
|
|
340
|
-
f"Distribution {dist_name} does not have a `rvs` method, which is required for sampling."
|
|
341
|
-
)
|
|
342
|
-
try:
|
|
343
|
-
getattr(scipy.stats, dist_name)(**dist_params)
|
|
344
|
-
except Exception:
|
|
345
|
-
raise InvalidSamplerParamsError(
|
|
346
|
-
f"Distribution parameters {dist_params} are not a valid for distribution '{dist_name}'"
|
|
347
|
-
)
|
|
Binary file
|