data-designer 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. data_designer/_version.py +2 -2
  2. data_designer/cli/README.md +15 -1
  3. data_designer/cli/commands/download.py +56 -0
  4. data_designer/cli/commands/list.py +4 -18
  5. data_designer/cli/controllers/__init__.py +2 -1
  6. data_designer/cli/controllers/download_controller.py +217 -0
  7. data_designer/cli/controllers/model_controller.py +4 -3
  8. data_designer/cli/forms/field.py +65 -19
  9. data_designer/cli/forms/model_builder.py +251 -44
  10. data_designer/cli/main.py +11 -1
  11. data_designer/cli/repositories/persona_repository.py +88 -0
  12. data_designer/cli/services/__init__.py +2 -1
  13. data_designer/cli/services/download_service.py +97 -0
  14. data_designer/cli/ui.py +131 -0
  15. data_designer/cli/utils.py +34 -0
  16. data_designer/config/analysis/__init__.py +2 -0
  17. data_designer/config/analysis/column_profilers.py +75 -7
  18. data_designer/config/analysis/column_statistics.py +192 -48
  19. data_designer/config/analysis/dataset_profiler.py +23 -5
  20. data_designer/config/analysis/utils/reporting.py +3 -3
  21. data_designer/config/base.py +3 -3
  22. data_designer/config/column_configs.py +27 -6
  23. data_designer/config/column_types.py +24 -17
  24. data_designer/config/config_builder.py +34 -26
  25. data_designer/config/data_designer_config.py +7 -7
  26. data_designer/config/datastore.py +6 -6
  27. data_designer/config/default_model_settings.py +27 -34
  28. data_designer/config/exports.py +14 -1
  29. data_designer/config/models.py +155 -29
  30. data_designer/config/preview_results.py +5 -4
  31. data_designer/config/processors.py +109 -4
  32. data_designer/config/sampler_constraints.py +1 -2
  33. data_designer/config/sampler_params.py +31 -31
  34. data_designer/config/seed.py +1 -2
  35. data_designer/config/utils/code_lang.py +4 -5
  36. data_designer/config/utils/constants.py +31 -8
  37. data_designer/config/utils/io_helpers.py +5 -5
  38. data_designer/config/utils/misc.py +1 -4
  39. data_designer/config/utils/numerical_helpers.py +2 -2
  40. data_designer/config/utils/type_helpers.py +3 -3
  41. data_designer/config/utils/validation.py +39 -9
  42. data_designer/config/utils/visualization.py +62 -15
  43. data_designer/config/validator_params.py +4 -8
  44. data_designer/engine/analysis/column_profilers/base.py +0 -7
  45. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
  46. data_designer/engine/analysis/column_statistics.py +16 -16
  47. data_designer/engine/analysis/dataset_profiler.py +25 -4
  48. data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
  49. data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
  50. data_designer/engine/column_generators/generators/base.py +34 -0
  51. data_designer/engine/column_generators/generators/embedding.py +45 -0
  52. data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
  53. data_designer/engine/column_generators/registry.py +4 -2
  54. data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
  55. data_designer/engine/configurable_task.py +2 -2
  56. data_designer/engine/dataset_builders/artifact_storage.py +14 -5
  57. data_designer/engine/dataset_builders/column_wise_builder.py +12 -8
  58. data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
  59. data_designer/engine/models/facade.py +66 -9
  60. data_designer/engine/models/litellm_overrides.py +5 -6
  61. data_designer/engine/models/parsers/errors.py +2 -4
  62. data_designer/engine/models/parsers/parser.py +2 -3
  63. data_designer/engine/models/parsers/postprocessors.py +3 -4
  64. data_designer/engine/models/parsers/types.py +4 -4
  65. data_designer/engine/models/registry.py +20 -11
  66. data_designer/engine/models/usage.py +7 -9
  67. data_designer/engine/processing/ginja/ast.py +1 -2
  68. data_designer/engine/processing/processors/drop_columns.py +1 -1
  69. data_designer/engine/processing/processors/registry.py +3 -0
  70. data_designer/engine/processing/processors/schema_transform.py +53 -0
  71. data_designer/engine/processing/utils.py +40 -2
  72. data_designer/engine/registry/base.py +12 -12
  73. data_designer/engine/sampling_gen/constraints.py +1 -2
  74. data_designer/engine/sampling_gen/data_sources/base.py +14 -14
  75. data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
  76. data_designer/engine/sampling_gen/people_gen.py +3 -7
  77. data_designer/engine/validators/base.py +2 -2
  78. data_designer/interface/data_designer.py +12 -0
  79. data_designer/interface/results.py +36 -0
  80. data_designer/logging.py +2 -2
  81. data_designer/plugin_manager.py +3 -3
  82. data_designer/plugins/plugin.py +3 -3
  83. data_designer/plugins/registry.py +2 -2
  84. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/METADATA +9 -9
  85. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/RECORD +88 -81
  86. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
  87. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
  88. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -6,8 +6,13 @@ from typing import Any
