data-designer 0.1.5__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/cli/README.md +15 -1
- data_designer/cli/commands/download.py +56 -0
- data_designer/cli/commands/list.py +4 -18
- data_designer/cli/controllers/__init__.py +2 -1
- data_designer/cli/controllers/download_controller.py +217 -0
- data_designer/cli/controllers/model_controller.py +4 -3
- data_designer/cli/forms/field.py +65 -19
- data_designer/cli/forms/model_builder.py +251 -44
- data_designer/cli/main.py +11 -1
- data_designer/cli/repositories/persona_repository.py +88 -0
- data_designer/cli/services/__init__.py +2 -1
- data_designer/cli/services/download_service.py +97 -0
- data_designer/cli/ui.py +131 -0
- data_designer/cli/utils.py +34 -0
- data_designer/config/analysis/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +75 -7
- data_designer/config/analysis/column_statistics.py +192 -48
- data_designer/config/analysis/dataset_profiler.py +23 -5
- data_designer/config/analysis/utils/reporting.py +3 -3
- data_designer/config/base.py +3 -3
- data_designer/config/column_configs.py +27 -6
- data_designer/config/column_types.py +24 -17
- data_designer/config/config_builder.py +34 -26
- data_designer/config/data_designer_config.py +7 -7
- data_designer/config/datastore.py +6 -6
- data_designer/config/default_model_settings.py +27 -34
- data_designer/config/exports.py +8 -0
- data_designer/config/models.py +155 -29
- data_designer/config/preview_results.py +6 -8
- data_designer/config/processors.py +63 -2
- data_designer/config/sampler_constraints.py +1 -2
- data_designer/config/sampler_params.py +31 -31
- data_designer/config/seed.py +1 -2
- data_designer/config/utils/code_lang.py +4 -5
- data_designer/config/utils/constants.py +31 -8
- data_designer/config/utils/io_helpers.py +5 -5
- data_designer/config/utils/misc.py +1 -4
- data_designer/config/utils/numerical_helpers.py +2 -2
- data_designer/config/utils/type_helpers.py +3 -3
- data_designer/config/utils/validation.py +7 -8
- data_designer/config/utils/visualization.py +32 -17
- data_designer/config/validator_params.py +4 -8
- data_designer/engine/analysis/column_profilers/base.py +0 -7
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
- data_designer/engine/analysis/column_statistics.py +16 -16
- data_designer/engine/analysis/dataset_profiler.py +25 -4
- data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
- data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
- data_designer/engine/column_generators/generators/base.py +34 -0
- data_designer/engine/column_generators/generators/embedding.py +45 -0
- data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
- data_designer/engine/column_generators/registry.py +4 -2
- data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
- data_designer/engine/configurable_task.py +2 -2
- data_designer/engine/dataset_builders/artifact_storage.py +1 -2
- data_designer/engine/dataset_builders/column_wise_builder.py +11 -10
- data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
- data_designer/engine/models/facade.py +66 -9
- data_designer/engine/models/litellm_overrides.py +5 -6
- data_designer/engine/models/parsers/errors.py +2 -4
- data_designer/engine/models/parsers/parser.py +2 -3
- data_designer/engine/models/parsers/postprocessors.py +3 -4
- data_designer/engine/models/parsers/types.py +4 -4
- data_designer/engine/models/registry.py +20 -11
- data_designer/engine/models/usage.py +7 -9
- data_designer/engine/processing/ginja/ast.py +1 -2
- data_designer/engine/processing/utils.py +40 -2
- data_designer/engine/registry/base.py +12 -12
- data_designer/engine/sampling_gen/constraints.py +1 -2
- data_designer/engine/sampling_gen/data_sources/base.py +14 -14
- data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
- data_designer/engine/sampling_gen/people_gen.py +3 -7
- data_designer/engine/validators/base.py +2 -2
- data_designer/logging.py +2 -2
- data_designer/plugin_manager.py +3 -3
- data_designer/plugins/plugin.py +3 -3
- data_designer/plugins/registry.py +2 -2
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/METADATA +1 -1
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/RECORD +83 -77
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,8 +6,13 @@ from typing import Any
|
|
|
6
6
|
from data_designer.cli.forms.builder import FormBuilder
|
|
7
7
|
from data_designer.cli.forms.field import NumericField, SelectField, TextField
|
|
8
8
|
from data_designer.cli.forms.form import Form
|
|
9
|
-
from data_designer.
|
|
10
|
-
from data_designer.config.
|
|
9
|
+
from data_designer.cli.ui import confirm_action, print_error, print_text
|
|
10
|
+
from data_designer.config.models import (
|
|
11
|
+
ChatCompletionInferenceParams,
|
|
12
|
+
EmbeddingInferenceParams,
|
|
13
|
+
GenerationType,
|
|
14
|
+
ModelConfig,
|
|
15
|
+
)
|
|
11
16
|
|
|
12
17
|
|
|
13
18
|
class ModelFormBuilder(FormBuilder[ModelConfig]):
|
|
@@ -19,7 +24,7 @@ class ModelFormBuilder(FormBuilder[ModelConfig]):
|
|
|
19
24
|
self.available_providers = available_providers or []
|
|
20
25
|
|
|
21
26
|
def create_form(self, initial_data: dict[str, Any] | None = None) -> Form:
|
|
22
|
-
"""Create the model configuration form."""
|
|
27
|
+
"""Create the model configuration form with basic fields."""
|
|
23
28
|
fields = []
|
|
24
29
|
|
|
25
30
|
# Model alias
|
|
@@ -29,7 +34,7 @@ class ModelFormBuilder(FormBuilder[ModelConfig]):
|
|
|
29
34
|
"Model alias (used in your configs)",
|
|
30
35
|
default=initial_data.get("alias") if initial_data else None,
|
|
31
36
|
required=True,
|
|
32
|
-
validator=self.
|
|
37
|
+
validator=self.validate_alias,
|
|
33
38
|
)
|
|
34
39
|
)
|
|
35
40
|
|
|
@@ -37,10 +42,10 @@ class ModelFormBuilder(FormBuilder[ModelConfig]):
|
|
|
37
42
|
fields.append(
|
|
38
43
|
TextField(
|
|
39
44
|
"model",
|
|
40
|
-
"Model
|
|
45
|
+
"Model",
|
|
41
46
|
default=initial_data.get("model") if initial_data else None,
|
|
42
47
|
required=True,
|
|
43
|
-
validator=lambda x: (False, "Model
|
|
48
|
+
validator=lambda x: (False, "Model is required") if not x else (True, None),
|
|
44
49
|
)
|
|
45
50
|
)
|
|
46
51
|
|
|
@@ -61,46 +66,222 @@ class ModelFormBuilder(FormBuilder[ModelConfig]):
|
|
|
61
66
|
# Single provider - will be set automatically
|
|
62
67
|
pass
|
|
63
68
|
|
|
64
|
-
#
|
|
65
|
-
|
|
66
|
-
|
|
69
|
+
# Generation type
|
|
70
|
+
# Extract from inference_parameters if present (for existing models)
|
|
71
|
+
default_gen_type = GenerationType.CHAT_COMPLETION
|
|
72
|
+
if initial_data:
|
|
73
|
+
inference_params = initial_data.get("inference_parameters", {})
|
|
74
|
+
default_gen_type = inference_params.get("generation_type", default_gen_type)
|
|
75
|
+
|
|
76
|
+
fields.append(
|
|
77
|
+
SelectField(
|
|
78
|
+
"generation_type",
|
|
79
|
+
"Generation type",
|
|
80
|
+
options={
|
|
81
|
+
GenerationType.CHAT_COMPLETION: "Chat completion",
|
|
82
|
+
GenerationType.EMBEDDING: "Embedding",
|
|
83
|
+
},
|
|
84
|
+
default=default_gen_type,
|
|
85
|
+
)
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return Form(self.title, fields)
|
|
89
|
+
|
|
90
|
+
def create_inference_params_form(
|
|
91
|
+
self, generation_type: GenerationType, initial_params: dict[str, Any] | None = None
|
|
92
|
+
) -> Form:
|
|
93
|
+
"""Create generation-type-specific inference parameters form."""
|
|
94
|
+
initial_params = initial_params or {}
|
|
95
|
+
fields = []
|
|
96
|
+
|
|
97
|
+
if generation_type == GenerationType.CHAT_COMPLETION:
|
|
98
|
+
# Temperature
|
|
99
|
+
fields.append(
|
|
67
100
|
NumericField(
|
|
68
101
|
"temperature",
|
|
69
|
-
|
|
70
|
-
default=
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
)
|
|
102
|
+
"Temperature <dim>(0.0-2.0)</dim>",
|
|
103
|
+
default=initial_params.get("temperature"),
|
|
104
|
+
min_value=0.0,
|
|
105
|
+
max_value=2.0,
|
|
106
|
+
required=False,
|
|
107
|
+
help_text="Higher values make output more random, lower values more deterministic",
|
|
108
|
+
)
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Top P
|
|
112
|
+
fields.append(
|
|
76
113
|
NumericField(
|
|
77
114
|
"top_p",
|
|
78
|
-
|
|
79
|
-
default=
|
|
80
|
-
min_value=
|
|
81
|
-
max_value=
|
|
82
|
-
|
|
115
|
+
"Top P <dim>(0.0-1.0)</dim>",
|
|
116
|
+
default=initial_params.get("top_p"),
|
|
117
|
+
min_value=0.0,
|
|
118
|
+
max_value=1.0,
|
|
119
|
+
required=False,
|
|
120
|
+
help_text="Controls diversity via nucleus sampling",
|
|
121
|
+
)
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Max tokens
|
|
125
|
+
fields.append(
|
|
83
126
|
NumericField(
|
|
84
127
|
"max_tokens",
|
|
85
|
-
"Max tokens",
|
|
86
|
-
default=
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
]
|
|
93
|
-
)
|
|
128
|
+
"Max tokens <dim>(maximum total tokens including input and output)</dim>",
|
|
129
|
+
default=initial_params.get("max_tokens"),
|
|
130
|
+
min_value=1.0,
|
|
131
|
+
required=False,
|
|
132
|
+
help_text="Maximum number of tokens including both input prompt and generated response",
|
|
133
|
+
)
|
|
134
|
+
)
|
|
94
135
|
|
|
95
|
-
|
|
136
|
+
# Max parallel requests
|
|
137
|
+
fields.append(
|
|
138
|
+
NumericField(
|
|
139
|
+
"max_parallel_requests",
|
|
140
|
+
"Max parallel requests <dim>(default: 4)</dim>",
|
|
141
|
+
default=initial_params.get("max_parallel_requests", 4),
|
|
142
|
+
min_value=1.0,
|
|
143
|
+
required=False,
|
|
144
|
+
help_text="Maximum number of parallel API requests",
|
|
145
|
+
)
|
|
146
|
+
)
|
|
96
147
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
148
|
+
# Timeout
|
|
149
|
+
fields.append(
|
|
150
|
+
NumericField(
|
|
151
|
+
"timeout",
|
|
152
|
+
"Timeout in seconds <dim>(optional)</dim>",
|
|
153
|
+
default=initial_params.get("timeout"),
|
|
154
|
+
min_value=1.0,
|
|
155
|
+
required=False,
|
|
156
|
+
help_text="Timeout for each API request in seconds",
|
|
157
|
+
)
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
else: # EMBEDDING
|
|
161
|
+
# Encoding format
|
|
162
|
+
fields.append(
|
|
163
|
+
TextField(
|
|
164
|
+
"encoding_format",
|
|
165
|
+
"Encoding format <dim>(float or base64)</dim>",
|
|
166
|
+
default=initial_params.get("encoding_format"),
|
|
167
|
+
required=False,
|
|
168
|
+
validator=self.validate_encoding_format,
|
|
169
|
+
)
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Dimensions
|
|
173
|
+
fields.append(
|
|
174
|
+
NumericField(
|
|
175
|
+
"dimensions",
|
|
176
|
+
"Dimensions <dim>(number of dimensions for embeddings)</dim>",
|
|
177
|
+
default=initial_params.get("dimensions"),
|
|
178
|
+
min_value=1.0,
|
|
179
|
+
required=False,
|
|
180
|
+
help_text="Model-specific dimension size (e.g., 1024, 768)",
|
|
181
|
+
)
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Max parallel requests (common field)
|
|
185
|
+
fields.append(
|
|
186
|
+
NumericField(
|
|
187
|
+
"max_parallel_requests",
|
|
188
|
+
"Max parallel requests <dim>(default: 4)</dim>",
|
|
189
|
+
default=initial_params.get("max_parallel_requests", 4),
|
|
190
|
+
min_value=1.0,
|
|
191
|
+
required=False,
|
|
192
|
+
help_text="Maximum number of parallel API requests",
|
|
193
|
+
)
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Timeout (common field)
|
|
197
|
+
fields.append(
|
|
198
|
+
NumericField(
|
|
199
|
+
"timeout",
|
|
200
|
+
"Timeout in seconds <dim>(optional)</dim>",
|
|
201
|
+
default=initial_params.get("timeout"),
|
|
202
|
+
min_value=1.0,
|
|
203
|
+
required=False,
|
|
204
|
+
help_text="Timeout for each API request in seconds",
|
|
205
|
+
)
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
return Form(f"{self.title} - Inference Parameters", fields)
|
|
209
|
+
|
|
210
|
+
def build_inference_params(self, generation_type: GenerationType, params_data: dict[str, Any]) -> dict[str, Any]:
|
|
211
|
+
"""Build inference parameters dictionary from form data with proper type conversions."""
|
|
212
|
+
inference_params = {}
|
|
213
|
+
|
|
214
|
+
if generation_type == GenerationType.CHAT_COMPLETION:
|
|
215
|
+
if params_data.get("temperature") is not None:
|
|
216
|
+
inference_params["temperature"] = params_data["temperature"]
|
|
217
|
+
if params_data.get("top_p") is not None:
|
|
218
|
+
inference_params["top_p"] = params_data["top_p"]
|
|
219
|
+
if params_data.get("max_tokens") is not None:
|
|
220
|
+
inference_params["max_tokens"] = int(params_data["max_tokens"])
|
|
221
|
+
|
|
222
|
+
else: # EMBEDDING
|
|
223
|
+
# Only include fields with actual values; Pydantic will use defaults for missing fields
|
|
224
|
+
if params_data.get("encoding_format"):
|
|
225
|
+
inference_params["encoding_format"] = params_data["encoding_format"]
|
|
226
|
+
if params_data.get("dimensions"):
|
|
227
|
+
inference_params["dimensions"] = int(params_data["dimensions"])
|
|
228
|
+
|
|
229
|
+
# Common fields for both generation types
|
|
230
|
+
if params_data.get("max_parallel_requests") is not None:
|
|
231
|
+
inference_params["max_parallel_requests"] = int(params_data["max_parallel_requests"])
|
|
232
|
+
if params_data.get("timeout") is not None:
|
|
233
|
+
inference_params["timeout"] = int(params_data["timeout"])
|
|
234
|
+
|
|
235
|
+
return inference_params
|
|
236
|
+
|
|
237
|
+
def run(self, initial_data: dict[str, Any] | None = None) -> ModelConfig | None:
|
|
238
|
+
"""Run the interactive form with two-step process for generation-type-specific parameters."""
|
|
239
|
+
# Step 1: Collect basic model configuration
|
|
240
|
+
basic_form = self.create_form(initial_data)
|
|
241
|
+
|
|
242
|
+
if initial_data:
|
|
243
|
+
basic_form.set_values(initial_data)
|
|
244
|
+
|
|
245
|
+
while True:
|
|
246
|
+
basic_result = basic_form.prompt_all(allow_back=True)
|
|
247
|
+
|
|
248
|
+
if basic_result is None:
|
|
249
|
+
if confirm_action("Cancel configuration?", default=False):
|
|
250
|
+
return None
|
|
251
|
+
continue
|
|
252
|
+
|
|
253
|
+
# Step 2: Collect generation-type-specific inference parameters
|
|
254
|
+
generation_type = basic_result.get("generation_type", GenerationType.CHAT_COMPLETION)
|
|
255
|
+
initial_params = initial_data.get("inference_parameters") if initial_data else None
|
|
256
|
+
|
|
257
|
+
# Print message to indicate we're now configuring inference parameters
|
|
258
|
+
gen_type_name = "chat completion" if generation_type == GenerationType.CHAT_COMPLETION else "embedding"
|
|
259
|
+
print_text(
|
|
260
|
+
f"⚙️ Configuring {gen_type_name} inference parameters [dim](Press Enter to keep current value or skip)[/dim]\n"
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
params_form = self.create_inference_params_form(generation_type, initial_params)
|
|
264
|
+
|
|
265
|
+
params_result = params_form.prompt_all(allow_back=True)
|
|
266
|
+
|
|
267
|
+
if params_result is None:
|
|
268
|
+
if confirm_action("Cancel configuration?", default=False):
|
|
269
|
+
return None
|
|
270
|
+
continue
|
|
271
|
+
|
|
272
|
+
# Build inference_parameters dict from individual fields
|
|
273
|
+
inference_params = self.build_inference_params(generation_type, params_result)
|
|
274
|
+
|
|
275
|
+
# Merge results
|
|
276
|
+
full_data = {**basic_result, "inference_parameters": inference_params}
|
|
277
|
+
|
|
278
|
+
try:
|
|
279
|
+
config = self.build_config(full_data)
|
|
280
|
+
return config
|
|
281
|
+
except Exception as e:
|
|
282
|
+
print_error(f"Configuration error: {e}")
|
|
283
|
+
if not confirm_action("Try again?", default=True):
|
|
284
|
+
return None
|
|
104
285
|
|
|
105
286
|
def build_config(self, form_data: dict[str, Any]) -> ModelConfig:
|
|
106
287
|
"""Build ModelConfig from form data."""
|
|
@@ -112,14 +293,40 @@ class ModelFormBuilder(FormBuilder[ModelConfig]):
|
|
|
112
293
|
else:
|
|
113
294
|
provider = None
|
|
114
295
|
|
|
296
|
+
# Get generation type (from form data, used to determine which inference params to create)
|
|
297
|
+
generation_type = form_data.get("generation_type", GenerationType.CHAT_COMPLETION)
|
|
298
|
+
|
|
299
|
+
# Get inference parameters dict
|
|
300
|
+
inference_params_dict = form_data.get("inference_parameters", {})
|
|
301
|
+
|
|
302
|
+
# Create the appropriate inference parameters type based on generation_type
|
|
303
|
+
# The generation_type will be set automatically by the inference params class
|
|
304
|
+
if generation_type == GenerationType.EMBEDDING:
|
|
305
|
+
inference_params = EmbeddingInferenceParams(**inference_params_dict)
|
|
306
|
+
else:
|
|
307
|
+
inference_params = ChatCompletionInferenceParams(**inference_params_dict)
|
|
308
|
+
|
|
115
309
|
return ModelConfig(
|
|
116
310
|
alias=form_data["alias"],
|
|
117
311
|
model=form_data["model"],
|
|
118
312
|
provider=provider,
|
|
119
|
-
inference_parameters=
|
|
120
|
-
"temperature": form_data["temperature"],
|
|
121
|
-
"top_p": form_data["top_p"],
|
|
122
|
-
"max_tokens": int(form_data["max_tokens"]),
|
|
123
|
-
"max_parallel_requests": 4,
|
|
124
|
-
},
|
|
313
|
+
inference_parameters=inference_params,
|
|
125
314
|
)
|
|
315
|
+
|
|
316
|
+
def validate_alias(self, alias: str) -> tuple[bool, str | None]:
|
|
317
|
+
"""Validate model alias."""
|
|
318
|
+
if not alias:
|
|
319
|
+
return False, "Model alias is required"
|
|
320
|
+
if alias in self.existing_aliases:
|
|
321
|
+
return False, f"Model alias '{alias}' already exists"
|
|
322
|
+
return True, None
|
|
323
|
+
|
|
324
|
+
def validate_encoding_format(self, value: str) -> tuple[bool, str | None]:
|
|
325
|
+
"""Validate encoding format for embedding models."""
|
|
326
|
+
if not value:
|
|
327
|
+
return True, None # Optional field
|
|
328
|
+
if value.lower() in ("clear", "none", "default"):
|
|
329
|
+
return True, None # Allow clearing keywords
|
|
330
|
+
if value not in ("float", "base64"):
|
|
331
|
+
return False, "Must be either 'float' or 'base64'"
|
|
332
|
+
return True, None
|
data_designer/cli/main.py
CHANGED
|
@@ -3,8 +3,8 @@
|
|
|
3
3
|
|
|
4
4
|
import typer
|
|
5
5
|
|
|
6
|
+
from data_designer.cli.commands import download, models, providers, reset
|
|
6
7
|
from data_designer.cli.commands import list as list_cmd
|
|
7
|
-
from data_designer.cli.commands import models, providers, reset
|
|
8
8
|
from data_designer.config.default_model_settings import resolve_seed_default_model_settings
|
|
9
9
|
from data_designer.config.utils.misc import can_run_data_designer_locally
|
|
10
10
|
|
|
@@ -32,7 +32,17 @@ config_app.command(name="models", help="Configure models interactively")(models.
|
|
|
32
32
|
config_app.command(name="list", help="List current configurations")(list_cmd.list_command)
|
|
33
33
|
config_app.command(name="reset", help="Reset configuration files")(reset.reset_command)
|
|
34
34
|
|
|
35
|
+
# Create download command group
|
|
36
|
+
download_app = typer.Typer(
|
|
37
|
+
name="download",
|
|
38
|
+
help="Download assets for Data Designer",
|
|
39
|
+
no_args_is_help=True,
|
|
40
|
+
)
|
|
41
|
+
download_app.command(name="personas", help="Download Nemotron-Persona datasets")(download.personas_command)
|
|
42
|
+
|
|
43
|
+
# Add command groups to main app
|
|
35
44
|
app.add_typer(config_app, name="config")
|
|
45
|
+
app.add_typer(download_app, name="download")
|
|
36
46
|
|
|
37
47
|
|
|
38
48
|
def main() -> None:
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from data_designer.config.utils.constants import (
|
|
7
|
+
NEMOTRON_PERSONAS_DATASET_PREFIX,
|
|
8
|
+
NEMOTRON_PERSONAS_DATASET_SIZES,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PersonaLocale(BaseModel):
|
|
13
|
+
"""Metadata for a single persona locale."""
|
|
14
|
+
|
|
15
|
+
code: str
|
|
16
|
+
size: str
|
|
17
|
+
dataset_name: str
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PersonaLocaleRegistry(BaseModel):
|
|
21
|
+
"""Registry for available persona locales."""
|
|
22
|
+
|
|
23
|
+
locales: list[PersonaLocale]
|
|
24
|
+
dataset_prefix: str = NEMOTRON_PERSONAS_DATASET_PREFIX
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PersonaRepository:
|
|
28
|
+
"""Repository for persona locale metadata.
|
|
29
|
+
|
|
30
|
+
This repository provides access to built-in persona locale metadata.
|
|
31
|
+
Unlike ConfigRepository subclasses, this is read-only reference data
|
|
32
|
+
about what's available in NGC, not user configuration.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self) -> None:
|
|
36
|
+
"""Initialize repository with built-in locale metadata."""
|
|
37
|
+
self._registry = self._initialize_registry()
|
|
38
|
+
|
|
39
|
+
def _initialize_registry(self) -> PersonaLocaleRegistry:
|
|
40
|
+
"""Initialize registry from constants."""
|
|
41
|
+
locales = [
|
|
42
|
+
PersonaLocale(
|
|
43
|
+
code=code,
|
|
44
|
+
size=size,
|
|
45
|
+
dataset_name=f"{NEMOTRON_PERSONAS_DATASET_PREFIX}{code.lower()}",
|
|
46
|
+
)
|
|
47
|
+
for code, size in NEMOTRON_PERSONAS_DATASET_SIZES.items()
|
|
48
|
+
]
|
|
49
|
+
return PersonaLocaleRegistry(locales=locales)
|
|
50
|
+
|
|
51
|
+
def list_all(self) -> list[PersonaLocale]:
|
|
52
|
+
"""Get all available persona locales.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
List of all available persona locales
|
|
56
|
+
"""
|
|
57
|
+
return list(self._registry.locales)
|
|
58
|
+
|
|
59
|
+
def get_by_code(self, code: str) -> PersonaLocale | None:
|
|
60
|
+
"""Get a specific locale by code.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
code: Locale code (e.g., 'en_US', 'ja_JP')
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
PersonaLocale if found, None otherwise
|
|
67
|
+
"""
|
|
68
|
+
return next((locale for locale in self._registry.locales if locale.code == code), None)
|
|
69
|
+
|
|
70
|
+
def get_dataset_name(self, code: str) -> str | None:
|
|
71
|
+
"""Get the NGC dataset name for a locale.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
code: Locale code (e.g., 'en_US', 'ja_JP')
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Dataset name if locale exists, None otherwise
|
|
78
|
+
"""
|
|
79
|
+
locale = self.get_by_code(code)
|
|
80
|
+
return locale.dataset_name if locale else None
|
|
81
|
+
|
|
82
|
+
def get_dataset_prefix(self) -> str:
|
|
83
|
+
"""Get the dataset prefix for all persona datasets.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Dataset prefix string
|
|
87
|
+
"""
|
|
88
|
+
return self._registry.dataset_prefix
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from data_designer.cli.services.download_service import DownloadService
|
|
4
5
|
from data_designer.cli.services.model_service import ModelService
|
|
5
6
|
from data_designer.cli.services.provider_service import ProviderService
|
|
6
7
|
|
|
7
|
-
__all__ = ["ModelService", "ProviderService"]
|
|
8
|
+
__all__ = ["DownloadService", "ModelService", "ProviderService"]
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import glob
|
|
5
|
+
import shutil
|
|
6
|
+
import subprocess
|
|
7
|
+
import tempfile
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
from data_designer.cli.repositories.persona_repository import PersonaRepository
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DownloadService:
|
|
14
|
+
"""Business logic for downloading assets via NGC CLI."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, config_dir: Path, persona_repository: PersonaRepository):
|
|
17
|
+
self.config_dir = config_dir
|
|
18
|
+
self.managed_assets_dir = config_dir / "managed-assets" / "datasets"
|
|
19
|
+
self.persona_repository = persona_repository
|
|
20
|
+
|
|
21
|
+
def get_available_locales(self) -> dict[str, str]:
|
|
22
|
+
"""Get dictionary of available persona locales (locale code -> locale code)."""
|
|
23
|
+
locales = self.persona_repository.list_all()
|
|
24
|
+
return {locale.code: locale.code for locale in locales}
|
|
25
|
+
|
|
26
|
+
def download_persona_dataset(self, locale: str) -> Path:
|
|
27
|
+
"""Download persona dataset for a specific locale using NGC CLI and move to managed assets.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
locale: Locale code (e.g., 'en_US', 'ja_JP')
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Path to the managed assets datasets directory
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
ValueError: If locale is invalid
|
|
37
|
+
subprocess.CalledProcessError: If NGC CLI command fails
|
|
38
|
+
"""
|
|
39
|
+
locale_obj = self.persona_repository.get_by_code(locale)
|
|
40
|
+
if not locale_obj:
|
|
41
|
+
raise ValueError(f"Invalid locale: {locale}")
|
|
42
|
+
|
|
43
|
+
self.managed_assets_dir.mkdir(parents=True, exist_ok=True)
|
|
44
|
+
|
|
45
|
+
# Use temporary directory for download
|
|
46
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
47
|
+
# Run NGC CLI download command (without version to get latest)
|
|
48
|
+
cmd = [
|
|
49
|
+
"ngc",
|
|
50
|
+
"registry",
|
|
51
|
+
"resource",
|
|
52
|
+
"download-version",
|
|
53
|
+
f"nvidia/nemotron-personas/{locale_obj.dataset_name}",
|
|
54
|
+
"--dest",
|
|
55
|
+
temp_dir,
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
subprocess.run(cmd, check=True)
|
|
59
|
+
|
|
60
|
+
download_pattern = f"{temp_dir}/{locale_obj.dataset_name}*/*.parquet"
|
|
61
|
+
parquet_files = glob.glob(download_pattern)
|
|
62
|
+
|
|
63
|
+
if not parquet_files:
|
|
64
|
+
raise FileNotFoundError(f"No parquet files found matching pattern: {download_pattern}")
|
|
65
|
+
|
|
66
|
+
# Move each parquet file to managed assets
|
|
67
|
+
for parquet_file in parquet_files:
|
|
68
|
+
source = Path(parquet_file)
|
|
69
|
+
dest = self.managed_assets_dir / source.name
|
|
70
|
+
shutil.move(str(source), str(dest))
|
|
71
|
+
|
|
72
|
+
return self.managed_assets_dir
|
|
73
|
+
|
|
74
|
+
def get_managed_assets_directory(self) -> Path:
|
|
75
|
+
"""Get the directory where managed datasets are stored."""
|
|
76
|
+
return self.managed_assets_dir
|
|
77
|
+
|
|
78
|
+
def is_locale_downloaded(self, locale: str) -> bool:
|
|
79
|
+
"""Check if a locale has already been downloaded to managed assets.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
locale: Locale code to check
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
True if the locale dataset exists in managed assets
|
|
86
|
+
"""
|
|
87
|
+
locale_obj = self.persona_repository.get_by_code(locale)
|
|
88
|
+
if not locale_obj:
|
|
89
|
+
return False
|
|
90
|
+
|
|
91
|
+
if not self.managed_assets_dir.exists():
|
|
92
|
+
return False
|
|
93
|
+
|
|
94
|
+
# Look for any parquet files that start with the dataset pattern
|
|
95
|
+
parquet_files = glob.glob(str(self.managed_assets_dir / f"{locale}.parquet"))
|
|
96
|
+
|
|
97
|
+
return len(parquet_files) > 0
|