data-designer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. data_designer/__init__.py +15 -0
  2. data_designer/_version.py +34 -0
  3. data_designer/cli/README.md +236 -0
  4. data_designer/cli/__init__.py +6 -0
  5. data_designer/cli/commands/__init__.py +2 -0
  6. data_designer/cli/commands/list.py +130 -0
  7. data_designer/cli/commands/models.py +10 -0
  8. data_designer/cli/commands/providers.py +11 -0
  9. data_designer/cli/commands/reset.py +100 -0
  10. data_designer/cli/controllers/__init__.py +7 -0
  11. data_designer/cli/controllers/model_controller.py +246 -0
  12. data_designer/cli/controllers/provider_controller.py +317 -0
  13. data_designer/cli/forms/__init__.py +20 -0
  14. data_designer/cli/forms/builder.py +51 -0
  15. data_designer/cli/forms/field.py +180 -0
  16. data_designer/cli/forms/form.py +59 -0
  17. data_designer/cli/forms/model_builder.py +125 -0
  18. data_designer/cli/forms/provider_builder.py +76 -0
  19. data_designer/cli/main.py +44 -0
  20. data_designer/cli/repositories/__init__.py +8 -0
  21. data_designer/cli/repositories/base.py +39 -0
  22. data_designer/cli/repositories/model_repository.py +42 -0
  23. data_designer/cli/repositories/provider_repository.py +43 -0
  24. data_designer/cli/services/__init__.py +7 -0
  25. data_designer/cli/services/model_service.py +116 -0
  26. data_designer/cli/services/provider_service.py +111 -0
  27. data_designer/cli/ui.py +448 -0
  28. data_designer/cli/utils.py +47 -0
  29. data_designer/config/__init__.py +2 -0
  30. data_designer/config/analysis/column_profilers.py +89 -0
  31. data_designer/config/analysis/column_statistics.py +274 -0
  32. data_designer/config/analysis/dataset_profiler.py +60 -0
  33. data_designer/config/analysis/utils/errors.py +8 -0
  34. data_designer/config/analysis/utils/reporting.py +188 -0
  35. data_designer/config/base.py +68 -0
  36. data_designer/config/column_configs.py +354 -0
  37. data_designer/config/column_types.py +168 -0
  38. data_designer/config/config_builder.py +660 -0
  39. data_designer/config/data_designer_config.py +40 -0
  40. data_designer/config/dataset_builders.py +11 -0
  41. data_designer/config/datastore.py +151 -0
  42. data_designer/config/default_model_settings.py +123 -0
  43. data_designer/config/errors.py +19 -0
  44. data_designer/config/interface.py +54 -0
  45. data_designer/config/models.py +231 -0
  46. data_designer/config/preview_results.py +32 -0
  47. data_designer/config/processors.py +41 -0
  48. data_designer/config/sampler_constraints.py +51 -0
  49. data_designer/config/sampler_params.py +604 -0
  50. data_designer/config/seed.py +145 -0
  51. data_designer/config/utils/code_lang.py +83 -0
  52. data_designer/config/utils/constants.py +313 -0
  53. data_designer/config/utils/errors.py +19 -0
  54. data_designer/config/utils/info.py +88 -0
  55. data_designer/config/utils/io_helpers.py +273 -0
  56. data_designer/config/utils/misc.py +81 -0
  57. data_designer/config/utils/numerical_helpers.py +28 -0
  58. data_designer/config/utils/type_helpers.py +100 -0
  59. data_designer/config/utils/validation.py +336 -0
  60. data_designer/config/utils/visualization.py +427 -0
  61. data_designer/config/validator_params.py +96 -0
  62. data_designer/engine/__init__.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +55 -0
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
  65. data_designer/engine/analysis/column_profilers/registry.py +20 -0
  66. data_designer/engine/analysis/column_statistics.py +142 -0
  67. data_designer/engine/analysis/dataset_profiler.py +125 -0
  68. data_designer/engine/analysis/errors.py +7 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
  70. data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
  71. data_designer/engine/column_generators/__init__.py +2 -0
  72. data_designer/engine/column_generators/generators/__init__.py +2 -0
  73. data_designer/engine/column_generators/generators/base.py +61 -0
  74. data_designer/engine/column_generators/generators/expression.py +63 -0
  75. data_designer/engine/column_generators/generators/llm_generators.py +172 -0
  76. data_designer/engine/column_generators/generators/samplers.py +75 -0
  77. data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
  78. data_designer/engine/column_generators/generators/validation.py +147 -0
  79. data_designer/engine/column_generators/registry.py +56 -0
  80. data_designer/engine/column_generators/utils/errors.py +13 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
  83. data_designer/engine/configurable_task.py +82 -0
  84. data_designer/engine/dataset_builders/artifact_storage.py +181 -0
  85. data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
  86. data_designer/engine/dataset_builders/errors.py +13 -0
  87. data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
  88. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +56 -0
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
  93. data_designer/engine/dataset_builders/utils/errors.py +13 -0
  94. data_designer/engine/errors.py +49 -0
  95. data_designer/engine/model_provider.py +75 -0
  96. data_designer/engine/models/__init__.py +2 -0
  97. data_designer/engine/models/errors.py +308 -0
  98. data_designer/engine/models/facade.py +225 -0
  99. data_designer/engine/models/litellm_overrides.py +162 -0
  100. data_designer/engine/models/parsers/__init__.py +2 -0
  101. data_designer/engine/models/parsers/errors.py +34 -0
  102. data_designer/engine/models/parsers/parser.py +236 -0
  103. data_designer/engine/models/parsers/postprocessors.py +93 -0
  104. data_designer/engine/models/parsers/tag_parsers.py +60 -0
  105. data_designer/engine/models/parsers/types.py +82 -0
  106. data_designer/engine/models/recipes/base.py +79 -0
  107. data_designer/engine/models/recipes/response_recipes.py +291 -0
  108. data_designer/engine/models/registry.py +118 -0
  109. data_designer/engine/models/usage.py +75 -0
  110. data_designer/engine/models/utils.py +38 -0
  111. data_designer/engine/processing/ginja/__init__.py +2 -0
  112. data_designer/engine/processing/ginja/ast.py +64 -0
  113. data_designer/engine/processing/ginja/environment.py +461 -0
  114. data_designer/engine/processing/ginja/exceptions.py +54 -0
  115. data_designer/engine/processing/ginja/record.py +30 -0
  116. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  117. data_designer/engine/processing/gsonschema/exceptions.py +8 -0
  118. data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
  119. data_designer/engine/processing/gsonschema/types.py +8 -0
  120. data_designer/engine/processing/gsonschema/validators.py +143 -0
  121. data_designer/engine/processing/processors/base.py +15 -0
  122. data_designer/engine/processing/processors/drop_columns.py +46 -0
  123. data_designer/engine/processing/processors/registry.py +20 -0
  124. data_designer/engine/processing/utils.py +120 -0
  125. data_designer/engine/registry/base.py +97 -0
  126. data_designer/engine/registry/data_designer_registry.py +37 -0
  127. data_designer/engine/registry/errors.py +10 -0
  128. data_designer/engine/resources/managed_dataset_generator.py +35 -0
  129. data_designer/engine/resources/managed_dataset_repository.py +194 -0
  130. data_designer/engine/resources/managed_storage.py +63 -0
  131. data_designer/engine/resources/resource_provider.py +46 -0
  132. data_designer/engine/resources/seed_dataset_data_store.py +66 -0
  133. data_designer/engine/sampling_gen/column.py +89 -0
  134. data_designer/engine/sampling_gen/constraints.py +95 -0
  135. data_designer/engine/sampling_gen/data_sources/base.py +214 -0
  136. data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
  137. data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
  138. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  139. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  140. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
  141. data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
  142. data_designer/engine/sampling_gen/entities/errors.py +8 -0
  143. data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
  144. data_designer/engine/sampling_gen/entities/person.py +142 -0
  145. data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
  146. data_designer/engine/sampling_gen/errors.py +24 -0
  147. data_designer/engine/sampling_gen/generator.py +121 -0
  148. data_designer/engine/sampling_gen/jinja_utils.py +60 -0
  149. data_designer/engine/sampling_gen/people_gen.py +203 -0
  150. data_designer/engine/sampling_gen/person_constants.py +54 -0
  151. data_designer/engine/sampling_gen/schema.py +143 -0
  152. data_designer/engine/sampling_gen/schema_builder.py +59 -0
  153. data_designer/engine/sampling_gen/utils.py +40 -0
  154. data_designer/engine/secret_resolver.py +80 -0
  155. data_designer/engine/validators/__init__.py +17 -0
  156. data_designer/engine/validators/base.py +36 -0
  157. data_designer/engine/validators/local_callable.py +34 -0
  158. data_designer/engine/validators/python.py +245 -0
  159. data_designer/engine/validators/remote.py +83 -0
  160. data_designer/engine/validators/sql.py +60 -0
  161. data_designer/errors.py +5 -0
  162. data_designer/essentials/__init__.py +137 -0
  163. data_designer/interface/__init__.py +2 -0
  164. data_designer/interface/data_designer.py +351 -0
  165. data_designer/interface/errors.py +16 -0
  166. data_designer/interface/results.py +55 -0
  167. data_designer/logging.py +161 -0
  168. data_designer/plugin_manager.py +83 -0
  169. data_designer/plugins/__init__.py +6 -0
  170. data_designer/plugins/errors.py +10 -0
  171. data_designer/plugins/plugin.py +69 -0
  172. data_designer/plugins/registry.py +86 -0
  173. data_designer-0.1.0.dist-info/METADATA +173 -0
  174. data_designer-0.1.0.dist-info/RECORD +177 -0
  175. data_designer-0.1.0.dist-info/WHEEL +4 -0
  176. data_designer-0.1.0.dist-info/entry_points.txt +2 -0
  177. data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,180 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC, abstractmethod