6
6
  from data_designer.cli.forms.builder import FormBuilder
7
7
  from data_designer.cli.forms.field import NumericField, SelectField, TextField
8
8
  from data_designer.cli.forms.form import Form
9
- from data_designer.config.models import ModelConfig
10
- from data_designer.config.utils.constants import MAX_TEMPERATURE, MAX_TOP_P, MIN_TEMPERATURE, MIN_TOP_P
9
+ from data_designer.cli.ui import confirm_action, print_error, print_text
10
+ from data_designer.config.models import (
11
+ ChatCompletionInferenceParams,
12
+ EmbeddingInferenceParams,
13
+ GenerationType,
14
+ ModelConfig,
15
+ )
11
16
 
12
17
 
13
18
  class ModelFormBuilder(FormBuilder[ModelConfig]):
@@ -19,7 +24,7 @@ class ModelFormBuilder(FormBuilder[ModelConfig]):
19
24
  self.available_providers = available_providers or []
20
25
 
21
26
  def create_form(self, initial_data: dict[str, Any] | None = None) -> Form:
22
- """Create the model configuration form."""
27
+ """Create the model configuration form with basic fields."""
23
28
  fields = []
24
29
 
25
30
  # Model alias
@@ -29,7 +34,7 @@ class ModelFormBuilder(FormBuilder[ModelConfig]):
29
34
  "Model alias (used in your configs)",
30
35
  default=initial_data.get("alias") if initial_data else None,
31
36
  required=True,
32
- validator=self._validate_alias,
37
+ validator=self.validate_alias,
33
38
  )
34
39
  )
35
40
 
@@ -37,10 +42,10 @@ class ModelFormBuilder(FormBuilder[ModelConfig]):
37
42
  fields.append(
38
43
  TextField(
39
44
  "model",
40
- "Model ID",
45
+ "Model",
41
46
  default=initial_data.get("model") if initial_data else None,
42
47
  required=True,
43
- validator=lambda x: (False, "Model ID is required") if not x else (True, None),
48
+ validator=lambda x: (False, "Model is required") if not x else (True, None),
44
49
  )
45
50
  )
46
51
 
@@ -61,46 +66,222 @@ class ModelFormBuilder(FormBuilder[ModelConfig]):
61
66
  # Single provider - will be set automatically
62
67
  pass
63
68
 
