data-designer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +15 -0
- data_designer/_version.py +34 -0
- data_designer/cli/README.md +236 -0
- data_designer/cli/__init__.py +6 -0
- data_designer/cli/commands/__init__.py +2 -0
- data_designer/cli/commands/list.py +130 -0
- data_designer/cli/commands/models.py +10 -0
- data_designer/cli/commands/providers.py +11 -0
- data_designer/cli/commands/reset.py +100 -0
- data_designer/cli/controllers/__init__.py +7 -0
- data_designer/cli/controllers/model_controller.py +246 -0
- data_designer/cli/controllers/provider_controller.py +317 -0
- data_designer/cli/forms/__init__.py +20 -0
- data_designer/cli/forms/builder.py +51 -0
- data_designer/cli/forms/field.py +180 -0
- data_designer/cli/forms/form.py +59 -0
- data_designer/cli/forms/model_builder.py +125 -0
- data_designer/cli/forms/provider_builder.py +76 -0
- data_designer/cli/main.py +44 -0
- data_designer/cli/repositories/__init__.py +8 -0
- data_designer/cli/repositories/base.py +39 -0
- data_designer/cli/repositories/model_repository.py +42 -0
- data_designer/cli/repositories/provider_repository.py +43 -0
- data_designer/cli/services/__init__.py +7 -0
- data_designer/cli/services/model_service.py +116 -0
- data_designer/cli/services/provider_service.py +111 -0
- data_designer/cli/ui.py +448 -0
- data_designer/cli/utils.py +47 -0
- data_designer/config/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +89 -0
- data_designer/config/analysis/column_statistics.py +274 -0
- data_designer/config/analysis/dataset_profiler.py +60 -0
- data_designer/config/analysis/utils/errors.py +8 -0
- data_designer/config/analysis/utils/reporting.py +188 -0
- data_designer/config/base.py +68 -0
- data_designer/config/column_configs.py +354 -0
- data_designer/config/column_types.py +168 -0
- data_designer/config/config_builder.py +660 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +11 -0
- data_designer/config/datastore.py +151 -0
- data_designer/config/default_model_settings.py +123 -0
- data_designer/config/errors.py +19 -0
- data_designer/config/interface.py +54 -0
- data_designer/config/models.py +231 -0
- data_designer/config/preview_results.py +32 -0
- data_designer/config/processors.py +41 -0
- data_designer/config/sampler_constraints.py +51 -0
- data_designer/config/sampler_params.py +604 -0
- data_designer/config/seed.py +145 -0
- data_designer/config/utils/code_lang.py +83 -0
- data_designer/config/utils/constants.py +313 -0
- data_designer/config/utils/errors.py +19 -0
- data_designer/config/utils/info.py +88 -0
- data_designer/config/utils/io_helpers.py +273 -0
- data_designer/config/utils/misc.py +81 -0
- data_designer/config/utils/numerical_helpers.py +28 -0
- data_designer/config/utils/type_helpers.py +100 -0
- data_designer/config/utils/validation.py +336 -0
- data_designer/config/utils/visualization.py +427 -0
- data_designer/config/validator_params.py +96 -0
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +55 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
- data_designer/engine/analysis/column_profilers/registry.py +20 -0
- data_designer/engine/analysis/column_statistics.py +142 -0
- data_designer/engine/analysis/dataset_profiler.py +125 -0
- data_designer/engine/analysis/errors.py +7 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +61 -0
- data_designer/engine/column_generators/generators/expression.py +63 -0
- data_designer/engine/column_generators/generators/llm_generators.py +172 -0
- data_designer/engine/column_generators/generators/samplers.py +75 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
- data_designer/engine/column_generators/generators/validation.py +147 -0
- data_designer/engine/column_generators/registry.py +56 -0
- data_designer/engine/column_generators/utils/errors.py +13 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
- data_designer/engine/configurable_task.py +82 -0
- data_designer/engine/dataset_builders/artifact_storage.py +181 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
- data_designer/engine/dataset_builders/errors.py +13 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
- data_designer/engine/dataset_builders/utils/dag.py +56 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
- data_designer/engine/dataset_builders/utils/errors.py +13 -0
- data_designer/engine/errors.py +49 -0
- data_designer/engine/model_provider.py +75 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +308 -0
- data_designer/engine/models/facade.py +225 -0
- data_designer/engine/models/litellm_overrides.py +162 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +236 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +60 -0
- data_designer/engine/models/parsers/types.py +82 -0
- data_designer/engine/models/recipes/base.py +79 -0
- data_designer/engine/models/recipes/response_recipes.py +291 -0
- data_designer/engine/models/registry.py +118 -0
- data_designer/engine/models/usage.py +75 -0
- data_designer/engine/models/utils.py +38 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +64 -0
- data_designer/engine/processing/ginja/environment.py +461 -0
- data_designer/engine/processing/ginja/exceptions.py +54 -0
- data_designer/engine/processing/ginja/record.py +30 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +8 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
- data_designer/engine/processing/gsonschema/types.py +8 -0
- data_designer/engine/processing/gsonschema/validators.py +143 -0
- data_designer/engine/processing/processors/base.py +15 -0
- data_designer/engine/processing/processors/drop_columns.py +46 -0
- data_designer/engine/processing/processors/registry.py +20 -0
- data_designer/engine/processing/utils.py +120 -0
- data_designer/engine/registry/base.py +97 -0
- data_designer/engine/registry/data_designer_registry.py +37 -0
- data_designer/engine/registry/errors.py +10 -0
- data_designer/engine/resources/managed_dataset_generator.py +35 -0
- data_designer/engine/resources/managed_dataset_repository.py +194 -0
- data_designer/engine/resources/managed_storage.py +63 -0
- data_designer/engine/resources/resource_provider.py +46 -0
- data_designer/engine/resources/seed_dataset_data_store.py +66 -0
- data_designer/engine/sampling_gen/column.py +89 -0
- data_designer/engine/sampling_gen/constraints.py +95 -0
- data_designer/engine/sampling_gen/data_sources/base.py +214 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
- data_designer/engine/sampling_gen/entities/errors.py +8 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
- data_designer/engine/sampling_gen/entities/person.py +142 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
- data_designer/engine/sampling_gen/errors.py +24 -0
- data_designer/engine/sampling_gen/generator.py +121 -0
- data_designer/engine/sampling_gen/jinja_utils.py +60 -0
- data_designer/engine/sampling_gen/people_gen.py +203 -0
- data_designer/engine/sampling_gen/person_constants.py +54 -0
- data_designer/engine/sampling_gen/schema.py +143 -0
- data_designer/engine/sampling_gen/schema_builder.py +59 -0
- data_designer/engine/sampling_gen/utils.py +40 -0
- data_designer/engine/secret_resolver.py +80 -0
- data_designer/engine/validators/__init__.py +17 -0
- data_designer/engine/validators/base.py +36 -0
- data_designer/engine/validators/local_callable.py +34 -0
- data_designer/engine/validators/python.py +245 -0
- data_designer/engine/validators/remote.py +83 -0
- data_designer/engine/validators/sql.py +60 -0
- data_designer/errors.py +5 -0
- data_designer/essentials/__init__.py +137 -0
- data_designer/interface/__init__.py +2 -0
- data_designer/interface/data_designer.py +351 -0
- data_designer/interface/errors.py +16 -0
- data_designer/interface/results.py +55 -0
- data_designer/logging.py +161 -0
- data_designer/plugin_manager.py +83 -0
- data_designer/plugins/__init__.py +6 -0
- data_designer/plugins/errors.py +10 -0
- data_designer/plugins/plugin.py +69 -0
- data_designer/plugins/registry.py +86 -0
- data_designer-0.1.0.dist-info/METADATA +173 -0
- data_designer-0.1.0.dist-info/RECORD +177 -0
- data_designer-0.1.0.dist-info/WHEEL +4 -0
- data_designer-0.1.0.dist-info/entry_points.txt +2 -0
- data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from data_designer.cli.forms.model_builder import ModelFormBuilder
|
|
7
|
+
from data_designer.cli.repositories.model_repository import ModelRepository
|
|
8
|
+
from data_designer.cli.repositories.provider_repository import ProviderRepository
|
|
9
|
+
from data_designer.cli.services.model_service import ModelService
|
|
10
|
+
from data_designer.cli.services.provider_service import ProviderService
|
|
11
|
+
from data_designer.cli.ui import (
|
|
12
|
+
confirm_action,
|
|
13
|
+
console,
|
|
14
|
+
display_config_preview,
|
|
15
|
+
print_error,
|
|
16
|
+
print_header,
|
|
17
|
+
print_info,
|
|
18
|
+
print_success,
|
|
19
|
+
print_text,
|
|
20
|
+
print_warning,
|
|
21
|
+
select_with_arrows,
|
|
22
|
+
)
|
|
23
|
+
from data_designer.config.models import ModelConfig
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ModelController:
|
|
27
|
+
"""Controller for model configuration workflows."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, config_dir: Path):
|
|
30
|
+
self.config_dir = config_dir
|
|
31
|
+
self.model_repository = ModelRepository(config_dir)
|
|
32
|
+
self.model_service = ModelService(self.model_repository)
|
|
33
|
+
self.provider_repository = ProviderRepository(config_dir)
|
|
34
|
+
self.provider_service = ProviderService(self.provider_repository)
|
|
35
|
+
|
|
36
|
+
def run(self) -> None:
|
|
37
|
+
"""Main entry point for model configuration."""
|
|
38
|
+
print_header("Configure Models")
|
|
39
|
+
|
|
40
|
+
# Check if providers are configured
|
|
41
|
+
available_providers = self._get_available_providers()
|
|
42
|
+
|
|
43
|
+
if not available_providers:
|
|
44
|
+
print_error("No providers available!")
|
|
45
|
+
print_info("Please run 'data-designer config providers' first")
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
print_info(f"Configuration directory: {self.config_dir}")
|
|
49
|
+
console.print()
|
|
50
|
+
|
|
51
|
+
# Check for existing configuration
|
|
52
|
+
models = self.model_service.list_all()
|
|
53
|
+
|
|
54
|
+
if models:
|
|
55
|
+
self._show_existing_config()
|
|
56
|
+
mode = self._select_mode()
|
|
57
|
+
else:
|
|
58
|
+
print_info("No models configured yet")
|
|
59
|
+
console.print()
|
|
60
|
+
mode = "add"
|
|
61
|
+
|
|
62
|
+
if mode is None:
|
|
63
|
+
print_info("No changes made")
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
# Execute selected mode
|
|
67
|
+
mode_handlers = {
|
|
68
|
+
"add": self._handle_add,
|
|
69
|
+
"update": self._handle_update,
|
|
70
|
+
"delete": self._handle_delete,
|
|
71
|
+
"delete_all": self._handle_delete_all,
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
handler = mode_handlers.get(mode)
|
|
75
|
+
if handler:
|
|
76
|
+
handler(available_providers)
|
|
77
|
+
|
|
78
|
+
def _get_available_providers(self) -> list[str]:
|
|
79
|
+
"""Get list of available providers."""
|
|
80
|
+
return [p.name for p in self.provider_service.list_all()]
|
|
81
|
+
|
|
82
|
+
def _show_existing_config(self) -> None:
|
|
83
|
+
"""Display current configuration."""
|
|
84
|
+
registry = self.model_repository.load()
|
|
85
|
+
if not registry:
|
|
86
|
+
return
|
|
87
|
+
|
|
88
|
+
print_info(f"Found {len(registry.model_configs)} configured model(s)")
|
|
89
|
+
console.print()
|
|
90
|
+
|
|
91
|
+
# Display configuration
|
|
92
|
+
config_dict = registry.model_dump(mode="json", exclude_none=True)
|
|
93
|
+
display_config_preview(config_dict, "Current Configuration")
|
|
94
|
+
console.print()
|
|
95
|
+
|
|
96
|
+
def _select_mode(self) -> str | None:
|
|
97
|
+
"""Prompt user to select operation mode."""
|
|
98
|
+
options = {
|
|
99
|
+
"add": "Add a new model",
|
|
100
|
+
"update": "Update an existing model",
|
|
101
|
+
"delete": "Delete a model",
|
|
102
|
+
"delete_all": "Delete all models",
|
|
103
|
+
"exit": "Exit without changes",
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
result = select_with_arrows(
|
|
107
|
+
options,
|
|
108
|
+
"What would you like to do?",
|
|
109
|
+
default_key="add",
|
|
110
|
+
allow_back=False,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return None if result == "exit" or result is None else result
|
|
114
|
+
|
|
115
|
+
def _handle_add(self, available_providers: list[str]) -> None:
|
|
116
|
+
"""Handle adding new models."""
|
|
117
|
+
existing_aliases = {m.alias for m in self.model_service.list_all()}
|
|
118
|
+
|
|
119
|
+
while True:
|
|
120
|
+
# Print message before starting configuration
|
|
121
|
+
console.print()
|
|
122
|
+
print_text("🚀 Starting a new model configuration")
|
|
123
|
+
console.print()
|
|
124
|
+
|
|
125
|
+
# Create builder with current existing aliases
|
|
126
|
+
builder = ModelFormBuilder(existing_aliases, available_providers)
|
|
127
|
+
model = builder.run()
|
|
128
|
+
|
|
129
|
+
if model is None:
|
|
130
|
+
break
|
|
131
|
+
|
|
132
|
+
# Attempt to add
|
|
133
|
+
try:
|
|
134
|
+
self.model_service.add(model)
|
|
135
|
+
print_success(f"Model '{model.alias}' added successfully")
|
|
136
|
+
existing_aliases.add(model.alias)
|
|
137
|
+
except ValueError as e:
|
|
138
|
+
print_error(f"Failed to add model: {e}")
|
|
139
|
+
break
|
|
140
|
+
|
|
141
|
+
# Ask if they want to add more
|
|
142
|
+
if not self._confirm_add_another():
|
|
143
|
+
break
|
|
144
|
+
|
|
145
|
+
def _handle_update(self, available_providers: list[str]) -> None:
|
|
146
|
+
"""Handle updating an existing model."""
|
|
147
|
+
models = self.model_service.list_all()
|
|
148
|
+
if not models:
|
|
149
|
+
print_error("No models to update")
|
|
150
|
+
return
|
|
151
|
+
|
|
152
|
+
# Select model to update
|
|
153
|
+
selected_alias = self._select_model(models, "Select model to update")
|
|
154
|
+
if selected_alias is None:
|
|
155
|
+
return
|
|
156
|
+
|
|
157
|
+
model = self.model_service.get_by_alias(selected_alias)
|
|
158
|
+
if not model:
|
|
159
|
+
print_error(f"Model '{selected_alias}' not found")
|
|
160
|
+
return
|
|
161
|
+
|
|
162
|
+
# Check if model has distribution-based parameters
|
|
163
|
+
if hasattr(model.inference_parameters.temperature, "sample") or hasattr(
|
|
164
|
+
model.inference_parameters.top_p, "sample"
|
|
165
|
+
):
|
|
166
|
+
print_warning(
|
|
167
|
+
"This model uses distribution-based inference parameters, "
|
|
168
|
+
"which cannot be edited via the CLI. Please edit the configuration file directly."
|
|
169
|
+
)
|
|
170
|
+
return
|
|
171
|
+
|
|
172
|
+
# Run builder with existing data
|
|
173
|
+
existing_aliases = {m.alias for m in models if m.alias != selected_alias}
|
|
174
|
+
builder = ModelFormBuilder(existing_aliases, available_providers)
|
|
175
|
+
initial_data = model.model_dump(mode="json", exclude_none=True)
|
|
176
|
+
updated_model = builder.run(initial_data)
|
|
177
|
+
|
|
178
|
+
if updated_model:
|
|
179
|
+
try:
|
|
180
|
+
self.model_service.update(selected_alias, updated_model)
|
|
181
|
+
print_success(f"Model '{updated_model.alias}' updated successfully")
|
|
182
|
+
except ValueError as e:
|
|
183
|
+
print_error(f"Failed to update model: {e}")
|
|
184
|
+
|
|
185
|
+
def _handle_delete(self, available_providers: list[str]) -> None:
|
|
186
|
+
"""Handle deleting a model."""
|
|
187
|
+
models = self.model_service.list_all()
|
|
188
|
+
if not models:
|
|
189
|
+
print_error("No models to delete")
|
|
190
|
+
return
|
|
191
|
+
|
|
192
|
+
# Select model to delete
|
|
193
|
+
selected_alias = self._select_model(models, "Select model to delete")
|
|
194
|
+
if selected_alias is None:
|
|
195
|
+
return
|
|
196
|
+
|
|
197
|
+
# Confirm deletion
|
|
198
|
+
console.print()
|
|
199
|
+
if confirm_action(f"Delete model '{selected_alias}'?", default=False):
|
|
200
|
+
try:
|
|
201
|
+
self.model_service.delete(selected_alias)
|
|
202
|
+
print_success(f"Model '{selected_alias}' deleted successfully")
|
|
203
|
+
except ValueError as e:
|
|
204
|
+
print_error(f"Failed to delete model: {e}")
|
|
205
|
+
|
|
206
|
+
def _handle_delete_all(self, available_providers: list[str]) -> None:
|
|
207
|
+
"""Handle deleting all models."""
|
|
208
|
+
models = self.model_service.list_all()
|
|
209
|
+
if not models:
|
|
210
|
+
print_error("No models to delete")
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
# List models to be deleted
|
|
214
|
+
console.print()
|
|
215
|
+
model_count = len(models)
|
|
216
|
+
model_aliases = ", ".join([f"'{m.alias}'" for m in models])
|
|
217
|
+
|
|
218
|
+
if confirm_action(
|
|
219
|
+
f"⚠️ Delete ALL ({model_count}) model(s): {model_aliases}?\n This action cannot be undone.", default=False
|
|
220
|
+
):
|
|
221
|
+
try:
|
|
222
|
+
# Delete the entire configuration file
|
|
223
|
+
self.model_repository.delete()
|
|
224
|
+
print_success(f"All ({model_count}) model(s) deleted successfully")
|
|
225
|
+
except Exception as e:
|
|
226
|
+
print_error(f"Failed to delete all models: {e}")
|
|
227
|
+
|
|
228
|
+
def _select_model(self, models: list[ModelConfig], prompt: str, default: str | None = None) -> str | None:
|
|
229
|
+
"""Helper to select a model from list."""
|
|
230
|
+
options = {m.alias: f"{m.alias} ({m.model})" for m in models}
|
|
231
|
+
return select_with_arrows(
|
|
232
|
+
options,
|
|
233
|
+
prompt,
|
|
234
|
+
default_key=default or models[0].alias,
|
|
235
|
+
allow_back=False,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
def _confirm_add_another(self) -> bool:
|
|
239
|
+
"""Ask if user wants to add another model."""
|
|
240
|
+
result = select_with_arrows(
|
|
241
|
+
{"yes": "Add another model", "no": "Finish"},
|
|
242
|
+
"Add another model?",
|
|
243
|
+
default_key="no",
|
|
244
|
+
allow_back=False,
|
|
245
|
+
)
|
|
246
|
+
return result == "yes"
|
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import copy
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from data_designer.cli.forms.provider_builder import ProviderFormBuilder
|
|
8
|
+
from data_designer.cli.repositories.model_repository import ModelRepository
|
|
9
|
+
from data_designer.cli.repositories.provider_repository import ProviderRepository
|
|
10
|
+
from data_designer.cli.services.model_service import ModelService
|
|
11
|
+
from data_designer.cli.services.provider_service import ProviderService
|
|
12
|
+
from data_designer.cli.ui import (
|
|
13
|
+
confirm_action,
|
|
14
|
+
console,
|
|
15
|
+
display_config_preview,
|
|
16
|
+
print_error,
|
|
17
|
+
print_header,
|
|
18
|
+
print_info,
|
|
19
|
+
print_success,
|
|
20
|
+
print_warning,
|
|
21
|
+
select_with_arrows,
|
|
22
|
+
)
|
|
23
|
+
from data_designer.engine.model_provider import ModelProvider
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ProviderController:
|
|
27
|
+
"""Controller for provider configuration workflows."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, config_dir: Path):
|
|
30
|
+
self.config_dir = config_dir
|
|
31
|
+
self.repository = ProviderRepository(config_dir)
|
|
32
|
+
self.service = ProviderService(self.repository)
|
|
33
|
+
self.model_repository = ModelRepository(config_dir)
|
|
34
|
+
self.model_service = ModelService(self.model_repository)
|
|
35
|
+
|
|
36
|
+
def run(self) -> None:
|
|
37
|
+
"""Main entry point for provider configuration."""
|
|
38
|
+
print_header("Configure Model Providers")
|
|
39
|
+
print_info(f"Configuration directory: {self.config_dir}")
|
|
40
|
+
console.print()
|
|
41
|
+
|
|
42
|
+
# Check for existing configuration
|
|
43
|
+
providers = self.service.list_all()
|
|
44
|
+
|
|
45
|
+
if providers:
|
|
46
|
+
self._show_existing_config()
|
|
47
|
+
mode = self._select_mode()
|
|
48
|
+
else:
|
|
49
|
+
print_info("No providers configured yet")
|
|
50
|
+
console.print()
|
|
51
|
+
mode = "add"
|
|
52
|
+
|
|
53
|
+
if mode is None:
|
|
54
|
+
print_info("No changes made")
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
# Execute selected mode
|
|
58
|
+
mode_handlers = {
|
|
59
|
+
"add": self._handle_add,
|
|
60
|
+
"update": self._handle_update,
|
|
61
|
+
"delete": self._handle_delete,
|
|
62
|
+
"delete_all": self._handle_delete_all,
|
|
63
|
+
"change_default": self._handle_change_default,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
handler = mode_handlers.get(mode)
|
|
67
|
+
if handler:
|
|
68
|
+
handler()
|
|
69
|
+
|
|
70
|
+
def _show_existing_config(self) -> None:
|
|
71
|
+
"""Display current configuration."""
|
|
72
|
+
registry = self.repository.load()
|
|
73
|
+
if not registry:
|
|
74
|
+
return
|
|
75
|
+
|
|
76
|
+
print_info(f"Found {len(registry.providers)} configured provider(s)")
|
|
77
|
+
console.print()
|
|
78
|
+
|
|
79
|
+
# Display configuration (with masked API keys)
|
|
80
|
+
config_dict = registry.model_dump(mode="json", exclude_none=True)
|
|
81
|
+
masked_config = self._mask_api_keys(config_dict)
|
|
82
|
+
display_config_preview(masked_config, "Current Configuration")
|
|
83
|
+
console.print()
|
|
84
|
+
|
|
85
|
+
def _mask_api_keys(self, config: dict) -> dict:
|
|
86
|
+
"""Mask API keys in configuration for display."""
|
|
87
|
+
masked = copy.deepcopy(config)
|
|
88
|
+
|
|
89
|
+
if "providers" in masked:
|
|
90
|
+
for provider in masked["providers"]:
|
|
91
|
+
if "api_key" in provider and provider["api_key"]:
|
|
92
|
+
api_key = provider["api_key"]
|
|
93
|
+
# Keep environment variable names visible
|
|
94
|
+
if not api_key.isupper():
|
|
95
|
+
provider["api_key"] = "***" + api_key[-4:] if len(api_key) > 4 else "***"
|
|
96
|
+
|
|
97
|
+
return masked
|
|
98
|
+
|
|
99
|
+
def _select_mode(self) -> str | None:
|
|
100
|
+
"""Prompt user to select operation mode."""
|
|
101
|
+
options = {
|
|
102
|
+
"add": "Add a new provider",
|
|
103
|
+
"update": "Update an existing provider",
|
|
104
|
+
"delete": "Delete a provider",
|
|
105
|
+
"delete_all": "Delete all providers",
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
# Only show change_default if multiple providers
|
|
109
|
+
if len(self.service.list_all()) > 1:
|
|
110
|
+
options["change_default"] = "Change default provider"
|
|
111
|
+
|
|
112
|
+
options["exit"] = "Exit without changes"
|
|
113
|
+
|
|
114
|
+
result = select_with_arrows(
|
|
115
|
+
options,
|
|
116
|
+
"What would you like to do?",
|
|
117
|
+
default_key="add",
|
|
118
|
+
allow_back=False,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
return None if result == "exit" or result is None else result
|
|
122
|
+
|
|
123
|
+
def _handle_add(self) -> None:
|
|
124
|
+
"""Handle adding new providers."""
|
|
125
|
+
existing_names = {p.name for p in self.service.list_all()}
|
|
126
|
+
|
|
127
|
+
while True:
|
|
128
|
+
# Create builder with current existing names
|
|
129
|
+
builder = ProviderFormBuilder(existing_names)
|
|
130
|
+
provider = builder.run()
|
|
131
|
+
|
|
132
|
+
if provider is None:
|
|
133
|
+
break
|
|
134
|
+
|
|
135
|
+
# Attempt to add
|
|
136
|
+
try:
|
|
137
|
+
self.service.add(provider)
|
|
138
|
+
print_success(f"Provider '{provider.name}' added successfully")
|
|
139
|
+
existing_names.add(provider.name)
|
|
140
|
+
except ValueError as e:
|
|
141
|
+
print_error(f"Failed to add provider: {e}")
|
|
142
|
+
break
|
|
143
|
+
|
|
144
|
+
# Ask if they want to add more
|
|
145
|
+
if not self._confirm_add_another():
|
|
146
|
+
break
|
|
147
|
+
|
|
148
|
+
def _handle_update(self) -> None:
|
|
149
|
+
"""Handle updating an existing provider."""
|
|
150
|
+
providers = self.service.list_all()
|
|
151
|
+
if not providers:
|
|
152
|
+
print_error("No providers to update")
|
|
153
|
+
return
|
|
154
|
+
|
|
155
|
+
# Select provider to update
|
|
156
|
+
selected_name = self._select_provider(providers, "Select provider to update")
|
|
157
|
+
if selected_name is None:
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
provider = self.service.get_by_name(selected_name)
|
|
161
|
+
if not provider:
|
|
162
|
+
print_error(f"Provider '{selected_name}' not found")
|
|
163
|
+
return
|
|
164
|
+
|
|
165
|
+
# Run builder with existing data
|
|
166
|
+
existing_names = {p.name for p in providers if p.name != selected_name}
|
|
167
|
+
builder = ProviderFormBuilder(existing_names)
|
|
168
|
+
initial_data = provider.model_dump(mode="json", exclude_none=True)
|
|
169
|
+
updated_provider = builder.run(initial_data)
|
|
170
|
+
|
|
171
|
+
if updated_provider:
|
|
172
|
+
try:
|
|
173
|
+
self.service.update(selected_name, updated_provider)
|
|
174
|
+
print_success(f"Provider '{updated_provider.name}' updated successfully")
|
|
175
|
+
except ValueError as e:
|
|
176
|
+
print_error(f"Failed to update provider: {e}")
|
|
177
|
+
|
|
178
|
+
def _handle_delete(self) -> None:
|
|
179
|
+
"""Handle deleting a provider."""
|
|
180
|
+
providers = self.service.list_all()
|
|
181
|
+
if not providers:
|
|
182
|
+
print_error("No providers to delete")
|
|
183
|
+
return
|
|
184
|
+
|
|
185
|
+
# Select provider to delete
|
|
186
|
+
selected_name = self._select_provider(providers, "Select provider to delete")
|
|
187
|
+
if selected_name is None:
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
# Check for associated models
|
|
191
|
+
associated_models = self.model_service.find_by_provider(selected_name)
|
|
192
|
+
|
|
193
|
+
# Confirm deletion
|
|
194
|
+
console.print()
|
|
195
|
+
|
|
196
|
+
if associated_models:
|
|
197
|
+
# Notify user about associated models
|
|
198
|
+
model_count = len(associated_models)
|
|
199
|
+
model_aliases = ", ".join([f"'{m.alias}'" for m in associated_models])
|
|
200
|
+
|
|
201
|
+
print_warning(f"Provider '{selected_name}' has {model_count} associated model config(s): {model_aliases}")
|
|
202
|
+
console.print()
|
|
203
|
+
|
|
204
|
+
# Ask if user wants to delete provider and associated models
|
|
205
|
+
if confirm_action(
|
|
206
|
+
f"Delete provider '{selected_name}' and its {model_count} associated model config(s)?", default=False
|
|
207
|
+
):
|
|
208
|
+
try:
|
|
209
|
+
# Delete associated models first
|
|
210
|
+
model_aliases_to_delete = [m.alias for m in associated_models]
|
|
211
|
+
self.model_service.delete_by_aliases(model_aliases_to_delete)
|
|
212
|
+
|
|
213
|
+
# Then delete the provider
|
|
214
|
+
self.service.delete(selected_name)
|
|
215
|
+
|
|
216
|
+
print_success(
|
|
217
|
+
f"Provider '{selected_name}' and {model_count} associated model(s) deleted successfully"
|
|
218
|
+
)
|
|
219
|
+
except ValueError as e:
|
|
220
|
+
print_error(f"Failed to delete provider and associated models: {e}")
|
|
221
|
+
else:
|
|
222
|
+
# No associated models, proceed with simple deletion
|
|
223
|
+
if confirm_action(f"Delete provider '{selected_name}'?", default=False):
|
|
224
|
+
try:
|
|
225
|
+
self.service.delete(selected_name)
|
|
226
|
+
print_success(f"Provider '{selected_name}' deleted successfully")
|
|
227
|
+
except ValueError as e:
|
|
228
|
+
print_error(f"Failed to delete provider: {e}")
|
|
229
|
+
|
|
230
|
+
def _handle_delete_all(self) -> None:
|
|
231
|
+
"""Handle deleting all providers."""
|
|
232
|
+
providers = self.service.list_all()
|
|
233
|
+
if not providers:
|
|
234
|
+
print_error("No providers to delete")
|
|
235
|
+
return
|
|
236
|
+
|
|
237
|
+
# Check for associated models across all providers
|
|
238
|
+
all_models = self.model_service.list_all()
|
|
239
|
+
provider_names_set = {p.name for p in providers}
|
|
240
|
+
associated_models = [m for m in all_models if m.provider in provider_names_set]
|
|
241
|
+
|
|
242
|
+
# List providers to be deleted
|
|
243
|
+
console.print()
|
|
244
|
+
provider_count = len(providers)
|
|
245
|
+
provider_names = ", ".join([f"'{p.name}'" for p in providers])
|
|
246
|
+
|
|
247
|
+
if associated_models:
|
|
248
|
+
model_count = len(associated_models)
|
|
249
|
+
print_warning(f"Deleting all providers will also affect {model_count} associated model config(s)")
|
|
250
|
+
console.print()
|
|
251
|
+
|
|
252
|
+
if confirm_action(
|
|
253
|
+
f"⚠️ Delete ALL ({provider_count}) provider(s): {provider_names} and {model_count} associated model(s)?\n This action cannot be undone.",
|
|
254
|
+
default=False,
|
|
255
|
+
):
|
|
256
|
+
try:
|
|
257
|
+
# Delete all associated models first
|
|
258
|
+
model_aliases_to_delete = [m.alias for m in associated_models]
|
|
259
|
+
self.model_service.delete_by_aliases(model_aliases_to_delete)
|
|
260
|
+
|
|
261
|
+
# Then delete all providers
|
|
262
|
+
self.repository.delete()
|
|
263
|
+
|
|
264
|
+
print_success(
|
|
265
|
+
f"All ({provider_count}) provider(s) and {model_count} associated model(s) deleted successfully"
|
|
266
|
+
)
|
|
267
|
+
except Exception as e:
|
|
268
|
+
print_error(f"Failed to delete all providers and associated models: {e}")
|
|
269
|
+
else:
|
|
270
|
+
if confirm_action(
|
|
271
|
+
f"⚠️ Delete ALL ({provider_count}) provider(s): {provider_names}?\n This action cannot be undone.",
|
|
272
|
+
default=False,
|
|
273
|
+
):
|
|
274
|
+
try:
|
|
275
|
+
self.repository.delete()
|
|
276
|
+
print_success(f"All ({provider_count}) provider(s) deleted successfully")
|
|
277
|
+
except Exception as e:
|
|
278
|
+
print_error(f"Failed to delete all providers: {e}")
|
|
279
|
+
|
|
280
|
+
def _handle_change_default(self) -> None:
|
|
281
|
+
"""Handle changing the default provider."""
|
|
282
|
+
providers = self.service.list_all()
|
|
283
|
+
current_default = self.service.get_default()
|
|
284
|
+
|
|
285
|
+
print_info(f"Current default: {current_default}")
|
|
286
|
+
console.print()
|
|
287
|
+
|
|
288
|
+
# Select new default
|
|
289
|
+
selected_name = self._select_provider(providers, "Select new default provider", default=current_default)
|
|
290
|
+
if selected_name is None:
|
|
291
|
+
return
|
|
292
|
+
if selected_name and selected_name != current_default:
|
|
293
|
+
try:
|
|
294
|
+
self.service.set_default(selected_name)
|
|
295
|
+
print_success(f"Default provider changed to '{selected_name}'")
|
|
296
|
+
except ValueError as e:
|
|
297
|
+
print_error(f"Failed to change default: {e}")
|
|
298
|
+
|
|
299
|
+
def _select_provider(self, providers: list[ModelProvider], prompt: str, default: str | None = None) -> str | None:
|
|
300
|
+
"""Helper to select a provider from list."""
|
|
301
|
+
options = {p.name: f"{p.name} ({p.endpoint})" for p in providers}
|
|
302
|
+
return select_with_arrows(
|
|
303
|
+
options,
|
|
304
|
+
prompt,
|
|
305
|
+
default_key=default or providers[0].name,
|
|
306
|
+
allow_back=False,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
def _confirm_add_another(self) -> bool:
|
|
310
|
+
"""Ask if user wants to add another provider."""
|
|
311
|
+
result = select_with_arrows(
|
|
312
|
+
{"yes": "Add another provider", "no": "Finish"},
|
|
313
|
+
"Add another provider?",
|
|
314
|
+
default_key="no",
|
|
315
|
+
allow_back=False,
|
|
316
|
+
)
|
|
317
|
+
return result == "yes"
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.cli.forms.builder import FormBuilder
|
|
5
|
+
from data_designer.cli.forms.field import Field, NumericField, SelectField, TextField, ValidationError
|
|
6
|
+
from data_designer.cli.forms.form import Form
|
|
7
|
+
from data_designer.cli.forms.model_builder import ModelFormBuilder
|
|
8
|
+
from data_designer.cli.forms.provider_builder import ProviderFormBuilder
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"Field",
|
|
12
|
+
"Form",
|
|
13
|
+
"FormBuilder",
|
|
14
|
+
"ModelFormBuilder",
|
|
15
|
+
"NumericField",
|
|
16
|
+
"ProviderFormBuilder",
|
|
17
|
+
"SelectField",
|
|
18
|
+
"TextField",
|
|
19
|
+
"ValidationError",
|
|
20
|
+
]
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from typing import Any, Generic, TypeVar
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from data_designer.cli.forms.form import Form
|
|
10
|
+
from data_designer.cli.ui import confirm_action, print_error
|
|
11
|
+
|
|
12
|
+
T = TypeVar("T", bound=BaseModel)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FormBuilder(ABC, Generic[T]):
|
|
16
|
+
"""Abstract base for building interactive configuration forms."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, title: str):
|
|
19
|
+
self.title = title
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def create_form(self, initial_data: dict[str, Any] | None = None) -> Form:
|
|
23
|
+
"""Create the form for this configuration."""
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def build_config(self, form_data: dict[str, Any]) -> T:
|
|
27
|
+
"""Build a configuration object from form data."""
|
|
28
|
+
|
|
29
|
+
def run(self, initial_data: dict[str, Any] | None = None) -> T | None:
|
|
30
|
+
"""Run the interactive form and return configured object."""
|
|
31
|
+
form = self.create_form(initial_data)
|
|
32
|
+
|
|
33
|
+
# Pre-populate form with initial data
|
|
34
|
+
if initial_data:
|
|
35
|
+
form.set_values(initial_data)
|
|
36
|
+
|
|
37
|
+
while True:
|
|
38
|
+
result = form.prompt_all(allow_back=True)
|
|
39
|
+
|
|
40
|
+
if result is None:
|
|
41
|
+
if confirm_action("Cancel configuration?", default=False):
|
|
42
|
+
return None
|
|
43
|
+
continue
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
config = self.build_config(result)
|
|
47
|
+
return config
|
|
48
|
+
except Exception as e:
|
|
49
|
+
print_error(f"Configuration error: {e}")
|
|
50
|
+
if not confirm_action("Try again?", default=True):
|
|
51
|
+
return None
|