data-designer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +15 -0
- data_designer/_version.py +34 -0
- data_designer/cli/README.md +236 -0
- data_designer/cli/__init__.py +6 -0
- data_designer/cli/commands/__init__.py +2 -0
- data_designer/cli/commands/list.py +130 -0
- data_designer/cli/commands/models.py +10 -0
- data_designer/cli/commands/providers.py +11 -0
- data_designer/cli/commands/reset.py +100 -0
- data_designer/cli/controllers/__init__.py +7 -0
- data_designer/cli/controllers/model_controller.py +246 -0
- data_designer/cli/controllers/provider_controller.py +317 -0
- data_designer/cli/forms/__init__.py +20 -0
- data_designer/cli/forms/builder.py +51 -0
- data_designer/cli/forms/field.py +180 -0
- data_designer/cli/forms/form.py +59 -0
- data_designer/cli/forms/model_builder.py +125 -0
- data_designer/cli/forms/provider_builder.py +76 -0
- data_designer/cli/main.py +44 -0
- data_designer/cli/repositories/__init__.py +8 -0
- data_designer/cli/repositories/base.py +39 -0
- data_designer/cli/repositories/model_repository.py +42 -0
- data_designer/cli/repositories/provider_repository.py +43 -0
- data_designer/cli/services/__init__.py +7 -0
- data_designer/cli/services/model_service.py +116 -0
- data_designer/cli/services/provider_service.py +111 -0
- data_designer/cli/ui.py +448 -0
- data_designer/cli/utils.py +47 -0
- data_designer/config/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +89 -0
- data_designer/config/analysis/column_statistics.py +274 -0
- data_designer/config/analysis/dataset_profiler.py +60 -0
- data_designer/config/analysis/utils/errors.py +8 -0
- data_designer/config/analysis/utils/reporting.py +188 -0
- data_designer/config/base.py +68 -0
- data_designer/config/column_configs.py +354 -0
- data_designer/config/column_types.py +168 -0
- data_designer/config/config_builder.py +660 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +11 -0
- data_designer/config/datastore.py +151 -0
- data_designer/config/default_model_settings.py +123 -0
- data_designer/config/errors.py +19 -0
- data_designer/config/interface.py +54 -0
- data_designer/config/models.py +231 -0
- data_designer/config/preview_results.py +32 -0
- data_designer/config/processors.py +41 -0
- data_designer/config/sampler_constraints.py +51 -0
- data_designer/config/sampler_params.py +604 -0
- data_designer/config/seed.py +145 -0
- data_designer/config/utils/code_lang.py +83 -0
- data_designer/config/utils/constants.py +313 -0
- data_designer/config/utils/errors.py +19 -0
- data_designer/config/utils/info.py +88 -0
- data_designer/config/utils/io_helpers.py +273 -0
- data_designer/config/utils/misc.py +81 -0
- data_designer/config/utils/numerical_helpers.py +28 -0
- data_designer/config/utils/type_helpers.py +100 -0
- data_designer/config/utils/validation.py +336 -0
- data_designer/config/utils/visualization.py +427 -0
- data_designer/config/validator_params.py +96 -0
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +55 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
- data_designer/engine/analysis/column_profilers/registry.py +20 -0
- data_designer/engine/analysis/column_statistics.py +142 -0
- data_designer/engine/analysis/dataset_profiler.py +125 -0
- data_designer/engine/analysis/errors.py +7 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +61 -0
- data_designer/engine/column_generators/generators/expression.py +63 -0
- data_designer/engine/column_generators/generators/llm_generators.py +172 -0
- data_designer/engine/column_generators/generators/samplers.py +75 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
- data_designer/engine/column_generators/generators/validation.py +147 -0
- data_designer/engine/column_generators/registry.py +56 -0
- data_designer/engine/column_generators/utils/errors.py +13 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
- data_designer/engine/configurable_task.py +82 -0
- data_designer/engine/dataset_builders/artifact_storage.py +181 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
- data_designer/engine/dataset_builders/errors.py +13 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
- data_designer/engine/dataset_builders/utils/dag.py +56 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
- data_designer/engine/dataset_builders/utils/errors.py +13 -0
- data_designer/engine/errors.py +49 -0
- data_designer/engine/model_provider.py +75 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +308 -0
- data_designer/engine/models/facade.py +225 -0
- data_designer/engine/models/litellm_overrides.py +162 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +236 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +60 -0
- data_designer/engine/models/parsers/types.py +82 -0
- data_designer/engine/models/recipes/base.py +79 -0
- data_designer/engine/models/recipes/response_recipes.py +291 -0
- data_designer/engine/models/registry.py +118 -0
- data_designer/engine/models/usage.py +75 -0
- data_designer/engine/models/utils.py +38 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +64 -0
- data_designer/engine/processing/ginja/environment.py +461 -0
- data_designer/engine/processing/ginja/exceptions.py +54 -0
- data_designer/engine/processing/ginja/record.py +30 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +8 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
- data_designer/engine/processing/gsonschema/types.py +8 -0
- data_designer/engine/processing/gsonschema/validators.py +143 -0
- data_designer/engine/processing/processors/base.py +15 -0
- data_designer/engine/processing/processors/drop_columns.py +46 -0
- data_designer/engine/processing/processors/registry.py +20 -0
- data_designer/engine/processing/utils.py +120 -0
- data_designer/engine/registry/base.py +97 -0
- data_designer/engine/registry/data_designer_registry.py +37 -0
- data_designer/engine/registry/errors.py +10 -0
- data_designer/engine/resources/managed_dataset_generator.py +35 -0
- data_designer/engine/resources/managed_dataset_repository.py +194 -0
- data_designer/engine/resources/managed_storage.py +63 -0
- data_designer/engine/resources/resource_provider.py +46 -0
- data_designer/engine/resources/seed_dataset_data_store.py +66 -0
- data_designer/engine/sampling_gen/column.py +89 -0
- data_designer/engine/sampling_gen/constraints.py +95 -0
- data_designer/engine/sampling_gen/data_sources/base.py +214 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
- data_designer/engine/sampling_gen/entities/errors.py +8 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
- data_designer/engine/sampling_gen/entities/person.py +142 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
- data_designer/engine/sampling_gen/errors.py +24 -0
- data_designer/engine/sampling_gen/generator.py +121 -0
- data_designer/engine/sampling_gen/jinja_utils.py +60 -0
- data_designer/engine/sampling_gen/people_gen.py +203 -0
- data_designer/engine/sampling_gen/person_constants.py +54 -0
- data_designer/engine/sampling_gen/schema.py +143 -0
- data_designer/engine/sampling_gen/schema_builder.py +59 -0
- data_designer/engine/sampling_gen/utils.py +40 -0
- data_designer/engine/secret_resolver.py +80 -0
- data_designer/engine/validators/__init__.py +17 -0
- data_designer/engine/validators/base.py +36 -0
- data_designer/engine/validators/local_callable.py +34 -0
- data_designer/engine/validators/python.py +245 -0
- data_designer/engine/validators/remote.py +83 -0
- data_designer/engine/validators/sql.py +60 -0
- data_designer/errors.py +5 -0
- data_designer/essentials/__init__.py +137 -0
- data_designer/interface/__init__.py +2 -0
- data_designer/interface/data_designer.py +351 -0
- data_designer/interface/errors.py +16 -0
- data_designer/interface/results.py +55 -0
- data_designer/logging.py +161 -0
- data_designer/plugin_manager.py +83 -0
- data_designer/plugins/__init__.py +6 -0
- data_designer/plugins/errors.py +10 -0
- data_designer/plugins/plugin.py +69 -0
- data_designer/plugins/registry.py +86 -0
- data_designer-0.1.0.dist-info/METADATA +173 -0
- data_designer-0.1.0.dist-info/RECORD +177 -0
- data_designer-0.1.0.dist-info/WHEEL +4 -0
- data_designer-0.1.0.dist-info/entry_points.txt +2 -0
- data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any, overload
|
|
7
|
+
|
|
8
|
+
from jsonschema import Draft202012Validator, ValidationError, validators
|
|
9
|
+
|
|
10
|
+
from data_designer.engine.processing.gsonschema.exceptions import JSONSchemaValidationError
|
|
11
|
+
from data_designer.engine.processing.gsonschema.schema_transformers import forbid_additional_properties
|
|
12
|
+
from data_designer.engine.processing.gsonschema.types import DataObjectT, JSONSchemaT, T_primitive
|
|
13
|
+
|
|
14
|
+
DEFAULT_JSONSCHEMA_VALIDATOR = Draft202012Validator
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def prune_additional_properties(
|
|
20
|
+
_, allow_additional_properties: bool, instance: DataObjectT, schema: JSONSchemaT
|
|
21
|
+
) -> None:
|
|
22
|
+
"""A JSONSchemaValidtor extension function to prune additional properties.
|
|
23
|
+
|
|
24
|
+
Operates on an individual schema in-place.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
allow_additional_properties (bool): The value of the `additionalProperties`
|
|
28
|
+
field for this schema.
|
|
29
|
+
instance (DataObjectT): The data object being validated.
|
|
30
|
+
schema (JSONSchemaT): The schema for this object.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Nothing (in place)
|
|
34
|
+
"""
|
|
35
|
+
# Only act if the instance is a dict.
|
|
36
|
+
if not isinstance(instance, dict) or allow_additional_properties:
|
|
37
|
+
return
|
|
38
|
+
|
|
39
|
+
# Allowed keys are those defined in the schema's "properties".
|
|
40
|
+
allowed = schema.get("properties", {}).keys()
|
|
41
|
+
|
|
42
|
+
# Iterate over a copy of keys so we can modify the dict in place.
|
|
43
|
+
n_removed = 0
|
|
44
|
+
for key in list(instance.keys()):
|
|
45
|
+
if key not in allowed:
|
|
46
|
+
instance.pop(key)
|
|
47
|
+
n_removed += 1
|
|
48
|
+
logger.info(f"Unspecified property removed from data object: {key}.")
|
|
49
|
+
|
|
50
|
+
if n_removed > 0:
|
|
51
|
+
logger.info(f"{n_removed} unspecified properties removed from data object.")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def extend_jsonschema_validator_with_pruning(validator):
|
|
55
|
+
"""Modify behavior of a jsonschema.Validator to use pruning.
|
|
56
|
+
|
|
57
|
+
Validators extended using this function will prune extra
|
|
58
|
+
fields, rather than raising a ValidationError, when encountering
|
|
59
|
+
extra, unspecified fiends when `additionalProperties: False` is
|
|
60
|
+
set in the validating schema.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
validator (Type[jsonschema.Validator): A validator class
|
|
64
|
+
to extend with pruning behavior.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Type[jsonschema.Validator]: A validator class that will
|
|
68
|
+
prune extra fields.
|
|
69
|
+
"""
|
|
70
|
+
return validators.extend(validator, {"additionalProperties": prune_additional_properties})
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
## We don't expect the outer data type (e.g. dict, list, or const) to be
|
|
74
|
+
## modified by the pruning action.
|
|
75
|
+
@overload
|
|
76
|
+
def validate(
|
|
77
|
+
obj: dict[str, Any],
|
|
78
|
+
schema: JSONSchemaT,
|
|
79
|
+
pruning: bool = False,
|
|
80
|
+
no_extra_properties: bool = False,
|
|
81
|
+
) -> dict[str, Any]: ...
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@overload
|
|
85
|
+
def validate(
|
|
86
|
+
obj: list[Any],
|
|
87
|
+
schema: JSONSchemaT,
|
|
88
|
+
pruning: bool = False,
|
|
89
|
+
no_extra_properties: bool = False,
|
|
90
|
+
) -> list[Any]: ...
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@overload
|
|
94
|
+
def validate(
|
|
95
|
+
obj: T_primitive,
|
|
96
|
+
schema: JSONSchemaT,
|
|
97
|
+
pruning: bool = False,
|
|
98
|
+
no_extra_properties: bool = False,
|
|
99
|
+
) -> T_primitive: ...
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def validate(
|
|
103
|
+
obj: DataObjectT,
|
|
104
|
+
schema: JSONSchemaT,
|
|
105
|
+
pruning: bool = False,
|
|
106
|
+
no_extra_properties: bool = False,
|
|
107
|
+
) -> DataObjectT:
|
|
108
|
+
"""Validate a data object against a JSONSchema.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
obj (DataObjectT): A data structure to validate against the
|
|
112
|
+
schema.
|
|
113
|
+
schema: (JSONSchemaT): A valid JSONSchema to use to validate
|
|
114
|
+
the provided data object.
|
|
115
|
+
pruning (bool): If set to `True`, then the default behavior for
|
|
116
|
+
`additionalProperties: False` is set to prune non-specified
|
|
117
|
+
properties instead of raising a ValidationError.
|
|
118
|
+
Default: `False`.
|
|
119
|
+
no_extra_properties (bool): If set to `True`, then
|
|
120
|
+
`additionalProperties: False` is set on all the schema
|
|
121
|
+
and all of its sub-schemas. This operation overrides any
|
|
122
|
+
existing settings of `additionalProperties` within the
|
|
123
|
+
schema. Default: `False`.
|
|
124
|
+
|
|
125
|
+
Raises:
|
|
126
|
+
JSONSchemaValidationError: This exception raised in the
|
|
127
|
+
event that the JSONSchema doesn't match the provided
|
|
128
|
+
schema.
|
|
129
|
+
"""
|
|
130
|
+
final_object = deepcopy(obj)
|
|
131
|
+
validator = DEFAULT_JSONSCHEMA_VALIDATOR
|
|
132
|
+
if pruning:
|
|
133
|
+
validator = extend_jsonschema_validator_with_pruning(validator)
|
|
134
|
+
|
|
135
|
+
if no_extra_properties:
|
|
136
|
+
schema = forbid_additional_properties(schema)
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
validator(schema).validate(final_object)
|
|
140
|
+
except ValidationError as exc:
|
|
141
|
+
raise JSONSchemaValidationError(str(exc)) from exc
|
|
142
|
+
|
|
143
|
+
return final_object
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
|
|
6
|
+
from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Processor(ConfigurableTask[TaskConfigT], ABC):
|
|
10
|
+
@staticmethod
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def metadata() -> ConfigurableTaskMetadata: ...
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def process(self, data: DataT, *, current_batch_number: int | None = None) -> DataT: ...
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from data_designer.config.processors import DropColumnsProcessorConfig
|
|
9
|
+
from data_designer.engine.configurable_task import ConfigurableTaskMetadata
|
|
10
|
+
from data_designer.engine.dataset_builders.artifact_storage import BatchStage
|
|
11
|
+
from data_designer.engine.processing.processors.base import Processor
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DropColumnsProcessor(Processor[DropColumnsProcessorConfig]):
|
|
17
|
+
@staticmethod
|
|
18
|
+
def metadata() -> ConfigurableTaskMetadata:
|
|
19
|
+
return ConfigurableTaskMetadata(
|
|
20
|
+
name="drop_columns",
|
|
21
|
+
description="Drop columns from the input dataset.",
|
|
22
|
+
required_resources=None,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None) -> pd.DataFrame:
|
|
26
|
+
logger.info(f"🙈 Dropping columns: {self.config.column_names}")
|
|
27
|
+
if current_batch_number is not None: # not in preview mode
|
|
28
|
+
self._save_dropped_columns_if_needed(data, current_batch_number)
|
|
29
|
+
for column in self.config.column_names:
|
|
30
|
+
if column in data.columns:
|
|
31
|
+
data.drop(columns=[column], inplace=True)
|
|
32
|
+
else:
|
|
33
|
+
logger.warning(f"⚠️ Cannot drop column: `{column}` not found in the dataset.")
|
|
34
|
+
return data
|
|
35
|
+
|
|
36
|
+
def _save_dropped_columns_if_needed(self, data: pd.DataFrame, current_batch_number: int) -> None:
|
|
37
|
+
logger.debug("📦 Saving dropped columns to dropped-columns directory")
|
|
38
|
+
dropped_column_parquet_file_name = self.artifact_storage.create_batch_file_path(
|
|
39
|
+
batch_number=current_batch_number,
|
|
40
|
+
batch_stage=BatchStage.DROPPED_COLUMNS,
|
|
41
|
+
).name
|
|
42
|
+
self.artifact_storage.write_parquet_file(
|
|
43
|
+
parquet_file_name=dropped_column_parquet_file_name,
|
|
44
|
+
dataframe=data[self.config.column_names],
|
|
45
|
+
batch_stage=BatchStage.DROPPED_COLUMNS,
|
|
46
|
+
)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.config.base import ConfigBase
|
|
5
|
+
from data_designer.config.processors import (
|
|
6
|
+
DropColumnsProcessorConfig,
|
|
7
|
+
ProcessorType,
|
|
8
|
+
)
|
|
9
|
+
from data_designer.engine.processing.processors.base import Processor
|
|
10
|
+
from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
|
|
11
|
+
from data_designer.engine.registry.base import TaskRegistry
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ProcessorRegistry(TaskRegistry[str, Processor, ConfigBase]): ...
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def create_default_processor_registry() -> ProcessorRegistry:
|
|
18
|
+
registry = ProcessorRegistry()
|
|
19
|
+
registry.register(ProcessorType.DROP_COLUMNS, DropColumnsProcessor, DropColumnsProcessorConfig, False)
|
|
20
|
+
return registry
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any, TypeVar, Union, overload
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
T = TypeVar("T")
|
|
13
|
+
K = TypeVar("K")
|
|
14
|
+
V = TypeVar("V")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def concat_datasets(datasets: list[pd.DataFrame]) -> pd.DataFrame:
|
|
18
|
+
_verify_columns_are_unique(datasets)
|
|
19
|
+
_verify_dataset_lengths_are_equal(datasets)
|
|
20
|
+
emoji = " + ".join(["💾"] * len(datasets))
|
|
21
|
+
logger.info(f"({emoji}) Concatenating {len(datasets)} datasets")
|
|
22
|
+
return pd.concat([df for df in datasets], axis=1)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# Overloads to help static type checker better understand
|
|
26
|
+
# the input/output types of the deserialize_json_values function.
|
|
27
|
+
@overload
|
|
28
|
+
def deserialize_json_values(data: str) -> Union[dict[str, Any], list[Any], Any]: ...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@overload
|
|
32
|
+
def deserialize_json_values(data: list[T]) -> list[Any]: ...
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@overload
|
|
36
|
+
def deserialize_json_values(data: dict[K, V]) -> dict[K, Any]: ...
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@overload
|
|
40
|
+
def deserialize_json_values(data: T) -> T: ...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def deserialize_json_values(data):
|
|
44
|
+
"""De-serialize JSON strings in various input formats.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
data: Input data in one of four formats:
|
|
48
|
+
- Single string (JSON string to deserialize)
|
|
49
|
+
- List of strings (list of JSON strings to deserialize)
|
|
50
|
+
- Dictionary (potentially with nested JSON strings to deserialize)
|
|
51
|
+
- Some other object that can't be deserialized.
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Deserialized data in the corresponding format:
|
|
56
|
+
- Dictionary (when input is a single string)
|
|
57
|
+
- List of dictionaries (when input is a list of strings)
|
|
58
|
+
- Dictionary (when input is a dictionary, with nested JSON strings deserialized)
|
|
59
|
+
- The original object (if there is no deserialization to perform)
|
|
60
|
+
"""
|
|
61
|
+
# Case 1: Single string input
|
|
62
|
+
if isinstance(data, str):
|
|
63
|
+
try:
|
|
64
|
+
return json.loads(data)
|
|
65
|
+
except json.JSONDecodeError:
|
|
66
|
+
return data
|
|
67
|
+
|
|
68
|
+
# Case 2: List of strings input
|
|
69
|
+
elif isinstance(data, list):
|
|
70
|
+
result = []
|
|
71
|
+
for item in data:
|
|
72
|
+
if isinstance(item, str):
|
|
73
|
+
try:
|
|
74
|
+
result.append(json.loads(item))
|
|
75
|
+
except json.JSONDecodeError:
|
|
76
|
+
result.append(item)
|
|
77
|
+
else:
|
|
78
|
+
# If list contains non-string items, recursively process them
|
|
79
|
+
result.append(deserialize_json_values(item))
|
|
80
|
+
return result
|
|
81
|
+
|
|
82
|
+
# Case 3: Dictionary input with potential nested JSON strings
|
|
83
|
+
elif isinstance(data, dict):
|
|
84
|
+
result = {}
|
|
85
|
+
for key, value in data.items():
|
|
86
|
+
if isinstance(value, str):
|
|
87
|
+
try:
|
|
88
|
+
result[key] = json.loads(value)
|
|
89
|
+
except json.JSONDecodeError:
|
|
90
|
+
result[key] = value
|
|
91
|
+
elif isinstance(value, (dict, list)):
|
|
92
|
+
# Recursively process nested dictionaries and lists
|
|
93
|
+
result[key] = deserialize_json_values(value)
|
|
94
|
+
else:
|
|
95
|
+
result[key] = value
|
|
96
|
+
return result
|
|
97
|
+
|
|
98
|
+
# Fallback for other data types
|
|
99
|
+
else:
|
|
100
|
+
return data
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _verify_columns_are_unique(datasets: list[pd.DataFrame]) -> None:
|
|
104
|
+
joined_columns = set()
|
|
105
|
+
for df in datasets:
|
|
106
|
+
columns = set(df.columns)
|
|
107
|
+
overlapping_columns = joined_columns & columns
|
|
108
|
+
if len(overlapping_columns) > 0:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"🛑 Input datasets have overlapping columns: {overlapping_columns} "
|
|
111
|
+
"Please ensure that the column names are unique."
|
|
112
|
+
)
|
|
113
|
+
joined_columns.update(columns)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _verify_dataset_lengths_are_equal(datasets: list[pd.DataFrame]) -> None:
|
|
117
|
+
if len(set([len(df) for df in datasets])) > 1:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
"🛑 Input datasets have different lengths. Please ensure that the datasets have the same number of rows."
|
|
120
|
+
)
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import threading
|
|
5
|
+
from typing import Any, Generic, Type, TypeVar
|
|
6
|
+
|
|
7
|
+
from data_designer.config.base import ConfigBase
|
|
8
|
+
from data_designer.config.utils.type_helpers import StrEnum
|
|
9
|
+
from data_designer.engine.configurable_task import ConfigurableTask
|
|
10
|
+
from data_designer.engine.registry.errors import NotFoundInRegistryError, RegistryItemNotTypeError
|
|
11
|
+
|
|
12
|
+
EnumNameT = TypeVar("EnumNameT", bound=StrEnum)
|
|
13
|
+
TaskT = TypeVar("TaskT", bound=ConfigurableTask)
|
|
14
|
+
TaskConfigT = TypeVar("TaskConfigT", bound=ConfigBase)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
|
|
18
|
+
# registered type name -> type
|
|
19
|
+
_registry: dict[EnumNameT, Type[TaskT]] = {}
|
|
20
|
+
# type -> registered type name
|
|
21
|
+
_reverse_registry: dict[Type[TaskT], EnumNameT] = {}
|
|
22
|
+
|
|
23
|
+
# registered type name -> config type
|
|
24
|
+
_config_registry: dict[EnumNameT, Type[TaskConfigT]] = {}
|
|
25
|
+
# config type -> registered type name
|
|
26
|
+
_reverse_config_registry: dict[Type[TaskConfigT], EnumNameT] = {}
|
|
27
|
+
|
|
28
|
+
# all registries are singletons
|
|
29
|
+
_instance = None
|
|
30
|
+
_lock = threading.Lock()
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def register(
|
|
34
|
+
cls,
|
|
35
|
+
name: EnumNameT,
|
|
36
|
+
task: Type[TaskT],
|
|
37
|
+
config: Type[TaskConfigT],
|
|
38
|
+
raise_on_collision: bool = False,
|
|
39
|
+
) -> None:
|
|
40
|
+
if cls._has_been_registered(name):
|
|
41
|
+
if not raise_on_collision:
|
|
42
|
+
return
|
|
43
|
+
raise ValueError(f"{name} has already been registered!")
|
|
44
|
+
|
|
45
|
+
cls._raise_if_not_type(task)
|
|
46
|
+
cls._raise_if_not_type(config)
|
|
47
|
+
|
|
48
|
+
with cls._lock:
|
|
49
|
+
cls._registry[name] = task
|
|
50
|
+
cls._reverse_registry[task] = name
|
|
51
|
+
cls._config_registry[name] = config
|
|
52
|
+
cls._reverse_config_registry[config] = name
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def get_task_type(cls, name: EnumNameT) -> Type[TaskT]:
|
|
56
|
+
cls._raise_if_not_registered(name, cls._registry)
|
|
57
|
+
return cls._registry[name]
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def get_config_type(cls, name: EnumNameT) -> Type[TaskConfigT]:
|
|
61
|
+
cls._raise_if_not_registered(name, cls._config_registry)
|
|
62
|
+
return cls._config_registry[name]
|
|
63
|
+
|
|
64
|
+
@classmethod
|
|
65
|
+
def get_registered_name(cls, task: Type[TaskT]) -> EnumNameT:
|
|
66
|
+
cls._raise_if_not_registered(task, cls._reverse_registry)
|
|
67
|
+
return cls._reverse_registry[task]
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def get_for_config_type(cls, config: Type[TaskConfigT]) -> Type[TaskT]:
|
|
71
|
+
cls._raise_if_not_registered(config, cls._reverse_config_registry)
|
|
72
|
+
name = cls._reverse_config_registry[config]
|
|
73
|
+
return cls.get_task_type(name)
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def _has_been_registered(cls, name: EnumNameT) -> bool:
|
|
77
|
+
return name in cls._registry
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def _raise_if_not_registered(cls, key: EnumNameT | Type[TaskT] | Type[TaskConfigT], mapping: dict) -> None:
|
|
81
|
+
if not (isinstance(key, StrEnum) or isinstance(key, str)):
|
|
82
|
+
cls._raise_if_not_type(key)
|
|
83
|
+
if key not in mapping:
|
|
84
|
+
raise NotFoundInRegistryError(f"{key} not found in registry")
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def _raise_if_not_type(cls, obj: Any) -> None:
|
|
88
|
+
if not isinstance(obj, type):
|
|
89
|
+
raise RegistryItemNotTypeError(f"{obj} is not a class!")
|
|
90
|
+
|
|
91
|
+
def __new__(cls, *args, **kwargs):
|
|
92
|
+
"""Registry is a singleton."""
|
|
93
|
+
if not cls._instance:
|
|
94
|
+
with cls._lock:
|
|
95
|
+
if not cls._instance:
|
|
96
|
+
cls._instance = super().__new__(cls)
|
|
97
|
+
return cls._instance
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.engine.analysis.column_profilers.registry import (
|
|
5
|
+
ColumnProfilerRegistry,
|
|
6
|
+
create_default_column_profiler_registry,
|
|
7
|
+
)
|
|
8
|
+
from data_designer.engine.column_generators.registry import (
|
|
9
|
+
ColumnGeneratorRegistry,
|
|
10
|
+
create_default_column_generator_registry,
|
|
11
|
+
)
|
|
12
|
+
from data_designer.engine.processing.processors.registry import ProcessorRegistry, create_default_processor_registry
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DataDesignerRegistry:
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
*,
|
|
19
|
+
column_generator_registry: ColumnGeneratorRegistry | None = None,
|
|
20
|
+
column_profiler_registry: ColumnProfilerRegistry | None = None,
|
|
21
|
+
processor_registry: ProcessorRegistry | None = None,
|
|
22
|
+
):
|
|
23
|
+
self._column_generator_registry = column_generator_registry or create_default_column_generator_registry()
|
|
24
|
+
self._column_profiler_registry = column_profiler_registry or create_default_column_profiler_registry()
|
|
25
|
+
self._processor_registry = processor_registry or create_default_processor_registry()
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def column_generators(self) -> ColumnGeneratorRegistry:
|
|
29
|
+
return self._column_generator_registry
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def column_profilers(self) -> ColumnProfilerRegistry:
|
|
33
|
+
return self._column_profiler_registry
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def processors(self) -> ProcessorRegistry:
|
|
37
|
+
return self._processor_registry
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.engine.errors import DataDesignerError
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class NotFoundInRegistryError(DataDesignerError): ...
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RegistryItemNotTypeError(DataDesignerError): ...
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from data_designer.engine.resources.managed_dataset_repository import ManagedDatasetRepository
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ManagedDatasetGenerator:
|
|
12
|
+
def __init__(self, managed_datasets: ManagedDatasetRepository, dataset_name: str):
|
|
13
|
+
self.managed_datasets = managed_datasets
|
|
14
|
+
self.dataset_name = dataset_name
|
|
15
|
+
|
|
16
|
+
def generate_samples(
|
|
17
|
+
self,
|
|
18
|
+
size: int = 1,
|
|
19
|
+
evidence: dict[str, Any | list[Any]] = {},
|
|
20
|
+
) -> pd.DataFrame:
|
|
21
|
+
parameters = []
|
|
22
|
+
query = f"select * from {self.dataset_name}"
|
|
23
|
+
if evidence:
|
|
24
|
+
where_conditions = []
|
|
25
|
+
for column, values in evidence.items():
|
|
26
|
+
if values:
|
|
27
|
+
values = values if isinstance(values, list) else [values]
|
|
28
|
+
formatted_values = ["?"] * len(values)
|
|
29
|
+
condition = f"{column} IN ({', '.join(formatted_values)})"
|
|
30
|
+
where_conditions.append(condition)
|
|
31
|
+
parameters.extend(values)
|
|
32
|
+
if where_conditions:
|
|
33
|
+
query += " where " + " and ".join(where_conditions)
|
|
34
|
+
query += f" order by random() limit {size}"
|
|
35
|
+
return self.managed_datasets.query(query, parameters)
|