64
- # Inference parameters
65
- fields.extend(
66
- [
69
+ # Generation type
70
+ # Extract from inference_parameters if present (for existing models)
71
+ default_gen_type = GenerationType.CHAT_COMPLETION
72
+ if initial_data:
73
+ inference_params = initial_data.get("inference_parameters", {})
74
+ default_gen_type = inference_params.get("generation_type", default_gen_type)
75
+
76
+ fields.append(
77
+ SelectField(
78
+ "generation_type",
79
+ "Generation type",
80
+ options={
81
+ GenerationType.CHAT_COMPLETION: "Chat completion",
82
+ GenerationType.EMBEDDING: "Embedding",
83
+ },
84
+ default=default_gen_type,
85
+ )
86
+ )
87
+
88
+ return Form(self.title, fields)
89
+
90
+ def create_inference_params_form(
91
+ self, generation_type: GenerationType, initial_params: dict[str, Any] | None = None
92
+ ) -> Form:
93
+ """Create generation-type-specific inference parameters form."""
94
+ initial_params = initial_params or {}
95
+ fields = []
96
+
97
+ if generation_type == GenerationType.CHAT_COMPLETION:
98
+ # Temperature
99
+ fields.append(
67
100
  NumericField(
68
101
  "temperature",
69
- f"Temperature ({MIN_TEMPERATURE}-{MAX_TEMPERATURE})",
70
- default=initial_data.get("inference_parameters", {}).get("temperature", 0.7)
71
- if initial_data
72
- else 0.7,
73
- min_value=MIN_TEMPERATURE,
74
- max_value=MAX_TEMPERATURE,
75
- ),
102
+ "Temperature <dim>(0.0-2.0)</dim>",
103
+ default=initial_params.get("temperature"),
104
+ min_value=0.0,
105
+ max_value=2.0,
106
+ required=False,
107
+ help_text="Higher values make output more random, lower values more deterministic",
108
+ )
109
+ )
110
+
111
+ # Top P
112
+ fields.append(
76
113
  NumericField(
77
114
  "top_p",
78
- f"Top P ({MIN_TOP_P}-{MAX_TOP_P})",
79
- default=initial_data.get("inference_parameters", {}).get("top_p", 0.9) if initial_data else 0.9,
80
- min_value=MIN_TOP_P,
81
- max_value=MAX_TOP_P,
82
- ),
115
+ "Top P <dim>(0.0-1.0)</dim>",
116
+ default=initial_params.get("top_p"),
117
+ min_value=0.0,
118
+ max_value=1.0,
119
+ required=False,
120
+ help_text="Controls diversity via nucleus sampling",
121
+ )
122
+ )
123
+
124
+ # Max tokens
125
+ fields.append(
83
126
  NumericField(
84
127
  "max_tokens",
85
- "Max tokens",
86
- default=initial_data.get("inference_parameters", {}).get("max_tokens", 2048)
87
- if initial_data
88
- else 2048,
89
- min_value=1,
90
- max_value=100000,
91
- ),
92
- ]
93
- )
128
+ "Max tokens <dim>(maximum total tokens including input and output)</dim>",
129
+ default=initial_params.get("max_tokens"),
130
+ min_value=1.0,
131
+ required=False,
132
+ help_text="Maximum number of tokens including both input prompt and generated response",
133
+ )
134
+ )
94
135
 
95
- return Form(self.title, fields)
136
+ # Max parallel requests
137
+ fields.append(
138
+ NumericField(
139
+ "max_parallel_requests",
140
+ "Max parallel requests <dim>(default: 4)</dim>",
141
+ default=initial_params.get("max_parallel_requests", 4),
142
+ min_value=1.0,
143
+ required=False,
144
+ help_text="Maximum number of parallel API requests",
145
+ )
146
+ )
96
147
 
