data-designer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +15 -0
- data_designer/_version.py +34 -0
- data_designer/cli/README.md +236 -0
- data_designer/cli/__init__.py +6 -0
- data_designer/cli/commands/__init__.py +2 -0
- data_designer/cli/commands/list.py +130 -0
- data_designer/cli/commands/models.py +10 -0
- data_designer/cli/commands/providers.py +11 -0
- data_designer/cli/commands/reset.py +100 -0
- data_designer/cli/controllers/__init__.py +7 -0
- data_designer/cli/controllers/model_controller.py +246 -0
- data_designer/cli/controllers/provider_controller.py +317 -0
- data_designer/cli/forms/__init__.py +20 -0
- data_designer/cli/forms/builder.py +51 -0
- data_designer/cli/forms/field.py +180 -0
- data_designer/cli/forms/form.py +59 -0
- data_designer/cli/forms/model_builder.py +125 -0
- data_designer/cli/forms/provider_builder.py +76 -0
- data_designer/cli/main.py +44 -0
- data_designer/cli/repositories/__init__.py +8 -0
- data_designer/cli/repositories/base.py +39 -0
- data_designer/cli/repositories/model_repository.py +42 -0
- data_designer/cli/repositories/provider_repository.py +43 -0
- data_designer/cli/services/__init__.py +7 -0
- data_designer/cli/services/model_service.py +116 -0
- data_designer/cli/services/provider_service.py +111 -0
- data_designer/cli/ui.py +448 -0
- data_designer/cli/utils.py +47 -0
- data_designer/config/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +89 -0
- data_designer/config/analysis/column_statistics.py +274 -0
- data_designer/config/analysis/dataset_profiler.py +60 -0
- data_designer/config/analysis/utils/errors.py +8 -0
- data_designer/config/analysis/utils/reporting.py +188 -0
- data_designer/config/base.py +68 -0
- data_designer/config/column_configs.py +354 -0
- data_designer/config/column_types.py +168 -0
- data_designer/config/config_builder.py +660 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +11 -0
- data_designer/config/datastore.py +151 -0
- data_designer/config/default_model_settings.py +123 -0
- data_designer/config/errors.py +19 -0
- data_designer/config/interface.py +54 -0
- data_designer/config/models.py +231 -0
- data_designer/config/preview_results.py +32 -0
- data_designer/config/processors.py +41 -0
- data_designer/config/sampler_constraints.py +51 -0
- data_designer/config/sampler_params.py +604 -0
- data_designer/config/seed.py +145 -0
- data_designer/config/utils/code_lang.py +83 -0
- data_designer/config/utils/constants.py +313 -0
- data_designer/config/utils/errors.py +19 -0
- data_designer/config/utils/info.py +88 -0
- data_designer/config/utils/io_helpers.py +273 -0
- data_designer/config/utils/misc.py +81 -0
- data_designer/config/utils/numerical_helpers.py +28 -0
- data_designer/config/utils/type_helpers.py +100 -0
- data_designer/config/utils/validation.py +336 -0
- data_designer/config/utils/visualization.py +427 -0
- data_designer/config/validator_params.py +96 -0
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +55 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
- data_designer/engine/analysis/column_profilers/registry.py +20 -0
- data_designer/engine/analysis/column_statistics.py +142 -0
- data_designer/engine/analysis/dataset_profiler.py +125 -0
- data_designer/engine/analysis/errors.py +7 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +61 -0
- data_designer/engine/column_generators/generators/expression.py +63 -0
- data_designer/engine/column_generators/generators/llm_generators.py +172 -0
- data_designer/engine/column_generators/generators/samplers.py +75 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
- data_designer/engine/column_generators/generators/validation.py +147 -0
- data_designer/engine/column_generators/registry.py +56 -0
- data_designer/engine/column_generators/utils/errors.py +13 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
- data_designer/engine/configurable_task.py +82 -0
- data_designer/engine/dataset_builders/artifact_storage.py +181 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
- data_designer/engine/dataset_builders/errors.py +13 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
- data_designer/engine/dataset_builders/utils/dag.py +56 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
- data_designer/engine/dataset_builders/utils/errors.py +13 -0
- data_designer/engine/errors.py +49 -0
- data_designer/engine/model_provider.py +75 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +308 -0
- data_designer/engine/models/facade.py +225 -0
- data_designer/engine/models/litellm_overrides.py +162 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +236 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +60 -0
- data_designer/engine/models/parsers/types.py +82 -0
- data_designer/engine/models/recipes/base.py +79 -0
- data_designer/engine/models/recipes/response_recipes.py +291 -0
- data_designer/engine/models/registry.py +118 -0
- data_designer/engine/models/usage.py +75 -0
- data_designer/engine/models/utils.py +38 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +64 -0
- data_designer/engine/processing/ginja/environment.py +461 -0
- data_designer/engine/processing/ginja/exceptions.py +54 -0
- data_designer/engine/processing/ginja/record.py +30 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +8 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
- data_designer/engine/processing/gsonschema/types.py +8 -0
- data_designer/engine/processing/gsonschema/validators.py +143 -0
- data_designer/engine/processing/processors/base.py +15 -0
- data_designer/engine/processing/processors/drop_columns.py +46 -0
- data_designer/engine/processing/processors/registry.py +20 -0
- data_designer/engine/processing/utils.py +120 -0
- data_designer/engine/registry/base.py +97 -0
- data_designer/engine/registry/data_designer_registry.py +37 -0
- data_designer/engine/registry/errors.py +10 -0
- data_designer/engine/resources/managed_dataset_generator.py +35 -0
- data_designer/engine/resources/managed_dataset_repository.py +194 -0
- data_designer/engine/resources/managed_storage.py +63 -0
- data_designer/engine/resources/resource_provider.py +46 -0
- data_designer/engine/resources/seed_dataset_data_store.py +66 -0
- data_designer/engine/sampling_gen/column.py +89 -0
- data_designer/engine/sampling_gen/constraints.py +95 -0
- data_designer/engine/sampling_gen/data_sources/base.py +214 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
- data_designer/engine/sampling_gen/entities/errors.py +8 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
- data_designer/engine/sampling_gen/entities/person.py +142 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
- data_designer/engine/sampling_gen/errors.py +24 -0
- data_designer/engine/sampling_gen/generator.py +121 -0
- data_designer/engine/sampling_gen/jinja_utils.py +60 -0
- data_designer/engine/sampling_gen/people_gen.py +203 -0
- data_designer/engine/sampling_gen/person_constants.py +54 -0
- data_designer/engine/sampling_gen/schema.py +143 -0
- data_designer/engine/sampling_gen/schema_builder.py +59 -0
- data_designer/engine/sampling_gen/utils.py +40 -0
- data_designer/engine/secret_resolver.py +80 -0
- data_designer/engine/validators/__init__.py +17 -0
- data_designer/engine/validators/base.py +36 -0
- data_designer/engine/validators/local_callable.py +34 -0
- data_designer/engine/validators/python.py +245 -0
- data_designer/engine/validators/remote.py +83 -0
- data_designer/engine/validators/sql.py +60 -0
- data_designer/errors.py +5 -0
- data_designer/essentials/__init__.py +137 -0
- data_designer/interface/__init__.py +2 -0
- data_designer/interface/data_designer.py +351 -0
- data_designer/interface/errors.py +16 -0
- data_designer/interface/results.py +55 -0
- data_designer/logging.py +161 -0
- data_designer/plugin_manager.py +83 -0
- data_designer/plugins/__init__.py +6 -0
- data_designer/plugins/errors.py +10 -0
- data_designer/plugins/plugin.py +69 -0
- data_designer/plugins/registry.py +86 -0
- data_designer-0.1.0.dist-info/METADATA +173 -0
- data_designer-0.1.0.dist-info/RECORD +177 -0
- data_designer-0.1.0.dist-info/WHEEL +4 -0
- data_designer-0.1.0.dist-info/entry_points.txt +2 -0
- data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import ast
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
|
|
9
|
+
from data_designer.engine.processing.ginja.environment import (
|
|
10
|
+
UserTemplateSandboxEnvironment,
|
|
11
|
+
WithJinja2UserTemplateRendering,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class JinjaDataFrame(WithJinja2UserTemplateRendering):
|
|
16
|
+
def __init__(self, expr: str):
|
|
17
|
+
self.expr = expr
|
|
18
|
+
|
|
19
|
+
def _jsonify(self, record) -> dict[str, Any]:
|
|
20
|
+
for key, value in record.items():
|
|
21
|
+
if isinstance(value, pd.Timestamp):
|
|
22
|
+
record[key] = value.isoformat()
|
|
23
|
+
return record
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def _expr(self) -> str:
|
|
27
|
+
return "{{ " + self.expr + " }}"
|
|
28
|
+
|
|
29
|
+
def select_index(self, dataframe: pd.DataFrame) -> pd.Index:
|
|
30
|
+
if dataframe.empty or self.expr == "...":
|
|
31
|
+
return dataframe.index
|
|
32
|
+
|
|
33
|
+
self.prepare_jinja2_template_renderer(self._expr, list(dataframe))
|
|
34
|
+
|
|
35
|
+
where = dataframe.apply(
|
|
36
|
+
lambda row: self.render_template(self._jsonify(row.to_dict())) == "True",
|
|
37
|
+
axis=1,
|
|
38
|
+
).to_numpy()
|
|
39
|
+
|
|
40
|
+
return dataframe[where].index
|
|
41
|
+
|
|
42
|
+
def to_column(self, dataframe: pd.DataFrame) -> list[Any]:
|
|
43
|
+
self.prepare_jinja2_template_renderer(self._expr, list(dataframe))
|
|
44
|
+
|
|
45
|
+
expr_values = []
|
|
46
|
+
for record in dataframe.to_dict(orient="records"):
|
|
47
|
+
rendered = self.render_template(self._jsonify(record))
|
|
48
|
+
try:
|
|
49
|
+
# Non-string expressions are evaluated as literals.
|
|
50
|
+
expr_values.append(ast.literal_eval(rendered))
|
|
51
|
+
except (ValueError, SyntaxError):
|
|
52
|
+
# Strings throw an error and are appended directly.
|
|
53
|
+
expr_values.append(rendered)
|
|
54
|
+
|
|
55
|
+
return expr_values
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def extract_column_names_from_expression(expr: str) -> set[str]:
|
|
59
|
+
"""Extract valid column names from the given expression."""
|
|
60
|
+
return UserTemplateSandboxEnvironment().get_references("{{ " + expr + " }}")
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from copy import deepcopy
|
|
9
|
+
import random
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
11
|
+
import uuid
|
|
12
|
+
|
|
13
|
+
from faker import Faker
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
from data_designer.config.utils.constants import AVAILABLE_LOCALES, DEFAULT_AGE_RANGE
|
|
17
|
+
from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
|
|
18
|
+
from data_designer.engine.sampling_gen.entities.dataset_based_person_fields import PERSONA_FIELDS, PII_FIELDS
|
|
19
|
+
from data_designer.engine.sampling_gen.entities.person import (
|
|
20
|
+
convert_age_to_birth_date,
|
|
21
|
+
generate_and_insert_derived_fields,
|
|
22
|
+
)
|
|
23
|
+
from data_designer.engine.sampling_gen.errors import ManagedDatasetGeneratorError
|
|
24
|
+
from data_designer.engine.sampling_gen.person_constants import faker_constants
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from data_designer.engine.sampling_gen.schema import DataSchema
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
EngineT = Union[Faker, ManagedDatasetGenerator]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class PeopleGen(ABC):
|
|
34
|
+
"""Unified interface for generating people data."""
|
|
35
|
+
|
|
36
|
+
def __init__(self, engine: EngineT, locale: str):
|
|
37
|
+
if locale not in AVAILABLE_LOCALES:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
f"Locale {locale} is not a supported locale.Supported locales: {', '.join(AVAILABLE_LOCALES)}"
|
|
40
|
+
)
|
|
41
|
+
self.locale = locale
|
|
42
|
+
self._engine = engine
|
|
43
|
+
|
|
44
|
+
def set_engine(self, engine: EngineT) -> None:
|
|
45
|
+
self._engine = engine
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def generate(self, n: int, **kwargs) -> list[dict[str, Any]]: ...
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class PeopleGenFaker(PeopleGen):
|
|
52
|
+
@property
|
|
53
|
+
def _fake(self) -> Faker:
|
|
54
|
+
return self._engine
|
|
55
|
+
|
|
56
|
+
def try_fake_else_none(self, attr_name: str, none_fill: Any | None = None) -> type:
|
|
57
|
+
try:
|
|
58
|
+
return getattr(self._fake, attr_name)()
|
|
59
|
+
except AttributeError:
|
|
60
|
+
return none_fill
|
|
61
|
+
|
|
62
|
+
def _generate_name_and_sex(self, **kwargs) -> dict[str, str]:
|
|
63
|
+
options = faker_constants.sex
|
|
64
|
+
if "sex" in kwargs and kwargs["sex"] in [*options, *[[o] for o in options]]:
|
|
65
|
+
sex = random.choice(kwargs["sex"]) if isinstance(kwargs["sex"], list) else kwargs["sex"]
|
|
66
|
+
else:
|
|
67
|
+
sex = random.choice(options)
|
|
68
|
+
|
|
69
|
+
return {
|
|
70
|
+
"first_name": getattr(self._fake, f"first_name_{sex.lower()}")(),
|
|
71
|
+
"last_name": getattr(self._fake, f"last_name_{sex.lower()}")(),
|
|
72
|
+
"middle_name": None,
|
|
73
|
+
"sex": sex,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
def _generate_address_fields(self, **kwargs) -> dict[str, str]:
|
|
77
|
+
address = {
|
|
78
|
+
"street_number": self.try_fake_else_none(faker_constants.attr_map["street_number"]),
|
|
79
|
+
"street_name": self.try_fake_else_none("street_name"),
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
# Location fields can be filtered using the fixed_kwargs.
|
|
83
|
+
for attr in faker_constants.location:
|
|
84
|
+
if attr in kwargs:
|
|
85
|
+
address[attr] = random.choice(kwargs[attr]) if isinstance(kwargs[attr], list) else kwargs[attr]
|
|
86
|
+
else:
|
|
87
|
+
address[attr] = self.try_fake_else_none(attr)
|
|
88
|
+
|
|
89
|
+
return address
|
|
90
|
+
|
|
91
|
+
def _generate_age(self, **kwargs) -> int:
|
|
92
|
+
return random.randint(*kwargs.get("age_range", DEFAULT_AGE_RANGE))
|
|
93
|
+
|
|
94
|
+
def _generate_marital_status(self, **kwargs) -> str:
|
|
95
|
+
return random.choice(faker_constants.marital_status)
|
|
96
|
+
|
|
97
|
+
def _generate_bachelors_field(self, **kwargs) -> str:
|
|
98
|
+
return random.choice(faker_constants.bachelors)
|
|
99
|
+
|
|
100
|
+
def _generate_education_level(self, **kwargs) -> str:
|
|
101
|
+
return random.choice(faker_constants.education_level)
|
|
102
|
+
|
|
103
|
+
def make_person(self, **kwargs) -> dict[str, Any]:
|
|
104
|
+
person = {"uuid": str(uuid.uuid4()), "locale": self.locale}
|
|
105
|
+
person.update(self._generate_name_and_sex(**kwargs))
|
|
106
|
+
person.update(self._generate_address_fields(**kwargs))
|
|
107
|
+
person.update({"age": self._generate_age(**kwargs)})
|
|
108
|
+
person.update({"birth_date": convert_age_to_birth_date(person["age"]).isoformat()})
|
|
109
|
+
person.update({"country": self.try_fake_else_none("country")})
|
|
110
|
+
person.update({"marital_status": self._generate_marital_status(**kwargs)})
|
|
111
|
+
person.update({"education_level": self._generate_education_level(**kwargs)})
|
|
112
|
+
person.update({"unit": ""})
|
|
113
|
+
person.update({"occupation": self.try_fake_else_none(faker_constants.attr_map["occupation"])})
|
|
114
|
+
person.update({"phone_number": (self.try_fake_else_none("phone_number") if person["age"] >= 18 else None)})
|
|
115
|
+
if person["education_level"] in faker_constants.college_level:
|
|
116
|
+
person.update({"bachelors_field": self._generate_bachelors_field(**kwargs)})
|
|
117
|
+
else:
|
|
118
|
+
person.update({"bachelors_field": "no_degree"})
|
|
119
|
+
return person
|
|
120
|
+
|
|
121
|
+
def generate(self, n: int, **kwargs) -> list[dict[str, Any]]:
|
|
122
|
+
return [self.make_person(**kwargs) for _ in range(n)]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class PeopleGenFromDataset(PeopleGen):
|
|
126
|
+
def _get_ages(self, age_range: tuple[int, int]) -> list[int]:
|
|
127
|
+
return list(range(age_range[0], age_range[1] + 1))
|
|
128
|
+
|
|
129
|
+
def _generate_from_dataset(self, n: int, **kwargs) -> pd.DataFrame:
|
|
130
|
+
kw = deepcopy(kwargs)
|
|
131
|
+
with_synthetic_personas = kw.pop("with_synthetic_personas", False)
|
|
132
|
+
kw["age"] = self._get_ages(kw.pop("age_range", DEFAULT_AGE_RANGE))
|
|
133
|
+
|
|
134
|
+
# Generate samples and drop columns where all rows are null.
|
|
135
|
+
df = self._engine.generate_samples(size=n, evidence=kw).dropna(axis=1, how="all")
|
|
136
|
+
|
|
137
|
+
# We need this for derived fields.
|
|
138
|
+
df["locale"] = self.locale
|
|
139
|
+
|
|
140
|
+
# Only keep columns that are listed in the schema.
|
|
141
|
+
fields = [field for field in PII_FIELDS if field in df.columns]
|
|
142
|
+
if with_synthetic_personas:
|
|
143
|
+
fields.extend([field for field in PERSONA_FIELDS if field in df.columns])
|
|
144
|
+
|
|
145
|
+
return df[fields]
|
|
146
|
+
|
|
147
|
+
def generate(self, n: int, **kwargs) -> list[dict[str, Any]]:
|
|
148
|
+
return [
|
|
149
|
+
generate_and_insert_derived_fields(p)
|
|
150
|
+
for p in self._generate_from_dataset(n, **kwargs).to_dict(orient="records")
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def create_people_gen_resource(
|
|
155
|
+
schema: DataSchema,
|
|
156
|
+
person_generator_loader: Callable[[bool], ManagedDatasetGenerator] | None = None,
|
|
157
|
+
) -> dict[str, PeopleGen]:
|
|
158
|
+
"""Creates resource of unique people generators needed to generate the dataset.
|
|
159
|
+
|
|
160
|
+
The resource is a dictionary of person generators, where the keys are the following:
|
|
161
|
+
- {locale} for dataset-based person generators
|
|
162
|
+
- {locale}_with_personas for dataset-based person generators with synthetic personas
|
|
163
|
+
- {locale}_faker for faker-based person generators
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
schema: Schema of the dataset that we will generate.
|
|
167
|
+
person_generator_loader: Function that loads a managed dataset generator.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Dictionary of unique people generators needed to generate the dataset.
|
|
171
|
+
"""
|
|
172
|
+
people_gen_resource = {}
|
|
173
|
+
|
|
174
|
+
# ------------------------------------------------------------
|
|
175
|
+
# Preload dataset-based person generators
|
|
176
|
+
# ------------------------------------------------------------
|
|
177
|
+
|
|
178
|
+
for column in schema.get_columns_by_sampler_type("person"):
|
|
179
|
+
for params in [column.params, *list(column.conditional_params.values())]:
|
|
180
|
+
if params.people_gen_key not in people_gen_resource:
|
|
181
|
+
try:
|
|
182
|
+
engine = person_generator_loader(locale=params.locale)
|
|
183
|
+
people_gen_resource[params.people_gen_key] = PeopleGenFromDataset(
|
|
184
|
+
engine=engine, locale=params.locale
|
|
185
|
+
)
|
|
186
|
+
except Exception as e:
|
|
187
|
+
raise ManagedDatasetGeneratorError(
|
|
188
|
+
f"🛑 Failed to load dataset-based person generator for locale {params.locale}. "
|
|
189
|
+
"Please check if you have access to person data for this locale. "
|
|
190
|
+
) from e
|
|
191
|
+
|
|
192
|
+
# ------------------------------------------------------------
|
|
193
|
+
# Preload faker-based person generators
|
|
194
|
+
# ------------------------------------------------------------
|
|
195
|
+
|
|
196
|
+
for column in schema.get_columns_by_sampler_type("person_from_faker"):
|
|
197
|
+
for params in [column.params, *list(column.conditional_params.values())]:
|
|
198
|
+
if params.people_gen_key not in people_gen_resource:
|
|
199
|
+
people_gen_resource[params.people_gen_key] = PeopleGenFaker(
|
|
200
|
+
engine=Faker(params.locale), locale=params.locale
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
return people_gen_resource
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from typing import NamedTuple
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class FakerPersonData(NamedTuple):
|
|
8
|
+
sex: list[str] = ["Male", "Female"]
|
|
9
|
+
|
|
10
|
+
us_locale_only: list[str] = [
|
|
11
|
+
"state",
|
|
12
|
+
"county",
|
|
13
|
+
"unit",
|
|
14
|
+
"middle_name",
|
|
15
|
+
"ethnic_background",
|
|
16
|
+
"ssn",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
location: list[str] = ["city", "state", "postcode"]
|
|
20
|
+
|
|
21
|
+
bachelors: list[str] = [
|
|
22
|
+
"stem",
|
|
23
|
+
"business",
|
|
24
|
+
"education",
|
|
25
|
+
"arts_humanities",
|
|
26
|
+
"stem_related",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
education_level: list[str] = [
|
|
30
|
+
"secondary_education",
|
|
31
|
+
"some_college",
|
|
32
|
+
"bachelors",
|
|
33
|
+
"associates",
|
|
34
|
+
"graduate",
|
|
35
|
+
"doctorate",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
marital_status: list[str] = [
|
|
39
|
+
"married_present",
|
|
40
|
+
"divorced",
|
|
41
|
+
"never_married",
|
|
42
|
+
"separated",
|
|
43
|
+
"widowed",
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
college_level: list[str] = ["bachelors", "graduate", "doctorate"]
|
|
47
|
+
|
|
48
|
+
attr_map: dict[str, str] = {
|
|
49
|
+
"street_number": "building_number",
|
|
50
|
+
"occupation": "job",
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
faker_constants = FakerPersonData()
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from functools import cached_property
|
|
7
|
+
|
|
8
|
+
import networkx as nx
|
|
9
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
10
|
+
from typing_extensions import Self
|
|
11
|
+
|
|
12
|
+
from data_designer.config.base import ConfigBase
|
|
13
|
+
from data_designer.config.sampler_constraints import ColumnConstraintT
|
|
14
|
+
from data_designer.config.sampler_params import SamplerType
|
|
15
|
+
from data_designer.engine.sampling_gen.column import ConditionalDataColumn
|
|
16
|
+
from data_designer.engine.sampling_gen.constraints import ConstraintChecker, get_constraint_checker
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Dag(BaseModel):
|
|
20
|
+
nodes: set[str]
|
|
21
|
+
edges: set[tuple[str, str]]
|
|
22
|
+
|
|
23
|
+
@model_validator(mode="after")
|
|
24
|
+
def validate_is_dag(self) -> Self:
|
|
25
|
+
if not nx.is_directed_acyclic_graph(self.to_networkx()):
|
|
26
|
+
raise ValueError("There are circular dependencies in the definitions of your sampler columns.")
|
|
27
|
+
return self
|
|
28
|
+
|
|
29
|
+
def to_networkx(self) -> nx.DiGraph:
|
|
30
|
+
dag = nx.DiGraph()
|
|
31
|
+
for node in self.nodes:
|
|
32
|
+
dag.add_node(node)
|
|
33
|
+
for edge in self.edges:
|
|
34
|
+
dag.add_edge(*edge)
|
|
35
|
+
return dag
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class DataSchema(ConfigBase):
|
|
39
|
+
"""Defines the data schema for synthetic data generation.
|
|
40
|
+
|
|
41
|
+
A DataSchema represents a collection of columns and their relationships through
|
|
42
|
+
conditional parameters and/or constraints. Upon initialization, the schema validates
|
|
43
|
+
that column dependencies form a DAG and that all constraints reference valid columns.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
columns: list[ConditionalDataColumn] = Field(..., min_length=1)
|
|
47
|
+
constraints: list[ColumnConstraintT] = []
|
|
48
|
+
|
|
49
|
+
@cached_property
|
|
50
|
+
def constraint_checkers(self) -> list[ConstraintChecker]:
|
|
51
|
+
return [get_constraint_checker(c.constraint_type)(constraint=c) for c in self.constraints]
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def column_names(self) -> list[str]:
|
|
55
|
+
return [column.name for column in self.columns]
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def dag(self) -> Dag:
|
|
59
|
+
nodes = set()
|
|
60
|
+
edges = set()
|
|
61
|
+
|
|
62
|
+
for column in self.columns:
|
|
63
|
+
nodes.add(column.name)
|
|
64
|
+
|
|
65
|
+
# Add edges for the conditional columns.
|
|
66
|
+
for conditional_column in column.conditional_column_names:
|
|
67
|
+
edges.add((conditional_column, column.name))
|
|
68
|
+
|
|
69
|
+
# Add edges if the source has required columns.
|
|
70
|
+
for condition in column.conditions:
|
|
71
|
+
source = column.get_sampler(condition)
|
|
72
|
+
for required_column in source.get_required_column_names():
|
|
73
|
+
edges.add((required_column, column.name))
|
|
74
|
+
|
|
75
|
+
for checker in self.constraint_checkers:
|
|
76
|
+
column_names = checker.get_required_column_names()
|
|
77
|
+
if len(column_names) == 2:
|
|
78
|
+
edges.add((column_names[1], column_names[0]))
|
|
79
|
+
return Dag(nodes=nodes, edges=edges)
|
|
80
|
+
|
|
81
|
+
@field_validator("columns", mode="after")
|
|
82
|
+
def check_unique_column_names(cls, columns: list[ConditionalDataColumn]) -> list[ConditionalDataColumn]:
|
|
83
|
+
column_names = [column.name for column in columns]
|
|
84
|
+
if len(column_names) != len(set(column_names)):
|
|
85
|
+
raise ValueError("Column names must be unique")
|
|
86
|
+
return columns
|
|
87
|
+
|
|
88
|
+
@model_validator(mode="after")
|
|
89
|
+
def validate_constraints(self) -> Self:
|
|
90
|
+
column_names = [column.name for column in self.columns]
|
|
91
|
+
|
|
92
|
+
# Check if all columns required by constraints are present in the schema.
|
|
93
|
+
for checker in self.constraint_checkers:
|
|
94
|
+
constrained_column_names = checker.get_required_column_names()
|
|
95
|
+
if not set(constrained_column_names).issubset(column_names):
|
|
96
|
+
missing = set(constrained_column_names) - set(column_names)
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f"These constrained columns are missing in the definitions of your sampler columns: {missing}"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return self
|
|
102
|
+
|
|
103
|
+
@model_validator(mode="after")
|
|
104
|
+
def validate_dag(self) -> Self:
|
|
105
|
+
self.dag
|
|
106
|
+
return self
|
|
107
|
+
|
|
108
|
+
@model_validator(mode="after")
|
|
109
|
+
def validate_subcategory_columns_if_present(self) -> Self:
|
|
110
|
+
for sub in self.get_columns_by_sampler_type(SamplerType.SUBCATEGORY):
|
|
111
|
+
cat = self.get_column(sub.params.category)
|
|
112
|
+
if cat.sampler_type != SamplerType.CATEGORY:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"The parent of subcategory column '{sub.name}' must be a category "
|
|
115
|
+
f"source type, but '{cat.name}' is of type '{cat.sampler_type}'."
|
|
116
|
+
)
|
|
117
|
+
cat_vals = set(cat.params.values)
|
|
118
|
+
for params in cat.conditional_params.values():
|
|
119
|
+
cat_vals.update(params.values)
|
|
120
|
+
sub_vals = set(sub.params.values.keys())
|
|
121
|
+
if cat_vals.symmetric_difference(sub_vals):
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"Subcategory column '{sub.name}' must have values for each value of "
|
|
124
|
+
f"its parent category '{sub.params.category}'. The following "
|
|
125
|
+
f"values need attention: {cat_vals.symmetric_difference(sub_vals)}"
|
|
126
|
+
)
|
|
127
|
+
if not all(len(v) > 0 for v in sub.params.values.values()):
|
|
128
|
+
raise ValueError(
|
|
129
|
+
f"Subcategory column '{sub.name}' must have non-empty values for "
|
|
130
|
+
f"each value of its parent category '{sub.params.category}'."
|
|
131
|
+
)
|
|
132
|
+
return self
|
|
133
|
+
|
|
134
|
+
def get_column(self, column_name: str) -> ConditionalDataColumn:
|
|
135
|
+
if column_name not in self.column_names:
|
|
136
|
+
raise ValueError(f"Column '{column_name}' not found in schema")
|
|
137
|
+
return next(column for column in self.columns if column.name == column_name)
|
|
138
|
+
|
|
139
|
+
def get_columns_by_sampler_type(self, sampler_type: SamplerType) -> list[ConditionalDataColumn]:
|
|
140
|
+
return [c for c in self.columns if c.sampler_type == sampler_type]
|
|
141
|
+
|
|
142
|
+
def get_constraint_checkers(self, column_name: str) -> list[ConstraintChecker]:
|
|
143
|
+
return [c for c in self.constraint_checkers if column_name == c.constraint.target_column]
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
|
|
6
|
+
from data_designer.config.column_configs import SamplerColumnConfig
|
|
7
|
+
from data_designer.config.sampler_constraints import ColumnConstraintT
|
|
8
|
+
from data_designer.config.sampler_params import SamplerParamsT
|
|
9
|
+
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
|
|
10
|
+
from data_designer.engine.sampling_gen.column import ConditionalDataColumn
|
|
11
|
+
from data_designer.engine.sampling_gen.schema import DataSchema
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SchemaBuilder:
|
|
15
|
+
"""Builder class for DataSchema objects.
|
|
16
|
+
|
|
17
|
+
This class is meant to be a helper for internal usage and experimentation. It
|
|
18
|
+
provides a simple interface for constructing a DataSchema object via `add_column`
|
|
19
|
+
and `add_constraint` methods similar.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
columns: list[ConditionalDataColumn] | None = None,
|
|
25
|
+
constraints: list[ColumnConstraintT] | None = None,
|
|
26
|
+
):
|
|
27
|
+
self._columns = columns or []
|
|
28
|
+
self._constraints = constraints or []
|
|
29
|
+
|
|
30
|
+
def add_column(
|
|
31
|
+
self,
|
|
32
|
+
name: str,
|
|
33
|
+
sampler_type: str | None,
|
|
34
|
+
params: dict | SamplerParamsT | None,
|
|
35
|
+
conditional_params: dict[str, SamplerParamsT] | None = None,
|
|
36
|
+
convert_to: str | None = None,
|
|
37
|
+
) -> None:
|
|
38
|
+
self._columns.append(
|
|
39
|
+
ConditionalDataColumn(
|
|
40
|
+
name=name,
|
|
41
|
+
sampler_type=sampler_type,
|
|
42
|
+
params=params,
|
|
43
|
+
conditional_params=conditional_params or {},
|
|
44
|
+
convert_to=convert_to,
|
|
45
|
+
)
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def add_constraint(self, constraint: ColumnConstraintT) -> None:
|
|
49
|
+
self._constraints.append(constraint)
|
|
50
|
+
|
|
51
|
+
def to_sampler_columns(self, max_rejections_factor: int = 5) -> SamplerMultiColumnConfig:
|
|
52
|
+
return SamplerMultiColumnConfig(
|
|
53
|
+
columns=[SamplerColumnConfig(**c.model_dump(mode="json")) for c in self._columns],
|
|
54
|
+
constraints=self._constraints,
|
|
55
|
+
max_rejections_factor=max_rejections_factor,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def build(self) -> DataSchema:
|
|
59
|
+
return DataSchema(columns=deepcopy(self._columns), constraints=deepcopy(self._constraints))
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import numbers
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def check_random_state(seed):
|
|
10
|
+
"""Turn seed into a np.random.RandomState instance.
|
|
11
|
+
|
|
12
|
+
This function was taken from scikit-learn's utils module.
|
|
13
|
+
Source GitHub: https://github.com/scikit-learn/scikit-learn
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
seed : None, int or instance of RandomState
|
|
18
|
+
If seed is None, return the RandomState singleton used by np.random.
|
|
19
|
+
If seed is an int, return a new RandomState instance seeded with seed.
|
|
20
|
+
If seed is already a RandomState instance, return it.
|
|
21
|
+
Otherwise raise ValueError.
|
|
22
|
+
|
|
23
|
+
Returns
|
|
24
|
+
-------
|
|
25
|
+
:class:`numpy:numpy.random.RandomState`
|
|
26
|
+
The random state object based on `seed` parameter.
|
|
27
|
+
|
|
28
|
+
Examples
|
|
29
|
+
--------
|
|
30
|
+
>>> from data_designer.engine.sampling_gen.utils import check_random_state
|
|
31
|
+
>>> check_random_state(42)
|
|
32
|
+
RandomState(MT19937) at 0x...
|
|
33
|
+
"""
|
|
34
|
+
if seed is None or seed is np.random:
|
|
35
|
+
return np.random.mtrand._rand
|
|
36
|
+
if isinstance(seed, numbers.Integral):
|
|
37
|
+
return np.random.RandomState(seed)
|
|
38
|
+
if isinstance(seed, np.random.RandomState):
|
|
39
|
+
return seed
|
|
40
|
+
raise ValueError("%r cannot be used to seed a numpy.random.RandomState instance" % seed)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Protocol
|
|
10
|
+
|
|
11
|
+
from data_designer.engine.errors import SecretResolutionError
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SecretResolver(Protocol):
|
|
17
|
+
def resolve(self, secret: str) -> str: ...
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SecretsFileResolver(SecretResolver):
|
|
21
|
+
_secrets: dict[str, str]
|
|
22
|
+
|
|
23
|
+
def __init__(self, filepath: Path):
|
|
24
|
+
if not filepath.exists():
|
|
25
|
+
self._secrets = {}
|
|
26
|
+
else:
|
|
27
|
+
with open(filepath) as f:
|
|
28
|
+
self._secrets = json.load(f)
|
|
29
|
+
|
|
30
|
+
def resolve(self, secret: str) -> str:
|
|
31
|
+
try:
|
|
32
|
+
return self._secrets[secret]
|
|
33
|
+
except KeyError:
|
|
34
|
+
raise SecretResolutionError(f"No secret found in secrets file with key {secret!r}")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class EnvironmentResolver(SecretResolver):
|
|
38
|
+
def resolve(self, secret: str) -> str:
|
|
39
|
+
try:
|
|
40
|
+
return os.environ[secret]
|
|
41
|
+
except KeyError:
|
|
42
|
+
raise SecretResolutionError(
|
|
43
|
+
f"Environment variable with name {secret!r} is required but not set. Please set it in your environment and try again."
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class PlaintextResolver(SecretResolver):
|
|
48
|
+
def resolve(self, secret: str) -> str:
|
|
49
|
+
return secret
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class CompositeResolver(SecretResolver):
|
|
53
|
+
_resolvers: Sequence[SecretResolver]
|
|
54
|
+
|
|
55
|
+
def __init__(self, resolvers: Sequence[SecretResolver]):
|
|
56
|
+
if len(resolvers) == 0:
|
|
57
|
+
raise SecretResolutionError("Must provide at least one SecretResolver to CompositeResolver")
|
|
58
|
+
self._resolvers = resolvers
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def resolvers(self) -> Sequence[SecretResolver]:
|
|
62
|
+
"""Get the sequence of resolvers in this composite resolver.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Sequence of SecretResolver instances used to resolve secrets.
|
|
66
|
+
"""
|
|
67
|
+
return self._resolvers
|
|
68
|
+
|
|
69
|
+
def resolve(self, secret: str) -> str:
|
|
70
|
+
errors = []
|
|
71
|
+
for resolver in self._resolvers:
|
|
72
|
+
try:
|
|
73
|
+
return resolver.resolve(secret)
|
|
74
|
+
except SecretResolutionError as err:
|
|
75
|
+
errors.append(str(err))
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
raise SecretResolutionError(
|
|
79
|
+
f"No configured resolvers were able to resolve secret {secret!r}: {', '.join(errors)}"
|
|
80
|
+
)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.engine.validators.base import BaseValidator, ValidationResult
|
|
5
|
+
from data_designer.engine.validators.local_callable import LocalCallableValidator
|
|
6
|
+
from data_designer.engine.validators.python import PythonValidator
|
|
7
|
+
from data_designer.engine.validators.remote import RemoteValidator
|
|
8
|
+
from data_designer.engine.validators.sql import SQLValidator
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"BaseValidator",
|
|
12
|
+
"LocalCallableValidator",
|
|
13
|
+
"RemoteValidator",
|
|
14
|
+
"ValidationResult",
|
|
15
|
+
"PythonValidator",
|
|
16
|
+
"SQLValidator",
|
|
17
|
+
]
|