data-designer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +15 -0
- data_designer/_version.py +34 -0
- data_designer/cli/README.md +236 -0
- data_designer/cli/__init__.py +6 -0
- data_designer/cli/commands/__init__.py +2 -0
- data_designer/cli/commands/list.py +130 -0
- data_designer/cli/commands/models.py +10 -0
- data_designer/cli/commands/providers.py +11 -0
- data_designer/cli/commands/reset.py +100 -0
- data_designer/cli/controllers/__init__.py +7 -0
- data_designer/cli/controllers/model_controller.py +246 -0
- data_designer/cli/controllers/provider_controller.py +317 -0
- data_designer/cli/forms/__init__.py +20 -0
- data_designer/cli/forms/builder.py +51 -0
- data_designer/cli/forms/field.py +180 -0
- data_designer/cli/forms/form.py +59 -0
- data_designer/cli/forms/model_builder.py +125 -0
- data_designer/cli/forms/provider_builder.py +76 -0
- data_designer/cli/main.py +44 -0
- data_designer/cli/repositories/__init__.py +8 -0
- data_designer/cli/repositories/base.py +39 -0
- data_designer/cli/repositories/model_repository.py +42 -0
- data_designer/cli/repositories/provider_repository.py +43 -0
- data_designer/cli/services/__init__.py +7 -0
- data_designer/cli/services/model_service.py +116 -0
- data_designer/cli/services/provider_service.py +111 -0
- data_designer/cli/ui.py +448 -0
- data_designer/cli/utils.py +47 -0
- data_designer/config/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +89 -0
- data_designer/config/analysis/column_statistics.py +274 -0
- data_designer/config/analysis/dataset_profiler.py +60 -0
- data_designer/config/analysis/utils/errors.py +8 -0
- data_designer/config/analysis/utils/reporting.py +188 -0
- data_designer/config/base.py +68 -0
- data_designer/config/column_configs.py +354 -0
- data_designer/config/column_types.py +168 -0
- data_designer/config/config_builder.py +660 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +11 -0
- data_designer/config/datastore.py +151 -0
- data_designer/config/default_model_settings.py +123 -0
- data_designer/config/errors.py +19 -0
- data_designer/config/interface.py +54 -0
- data_designer/config/models.py +231 -0
- data_designer/config/preview_results.py +32 -0
- data_designer/config/processors.py +41 -0
- data_designer/config/sampler_constraints.py +51 -0
- data_designer/config/sampler_params.py +604 -0
- data_designer/config/seed.py +145 -0
- data_designer/config/utils/code_lang.py +83 -0
- data_designer/config/utils/constants.py +313 -0
- data_designer/config/utils/errors.py +19 -0
- data_designer/config/utils/info.py +88 -0
- data_designer/config/utils/io_helpers.py +273 -0
- data_designer/config/utils/misc.py +81 -0
- data_designer/config/utils/numerical_helpers.py +28 -0
- data_designer/config/utils/type_helpers.py +100 -0
- data_designer/config/utils/validation.py +336 -0
- data_designer/config/utils/visualization.py +427 -0
- data_designer/config/validator_params.py +96 -0
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +55 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
- data_designer/engine/analysis/column_profilers/registry.py +20 -0
- data_designer/engine/analysis/column_statistics.py +142 -0
- data_designer/engine/analysis/dataset_profiler.py +125 -0
- data_designer/engine/analysis/errors.py +7 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +61 -0
- data_designer/engine/column_generators/generators/expression.py +63 -0
- data_designer/engine/column_generators/generators/llm_generators.py +172 -0
- data_designer/engine/column_generators/generators/samplers.py +75 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
- data_designer/engine/column_generators/generators/validation.py +147 -0
- data_designer/engine/column_generators/registry.py +56 -0
- data_designer/engine/column_generators/utils/errors.py +13 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
- data_designer/engine/configurable_task.py +82 -0
- data_designer/engine/dataset_builders/artifact_storage.py +181 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
- data_designer/engine/dataset_builders/errors.py +13 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
- data_designer/engine/dataset_builders/utils/dag.py +56 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
- data_designer/engine/dataset_builders/utils/errors.py +13 -0
- data_designer/engine/errors.py +49 -0
- data_designer/engine/model_provider.py +75 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +308 -0
- data_designer/engine/models/facade.py +225 -0
- data_designer/engine/models/litellm_overrides.py +162 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +236 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +60 -0
- data_designer/engine/models/parsers/types.py +82 -0
- data_designer/engine/models/recipes/base.py +79 -0
- data_designer/engine/models/recipes/response_recipes.py +291 -0
- data_designer/engine/models/registry.py +118 -0
- data_designer/engine/models/usage.py +75 -0
- data_designer/engine/models/utils.py +38 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +64 -0
- data_designer/engine/processing/ginja/environment.py +461 -0
- data_designer/engine/processing/ginja/exceptions.py +54 -0
- data_designer/engine/processing/ginja/record.py +30 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +8 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
- data_designer/engine/processing/gsonschema/types.py +8 -0
- data_designer/engine/processing/gsonschema/validators.py +143 -0
- data_designer/engine/processing/processors/base.py +15 -0
- data_designer/engine/processing/processors/drop_columns.py +46 -0
- data_designer/engine/processing/processors/registry.py +20 -0
- data_designer/engine/processing/utils.py +120 -0
- data_designer/engine/registry/base.py +97 -0
- data_designer/engine/registry/data_designer_registry.py +37 -0
- data_designer/engine/registry/errors.py +10 -0
- data_designer/engine/resources/managed_dataset_generator.py +35 -0
- data_designer/engine/resources/managed_dataset_repository.py +194 -0
- data_designer/engine/resources/managed_storage.py +63 -0
- data_designer/engine/resources/resource_provider.py +46 -0
- data_designer/engine/resources/seed_dataset_data_store.py +66 -0
- data_designer/engine/sampling_gen/column.py +89 -0
- data_designer/engine/sampling_gen/constraints.py +95 -0
- data_designer/engine/sampling_gen/data_sources/base.py +214 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
- data_designer/engine/sampling_gen/entities/errors.py +8 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
- data_designer/engine/sampling_gen/entities/person.py +142 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
- data_designer/engine/sampling_gen/errors.py +24 -0
- data_designer/engine/sampling_gen/generator.py +121 -0
- data_designer/engine/sampling_gen/jinja_utils.py +60 -0
- data_designer/engine/sampling_gen/people_gen.py +203 -0
- data_designer/engine/sampling_gen/person_constants.py +54 -0
- data_designer/engine/sampling_gen/schema.py +143 -0
- data_designer/engine/sampling_gen/schema_builder.py +59 -0
- data_designer/engine/sampling_gen/utils.py +40 -0
- data_designer/engine/secret_resolver.py +80 -0
- data_designer/engine/validators/__init__.py +17 -0
- data_designer/engine/validators/base.py +36 -0
- data_designer/engine/validators/local_callable.py +34 -0
- data_designer/engine/validators/python.py +245 -0
- data_designer/engine/validators/remote.py +83 -0
- data_designer/engine/validators/sql.py +60 -0
- data_designer/errors.py +5 -0
- data_designer/essentials/__init__.py +137 -0
- data_designer/interface/__init__.py +2 -0
- data_designer/interface/data_designer.py +351 -0
- data_designer/interface/errors.py +16 -0
- data_designer/interface/results.py +55 -0
- data_designer/logging.py +161 -0
- data_designer/plugin_manager.py +83 -0
- data_designer/plugins/__init__.py +6 -0
- data_designer/plugins/errors.py +10 -0
- data_designer/plugins/plugin.py +69 -0
- data_designer/plugins/registry.py +86 -0
- data_designer-0.1.0.dist-info/METADATA +173 -0
- data_designer-0.1.0.dist-info/RECORD +177 -0
- data_designer-0.1.0.dist-info/WHEEL +4 -0
- data_designer-0.1.0.dist-info/entry_points.txt +2 -0
- data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,427 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from collections import OrderedDict
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from functools import cached_property
|
|
9
|
+
import json
|
|
10
|
+
import os
|
|
11
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
from rich.console import Console, Group
|
|
16
|
+
from rich.padding import Padding
|
|
17
|
+
from rich.panel import Panel
|
|
18
|
+
from rich.pretty import Pretty
|
|
19
|
+
from rich.rule import Rule
|
|
20
|
+
from rich.syntax import Syntax
|
|
21
|
+
from rich.table import Table
|
|
22
|
+
from rich.text import Text
|
|
23
|
+
|
|
24
|
+
from ..base import ConfigBase
|
|
25
|
+
from ..column_types import DataDesignerColumnType
|
|
26
|
+
from ..models import ModelConfig, ModelProvider
|
|
27
|
+
from ..sampler_params import SamplerType
|
|
28
|
+
from .code_lang import code_lang_to_syntax_lexer
|
|
29
|
+
from .constants import NVIDIA_API_KEY_ENV_VAR_NAME, OPENAI_API_KEY_ENV_VAR_NAME
|
|
30
|
+
from .errors import DatasetSampleDisplayError
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from ..config_builder import DataDesignerConfigBuilder
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
console = Console()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_nvidia_api_key() -> Optional[str]:
|
|
40
|
+
return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_openai_api_key() -> Optional[str]:
|
|
44
|
+
return os.getenv(OPENAI_API_KEY_ENV_VAR_NAME)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ColorPalette(str, Enum):
|
|
48
|
+
NVIDIA_GREEN = "#76b900"
|
|
49
|
+
PURPLE = "#9525c6"
|
|
50
|
+
YELLOW = "#f9c500"
|
|
51
|
+
BLUE = "#0074df"
|
|
52
|
+
RED = "#e52020"
|
|
53
|
+
ORANGE = "#ef9100"
|
|
54
|
+
MAGENTA = "#d2308e"
|
|
55
|
+
TEAL = "#1dbba4"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class WithRecordSamplerMixin:
|
|
59
|
+
_display_cycle_index: int = 0
|
|
60
|
+
|
|
61
|
+
@cached_property
|
|
62
|
+
def _record_sampler_dataset(self) -> pd.DataFrame:
|
|
63
|
+
if hasattr(self, "dataset") and self.dataset is not None and isinstance(self.dataset, pd.DataFrame):
|
|
64
|
+
return self.dataset
|
|
65
|
+
elif (
|
|
66
|
+
hasattr(self, "load_dataset")
|
|
67
|
+
and callable(self.load_dataset)
|
|
68
|
+
and (dataset := self.load_dataset()) is not None
|
|
69
|
+
and isinstance(dataset, pd.DataFrame)
|
|
70
|
+
):
|
|
71
|
+
return dataset
|
|
72
|
+
else:
|
|
73
|
+
raise DatasetSampleDisplayError("No valid dataset found in results object.")
|
|
74
|
+
|
|
75
|
+
def display_sample_record(
|
|
76
|
+
self,
|
|
77
|
+
index: Optional[int] = None,
|
|
78
|
+
*,
|
|
79
|
+
hide_seed_columns: bool = False,
|
|
80
|
+
syntax_highlighting_theme: str = "dracula",
|
|
81
|
+
background_color: Optional[str] = None,
|
|
82
|
+
) -> None:
|
|
83
|
+
"""Display a sample record from the Data Designer dataset preview.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
index: Index of the record to display. If None, the next record will be displayed.
|
|
87
|
+
This is useful for running the cell in a notebook multiple times.
|
|
88
|
+
hide_seed_columns: If True, the columns from the seed dataset (if any) will not be displayed.
|
|
89
|
+
syntax_highlighting_theme: Theme to use for syntax highlighting. See the `Syntax`
|
|
90
|
+
documentation from `rich` for information about available themes.
|
|
91
|
+
background_color: Background color to use for the record. See the `Syntax`
|
|
92
|
+
documentation from `rich` for information about available background colors.
|
|
93
|
+
"""
|
|
94
|
+
i = index or self._display_cycle_index
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
record = self._record_sampler_dataset.iloc[i]
|
|
98
|
+
num_records = len(self._record_sampler_dataset)
|
|
99
|
+
except IndexError:
|
|
100
|
+
raise DatasetSampleDisplayError(f"Index {i} is out of bounds for dataset of length {num_records}.")
|
|
101
|
+
|
|
102
|
+
display_sample_record(
|
|
103
|
+
record=record,
|
|
104
|
+
config_builder=self._config_builder,
|
|
105
|
+
background_color=background_color,
|
|
106
|
+
syntax_highlighting_theme=syntax_highlighting_theme,
|
|
107
|
+
hide_seed_columns=hide_seed_columns,
|
|
108
|
+
record_index=i,
|
|
109
|
+
)
|
|
110
|
+
if index is None:
|
|
111
|
+
self._display_cycle_index = (self._display_cycle_index + 1) % num_records
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def create_rich_histogram_table(
|
|
115
|
+
data: dict[str, Union[int, float]],
|
|
116
|
+
column_names: tuple[int, int],
|
|
117
|
+
name_style: str = ColorPalette.BLUE.value,
|
|
118
|
+
value_style: str = ColorPalette.TEAL.value,
|
|
119
|
+
title: Optional[str] = None,
|
|
120
|
+
**kwargs,
|
|
121
|
+
) -> Table:
|
|
122
|
+
table = Table(title=title, **kwargs)
|
|
123
|
+
table.add_column(column_names[0], justify="right", style=name_style)
|
|
124
|
+
table.add_column(column_names[1], justify="left", style=value_style)
|
|
125
|
+
|
|
126
|
+
max_count = max(data.values())
|
|
127
|
+
for name, value in data.items():
|
|
128
|
+
bar = "" if max_count <= 0 else "█" * int((value / max_count) * 20)
|
|
129
|
+
table.add_row(str(name), f"{bar} {value:.1f}")
|
|
130
|
+
|
|
131
|
+
return table
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def display_sample_record(
|
|
135
|
+
record: Union[dict, pd.Series, pd.DataFrame],
|
|
136
|
+
config_builder: DataDesignerConfigBuilder,
|
|
137
|
+
background_color: Optional[str] = None,
|
|
138
|
+
syntax_highlighting_theme: str = "dracula",
|
|
139
|
+
record_index: Optional[int] = None,
|
|
140
|
+
hide_seed_columns: bool = False,
|
|
141
|
+
):
|
|
142
|
+
if isinstance(record, (dict, pd.Series)):
|
|
143
|
+
record = pd.DataFrame([record]).iloc[0]
|
|
144
|
+
elif isinstance(record, pd.DataFrame):
|
|
145
|
+
if record.shape[0] > 1:
|
|
146
|
+
raise DatasetSampleDisplayError(
|
|
147
|
+
f"The record must be a single record. You provided a DataFrame with {record.shape[0]} records."
|
|
148
|
+
)
|
|
149
|
+
record = record.iloc[0]
|
|
150
|
+
else:
|
|
151
|
+
raise DatasetSampleDisplayError(
|
|
152
|
+
"The record must be a single record in a dictionary, pandas Series, "
|
|
153
|
+
f"or pandas DataFrame. You provided: {type(record)}."
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
render_list = []
|
|
157
|
+
table_kws = dict(show_lines=True, expand=True)
|
|
158
|
+
|
|
159
|
+
seed_columns = config_builder.get_columns_of_type(DataDesignerColumnType.SEED_DATASET)
|
|
160
|
+
if not hide_seed_columns and len(seed_columns) > 0:
|
|
161
|
+
table = Table(title="Seed Columns", **table_kws)
|
|
162
|
+
table.add_column("Name")
|
|
163
|
+
table.add_column("Value")
|
|
164
|
+
for col in seed_columns:
|
|
165
|
+
if not col.drop:
|
|
166
|
+
table.add_row(col.name, convert_to_row_element(record[col.name]))
|
|
167
|
+
render_list.append(pad_console_element(table))
|
|
168
|
+
|
|
169
|
+
non_code_columns = (
|
|
170
|
+
config_builder.get_columns_of_type(DataDesignerColumnType.SAMPLER)
|
|
171
|
+
+ config_builder.get_columns_of_type(DataDesignerColumnType.EXPRESSION)
|
|
172
|
+
+ config_builder.get_columns_of_type(DataDesignerColumnType.LLM_TEXT)
|
|
173
|
+
+ config_builder.get_columns_of_type(DataDesignerColumnType.LLM_STRUCTURED)
|
|
174
|
+
)
|
|
175
|
+
if len(non_code_columns) > 0:
|
|
176
|
+
table = Table(title="Generated Columns", **table_kws)
|
|
177
|
+
table.add_column("Name")
|
|
178
|
+
table.add_column("Value")
|
|
179
|
+
for col in non_code_columns:
|
|
180
|
+
if not col.drop:
|
|
181
|
+
table.add_row(col.name, convert_to_row_element(record[col.name]))
|
|
182
|
+
render_list.append(pad_console_element(table))
|
|
183
|
+
|
|
184
|
+
for col in config_builder.get_columns_of_type(DataDesignerColumnType.LLM_CODE):
|
|
185
|
+
panel = Panel(
|
|
186
|
+
Syntax(
|
|
187
|
+
record[col.name],
|
|
188
|
+
lexer=code_lang_to_syntax_lexer(col.code_lang),
|
|
189
|
+
theme=syntax_highlighting_theme,
|
|
190
|
+
word_wrap=True,
|
|
191
|
+
background_color=background_color,
|
|
192
|
+
),
|
|
193
|
+
title=col.name,
|
|
194
|
+
expand=True,
|
|
195
|
+
)
|
|
196
|
+
render_list.append(pad_console_element(panel))
|
|
197
|
+
|
|
198
|
+
validation_columns = config_builder.get_columns_of_type(DataDesignerColumnType.VALIDATION)
|
|
199
|
+
if len(validation_columns) > 0:
|
|
200
|
+
table = Table(title="Validation", **table_kws)
|
|
201
|
+
table.add_column("Name")
|
|
202
|
+
table.add_column("Value", ratio=1)
|
|
203
|
+
for col in validation_columns:
|
|
204
|
+
if not col.drop:
|
|
205
|
+
# Add is_valid before other fields
|
|
206
|
+
if "is_valid" in record[col.name]:
|
|
207
|
+
value_to_display = {"is_valid": record[col.name].get("is_valid")} | record[col.name]
|
|
208
|
+
else: # if columns treated separately
|
|
209
|
+
value_to_display = {}
|
|
210
|
+
for col_name, validation_output in record[col.name].items():
|
|
211
|
+
value_to_display[col_name] = {
|
|
212
|
+
"is_valid": validation_output.get("is_valid", None)
|
|
213
|
+
} | validation_output
|
|
214
|
+
|
|
215
|
+
table.add_row(col.name, convert_to_row_element(value_to_display))
|
|
216
|
+
render_list.append(pad_console_element(table, (1, 0, 1, 0)))
|
|
217
|
+
|
|
218
|
+
llm_judge_columns = config_builder.get_columns_of_type(DataDesignerColumnType.LLM_JUDGE)
|
|
219
|
+
if len(llm_judge_columns) > 0:
|
|
220
|
+
for col in llm_judge_columns:
|
|
221
|
+
if col.drop:
|
|
222
|
+
continue
|
|
223
|
+
table = Table(title=f"LLM-as-a-Judge: {col.name}", **table_kws)
|
|
224
|
+
row = []
|
|
225
|
+
judge = record[col.name]
|
|
226
|
+
|
|
227
|
+
for measure, results in judge.items():
|
|
228
|
+
table.add_column(measure)
|
|
229
|
+
row.append(f"score: {results['score']}\nreasoning: {results['reasoning']}")
|
|
230
|
+
table.add_row(*row)
|
|
231
|
+
render_list.append(pad_console_element(table, (1, 0, 1, 0)))
|
|
232
|
+
|
|
233
|
+
if record_index is not None:
|
|
234
|
+
index_label = Text(f"[index: {record_index}]", justify="center")
|
|
235
|
+
render_list.append(index_label)
|
|
236
|
+
|
|
237
|
+
console.print(Group(*render_list), markup=False)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def display_sampler_table(
|
|
241
|
+
sampler_params: dict[SamplerType, ConfigBase],
|
|
242
|
+
title: Optional[str] = None,
|
|
243
|
+
) -> None:
|
|
244
|
+
table = Table(expand=True)
|
|
245
|
+
table.add_column("Type")
|
|
246
|
+
table.add_column("Parameter")
|
|
247
|
+
table.add_column("Data Type")
|
|
248
|
+
table.add_column("Required", justify="center")
|
|
249
|
+
table.add_column("Constraints")
|
|
250
|
+
|
|
251
|
+
for sampler_type, params in sampler_params.items():
|
|
252
|
+
num = 0
|
|
253
|
+
schema = params.model_json_schema()
|
|
254
|
+
for param_name, field_info in schema["properties"].items():
|
|
255
|
+
is_required = param_name in schema.get("required", [])
|
|
256
|
+
table.add_row(
|
|
257
|
+
sampler_type if num == 0 else "",
|
|
258
|
+
param_name,
|
|
259
|
+
_get_field_type(field_info),
|
|
260
|
+
"✓" if is_required else "",
|
|
261
|
+
_get_field_constraints(field_info, schema),
|
|
262
|
+
)
|
|
263
|
+
num += 1
|
|
264
|
+
table.add_section()
|
|
265
|
+
|
|
266
|
+
title = title or "NeMo Data Designer Samplers"
|
|
267
|
+
|
|
268
|
+
group = Group(Rule(title, end="\n\n"), table)
|
|
269
|
+
console.print(group)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def display_model_configs_table(model_configs: list[ModelConfig]) -> None:
|
|
273
|
+
table_model_configs = Table(expand=True)
|
|
274
|
+
table_model_configs.add_column("Alias")
|
|
275
|
+
table_model_configs.add_column("Model")
|
|
276
|
+
table_model_configs.add_column("Provider")
|
|
277
|
+
table_model_configs.add_column("Temperature")
|
|
278
|
+
table_model_configs.add_column("Top P")
|
|
279
|
+
for model_config in model_configs:
|
|
280
|
+
table_model_configs.add_row(
|
|
281
|
+
model_config.alias,
|
|
282
|
+
model_config.model,
|
|
283
|
+
model_config.provider,
|
|
284
|
+
str(model_config.inference_parameters.temperature),
|
|
285
|
+
str(model_config.inference_parameters.top_p),
|
|
286
|
+
)
|
|
287
|
+
group_args: list = [Rule(title="Model Configs"), table_model_configs]
|
|
288
|
+
if len(model_configs) == 0:
|
|
289
|
+
subtitle = Text(
|
|
290
|
+
"‼️ No model configs found. Please provide at least one model config to the config builder",
|
|
291
|
+
style="dim",
|
|
292
|
+
justify="center",
|
|
293
|
+
)
|
|
294
|
+
group_args.insert(1, subtitle)
|
|
295
|
+
group = Group(*group_args)
|
|
296
|
+
console.print(group)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def display_model_providers_table(model_providers: list[ModelProvider]) -> None:
|
|
300
|
+
table_model_providers = Table(expand=True)
|
|
301
|
+
table_model_providers.add_column("Name")
|
|
302
|
+
table_model_providers.add_column("Endpoint")
|
|
303
|
+
table_model_providers.add_column("API Key")
|
|
304
|
+
for model_provider in model_providers:
|
|
305
|
+
api_key = model_provider.api_key
|
|
306
|
+
if model_provider.api_key == OPENAI_API_KEY_ENV_VAR_NAME:
|
|
307
|
+
if get_openai_api_key() is not None:
|
|
308
|
+
api_key = mask_api_key(get_openai_api_key())
|
|
309
|
+
else:
|
|
310
|
+
api_key = f"* {OPENAI_API_KEY_ENV_VAR_NAME!r} not set in environment variables * "
|
|
311
|
+
elif model_provider.api_key == NVIDIA_API_KEY_ENV_VAR_NAME:
|
|
312
|
+
if get_nvidia_api_key() is not None:
|
|
313
|
+
api_key = mask_api_key(get_nvidia_api_key())
|
|
314
|
+
else:
|
|
315
|
+
api_key = f"* {NVIDIA_API_KEY_ENV_VAR_NAME!r} not set in environment variables *"
|
|
316
|
+
else:
|
|
317
|
+
api_key = mask_api_key(model_provider.api_key)
|
|
318
|
+
table_model_providers.add_row(model_provider.name, model_provider.endpoint, api_key)
|
|
319
|
+
group = Group(Rule(title="Model Providers"), table_model_providers)
|
|
320
|
+
console.print(group)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def mask_api_key(api_key: str | None) -> str:
|
|
324
|
+
"""Mask API keys for display.
|
|
325
|
+
|
|
326
|
+
Environment variable names (all uppercase) are kept visible.
|
|
327
|
+
Actual API keys are masked to show only the last 4 characters.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
api_key: The API key to mask.
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
Masked API key string or "(not set)" if None.
|
|
334
|
+
"""
|
|
335
|
+
if not api_key:
|
|
336
|
+
return "(not set)"
|
|
337
|
+
|
|
338
|
+
# Keep environment variable names visible
|
|
339
|
+
if api_key.isupper():
|
|
340
|
+
return api_key
|
|
341
|
+
|
|
342
|
+
# Mask actual API keys
|
|
343
|
+
return "***" + api_key[-4:] if len(api_key) > 4 else "***"
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def convert_to_row_element(elem):
|
|
347
|
+
try:
|
|
348
|
+
elem = Pretty(json.loads(elem))
|
|
349
|
+
except (TypeError, json.JSONDecodeError):
|
|
350
|
+
pass
|
|
351
|
+
if isinstance(elem, (np.integer, np.floating, np.ndarray)):
|
|
352
|
+
elem = str(elem)
|
|
353
|
+
elif isinstance(elem, (list, dict)):
|
|
354
|
+
elem = Pretty(elem)
|
|
355
|
+
return elem
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def pad_console_element(elem, padding=(1, 0, 1, 0)):
|
|
359
|
+
return Padding(elem, padding)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def _get_field_type(field: dict) -> str:
|
|
363
|
+
"""Extract human-readable type information from a JSON Schema field."""
|
|
364
|
+
|
|
365
|
+
# single type
|
|
366
|
+
if "type" in field:
|
|
367
|
+
if field["type"] == "array":
|
|
368
|
+
return " | ".join([f"{f.strip()}[]" for f in _get_field_type(field["items"]).split("|")])
|
|
369
|
+
if field["type"] == "object":
|
|
370
|
+
return "dict"
|
|
371
|
+
return field["type"]
|
|
372
|
+
|
|
373
|
+
# union type
|
|
374
|
+
elif "anyOf" in field:
|
|
375
|
+
types = []
|
|
376
|
+
for f in field["anyOf"]:
|
|
377
|
+
if "$ref" in f:
|
|
378
|
+
types.append("enum")
|
|
379
|
+
elif f.get("type") == "array":
|
|
380
|
+
if "items" in f and "$ref" in f["items"]:
|
|
381
|
+
types.append("enum[]")
|
|
382
|
+
else:
|
|
383
|
+
types.append(f"{f['items']['type']}[]")
|
|
384
|
+
else:
|
|
385
|
+
types.append(f.get("type", ""))
|
|
386
|
+
return " | ".join(t for t in types if t)
|
|
387
|
+
|
|
388
|
+
return ""
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def _get_field_constraints(field: dict, schema: dict) -> str:
|
|
392
|
+
"""Extract human-readable constraints from a JSON Schema field."""
|
|
393
|
+
constraints = []
|
|
394
|
+
|
|
395
|
+
# numeric constraints
|
|
396
|
+
if "minimum" in field:
|
|
397
|
+
constraints.append(f">= {field['minimum']}")
|
|
398
|
+
if "exclusiveMinimum" in field:
|
|
399
|
+
constraints.append(f"> {field['exclusiveMinimum']}")
|
|
400
|
+
if "maximum" in field:
|
|
401
|
+
constraints.append(f"<= {field['maximum']}")
|
|
402
|
+
if "exclusiveMaximum" in field:
|
|
403
|
+
constraints.append(f"< {field['exclusiveMaximum']}")
|
|
404
|
+
|
|
405
|
+
# string constraints
|
|
406
|
+
if "minLength" in field:
|
|
407
|
+
constraints.append(f"len > {field['minLength']}")
|
|
408
|
+
if "maxLength" in field:
|
|
409
|
+
constraints.append(f"len < {field['maxLength']}")
|
|
410
|
+
|
|
411
|
+
# array constraints
|
|
412
|
+
if "minItems" in field:
|
|
413
|
+
constraints.append(f"len > {field['minItems']}")
|
|
414
|
+
if "maxItems" in field:
|
|
415
|
+
constraints.append(f"len < {field['maxItems']}")
|
|
416
|
+
|
|
417
|
+
# enum constraints
|
|
418
|
+
if "enum" in _get_field_type(field) and "$defs" in schema:
|
|
419
|
+
enum_values = []
|
|
420
|
+
for defs in schema["$defs"].values():
|
|
421
|
+
if "enum" in defs:
|
|
422
|
+
enum_values.extend(defs["enum"])
|
|
423
|
+
if len(enum_values) > 0:
|
|
424
|
+
enum_values = OrderedDict.fromkeys(enum_values)
|
|
425
|
+
constraints.append(f"allowed: {', '.join(enum_values.keys())}")
|
|
426
|
+
|
|
427
|
+
return ", ".join(constraints)
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Any, Optional, Union
|
|
6
|
+
|
|
7
|
+
from pydantic import Field, field_serializer, model_validator
|
|
8
|
+
from typing_extensions import Self, TypeAlias
|
|
9
|
+
|
|
10
|
+
from .base import ConfigBase
|
|
11
|
+
from .utils.code_lang import SQL_DIALECTS, CodeLang
|
|
12
|
+
|
|
13
|
+
SUPPORTED_CODE_LANGUAGES = {CodeLang.PYTHON, *SQL_DIALECTS}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ValidatorType(str, Enum):
|
|
17
|
+
CODE = "code"
|
|
18
|
+
LOCAL_CALLABLE = "local_callable"
|
|
19
|
+
REMOTE = "remote"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CodeValidatorParams(ConfigBase):
|
|
23
|
+
"""Configuration for code validation. Supports Python and SQL code validation.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
code_lang: The language of the code to validate. Supported values include: `python`,
|
|
27
|
+
`sql:sqlite`, `sql:postgres`, `sql:mysql`, `sql:tsql`, `sql:bigquery`, `sql:ansi`.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
code_lang: CodeLang = Field(description="The language of the code to validate")
|
|
31
|
+
|
|
32
|
+
@model_validator(mode="after")
|
|
33
|
+
def validate_code_lang(self) -> Self:
|
|
34
|
+
if self.code_lang not in SUPPORTED_CODE_LANGUAGES:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"Unsupported code language, supported languages are: {[lang.value for lang in SUPPORTED_CODE_LANGUAGES]}"
|
|
37
|
+
)
|
|
38
|
+
return self
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class LocalCallableValidatorParams(ConfigBase):
|
|
42
|
+
"""Configuration for local callable validation. Expects a function to be passed that validates the data.
|
|
43
|
+
|
|
44
|
+
Attributes:
|
|
45
|
+
validation_function: Function (`Callable[[pd.DataFrame], pd.DataFrame]`) to validate the
|
|
46
|
+
data. Output must contain a column `is_valid` of type `bool`.
|
|
47
|
+
output_schema: The JSON schema for the local callable validator's output. If not provided,
|
|
48
|
+
the output will not be validated.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
validation_function: Any = Field(
|
|
52
|
+
description="Function (Callable[[pd.DataFrame], pd.DataFrame]) to validate the data"
|
|
53
|
+
)
|
|
54
|
+
output_schema: Optional[dict[str, Any]] = Field(
|
|
55
|
+
default=None, description="Expected schema for local callable validator's output"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
@field_serializer("validation_function")
|
|
59
|
+
def serialize_validation_function(self, v: Any) -> Any:
|
|
60
|
+
return v.__name__
|
|
61
|
+
|
|
62
|
+
@model_validator(mode="after")
|
|
63
|
+
def validate_validation_function(self) -> Self:
|
|
64
|
+
if not callable(self.validation_function):
|
|
65
|
+
raise ValueError("Validation function must be a callable")
|
|
66
|
+
return self
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class RemoteValidatorParams(ConfigBase):
|
|
70
|
+
"""Configuration for remote validation. Sends data to a remote endpoint for validation.
|
|
71
|
+
|
|
72
|
+
Attributes:
|
|
73
|
+
endpoint_url: The URL of the remote endpoint.
|
|
74
|
+
output_schema: The JSON schema for the remote validator's output. If not provided,
|
|
75
|
+
the output will not be validated.
|
|
76
|
+
timeout: The timeout for the HTTP request in seconds. Defaults to 30.0.
|
|
77
|
+
max_retries: The maximum number of retry attempts. Defaults to 3.
|
|
78
|
+
retry_backoff: The backoff factor for the retry delay in seconds. Defaults to 2.0.
|
|
79
|
+
max_parallel_requests: The maximum number of parallel requests to make. Defaults to 4.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
endpoint_url: str = Field(description="URL of the remote endpoint")
|
|
83
|
+
output_schema: Optional[dict[str, Any]] = Field(
|
|
84
|
+
default=None, description="Expected schema for remote validator's output"
|
|
85
|
+
)
|
|
86
|
+
timeout: float = Field(default=30.0, gt=0, description="The timeout for the HTTP request")
|
|
87
|
+
max_retries: int = Field(default=3, ge=0, description="The maximum number of retry attempts")
|
|
88
|
+
retry_backoff: float = Field(default=2.0, gt=1, description="The backoff factor for the retry delay")
|
|
89
|
+
max_parallel_requests: int = Field(default=4, ge=1, description="The maximum number of parallel requests to make")
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
ValidatorParamsT: TypeAlias = Union[
|
|
93
|
+
CodeValidatorParams,
|
|
94
|
+
LocalCallableValidatorParams,
|
|
95
|
+
RemoteValidatorParams,
|
|
96
|
+
]
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import pyarrow as pa
|
|
11
|
+
from pydantic import BaseModel, model_validator
|
|
12
|
+
from typing_extensions import Self
|
|
13
|
+
|
|
14
|
+
from data_designer.config.base import ConfigBase
|
|
15
|
+
from data_designer.config.column_configs import SingleColumnConfig
|
|
16
|
+
from data_designer.config.column_types import DataDesignerColumnType
|
|
17
|
+
from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, TaskConfigT
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ColumnConfigWithDataFrame(ConfigBase):
|
|
23
|
+
column_config: SingleColumnConfig
|
|
24
|
+
df: pd.DataFrame
|
|
25
|
+
|
|
26
|
+
@model_validator(mode="after")
|
|
27
|
+
def validate_column_exists(self) -> Self:
|
|
28
|
+
if self.column_config.name not in self.df.columns:
|
|
29
|
+
raise ValueError(f"Column {self.column_config.name!r} not found in DataFrame")
|
|
30
|
+
return self
|
|
31
|
+
|
|
32
|
+
@model_validator(mode="after")
|
|
33
|
+
def ensure_pyarrow_backend(self) -> Self:
|
|
34
|
+
if not all(isinstance(dtype, pd.ArrowDtype) for dtype in self.df.dtypes):
|
|
35
|
+
self.df = pa.Table.from_pandas(self.df).to_pandas(types_mapper=pd.ArrowDtype)
|
|
36
|
+
return self
|
|
37
|
+
|
|
38
|
+
def as_tuple(self) -> tuple[SingleColumnConfig, pd.DataFrame]:
|
|
39
|
+
return (self.column_config, self.df)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ColumnProfilerMetadata(ConfigurableTaskMetadata):
|
|
43
|
+
applicable_column_types: list[DataDesignerColumnType]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ColumnProfiler(ConfigurableTask[TaskConfigT], ABC):
|
|
47
|
+
@staticmethod
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def metadata() -> ColumnProfilerMetadata: ...
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def profile(self, column_config_with_df: ColumnConfigWithDataFrame) -> BaseModel: ...
|
|
53
|
+
|
|
54
|
+
def _initialize(self) -> None:
|
|
55
|
+
logger.info(f"💫 Initializing column profiler: '{self.metadata().name}'")
|