97
- def _validate_alias(self, alias: str) -> tuple[bool, str | None]:
98
- """Validate model alias."""
99
- if not alias:
100
- return False, "Model alias is required"
101
- if alias in self.existing_aliases:
102
- return False, f"Model alias '{alias}' already exists"
103
- return True, None
148
+ # Timeout
149
+ fields.append(
150
+ NumericField(
151
+ "timeout",
152
+ "Timeout in seconds <dim>(optional)</dim>",
153
+ default=initial_params.get("timeout"),
154
+ min_value=1.0,
155
+ required=False,
156
+ help_text="Timeout for each API request in seconds",
157
+ )
158
+ )
159
+
160
+ else: # EMBEDDING
161
+ # Encoding format
162
+ fields.append(
163
+ TextField(
164
+ "encoding_format",
165
+ "Encoding format <dim>(float or base64)</dim>",
166
+ default=initial_params.get("encoding_format"),
167
+ required=False,
168
+ validator=self.validate_encoding_format,
169
+ )
170
+ )
171
+
172
+ # Dimensions
173
+ fields.append(
174
+ NumericField(
175
+ "dimensions",
176
+ "Dimensions <dim>(number of dimensions for embeddings)</dim>",
177
+ default=initial_params.get("dimensions"),
178
+ min_value=1.0,
179
+ required=False,
180
+ help_text="Model-specific dimension size (e.g., 1024, 768)",
181
+ )
182
+ )
183
+
184
+ # Max parallel requests (common field)
185
+ fields.append(
186
+ NumericField(
187
+ "max_parallel_requests",
188
+ "Max parallel requests <dim>(default: 4)</dim>",
189
+ default=initial_params.get("max_parallel_requests", 4),
190
+ min_value=1.0,
191
+ required=False,
192
+ help_text="Maximum number of parallel API requests",
193
+ )
194
+ )
195
+
196
+ # Timeout (common field)
197
+ fields.append(
198
+ NumericField(
199
+ "timeout",
200
+ "Timeout in seconds <dim>(optional)</dim>",
201
+ default=initial_params.get("timeout"),
202
+ min_value=1.0,
203
+ required=False,
204
+ help_text="Timeout for each API request in seconds",
205
+ )
206
+ )
207
+
208
+ return Form(f"{self.title} - Inference Parameters", fields)
209
+
210
+ def build_inference_params(self, generation_type: GenerationType, params_data: dict[str, Any]) -> dict[str, Any]:
211
+ """Build inference parameters dictionary from form data with proper type conversions."""
212
+ inference_params = {}
213
+
214
+ if generation_type == GenerationType.CHAT_COMPLETION:
215
+ if params_data.get("temperature") is not None:
216
+ inference_params["temperature"] = params_data["temperature"]
217
+ if params_data.get("top_p") is not None:
218
+ inference_params["top_p"] = params_data["top_p"]
219
+ if params_data.get("max_tokens") is not None:
220
+ inference_params["max_tokens"] = int(params_data["max_tokens"])
221
+
222
+ else: # EMBEDDING
223
+ # Only include fields with actual values; Pydantic will use defaults for missing fields
224
+ if params_data.get("encoding_format"):
225
+ inference_params["encoding_format"] = params_data["encoding_format"]
226
+ if params_data.get("dimensions"):
227
+ inference_params["dimensions"] = int(params_data["dimensions"])
228
+
229
+ # Common fields for both generation types
230
+ if params_data.get("max_parallel_requests") is not None:
231
+ inference_params["max_parallel_requests"] = int(params_data["max_parallel_requests"])
232
+ if params_data.get("timeout") is not None:
233
+ inference_params["timeout"] = int(params_data["timeout"])
234
+
235
+ return inference_params
236
+
237
+ def run(self, initial_data: dict[str, Any] | None = None) -> ModelConfig | None:
238
+ """Run the interactive form with two-step process for generation-type-specific parameters."""
239
+ # Step 1: Collect basic model configuration
240
+ basic_form = self.create_form(initial_data)
241
+
242
+ if initial_data:
243
+ basic_form.set_values(initial_data)
244
+
245
+ while True:
246
+ basic_result = basic_form.prompt_all(allow_back=True)
247
+
248
+ if basic_result is None:
249
+ if confirm_action("Cancel configuration?", default=False):
250
+ return None
251
+ continue
252
+
253
+ # Step 2: Collect generation-type-specific inference parameters
254
+ generation_type = basic_result.get("generation_type", GenerationType.CHAT_COMPLETION)
255
+ initial_params = initial_data.get("inference_parameters") if initial_data else None
256
+
257
+ # Print message to indicate we're now configuring inference parameters
258
+ gen_type_name = "chat completion" if generation_type == GenerationType.CHAT_COMPLETION else "embedding"
259
+ print_text(
260
+ f"⚙️ Configuring {gen_type_name} inference parameters [dim](Press Enter to keep current value or skip)[/dim]\n"
261
+ )
262
+
263
+ params_form = self.create_inference_params_form(generation_type, initial_params)
264
+
265
+ params_result = params_form.prompt_all(allow_back=True)
266
+
267
+ if params_result is None:
268
+ if confirm_action("Cancel configuration?", default=False):
269
+ return None
270
+ continue
271
+
272
+ # Build inference_parameters dict from individual fields
273
+ inference_params = self.build_inference_params(generation_type, params_result)
274
+
275
+ # Merge results
276
+ full_data = {**basic_result, "inference_parameters": inference_params}
277
+
278
+ try:
279
+ config = self.build_config(full_data)
280
+ return config
281
+ except Exception as e:
282
+ print_error(f"Configuration error: {e}")
283
+ if not confirm_action("Try again?", default=True):
284
+ return None
104
285
 
