data-designer 0.1.5__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/cli/README.md +15 -1
- data_designer/cli/commands/download.py +56 -0
- data_designer/cli/commands/list.py +4 -18
- data_designer/cli/controllers/__init__.py +2 -1
- data_designer/cli/controllers/download_controller.py +217 -0
- data_designer/cli/controllers/model_controller.py +4 -3
- data_designer/cli/forms/field.py +65 -19
- data_designer/cli/forms/model_builder.py +251 -44
- data_designer/cli/main.py +11 -1
- data_designer/cli/repositories/persona_repository.py +88 -0
- data_designer/cli/services/__init__.py +2 -1
- data_designer/cli/services/download_service.py +97 -0
- data_designer/cli/ui.py +131 -0
- data_designer/cli/utils.py +34 -0
- data_designer/config/analysis/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +75 -7
- data_designer/config/analysis/column_statistics.py +192 -48
- data_designer/config/analysis/dataset_profiler.py +23 -5
- data_designer/config/analysis/utils/reporting.py +3 -3
- data_designer/config/base.py +3 -3
- data_designer/config/column_configs.py +27 -6
- data_designer/config/column_types.py +24 -17
- data_designer/config/config_builder.py +36 -27
- data_designer/config/data_designer_config.py +7 -7
- data_designer/config/datastore.py +6 -6
- data_designer/config/default_model_settings.py +27 -34
- data_designer/config/exports.py +8 -0
- data_designer/config/models.py +155 -29
- data_designer/config/preview_results.py +6 -8
- data_designer/config/processors.py +63 -2
- data_designer/config/sampler_constraints.py +1 -2
- data_designer/config/sampler_params.py +50 -31
- data_designer/config/seed.py +1 -2
- data_designer/config/utils/code_lang.py +4 -5
- data_designer/config/utils/constants.py +31 -8
- data_designer/config/utils/io_helpers.py +5 -5
- data_designer/config/utils/misc.py +1 -4
- data_designer/config/utils/numerical_helpers.py +2 -2
- data_designer/config/utils/type_helpers.py +3 -3
- data_designer/config/utils/validation.py +7 -8
- data_designer/config/utils/visualization.py +32 -17
- data_designer/config/validator_params.py +4 -8
- data_designer/engine/analysis/column_profilers/base.py +0 -7
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
- data_designer/engine/analysis/column_statistics.py +16 -16
- data_designer/engine/analysis/dataset_profiler.py +25 -4
- data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
- data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
- data_designer/engine/column_generators/generators/base.py +34 -0
- data_designer/engine/column_generators/generators/embedding.py +45 -0
- data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
- data_designer/engine/column_generators/registry.py +4 -2
- data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
- data_designer/engine/configurable_task.py +2 -2
- data_designer/engine/dataset_builders/artifact_storage.py +1 -2
- data_designer/engine/dataset_builders/column_wise_builder.py +58 -15
- data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
- data_designer/engine/models/facade.py +66 -9
- data_designer/engine/models/litellm_overrides.py +5 -6
- data_designer/engine/models/parsers/errors.py +2 -4
- data_designer/engine/models/parsers/parser.py +2 -3
- data_designer/engine/models/parsers/postprocessors.py +3 -4
- data_designer/engine/models/parsers/types.py +4 -4
- data_designer/engine/models/registry.py +47 -12
- data_designer/engine/models/telemetry.py +355 -0
- data_designer/engine/models/usage.py +7 -9
- data_designer/engine/processing/ginja/ast.py +1 -2
- data_designer/engine/processing/utils.py +40 -2
- data_designer/engine/registry/base.py +12 -12
- data_designer/engine/sampling_gen/constraints.py +1 -2
- data_designer/engine/sampling_gen/data_sources/base.py +14 -14
- data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
- data_designer/engine/sampling_gen/people_gen.py +3 -7
- data_designer/engine/validators/base.py +2 -2
- data_designer/logging.py +2 -2
- data_designer/plugin_manager.py +3 -3
- data_designer/plugins/plugin.py +3 -3
- data_designer/plugins/registry.py +2 -2
- {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/METADATA +32 -1
- {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/RECORD +84 -77
- {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/WHEEL +0 -0
- {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/entry_points.txt +0 -0
- {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/licenses/LICENSE +0 -0
data_designer/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.1
|
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
|
31
|
+
__version__ = version = '0.2.1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 1)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
data_designer/cli/README.md
CHANGED
|
@@ -129,8 +129,10 @@ class ConfigRepository(ABC, Generic[T]):
|
|
|
129
129
|
- Field-level validation
|
|
130
130
|
- Auto-completion support
|
|
131
131
|
- History navigation (arrow keys)
|
|
132
|
-
-
|
|
132
|
+
- Current value display when editing (`(current value: X)` instead of `(default: X)`)
|
|
133
|
+
- Value clearing support (type `'clear'` to remove optional parameter values)
|
|
133
134
|
- Back navigation support
|
|
135
|
+
- Empty input handling (Enter key keeps current value or skips optional fields)
|
|
134
136
|
|
|
135
137
|
#### 6. **UI Utilities** (`ui.py`)
|
|
136
138
|
- **Purpose**: User interface utilities for terminal output and input
|
|
@@ -179,17 +181,29 @@ model_configs:
|
|
|
179
181
|
model: meta/llama-3.1-70b-instruct
|
|
180
182
|
provider: nvidia
|
|
181
183
|
inference_parameters:
|
|
184
|
+
generation_type: chat-completion
|
|
182
185
|
temperature: 0.7
|
|
183
186
|
top_p: 0.9
|
|
184
187
|
max_tokens: 2048
|
|
185
188
|
max_parallel_requests: 4
|
|
189
|
+
timeout: 60
|
|
186
190
|
- alias: gpt-4
|
|
187
191
|
model: gpt-4-turbo
|
|
188
192
|
provider: openai
|
|
189
193
|
inference_parameters:
|
|
194
|
+
generation_type: chat-completion
|
|
190
195
|
temperature: 0.8
|
|
191
196
|
top_p: 0.95
|
|
192
197
|
max_tokens: 4096
|
|
198
|
+
max_parallel_requests: 4
|
|
199
|
+
- alias: embedder
|
|
200
|
+
model: text-embedding-3-large
|
|
201
|
+
provider: openai
|
|
202
|
+
inference_parameters:
|
|
203
|
+
generation_type: embedding
|
|
204
|
+
encoding_format: float
|
|
205
|
+
dimensions: 1024
|
|
206
|
+
max_parallel_requests: 4
|
|
193
207
|
```
|
|
194
208
|
|
|
195
209
|
## Usage Examples
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import typer
|
|
5
|
+
|
|
6
|
+
from data_designer.cli.controllers.download_controller import DownloadController
|
|
7
|
+
from data_designer.config.utils.constants import DATA_DESIGNER_HOME
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def personas_command(
|
|
11
|
+
locales: list[str] = typer.Option(
|
|
12
|
+
None,
|
|
13
|
+
"--locale",
|
|
14
|
+
"-l",
|
|
15
|
+
help="Locales to download (en_US, en_IN, hi_Deva_IN, hi_Latn_IN, ja_JP). Can be specified multiple times.",
|
|
16
|
+
),
|
|
17
|
+
all_locales: bool = typer.Option(
|
|
18
|
+
False,
|
|
19
|
+
"--all",
|
|
20
|
+
help="Download all available locales",
|
|
21
|
+
),
|
|
22
|
+
dry_run: bool = typer.Option(
|
|
23
|
+
False,
|
|
24
|
+
"--dry-run",
|
|
25
|
+
help="Show what would be downloaded without actually downloading",
|
|
26
|
+
),
|
|
27
|
+
list_available: bool = typer.Option(
|
|
28
|
+
False,
|
|
29
|
+
"--list",
|
|
30
|
+
help="List available persona datasets and their sizes",
|
|
31
|
+
),
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Download Nemotron-Personas datasets for synthetic data generation.
|
|
34
|
+
|
|
35
|
+
Examples:
|
|
36
|
+
# List available datasets
|
|
37
|
+
data-designer download personas --list
|
|
38
|
+
|
|
39
|
+
# Interactive selection
|
|
40
|
+
data-designer download personas
|
|
41
|
+
|
|
42
|
+
# Download specific locales
|
|
43
|
+
data-designer download personas --locale en_US --locale ja_JP
|
|
44
|
+
|
|
45
|
+
# Download all available locales
|
|
46
|
+
data-designer download personas --all
|
|
47
|
+
|
|
48
|
+
# Preview what would be downloaded
|
|
49
|
+
data-designer download personas --all --dry-run
|
|
50
|
+
"""
|
|
51
|
+
controller = DownloadController(DATA_DESIGNER_HOME)
|
|
52
|
+
|
|
53
|
+
if list_available:
|
|
54
|
+
controller.list_personas()
|
|
55
|
+
else:
|
|
56
|
+
controller.run_personas(locales=locales, all_locales=all_locales, dry_run=dry_run)
|
|
@@ -95,32 +95,18 @@ def display_models(model_repo: ModelRepository) -> None:
|
|
|
95
95
|
# Display as table
|
|
96
96
|
table = Table(title="Model Configurations", border_style=NordColor.NORD8.value)
|
|
97
97
|
table.add_column("Alias", style=NordColor.NORD14.value, no_wrap=True)
|
|
98
|
-
table.add_column("Model
|
|
98
|
+
table.add_column("Model", style=NordColor.NORD4.value)
|
|
99
99
|
table.add_column("Provider", style=NordColor.NORD9.value, no_wrap=True)
|
|
100
|
-
table.add_column("
|
|
101
|
-
table.add_column("Top P", style=NordColor.NORD15.value, justify="right")
|
|
102
|
-
table.add_column("Max Tokens", style=NordColor.NORD15.value, justify="right")
|
|
100
|
+
table.add_column("Inference Parameters", style=NordColor.NORD15.value)
|
|
103
101
|
|
|
104
102
|
for mc in registry.model_configs:
|
|
105
|
-
|
|
106
|
-
temp_display = (
|
|
107
|
-
f"{mc.inference_parameters.temperature:.2f}"
|
|
108
|
-
if isinstance(mc.inference_parameters.temperature, (int, float))
|
|
109
|
-
else "dist"
|
|
110
|
-
)
|
|
111
|
-
top_p_display = (
|
|
112
|
-
f"{mc.inference_parameters.top_p:.2f}"
|
|
113
|
-
if isinstance(mc.inference_parameters.top_p, (int, float))
|
|
114
|
-
else "dist"
|
|
115
|
-
)
|
|
103
|
+
params_display = mc.inference_parameters.format_for_display()
|
|
116
104
|
|
|
117
105
|
table.add_row(
|
|
118
106
|
mc.alias,
|
|
119
107
|
mc.model,
|
|
120
108
|
mc.provider or "(default)",
|
|
121
|
-
|
|
122
|
-
top_p_display,
|
|
123
|
-
str(mc.inference_parameters.max_tokens) if mc.inference_parameters.max_tokens else "(none)",
|
|
109
|
+
params_display,
|
|
124
110
|
)
|
|
125
111
|
|
|
126
112
|
console.print(table)
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from data_designer.cli.controllers.download_controller import DownloadController
|
|
4
5
|
from data_designer.cli.controllers.model_controller import ModelController
|
|
5
6
|
from data_designer.cli.controllers.provider_controller import ProviderController
|
|
6
7
|
|
|
7
|
-
__all__ = ["ModelController", "ProviderController"]
|
|
8
|
+
__all__ = ["DownloadController", "ModelController", "ProviderController"]
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import subprocess
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from data_designer.cli.repositories.persona_repository import PersonaRepository
|
|
8
|
+
from data_designer.cli.services.download_service import DownloadService
|
|
9
|
+
from data_designer.cli.ui import (
|
|
10
|
+
confirm_action,
|
|
11
|
+
console,
|
|
12
|
+
print_error,
|
|
13
|
+
print_header,
|
|
14
|
+
print_info,
|
|
15
|
+
print_success,
|
|
16
|
+
print_text,
|
|
17
|
+
select_multiple_with_arrows,
|
|
18
|
+
)
|
|
19
|
+
from data_designer.cli.utils import check_ngc_cli_available, get_ngc_version
|
|
20
|
+
|
|
21
|
+
NGC_URL = "https://catalog.ngc.nvidia.com/"
|
|
22
|
+
NGC_CLI_INSTALL_URL = "https://org.ngc.nvidia.com/setup/installers/cli"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DownloadController:
|
|
26
|
+
"""Controller for asset download workflows."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, config_dir: Path):
|
|
29
|
+
self.config_dir = config_dir
|
|
30
|
+
self.persona_repository = PersonaRepository()
|
|
31
|
+
self.service = DownloadService(config_dir, self.persona_repository)
|
|
32
|
+
|
|
33
|
+
def list_personas(self) -> None:
|
|
34
|
+
"""List available persona datasets and their sizes."""
|
|
35
|
+
print_header("Available Nemotron-Persona Datasets")
|
|
36
|
+
console.print()
|
|
37
|
+
|
|
38
|
+
available_locales = self.persona_repository.list_all()
|
|
39
|
+
|
|
40
|
+
print_text("📦 Available locales:")
|
|
41
|
+
console.print()
|
|
42
|
+
|
|
43
|
+
for locale in available_locales:
|
|
44
|
+
already_downloaded = self.service.is_locale_downloaded(locale.code)
|
|
45
|
+
status = " (downloaded)" if already_downloaded else ""
|
|
46
|
+
print_text(f" • {locale.code}: {locale.size}{status}")
|
|
47
|
+
|
|
48
|
+
console.print()
|
|
49
|
+
print_info(f"Total: {len(available_locales)} datasets available")
|
|
50
|
+
|
|
51
|
+
def run_personas(self, locales: list[str] | None, all_locales: bool, dry_run: bool = False) -> None:
|
|
52
|
+
"""Main entry point for persona dataset downloads.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
locales: List of locale codes to download (if provided via CLI flags)
|
|
56
|
+
all_locales: If True, download all available locales
|
|
57
|
+
dry_run: If True, only show what would be downloaded without actually downloading
|
|
58
|
+
"""
|
|
59
|
+
header = "Download Nemotron-Persona Datasets (Dry Run)" if dry_run else "Download Nemotron-Persona Datasets"
|
|
60
|
+
print_header(header)
|
|
61
|
+
print_info(f"Datasets will be saved to: {self.service.get_managed_assets_directory()}")
|
|
62
|
+
console.print()
|
|
63
|
+
|
|
64
|
+
# Check NGC CLI availability (skip checking in dry run mode)
|
|
65
|
+
if not dry_run and not check_ngc_cli_with_instructions():
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
# Determine which locales to download
|
|
69
|
+
selected_locales = self._determine_locales(locales, all_locales)
|
|
70
|
+
|
|
71
|
+
if not selected_locales:
|
|
72
|
+
print_info("No locales selected")
|
|
73
|
+
return
|
|
74
|
+
|
|
75
|
+
# Show what will be downloaded
|
|
76
|
+
console.print()
|
|
77
|
+
action = "Would download" if dry_run else "Will download"
|
|
78
|
+
print_text(f"📦 {action} {len(selected_locales)} Nemotron-Persona dataset(s):")
|
|
79
|
+
for locale_code in selected_locales:
|
|
80
|
+
locale = self.persona_repository.get_by_code(locale_code)
|
|
81
|
+
already_downloaded = self.service.is_locale_downloaded(locale_code)
|
|
82
|
+
status = " - already exists, will update" if already_downloaded else ""
|
|
83
|
+
size = locale.size if locale else "unknown"
|
|
84
|
+
print_text(f" • {locale_code} ({size}){status}")
|
|
85
|
+
|
|
86
|
+
console.print()
|
|
87
|
+
|
|
88
|
+
# In dry run mode, exit here
|
|
89
|
+
if dry_run:
|
|
90
|
+
print_info("Dry run complete - no files were downloaded")
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
# Confirm download
|
|
94
|
+
if not confirm_action("Proceed with download?", default=True):
|
|
95
|
+
print_info("Download cancelled")
|
|
96
|
+
return
|
|
97
|
+
|
|
98
|
+
# Download each locale
|
|
99
|
+
console.print()
|
|
100
|
+
successful = []
|
|
101
|
+
failed = []
|
|
102
|
+
|
|
103
|
+
for locale in selected_locales:
|
|
104
|
+
if self._download_locale(locale):
|
|
105
|
+
successful.append(locale)
|
|
106
|
+
else:
|
|
107
|
+
failed.append(locale)
|
|
108
|
+
|
|
109
|
+
# Summary
|
|
110
|
+
console.print()
|
|
111
|
+
if successful:
|
|
112
|
+
print_success(f"Successfully downloaded {len(successful)} dataset(s): {', '.join(successful)}")
|
|
113
|
+
print_info(f"Saved datasets to: {self.service.get_managed_assets_directory()}")
|
|
114
|
+
|
|
115
|
+
if failed:
|
|
116
|
+
print_error(f"Failed to download {len(failed)} dataset(s): {', '.join(failed)}")
|
|
117
|
+
|
|
118
|
+
def _determine_locales(self, locales: list[str] | None, all_locales: bool) -> list[str]:
|
|
119
|
+
"""Determine which locales to download based on user input.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
locales: List of locales from CLI flags (may be None)
|
|
123
|
+
all_locales: Whether to download all locales
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
List of locale codes to download
|
|
127
|
+
"""
|
|
128
|
+
available_locales = self.service.get_available_locales()
|
|
129
|
+
|
|
130
|
+
# If --all flag is set, return all locales
|
|
131
|
+
if all_locales:
|
|
132
|
+
return list(available_locales.keys())
|
|
133
|
+
|
|
134
|
+
# If locales specified via flags, validate and return them
|
|
135
|
+
if locales:
|
|
136
|
+
invalid_locales = [loc for loc in locales if loc not in available_locales]
|
|
137
|
+
if invalid_locales:
|
|
138
|
+
print_error(f"Invalid locale(s): {', '.join(invalid_locales)}")
|
|
139
|
+
print_info(f"Available locales: {', '.join(available_locales.keys())}")
|
|
140
|
+
return []
|
|
141
|
+
return locales
|
|
142
|
+
|
|
143
|
+
# Interactive multi-select
|
|
144
|
+
return self._select_locales_interactive(available_locales)
|
|
145
|
+
|
|
146
|
+
def _select_locales_interactive(self, available_locales: dict[str, str]) -> list[str]:
|
|
147
|
+
"""Interactive multi-select for locales.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
available_locales: Dictionary of {locale_code: description}
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
List of selected locale codes
|
|
154
|
+
"""
|
|
155
|
+
console.print()
|
|
156
|
+
print_text("Select locales you want to download:")
|
|
157
|
+
console.print()
|
|
158
|
+
|
|
159
|
+
selected = select_multiple_with_arrows(
|
|
160
|
+
options=available_locales,
|
|
161
|
+
prompt_text="Use ↑/↓ to navigate, Space to toggle ✓, Enter to confirm:",
|
|
162
|
+
default_keys=None,
|
|
163
|
+
allow_empty=False,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
return selected if selected else []
|
|
167
|
+
|
|
168
|
+
def _download_locale(self, locale: str) -> bool:
|
|
169
|
+
"""Download a single locale using NGC CLI.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
locale: Locale code to download
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
True if download succeeded, False otherwise
|
|
176
|
+
"""
|
|
177
|
+
# Print header before download (NGC CLI will show its own progress)
|
|
178
|
+
print_text(f"📦 Downloading Nemotron-Persona dataset for {locale}...")
|
|
179
|
+
console.print()
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
self.service.download_persona_dataset(locale)
|
|
183
|
+
console.print()
|
|
184
|
+
print_success(f"✓ Downloaded Nemotron-Persona dataset for {locale}")
|
|
185
|
+
return True
|
|
186
|
+
|
|
187
|
+
except subprocess.CalledProcessError as e:
|
|
188
|
+
console.print()
|
|
189
|
+
print_error(f"✗ Failed to download Nemotron-Persona dataset for {locale}")
|
|
190
|
+
print_error(f"NGC CLI error: {e}")
|
|
191
|
+
return False
|
|
192
|
+
|
|
193
|
+
except Exception as e:
|
|
194
|
+
console.print()
|
|
195
|
+
print_error(f"✗ Failed to download Nemotron-Persona dataset for {locale}")
|
|
196
|
+
print_error(f"Unexpected error: {e}")
|
|
197
|
+
return False
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def check_ngc_cli_with_instructions() -> bool:
|
|
201
|
+
"""Check if NGC CLI is installed and guide user if not."""
|
|
202
|
+
if check_ngc_cli_available():
|
|
203
|
+
version = get_ngc_version()
|
|
204
|
+
if version:
|
|
205
|
+
print_info(version)
|
|
206
|
+
return True
|
|
207
|
+
|
|
208
|
+
print_error("NGC CLI not found!")
|
|
209
|
+
console.print()
|
|
210
|
+
print_text("The NGC CLI is required to download the Nemotron-Personas datasets.")
|
|
211
|
+
console.print()
|
|
212
|
+
print_text("To download the Nemotron-Personas datasets, follow these steps:")
|
|
213
|
+
print_text(f" 1. Create an NVIDIA NGC account: {NGC_URL}")
|
|
214
|
+
print_text(f" 2. Install the NGC CLI: {NGC_CLI_INSTALL_URL}")
|
|
215
|
+
print_text(" 3. Following the install instructions to set up the NGC CLI")
|
|
216
|
+
print_text(" 4. Run 'data-designer download personas'")
|
|
217
|
+
return False
|
|
@@ -160,9 +160,10 @@ class ModelController:
|
|
|
160
160
|
return
|
|
161
161
|
|
|
162
162
|
# Check if model has distribution-based parameters
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
163
|
+
params_dict = model.inference_parameters.model_dump(mode="json", exclude_none=True)
|
|
164
|
+
has_distribution = any(isinstance(v, dict) and "distribution_type" in v for v in params_dict.values())
|
|
165
|
+
|
|
166
|
+
if has_distribution:
|
|
166
167
|
print_warning(
|
|
167
168
|
"This model uses distribution-based inference parameters, "
|
|
168
169
|
"which cannot be edited via the CLI. Please edit the configuration file directly."
|
data_designer/cli/forms/field.py
CHANGED
|
@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
|
|
|
5
5
|
from collections.abc import Callable
|
|
6
6
|
from typing import Any, Generic, TypeVar
|
|
7
7
|
|
|
8
|
+
from data_designer.cli.ui import BACK, prompt_text_input, select_with_arrows
|
|
8
9
|
from data_designer.cli.utils import validate_numeric_range
|
|
9
10
|
|
|
10
11
|
T = TypeVar("T")
|
|
@@ -40,8 +41,14 @@ class Field(ABC, Generic[T]):
|
|
|
40
41
|
return self._value
|
|
41
42
|
|
|
42
43
|
@value.setter
|
|
43
|
-
def value(self, val: T) -> None:
|
|
44
|
-
"""Set and validate the field value."""
|
|
44
|
+
def value(self, val: T | str) -> None:
|
|
45
|
+
"""Set and validate the field value. Converts empty strings to None for optional fields."""
|
|
46
|
+
# Handle empty string for optional fields (clearing the value)
|
|
47
|
+
if val == "" and not self.required:
|
|
48
|
+
self._value = None
|
|
49
|
+
return
|
|
50
|
+
|
|
51
|
+
# Standard validation for non-empty values
|
|
45
52
|
if self.validator:
|
|
46
53
|
# For string validators, convert to string first if needed
|
|
47
54
|
val_str = str(val) if not isinstance(val, str) else val
|
|
@@ -50,6 +57,40 @@ class Field(ABC, Generic[T]):
|
|
|
50
57
|
raise ValidationError(error_msg or "Invalid value")
|
|
51
58
|
self._value = val
|
|
52
59
|
|
|
60
|
+
def _build_prompt_text(self) -> str:
|
|
61
|
+
"""Build prompt text with current value information."""
|
|
62
|
+
has_current_value = self.default is not None
|
|
63
|
+
|
|
64
|
+
if has_current_value:
|
|
65
|
+
# Show as "current" instead of "default" with dimmed styling
|
|
66
|
+
if not self.required:
|
|
67
|
+
return f"{self.prompt} <dim>(current value: {self.default}, type 'clear' to remove)</dim>"
|
|
68
|
+
return f"{self.prompt} <dim>(current value: {self.default})</dim>"
|
|
69
|
+
|
|
70
|
+
return self.prompt
|
|
71
|
+
|
|
72
|
+
def _handle_prompt_result(self, result: str | None | Any) -> str | None | Any:
|
|
73
|
+
"""Handle common prompt result logic (BACK, None, clear keywords, empty input)."""
|
|
74
|
+
if result is BACK:
|
|
75
|
+
return BACK
|
|
76
|
+
|
|
77
|
+
if result is None:
|
|
78
|
+
# User cancelled (ESC)
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
# Check for special keywords to clear the value
|
|
82
|
+
if result and result.lower() in ("clear", "none", "default"):
|
|
83
|
+
return ""
|
|
84
|
+
|
|
85
|
+
if not result:
|
|
86
|
+
# Empty input: return current value if exists
|
|
87
|
+
has_current_value = self.default is not None
|
|
88
|
+
if has_current_value:
|
|
89
|
+
return self.default
|
|
90
|
+
return ""
|
|
91
|
+
|
|
92
|
+
return result
|
|
93
|
+
|
|
53
94
|
@abstractmethod
|
|
54
95
|
def prompt_user(self, allow_back: bool = False) -> T | None | Any:
|
|
55
96
|
"""Prompt user for input."""
|
|
@@ -75,21 +116,19 @@ class TextField(Field[str]):
|
|
|
75
116
|
|
|
76
117
|
def prompt_user(self, allow_back: bool = False) -> str | None | Any:
|
|
77
118
|
"""Prompt user for text input."""
|
|
78
|
-
|
|
119
|
+
prompt_text = self._build_prompt_text()
|
|
79
120
|
|
|
121
|
+
# Don't pass default to prompt_text_input to avoid duplicate "(default: X)" text
|
|
80
122
|
result = prompt_text_input(
|
|
81
|
-
|
|
82
|
-
default=
|
|
123
|
+
prompt_text,
|
|
124
|
+
default=None,
|
|
83
125
|
validator=self.validator,
|
|
84
126
|
mask=self.mask,
|
|
85
127
|
completions=self.completions,
|
|
86
128
|
allow_back=allow_back,
|
|
87
129
|
)
|
|
88
130
|
|
|
89
|
-
|
|
90
|
-
return BACK
|
|
91
|
-
|
|
92
|
-
return result
|
|
131
|
+
return self._handle_prompt_result(result)
|
|
93
132
|
|
|
94
133
|
|
|
95
134
|
class SelectField(Field[str]):
|
|
@@ -109,8 +148,6 @@ class SelectField(Field[str]):
|
|
|
109
148
|
|
|
110
149
|
def prompt_user(self, allow_back: bool = False) -> str | None | Any:
|
|
111
150
|
"""Prompt user for selection."""
|
|
112
|
-
from data_designer.cli.ui import BACK, select_with_arrows
|
|
113
|
-
|
|
114
151
|
result = select_with_arrows(
|
|
115
152
|
self.options,
|
|
116
153
|
self.prompt,
|
|
@@ -144,6 +181,9 @@ class NumericField(Field[float]):
|
|
|
144
181
|
def range_validator(value: str) -> tuple[bool, str | None]:
|
|
145
182
|
if not value and not required:
|
|
146
183
|
return True, None
|
|
184
|
+
# Allow special keywords to clear the value
|
|
185
|
+
if value and value.lower() in ("clear", "none", "default"):
|
|
186
|
+
return True, None
|
|
147
187
|
if min_value is not None and max_value is not None:
|
|
148
188
|
is_valid, parsed = validate_numeric_range(value, min_value, max_value)
|
|
149
189
|
if not is_valid:
|
|
@@ -163,18 +203,24 @@ class NumericField(Field[float]):
|
|
|
163
203
|
|
|
164
204
|
def prompt_user(self, allow_back: bool = False) -> float | None | Any:
|
|
165
205
|
"""Prompt user for numeric input."""
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
default_str = str(self.default) if self.default is not None else None
|
|
206
|
+
prompt_text = self._build_prompt_text()
|
|
169
207
|
|
|
208
|
+
# Don't pass default to prompt_text_input to avoid duplicate "(default: X)" text
|
|
170
209
|
result = prompt_text_input(
|
|
171
|
-
|
|
172
|
-
default=
|
|
210
|
+
prompt_text,
|
|
211
|
+
default=None,
|
|
173
212
|
validator=self.validator,
|
|
174
213
|
allow_back=allow_back,
|
|
175
214
|
)
|
|
176
215
|
|
|
177
|
-
|
|
178
|
-
return BACK
|
|
216
|
+
result = self._handle_prompt_result(result)
|
|
179
217
|
|
|
180
|
-
|
|
218
|
+
# Return special values (BACK, None, empty string, defaults) as-is
|
|
219
|
+
if result is BACK or result is None or result == "":
|
|
220
|
+
return result
|
|
221
|
+
|
|
222
|
+
# Convert numeric strings to float (but not if it's already a float from default)
|
|
223
|
+
if isinstance(result, str):
|
|
224
|
+
return float(result)
|
|
225
|
+
|
|
226
|
+
return result
|