data-designer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +15 -0
- data_designer/_version.py +34 -0
- data_designer/cli/README.md +236 -0
- data_designer/cli/__init__.py +6 -0
- data_designer/cli/commands/__init__.py +2 -0
- data_designer/cli/commands/list.py +130 -0
- data_designer/cli/commands/models.py +10 -0
- data_designer/cli/commands/providers.py +11 -0
- data_designer/cli/commands/reset.py +100 -0
- data_designer/cli/controllers/__init__.py +7 -0
- data_designer/cli/controllers/model_controller.py +246 -0
- data_designer/cli/controllers/provider_controller.py +317 -0
- data_designer/cli/forms/__init__.py +20 -0
- data_designer/cli/forms/builder.py +51 -0
- data_designer/cli/forms/field.py +180 -0
- data_designer/cli/forms/form.py +59 -0
- data_designer/cli/forms/model_builder.py +125 -0
- data_designer/cli/forms/provider_builder.py +76 -0
- data_designer/cli/main.py +44 -0
- data_designer/cli/repositories/__init__.py +8 -0
- data_designer/cli/repositories/base.py +39 -0
- data_designer/cli/repositories/model_repository.py +42 -0
- data_designer/cli/repositories/provider_repository.py +43 -0
- data_designer/cli/services/__init__.py +7 -0
- data_designer/cli/services/model_service.py +116 -0
- data_designer/cli/services/provider_service.py +111 -0
- data_designer/cli/ui.py +448 -0
- data_designer/cli/utils.py +47 -0
- data_designer/config/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +89 -0
- data_designer/config/analysis/column_statistics.py +274 -0
- data_designer/config/analysis/dataset_profiler.py +60 -0
- data_designer/config/analysis/utils/errors.py +8 -0
- data_designer/config/analysis/utils/reporting.py +188 -0
- data_designer/config/base.py +68 -0
- data_designer/config/column_configs.py +354 -0
- data_designer/config/column_types.py +168 -0
- data_designer/config/config_builder.py +660 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +11 -0
- data_designer/config/datastore.py +151 -0
- data_designer/config/default_model_settings.py +123 -0
- data_designer/config/errors.py +19 -0
- data_designer/config/interface.py +54 -0
- data_designer/config/models.py +231 -0
- data_designer/config/preview_results.py +32 -0
- data_designer/config/processors.py +41 -0
- data_designer/config/sampler_constraints.py +51 -0
- data_designer/config/sampler_params.py +604 -0
- data_designer/config/seed.py +145 -0
- data_designer/config/utils/code_lang.py +83 -0
- data_designer/config/utils/constants.py +313 -0
- data_designer/config/utils/errors.py +19 -0
- data_designer/config/utils/info.py +88 -0
- data_designer/config/utils/io_helpers.py +273 -0
- data_designer/config/utils/misc.py +81 -0
- data_designer/config/utils/numerical_helpers.py +28 -0
- data_designer/config/utils/type_helpers.py +100 -0
- data_designer/config/utils/validation.py +336 -0
- data_designer/config/utils/visualization.py +427 -0
- data_designer/config/validator_params.py +96 -0
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +55 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
- data_designer/engine/analysis/column_profilers/registry.py +20 -0
- data_designer/engine/analysis/column_statistics.py +142 -0
- data_designer/engine/analysis/dataset_profiler.py +125 -0
- data_designer/engine/analysis/errors.py +7 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +61 -0
- data_designer/engine/column_generators/generators/expression.py +63 -0
- data_designer/engine/column_generators/generators/llm_generators.py +172 -0
- data_designer/engine/column_generators/generators/samplers.py +75 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
- data_designer/engine/column_generators/generators/validation.py +147 -0
- data_designer/engine/column_generators/registry.py +56 -0
- data_designer/engine/column_generators/utils/errors.py +13 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
- data_designer/engine/configurable_task.py +82 -0
- data_designer/engine/dataset_builders/artifact_storage.py +181 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
- data_designer/engine/dataset_builders/errors.py +13 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
- data_designer/engine/dataset_builders/utils/dag.py +56 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
- data_designer/engine/dataset_builders/utils/errors.py +13 -0
- data_designer/engine/errors.py +49 -0
- data_designer/engine/model_provider.py +75 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +308 -0
- data_designer/engine/models/facade.py +225 -0
- data_designer/engine/models/litellm_overrides.py +162 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +236 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +60 -0
- data_designer/engine/models/parsers/types.py +82 -0
- data_designer/engine/models/recipes/base.py +79 -0
- data_designer/engine/models/recipes/response_recipes.py +291 -0
- data_designer/engine/models/registry.py +118 -0
- data_designer/engine/models/usage.py +75 -0
- data_designer/engine/models/utils.py +38 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +64 -0
- data_designer/engine/processing/ginja/environment.py +461 -0
- data_designer/engine/processing/ginja/exceptions.py +54 -0
- data_designer/engine/processing/ginja/record.py +30 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +8 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
- data_designer/engine/processing/gsonschema/types.py +8 -0
- data_designer/engine/processing/gsonschema/validators.py +143 -0
- data_designer/engine/processing/processors/base.py +15 -0
- data_designer/engine/processing/processors/drop_columns.py +46 -0
- data_designer/engine/processing/processors/registry.py +20 -0
- data_designer/engine/processing/utils.py +120 -0
- data_designer/engine/registry/base.py +97 -0
- data_designer/engine/registry/data_designer_registry.py +37 -0
- data_designer/engine/registry/errors.py +10 -0
- data_designer/engine/resources/managed_dataset_generator.py +35 -0
- data_designer/engine/resources/managed_dataset_repository.py +194 -0
- data_designer/engine/resources/managed_storage.py +63 -0
- data_designer/engine/resources/resource_provider.py +46 -0
- data_designer/engine/resources/seed_dataset_data_store.py +66 -0
- data_designer/engine/sampling_gen/column.py +89 -0
- data_designer/engine/sampling_gen/constraints.py +95 -0
- data_designer/engine/sampling_gen/data_sources/base.py +214 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
- data_designer/engine/sampling_gen/entities/errors.py +8 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
- data_designer/engine/sampling_gen/entities/person.py +142 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
- data_designer/engine/sampling_gen/errors.py +24 -0
- data_designer/engine/sampling_gen/generator.py +121 -0
- data_designer/engine/sampling_gen/jinja_utils.py +60 -0
- data_designer/engine/sampling_gen/people_gen.py +203 -0
- data_designer/engine/sampling_gen/person_constants.py +54 -0
- data_designer/engine/sampling_gen/schema.py +143 -0
- data_designer/engine/sampling_gen/schema_builder.py +59 -0
- data_designer/engine/sampling_gen/utils.py +40 -0
- data_designer/engine/secret_resolver.py +80 -0
- data_designer/engine/validators/__init__.py +17 -0
- data_designer/engine/validators/base.py +36 -0
- data_designer/engine/validators/local_callable.py +34 -0
- data_designer/engine/validators/python.py +245 -0
- data_designer/engine/validators/remote.py +83 -0
- data_designer/engine/validators/sql.py +60 -0
- data_designer/errors.py +5 -0
- data_designer/essentials/__init__.py +137 -0
- data_designer/interface/__init__.py +2 -0
- data_designer/interface/data_designer.py +351 -0
- data_designer/interface/errors.py +16 -0
- data_designer/interface/results.py +55 -0
- data_designer/logging.py +161 -0
- data_designer/plugin_manager.py +83 -0
- data_designer/plugins/__init__.py +6 -0
- data_designer/plugins/errors.py +10 -0
- data_designer/plugins/plugin.py +69 -0
- data_designer/plugins/registry.py +86 -0
- data_designer-0.1.0.dist-info/METADATA +173 -0
- data_designer-0.1.0.dist-info/RECORD +177 -0
- data_designer-0.1.0.dist-info/WHEEL +4 -0
- data_designer-0.1.0.dist-info/entry_points.txt +2 -0
- data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from datetime import date
|
|
5
|
+
import random
|
|
6
|
+
import re
|
|
7
|
+
|
|
8
|
+
import anyascii
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_email_address(
|
|
12
|
+
first_name: str,
|
|
13
|
+
middle_name: str | None,
|
|
14
|
+
last_name: str,
|
|
15
|
+
age: int,
|
|
16
|
+
birth_date: date,
|
|
17
|
+
) -> str:
|
|
18
|
+
"""
|
|
19
|
+
Generate an email address based on a person's attributes.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
domain = get_email_domain_by_age(age)
|
|
23
|
+
username_base = get_email_basename_by_name(first_name, middle_name, last_name)
|
|
24
|
+
suffix = get_email_suffix_by_birth_date(birth_date)
|
|
25
|
+
|
|
26
|
+
# Combine to form email
|
|
27
|
+
return f"{username_base}{suffix}@{domain}"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_email_domain_by_age(age: int) -> str:
|
|
31
|
+
"""
|
|
32
|
+
Get a free email domain heuristically dependent on
|
|
33
|
+
overall number of subscribers and user age.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
# Common free email domains
|
|
37
|
+
# Source: https://www.sellcell.com/blog/most-popular-email-provider-by-number-of-users/
|
|
38
|
+
# Split heuristically into age demographics
|
|
39
|
+
# Also adjusted to maintain the approximate 38/27/35 split between these groups
|
|
40
|
+
email_domains_under_30 = {
|
|
41
|
+
"gmail.com": 710, # gmail.com total: 1500
|
|
42
|
+
"icloud.com": 300, # icloud.com total: 850
|
|
43
|
+
"outlook.com": 50, # outlook.com total: 200
|
|
44
|
+
"hotmail.com": 40, # hotmail.com total: 200
|
|
45
|
+
"yahoo.com": 35, # yahoo.com total: 230
|
|
46
|
+
"protonmail.com": 20, # protonmail.com total: 50
|
|
47
|
+
"zoho.com": 3, # zoho.com total: 15
|
|
48
|
+
"gmx.com": 3, # gmx.com total: 11
|
|
49
|
+
"aol.com": 0.1, # aol.com total: 1.5
|
|
50
|
+
}
|
|
51
|
+
email_domains_30_50 = {
|
|
52
|
+
"gmail.com": 360,
|
|
53
|
+
"icloud.com": 270,
|
|
54
|
+
"outlook.com": 60,
|
|
55
|
+
"hotmail.com": 50,
|
|
56
|
+
"yahoo.com": 60,
|
|
57
|
+
"protonmail.com": 18,
|
|
58
|
+
"zoho.com": 7,
|
|
59
|
+
"gmx.com": 4,
|
|
60
|
+
"aol.com": 0.3,
|
|
61
|
+
}
|
|
62
|
+
email_domains_over_50 = {
|
|
63
|
+
"gmail.com": 430,
|
|
64
|
+
"icloud.com": 280,
|
|
65
|
+
"outlook.com": 90,
|
|
66
|
+
"hotmail.com": 110,
|
|
67
|
+
"yahoo.com": 135,
|
|
68
|
+
"protonmail.com": 12,
|
|
69
|
+
"zoho.com": 5,
|
|
70
|
+
"gmx.com": 4,
|
|
71
|
+
"aol.com": 1.1,
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
if age < 30:
|
|
75
|
+
return random.choices(
|
|
76
|
+
list(email_domains_under_30.keys()),
|
|
77
|
+
weights=list(email_domains_under_30.values()),
|
|
78
|
+
k=1,
|
|
79
|
+
)[0]
|
|
80
|
+
elif age < 50:
|
|
81
|
+
return random.choices(
|
|
82
|
+
list(email_domains_30_50.keys()),
|
|
83
|
+
weights=list(email_domains_30_50.values()),
|
|
84
|
+
k=1,
|
|
85
|
+
)[0]
|
|
86
|
+
else:
|
|
87
|
+
return random.choices(
|
|
88
|
+
list(email_domains_over_50.keys()),
|
|
89
|
+
weights=list(email_domains_over_50.values()),
|
|
90
|
+
k=1,
|
|
91
|
+
)[0]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def get_email_basename_by_name(first_name: str, middle_name: str | None, last_name: str) -> str:
|
|
95
|
+
"""
|
|
96
|
+
Get a email address basename heuristically dependent on first and last name.
|
|
97
|
+
|
|
98
|
+
Patterns include:
|
|
99
|
+
- firstname.lastname
|
|
100
|
+
- firstnamelastname
|
|
101
|
+
- firstinitiallastname
|
|
102
|
+
- firstname_lastname
|
|
103
|
+
- lastnamefirstinitial
|
|
104
|
+
- firstnamelastinitial
|
|
105
|
+
- firstnamemiddlename
|
|
106
|
+
- firstnamemiddleinitiallastname
|
|
107
|
+
- firstnamemiddlenamelastname
|
|
108
|
+
"""
|
|
109
|
+
# Normalize names (lowercase, remove spaces and special chars)
|
|
110
|
+
first = re.sub(r"[^a-z0-9]", "", anyascii.anyascii(first_name).lower())
|
|
111
|
+
last = re.sub(r"[^a-z0-9]", "", anyascii.anyascii(last_name).lower())
|
|
112
|
+
assert len(first) > 0 and len(last) > 0, (
|
|
113
|
+
"Both first and last name must be non-empty, after removing non-alphanumeric."
|
|
114
|
+
)
|
|
115
|
+
first_initial = first[0]
|
|
116
|
+
last_initial = last[0]
|
|
117
|
+
|
|
118
|
+
# Generate username patterns
|
|
119
|
+
username_patterns = [
|
|
120
|
+
f"{first}.{last}",
|
|
121
|
+
f"{first}{last}",
|
|
122
|
+
f"{first_initial}{last}",
|
|
123
|
+
f"{first}_{last}",
|
|
124
|
+
f"{last}{first_initial}",
|
|
125
|
+
f"{first}{last_initial}",
|
|
126
|
+
]
|
|
127
|
+
# Higher probability for more common patterns
|
|
128
|
+
pattern_weights = [0.3, 0.2, 0.15, 0.1, 0.15, 0.1]
|
|
129
|
+
if middle_name:
|
|
130
|
+
middle = re.sub(r"[^a-z0-9]", "", anyascii.anyascii(middle_name).lower())
|
|
131
|
+
middle_initial = middle[0]
|
|
132
|
+
username_patterns.extend(
|
|
133
|
+
[
|
|
134
|
+
f"{first}{middle}",
|
|
135
|
+
f"{first}{middle_initial}{last}",
|
|
136
|
+
f"{first}{middle}{last}",
|
|
137
|
+
]
|
|
138
|
+
)
|
|
139
|
+
pattern_weights = [0.25, 0.17, 0.12, 0.08, 0.12, 0.08, 0.06, 0.06, 0.06]
|
|
140
|
+
|
|
141
|
+
return random.choices(username_patterns, weights=pattern_weights, k=1)[0]
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def get_email_suffix_by_birth_date(birth_date: date) -> str:
|
|
145
|
+
"""
|
|
146
|
+
Get a email address suffix heuristically dependent on birth date.
|
|
147
|
+
|
|
148
|
+
Suffices include:
|
|
149
|
+
- Empty
|
|
150
|
+
- Random 1-2 digit number
|
|
151
|
+
- Last 2 digits of birth year
|
|
152
|
+
- Full birth year
|
|
153
|
+
- Birth day
|
|
154
|
+
"""
|
|
155
|
+
# Suffix patterns (could be empty)
|
|
156
|
+
birth_day = birth_date.day
|
|
157
|
+
birth_year = birth_date.year
|
|
158
|
+
birth_year_short = birth_year % 100
|
|
159
|
+
suffix_patterns = [
|
|
160
|
+
"",
|
|
161
|
+
str(random.randint(1, 99)),
|
|
162
|
+
f"{birth_year_short:02d}",
|
|
163
|
+
str(birth_date.year),
|
|
164
|
+
str(birth_day),
|
|
165
|
+
]
|
|
166
|
+
suffix_weights = [0.4, 0.3, 0.1, 0.1, 0.1]
|
|
167
|
+
|
|
168
|
+
# Select pattern and suffix based on weights
|
|
169
|
+
return random.choices(suffix_patterns, weights=suffix_weights, k=1)[0]
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.errors import DataDesignerError
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MissingPersonFieldsError(DataDesignerError):
|
|
8
|
+
"""Exception for all errors related to missing person fields."""
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from datetime import date
|
|
5
|
+
import random
|
|
6
|
+
|
|
7
|
+
SSN_RANDOMIZATION_DATE = date(2011, 6, 25)
|
|
8
|
+
|
|
9
|
+
# Area number mapping by state code (pre-2011)
|
|
10
|
+
STATE_TO_AREA_SSN = {
|
|
11
|
+
"NH": [1, 3],
|
|
12
|
+
"ME": [4, 7],
|
|
13
|
+
"VT": [8, 9],
|
|
14
|
+
"MA": [10, 34],
|
|
15
|
+
"RI": [35, 39],
|
|
16
|
+
"CT": [40, 49],
|
|
17
|
+
"NY": [50, 134],
|
|
18
|
+
"NJ": [135, 158],
|
|
19
|
+
"PA": [159, 211],
|
|
20
|
+
"MD": [212, 220],
|
|
21
|
+
"DE": [221, 222],
|
|
22
|
+
"VA": [223, 231],
|
|
23
|
+
"WV": [232, 236],
|
|
24
|
+
"NC": [237, 246],
|
|
25
|
+
"SC": [247, 251],
|
|
26
|
+
"GA": [252, 260],
|
|
27
|
+
"FL": [261, 267],
|
|
28
|
+
"OH": [268, 302],
|
|
29
|
+
"IN": [303, 317],
|
|
30
|
+
"IL": [318, 361],
|
|
31
|
+
"MI": [362, 386],
|
|
32
|
+
"WI": [387, 399],
|
|
33
|
+
"KY": [400, 407],
|
|
34
|
+
"TN": [408, 415],
|
|
35
|
+
"AL": [416, 424],
|
|
36
|
+
"MS": [425, 428],
|
|
37
|
+
"AR": [429, 432],
|
|
38
|
+
"LA": [433, 439],
|
|
39
|
+
"OK": [440, 448],
|
|
40
|
+
"TX": [449, 467],
|
|
41
|
+
"MN": [468, 477],
|
|
42
|
+
"IA": [478, 485],
|
|
43
|
+
"MO": [486, 500],
|
|
44
|
+
"ND": [501, 502],
|
|
45
|
+
"SD": [503, 504],
|
|
46
|
+
"NE": [505, 508],
|
|
47
|
+
"KS": [509, 515],
|
|
48
|
+
"MT": [516, 517],
|
|
49
|
+
"ID": [518, 519],
|
|
50
|
+
"WY": [520, 520],
|
|
51
|
+
"CO": [521, 524],
|
|
52
|
+
"NM": [525, 527],
|
|
53
|
+
"AZ": [526, 527],
|
|
54
|
+
"UT": [528, 529],
|
|
55
|
+
"NV": [530, 530],
|
|
56
|
+
"WA": [531, 539],
|
|
57
|
+
"OR": [540, 544],
|
|
58
|
+
"CA": [545, 573],
|
|
59
|
+
"AK": [574, 574],
|
|
60
|
+
"HI": [575, 576],
|
|
61
|
+
"DC": [577, 579],
|
|
62
|
+
"VI": [580, 580],
|
|
63
|
+
"PR": [580, 599],
|
|
64
|
+
"GU": [586, 586],
|
|
65
|
+
"AS": [586, 586],
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def generate_ssn(state: str, birth_date: date) -> str:
|
|
70
|
+
"""
|
|
71
|
+
Generate a synthetic SSN based on state and birth date.
|
|
72
|
+
|
|
73
|
+
The first three digits are derived from the state the person lives in,
|
|
74
|
+
if born after June 25, 2011, with an 80% chance. Otherwise, the first
|
|
75
|
+
three digits are randomly chosen from the possible codes.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
state (str): Two-letter state code (e.g., "NY", "CA")
|
|
79
|
+
birth_date (date): Date of birth
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
str: A formatted synthetic SSN in the format "XXX-XX-XXXX"
|
|
83
|
+
|
|
84
|
+
"""
|
|
85
|
+
if birth_date < SSN_RANDOMIZATION_DATE:
|
|
86
|
+
if random.random() < 0.3:
|
|
87
|
+
# Maybe born in a different state
|
|
88
|
+
area_range = random.choice(list(STATE_TO_AREA_SSN.values()))
|
|
89
|
+
area_range = STATE_TO_AREA_SSN.get(state, [1, 899])
|
|
90
|
+
else:
|
|
91
|
+
area_range = [1, 899]
|
|
92
|
+
area = 666
|
|
93
|
+
while area == 666:
|
|
94
|
+
# Unallowed area code
|
|
95
|
+
area = random.randint(area_range[0], area_range[1])
|
|
96
|
+
# Group number
|
|
97
|
+
group = random.randint(1, 99)
|
|
98
|
+
# Serial number
|
|
99
|
+
serial = random.randint(1, 9999)
|
|
100
|
+
return f"{area:03d}-{group:02d}-{serial:04d}"
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from datetime import date, timedelta
|
|
5
|
+
import random
|
|
6
|
+
from typing import Any, Literal, TypeAlias
|
|
7
|
+
|
|
8
|
+
from data_designer.config.utils.constants import LOCALES_WITH_MANAGED_DATASETS
|
|
9
|
+
from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
|
|
10
|
+
from data_designer.engine.resources.managed_dataset_repository import load_managed_dataset_repository
|
|
11
|
+
from data_designer.engine.resources.managed_storage import ManagedBlobStorage
|
|
12
|
+
from data_designer.engine.sampling_gen.entities.dataset_based_person_fields import (
|
|
13
|
+
PERSONA_FIELDS,
|
|
14
|
+
PII_FIELDS,
|
|
15
|
+
REQUIRED_FIELDS,
|
|
16
|
+
)
|
|
17
|
+
from data_designer.engine.sampling_gen.entities.email_address_utils import get_email_address
|
|
18
|
+
from data_designer.engine.sampling_gen.entities.errors import MissingPersonFieldsError
|
|
19
|
+
from data_designer.engine.sampling_gen.entities.national_id_utils import generate_ssn
|
|
20
|
+
from data_designer.engine.sampling_gen.entities.phone_number import PhoneNumber
|
|
21
|
+
from data_designer.engine.sampling_gen.errors import DatasetNotAvailableForLocaleError
|
|
22
|
+
|
|
23
|
+
SexT: TypeAlias = Literal["Male", "Female"]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def convert_age_to_birth_date(age: int) -> date:
|
|
27
|
+
today = date.today()
|
|
28
|
+
start_date = today.replace(year=today.year - age - 1)
|
|
29
|
+
end_date = today.replace(year=today.year - age)
|
|
30
|
+
days_between = (end_date - start_date).days
|
|
31
|
+
random_days = random.randint(0, days_between)
|
|
32
|
+
birthdate = start_date + timedelta(days=random_days)
|
|
33
|
+
return birthdate
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def generate_email_address(
|
|
37
|
+
first_name: str,
|
|
38
|
+
middle_name: str | None,
|
|
39
|
+
last_name: str,
|
|
40
|
+
age: int,
|
|
41
|
+
birth_date: date,
|
|
42
|
+
) -> str | None:
|
|
43
|
+
"""
|
|
44
|
+
Generate an email address based on the person's attributes.
|
|
45
|
+
Email address is None for children. Uses common free email domains.
|
|
46
|
+
"""
|
|
47
|
+
if age < 18:
|
|
48
|
+
return None
|
|
49
|
+
return get_email_address(
|
|
50
|
+
first_name=first_name,
|
|
51
|
+
middle_name=middle_name,
|
|
52
|
+
last_name=last_name,
|
|
53
|
+
age=age,
|
|
54
|
+
birth_date=birth_date,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_national_id(locale: str | None, region: str | None, birth_date: date) -> str | None:
|
|
59
|
+
if locale != "en_US":
|
|
60
|
+
return None
|
|
61
|
+
if region is None:
|
|
62
|
+
return None
|
|
63
|
+
return generate_ssn(state=region, birth_date=birth_date)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def generate_phone_number(locale: str, age: int, postcode: str | None, style: str = "dash") -> str | None:
|
|
67
|
+
"""
|
|
68
|
+
Generate a phone number correlated with location (postcode).
|
|
69
|
+
Phone number is None for children.
|
|
70
|
+
"""
|
|
71
|
+
if locale != "en_US":
|
|
72
|
+
return None
|
|
73
|
+
if age < 18:
|
|
74
|
+
return None
|
|
75
|
+
if postcode is None:
|
|
76
|
+
return None
|
|
77
|
+
locality_var = random.random()
|
|
78
|
+
if locality_var < 0.6:
|
|
79
|
+
# Exact match to postcode 60% of the time
|
|
80
|
+
return PhoneNumber.from_zip_prefix(postcode).format(style=style)
|
|
81
|
+
elif locality_var < 0.8:
|
|
82
|
+
# Nearby postcodes 20% of the time
|
|
83
|
+
return PhoneNumber.from_zip_prefix(postcode[:4]).format(style=style)
|
|
84
|
+
elif locality_var < 0.9:
|
|
85
|
+
# More distant postcodes 10% of the time
|
|
86
|
+
return PhoneNumber.from_zip_prefix(postcode[:3]).format(style=style)
|
|
87
|
+
# Random (population-weighted) area code 10% of the time
|
|
88
|
+
return PhoneNumber.generate().format(style=style)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def generate_and_insert_derived_fields(person_record: dict[str, Any]) -> dict[str, str | None]:
|
|
92
|
+
_verify_required_fields(person_record)
|
|
93
|
+
birth_date = convert_age_to_birth_date(person_record.get("age"))
|
|
94
|
+
person_record.update(
|
|
95
|
+
{
|
|
96
|
+
# Note: All data must be serializable to JSON.
|
|
97
|
+
"birth_date": birth_date.isoformat(),
|
|
98
|
+
"phone_number": generate_phone_number(
|
|
99
|
+
locale=person_record.get("locale"),
|
|
100
|
+
age=person_record.get("age"),
|
|
101
|
+
postcode=person_record.get("postcode"),
|
|
102
|
+
),
|
|
103
|
+
"email_address": generate_email_address(
|
|
104
|
+
first_name=person_record.get("first_name"),
|
|
105
|
+
middle_name=person_record.get("middle_name"),
|
|
106
|
+
last_name=person_record.get("last_name"),
|
|
107
|
+
age=person_record.get("age"),
|
|
108
|
+
birth_date=birth_date,
|
|
109
|
+
),
|
|
110
|
+
"national_id": get_national_id(
|
|
111
|
+
locale=person_record.get("locale"),
|
|
112
|
+
region=person_record.get("region"),
|
|
113
|
+
birth_date=birth_date,
|
|
114
|
+
),
|
|
115
|
+
}
|
|
116
|
+
)
|
|
117
|
+
if person_record.get("locale") == "en_US" and "region" in person_record and "state" not in person_record:
|
|
118
|
+
state = person_record.pop("region")
|
|
119
|
+
person_record.update({"state": state})
|
|
120
|
+
|
|
121
|
+
return {
|
|
122
|
+
**{k: v for k, v in person_record.items() if k in PII_FIELDS},
|
|
123
|
+
**{k: v for k, v in person_record.items() if k in ["state", "phone_number", "email_address", "national_id"]},
|
|
124
|
+
**{k: v for k, v in person_record.items() if k in PERSONA_FIELDS},
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def load_person_data_sampler(blob_storage: ManagedBlobStorage, locale: str) -> ManagedDatasetGenerator:
|
|
129
|
+
if locale not in LOCALES_WITH_MANAGED_DATASETS:
|
|
130
|
+
raise DatasetNotAvailableForLocaleError(f"Locale {locale} is not supported by the managed dataset generator.")
|
|
131
|
+
|
|
132
|
+
return ManagedDatasetGenerator(
|
|
133
|
+
managed_datasets=load_managed_dataset_repository(blob_storage, [locale]),
|
|
134
|
+
dataset_name=locale,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _verify_required_fields(person_record: dict[str, Any]) -> None:
|
|
139
|
+
"""Verify that the person record contains all required fields."""
|
|
140
|
+
missing_fields = REQUIRED_FIELDS - set(person_record.keys())
|
|
141
|
+
if missing_fields:
|
|
142
|
+
raise MissingPersonFieldsError(f"Person data is missing the following required fields: {missing_fields}")
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import random
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
from pydantic import BaseModel, Field, field_validator
|
|
10
|
+
|
|
11
|
+
ZIP_AREA_CODE_DATA = pd.read_parquet(Path(__file__).parent / "assets" / "zip_area_code_map.parquet")
|
|
12
|
+
ZIPCODE_AREA_CODE_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DATA["area_code"]))
|
|
13
|
+
ZIPCODE_POPULATION_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DATA["count"]))
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_area_code(zip_prefix: Optional[str] = None) -> str:
|
|
17
|
+
"""
|
|
18
|
+
Sample an area code for the given ZIP code prefix, population-weighted.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
zip_prefix: The prefix of a ZIP code, 5 digits or fewer. If None, sample from all ZIP codes.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
A sampled area code matching the prefix, population-weighted.
|
|
25
|
+
"""
|
|
26
|
+
if zip_prefix is None:
|
|
27
|
+
zipcodes, weights = zip(*ZIPCODE_POPULATION_MAP.items())
|
|
28
|
+
zipcode = random.choices(zipcodes, weights=weights, k=1)[0]
|
|
29
|
+
return str(ZIPCODE_AREA_CODE_MAP[zipcode])
|
|
30
|
+
if len(zip_prefix) == 5:
|
|
31
|
+
try:
|
|
32
|
+
return str(ZIPCODE_AREA_CODE_MAP[zip_prefix])
|
|
33
|
+
except KeyError:
|
|
34
|
+
raise ValueError(f"ZIP code {zip_prefix} not found.")
|
|
35
|
+
matching_zipcodes = [[z, c] for z, c in ZIPCODE_POPULATION_MAP.items() if z.startswith(zip_prefix)]
|
|
36
|
+
zipcodes, weights = zip(*matching_zipcodes)
|
|
37
|
+
if not zipcodes:
|
|
38
|
+
raise ValueError(f"No ZIP codes found with prefix {zip_prefix}.")
|
|
39
|
+
zipcode = random.choices(zipcodes, weights=weights, k=1)[0]
|
|
40
|
+
return str(ZIPCODE_AREA_CODE_MAP[zipcode])
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class PhoneNumber(BaseModel):
|
|
44
|
+
"""
|
|
45
|
+
A phone number object that supports various formatting styles
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
country_code: str = Field(default="1")
|
|
49
|
+
area_code: str
|
|
50
|
+
prefix: str # First part of the local number
|
|
51
|
+
line_number: str # Second part of the local number
|
|
52
|
+
|
|
53
|
+
@field_validator("country_code", "area_code", "prefix", "line_number")
|
|
54
|
+
@classmethod
|
|
55
|
+
def validate_digits(cls, v):
|
|
56
|
+
if not v.isdigit():
|
|
57
|
+
raise ValueError("Must contain only digits")
|
|
58
|
+
return v
|
|
59
|
+
|
|
60
|
+
@field_validator("country_code")
|
|
61
|
+
@classmethod
|
|
62
|
+
def validate_country_code_length(cls, v):
|
|
63
|
+
max_length = 3
|
|
64
|
+
if len(v) > max_length:
|
|
65
|
+
raise ValueError(f"Country code {v} is longer than {max_length} digits")
|
|
66
|
+
return v
|
|
67
|
+
|
|
68
|
+
def format(self, style: str = "dash") -> str:
|
|
69
|
+
"""
|
|
70
|
+
Format the phone number according to the specified style.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
style: One of "dash", "parentheses", "dot", "space", "no_separation",
|
|
74
|
+
"international_plus", "international"
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Formatted phone number string
|
|
78
|
+
"""
|
|
79
|
+
if style == "dash":
|
|
80
|
+
formatted = f"{self.area_code}-{self.prefix}-{self.line_number}"
|
|
81
|
+
elif style == "parentheses":
|
|
82
|
+
formatted = f"({self.area_code}) {self.prefix}-{self.line_number}"
|
|
83
|
+
elif style == "dot":
|
|
84
|
+
formatted = f"{self.area_code}.{self.prefix}.{self.line_number}"
|
|
85
|
+
elif style == "space":
|
|
86
|
+
formatted = f"{self.area_code} {self.prefix} {self.line_number}"
|
|
87
|
+
elif style == "no_separation":
|
|
88
|
+
formatted = f"{self.area_code}{self.prefix}{self.line_number}"
|
|
89
|
+
elif style == "international_plus":
|
|
90
|
+
cc = self.country_code or "1" # Default to US/Canada
|
|
91
|
+
formatted = f"+{cc} {self.area_code} {self.prefix} {self.line_number}"
|
|
92
|
+
elif style == "international":
|
|
93
|
+
cc = int(self.country_code or 1) # Default to US/Canada
|
|
94
|
+
formatted = f"{cc:03d} {self.area_code} {self.prefix} {self.line_number}"
|
|
95
|
+
else:
|
|
96
|
+
raise ValueError(f"Unsupported format style: {style}")
|
|
97
|
+
|
|
98
|
+
return formatted
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def from_area_code(cls, area_code: str) -> "PhoneNumber":
|
|
102
|
+
prefix = str(random.randint(200, 1000))
|
|
103
|
+
line_number = str(random.randint(0, 10000)).zfill(4)
|
|
104
|
+
return PhoneNumber(area_code=area_code, prefix=prefix, line_number=line_number)
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def from_zip_prefix(cls, zip_prefix: str) -> "PhoneNumber":
|
|
108
|
+
"""Create a phone number from the given ZIP code prefix."""
|
|
109
|
+
area_code = get_area_code(zip_prefix)
|
|
110
|
+
return cls.from_area_code(area_code)
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def generate(cls) -> "PhoneNumber":
|
|
114
|
+
"""Create a random valid US phone number."""
|
|
115
|
+
area_code = get_area_code()
|
|
116
|
+
return cls.from_area_code(area_code)
|
|
117
|
+
|
|
118
|
+
def __str__(self) -> str:
|
|
119
|
+
return self.format("dash")
|
|
120
|
+
|
|
121
|
+
def __repr__(self) -> str:
|
|
122
|
+
return f"PhoneNumber({str(self)})"
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.engine.errors import DataDesignerError
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SamplingGenError(DataDesignerError):
|
|
8
|
+
"""Base exception for all errors in the sampling_gen library."""
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RejectionSamplingError(SamplingGenError):
|
|
12
|
+
"""Exception for all errors related to rejection sampling."""
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DataConversionError(SamplingGenError):
|
|
16
|
+
"""Exception for all errors related to data conversion."""
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DatasetNotAvailableForLocaleError(SamplingGenError):
|
|
20
|
+
"""Exception for all errors related to the dataset not being available for a given locale."""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ManagedDatasetGeneratorError(SamplingGenError):
|
|
24
|
+
"""Exception for all errors related to the managed dataset generator."""
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
import networkx as nx
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
from data_designer.engine.sampling_gen.data_sources.base import RadomStateT
|
|
14
|
+
from data_designer.engine.sampling_gen.errors import RejectionSamplingError
|
|
15
|
+
from data_designer.engine.sampling_gen.jinja_utils import JinjaDataFrame
|
|
16
|
+
from data_designer.engine.sampling_gen.people_gen import create_people_gen_resource
|
|
17
|
+
from data_designer.engine.sampling_gen.schema import DataSchema
|
|
18
|
+
from data_designer.engine.sampling_gen.utils import check_random_state
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
|
|
22
|
+
from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
|
|
23
|
+
from data_designer.engine.sampling_gen.column import ConditionalDataColumn
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DatasetGenerator:
|
|
27
|
+
"""Generates synthetic datasets based on the given schema definition.
|
|
28
|
+
|
|
29
|
+
This object generates synthetic data based on the schema using sampling-based
|
|
30
|
+
methods (implemented as "data sources"), including handling conditional generation
|
|
31
|
+
and enforcing constraints through rejection sampling.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
sampler_columns: Sampler columns to generate.
|
|
35
|
+
random_state: Random number generator or seed for reproducibility.
|
|
36
|
+
person_generator_loader: A function that loads a person generator. If None,
|
|
37
|
+
person generation will not be supported.
|
|
38
|
+
|
|
39
|
+
Note:
|
|
40
|
+
The generator leverages the schema's DAG to topologically sort the columns
|
|
41
|
+
and uses rejection sampling to satisfy constraints. If constraints are too strict,
|
|
42
|
+
generation may fail with a RejectionSamplingError.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
sampler_columns: SamplerMultiColumnConfig | None,
|
|
48
|
+
random_state: RadomStateT | None = None,
|
|
49
|
+
person_generator_loader: Callable[[bool], ManagedDatasetGenerator] | None = None,
|
|
50
|
+
*,
|
|
51
|
+
schema: DataSchema | None = None,
|
|
52
|
+
max_rejections_factor: int = 5,
|
|
53
|
+
):
|
|
54
|
+
# This is temporary while we need the legacy and refactored code to coexist.
|
|
55
|
+
if schema is not None:
|
|
56
|
+
self.schema = schema
|
|
57
|
+
self.max_rejections_factor = max_rejections_factor
|
|
58
|
+
else:
|
|
59
|
+
self.schema = DataSchema(
|
|
60
|
+
columns=[column.model_dump() for column in sampler_columns.columns],
|
|
61
|
+
constraints=sampler_columns.constraints,
|
|
62
|
+
)
|
|
63
|
+
self.max_rejections_factor = sampler_columns.max_rejections_factor
|
|
64
|
+
|
|
65
|
+
self.rng = check_random_state(random_state)
|
|
66
|
+
self._dag = self.schema.dag.to_networkx()
|
|
67
|
+
self._shared_sampler_kwargs = {
|
|
68
|
+
"random_state": self.rng,
|
|
69
|
+
"people_gen_resource": create_people_gen_resource(self.schema, person_generator_loader),
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
def _round_if_needed(self, column: ConditionalDataColumn, df: pd.DataFrame) -> pd.DataFrame:
|
|
73
|
+
if hasattr(column.params, "decimal_places") and column.params.decimal_places is not None:
|
|
74
|
+
df[column.name] = df[column.name].round(column.params.decimal_places)
|
|
75
|
+
return df
|
|
76
|
+
|
|
77
|
+
def _run_rejection_sampling(self, df: pd.DataFrame, column: ConditionalDataColumn) -> pd.DataFrame:
|
|
78
|
+
name = column.name
|
|
79
|
+
num_iterations = 0
|
|
80
|
+
num_samples = len(df)
|
|
81
|
+
needs_samples = np.ones(num_samples, dtype=bool)
|
|
82
|
+
|
|
83
|
+
while needs_samples.any():
|
|
84
|
+
for condition in column.conditions:
|
|
85
|
+
index = JinjaDataFrame(condition).select_index(df[needs_samples])
|
|
86
|
+
src = column.get_sampler(condition, **self._shared_sampler_kwargs)
|
|
87
|
+
df = src.inject_data_column(df, name, index)
|
|
88
|
+
|
|
89
|
+
df[name] = column.get_default_sampler(**self._shared_sampler_kwargs).preproc(df[name], column.convert_to)
|
|
90
|
+
|
|
91
|
+
# Check all constraints on the column.
|
|
92
|
+
batch_mask = np.ones(num_samples, dtype=bool)
|
|
93
|
+
for constraint in self.schema.get_constraint_checkers(name):
|
|
94
|
+
batch_mask &= constraint.check(df)
|
|
95
|
+
needs_samples[batch_mask] = False
|
|
96
|
+
num_iterations += 1
|
|
97
|
+
|
|
98
|
+
if num_iterations > self.max_rejections_factor * num_samples:
|
|
99
|
+
raise RejectionSamplingError(
|
|
100
|
+
"Exceeded the maximum number of rejections (max_rejections_factor * "
|
|
101
|
+
f"num_samples = {self.max_rejections_factor * num_samples}) while "
|
|
102
|
+
f"sampling `{column.name}`. Please consider adjusting the constraints "
|
|
103
|
+
"and/or column's generation configuration."
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return df
|
|
107
|
+
|
|
108
|
+
def generate(self, num_samples: int) -> pd.DataFrame:
|
|
109
|
+
dataset = pd.DataFrame(index=range(num_samples))
|
|
110
|
+
|
|
111
|
+
for column_name in nx.topological_sort(self._dag):
|
|
112
|
+
column = self.schema.get_column(column_name)
|
|
113
|
+
dataset = self._run_rejection_sampling(dataset, column)
|
|
114
|
+
|
|
115
|
+
for column in self.schema.columns:
|
|
116
|
+
dataset[column.name] = column.get_default_sampler(**self._shared_sampler_kwargs).postproc(
|
|
117
|
+
dataset[column.name], column.convert_to
|
|
118
|
+
)
|
|
119
|
+
dataset = self._round_if_needed(column, dataset)
|
|
120
|
+
|
|
121
|
+
return dataset[self.schema.column_names]
|