data-designer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +15 -0
- data_designer/_version.py +34 -0
- data_designer/cli/README.md +236 -0
- data_designer/cli/__init__.py +6 -0
- data_designer/cli/commands/__init__.py +2 -0
- data_designer/cli/commands/list.py +130 -0
- data_designer/cli/commands/models.py +10 -0
- data_designer/cli/commands/providers.py +11 -0
- data_designer/cli/commands/reset.py +100 -0
- data_designer/cli/controllers/__init__.py +7 -0
- data_designer/cli/controllers/model_controller.py +246 -0
- data_designer/cli/controllers/provider_controller.py +317 -0
- data_designer/cli/forms/__init__.py +20 -0
- data_designer/cli/forms/builder.py +51 -0
- data_designer/cli/forms/field.py +180 -0
- data_designer/cli/forms/form.py +59 -0
- data_designer/cli/forms/model_builder.py +125 -0
- data_designer/cli/forms/provider_builder.py +76 -0
- data_designer/cli/main.py +44 -0
- data_designer/cli/repositories/__init__.py +8 -0
- data_designer/cli/repositories/base.py +39 -0
- data_designer/cli/repositories/model_repository.py +42 -0
- data_designer/cli/repositories/provider_repository.py +43 -0
- data_designer/cli/services/__init__.py +7 -0
- data_designer/cli/services/model_service.py +116 -0
- data_designer/cli/services/provider_service.py +111 -0
- data_designer/cli/ui.py +448 -0
- data_designer/cli/utils.py +47 -0
- data_designer/config/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +89 -0
- data_designer/config/analysis/column_statistics.py +274 -0
- data_designer/config/analysis/dataset_profiler.py +60 -0
- data_designer/config/analysis/utils/errors.py +8 -0
- data_designer/config/analysis/utils/reporting.py +188 -0
- data_designer/config/base.py +68 -0
- data_designer/config/column_configs.py +354 -0
- data_designer/config/column_types.py +168 -0
- data_designer/config/config_builder.py +660 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +11 -0
- data_designer/config/datastore.py +151 -0
- data_designer/config/default_model_settings.py +123 -0
- data_designer/config/errors.py +19 -0
- data_designer/config/interface.py +54 -0
- data_designer/config/models.py +231 -0
- data_designer/config/preview_results.py +32 -0
- data_designer/config/processors.py +41 -0
- data_designer/config/sampler_constraints.py +51 -0
- data_designer/config/sampler_params.py +604 -0
- data_designer/config/seed.py +145 -0
- data_designer/config/utils/code_lang.py +83 -0
- data_designer/config/utils/constants.py +313 -0
- data_designer/config/utils/errors.py +19 -0
- data_designer/config/utils/info.py +88 -0
- data_designer/config/utils/io_helpers.py +273 -0
- data_designer/config/utils/misc.py +81 -0
- data_designer/config/utils/numerical_helpers.py +28 -0
- data_designer/config/utils/type_helpers.py +100 -0
- data_designer/config/utils/validation.py +336 -0
- data_designer/config/utils/visualization.py +427 -0
- data_designer/config/validator_params.py +96 -0
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +55 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
- data_designer/engine/analysis/column_profilers/registry.py +20 -0
- data_designer/engine/analysis/column_statistics.py +142 -0
- data_designer/engine/analysis/dataset_profiler.py +125 -0
- data_designer/engine/analysis/errors.py +7 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +61 -0
- data_designer/engine/column_generators/generators/expression.py +63 -0
- data_designer/engine/column_generators/generators/llm_generators.py +172 -0
- data_designer/engine/column_generators/generators/samplers.py +75 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
- data_designer/engine/column_generators/generators/validation.py +147 -0
- data_designer/engine/column_generators/registry.py +56 -0
- data_designer/engine/column_generators/utils/errors.py +13 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
- data_designer/engine/configurable_task.py +82 -0
- data_designer/engine/dataset_builders/artifact_storage.py +181 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
- data_designer/engine/dataset_builders/errors.py +13 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
- data_designer/engine/dataset_builders/utils/dag.py +56 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
- data_designer/engine/dataset_builders/utils/errors.py +13 -0
- data_designer/engine/errors.py +49 -0
- data_designer/engine/model_provider.py +75 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +308 -0
- data_designer/engine/models/facade.py +225 -0
- data_designer/engine/models/litellm_overrides.py +162 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +236 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +60 -0
- data_designer/engine/models/parsers/types.py +82 -0
- data_designer/engine/models/recipes/base.py +79 -0
- data_designer/engine/models/recipes/response_recipes.py +291 -0
- data_designer/engine/models/registry.py +118 -0
- data_designer/engine/models/usage.py +75 -0
- data_designer/engine/models/utils.py +38 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +64 -0
- data_designer/engine/processing/ginja/environment.py +461 -0
- data_designer/engine/processing/ginja/exceptions.py +54 -0
- data_designer/engine/processing/ginja/record.py +30 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +8 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
- data_designer/engine/processing/gsonschema/types.py +8 -0
- data_designer/engine/processing/gsonschema/validators.py +143 -0
- data_designer/engine/processing/processors/base.py +15 -0
- data_designer/engine/processing/processors/drop_columns.py +46 -0
- data_designer/engine/processing/processors/registry.py +20 -0
- data_designer/engine/processing/utils.py +120 -0
- data_designer/engine/registry/base.py +97 -0
- data_designer/engine/registry/data_designer_registry.py +37 -0
- data_designer/engine/registry/errors.py +10 -0
- data_designer/engine/resources/managed_dataset_generator.py +35 -0
- data_designer/engine/resources/managed_dataset_repository.py +194 -0
- data_designer/engine/resources/managed_storage.py +63 -0
- data_designer/engine/resources/resource_provider.py +46 -0
- data_designer/engine/resources/seed_dataset_data_store.py +66 -0
- data_designer/engine/sampling_gen/column.py +89 -0
- data_designer/engine/sampling_gen/constraints.py +95 -0
- data_designer/engine/sampling_gen/data_sources/base.py +214 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
- data_designer/engine/sampling_gen/entities/errors.py +8 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
- data_designer/engine/sampling_gen/entities/person.py +142 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
- data_designer/engine/sampling_gen/errors.py +24 -0
- data_designer/engine/sampling_gen/generator.py +121 -0
- data_designer/engine/sampling_gen/jinja_utils.py +60 -0
- data_designer/engine/sampling_gen/people_gen.py +203 -0
- data_designer/engine/sampling_gen/person_constants.py +54 -0
- data_designer/engine/sampling_gen/schema.py +143 -0
- data_designer/engine/sampling_gen/schema_builder.py +59 -0
- data_designer/engine/sampling_gen/utils.py +40 -0
- data_designer/engine/secret_resolver.py +80 -0
- data_designer/engine/validators/__init__.py +17 -0
- data_designer/engine/validators/base.py +36 -0
- data_designer/engine/validators/local_callable.py +34 -0
- data_designer/engine/validators/python.py +245 -0
- data_designer/engine/validators/remote.py +83 -0
- data_designer/engine/validators/sql.py +60 -0
- data_designer/errors.py +5 -0
- data_designer/essentials/__init__.py +137 -0
- data_designer/interface/__init__.py +2 -0
- data_designer/interface/data_designer.py +351 -0
- data_designer/interface/errors.py +16 -0
- data_designer/interface/results.py +55 -0
- data_designer/logging.py +161 -0
- data_designer/plugin_manager.py +83 -0
- data_designer/plugins/__init__.py +6 -0
- data_designer/plugins/errors.py +10 -0
- data_designer/plugins/plugin.py +69 -0
- data_designer/plugins/registry.py +86 -0
- data_designer-0.1.0.dist-info/METADATA +173 -0
- data_designer-0.1.0.dist-info/RECORD +177 -0
- data_designer-0.1.0.dist-info/WHEEL +4 -0
- data_designer-0.1.0.dist-info/entry_points.txt +2 -0
- data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from abc import ABC
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Optional, Union
|
|
7
|
+
|
|
8
|
+
from pydantic import Field, field_validator, model_validator
|
|
9
|
+
from typing_extensions import Self
|
|
10
|
+
|
|
11
|
+
from .base import ConfigBase
|
|
12
|
+
from .datastore import DatastoreSettings
|
|
13
|
+
from .utils.io_helpers import (
|
|
14
|
+
VALID_DATASET_FILE_EXTENSIONS,
|
|
15
|
+
validate_dataset_file_path,
|
|
16
|
+
validate_path_contains_files_of_type,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SamplingStrategy(str, Enum):
|
|
21
|
+
ORDERED = "ordered"
|
|
22
|
+
SHUFFLE = "shuffle"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class IndexRange(ConfigBase):
|
|
26
|
+
start: int = Field(ge=0, description="The start index of the index range (inclusive)")
|
|
27
|
+
end: int = Field(ge=0, description="The end index of the index range (inclusive)")
|
|
28
|
+
|
|
29
|
+
@model_validator(mode="after")
|
|
30
|
+
def _validate_index_range(self) -> Self:
|
|
31
|
+
if self.start > self.end:
|
|
32
|
+
raise ValueError("'start' index must be less than or equal to 'end' index")
|
|
33
|
+
return self
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def size(self) -> int:
|
|
37
|
+
return self.end - self.start + 1
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class PartitionBlock(ConfigBase):
|
|
41
|
+
index: int = Field(default=0, ge=0, description="The index of the partition to sample from")
|
|
42
|
+
num_partitions: int = Field(default=1, ge=1, description="The total number of partitions in the dataset")
|
|
43
|
+
|
|
44
|
+
@model_validator(mode="after")
|
|
45
|
+
def _validate_partition_block(self) -> Self:
|
|
46
|
+
if self.index >= self.num_partitions:
|
|
47
|
+
raise ValueError("'index' must be less than 'num_partitions'")
|
|
48
|
+
return self
|
|
49
|
+
|
|
50
|
+
def to_index_range(self, dataset_size: int) -> IndexRange:
|
|
51
|
+
partition_size = dataset_size // self.num_partitions
|
|
52
|
+
start = self.index * partition_size
|
|
53
|
+
|
|
54
|
+
# For the last partition, extend to the end of the dataset to include remainder rows
|
|
55
|
+
if self.index == self.num_partitions - 1:
|
|
56
|
+
end = dataset_size - 1
|
|
57
|
+
else:
|
|
58
|
+
end = ((self.index + 1) * partition_size) - 1
|
|
59
|
+
return IndexRange(start=start, end=end)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class SeedConfig(ConfigBase):
|
|
63
|
+
"""Configuration for sampling data from a seed dataset.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
dataset: Path or identifier for the seed dataset.
|
|
67
|
+
sampling_strategy: Strategy for how to sample rows from the dataset.
|
|
68
|
+
- ORDERED: Read rows sequentially in their original order.
|
|
69
|
+
- SHUFFLE: Randomly shuffle rows before sampling. When used with
|
|
70
|
+
selection_strategy, shuffling occurs within the selected range/partition.
|
|
71
|
+
selection_strategy: Optional strategy to select a subset of the dataset.
|
|
72
|
+
- IndexRange: Select a specific range of indices (e.g., rows 100-200).
|
|
73
|
+
- PartitionBlock: Select a partition by splitting the dataset into N equal parts.
|
|
74
|
+
Partition indices are zero-based (index=0 is the first partition, index=1 is
|
|
75
|
+
the second, etc.).
|
|
76
|
+
|
|
77
|
+
Examples:
|
|
78
|
+
Read rows sequentially from start to end:
|
|
79
|
+
SeedConfig(dataset="my_data.parquet", sampling_strategy=SamplingStrategy.ORDERED)
|
|
80
|
+
|
|
81
|
+
Read rows in random order:
|
|
82
|
+
SeedConfig(dataset="my_data.parquet", sampling_strategy=SamplingStrategy.SHUFFLE)
|
|
83
|
+
|
|
84
|
+
Read specific index range (rows 100-199):
|
|
85
|
+
SeedConfig(
|
|
86
|
+
dataset="my_data.parquet",
|
|
87
|
+
sampling_strategy=SamplingStrategy.ORDERED,
|
|
88
|
+
selection_strategy=IndexRange(start=100, end=199)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
Read random rows from a specific index range (shuffles within rows 100-199):
|
|
92
|
+
SeedConfig(
|
|
93
|
+
dataset="my_data.parquet",
|
|
94
|
+
sampling_strategy=SamplingStrategy.SHUFFLE,
|
|
95
|
+
selection_strategy=IndexRange(start=100, end=199)
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
Read from partition 2 (3rd partition, zero-based) of 5 partitions (20% of dataset):
|
|
99
|
+
SeedConfig(
|
|
100
|
+
dataset="my_data.parquet",
|
|
101
|
+
sampling_strategy=SamplingStrategy.ORDERED,
|
|
102
|
+
selection_strategy=PartitionBlock(index=2, num_partitions=5)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
Read shuffled rows from partition 0 of 10 partitions (shuffles within the partition):
|
|
106
|
+
SeedConfig(
|
|
107
|
+
dataset="my_data.parquet",
|
|
108
|
+
sampling_strategy=SamplingStrategy.SHUFFLE,
|
|
109
|
+
selection_strategy=PartitionBlock(index=0, num_partitions=10)
|
|
110
|
+
)
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
dataset: str
|
|
114
|
+
sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED
|
|
115
|
+
selection_strategy: Optional[Union[IndexRange, PartitionBlock]] = None
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class SeedDatasetReference(ABC, ConfigBase):
|
|
119
|
+
dataset: str
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class DatastoreSeedDatasetReference(SeedDatasetReference):
|
|
123
|
+
datastore_settings: DatastoreSettings
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def repo_id(self) -> str:
|
|
127
|
+
return "/".join(self.dataset.split("/")[:-1])
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def filename(self) -> str:
|
|
131
|
+
return self.dataset.split("/")[-1]
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class LocalSeedDatasetReference(SeedDatasetReference):
|
|
135
|
+
@field_validator("dataset", mode="after")
|
|
136
|
+
def validate_dataset_is_file(cls, v: str) -> str:
|
|
137
|
+
valid_wild_card_versions = {f"*{ext}" for ext in VALID_DATASET_FILE_EXTENSIONS}
|
|
138
|
+
if any(v.endswith(wildcard) for wildcard in valid_wild_card_versions):
|
|
139
|
+
parts = v.split("*.")
|
|
140
|
+
file_path = parts[0]
|
|
141
|
+
file_extension = parts[-1]
|
|
142
|
+
validate_path_contains_files_of_type(file_path, file_extension)
|
|
143
|
+
else:
|
|
144
|
+
validate_dataset_file_path(v)
|
|
145
|
+
return v
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Union
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CodeLang(str, Enum):
|
|
11
|
+
GO = "go"
|
|
12
|
+
JAVASCRIPT = "javascript"
|
|
13
|
+
JAVA = "java"
|
|
14
|
+
KOTLIN = "kotlin"
|
|
15
|
+
PYTHON = "python"
|
|
16
|
+
RUBY = "ruby"
|
|
17
|
+
RUST = "rust"
|
|
18
|
+
SCALA = "scala"
|
|
19
|
+
SWIFT = "swift"
|
|
20
|
+
TYPESCRIPT = "typescript"
|
|
21
|
+
SQL_SQLITE = "sql:sqlite"
|
|
22
|
+
SQL_TSQL = "sql:tsql"
|
|
23
|
+
SQL_BIGQUERY = "sql:bigquery"
|
|
24
|
+
SQL_MYSQL = "sql:mysql"
|
|
25
|
+
SQL_POSTGRES = "sql:postgres"
|
|
26
|
+
SQL_ANSI = "sql:ansi"
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def parse(value: Union[str, CodeLang]) -> tuple[str, Union[str, None]]:
|
|
30
|
+
value = value.value if isinstance(value, CodeLang) else value
|
|
31
|
+
split_vals = value.split(":")
|
|
32
|
+
return (split_vals[0], split_vals[1] if len(split_vals) > 1 else None)
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def parse_lang(value: Union[str, CodeLang]) -> str:
|
|
36
|
+
return CodeLang.parse(value)[0]
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def parse_dialect(value: Union[str, CodeLang]) -> Union[str, None]:
|
|
40
|
+
return CodeLang.parse(value)[1]
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def supported_values() -> set[str]:
|
|
44
|
+
return {lang.value for lang in CodeLang}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
SQL_DIALECTS: set[CodeLang] = {
|
|
48
|
+
CodeLang.SQL_SQLITE,
|
|
49
|
+
CodeLang.SQL_TSQL,
|
|
50
|
+
CodeLang.SQL_BIGQUERY,
|
|
51
|
+
CodeLang.SQL_MYSQL,
|
|
52
|
+
CodeLang.SQL_POSTGRES,
|
|
53
|
+
CodeLang.SQL_ANSI,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
##########################################################
|
|
57
|
+
# Helper functions
|
|
58
|
+
##########################################################
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def code_lang_to_syntax_lexer(code_lang: Union[CodeLang, str]) -> str:
|
|
62
|
+
"""Convert the code language to a syntax lexer for Pygments.
|
|
63
|
+
|
|
64
|
+
Reference: https://pygments.org/docs/lexers/
|
|
65
|
+
"""
|
|
66
|
+
code_lang_to_lexer = {
|
|
67
|
+
CodeLang.GO: "golang",
|
|
68
|
+
CodeLang.JAVASCRIPT: "javascript",
|
|
69
|
+
CodeLang.JAVA: "java",
|
|
70
|
+
CodeLang.KOTLIN: "kotlin",
|
|
71
|
+
CodeLang.PYTHON: "python",
|
|
72
|
+
CodeLang.RUBY: "ruby",
|
|
73
|
+
CodeLang.RUST: "rust",
|
|
74
|
+
CodeLang.SCALA: "scala",
|
|
75
|
+
CodeLang.SWIFT: "swift",
|
|
76
|
+
CodeLang.SQL_SQLITE: "sql",
|
|
77
|
+
CodeLang.SQL_ANSI: "sql",
|
|
78
|
+
CodeLang.SQL_TSQL: "tsql",
|
|
79
|
+
CodeLang.SQL_BIGQUERY: "sql",
|
|
80
|
+
CodeLang.SQL_MYSQL: "mysql",
|
|
81
|
+
CodeLang.SQL_POSTGRES: "postgres",
|
|
82
|
+
}
|
|
83
|
+
return code_lang_to_lexer.get(code_lang, code_lang)
|
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from enum import Enum
|
|
5
|
+
import os
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from rich.theme import Theme
|
|
9
|
+
|
|
10
|
+
DEFAULT_NUM_RECORDS = 10
|
|
11
|
+
|
|
12
|
+
EPSILON = 1e-8
|
|
13
|
+
REPORTING_PRECISION = 2
|
|
14
|
+
|
|
15
|
+
DEFAULT_REPR_HTML_STYLE = "nord"
|
|
16
|
+
|
|
17
|
+
REPR_HTML_FIXED_WIDTH = 1000
|
|
18
|
+
REPR_HTML_TEMPLATE = """
|
|
19
|
+
<meta charset="UTF-8">
|
|
20
|
+
<style>
|
|
21
|
+
{{css}}
|
|
22
|
+
|
|
23
|
+
.code {{{{
|
|
24
|
+
padding: 4px;
|
|
25
|
+
border: 1px solid grey;
|
|
26
|
+
border-radius: 4px;
|
|
27
|
+
max-width: {fixed_width}px;
|
|
28
|
+
width: 100%;
|
|
29
|
+
display: inline-block;
|
|
30
|
+
box-sizing: border-box;
|
|
31
|
+
text-align: left;
|
|
32
|
+
vertical-align: top;
|
|
33
|
+
line-height: normal;
|
|
34
|
+
overflow-x: auto;
|
|
35
|
+
}}}}
|
|
36
|
+
|
|
37
|
+
.code pre {{{{
|
|
38
|
+
white-space: pre-wrap; /* CSS 3 */
|
|
39
|
+
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
|
40
|
+
white-space: -pre-wrap; /* Opera 4-6 */
|
|
41
|
+
white-space: -o-pre-wrap; /* Opera 7 */
|
|
42
|
+
word-wrap: break-word;
|
|
43
|
+
overflow-wrap: break-word;
|
|
44
|
+
margin: 0;
|
|
45
|
+
}}}}
|
|
46
|
+
</style>
|
|
47
|
+
{{highlighted_html}}
|
|
48
|
+
""".format(fixed_width=REPR_HTML_FIXED_WIDTH)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class NordColor(Enum):
|
|
52
|
+
NORD0 = "#2E3440" # Darkest gray (background)
|
|
53
|
+
NORD1 = "#3B4252" # Dark gray
|
|
54
|
+
NORD2 = "#434C5E" # Medium dark gray
|
|
55
|
+
NORD3 = "#4C566A" # Lighter dark gray
|
|
56
|
+
NORD4 = "#D8DEE9" # Light gray (default text)
|
|
57
|
+
NORD5 = "#E5E9F0" # Very light gray
|
|
58
|
+
NORD6 = "#ECEFF4" # Almost white
|
|
59
|
+
NORD7 = "#8FBCBB" # Teal
|
|
60
|
+
NORD8 = "#88C0D0" # Light cyan
|
|
61
|
+
NORD9 = "#81A1C1" # Soft blue
|
|
62
|
+
NORD10 = "#5E81AC" # Darker blue
|
|
63
|
+
NORD11 = "#BF616A" # Red
|
|
64
|
+
NORD12 = "#D08770" # Orange
|
|
65
|
+
NORD13 = "#EBCB8B" # Yellow
|
|
66
|
+
NORD14 = "#A3BE8C" # Green
|
|
67
|
+
NORD15 = "#B48EAD" # Purple
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
RICH_CONSOLE_THEME = Theme(
|
|
71
|
+
{
|
|
72
|
+
"repr.number": NordColor.NORD15.value, # Purple for numbers
|
|
73
|
+
"repr.string": NordColor.NORD14.value, # Green for strings
|
|
74
|
+
"repr.bool_true": NordColor.NORD9.value, # Blue for True
|
|
75
|
+
"repr.bool_false": NordColor.NORD9.value, # Blue for False
|
|
76
|
+
"repr.none": NordColor.NORD11.value, # Red for None
|
|
77
|
+
"repr.brace": NordColor.NORD7.value, # Teal for brackets/braces
|
|
78
|
+
"repr.comma": NordColor.NORD7.value, # Teal for commas
|
|
79
|
+
"repr.ellipsis": NordColor.NORD7.value, # Teal for ellipsis
|
|
80
|
+
"repr.attrib_name": NordColor.NORD3.value, # Light gray for dict keys
|
|
81
|
+
"repr.attrib_equal": NordColor.NORD7.value, # Teal for equals signs
|
|
82
|
+
"repr.call": NordColor.NORD10.value, # Darker blue for function calls
|
|
83
|
+
"repr.function_name": NordColor.NORD10.value, # Darker blue for function names
|
|
84
|
+
"repr.class_name": NordColor.NORD12.value, # Orange for class names
|
|
85
|
+
"repr.module_name": NordColor.NORD8.value, # Light cyan for module names
|
|
86
|
+
"repr.error": NordColor.NORD11.value, # Red for errors
|
|
87
|
+
"repr.warning": NordColor.NORD13.value, # Yellow for warnings
|
|
88
|
+
}
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
DEFAULT_HIST_NAME_COLOR = "medium_purple1"
|
|
92
|
+
|
|
93
|
+
DEFAULT_HIST_VALUE_COLOR = "pale_green3"
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
DEFAULT_AGE_RANGE = [18, 114]
|
|
97
|
+
MIN_AGE = 0
|
|
98
|
+
MAX_AGE = 114
|
|
99
|
+
|
|
100
|
+
LOCALES_WITH_MANAGED_DATASETS = ["en_US", "ja_JP", "en_IN", "hi_IN"]
|
|
101
|
+
|
|
102
|
+
US_STATES_AND_MAJOR_TERRITORIES = {
|
|
103
|
+
# States
|
|
104
|
+
"AK",
|
|
105
|
+
"AL",
|
|
106
|
+
"AR",
|
|
107
|
+
"AZ",
|
|
108
|
+
"CA",
|
|
109
|
+
"CO",
|
|
110
|
+
"CT",
|
|
111
|
+
"DE",
|
|
112
|
+
"FL",
|
|
113
|
+
"GA",
|
|
114
|
+
"HI",
|
|
115
|
+
"IA",
|
|
116
|
+
"ID",
|
|
117
|
+
"IL",
|
|
118
|
+
"IN",
|
|
119
|
+
"KS",
|
|
120
|
+
"KY",
|
|
121
|
+
"LA",
|
|
122
|
+
"MA",
|
|
123
|
+
"MD",
|
|
124
|
+
"ME",
|
|
125
|
+
"MI",
|
|
126
|
+
"MN",
|
|
127
|
+
"MO",
|
|
128
|
+
"MS",
|
|
129
|
+
"MT",
|
|
130
|
+
"NC",
|
|
131
|
+
"ND",
|
|
132
|
+
"NE",
|
|
133
|
+
"NH",
|
|
134
|
+
"NJ",
|
|
135
|
+
"NM",
|
|
136
|
+
"NV",
|
|
137
|
+
"NY",
|
|
138
|
+
"OH",
|
|
139
|
+
"OK",
|
|
140
|
+
"OR",
|
|
141
|
+
"PA",
|
|
142
|
+
"RI",
|
|
143
|
+
"SC",
|
|
144
|
+
"SD",
|
|
145
|
+
"TN",
|
|
146
|
+
"TX",
|
|
147
|
+
"UT",
|
|
148
|
+
"VA",
|
|
149
|
+
"VT",
|
|
150
|
+
"WA",
|
|
151
|
+
"WI",
|
|
152
|
+
"WV",
|
|
153
|
+
"WY",
|
|
154
|
+
# D.C.
|
|
155
|
+
"DC",
|
|
156
|
+
# Territories
|
|
157
|
+
"AS",
|
|
158
|
+
"GU",
|
|
159
|
+
"MP",
|
|
160
|
+
"PR",
|
|
161
|
+
"VI",
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
MAX_TEMPERATURE = 2.0
|
|
165
|
+
MIN_TEMPERATURE = 0.0
|
|
166
|
+
MAX_TOP_P = 1.0
|
|
167
|
+
MIN_TOP_P = 0.0
|
|
168
|
+
MIN_MAX_TOKENS = 1
|
|
169
|
+
REASONING_TRACE_COLUMN_POSTFIX = "__reasoning_trace"
|
|
170
|
+
|
|
171
|
+
AVAILABLE_LOCALES = [
|
|
172
|
+
"ar_AA",
|
|
173
|
+
"ar_AE",
|
|
174
|
+
"ar_BH",
|
|
175
|
+
"ar_EG",
|
|
176
|
+
"ar_JO",
|
|
177
|
+
"ar_PS",
|
|
178
|
+
"ar_SA",
|
|
179
|
+
"az_AZ",
|
|
180
|
+
"bg_BG",
|
|
181
|
+
"bn_BD",
|
|
182
|
+
"bs_BA",
|
|
183
|
+
"cs_CZ",
|
|
184
|
+
"da_DK",
|
|
185
|
+
"de",
|
|
186
|
+
"de_AT",
|
|
187
|
+
"de_CH",
|
|
188
|
+
"de_DE",
|
|
189
|
+
"dk_DK",
|
|
190
|
+
"el_CY",
|
|
191
|
+
"el_GR",
|
|
192
|
+
"en",
|
|
193
|
+
"en_AU",
|
|
194
|
+
"en_BD",
|
|
195
|
+
"en_CA",
|
|
196
|
+
"en_GB",
|
|
197
|
+
"en_IE",
|
|
198
|
+
"en_IN",
|
|
199
|
+
"en_NZ",
|
|
200
|
+
"en_PH",
|
|
201
|
+
"en_TH",
|
|
202
|
+
"en_US",
|
|
203
|
+
"es",
|
|
204
|
+
"es_AR",
|
|
205
|
+
"es_CA",
|
|
206
|
+
"es_CL",
|
|
207
|
+
"es_CO",
|
|
208
|
+
"es_ES",
|
|
209
|
+
"es_MX",
|
|
210
|
+
"et_EE",
|
|
211
|
+
"fa_IR",
|
|
212
|
+
"fi_FI",
|
|
213
|
+
"fil_PH",
|
|
214
|
+
"fr_BE",
|
|
215
|
+
"fr_CA",
|
|
216
|
+
"fr_CH",
|
|
217
|
+
"fr_FR",
|
|
218
|
+
# "fr_QC", deprecated, use fr_CA instead
|
|
219
|
+
"ga_IE",
|
|
220
|
+
"he_IL",
|
|
221
|
+
"hi_IN",
|
|
222
|
+
"hr_HR",
|
|
223
|
+
"hu_HU",
|
|
224
|
+
"hy_AM",
|
|
225
|
+
"id_ID",
|
|
226
|
+
"it_CH",
|
|
227
|
+
"it_IT",
|
|
228
|
+
"ja_JP",
|
|
229
|
+
"ka_GE",
|
|
230
|
+
"ko_KR",
|
|
231
|
+
"la",
|
|
232
|
+
"lb_LU",
|
|
233
|
+
"lt_LT",
|
|
234
|
+
"lv_LV",
|
|
235
|
+
"mt_MT",
|
|
236
|
+
"ne_NP",
|
|
237
|
+
"nl_BE",
|
|
238
|
+
"nl_NL",
|
|
239
|
+
"no_NO",
|
|
240
|
+
"or_IN",
|
|
241
|
+
"pl_PL",
|
|
242
|
+
"pt_BR",
|
|
243
|
+
"pt_PT",
|
|
244
|
+
"ro_RO",
|
|
245
|
+
"ru_RU",
|
|
246
|
+
"sk_SK",
|
|
247
|
+
"sl_SI",
|
|
248
|
+
"sq_AL",
|
|
249
|
+
"sv_SE",
|
|
250
|
+
"ta_IN",
|
|
251
|
+
"th",
|
|
252
|
+
"th_TH",
|
|
253
|
+
"tl_PH",
|
|
254
|
+
"tr_TR",
|
|
255
|
+
"tw_GH",
|
|
256
|
+
"uk_UA",
|
|
257
|
+
"vi_VN",
|
|
258
|
+
"zh_CN",
|
|
259
|
+
"zh_TW",
|
|
260
|
+
"zu_ZA",
|
|
261
|
+
]
|
|
262
|
+
|
|
263
|
+
DATA_DESIGNER_HOME_ENV_VAR = "DATA_DESIGNER_HOME"
|
|
264
|
+
|
|
265
|
+
DATA_DESIGNER_HOME = Path(os.getenv(DATA_DESIGNER_HOME_ENV_VAR, Path.home() / ".data-designer"))
|
|
266
|
+
|
|
267
|
+
MANAGED_ASSETS_PATH_ENV_VAR = "DATA_DESIGNER_MANAGED_ASSETS_PATH"
|
|
268
|
+
|
|
269
|
+
MANAGED_ASSETS_PATH = Path(os.getenv(MANAGED_ASSETS_PATH_ENV_VAR, DATA_DESIGNER_HOME / "managed-assets"))
|
|
270
|
+
|
|
271
|
+
MODEL_CONFIGS_FILE_NAME = "model_configs.yaml"
|
|
272
|
+
|
|
273
|
+
MODEL_CONFIGS_FILE_PATH = DATA_DESIGNER_HOME / MODEL_CONFIGS_FILE_NAME
|
|
274
|
+
|
|
275
|
+
MODEL_PROVIDERS_FILE_NAME = "model_providers.yaml"
|
|
276
|
+
|
|
277
|
+
MODEL_PROVIDERS_FILE_PATH = DATA_DESIGNER_HOME / MODEL_PROVIDERS_FILE_NAME
|
|
278
|
+
|
|
279
|
+
NVIDIA_PROVIDER_NAME = "nvidia"
|
|
280
|
+
|
|
281
|
+
NVIDIA_API_KEY_ENV_VAR_NAME = "NVIDIA_API_KEY"
|
|
282
|
+
|
|
283
|
+
OPENAI_PROVIDER_NAME = "openai"
|
|
284
|
+
|
|
285
|
+
OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY"
|
|
286
|
+
|
|
287
|
+
PREDEFINED_PROVIDERS = [
|
|
288
|
+
{
|
|
289
|
+
"name": NVIDIA_PROVIDER_NAME,
|
|
290
|
+
"endpoint": "https://integrate.api.nvidia.com/v1",
|
|
291
|
+
"provider_type": "openai",
|
|
292
|
+
"api_key": NVIDIA_API_KEY_ENV_VAR_NAME,
|
|
293
|
+
},
|
|
294
|
+
{
|
|
295
|
+
"name": OPENAI_PROVIDER_NAME,
|
|
296
|
+
"endpoint": "https://api.openai.com/v1",
|
|
297
|
+
"provider_type": "openai",
|
|
298
|
+
"api_key": OPENAI_API_KEY_ENV_VAR_NAME,
|
|
299
|
+
},
|
|
300
|
+
]
|
|
301
|
+
|
|
302
|
+
PREDEFINED_PROVIDERS_MODEL_MAP = {
|
|
303
|
+
NVIDIA_PROVIDER_NAME: {
|
|
304
|
+
"text": "nvidia/nvidia-nemotron-nano-9b-v2",
|
|
305
|
+
"reasoning": "openai/gpt-oss-20b",
|
|
306
|
+
"vision": "nvidia/nemotron-nano-12b-v2-vl",
|
|
307
|
+
},
|
|
308
|
+
OPENAI_PROVIDER_NAME: {
|
|
309
|
+
"text": "gpt-4.1",
|
|
310
|
+
"reasoning": "gpt-5",
|
|
311
|
+
"vision": "gpt-5",
|
|
312
|
+
},
|
|
313
|
+
}
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from ...errors import DataDesignerError
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class UserJinjaTemplateSyntaxError(DataDesignerError): ...
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class InvalidEnumValueError(DataDesignerError): ...
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InvalidTypeUnionError(DataDesignerError): ...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class InvalidDiscriminatorFieldError(DataDesignerError): ...
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DatasetSampleDisplayError(DataDesignerError): ...
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Literal, TypeVar
|
|
7
|
+
|
|
8
|
+
from ..models import ModelConfig, ModelProvider
|
|
9
|
+
from ..sampler_params import SamplerType
|
|
10
|
+
from .type_helpers import get_sampler_params
|
|
11
|
+
from .visualization import display_model_configs_table, display_model_providers_table, display_sampler_table
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class InfoType(str, Enum):
|
|
15
|
+
SAMPLERS = "samplers"
|
|
16
|
+
MODEL_CONFIGS = "model_configs"
|
|
17
|
+
MODEL_PROVIDERS = "model_providers"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
ConfigBuilderInfoType = Literal[InfoType.SAMPLERS, InfoType.MODEL_CONFIGS]
|
|
21
|
+
DataDesignerInfoType = Literal[InfoType.MODEL_PROVIDERS]
|
|
22
|
+
InfoTypeT = TypeVar("InfoTypeT", bound=InfoType)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class InfoDisplay(ABC):
|
|
26
|
+
"""Base class for info display classes that provide type-safe display methods."""
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def display(self, info_type: InfoTypeT, **kwargs) -> None:
|
|
30
|
+
"""Display information based on the provided info type.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
info_type: Type of information to display.
|
|
34
|
+
"""
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ConfigBuilderInfo(InfoDisplay):
|
|
39
|
+
def __init__(self, model_configs: list[ModelConfig]):
|
|
40
|
+
self._sampler_params = get_sampler_params()
|
|
41
|
+
self._model_configs = model_configs
|
|
42
|
+
|
|
43
|
+
def display(self, info_type: ConfigBuilderInfoType, **kwargs) -> None:
|
|
44
|
+
"""Display information based on the provided info type.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
info_type: Type of information to display. Only SAMPLERS and MODEL_CONFIGS are supported.
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
ValueError: If an unsupported info_type is provided.
|
|
51
|
+
"""
|
|
52
|
+
if info_type == InfoType.SAMPLERS:
|
|
53
|
+
self._display_sampler_info(sampler_type=kwargs.get("sampler_type"))
|
|
54
|
+
elif info_type == InfoType.MODEL_CONFIGS:
|
|
55
|
+
display_model_configs_table(self._model_configs)
|
|
56
|
+
else:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Unsupported info_type: {str(info_type)!r}. "
|
|
59
|
+
f"ConfigBuilderInfo only supports {InfoType.SAMPLERS.value!r} and {InfoType.MODEL_CONFIGS.value!r}."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def _display_sampler_info(self, sampler_type: SamplerType | None) -> None:
|
|
63
|
+
if sampler_type is not None:
|
|
64
|
+
title = f"{SamplerType(sampler_type).value.replace('_', ' ').title()} Sampler"
|
|
65
|
+
display_sampler_table({sampler_type: self._sampler_params[sampler_type]}, title=title)
|
|
66
|
+
else:
|
|
67
|
+
display_sampler_table(self._sampler_params)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class InterfaceInfo(InfoDisplay):
|
|
71
|
+
def __init__(self, model_providers: list[ModelProvider]):
|
|
72
|
+
self._model_providers = model_providers
|
|
73
|
+
|
|
74
|
+
def display(self, info_type: DataDesignerInfoType, **kwargs) -> None:
|
|
75
|
+
"""Display information based on the provided info type.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
info_type: Type of information to display. Only MODEL_PROVIDERS is supported.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
ValueError: If an unsupported info_type is provided.
|
|
82
|
+
"""
|
|
83
|
+
if info_type == InfoType.MODEL_PROVIDERS:
|
|
84
|
+
display_model_providers_table(self._model_providers)
|
|
85
|
+
else:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
f"Unsupported info_type: {str(info_type)!r}. InterfaceInfo only supports {InfoType.MODEL_PROVIDERS.value!r}."
|
|
88
|
+
)
|