105
286
  def build_config(self, form_data: dict[str, Any]) -> ModelConfig:
106
287
  """Build ModelConfig from form data."""
@@ -112,14 +293,40 @@ class ModelFormBuilder(FormBuilder[ModelConfig]):
112
293
  else:
113
294
  provider = None
114
295
 
296
+ # Get generation type (from form data, used to determine which inference params to create)
297
+ generation_type = form_data.get("generation_type", GenerationType.CHAT_COMPLETION)
298
+
299
+ # Get inference parameters dict
300
+ inference_params_dict = form_data.get("inference_parameters", {})
301
+
302
+ # Create the appropriate inference parameters type based on generation_type
303
+ # The generation_type will be set automatically by the inference params class
304
+ if generation_type == GenerationType.EMBEDDING:
305
+ inference_params = EmbeddingInferenceParams(**inference_params_dict)
306
+ else:
307
+ inference_params = ChatCompletionInferenceParams(**inference_params_dict)
308
+
115
309
  return ModelConfig(
116
310
  alias=form_data["alias"],
117
311
  model=form_data["model"],
118
312
  provider=provider,
119
- inference_parameters={
120
- "temperature": form_data["temperature"],
121
- "top_p": form_data["top_p"],
122
- "max_tokens": int(form_data["max_tokens"]),
123
- "max_parallel_requests": 4,
124
- },
313
+ inference_parameters=inference_params,
125
314
  )
315
+
316
+ def validate_alias(self, alias: str) -> tuple[bool, str | None]:
317
+ """Validate model alias."""
318
+ if not alias:
319
+ return False, "Model alias is required"
320
+ if alias in self.existing_aliases:
321
+ return False, f"Model alias '{alias}' already exists"
322
+ return True, None
323
+
324
+ def validate_encoding_format(self, value: str) -> tuple[bool, str | None]:
325
+ """Validate encoding format for embedding models."""
326
+ if not value:
327
+ return True, None # Optional field
328
+ if value.lower() in ("clear", "none", "default"):
329
+ return True, None # Allow clearing keywords
330
+ if value not in ("float", "base64"):
331
+ return False, "Must be either 'float' or 'base64'"
332
+ return True, None
data_designer/cli/main.py CHANGED
@@ -3,8 +3,8 @@
3
3
 
4
4
  import typer
5
5
 
6
+ from data_designer.cli.commands import download, models, providers, reset
6
7
  from data_designer.cli.commands import list as list_cmd
7
- from data_designer.cli.commands import models, providers, reset
8
8
  from data_designer.config.default_model_settings import resolve_seed_default_model_settings
9
9
  from data_designer.config.utils.misc import can_run_data_designer_locally
10
10
 
