data-designer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +15 -0
- data_designer/_version.py +34 -0
- data_designer/cli/README.md +236 -0
- data_designer/cli/__init__.py +6 -0
- data_designer/cli/commands/__init__.py +2 -0
- data_designer/cli/commands/list.py +130 -0
- data_designer/cli/commands/models.py +10 -0
- data_designer/cli/commands/providers.py +11 -0
- data_designer/cli/commands/reset.py +100 -0
- data_designer/cli/controllers/__init__.py +7 -0
- data_designer/cli/controllers/model_controller.py +246 -0
- data_designer/cli/controllers/provider_controller.py +317 -0
- data_designer/cli/forms/__init__.py +20 -0
- data_designer/cli/forms/builder.py +51 -0
- data_designer/cli/forms/field.py +180 -0
- data_designer/cli/forms/form.py +59 -0
- data_designer/cli/forms/model_builder.py +125 -0
- data_designer/cli/forms/provider_builder.py +76 -0
- data_designer/cli/main.py +44 -0
- data_designer/cli/repositories/__init__.py +8 -0
- data_designer/cli/repositories/base.py +39 -0
- data_designer/cli/repositories/model_repository.py +42 -0
- data_designer/cli/repositories/provider_repository.py +43 -0
- data_designer/cli/services/__init__.py +7 -0
- data_designer/cli/services/model_service.py +116 -0
- data_designer/cli/services/provider_service.py +111 -0
- data_designer/cli/ui.py +448 -0
- data_designer/cli/utils.py +47 -0
- data_designer/config/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +89 -0
- data_designer/config/analysis/column_statistics.py +274 -0
- data_designer/config/analysis/dataset_profiler.py +60 -0
- data_designer/config/analysis/utils/errors.py +8 -0
- data_designer/config/analysis/utils/reporting.py +188 -0
- data_designer/config/base.py +68 -0
- data_designer/config/column_configs.py +354 -0
- data_designer/config/column_types.py +168 -0
- data_designer/config/config_builder.py +660 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +11 -0
- data_designer/config/datastore.py +151 -0
- data_designer/config/default_model_settings.py +123 -0
- data_designer/config/errors.py +19 -0
- data_designer/config/interface.py +54 -0
- data_designer/config/models.py +231 -0
- data_designer/config/preview_results.py +32 -0
- data_designer/config/processors.py +41 -0
- data_designer/config/sampler_constraints.py +51 -0
- data_designer/config/sampler_params.py +604 -0
- data_designer/config/seed.py +145 -0
- data_designer/config/utils/code_lang.py +83 -0
- data_designer/config/utils/constants.py +313 -0
- data_designer/config/utils/errors.py +19 -0
- data_designer/config/utils/info.py +88 -0
- data_designer/config/utils/io_helpers.py +273 -0
- data_designer/config/utils/misc.py +81 -0
- data_designer/config/utils/numerical_helpers.py +28 -0
- data_designer/config/utils/type_helpers.py +100 -0
- data_designer/config/utils/validation.py +336 -0
- data_designer/config/utils/visualization.py +427 -0
- data_designer/config/validator_params.py +96 -0
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +55 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
- data_designer/engine/analysis/column_profilers/registry.py +20 -0
- data_designer/engine/analysis/column_statistics.py +142 -0
- data_designer/engine/analysis/dataset_profiler.py +125 -0
- data_designer/engine/analysis/errors.py +7 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +61 -0
- data_designer/engine/column_generators/generators/expression.py +63 -0
- data_designer/engine/column_generators/generators/llm_generators.py +172 -0
- data_designer/engine/column_generators/generators/samplers.py +75 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
- data_designer/engine/column_generators/generators/validation.py +147 -0
- data_designer/engine/column_generators/registry.py +56 -0
- data_designer/engine/column_generators/utils/errors.py +13 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
- data_designer/engine/configurable_task.py +82 -0
- data_designer/engine/dataset_builders/artifact_storage.py +181 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
- data_designer/engine/dataset_builders/errors.py +13 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
- data_designer/engine/dataset_builders/utils/dag.py +56 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
- data_designer/engine/dataset_builders/utils/errors.py +13 -0
- data_designer/engine/errors.py +49 -0
- data_designer/engine/model_provider.py +75 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +308 -0
- data_designer/engine/models/facade.py +225 -0
- data_designer/engine/models/litellm_overrides.py +162 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +236 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +60 -0
- data_designer/engine/models/parsers/types.py +82 -0
- data_designer/engine/models/recipes/base.py +79 -0
- data_designer/engine/models/recipes/response_recipes.py +291 -0
- data_designer/engine/models/registry.py +118 -0
- data_designer/engine/models/usage.py +75 -0
- data_designer/engine/models/utils.py +38 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +64 -0
- data_designer/engine/processing/ginja/environment.py +461 -0
- data_designer/engine/processing/ginja/exceptions.py +54 -0
- data_designer/engine/processing/ginja/record.py +30 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +8 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
- data_designer/engine/processing/gsonschema/types.py +8 -0
- data_designer/engine/processing/gsonschema/validators.py +143 -0
- data_designer/engine/processing/processors/base.py +15 -0
- data_designer/engine/processing/processors/drop_columns.py +46 -0
- data_designer/engine/processing/processors/registry.py +20 -0
- data_designer/engine/processing/utils.py +120 -0
- data_designer/engine/registry/base.py +97 -0
- data_designer/engine/registry/data_designer_registry.py +37 -0
- data_designer/engine/registry/errors.py +10 -0
- data_designer/engine/resources/managed_dataset_generator.py +35 -0
- data_designer/engine/resources/managed_dataset_repository.py +194 -0
- data_designer/engine/resources/managed_storage.py +63 -0
- data_designer/engine/resources/resource_provider.py +46 -0
- data_designer/engine/resources/seed_dataset_data_store.py +66 -0
- data_designer/engine/sampling_gen/column.py +89 -0
- data_designer/engine/sampling_gen/constraints.py +95 -0
- data_designer/engine/sampling_gen/data_sources/base.py +214 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
- data_designer/engine/sampling_gen/entities/errors.py +8 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
- data_designer/engine/sampling_gen/entities/person.py +142 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
- data_designer/engine/sampling_gen/errors.py +24 -0
- data_designer/engine/sampling_gen/generator.py +121 -0
- data_designer/engine/sampling_gen/jinja_utils.py +60 -0
- data_designer/engine/sampling_gen/people_gen.py +203 -0
- data_designer/engine/sampling_gen/person_constants.py +54 -0
- data_designer/engine/sampling_gen/schema.py +143 -0
- data_designer/engine/sampling_gen/schema_builder.py +59 -0
- data_designer/engine/sampling_gen/utils.py +40 -0
- data_designer/engine/secret_resolver.py +80 -0
- data_designer/engine/validators/__init__.py +17 -0
- data_designer/engine/validators/base.py +36 -0
- data_designer/engine/validators/local_callable.py +34 -0
- data_designer/engine/validators/python.py +245 -0
- data_designer/engine/validators/remote.py +83 -0
- data_designer/engine/validators/sql.py +60 -0
- data_designer/errors.py +5 -0
- data_designer/essentials/__init__.py +137 -0
- data_designer/interface/__init__.py +2 -0
- data_designer/interface/data_designer.py +351 -0
- data_designer/interface/errors.py +16 -0
- data_designer/interface/results.py +55 -0
- data_designer/logging.py +161 -0
- data_designer/plugin_manager.py +83 -0
- data_designer/plugins/__init__.py +6 -0
- data_designer/plugins/errors.py +10 -0
- data_designer/plugins/plugin.py +69 -0
- data_designer/plugins/registry.py +86 -0
- data_designer-0.1.0.dist-info/METADATA +173 -0
- data_designer-0.1.0.dist-info/RECORD +177 -0
- data_designer-0.1.0.dist-info/WHEEL +4 -0
- data_designer-0.1.0.dist-info/entry_points.txt +2 -0
- data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import Annotated, Optional
|
|
7
|
+
|
|
8
|
+
from pydantic import Field
|
|
9
|
+
|
|
10
|
+
from .analysis.column_profilers import ColumnProfilerConfigT
|
|
11
|
+
from .base import ExportableConfigBase
|
|
12
|
+
from .column_types import ColumnConfigT
|
|
13
|
+
from .models import ModelConfig
|
|
14
|
+
from .processors import ProcessorConfig
|
|
15
|
+
from .sampler_constraints import ColumnConstraintT
|
|
16
|
+
from .seed import SeedConfig
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DataDesignerConfig(ExportableConfigBase):
|
|
20
|
+
"""Configuration for NeMo Data Designer.
|
|
21
|
+
|
|
22
|
+
This class defines the main configuration structure for NeMo Data Designer,
|
|
23
|
+
which orchestrates the generation of synthetic data.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
columns: Required list of column configurations defining how each column
|
|
27
|
+
should be generated. Must contain at least one column.
|
|
28
|
+
model_configs: Optional list of model configurations for LLM-based generation.
|
|
29
|
+
Each model config defines the model, provider, and inference parameters.
|
|
30
|
+
seed_config: Optional seed dataset settings to use for generation.
|
|
31
|
+
constraints: Optional list of column constraints.
|
|
32
|
+
profilers: Optional list of column profilers for analyzing generated data characteristics.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
columns: list[Annotated[ColumnConfigT, Field(discriminator="column_type")]] = Field(min_length=1)
|
|
36
|
+
model_configs: Optional[list[ModelConfig]] = None
|
|
37
|
+
seed_config: Optional[SeedConfig] = None
|
|
38
|
+
constraints: Optional[list[ColumnConstraintT]] = None
|
|
39
|
+
profilers: Optional[list[ColumnProfilerConfigT]] = None
|
|
40
|
+
processors: Optional[list[ProcessorConfig]] = None
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from enum import Enum
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BuildStage(str, Enum):
|
|
8
|
+
PRE_BATCH = "pre_batch"
|
|
9
|
+
POST_BATCH = "post_batch"
|
|
10
|
+
PRE_GENERATION = "pre_generation"
|
|
11
|
+
POST_GENERATION = "post_generation"
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
9
|
+
|
|
10
|
+
from huggingface_hub import HfApi, HfFileSystem
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import pyarrow.parquet as pq
|
|
13
|
+
from pydantic import BaseModel, Field
|
|
14
|
+
|
|
15
|
+
from .errors import InvalidConfigError, InvalidFileFormatError, InvalidFilePathError
|
|
16
|
+
from .utils.io_helpers import VALID_DATASET_FILE_EXTENSIONS, validate_path_contains_files_of_type
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from .seed import SeedDatasetReference
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class DatastoreSettings(BaseModel):
|
|
25
|
+
"""Configuration for interacting with a datastore."""
|
|
26
|
+
|
|
27
|
+
endpoint: str = Field(
|
|
28
|
+
...,
|
|
29
|
+
description="Datastore endpoint. Use 'https://huggingface.co' for the Hugging Face Hub.",
|
|
30
|
+
)
|
|
31
|
+
token: Optional[str] = Field(default=None, description="If needed, token to use for authentication.")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_file_column_names(file_path: Union[str, Path], file_type: str) -> list[str]:
|
|
35
|
+
"""Extract column names based on file type. Supports glob patterns like '../path/*.parquet'."""
|
|
36
|
+
file_path = Path(file_path)
|
|
37
|
+
if "*" in str(file_path):
|
|
38
|
+
matching_files = sorted(file_path.parent.glob(file_path.name))
|
|
39
|
+
if not matching_files:
|
|
40
|
+
raise InvalidFilePathError(f"🛑 No files found matching pattern: {str(file_path)!r}")
|
|
41
|
+
logger.debug(f"0️⃣ Using the first matching file in {str(file_path)!r} to determine column names in seed dataset")
|
|
42
|
+
file_path = matching_files[0]
|
|
43
|
+
|
|
44
|
+
if file_type == "parquet":
|
|
45
|
+
try:
|
|
46
|
+
schema = pq.read_schema(file_path)
|
|
47
|
+
if hasattr(schema, "names"):
|
|
48
|
+
return schema.names
|
|
49
|
+
else:
|
|
50
|
+
return [field.name for field in schema]
|
|
51
|
+
except Exception as e:
|
|
52
|
+
logger.warning(f"Failed to process parquet file {file_path}: {e}")
|
|
53
|
+
return []
|
|
54
|
+
elif file_type in ["json", "jsonl"]:
|
|
55
|
+
return pd.read_json(file_path, orient="records", lines=True, nrows=1).columns.tolist()
|
|
56
|
+
elif file_type == "csv":
|
|
57
|
+
try:
|
|
58
|
+
df = pd.read_csv(file_path, nrows=1)
|
|
59
|
+
return df.columns.tolist()
|
|
60
|
+
except (pd.errors.EmptyDataError, pd.errors.ParserError) as e:
|
|
61
|
+
logger.warning(f"Failed to process CSV file {file_path}: {e}")
|
|
62
|
+
return []
|
|
63
|
+
else:
|
|
64
|
+
raise InvalidFilePathError(f"🛑 Unsupported file type: {file_type!r}")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def fetch_seed_dataset_column_names(seed_dataset_reference: SeedDatasetReference) -> list[str]:
|
|
68
|
+
if hasattr(seed_dataset_reference, "datastore_settings"):
|
|
69
|
+
return _fetch_seed_dataset_column_names_from_datastore(
|
|
70
|
+
seed_dataset_reference.repo_id,
|
|
71
|
+
seed_dataset_reference.filename,
|
|
72
|
+
seed_dataset_reference.datastore_settings,
|
|
73
|
+
)
|
|
74
|
+
return _fetch_seed_dataset_column_names_from_local_file(seed_dataset_reference.dataset)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def resolve_datastore_settings(datastore_settings: DatastoreSettings | dict | None) -> DatastoreSettings:
|
|
78
|
+
if datastore_settings is None:
|
|
79
|
+
raise InvalidConfigError("🛑 Datastore settings are required in order to upload datasets to the datastore.")
|
|
80
|
+
if isinstance(datastore_settings, DatastoreSettings):
|
|
81
|
+
return datastore_settings
|
|
82
|
+
elif isinstance(datastore_settings, dict):
|
|
83
|
+
return DatastoreSettings.model_validate(datastore_settings)
|
|
84
|
+
else:
|
|
85
|
+
raise InvalidConfigError(
|
|
86
|
+
"🛑 Invalid datastore settings format. Must be DatastoreSettings object or dictionary."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def upload_to_hf_hub(
|
|
91
|
+
dataset_path: Union[str, Path],
|
|
92
|
+
filename: str,
|
|
93
|
+
repo_id: str,
|
|
94
|
+
datastore_settings: DatastoreSettings,
|
|
95
|
+
**kwargs,
|
|
96
|
+
) -> str:
|
|
97
|
+
datastore_settings = resolve_datastore_settings(datastore_settings)
|
|
98
|
+
dataset_path = _validate_dataset_path(dataset_path)
|
|
99
|
+
filename_ext = filename.split(".")[-1].lower()
|
|
100
|
+
if dataset_path.suffix.lower()[1:] != filename_ext:
|
|
101
|
+
raise InvalidFileFormatError(
|
|
102
|
+
f"🛑 Dataset file extension {dataset_path.suffix!r} does not match `filename` extension .{filename_ext!r}"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
hfapi = HfApi(endpoint=datastore_settings.endpoint, token=datastore_settings.token)
|
|
106
|
+
hfapi.create_repo(repo_id, exist_ok=True, repo_type="dataset")
|
|
107
|
+
hfapi.upload_file(
|
|
108
|
+
path_or_fileobj=dataset_path,
|
|
109
|
+
path_in_repo=filename,
|
|
110
|
+
repo_id=repo_id,
|
|
111
|
+
repo_type="dataset",
|
|
112
|
+
**kwargs,
|
|
113
|
+
)
|
|
114
|
+
return f"{repo_id}/{filename}"
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _fetch_seed_dataset_column_names_from_datastore(
|
|
118
|
+
repo_id: str,
|
|
119
|
+
filename: str,
|
|
120
|
+
datastore_settings: Optional[Union[DatastoreSettings, dict]] = None,
|
|
121
|
+
) -> list[str]:
|
|
122
|
+
file_type = filename.split(".")[-1]
|
|
123
|
+
if f".{file_type}" not in VALID_DATASET_FILE_EXTENSIONS:
|
|
124
|
+
raise InvalidFileFormatError(f"🛑 Unsupported file type: {filename!r}")
|
|
125
|
+
|
|
126
|
+
datastore_settings = resolve_datastore_settings(datastore_settings)
|
|
127
|
+
fs = HfFileSystem(endpoint=datastore_settings.endpoint, token=datastore_settings.token)
|
|
128
|
+
|
|
129
|
+
with fs.open(f"datasets/{repo_id}/{filename}") as f:
|
|
130
|
+
return get_file_column_names(f, file_type)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _fetch_seed_dataset_column_names_from_local_file(dataset_path: str | Path) -> list[str]:
|
|
134
|
+
dataset_path = _validate_dataset_path(dataset_path, allow_glob_pattern=True)
|
|
135
|
+
return get_file_column_names(dataset_path, str(dataset_path).split(".")[-1])
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _validate_dataset_path(dataset_path: Union[str, Path], allow_glob_pattern: bool = False) -> Path:
|
|
139
|
+
if allow_glob_pattern and "*" in str(dataset_path):
|
|
140
|
+
parts = str(dataset_path).split("*.")
|
|
141
|
+
file_path = parts[0]
|
|
142
|
+
file_extension = parts[-1]
|
|
143
|
+
validate_path_contains_files_of_type(file_path, file_extension)
|
|
144
|
+
return Path(dataset_path)
|
|
145
|
+
if not Path(dataset_path).is_file():
|
|
146
|
+
raise InvalidFilePathError("🛑 To upload a dataset to the datastore, you must provide a valid file path.")
|
|
147
|
+
if not Path(dataset_path).name.endswith(tuple(VALID_DATASET_FILE_EXTENSIONS)):
|
|
148
|
+
raise InvalidFileFormatError(
|
|
149
|
+
"🛑 Dataset files must be in `parquet`, `csv`, or `json` (orient='records', lines=True) format."
|
|
150
|
+
)
|
|
151
|
+
return Path(dataset_path)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
from functools import lru_cache
|
|
6
|
+
import logging
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Literal, Optional
|
|
9
|
+
|
|
10
|
+
from .models import InferenceParameters, ModelConfig, ModelProvider
|
|
11
|
+
from .utils.constants import (
|
|
12
|
+
MANAGED_ASSETS_PATH,
|
|
13
|
+
MODEL_CONFIGS_FILE_PATH,
|
|
14
|
+
MODEL_PROVIDERS_FILE_PATH,
|
|
15
|
+
PREDEFINED_PROVIDERS,
|
|
16
|
+
PREDEFINED_PROVIDERS_MODEL_MAP,
|
|
17
|
+
)
|
|
18
|
+
from .utils.info import ConfigBuilderInfo, InfoType, InterfaceInfo
|
|
19
|
+
from .utils.io_helpers import load_config_file, save_config_file
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_default_text_alias_inference_parameters() -> InferenceParameters:
|
|
25
|
+
return InferenceParameters(
|
|
26
|
+
temperature=0.85,
|
|
27
|
+
top_p=0.95,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_default_reasoning_alias_inference_parameters() -> InferenceParameters:
|
|
32
|
+
return InferenceParameters(
|
|
33
|
+
temperature=0.35,
|
|
34
|
+
top_p=0.95,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_default_vision_alias_inference_parameters() -> InferenceParameters:
|
|
39
|
+
return InferenceParameters(
|
|
40
|
+
temperature=0.85,
|
|
41
|
+
top_p=0.95,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_default_inference_parameters(model_alias: Literal["text", "reasoning", "vision"]) -> InferenceParameters:
|
|
46
|
+
if model_alias == "reasoning":
|
|
47
|
+
return get_default_reasoning_alias_inference_parameters()
|
|
48
|
+
elif model_alias == "vision":
|
|
49
|
+
return get_default_vision_alias_inference_parameters()
|
|
50
|
+
else:
|
|
51
|
+
return get_default_text_alias_inference_parameters()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_builtin_model_configs() -> list[ModelConfig]:
|
|
55
|
+
model_configs = []
|
|
56
|
+
for provider, model_alias_map in PREDEFINED_PROVIDERS_MODEL_MAP.items():
|
|
57
|
+
for model_alias, model_id in model_alias_map.items():
|
|
58
|
+
model_configs.append(
|
|
59
|
+
ModelConfig(
|
|
60
|
+
alias=f"{provider}-{model_alias}",
|
|
61
|
+
model=model_id,
|
|
62
|
+
provider=provider,
|
|
63
|
+
inference_parameters=get_default_inference_parameters(model_alias),
|
|
64
|
+
)
|
|
65
|
+
)
|
|
66
|
+
return model_configs
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def get_builtin_model_providers() -> list[ModelProvider]:
|
|
70
|
+
return [ModelProvider.model_validate(provider) for provider in PREDEFINED_PROVIDERS]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def get_default_model_configs() -> list[ModelConfig]:
|
|
74
|
+
if MODEL_CONFIGS_FILE_PATH.exists():
|
|
75
|
+
config_dict = load_config_file(MODEL_CONFIGS_FILE_PATH)
|
|
76
|
+
if "model_configs" in config_dict:
|
|
77
|
+
return [ModelConfig.model_validate(mc) for mc in config_dict["model_configs"]]
|
|
78
|
+
raise FileNotFoundError(f"Default model configs file not found at {str(MODEL_CONFIGS_FILE_PATH)!r}")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_default_providers() -> list[ModelProvider]:
|
|
82
|
+
config_dict = _get_default_providers_file_content(MODEL_PROVIDERS_FILE_PATH)
|
|
83
|
+
if "providers" in config_dict:
|
|
84
|
+
return [ModelProvider.model_validate(p) for p in config_dict["providers"]]
|
|
85
|
+
return []
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_default_provider_name() -> Optional[str]:
|
|
89
|
+
return _get_default_providers_file_content(MODEL_PROVIDERS_FILE_PATH).get("default")
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def resolve_seed_default_model_settings() -> None:
|
|
93
|
+
if not MODEL_CONFIGS_FILE_PATH.exists():
|
|
94
|
+
logger.info(
|
|
95
|
+
f"🍾 Default model configs were not found, so writing the following to {str(MODEL_CONFIGS_FILE_PATH)!r}"
|
|
96
|
+
)
|
|
97
|
+
config_builder_info = ConfigBuilderInfo(model_configs=get_builtin_model_configs())
|
|
98
|
+
config_builder_info.display(info_type=InfoType.MODEL_CONFIGS)
|
|
99
|
+
save_config_file(
|
|
100
|
+
MODEL_CONFIGS_FILE_PATH, {"model_configs": [mc.model_dump() for mc in get_builtin_model_configs()]}
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
if not MODEL_PROVIDERS_FILE_PATH.exists():
|
|
104
|
+
logger.info(
|
|
105
|
+
f"🪄 Default model providers were not found, so writing the following to {str(MODEL_PROVIDERS_FILE_PATH)!r}"
|
|
106
|
+
)
|
|
107
|
+
interface_info = InterfaceInfo(model_providers=get_builtin_model_providers())
|
|
108
|
+
interface_info.display(info_type=InfoType.MODEL_PROVIDERS)
|
|
109
|
+
save_config_file(
|
|
110
|
+
MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump() for p in get_builtin_model_providers()]}
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if not MANAGED_ASSETS_PATH.exists():
|
|
114
|
+
logger.debug(f"🏗️ Default managed assets path was not found, so creating it at {str(MANAGED_ASSETS_PATH)!r}")
|
|
115
|
+
MANAGED_ASSETS_PATH.mkdir(parents=True, exist_ok=True)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@lru_cache(maxsize=1)
|
|
119
|
+
def _get_default_providers_file_content(file_path: Path) -> dict[str, Any]:
|
|
120
|
+
"""Load and cache the default providers file content."""
|
|
121
|
+
if file_path.exists():
|
|
122
|
+
return load_config_file(file_path)
|
|
123
|
+
raise FileNotFoundError(f"Default model providers file not found at {str(file_path)!r}")
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from ..errors import DataDesignerError
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BuilderConfigurationError(DataDesignerError): ...
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class InvalidColumnTypeError(DataDesignerError): ...
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InvalidConfigError(DataDesignerError): ...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class InvalidFilePathError(DataDesignerError): ...
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class InvalidFileFormatError(DataDesignerError): ...
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar
|
|
8
|
+
|
|
9
|
+
import pandas as pd
|
|
10
|
+
|
|
11
|
+
from .models import ModelConfig, ModelProvider
|
|
12
|
+
from .utils.constants import DEFAULT_NUM_RECORDS
|
|
13
|
+
from .utils.info import InterfaceInfo
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from .analysis.dataset_profiler import DatasetProfilerResults
|
|
17
|
+
from .config_builder import DataDesignerConfigBuilder
|
|
18
|
+
from .preview_results import PreviewResults
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ResultsProtocol(Protocol):
|
|
22
|
+
def load_analysis(self) -> DatasetProfilerResults: ...
|
|
23
|
+
def load_dataset(self) -> pd.DataFrame: ...
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
ResultsT = TypeVar("ResultsT", bound=ResultsProtocol)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class DataDesignerInterface(ABC, Generic[ResultsT]):
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def create(
|
|
32
|
+
self,
|
|
33
|
+
config_builder: DataDesignerConfigBuilder,
|
|
34
|
+
*,
|
|
35
|
+
num_records: int = DEFAULT_NUM_RECORDS,
|
|
36
|
+
) -> ResultsT: ...
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def preview(
|
|
40
|
+
self,
|
|
41
|
+
config_builder: DataDesignerConfigBuilder,
|
|
42
|
+
*,
|
|
43
|
+
num_records: int = DEFAULT_NUM_RECORDS,
|
|
44
|
+
) -> PreviewResults: ...
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def get_default_model_configs(self) -> list[ModelConfig]: ...
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def get_default_model_providers(self) -> list[ModelProvider]: ...
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def info(self) -> InterfaceInfo: ...
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from enum import Enum
|
|
6
|
+
import logging
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Generic, List, Optional, TypeVar, Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from pydantic import BaseModel, Field, model_validator
|
|
12
|
+
from typing_extensions import Self, TypeAlias
|
|
13
|
+
|
|
14
|
+
from .base import ConfigBase
|
|
15
|
+
from .errors import InvalidConfigError
|
|
16
|
+
from .utils.constants import (
|
|
17
|
+
MAX_TEMPERATURE,
|
|
18
|
+
MAX_TOP_P,
|
|
19
|
+
MIN_TEMPERATURE,
|
|
20
|
+
MIN_TOP_P,
|
|
21
|
+
)
|
|
22
|
+
from .utils.io_helpers import smart_load_yaml
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Modality(str, Enum):
|
|
28
|
+
IMAGE = "image"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ModalityDataType(str, Enum):
|
|
32
|
+
URL = "url"
|
|
33
|
+
BASE64 = "base64"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ImageFormat(str, Enum):
|
|
37
|
+
PNG = "png"
|
|
38
|
+
JPG = "jpg"
|
|
39
|
+
JPEG = "jpeg"
|
|
40
|
+
GIF = "gif"
|
|
41
|
+
WEBP = "webp"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DistributionType(str, Enum):
|
|
45
|
+
UNIFORM = "uniform"
|
|
46
|
+
MANUAL = "manual"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ModalityContext(ABC, BaseModel):
|
|
50
|
+
modality: Modality
|
|
51
|
+
column_name: str
|
|
52
|
+
data_type: ModalityDataType
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def get_context(self, record: dict) -> dict[str, Any]: ...
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ImageContext(ModalityContext):
|
|
59
|
+
modality: Modality = Modality.IMAGE
|
|
60
|
+
image_format: Optional[ImageFormat] = None
|
|
61
|
+
|
|
62
|
+
def get_context(self, record: dict) -> dict[str, Any]:
|
|
63
|
+
context = dict(type="image_url")
|
|
64
|
+
context_value = record[self.column_name]
|
|
65
|
+
if self.data_type == ModalityDataType.URL:
|
|
66
|
+
context["image_url"] = context_value
|
|
67
|
+
else:
|
|
68
|
+
context["image_url"] = {
|
|
69
|
+
"url": f"data:image/{self.image_format.value};base64,{context_value}",
|
|
70
|
+
"format": self.image_format.value,
|
|
71
|
+
}
|
|
72
|
+
return context
|
|
73
|
+
|
|
74
|
+
@model_validator(mode="after")
|
|
75
|
+
def _validate_image_format(self) -> Self:
|
|
76
|
+
if self.data_type == ModalityDataType.BASE64 and self.image_format is None:
|
|
77
|
+
raise ValueError(f"image_format is required when data_type is {self.data_type.value}")
|
|
78
|
+
return self
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
DistributionParamsT = TypeVar("DistributionParamsT", bound=ConfigBase)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class Distribution(ABC, ConfigBase, Generic[DistributionParamsT]):
|
|
85
|
+
distribution_type: DistributionType
|
|
86
|
+
params: DistributionParamsT
|
|
87
|
+
|
|
88
|
+
@abstractmethod
|
|
89
|
+
def sample(self) -> float: ...
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class ManualDistributionParams(ConfigBase):
|
|
93
|
+
values: List[float] = Field(min_length=1)
|
|
94
|
+
weights: Optional[List[float]] = None
|
|
95
|
+
|
|
96
|
+
@model_validator(mode="after")
|
|
97
|
+
def _normalize_weights(self) -> Self:
|
|
98
|
+
if self.weights is not None:
|
|
99
|
+
self.weights = [w / sum(self.weights) for w in self.weights]
|
|
100
|
+
return self
|
|
101
|
+
|
|
102
|
+
@model_validator(mode="after")
|
|
103
|
+
def _validate_equal_lengths(self) -> Self:
|
|
104
|
+
if self.weights and len(self.values) != len(self.weights):
|
|
105
|
+
raise ValueError("`values` and `weights` must have the same length")
|
|
106
|
+
return self
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class ManualDistribution(Distribution[ManualDistributionParams]):
|
|
110
|
+
distribution_type: Optional[DistributionType] = "manual"
|
|
111
|
+
params: ManualDistributionParams
|
|
112
|
+
|
|
113
|
+
def sample(self) -> float:
|
|
114
|
+
return float(np.random.choice(self.params.values, p=self.params.weights))
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class UniformDistributionParams(ConfigBase):
|
|
118
|
+
low: float
|
|
119
|
+
high: float
|
|
120
|
+
|
|
121
|
+
@model_validator(mode="after")
|
|
122
|
+
def _validate_low_lt_high(self) -> Self:
|
|
123
|
+
if self.low >= self.high:
|
|
124
|
+
raise ValueError("`low` must be less than `high`")
|
|
125
|
+
return self
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class UniformDistribution(Distribution[UniformDistributionParams]):
|
|
129
|
+
distribution_type: Optional[DistributionType] = "uniform"
|
|
130
|
+
params: UniformDistributionParams
|
|
131
|
+
|
|
132
|
+
def sample(self) -> float:
|
|
133
|
+
return float(np.random.uniform(low=self.params.low, high=self.params.high, size=1)[0])
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
DistributionT: TypeAlias = Union[UniformDistribution, ManualDistribution]
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class InferenceParameters(ConfigBase):
|
|
140
|
+
temperature: Optional[Union[float, DistributionT]] = None
|
|
141
|
+
top_p: Optional[Union[float, DistributionT]] = None
|
|
142
|
+
max_tokens: Optional[int] = Field(default=None, ge=1)
|
|
143
|
+
max_parallel_requests: int = Field(default=4, ge=1)
|
|
144
|
+
timeout: Optional[int] = Field(default=None, ge=1)
|
|
145
|
+
extra_body: Optional[dict[str, Any]] = None
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def generate_kwargs(self) -> dict[str, Union[float, int]]:
|
|
149
|
+
result = {}
|
|
150
|
+
if self.temperature is not None:
|
|
151
|
+
result["temperature"] = (
|
|
152
|
+
self.temperature.sample() if hasattr(self.temperature, "sample") else self.temperature
|
|
153
|
+
)
|
|
154
|
+
if self.top_p is not None:
|
|
155
|
+
result["top_p"] = self.top_p.sample() if hasattr(self.top_p, "sample") else self.top_p
|
|
156
|
+
if self.max_tokens is not None:
|
|
157
|
+
result["max_tokens"] = self.max_tokens
|
|
158
|
+
if self.timeout is not None:
|
|
159
|
+
result["timeout"] = self.timeout
|
|
160
|
+
if self.extra_body is not None and self.extra_body != {}:
|
|
161
|
+
result["extra_body"] = self.extra_body
|
|
162
|
+
return result
|
|
163
|
+
|
|
164
|
+
@model_validator(mode="after")
|
|
165
|
+
def _validate_temperature(self) -> Self:
|
|
166
|
+
return self._run_validation(
|
|
167
|
+
value=self.temperature,
|
|
168
|
+
param_name="temperature",
|
|
169
|
+
min_value=MIN_TEMPERATURE,
|
|
170
|
+
max_value=MAX_TEMPERATURE,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
@model_validator(mode="after")
|
|
174
|
+
def _validate_top_p(self) -> Self:
|
|
175
|
+
return self._run_validation(
|
|
176
|
+
value=self.top_p,
|
|
177
|
+
param_name="top_p",
|
|
178
|
+
min_value=MIN_TOP_P,
|
|
179
|
+
max_value=MAX_TOP_P,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
def _run_validation(
|
|
183
|
+
self,
|
|
184
|
+
value: Union[float, DistributionT, None],
|
|
185
|
+
param_name: str,
|
|
186
|
+
min_value: float,
|
|
187
|
+
max_value: float,
|
|
188
|
+
) -> Self:
|
|
189
|
+
if value is None:
|
|
190
|
+
return self
|
|
191
|
+
value_err = ValueError(f"{param_name} defined in model config must be between {min_value} and {max_value}")
|
|
192
|
+
if isinstance(value, Distribution):
|
|
193
|
+
if value.distribution_type == DistributionType.UNIFORM:
|
|
194
|
+
if value.params.low < min_value or value.params.high > max_value:
|
|
195
|
+
raise value_err
|
|
196
|
+
elif value.distribution_type == DistributionType.MANUAL:
|
|
197
|
+
if any(not self._is_value_in_range(v, min_value, max_value) for v in value.params.values):
|
|
198
|
+
raise value_err
|
|
199
|
+
else:
|
|
200
|
+
if not self._is_value_in_range(value, min_value, max_value):
|
|
201
|
+
raise value_err
|
|
202
|
+
return self
|
|
203
|
+
|
|
204
|
+
def _is_value_in_range(self, value: float, min_value: float, max_value: float) -> bool:
|
|
205
|
+
return min_value <= value <= max_value
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class ModelConfig(ConfigBase):
|
|
209
|
+
alias: str
|
|
210
|
+
model: str
|
|
211
|
+
inference_parameters: InferenceParameters = Field(default_factory=InferenceParameters)
|
|
212
|
+
provider: Optional[str] = None
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class ModelProvider(ConfigBase):
|
|
216
|
+
name: str
|
|
217
|
+
endpoint: str
|
|
218
|
+
provider_type: str = "openai"
|
|
219
|
+
api_key: Optional[str] = None
|
|
220
|
+
extra_body: Optional[dict[str, Any]] = None
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def load_model_configs(model_configs: Union[list[ModelConfig], str, Path]) -> list[ModelConfig]:
|
|
224
|
+
if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs):
|
|
225
|
+
return model_configs
|
|
226
|
+
json_config = smart_load_yaml(model_configs)
|
|
227
|
+
if "model_configs" not in json_config:
|
|
228
|
+
raise InvalidConfigError(
|
|
229
|
+
"The list of model configs must be provided under model_configs in the configuration file."
|
|
230
|
+
)
|
|
231
|
+
return [ModelConfig.model_validate(mc) for mc in json_config["model_configs"]]
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
from .analysis.dataset_profiler import DatasetProfilerResults
|
|
11
|
+
from .config_builder import DataDesignerConfigBuilder
|
|
12
|
+
from .utils.visualization import WithRecordSamplerMixin
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PreviewResults(WithRecordSamplerMixin):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
*,
|
|
19
|
+
config_builder: DataDesignerConfigBuilder,
|
|
20
|
+
dataset: Optional[pd.DataFrame] = None,
|
|
21
|
+
analysis: Optional[DatasetProfilerResults] = None,
|
|
22
|
+
):
|
|
23
|
+
"""Creates a new instance with results from a Data Designer preview run.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
config_builder: Data Designer configuration builder.
|
|
27
|
+
dataset: Dataset of the preview run.
|
|
28
|
+
analysis: Analysis of the preview run.
|
|
29
|
+
"""
|
|
30
|
+
self.dataset: pd.DataFrame | None = dataset
|
|
31
|
+
self.analysis: DatasetProfilerResults | None = analysis
|
|
32
|
+
self._config_builder = config_builder
|