data-designer 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +2 -0
- data_designer/_version.py +2 -2
- data_designer/cli/__init__.py +2 -0
- data_designer/cli/commands/download.py +2 -0
- data_designer/cli/commands/list.py +2 -0
- data_designer/cli/commands/models.py +2 -0
- data_designer/cli/commands/providers.py +2 -0
- data_designer/cli/commands/reset.py +2 -0
- data_designer/cli/controllers/__init__.py +2 -0
- data_designer/cli/controllers/download_controller.py +2 -0
- data_designer/cli/controllers/model_controller.py +6 -1
- data_designer/cli/controllers/provider_controller.py +6 -1
- data_designer/cli/forms/__init__.py +2 -0
- data_designer/cli/forms/builder.py +2 -0
- data_designer/cli/forms/field.py +2 -0
- data_designer/cli/forms/form.py +2 -0
- data_designer/cli/forms/model_builder.py +2 -0
- data_designer/cli/forms/provider_builder.py +2 -0
- data_designer/cli/main.py +2 -0
- data_designer/cli/repositories/__init__.py +2 -0
- data_designer/cli/repositories/base.py +2 -0
- data_designer/cli/repositories/model_repository.py +2 -0
- data_designer/cli/repositories/persona_repository.py +2 -0
- data_designer/cli/repositories/provider_repository.py +2 -0
- data_designer/cli/services/__init__.py +2 -0
- data_designer/cli/services/download_service.py +2 -0
- data_designer/cli/services/model_service.py +2 -0
- data_designer/cli/services/provider_service.py +2 -0
- data_designer/cli/ui.py +2 -0
- data_designer/cli/utils.py +2 -0
- data_designer/config/analysis/column_profilers.py +2 -0
- data_designer/config/analysis/column_statistics.py +8 -5
- data_designer/config/analysis/dataset_profiler.py +9 -3
- data_designer/config/analysis/utils/errors.py +2 -0
- data_designer/config/analysis/utils/reporting.py +7 -3
- data_designer/config/base.py +1 -0
- data_designer/config/column_configs.py +77 -7
- data_designer/config/column_types.py +33 -36
- data_designer/config/dataset_builders.py +2 -0
- data_designer/config/dataset_metadata.py +18 -0
- data_designer/config/default_model_settings.py +1 -0
- data_designer/config/errors.py +2 -0
- data_designer/config/exports.py +2 -0
- data_designer/config/interface.py +3 -2
- data_designer/config/models.py +7 -2
- data_designer/config/preview_results.py +9 -1
- data_designer/config/processors.py +2 -0
- data_designer/config/run_config.py +19 -5
- data_designer/config/sampler_constraints.py +2 -0
- data_designer/config/sampler_params.py +7 -2
- data_designer/config/seed.py +2 -0
- data_designer/config/seed_source.py +9 -3
- data_designer/config/seed_source_types.py +2 -0
- data_designer/config/utils/constants.py +2 -0
- data_designer/config/utils/errors.py +2 -0
- data_designer/config/utils/info.py +2 -0
- data_designer/config/utils/io_helpers.py +8 -3
- data_designer/config/utils/misc.py +2 -2
- data_designer/config/utils/numerical_helpers.py +2 -0
- data_designer/config/utils/type_helpers.py +2 -0
- data_designer/config/utils/visualization.py +19 -11
- data_designer/config/validator_params.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +9 -8
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +15 -19
- data_designer/engine/analysis/column_profilers/registry.py +2 -0
- data_designer/engine/analysis/column_statistics.py +5 -2
- data_designer/engine/analysis/dataset_profiler.py +12 -9
- data_designer/engine/analysis/errors.py +2 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +7 -4
- data_designer/engine/analysis/utils/judge_score_processing.py +7 -3
- data_designer/engine/column_generators/generators/base.py +26 -14
- data_designer/engine/column_generators/generators/embedding.py +4 -11
- data_designer/engine/column_generators/generators/expression.py +7 -16
- data_designer/engine/column_generators/generators/llm_completion.py +13 -47
- data_designer/engine/column_generators/generators/samplers.py +8 -14
- data_designer/engine/column_generators/generators/seed_dataset.py +9 -15
- data_designer/engine/column_generators/generators/validation.py +9 -20
- data_designer/engine/column_generators/registry.py +2 -0
- data_designer/engine/column_generators/utils/errors.py +2 -0
- data_designer/engine/column_generators/utils/generator_classification.py +2 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +2 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +4 -2
- data_designer/engine/compiler.py +3 -6
- data_designer/engine/configurable_task.py +12 -13
- data_designer/engine/dataset_builders/artifact_storage.py +87 -8
- data_designer/engine/dataset_builders/column_wise_builder.py +34 -35
- data_designer/engine/dataset_builders/errors.py +2 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +13 -4
- data_designer/engine/dataset_builders/utils/config_compiler.py +2 -0
- data_designer/engine/dataset_builders/utils/dag.py +7 -2
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +35 -25
- data_designer/engine/dataset_builders/utils/errors.py +2 -0
- data_designer/engine/errors.py +2 -0
- data_designer/engine/model_provider.py +2 -0
- data_designer/engine/models/errors.py +23 -31
- data_designer/engine/models/facade.py +12 -9
- data_designer/engine/models/factory.py +42 -0
- data_designer/engine/models/litellm_overrides.py +16 -11
- data_designer/engine/models/parsers/errors.py +2 -0
- data_designer/engine/models/parsers/parser.py +2 -2
- data_designer/engine/models/parsers/postprocessors.py +1 -0
- data_designer/engine/models/parsers/tag_parsers.py +2 -0
- data_designer/engine/models/parsers/types.py +2 -0
- data_designer/engine/models/recipes/base.py +2 -0
- data_designer/engine/models/recipes/response_recipes.py +2 -0
- data_designer/engine/models/registry.py +11 -18
- data_designer/engine/models/telemetry.py +6 -2
- data_designer/engine/processing/ginja/ast.py +2 -0
- data_designer/engine/processing/ginja/environment.py +2 -0
- data_designer/engine/processing/ginja/exceptions.py +2 -0
- data_designer/engine/processing/ginja/record.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +9 -2
- data_designer/engine/processing/gsonschema/schema_transformers.py +2 -0
- data_designer/engine/processing/gsonschema/types.py +2 -0
- data_designer/engine/processing/gsonschema/validators.py +10 -6
- data_designer/engine/processing/processors/base.py +1 -5
- data_designer/engine/processing/processors/drop_columns.py +7 -10
- data_designer/engine/processing/processors/registry.py +2 -0
- data_designer/engine/processing/processors/schema_transform.py +7 -10
- data_designer/engine/processing/utils.py +7 -3
- data_designer/engine/registry/base.py +2 -0
- data_designer/engine/registry/data_designer_registry.py +2 -0
- data_designer/engine/registry/errors.py +2 -0
- data_designer/engine/resources/managed_dataset_generator.py +6 -2
- data_designer/engine/resources/managed_dataset_repository.py +8 -5
- data_designer/engine/resources/managed_storage.py +2 -0
- data_designer/engine/resources/resource_provider.py +20 -1
- data_designer/engine/resources/seed_reader.py +7 -2
- data_designer/engine/sampling_gen/column.py +2 -0
- data_designer/engine/sampling_gen/constraints.py +8 -2
- data_designer/engine/sampling_gen/data_sources/base.py +10 -7
- data_designer/engine/sampling_gen/data_sources/errors.py +2 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +27 -22
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +2 -2
- data_designer/engine/sampling_gen/entities/email_address_utils.py +2 -0
- data_designer/engine/sampling_gen/entities/errors.py +2 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +2 -0
- data_designer/engine/sampling_gen/entities/person.py +2 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +8 -1
- data_designer/engine/sampling_gen/errors.py +2 -0
- data_designer/engine/sampling_gen/generator.py +5 -4
- data_designer/engine/sampling_gen/jinja_utils.py +7 -3
- data_designer/engine/sampling_gen/people_gen.py +7 -7
- data_designer/engine/sampling_gen/person_constants.py +2 -0
- data_designer/engine/sampling_gen/schema.py +5 -1
- data_designer/engine/sampling_gen/schema_builder.py +2 -0
- data_designer/engine/sampling_gen/utils.py +7 -1
- data_designer/engine/secret_resolver.py +2 -0
- data_designer/engine/validation.py +2 -2
- data_designer/engine/validators/__init__.py +2 -0
- data_designer/engine/validators/base.py +2 -0
- data_designer/engine/validators/local_callable.py +7 -2
- data_designer/engine/validators/python.py +7 -1
- data_designer/engine/validators/remote.py +7 -1
- data_designer/engine/validators/sql.py +8 -3
- data_designer/errors.py +2 -0
- data_designer/essentials/__init__.py +2 -0
- data_designer/interface/data_designer.py +36 -39
- data_designer/interface/errors.py +2 -0
- data_designer/interface/results.py +9 -2
- data_designer/lazy_heavy_imports.py +54 -0
- data_designer/logging.py +2 -0
- data_designer/plugins/__init__.py +2 -0
- data_designer/plugins/errors.py +2 -0
- data_designer/plugins/plugin.py +0 -1
- data_designer/plugins/registry.py +2 -0
- data_designer/plugins/testing/__init__.py +2 -0
- data_designer/plugins/testing/stubs.py +21 -43
- data_designer/plugins/testing/utils.py +2 -0
- {data_designer-0.3.3.dist-info → data_designer-0.3.5.dist-info}/METADATA +19 -4
- data_designer-0.3.5.dist-info/RECORD +196 -0
- data_designer-0.3.3.dist-info/RECORD +0 -193
- {data_designer-0.3.3.dist-info → data_designer-0.3.5.dist-info}/WHEEL +0 -0
- {data_designer-0.3.3.dist-info → data_designer-0.3.5.dist-info}/entry_points.txt +0 -0
- {data_designer-0.3.3.dist-info → data_designer-0.3.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,29 +1,26 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
import json
|
|
5
7
|
import logging
|
|
6
|
-
|
|
7
|
-
import pandas as pd
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
8
9
|
|
|
9
10
|
from data_designer.config.processors import SchemaTransformProcessorConfig
|
|
10
|
-
from data_designer.engine.configurable_task import ConfigurableTaskMetadata
|
|
11
11
|
from data_designer.engine.dataset_builders.artifact_storage import BatchStage
|
|
12
12
|
from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering
|
|
13
13
|
from data_designer.engine.processing.processors.base import Processor
|
|
14
14
|
from data_designer.engine.processing.utils import deserialize_json_values
|
|
15
|
+
from data_designer.lazy_heavy_imports import pd
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import pandas as pd
|
|
15
19
|
|
|
16
20
|
logger = logging.getLogger(__name__)
|
|
17
21
|
|
|
18
22
|
|
|
19
23
|
class SchemaTransformProcessor(WithJinja2UserTemplateRendering, Processor[SchemaTransformProcessorConfig]):
|
|
20
|
-
@staticmethod
|
|
21
|
-
def metadata() -> ConfigurableTaskMetadata:
|
|
22
|
-
return ConfigurableTaskMetadata(
|
|
23
|
-
name="schema_transform_processor",
|
|
24
|
-
description="Generate dataset with transformed schema using a Jinja2 template.",
|
|
25
|
-
)
|
|
26
|
-
|
|
27
24
|
@property
|
|
28
25
|
def template_as_str(self) -> str:
|
|
29
26
|
return json.dumps(self.config.template)
|
|
@@ -1,13 +1,18 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
import ast
|
|
5
7
|
import json
|
|
6
8
|
import logging
|
|
7
9
|
import re
|
|
8
|
-
from typing import Any, TypeVar, overload
|
|
10
|
+
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
|
11
|
+
|
|
12
|
+
from data_designer.lazy_heavy_imports import pd
|
|
9
13
|
|
|
10
|
-
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import pandas as pd
|
|
11
16
|
|
|
12
17
|
logger = logging.getLogger(__name__)
|
|
13
18
|
|
|
@@ -52,7 +57,6 @@ def deserialize_json_values(data):
|
|
|
52
57
|
- Dictionary (potentially with nested JSON strings to deserialize)
|
|
53
58
|
- Some other object that can't be deserialized.
|
|
54
59
|
|
|
55
|
-
|
|
56
60
|
Returns:
|
|
57
61
|
Deserialized data in the corresponding format:
|
|
58
62
|
- Dictionary (when input is a single string)
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
from data_designer.engine.analysis.column_profilers.registry import (
|
|
5
7
|
ColumnProfilerRegistry,
|
|
6
8
|
create_default_column_profiler_registry,
|
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from
|
|
4
|
+
from __future__ import annotations
|
|
5
5
|
|
|
6
|
-
import
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
8
|
from data_designer.engine.resources.managed_dataset_repository import ManagedDatasetRepository
|
|
9
|
+
from data_designer.lazy_heavy_imports import pd
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
import pandas as pd
|
|
9
13
|
|
|
10
14
|
|
|
11
15
|
class ManagedDatasetGenerator:
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
import logging
|
|
5
7
|
import tempfile
|
|
6
8
|
import threading
|
|
@@ -9,13 +11,15 @@ from abc import ABC, abstractmethod
|
|
|
9
11
|
from dataclasses import dataclass
|
|
10
12
|
from functools import cached_property
|
|
11
13
|
from pathlib import Path
|
|
12
|
-
from typing import Any
|
|
13
|
-
|
|
14
|
-
import duckdb
|
|
15
|
-
import pandas as pd
|
|
14
|
+
from typing import TYPE_CHECKING, Any
|
|
16
15
|
|
|
17
16
|
from data_designer.config.utils.constants import LOCALES_WITH_MANAGED_DATASETS
|
|
18
17
|
from data_designer.engine.resources.managed_storage import LocalBlobStorageProvider, ManagedBlobStorage
|
|
18
|
+
from data_designer.lazy_heavy_imports import duckdb, pd
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
import duckdb
|
|
22
|
+
import pandas as pd
|
|
19
23
|
|
|
20
24
|
logger = logging.getLogger(__name__)
|
|
21
25
|
|
|
@@ -52,7 +56,6 @@ class Table:
|
|
|
52
56
|
|
|
53
57
|
DataCatalog = list[Table]
|
|
54
58
|
|
|
55
|
-
|
|
56
59
|
# For now we hardcode the remote data catalog in code. This make it easier
|
|
57
60
|
# initialize the data catalog. Eventually we can make this work more
|
|
58
61
|
# dynamically once this data catalog pattern becomes more widely adopted.
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
import logging
|
|
5
7
|
from abc import ABC, abstractmethod
|
|
6
8
|
from collections.abc import Iterator
|
|
@@ -1,14 +1,18 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
from data_designer.config.base import ConfigBase
|
|
7
|
+
from data_designer.config.dataset_metadata import DatasetMetadata
|
|
5
8
|
from data_designer.config.models import ModelConfig
|
|
6
9
|
from data_designer.config.run_config import RunConfig
|
|
7
10
|
from data_designer.config.seed_source import SeedSource
|
|
8
11
|
from data_designer.config.utils.type_helpers import StrEnum
|
|
9
12
|
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
|
|
10
13
|
from data_designer.engine.model_provider import ModelProviderRegistry
|
|
11
|
-
from data_designer.engine.models.
|
|
14
|
+
from data_designer.engine.models.factory import create_model_registry
|
|
15
|
+
from data_designer.engine.models.registry import ModelRegistry
|
|
12
16
|
from data_designer.engine.resources.managed_storage import ManagedBlobStorage, init_managed_blob_storage
|
|
13
17
|
from data_designer.engine.resources.seed_reader import SeedReader, SeedReaderRegistry
|
|
14
18
|
from data_designer.engine.secret_resolver import SecretResolver
|
|
@@ -27,6 +31,17 @@ class ResourceProvider(ConfigBase):
|
|
|
27
31
|
run_config: RunConfig = RunConfig()
|
|
28
32
|
seed_reader: SeedReader | None = None
|
|
29
33
|
|
|
34
|
+
def get_dataset_metadata(self) -> DatasetMetadata:
|
|
35
|
+
"""Get metadata about the dataset being generated.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
DatasetMetadata with seed column names and other metadata.
|
|
39
|
+
"""
|
|
40
|
+
seed_column_names = []
|
|
41
|
+
if self.seed_reader is not None:
|
|
42
|
+
seed_column_names = self.seed_reader.get_column_names()
|
|
43
|
+
return DatasetMetadata(seed_column_names=seed_column_names)
|
|
44
|
+
|
|
30
45
|
|
|
31
46
|
def create_resource_provider(
|
|
32
47
|
*,
|
|
@@ -39,12 +54,16 @@ def create_resource_provider(
|
|
|
39
54
|
seed_dataset_source: SeedSource | None = None,
|
|
40
55
|
run_config: RunConfig | None = None,
|
|
41
56
|
) -> ResourceProvider:
|
|
57
|
+
"""Factory function for creating a ResourceProvider instance.
|
|
58
|
+
This function triggers lazy loading of heavy dependencies like litellm.
|
|
59
|
+
"""
|
|
42
60
|
seed_reader = None
|
|
43
61
|
if seed_dataset_source:
|
|
44
62
|
seed_reader = seed_reader_registry.get_reader(
|
|
45
63
|
seed_dataset_source,
|
|
46
64
|
secret_resolver,
|
|
47
65
|
)
|
|
66
|
+
|
|
48
67
|
return ResourceProvider(
|
|
49
68
|
artifact_storage=artifact_storage,
|
|
50
69
|
model_registry=create_model_registry(
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
from abc import ABC, abstractmethod
|
|
5
7
|
from collections.abc import Sequence
|
|
6
|
-
from typing import Generic, TypeVar, get_args, get_origin
|
|
8
|
+
from typing import TYPE_CHECKING, Generic, TypeVar, get_args, get_origin
|
|
7
9
|
|
|
8
|
-
import duckdb
|
|
9
10
|
from huggingface_hub import HfFileSystem
|
|
10
11
|
from typing_extensions import Self
|
|
11
12
|
|
|
@@ -17,6 +18,10 @@ from data_designer.config.seed_source import (
|
|
|
17
18
|
)
|
|
18
19
|
from data_designer.engine.secret_resolver import SecretResolver
|
|
19
20
|
from data_designer.errors import DataDesignerError
|
|
21
|
+
from data_designer.lazy_heavy_imports import duckdb
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
import duckdb
|
|
20
25
|
|
|
21
26
|
|
|
22
27
|
class SeedReaderError(DataDesignerError): ...
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
from typing import Any
|
|
5
7
|
|
|
6
8
|
from pydantic import field_serializer, model_validator
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
5
8
|
|
|
6
|
-
import numpy as np
|
|
7
|
-
import pandas as pd
|
|
8
9
|
from numpy.typing import NDArray
|
|
9
10
|
|
|
10
11
|
from data_designer.config.base import ConfigBase
|
|
@@ -15,6 +16,11 @@ from data_designer.config.sampler_constraints import (
|
|
|
15
16
|
InequalityOperator,
|
|
16
17
|
ScalarInequalityConstraint,
|
|
17
18
|
)
|
|
19
|
+
from data_designer.lazy_heavy_imports import np, pd
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
import numpy as np
|
|
23
|
+
import pandas as pd
|
|
18
24
|
|
|
19
25
|
|
|
20
26
|
class ConstraintChecker(ConfigBase, ABC):
|
|
@@ -1,24 +1,27 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
from abc import ABC, abstractmethod
|
|
5
|
-
from typing import Any, Generic, TypeVar
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
|
6
8
|
|
|
7
|
-
import numpy as np
|
|
8
|
-
import pandas as pd
|
|
9
9
|
from numpy.typing import NDArray
|
|
10
|
-
from scipy import stats
|
|
11
10
|
|
|
12
11
|
from data_designer.config.sampler_params import SamplerParamsT
|
|
13
12
|
from data_designer.engine.sampling_gen.utils import check_random_state
|
|
13
|
+
from data_designer.lazy_heavy_imports import np, pd, scipy
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pandas as pd
|
|
18
|
+
import scipy
|
|
14
19
|
|
|
15
20
|
NumpyArray1dT = NDArray[Any]
|
|
16
21
|
RadomStateT = int | np.random.RandomState
|
|
17
22
|
|
|
18
|
-
|
|
19
23
|
GenericParamsT = TypeVar("GenericParamsT", bound=SamplerParamsT)
|
|
20
24
|
|
|
21
|
-
|
|
22
25
|
###########################################################
|
|
23
26
|
# Processing Mixins
|
|
24
27
|
# -----------------
|
|
@@ -208,7 +211,7 @@ class Sampler(DataSource[GenericParamsT], ABC):
|
|
|
208
211
|
class ScipyStatsSampler(Sampler[GenericParamsT], ABC):
|
|
209
212
|
@property
|
|
210
213
|
@abstractmethod
|
|
211
|
-
def distribution(self) -> stats.rv_continuous | stats.rv_discrete: ...
|
|
214
|
+
def distribution(self) -> scipy.stats.rv_continuous | scipy.stats.rv_discrete: ...
|
|
212
215
|
|
|
213
216
|
def sample(self, num_samples: int) -> NumpyArray1dT:
|
|
214
217
|
return self.distribution.rvs(size=num_samples, random_state=self.rng)
|
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
import
|
|
4
|
+
from __future__ import annotations
|
|
5
5
|
|
|
6
|
-
import
|
|
7
|
-
|
|
8
|
-
from scipy import stats
|
|
6
|
+
import uuid
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
9
8
|
|
|
10
9
|
from data_designer.config.sampler_params import (
|
|
11
10
|
BernoulliMixtureSamplerParams,
|
|
@@ -40,6 +39,12 @@ from data_designer.engine.sampling_gen.data_sources.errors import (
|
|
|
40
39
|
)
|
|
41
40
|
from data_designer.engine.sampling_gen.entities.dataset_based_person_fields import PERSONA_FIELDS, PII_FIELDS
|
|
42
41
|
from data_designer.engine.sampling_gen.people_gen import PeopleGen
|
|
42
|
+
from data_designer.lazy_heavy_imports import np, pd, scipy
|
|
43
|
+
|
|
44
|
+
if TYPE_CHECKING:
|
|
45
|
+
import numpy as np
|
|
46
|
+
import pandas as pd
|
|
47
|
+
import scipy
|
|
43
48
|
|
|
44
49
|
ONE_BILLION = 10**9
|
|
45
50
|
|
|
@@ -264,8 +269,8 @@ class ScipySampler(TypeConversionMixin, ScipyStatsSampler[ScipySamplerParams]):
|
|
|
264
269
|
"""Escape hatch sampler to give users access to any scipy.stats distribution."""
|
|
265
270
|
|
|
266
271
|
@property
|
|
267
|
-
def distribution(self) -> stats.rv_continuous | stats.rv_discrete:
|
|
268
|
-
return getattr(stats, self.params.dist_name)(**self.params.dist_params)
|
|
272
|
+
def distribution(self) -> scipy.stats.rv_continuous | scipy.stats.rv_discrete:
|
|
273
|
+
return getattr(scipy.stats, self.params.dist_name)(**self.params.dist_params)
|
|
269
274
|
|
|
270
275
|
def _validate(self) -> None:
|
|
271
276
|
_validate_scipy_distribution(self.params.dist_name, self.params.dist_params)
|
|
@@ -274,16 +279,16 @@ class ScipySampler(TypeConversionMixin, ScipyStatsSampler[ScipySamplerParams]):
|
|
|
274
279
|
@SamplerRegistry.register(SamplerType.BERNOULLI)
|
|
275
280
|
class BernoulliSampler(TypeConversionMixin, ScipyStatsSampler[BernoulliSamplerParams]):
|
|
276
281
|
@property
|
|
277
|
-
def distribution(self) -> stats.rv_discrete:
|
|
278
|
-
return stats.bernoulli(p=self.params.p)
|
|
282
|
+
def distribution(self) -> scipy.stats.rv_discrete:
|
|
283
|
+
return scipy.stats.bernoulli(p=self.params.p)
|
|
279
284
|
|
|
280
285
|
|
|
281
286
|
@SamplerRegistry.register(SamplerType.BERNOULLI_MIXTURE)
|
|
282
287
|
class BernoulliMixtureSampler(TypeConversionMixin, Sampler[BernoulliMixtureSamplerParams]):
|
|
283
288
|
def sample(self, num_samples: int) -> NumpyArray1dT:
|
|
284
|
-
return stats.bernoulli(p=self.params.p).rvs(size=num_samples) * getattr(
|
|
285
|
-
|
|
286
|
-
).rvs(size=num_samples)
|
|
289
|
+
return scipy.stats.bernoulli(p=self.params.p).rvs(size=num_samples) * getattr(
|
|
290
|
+
scipy.stats, self.params.dist_name
|
|
291
|
+
)(**self.params.dist_params).rvs(size=num_samples)
|
|
287
292
|
|
|
288
293
|
def _validate(self) -> None:
|
|
289
294
|
_validate_scipy_distribution(self.params.dist_name, self.params.dist_params)
|
|
@@ -292,29 +297,29 @@ class BernoulliMixtureSampler(TypeConversionMixin, Sampler[BernoulliMixtureSampl
|
|
|
292
297
|
@SamplerRegistry.register(SamplerType.BINOMIAL)
|
|
293
298
|
class BinomialSampler(TypeConversionMixin, ScipyStatsSampler[BinomialSamplerParams]):
|
|
294
299
|
@property
|
|
295
|
-
def distribution(self) -> stats.rv_discrete:
|
|
296
|
-
return stats.binom(n=self.params.n, p=self.params.p)
|
|
300
|
+
def distribution(self) -> scipy.stats.rv_discrete:
|
|
301
|
+
return scipy.stats.binom(n=self.params.n, p=self.params.p)
|
|
297
302
|
|
|
298
303
|
|
|
299
304
|
@SamplerRegistry.register(SamplerType.GAUSSIAN)
|
|
300
305
|
class GaussianSampler(TypeConversionMixin, ScipyStatsSampler[GaussianSamplerParams]):
|
|
301
306
|
@property
|
|
302
|
-
def distribution(self) -> stats.rv_continuous:
|
|
303
|
-
return stats.norm(loc=self.params.mean, scale=self.params.stddev)
|
|
307
|
+
def distribution(self) -> scipy.stats.rv_continuous:
|
|
308
|
+
return scipy.stats.norm(loc=self.params.mean, scale=self.params.stddev)
|
|
304
309
|
|
|
305
310
|
|
|
306
311
|
@SamplerRegistry.register(SamplerType.POISSON)
|
|
307
312
|
class PoissonSampler(TypeConversionMixin, ScipyStatsSampler[PoissonSamplerParams]):
|
|
308
313
|
@property
|
|
309
|
-
def distribution(self) -> stats.rv_discrete:
|
|
310
|
-
return stats.poisson(mu=self.params.mean)
|
|
314
|
+
def distribution(self) -> scipy.stats.rv_discrete:
|
|
315
|
+
return scipy.stats.poisson(mu=self.params.mean)
|
|
311
316
|
|
|
312
317
|
|
|
313
318
|
@SamplerRegistry.register(SamplerType.UNIFORM)
|
|
314
319
|
class UniformSampler(TypeConversionMixin, ScipyStatsSampler[UniformSamplerParams]):
|
|
315
320
|
@property
|
|
316
|
-
def distribution(self) -> stats.rv_continuous:
|
|
317
|
-
return stats.uniform(loc=self.params.low, scale=self.params.high - self.params.low)
|
|
321
|
+
def distribution(self) -> scipy.stats.rv_continuous:
|
|
322
|
+
return scipy.stats.uniform(loc=self.params.low, scale=self.params.high - self.params.low)
|
|
318
323
|
|
|
319
324
|
|
|
320
325
|
###################################################
|
|
@@ -328,14 +333,14 @@ def load_sampler(sampler_type: SamplerType, **params) -> DataSource:
|
|
|
328
333
|
|
|
329
334
|
|
|
330
335
|
def _validate_scipy_distribution(dist_name: str, dist_params: dict) -> None:
|
|
331
|
-
if not hasattr(stats, dist_name):
|
|
336
|
+
if not hasattr(scipy.stats, dist_name):
|
|
332
337
|
raise InvalidSamplerParamsError(f"Distribution {dist_name} not found in scipy.stats")
|
|
333
|
-
if not hasattr(getattr(stats, dist_name), "rvs"):
|
|
338
|
+
if not hasattr(getattr(scipy.stats, dist_name), "rvs"):
|
|
334
339
|
raise InvalidSamplerParamsError(
|
|
335
340
|
f"Distribution {dist_name} does not have a `rvs` method, which is required for sampling."
|
|
336
341
|
)
|
|
337
342
|
try:
|
|
338
|
-
getattr(stats, dist_name)(**dist_params)
|
|
343
|
+
getattr(scipy.stats, dist_name)(**dist_params)
|
|
339
344
|
except Exception:
|
|
340
345
|
raise InvalidSamplerParamsError(
|
|
341
346
|
f"Distribution parameters {dist_params} are not a valid for distribution '{dist_name}'"
|
|
@@ -10,8 +10,9 @@ This file contains all possible fields that:
|
|
|
10
10
|
Do not add any other code or logic in this file.
|
|
11
11
|
"""
|
|
12
12
|
|
|
13
|
-
|
|
13
|
+
from __future__ import annotations
|
|
14
14
|
|
|
15
|
+
REQUIRED_FIELDS = {"first_name", "last_name", "age", "locale"}
|
|
15
16
|
|
|
16
17
|
PII_FIELDS = [
|
|
17
18
|
# Core demographic fields
|
|
@@ -52,7 +53,6 @@ PII_FIELDS = [
|
|
|
52
53
|
"third_language",
|
|
53
54
|
]
|
|
54
55
|
|
|
55
|
-
|
|
56
56
|
PERSONA_FIELDS = [
|
|
57
57
|
# Core persona fields
|
|
58
58
|
"persona",
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
import random
|
|
5
7
|
from datetime import date, timedelta
|
|
6
8
|
from typing import Any, Literal, TypeAlias
|
|
@@ -1,12 +1,19 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
import random
|
|
5
7
|
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
6
9
|
|
|
7
|
-
import pandas as pd
|
|
8
10
|
from pydantic import BaseModel, Field, field_validator
|
|
9
11
|
|
|
12
|
+
from data_designer.lazy_heavy_imports import pd
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
10
17
|
ZIP_AREA_CODE_DATA = pd.read_parquet(Path(__file__).parent / "assets" / "zip_area_code_map.parquet")
|
|
11
18
|
ZIPCODE_AREA_CODE_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DATA["area_code"]))
|
|
12
19
|
ZIPCODE_POPULATION_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DATA["count"]))
|
|
@@ -6,18 +6,19 @@ from __future__ import annotations
|
|
|
6
6
|
from collections.abc import Callable
|
|
7
7
|
from typing import TYPE_CHECKING
|
|
8
8
|
|
|
9
|
-
import networkx as nx
|
|
10
|
-
import numpy as np
|
|
11
|
-
import pandas as pd
|
|
12
|
-
|
|
13
9
|
from data_designer.engine.sampling_gen.data_sources.base import RadomStateT
|
|
14
10
|
from data_designer.engine.sampling_gen.errors import RejectionSamplingError
|
|
15
11
|
from data_designer.engine.sampling_gen.jinja_utils import JinjaDataFrame
|
|
16
12
|
from data_designer.engine.sampling_gen.people_gen import create_people_gen_resource
|
|
17
13
|
from data_designer.engine.sampling_gen.schema import DataSchema
|
|
18
14
|
from data_designer.engine.sampling_gen.utils import check_random_state
|
|
15
|
+
from data_designer.lazy_heavy_imports import np, nx, pd
|
|
19
16
|
|
|
20
17
|
if TYPE_CHECKING:
|
|
18
|
+
import networkx as nx
|
|
19
|
+
import numpy as np
|
|
20
|
+
import pandas as pd
|
|
21
|
+
|
|
21
22
|
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
|
|
22
23
|
from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
|
|
23
24
|
from data_designer.engine.sampling_gen.column import ConditionalDataColumn
|
|
@@ -1,15 +1,19 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
import
|
|
5
|
-
from typing import Any
|
|
4
|
+
from __future__ import annotations
|
|
6
5
|
|
|
7
|
-
import
|
|
6
|
+
import ast
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
8
|
|
|
9
9
|
from data_designer.engine.processing.ginja.environment import (
|
|
10
10
|
UserTemplateSandboxEnvironment,
|
|
11
11
|
WithJinja2UserTemplateRendering,
|
|
12
12
|
)
|
|
13
|
+
from data_designer.lazy_heavy_imports import pd
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
import pandas as pd
|
|
13
17
|
|
|
14
18
|
|
|
15
19
|
class JinjaDataFrame(WithJinja2UserTemplateRendering):
|
|
@@ -10,9 +10,6 @@ from collections.abc import Callable
|
|
|
10
10
|
from copy import deepcopy
|
|
11
11
|
from typing import TYPE_CHECKING, Any, TypeAlias
|
|
12
12
|
|
|
13
|
-
import pandas as pd
|
|
14
|
-
from faker import Faker
|
|
15
|
-
|
|
16
13
|
from data_designer.config.utils.constants import DEFAULT_AGE_RANGE
|
|
17
14
|
from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
|
|
18
15
|
from data_designer.engine.sampling_gen.entities.dataset_based_person_fields import PERSONA_FIELDS, PII_FIELDS
|
|
@@ -22,12 +19,15 @@ from data_designer.engine.sampling_gen.entities.person import (
|
|
|
22
19
|
)
|
|
23
20
|
from data_designer.engine.sampling_gen.errors import ManagedDatasetGeneratorError
|
|
24
21
|
from data_designer.engine.sampling_gen.person_constants import faker_constants
|
|
22
|
+
from data_designer.lazy_heavy_imports import faker, pd
|
|
25
23
|
|
|
26
24
|
if TYPE_CHECKING:
|
|
27
|
-
|
|
25
|
+
import faker
|
|
26
|
+
import pandas as pd
|
|
28
27
|
|
|
28
|
+
from data_designer.engine.sampling_gen.schema import DataSchema
|
|
29
29
|
|
|
30
|
-
EngineT: TypeAlias = Faker | ManagedDatasetGenerator
|
|
30
|
+
EngineT: TypeAlias = faker.Faker | ManagedDatasetGenerator
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
class PeopleGen(ABC):
|
|
@@ -46,7 +46,7 @@ class PeopleGen(ABC):
|
|
|
46
46
|
|
|
47
47
|
class PeopleGenFaker(PeopleGen):
|
|
48
48
|
@property
|
|
49
|
-
def _fake(self) -> Faker:
|
|
49
|
+
def _fake(self) -> faker.Faker:
|
|
50
50
|
return self._engine
|
|
51
51
|
|
|
52
52
|
def try_fake_else_none(self, attr_name: str, none_fill: Any | None = None) -> type:
|
|
@@ -193,7 +193,7 @@ def create_people_gen_resource(
|
|
|
193
193
|
for params in [column.params, *list(column.conditional_params.values())]:
|
|
194
194
|
if params.people_gen_key not in people_gen_resource:
|
|
195
195
|
people_gen_resource[params.people_gen_key] = PeopleGenFaker(
|
|
196
|
-
engine=Faker(params.locale), locale=params.locale
|
|
196
|
+
engine=faker.Faker(params.locale), locale=params.locale
|
|
197
197
|
)
|
|
198
198
|
|
|
199
199
|
return people_gen_resource
|
|
@@ -4,8 +4,8 @@
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
6
|
from functools import cached_property
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
7
8
|
|
|
8
|
-
import networkx as nx
|
|
9
9
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
10
10
|
from typing_extensions import Self
|
|
11
11
|
|
|
@@ -14,6 +14,10 @@ from data_designer.config.sampler_constraints import ColumnConstraintT
|
|
|
14
14
|
from data_designer.config.sampler_params import SamplerType
|
|
15
15
|
from data_designer.engine.sampling_gen.column import ConditionalDataColumn
|
|
16
16
|
from data_designer.engine.sampling_gen.constraints import ConstraintChecker, get_constraint_checker
|
|
17
|
+
from data_designer.lazy_heavy_imports import nx
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
import networkx as nx
|
|
17
21
|
|
|
18
22
|
|
|
19
23
|
class Dag(BaseModel):
|