@@ -32,7 +32,17 @@ config_app.command(name="models", help="Configure models interactively")(models.
32
32
  config_app.command(name="list", help="List current configurations")(list_cmd.list_command)
33
33
  config_app.command(name="reset", help="Reset configuration files")(reset.reset_command)
34
34
 
35
+ # Create download command group
36
+ download_app = typer.Typer(
37
+ name="download",
38
+ help="Download assets for Data Designer",
39
+ no_args_is_help=True,
40
+ )
41
+ download_app.command(name="personas", help="Download Nemotron-Persona datasets")(download.personas_command)
42
+
43
+ # Add command groups to main app
35
44
  app.add_typer(config_app, name="config")
45
+ app.add_typer(download_app, name="download")
36
46
 
37
47
 
38
48
  def main() -> None:
@@ -0,0 +1,88 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from data_designer.config.utils.constants import (
7
+ NEMOTRON_PERSONAS_DATASET_PREFIX,
8
+ NEMOTRON_PERSONAS_DATASET_SIZES,
9
+ )
10
+
11
+
12
+ class PersonaLocale(BaseModel):
13
+ """Metadata for a single persona locale."""
14
+
15
+ code: str
16
+ size: str
17
+ dataset_name: str
18
+
19
+
20
+ class PersonaLocaleRegistry(BaseModel):
21
+ """Registry for available persona locales."""
22
+
23
+ locales: list[PersonaLocale]
24
+ dataset_prefix: str = NEMOTRON_PERSONAS_DATASET_PREFIX
25
+
26
+
27
+ class PersonaRepository:
28
+ """Repository for persona locale metadata.
29
+
30
+ This repository provides access to built-in persona locale metadata.
31
+ Unlike ConfigRepository subclasses, this is read-only reference data
32
+ about what's available in NGC, not user configuration.
33
+ """
34
+
35
+ def __init__(self) -> None:
36
+ """Initialize repository with built-in locale metadata."""
37
+ self._registry = self._initialize_registry()
38
+
39
+ def _initialize_registry(self) -> PersonaLocaleRegistry:
40
+ """Initialize registry from constants."""
41
+ locales = [
42
+ PersonaLocale(
43
+ code=code,
44
+ size=size,
45
+ dataset_name=f"{NEMOTRON_PERSONAS_DATASET_PREFIX}{code.lower()}",
46
+ )
47
+ for code, size in NEMOTRON_PERSONAS_DATASET_SIZES.items()
48
+ ]
49
+ return PersonaLocaleRegistry(locales=locales)
50
+
51
+ def list_all(self) -> list[PersonaLocale]:
52
+ """Get all available persona locales.
53
+
54
+ Returns:
55
+ List of all available persona locales
56
+ """
57
+ return list(self._registry.locales)
58
+
59
+ def get_by_code(self, code: str) -> PersonaLocale | None:
60
+ """Get a specific locale by code.
61
+
62
+ Args:
63
+ code: Locale code (e.g., 'en_US', 'ja_JP')
64
+
65
+ Returns:
66
+ PersonaLocale if found, None otherwise
67
+ """
68
+ return next((locale for locale in self._registry.locales if locale.code == code), None)
69
+
70
+ def get_dataset_name(self, code: str) -> str | None:
71
+ """Get the NGC dataset name for a locale.
72
+
73
+ Args:
74
+ code: Locale code (e.g., 'en_US', 'ja_JP')
75
+
76
+ Returns:
77
+ Dataset name if locale exists, None otherwise
78
+ """
79
+ locale = self.get_by_code(code)
80
+ return locale.dataset_name if locale else None
81
+
82
+ def get_dataset_prefix(self) -> str:
83
+ """Get the dataset prefix for all persona datasets.
84
+
85
+ Returns:
86
+ Dataset prefix string
87
+ """
88
+ return self._registry.dataset_prefix
@@ -1,7 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from data_designer.cli.services.download_service import DownloadService
4
5
  from data_designer.cli.services.model_service import ModelService
5
6
  from data_designer.cli.services.provider_service import ProviderService
6
7
 
7
- __all__ = ["ModelService", "ProviderService"]
8
+ __all__ = ["DownloadService", "ModelService", "ProviderService"]
@@ -0,0 +1,97 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import glob
5
+ import shutil
6
+ import subprocess
7
+ import tempfile
8
+ from pathlib import Path
9
+
10
+ from data_designer.cli.repositories.persona_repository import PersonaRepository
11
+
12
+
13
+ class DownloadService:
14
+ """Business logic for downloading assets via NGC CLI."""
15
+
16
+ def __init__(self, config_dir: Path, persona_repository: PersonaRepository):
17
+ self.config_dir = config_dir
18
+ self.managed_assets_dir = config_dir / "managed-assets" / "datasets"
19
+ self.persona_repository = persona_repository
20
+
21
+ def get_available_locales(self) -> dict[str, str]:
22
+ """Get dictionary of available persona locales (locale code -> locale code)."""
23
+ locales = self.persona_repository.list_all()
24
+ return {locale.code: locale.code for locale in locales}
25
+
26
+ def download_persona_dataset(self, locale: str) -> Path:
27
+ """Download persona dataset for a specific locale using NGC CLI and move to managed assets.
28
+
29
+ Args:
30
+ locale: Locale code (e.g., 'en_US', 'ja_JP')
31
+
32
+ Returns:
33
+ Path to the managed assets datasets directory
34
+
35
+ Raises:
36
+ ValueError: If locale is invalid
37
+ subprocess.CalledProcessError: If NGC CLI command fails
38
+ """
39
+ locale_obj = self.persona_repository.get_by_code(locale)
40
+ if not locale_obj:
41
+ raise ValueError(f"Invalid locale: {locale}")
42
+
43
+ self.managed_assets_dir.mkdir(parents=True, exist_ok=True)
44
+
45
+ # Use temporary directory for download
46
+ with tempfile.TemporaryDirectory() as temp_dir:
47
+ # Run NGC CLI download command (without version to get latest)
48
+ cmd = [
49
+ "ngc",
50
+ "registry",
51
+ "resource",
52
+ "download-version",
53
+ f"nvidia/nemotron-personas/{locale_obj.dataset_name}",
54
+ "--dest",
55
+ temp_dir,
56
+ ]
57
+
58
+ subprocess.run(cmd, check=True)
59
+
60
+ download_pattern = f"{temp_dir}/{locale_obj.dataset_name}*/*.parquet"
61
+ parquet_files = glob.glob(download_pattern)
62
+
63
+ if not parquet_files:
64
+ raise FileNotFoundError(f"No parquet files found matching pattern: {download_pattern}")
65
+
66
+ # Move each parquet file to managed assets
67
+ for parquet_file in parquet_files:
68
+ source = Path(parquet_file)
69
+ dest = self.managed_assets_dir / source.name
70
+ shutil.move(str(source), str(dest))
71
+
72
+ return self.managed_assets_dir
73
+
74
+ def get_managed_assets_directory(self) -> Path:
75
+ """Get the directory where managed datasets are stored."""
76
+ return self.managed_assets_dir
77
+
78
+ def is_locale_downloaded(self, locale: str) -> bool:
79
+ """Check if a locale has already been downloaded to managed assets.
80
+
81
+ Args:
82
+ locale: Locale code to check
83
+
84
+ Returns:
85
+ True if the locale dataset exists in managed assets
86
+ """
87
+ locale_obj = self.persona_repository.get_by_code(locale)
88
+ if not locale_obj:
89
+ return False
90
+
91
+ if not self.managed_assets_dir.exists():
92
+ return False
93
+
94
+ # Look for any parquet files that start with the dataset pattern
95
+ parquet_files = glob.glob(str(self.managed_assets_dir / f"{locale}.parquet"))
96
+
97
+ return len(parquet_files) > 0