5
+ from collections.abc import Callable
6
+ from typing import Any, Generic, TypeVar
7
+
8
+ from data_designer.cli.utils import validate_numeric_range
9
+
10
+ T = TypeVar("T")
11
+
12
+
13
+ class ValidationError(Exception):
14
+ """Field validation error."""
15
+
16
+
17
+ class Field(ABC, Generic[T]):
18
+ """Base class for form fields."""
19
+
20
+ def __init__(
21
+ self,
22
+ name: str,
23
+ prompt: str,
24
+ default: T | None = None,
25
+ required: bool = True,
26
+ validator: Callable[[str], tuple[bool, str | None]] | None = None,
27
+ help_text: str | None = None,
28
+ ):
29
+ self.name = name
30
+ self.prompt = prompt
31
+ self.default = default
32
+ self.required = required
33
+ self.validator = validator
34
+ self.help_text = help_text
35
+ self._value: T | None = None
36
+
37
+ @property
38
+ def value(self) -> T | None:
39
+ """Get the current field value."""
40
+ return self._value
41
+
42
+ @value.setter
43
+ def value(self, val: T) -> None:
44
+ """Set and validate the field value."""
45
+ if self.validator:
46
+ # For string validators, convert to string first if needed
47
+ val_str = str(val) if not isinstance(val, str) else val
48
+ is_valid, error_msg = self.validator(val_str)
49
+ if not is_valid:
50
+ raise ValidationError(error_msg or "Invalid value")
51
+ self._value = val
52
+
53
+ @abstractmethod
54
+ def prompt_user(self, allow_back: bool = False) -> T | None | Any:
55
+ """Prompt user for input."""
56
+
57
+
58
+ class TextField(Field[str]):
59
+ """Text input field."""
60
+
61
+ def __init__(
62
+ self,
63
+ name: str,
64
+ prompt: str,
65
+ default: str | None = None,
66
+ required: bool = True,
67
+ validator: Callable[[str], tuple[bool, str | None]] | None = None,
68
+ completions: list[str] | None = None,
69
+ mask: bool = False,
70
+ help_text: str | None = None,
71
+ ):
72
+ super().__init__(name, prompt, default, required, validator, help_text)
73
+ self.completions = completions
74
+ self.mask = mask
75
+
76
+ def prompt_user(self, allow_back: bool = False) -> str | None | Any:
77
+ """Prompt user for text input."""
78
+ from data_designer.cli.ui import BACK, prompt_text_input
79
+
80
+ result = prompt_text_input(
81
+ self.prompt,
82
+ default=self.default,
83
+ validator=self.validator,
84
+ mask=self.mask,
85
+ completions=self.completions,
86
+ allow_back=allow_back,
87
+ )
88
+
89
+ if result is BACK:
90
+ return BACK
91
+
92
+ return result
93
+
94
+
95
+ class SelectField(Field[str]):
96
+ """Selection field with arrow navigation."""
97
+
98
+ def __init__(
99
+ self,
100
+ name: str,
101
+ prompt: str,
102
+ options: dict[str, str],
103
+ default: str | None = None,
104
+ required: bool = True,
105
+ help_text: str | None = None,
106
+ ):
107
+ super().__init__(name, prompt, default, required, None, help_text)
108
+ self.options = options
109
+
110
+ def prompt_user(self, allow_back: bool = False) -> str | None | Any:
111
+ """Prompt user for selection."""
112
+ from data_designer.cli.ui import BACK, select_with_arrows
113
+
114
+ result = select_with_arrows(
115
+ self.options,
116
+ self.prompt,
117
+ default_key=self.default,
118
+ allow_back=allow_back,
119
+ )
120
+
121
+ if result is BACK:
122
+ return BACK
123
+
124
+ return result
125
+
126
+
127
+ class NumericField(Field[float]):
128
+ """Numeric input field with range validation."""
129
+
130
+ def __init__(
131
+ self,
132
+ name: str,
133
+ prompt: str,
134
+ default: float | None = None,
135
+ min_value: float | None = None,
136
+ max_value: float | None = None,
137
+ required: bool = True,
138
+ help_text: str | None = None,
139
+ ):
140
+ self.min_value = min_value
141
+ self.max_value = max_value
142
+
143
+ # Build validator based on range
144
+ def range_validator(value: str) -> tuple[bool, str | None]:
145
+ if not value and not required:
146
+ return True, None
147
+ if min_value is not None and max_value is not None:
148
+ is_valid, parsed = validate_numeric_range(value, min_value, max_value)
149
+ if not is_valid:
150
+ return False, f"Value must be between {min_value} and {max_value}"
151
+ return True, None
152
+ try:
153
+ num = float(value)
154
+ if min_value is not None and num < min_value:
155
+ return False, f"Value must be >= {min_value}"
156
+ if max_value is not None and num > max_value:
157
+ return False, f"Value must be <= {max_value}"
158
+ return True, None
159
+ except ValueError:
160
+ return False, "Must be a valid number"
161
+
162
+ super().__init__(name, prompt, default, required, range_validator, help_text)
163
+
164
+ def prompt_user(self, allow_back: bool = False) -> float | None | Any:
165
+ """Prompt user for numeric input."""
166
+ from data_designer.cli.ui import BACK, prompt_text_input
167
+
168
+ default_str = str(self.default) if self.default is not None else None
169
+
170
+ result = prompt_text_input(
171
+ self.prompt,
172
+ default=default_str,
173
+ validator=self.validator,
174
+ allow_back=allow_back,
175
+ )
176
+
177
+ if result is BACK:
178
+ return BACK
179
+
180
+ return float(result) if result else None
@@ -0,0 +1,59 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ from data_designer.cli.forms.field import Field
7
+ from data_designer.cli.ui import BACK, print_error
8
+
9
+
10
+ class Form:
11
+ """A collection of fields forming a complete configuration form."""
12
+
13
+ def __init__(self, name: str, fields: list[Field]):
14
+ self.name = name
15
+ self.fields = fields
16
+ self._field_map = {f.name: f for f in fields}
17
+
18
+ def get_field(self, name: str) -> Field | None:
19
+ """Get a field by name."""
20
+ return self._field_map.get(name)
21
+
22
+ def get_values(self) -> dict[str, Any]:
23
+ """Get all field values as a dictionary."""
24
+ return {field.name: field.value for field in self.fields if field.value is not None}
25
+
26
+ def set_values(self, values: dict[str, Any]) -> None:
27
+ """Set field values from a dictionary."""
28
+ for name, value in values.items():
29
+ field = self.get_field(name)
30
+ if field:
31
+ field.value = value
32
+
33
+ def prompt_all(self, allow_back: bool = True) -> dict[str, Any] | None:
34
+ """Prompt user for all fields in sequence with back navigation."""
35
+ field_index = 0
36
+
37
+ while field_index < len(self.fields):
38
+ field = self.fields[field_index]
39
+
40
+ result = field.prompt_user(allow_back=allow_back and field_index > 0)
41
+
42
+ if result is None:
43
+ # User cancelled
44
+ return None
45
+ elif result is BACK:
46
+ # Go back to previous field
47
+ if field_index > 0:
48
+ field_index -= 1
49
+ continue
50
+ else:
51
+ # Store value and move forward
52
+ try:
53
+ field.value = result
54
+ field_index += 1
55
+ except Exception as e:
56
+ print_error(f"Validation error: {e}")
57
+ continue
58
+
59
+ return self.get_values()
@@ -0,0 +1,125 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ from data_designer.cli.forms.builder import FormBuilder
7
+ from data_designer.cli.forms.field import NumericField, SelectField, TextField
8
+ from data_designer.cli.forms.form import Form
9
+ from data_designer.config.models import ModelConfig
10
+ from data_designer.config.utils.constants import MAX_TEMPERATURE, MAX_TOP_P, MIN_TEMPERATURE, MIN_TOP_P
11
+
12
+
13
+ class ModelFormBuilder(FormBuilder[ModelConfig]):
14
+ """Builds interactive forms for model configuration."""
15
+
16
+ def __init__(self, existing_aliases: set[str] | None = None, available_providers: list[str] | None = None):
17
+ super().__init__("Model Configuration")
18
+ self.existing_aliases = existing_aliases or set()
19
+ self.available_providers = available_providers or []
20
+
21
+ def create_form(self, initial_data: dict[str, Any] | None = None) -> Form:
22
+ """Create the model configuration form."""
23
+ fields = []
24
+
25
+ # Model alias
26
+ fields.append(
27
+ TextField(
28
+ "alias",
29
+ "Model alias (used in your configs)",
30
+ default=initial_data.get("alias") if initial_data else None,
31
+ required=True,
32
+ validator=self._validate_alias,
33
+ )
34
+ )
35
+
36
+ # Model ID
37
+ fields.append(
38
+ TextField(
39
+ "model",
40
+ "Model ID",
41
+ default=initial_data.get("model") if initial_data else None,
42
+ required=True,
43
+ validator=lambda x: (False, "Model ID is required") if not x else (True, None),
44
+ )
45
+ )
46
+
47
+ # Provider (if multiple available)
48
+ if len(self.available_providers) > 1:
49
+ provider_options = {p: p for p in self.available_providers}
50
+ fields.append(
51
+ SelectField(
52
+ "provider",
53
+ "Select provider for this model",
54
+ options=provider_options,
55
+ default=initial_data.get("provider", self.available_providers[0])
56
+ if initial_data
57
+ else self.available_providers[0],
58
+ )
59
+ )
60
+ elif len(self.available_providers) == 1:
61
+ # Single provider - will be set automatically
62
+ pass
63
+
64
+ # Inference parameters
65
+ fields.extend(
66
+ [
67
+ NumericField(
68
+ "temperature",
69
+ f"Temperature ({MIN_TEMPERATURE}-{MAX_TEMPERATURE})",
70
+ default=initial_data.get("inference_parameters", {}).get("temperature", 0.7)
71
+ if initial_data
72
+ else 0.7,
73
+ min_value=MIN_TEMPERATURE,
74
+ max_value=MAX_TEMPERATURE,
75
+ ),
76
+ NumericField(
77
+ "top_p",
78
+ f"Top P ({MIN_TOP_P}-{MAX_TOP_P})",
79
+ default=initial_data.get("inference_parameters", {}).get("top_p", 0.9) if initial_data else 0.9,
80
+ min_value=MIN_TOP_P,
81
+ max_value=MAX_TOP_P,
82
+ ),
83
+ NumericField(
84
+ "max_tokens",
85
+ "Max tokens",
86
+ default=initial_data.get("inference_parameters", {}).get("max_tokens", 2048)
87
+ if initial_data
88
+ else 2048,
89
+ min_value=1,
90
+ max_value=100000,
91
+ ),
92
+ ]
93
+ )
94
+
95
+ return Form(self.title, fields)
96
+
97
+ def _validate_alias(self, alias: str) -> tuple[bool, str | None]:
98
+ """Validate model alias."""
99
+ if not alias:
100
+ return False, "Model alias is required"
101
+ if alias in self.existing_aliases:
102
+ return False, f"Model alias '{alias}' already exists"
103
+ return True, None
104
+
105
+ def build_config(self, form_data: dict[str, Any]) -> ModelConfig:
106
+ """Build ModelConfig from form data."""
107
+ # Determine provider
108
+ if "provider" in form_data:
109
+ provider = form_data["provider"]
110
+ elif len(self.available_providers) == 1:
111
+ provider = self.available_providers[0]
112
+ else:
113
+ provider = None
114
+
115
+ return ModelConfig(
116
+ alias=form_data["alias"],
117
+ model=form_data["model"],
118
+ provider=provider,
119
+ inference_parameters={
120
+ "temperature": form_data["temperature"],
121
+ "top_p": form_data["top_p"],
122
+ "max_tokens": int(form_data["max_tokens"]),
123
+ "max_parallel_requests": 4,
124
+ },
125
+ )
@@ -0,0 +1,76 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ from data_designer.cli.forms.builder import FormBuilder
7
+ from data_designer.cli.forms.field import TextField
8
+ from data_designer.cli.forms.form import Form
9
+ from data_designer.cli.utils import validate_url
10
+ from data_designer.engine.model_provider import ModelProvider
11
+
12
+
13
+ class ProviderFormBuilder(FormBuilder[ModelProvider]):
14
+ """Builds interactive forms for provider configuration."""
15
+
16
+ def __init__(self, existing_names: set[str] | None = None):
17
+ super().__init__("Provider Configuration")
18
+ self.existing_names = existing_names or set()
19
+
20
+ def create_form(self, initial_data: dict[str, Any] | None = None) -> Form:
21
+ """Create the provider configuration form."""
22
+ fields = [
23
+ TextField(
24
+ "name",
25
+ "Provider name",
26
+ default=initial_data.get("name") if initial_data else None,
27
+ required=True,
28
+ validator=self._validate_name,
29
+ ),
30
+ TextField(
31
+ "endpoint",
32
+ "API endpoint URL",
33
+ default=initial_data.get("endpoint") if initial_data else None,
34
+ required=True,
35
+ validator=self._validate_endpoint,
36
+ ),
37
+ TextField(
38
+ "provider_type",
39
+ "Provider type",
40
+ default=initial_data.get("provider_type", "openai") if initial_data else "openai",
41
+ required=True,
42
+ ),
43
+ TextField(
44
+ "api_key",
45
+ "API key or environment variable name",
46
+ default=initial_data.get("api_key") if initial_data else None,
47
+ required=False,
48
+ ),
49
+ ]
50
+
51
+ return Form(self.title, fields)
52
+
53
+ def _validate_name(self, name: str) -> tuple[bool, str | None]:
54
+ """Validate provider name."""
55
+ if not name:
56
+ return False, "Provider name is required"
57
+ if name in self.existing_names:
58
+ return False, f"Provider '{name}' already exists"
59
+ return True, None
60
+
61
+ def _validate_endpoint(self, endpoint: str) -> tuple[bool, str | None]:
62
+ """Validate endpoint URL."""
63
+ if not endpoint:
64
+ return False, "Endpoint URL is required"
65
+ if not validate_url(endpoint):
66
+ return False, "Invalid URL format (must start with http:// or https://)"
67
+ return True, None
68
+
69
+ def build_config(self, form_data: dict[str, Any]) -> ModelProvider:
70
+ """Build ModelProvider from form data."""
71
+ return ModelProvider(
72
+ name=form_data["name"],
73
+ endpoint=form_data["endpoint"],
74
+ provider_type=form_data["provider_type"],
75
+ api_key=form_data.get("api_key"),
76
+ )
@@ -0,0 +1,44 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import typer
5
+
6
+ from data_designer.cli.commands import list as list_cmd
7
+ from data_designer.cli.commands import models, providers, reset
8
+ from data_designer.config.default_model_settings import resolve_seed_default_model_settings
9
+ from data_designer.config.utils.misc import can_run_data_designer_locally
10
+
11
+ # Resolve default model settings on import to ensure they are available when the library is used.
12
+ if can_run_data_designer_locally():
13
+ resolve_seed_default_model_settings()
14
+
15
+ # Initialize Typer app with custom configuration
16
+ app = typer.Typer(
17
+ name="data-designer",
18
+ help="Data Designer CLI - Configure model providers and models for synthetic data generation",
19
+ add_completion=False,
20
+ no_args_is_help=True,
21
+ rich_markup_mode="rich",
22
+ )
23
+
24
+ # Create config subcommand group
25
+ config_app = typer.Typer(
26
+ name="config",
27
+ help="Manage configuration files",
28
+ no_args_is_help=True,
29
+ )
30
+ config_app.command(name="providers", help="Configure model providers interactively")(providers.providers_command)
31
+ config_app.command(name="models", help="Configure models interactively")(models.models_command)
32
+ config_app.command(name="list", help="List current configurations")(list_cmd.list_command)
33
+ config_app.command(name="reset", help="Reset configuration files")(reset.reset_command)
34
+
35
+ app.add_typer(config_app, name="config")
36
+
37
+
38
+ def main() -> None:
39
+ """Main entry point for the CLI."""
40
+ app()
41
+
42
+
43
+ if __name__ == "__main__":
44
+ main()
@@ -0,0 +1,8 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from data_designer.cli.repositories.base import ConfigRepository
5
+ from data_designer.cli.repositories.model_repository import ModelRepository
6
+ from data_designer.cli.repositories.provider_repository import ProviderRepository
7
+
8
+ __all__ = ["ConfigRepository", "ModelRepository", "ProviderRepository"]
@@ -0,0 +1,39 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC, abstractmethod
5
+ from pathlib import Path
6
+ from typing import Generic, TypeVar
7
+
8
+ from pydantic import BaseModel
9
+
10
+ T = TypeVar("T", bound=BaseModel)
11
+
12
+
13
+ class ConfigRepository(ABC, Generic[T]):
14
+ """Abstract base for configuration persistence."""
15
+
16
+ def __init__(self, config_dir: Path):
17
+ self.config_dir = config_dir
18
+
19
+ @property
20
+ @abstractmethod
21
+ def config_file(self) -> Path:
22
+ """Get the configuration file path."""
23
+
24
+ @abstractmethod
25
+ def load(self) -> T | None:
26
+ """Load configuration from file."""
27
+
28
+ @abstractmethod
29
+ def save(self, config: T) -> None:
30
+ """Save configuration to file."""
31
+
32
+ def exists(self) -> bool:
33
+ """Check if configuration file exists."""
34
+ return self.config_file.exists()
35
+
36
+ def delete(self) -> None:
37
+ """Delete configuration file."""
38
+ if self.exists():
39
+ self.config_file.unlink()
@@ -0,0 +1,42 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from pathlib import Path
5
+
6
+ from pydantic import BaseModel
7
+
8
+ from data_designer.cli.repositories.base import ConfigRepository
9
+ from data_designer.config.models import ModelConfig
10
+ from data_designer.config.utils.constants import MODEL_CONFIGS_FILE_NAME
11
+ from data_designer.config.utils.io_helpers import load_config_file, save_config_file
12
+
13
+
14
+ class ModelConfigRegistry(BaseModel):
15
+ """Registry for model configurations."""
16
+
17
+ model_configs: list[ModelConfig]
18
+
19
+
20
+ class ModelRepository(ConfigRepository[ModelConfigRegistry]):
21
+ """Repository for model configurations."""
22
+
23
+ @property
24
+ def config_file(self) -> Path:
25
+ """Get the model configuration file path."""
26
+ return self.config_dir / MODEL_CONFIGS_FILE_NAME
27
+
28
+ def load(self) -> ModelConfigRegistry | None:
29
+ """Load model configuration from file."""
30
+ if not self.exists():
31
+ return None
32
+
33
+ try:
34
+ config_dict = load_config_file(self.config_file)
35
+ return ModelConfigRegistry.model_validate(config_dict)
36
+ except Exception:
37
+ return None
38
+
39
+ def save(self, config: ModelConfigRegistry) -> None:
40
+ """Save model configuration to file."""
41
+ config_dict = config.model_dump(mode="json", exclude_none=True)
42
+ save_config_file(self.config_file, config_dict)
@@ -0,0 +1,43 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from pathlib import Path
5
+
6
+ from pydantic import BaseModel
7
+
8
+ from data_designer.cli.repositories.base import ConfigRepository
9
+ from data_designer.config.models import ModelProvider
10
+ from data_designer.config.utils.constants import MODEL_PROVIDERS_FILE_NAME
11
+ from data_designer.config.utils.io_helpers import load_config_file, save_config_file
12
+
13
+
14
+ class ModelProviderRegistry(BaseModel):
15
+ """Registry for model provider configurations."""
16
+
17
+ providers: list[ModelProvider]
18
+ default: str | None = None
19
+
20
+
21
+ class ProviderRepository(ConfigRepository[ModelProviderRegistry]):
22
+ """Repository for provider configurations."""
23
+
24
+ @property
25
+ def config_file(self) -> Path:
26
+ """Get the provider configuration file path."""
27
+ return self.config_dir / MODEL_PROVIDERS_FILE_NAME
28
+
29
+ def load(self) -> ModelProviderRegistry | None:
30
+ """Load provider configuration from file."""
31
+ if not self.exists():
32
+ return None
33
+
34
+ try:
35
+ config_dict = load_config_file(self.config_file)
36
+ return ModelProviderRegistry.model_validate(config_dict)
37
+ except Exception:
38
+ return None
39
+
40
+ def save(self, config: ModelProviderRegistry) -> None:
41
+ """Save provider configuration to file."""
42
+ config_dict = config.model_dump(mode="json", exclude_none=True)
43
+ save_config_file(self.config_file, config_dict)
@@ -0,0 +1,7 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from data_designer.cli.services.model_service import ModelService
5
+ from data_designer.cli.services.provider_service import ProviderService
6
+
7
+ __all__ = ["ModelService", "ProviderService"]