data-designer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +15 -0
- data_designer/_version.py +34 -0
- data_designer/cli/README.md +236 -0
- data_designer/cli/__init__.py +6 -0
- data_designer/cli/commands/__init__.py +2 -0
- data_designer/cli/commands/list.py +130 -0
- data_designer/cli/commands/models.py +10 -0
- data_designer/cli/commands/providers.py +11 -0
- data_designer/cli/commands/reset.py +100 -0
- data_designer/cli/controllers/__init__.py +7 -0
- data_designer/cli/controllers/model_controller.py +246 -0
- data_designer/cli/controllers/provider_controller.py +317 -0
- data_designer/cli/forms/__init__.py +20 -0
- data_designer/cli/forms/builder.py +51 -0
- data_designer/cli/forms/field.py +180 -0
- data_designer/cli/forms/form.py +59 -0
- data_designer/cli/forms/model_builder.py +125 -0
- data_designer/cli/forms/provider_builder.py +76 -0
- data_designer/cli/main.py +44 -0
- data_designer/cli/repositories/__init__.py +8 -0
- data_designer/cli/repositories/base.py +39 -0
- data_designer/cli/repositories/model_repository.py +42 -0
- data_designer/cli/repositories/provider_repository.py +43 -0
- data_designer/cli/services/__init__.py +7 -0
- data_designer/cli/services/model_service.py +116 -0
- data_designer/cli/services/provider_service.py +111 -0
- data_designer/cli/ui.py +448 -0
- data_designer/cli/utils.py +47 -0
- data_designer/config/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +89 -0
- data_designer/config/analysis/column_statistics.py +274 -0
- data_designer/config/analysis/dataset_profiler.py +60 -0
- data_designer/config/analysis/utils/errors.py +8 -0
- data_designer/config/analysis/utils/reporting.py +188 -0
- data_designer/config/base.py +68 -0
- data_designer/config/column_configs.py +354 -0
- data_designer/config/column_types.py +168 -0
- data_designer/config/config_builder.py +660 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +11 -0
- data_designer/config/datastore.py +151 -0
- data_designer/config/default_model_settings.py +123 -0
- data_designer/config/errors.py +19 -0
- data_designer/config/interface.py +54 -0
- data_designer/config/models.py +231 -0
- data_designer/config/preview_results.py +32 -0
- data_designer/config/processors.py +41 -0
- data_designer/config/sampler_constraints.py +51 -0
- data_designer/config/sampler_params.py +604 -0
- data_designer/config/seed.py +145 -0
- data_designer/config/utils/code_lang.py +83 -0
- data_designer/config/utils/constants.py +313 -0
- data_designer/config/utils/errors.py +19 -0
- data_designer/config/utils/info.py +88 -0
- data_designer/config/utils/io_helpers.py +273 -0
- data_designer/config/utils/misc.py +81 -0
- data_designer/config/utils/numerical_helpers.py +28 -0
- data_designer/config/utils/type_helpers.py +100 -0
- data_designer/config/utils/validation.py +336 -0
- data_designer/config/utils/visualization.py +427 -0
- data_designer/config/validator_params.py +96 -0
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +55 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
- data_designer/engine/analysis/column_profilers/registry.py +20 -0
- data_designer/engine/analysis/column_statistics.py +142 -0
- data_designer/engine/analysis/dataset_profiler.py +125 -0
- data_designer/engine/analysis/errors.py +7 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +61 -0
- data_designer/engine/column_generators/generators/expression.py +63 -0
- data_designer/engine/column_generators/generators/llm_generators.py +172 -0
- data_designer/engine/column_generators/generators/samplers.py +75 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
- data_designer/engine/column_generators/generators/validation.py +147 -0
- data_designer/engine/column_generators/registry.py +56 -0
- data_designer/engine/column_generators/utils/errors.py +13 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
- data_designer/engine/configurable_task.py +82 -0
- data_designer/engine/dataset_builders/artifact_storage.py +181 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
- data_designer/engine/dataset_builders/errors.py +13 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
- data_designer/engine/dataset_builders/utils/dag.py +56 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
- data_designer/engine/dataset_builders/utils/errors.py +13 -0
- data_designer/engine/errors.py +49 -0
- data_designer/engine/model_provider.py +75 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +308 -0
- data_designer/engine/models/facade.py +225 -0
- data_designer/engine/models/litellm_overrides.py +162 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +236 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +60 -0
- data_designer/engine/models/parsers/types.py +82 -0
- data_designer/engine/models/recipes/base.py +79 -0
- data_designer/engine/models/recipes/response_recipes.py +291 -0
- data_designer/engine/models/registry.py +118 -0
- data_designer/engine/models/usage.py +75 -0
- data_designer/engine/models/utils.py +38 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +64 -0
- data_designer/engine/processing/ginja/environment.py +461 -0
- data_designer/engine/processing/ginja/exceptions.py +54 -0
- data_designer/engine/processing/ginja/record.py +30 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +8 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
- data_designer/engine/processing/gsonschema/types.py +8 -0
- data_designer/engine/processing/gsonschema/validators.py +143 -0
- data_designer/engine/processing/processors/base.py +15 -0
- data_designer/engine/processing/processors/drop_columns.py +46 -0
- data_designer/engine/processing/processors/registry.py +20 -0
- data_designer/engine/processing/utils.py +120 -0
- data_designer/engine/registry/base.py +97 -0
- data_designer/engine/registry/data_designer_registry.py +37 -0
- data_designer/engine/registry/errors.py +10 -0
- data_designer/engine/resources/managed_dataset_generator.py +35 -0
- data_designer/engine/resources/managed_dataset_repository.py +194 -0
- data_designer/engine/resources/managed_storage.py +63 -0
- data_designer/engine/resources/resource_provider.py +46 -0
- data_designer/engine/resources/seed_dataset_data_store.py +66 -0
- data_designer/engine/sampling_gen/column.py +89 -0
- data_designer/engine/sampling_gen/constraints.py +95 -0
- data_designer/engine/sampling_gen/data_sources/base.py +214 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
- data_designer/engine/sampling_gen/entities/errors.py +8 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
- data_designer/engine/sampling_gen/entities/person.py +142 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
- data_designer/engine/sampling_gen/errors.py +24 -0
- data_designer/engine/sampling_gen/generator.py +121 -0
- data_designer/engine/sampling_gen/jinja_utils.py +60 -0
- data_designer/engine/sampling_gen/people_gen.py +203 -0
- data_designer/engine/sampling_gen/person_constants.py +54 -0
- data_designer/engine/sampling_gen/schema.py +143 -0
- data_designer/engine/sampling_gen/schema_builder.py +59 -0
- data_designer/engine/sampling_gen/utils.py +40 -0
- data_designer/engine/secret_resolver.py +80 -0
- data_designer/engine/validators/__init__.py +17 -0
- data_designer/engine/validators/base.py +36 -0
- data_designer/engine/validators/local_callable.py +34 -0
- data_designer/engine/validators/python.py +245 -0
- data_designer/engine/validators/remote.py +83 -0
- data_designer/engine/validators/sql.py +60 -0
- data_designer/errors.py +5 -0
- data_designer/essentials/__init__.py +137 -0
- data_designer/interface/__init__.py +2 -0
- data_designer/interface/data_designer.py +351 -0
- data_designer/interface/errors.py +16 -0
- data_designer/interface/results.py +55 -0
- data_designer/logging.py +161 -0
- data_designer/plugin_manager.py +83 -0
- data_designer/plugins/__init__.py +6 -0
- data_designer/plugins/errors.py +10 -0
- data_designer/plugins/plugin.py +69 -0
- data_designer/plugins/registry.py +86 -0
- data_designer-0.1.0.dist-info/METADATA +173 -0
- data_designer-0.1.0.dist-info/RECORD +177 -0
- data_designer-0.1.0.dist-info/WHEEL +4 -0
- data_designer-0.1.0.dist-info/entry_points.txt +2 -0
- data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any, Generic, TypeVar
|
|
7
|
+
|
|
8
|
+
from data_designer.cli.utils import validate_numeric_range
|
|
9
|
+
|
|
10
|
+
T = TypeVar("T")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ValidationError(Exception):
|
|
14
|
+
"""Field validation error."""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Field(ABC, Generic[T]):
|
|
18
|
+
"""Base class for form fields."""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
name: str,
|
|
23
|
+
prompt: str,
|
|
24
|
+
default: T | None = None,
|
|
25
|
+
required: bool = True,
|
|
26
|
+
validator: Callable[[str], tuple[bool, str | None]] | None = None,
|
|
27
|
+
help_text: str | None = None,
|
|
28
|
+
):
|
|
29
|
+
self.name = name
|
|
30
|
+
self.prompt = prompt
|
|
31
|
+
self.default = default
|
|
32
|
+
self.required = required
|
|
33
|
+
self.validator = validator
|
|
34
|
+
self.help_text = help_text
|
|
35
|
+
self._value: T | None = None
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def value(self) -> T | None:
|
|
39
|
+
"""Get the current field value."""
|
|
40
|
+
return self._value
|
|
41
|
+
|
|
42
|
+
@value.setter
|
|
43
|
+
def value(self, val: T) -> None:
|
|
44
|
+
"""Set and validate the field value."""
|
|
45
|
+
if self.validator:
|
|
46
|
+
# For string validators, convert to string first if needed
|
|
47
|
+
val_str = str(val) if not isinstance(val, str) else val
|
|
48
|
+
is_valid, error_msg = self.validator(val_str)
|
|
49
|
+
if not is_valid:
|
|
50
|
+
raise ValidationError(error_msg or "Invalid value")
|
|
51
|
+
self._value = val
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def prompt_user(self, allow_back: bool = False) -> T | None | Any:
|
|
55
|
+
"""Prompt user for input."""
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class TextField(Field[str]):
|
|
59
|
+
"""Text input field."""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
name: str,
|
|
64
|
+
prompt: str,
|
|
65
|
+
default: str | None = None,
|
|
66
|
+
required: bool = True,
|
|
67
|
+
validator: Callable[[str], tuple[bool, str | None]] | None = None,
|
|
68
|
+
completions: list[str] | None = None,
|
|
69
|
+
mask: bool = False,
|
|
70
|
+
help_text: str | None = None,
|
|
71
|
+
):
|
|
72
|
+
super().__init__(name, prompt, default, required, validator, help_text)
|
|
73
|
+
self.completions = completions
|
|
74
|
+
self.mask = mask
|
|
75
|
+
|
|
76
|
+
def prompt_user(self, allow_back: bool = False) -> str | None | Any:
|
|
77
|
+
"""Prompt user for text input."""
|
|
78
|
+
from data_designer.cli.ui import BACK, prompt_text_input
|
|
79
|
+
|
|
80
|
+
result = prompt_text_input(
|
|
81
|
+
self.prompt,
|
|
82
|
+
default=self.default,
|
|
83
|
+
validator=self.validator,
|
|
84
|
+
mask=self.mask,
|
|
85
|
+
completions=self.completions,
|
|
86
|
+
allow_back=allow_back,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if result is BACK:
|
|
90
|
+
return BACK
|
|
91
|
+
|
|
92
|
+
return result
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class SelectField(Field[str]):
|
|
96
|
+
"""Selection field with arrow navigation."""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
name: str,
|
|
101
|
+
prompt: str,
|
|
102
|
+
options: dict[str, str],
|
|
103
|
+
default: str | None = None,
|
|
104
|
+
required: bool = True,
|
|
105
|
+
help_text: str | None = None,
|
|
106
|
+
):
|
|
107
|
+
super().__init__(name, prompt, default, required, None, help_text)
|
|
108
|
+
self.options = options
|
|
109
|
+
|
|
110
|
+
def prompt_user(self, allow_back: bool = False) -> str | None | Any:
|
|
111
|
+
"""Prompt user for selection."""
|
|
112
|
+
from data_designer.cli.ui import BACK, select_with_arrows
|
|
113
|
+
|
|
114
|
+
result = select_with_arrows(
|
|
115
|
+
self.options,
|
|
116
|
+
self.prompt,
|
|
117
|
+
default_key=self.default,
|
|
118
|
+
allow_back=allow_back,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
if result is BACK:
|
|
122
|
+
return BACK
|
|
123
|
+
|
|
124
|
+
return result
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class NumericField(Field[float]):
|
|
128
|
+
"""Numeric input field with range validation."""
|
|
129
|
+
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
name: str,
|
|
133
|
+
prompt: str,
|
|
134
|
+
default: float | None = None,
|
|
135
|
+
min_value: float | None = None,
|
|
136
|
+
max_value: float | None = None,
|
|
137
|
+
required: bool = True,
|
|
138
|
+
help_text: str | None = None,
|
|
139
|
+
):
|
|
140
|
+
self.min_value = min_value
|
|
141
|
+
self.max_value = max_value
|
|
142
|
+
|
|
143
|
+
# Build validator based on range
|
|
144
|
+
def range_validator(value: str) -> tuple[bool, str | None]:
|
|
145
|
+
if not value and not required:
|
|
146
|
+
return True, None
|
|
147
|
+
if min_value is not None and max_value is not None:
|
|
148
|
+
is_valid, parsed = validate_numeric_range(value, min_value, max_value)
|
|
149
|
+
if not is_valid:
|
|
150
|
+
return False, f"Value must be between {min_value} and {max_value}"
|
|
151
|
+
return True, None
|
|
152
|
+
try:
|
|
153
|
+
num = float(value)
|
|
154
|
+
if min_value is not None and num < min_value:
|
|
155
|
+
return False, f"Value must be >= {min_value}"
|
|
156
|
+
if max_value is not None and num > max_value:
|
|
157
|
+
return False, f"Value must be <= {max_value}"
|
|
158
|
+
return True, None
|
|
159
|
+
except ValueError:
|
|
160
|
+
return False, "Must be a valid number"
|
|
161
|
+
|
|
162
|
+
super().__init__(name, prompt, default, required, range_validator, help_text)
|
|
163
|
+
|
|
164
|
+
def prompt_user(self, allow_back: bool = False) -> float | None | Any:
|
|
165
|
+
"""Prompt user for numeric input."""
|
|
166
|
+
from data_designer.cli.ui import BACK, prompt_text_input
|
|
167
|
+
|
|
168
|
+
default_str = str(self.default) if self.default is not None else None
|
|
169
|
+
|
|
170
|
+
result = prompt_text_input(
|
|
171
|
+
self.prompt,
|
|
172
|
+
default=default_str,
|
|
173
|
+
validator=self.validator,
|
|
174
|
+
allow_back=allow_back,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
if result is BACK:
|
|
178
|
+
return BACK
|
|
179
|
+
|
|
180
|
+
return float(result) if result else None
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from data_designer.cli.forms.field import Field
|
|
7
|
+
from data_designer.cli.ui import BACK, print_error
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Form:
|
|
11
|
+
"""A collection of fields forming a complete configuration form."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, name: str, fields: list[Field]):
|
|
14
|
+
self.name = name
|
|
15
|
+
self.fields = fields
|
|
16
|
+
self._field_map = {f.name: f for f in fields}
|
|
17
|
+
|
|
18
|
+
def get_field(self, name: str) -> Field | None:
|
|
19
|
+
"""Get a field by name."""
|
|
20
|
+
return self._field_map.get(name)
|
|
21
|
+
|
|
22
|
+
def get_values(self) -> dict[str, Any]:
|
|
23
|
+
"""Get all field values as a dictionary."""
|
|
24
|
+
return {field.name: field.value for field in self.fields if field.value is not None}
|
|
25
|
+
|
|
26
|
+
def set_values(self, values: dict[str, Any]) -> None:
|
|
27
|
+
"""Set field values from a dictionary."""
|
|
28
|
+
for name, value in values.items():
|
|
29
|
+
field = self.get_field(name)
|
|
30
|
+
if field:
|
|
31
|
+
field.value = value
|
|
32
|
+
|
|
33
|
+
def prompt_all(self, allow_back: bool = True) -> dict[str, Any] | None:
|
|
34
|
+
"""Prompt user for all fields in sequence with back navigation."""
|
|
35
|
+
field_index = 0
|
|
36
|
+
|
|
37
|
+
while field_index < len(self.fields):
|
|
38
|
+
field = self.fields[field_index]
|
|
39
|
+
|
|
40
|
+
result = field.prompt_user(allow_back=allow_back and field_index > 0)
|
|
41
|
+
|
|
42
|
+
if result is None:
|
|
43
|
+
# User cancelled
|
|
44
|
+
return None
|
|
45
|
+
elif result is BACK:
|
|
46
|
+
# Go back to previous field
|
|
47
|
+
if field_index > 0:
|
|
48
|
+
field_index -= 1
|
|
49
|
+
continue
|
|
50
|
+
else:
|
|
51
|
+
# Store value and move forward
|
|
52
|
+
try:
|
|
53
|
+
field.value = result
|
|
54
|
+
field_index += 1
|
|
55
|
+
except Exception as e:
|
|
56
|
+
print_error(f"Validation error: {e}")
|
|
57
|
+
continue
|
|
58
|
+
|
|
59
|
+
return self.get_values()
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from data_designer.cli.forms.builder import FormBuilder
|
|
7
|
+
from data_designer.cli.forms.field import NumericField, SelectField, TextField
|
|
8
|
+
from data_designer.cli.forms.form import Form
|
|
9
|
+
from data_designer.config.models import ModelConfig
|
|
10
|
+
from data_designer.config.utils.constants import MAX_TEMPERATURE, MAX_TOP_P, MIN_TEMPERATURE, MIN_TOP_P
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ModelFormBuilder(FormBuilder[ModelConfig]):
|
|
14
|
+
"""Builds interactive forms for model configuration."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, existing_aliases: set[str] | None = None, available_providers: list[str] | None = None):
|
|
17
|
+
super().__init__("Model Configuration")
|
|
18
|
+
self.existing_aliases = existing_aliases or set()
|
|
19
|
+
self.available_providers = available_providers or []
|
|
20
|
+
|
|
21
|
+
def create_form(self, initial_data: dict[str, Any] | None = None) -> Form:
|
|
22
|
+
"""Create the model configuration form."""
|
|
23
|
+
fields = []
|
|
24
|
+
|
|
25
|
+
# Model alias
|
|
26
|
+
fields.append(
|
|
27
|
+
TextField(
|
|
28
|
+
"alias",
|
|
29
|
+
"Model alias (used in your configs)",
|
|
30
|
+
default=initial_data.get("alias") if initial_data else None,
|
|
31
|
+
required=True,
|
|
32
|
+
validator=self._validate_alias,
|
|
33
|
+
)
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# Model ID
|
|
37
|
+
fields.append(
|
|
38
|
+
TextField(
|
|
39
|
+
"model",
|
|
40
|
+
"Model ID",
|
|
41
|
+
default=initial_data.get("model") if initial_data else None,
|
|
42
|
+
required=True,
|
|
43
|
+
validator=lambda x: (False, "Model ID is required") if not x else (True, None),
|
|
44
|
+
)
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Provider (if multiple available)
|
|
48
|
+
if len(self.available_providers) > 1:
|
|
49
|
+
provider_options = {p: p for p in self.available_providers}
|
|
50
|
+
fields.append(
|
|
51
|
+
SelectField(
|
|
52
|
+
"provider",
|
|
53
|
+
"Select provider for this model",
|
|
54
|
+
options=provider_options,
|
|
55
|
+
default=initial_data.get("provider", self.available_providers[0])
|
|
56
|
+
if initial_data
|
|
57
|
+
else self.available_providers[0],
|
|
58
|
+
)
|
|
59
|
+
)
|
|
60
|
+
elif len(self.available_providers) == 1:
|
|
61
|
+
# Single provider - will be set automatically
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
# Inference parameters
|
|
65
|
+
fields.extend(
|
|
66
|
+
[
|
|
67
|
+
NumericField(
|
|
68
|
+
"temperature",
|
|
69
|
+
f"Temperature ({MIN_TEMPERATURE}-{MAX_TEMPERATURE})",
|
|
70
|
+
default=initial_data.get("inference_parameters", {}).get("temperature", 0.7)
|
|
71
|
+
if initial_data
|
|
72
|
+
else 0.7,
|
|
73
|
+
min_value=MIN_TEMPERATURE,
|
|
74
|
+
max_value=MAX_TEMPERATURE,
|
|
75
|
+
),
|
|
76
|
+
NumericField(
|
|
77
|
+
"top_p",
|
|
78
|
+
f"Top P ({MIN_TOP_P}-{MAX_TOP_P})",
|
|
79
|
+
default=initial_data.get("inference_parameters", {}).get("top_p", 0.9) if initial_data else 0.9,
|
|
80
|
+
min_value=MIN_TOP_P,
|
|
81
|
+
max_value=MAX_TOP_P,
|
|
82
|
+
),
|
|
83
|
+
NumericField(
|
|
84
|
+
"max_tokens",
|
|
85
|
+
"Max tokens",
|
|
86
|
+
default=initial_data.get("inference_parameters", {}).get("max_tokens", 2048)
|
|
87
|
+
if initial_data
|
|
88
|
+
else 2048,
|
|
89
|
+
min_value=1,
|
|
90
|
+
max_value=100000,
|
|
91
|
+
),
|
|
92
|
+
]
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return Form(self.title, fields)
|
|
96
|
+
|
|
97
|
+
def _validate_alias(self, alias: str) -> tuple[bool, str | None]:
|
|
98
|
+
"""Validate model alias."""
|
|
99
|
+
if not alias:
|
|
100
|
+
return False, "Model alias is required"
|
|
101
|
+
if alias in self.existing_aliases:
|
|
102
|
+
return False, f"Model alias '{alias}' already exists"
|
|
103
|
+
return True, None
|
|
104
|
+
|
|
105
|
+
def build_config(self, form_data: dict[str, Any]) -> ModelConfig:
|
|
106
|
+
"""Build ModelConfig from form data."""
|
|
107
|
+
# Determine provider
|
|
108
|
+
if "provider" in form_data:
|
|
109
|
+
provider = form_data["provider"]
|
|
110
|
+
elif len(self.available_providers) == 1:
|
|
111
|
+
provider = self.available_providers[0]
|
|
112
|
+
else:
|
|
113
|
+
provider = None
|
|
114
|
+
|
|
115
|
+
return ModelConfig(
|
|
116
|
+
alias=form_data["alias"],
|
|
117
|
+
model=form_data["model"],
|
|
118
|
+
provider=provider,
|
|
119
|
+
inference_parameters={
|
|
120
|
+
"temperature": form_data["temperature"],
|
|
121
|
+
"top_p": form_data["top_p"],
|
|
122
|
+
"max_tokens": int(form_data["max_tokens"]),
|
|
123
|
+
"max_parallel_requests": 4,
|
|
124
|
+
},
|
|
125
|
+
)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from data_designer.cli.forms.builder import FormBuilder
|
|
7
|
+
from data_designer.cli.forms.field import TextField
|
|
8
|
+
from data_designer.cli.forms.form import Form
|
|
9
|
+
from data_designer.cli.utils import validate_url
|
|
10
|
+
from data_designer.engine.model_provider import ModelProvider
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ProviderFormBuilder(FormBuilder[ModelProvider]):
|
|
14
|
+
"""Builds interactive forms for provider configuration."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, existing_names: set[str] | None = None):
|
|
17
|
+
super().__init__("Provider Configuration")
|
|
18
|
+
self.existing_names = existing_names or set()
|
|
19
|
+
|
|
20
|
+
def create_form(self, initial_data: dict[str, Any] | None = None) -> Form:
|
|
21
|
+
"""Create the provider configuration form."""
|
|
22
|
+
fields = [
|
|
23
|
+
TextField(
|
|
24
|
+
"name",
|
|
25
|
+
"Provider name",
|
|
26
|
+
default=initial_data.get("name") if initial_data else None,
|
|
27
|
+
required=True,
|
|
28
|
+
validator=self._validate_name,
|
|
29
|
+
),
|
|
30
|
+
TextField(
|
|
31
|
+
"endpoint",
|
|
32
|
+
"API endpoint URL",
|
|
33
|
+
default=initial_data.get("endpoint") if initial_data else None,
|
|
34
|
+
required=True,
|
|
35
|
+
validator=self._validate_endpoint,
|
|
36
|
+
),
|
|
37
|
+
TextField(
|
|
38
|
+
"provider_type",
|
|
39
|
+
"Provider type",
|
|
40
|
+
default=initial_data.get("provider_type", "openai") if initial_data else "openai",
|
|
41
|
+
required=True,
|
|
42
|
+
),
|
|
43
|
+
TextField(
|
|
44
|
+
"api_key",
|
|
45
|
+
"API key or environment variable name",
|
|
46
|
+
default=initial_data.get("api_key") if initial_data else None,
|
|
47
|
+
required=False,
|
|
48
|
+
),
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
return Form(self.title, fields)
|
|
52
|
+
|
|
53
|
+
def _validate_name(self, name: str) -> tuple[bool, str | None]:
|
|
54
|
+
"""Validate provider name."""
|
|
55
|
+
if not name:
|
|
56
|
+
return False, "Provider name is required"
|
|
57
|
+
if name in self.existing_names:
|
|
58
|
+
return False, f"Provider '{name}' already exists"
|
|
59
|
+
return True, None
|
|
60
|
+
|
|
61
|
+
def _validate_endpoint(self, endpoint: str) -> tuple[bool, str | None]:
|
|
62
|
+
"""Validate endpoint URL."""
|
|
63
|
+
if not endpoint:
|
|
64
|
+
return False, "Endpoint URL is required"
|
|
65
|
+
if not validate_url(endpoint):
|
|
66
|
+
return False, "Invalid URL format (must start with http:// or https://)"
|
|
67
|
+
return True, None
|
|
68
|
+
|
|
69
|
+
def build_config(self, form_data: dict[str, Any]) -> ModelProvider:
|
|
70
|
+
"""Build ModelProvider from form data."""
|
|
71
|
+
return ModelProvider(
|
|
72
|
+
name=form_data["name"],
|
|
73
|
+
endpoint=form_data["endpoint"],
|
|
74
|
+
provider_type=form_data["provider_type"],
|
|
75
|
+
api_key=form_data.get("api_key"),
|
|
76
|
+
)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import typer
|
|
5
|
+
|
|
6
|
+
from data_designer.cli.commands import list as list_cmd
|
|
7
|
+
from data_designer.cli.commands import models, providers, reset
|
|
8
|
+
from data_designer.config.default_model_settings import resolve_seed_default_model_settings
|
|
9
|
+
from data_designer.config.utils.misc import can_run_data_designer_locally
|
|
10
|
+
|
|
11
|
+
# Resolve default model settings on import to ensure they are available when the library is used.
|
|
12
|
+
if can_run_data_designer_locally():
|
|
13
|
+
resolve_seed_default_model_settings()
|
|
14
|
+
|
|
15
|
+
# Initialize Typer app with custom configuration
|
|
16
|
+
app = typer.Typer(
|
|
17
|
+
name="data-designer",
|
|
18
|
+
help="Data Designer CLI - Configure model providers and models for synthetic data generation",
|
|
19
|
+
add_completion=False,
|
|
20
|
+
no_args_is_help=True,
|
|
21
|
+
rich_markup_mode="rich",
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# Create config subcommand group
|
|
25
|
+
config_app = typer.Typer(
|
|
26
|
+
name="config",
|
|
27
|
+
help="Manage configuration files",
|
|
28
|
+
no_args_is_help=True,
|
|
29
|
+
)
|
|
30
|
+
config_app.command(name="providers", help="Configure model providers interactively")(providers.providers_command)
|
|
31
|
+
config_app.command(name="models", help="Configure models interactively")(models.models_command)
|
|
32
|
+
config_app.command(name="list", help="List current configurations")(list_cmd.list_command)
|
|
33
|
+
config_app.command(name="reset", help="Reset configuration files")(reset.reset_command)
|
|
34
|
+
|
|
35
|
+
app.add_typer(config_app, name="config")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def main() -> None:
|
|
39
|
+
"""Main entry point for the CLI."""
|
|
40
|
+
app()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
if __name__ == "__main__":
|
|
44
|
+
main()
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.cli.repositories.base import ConfigRepository
|
|
5
|
+
from data_designer.cli.repositories.model_repository import ModelRepository
|
|
6
|
+
from data_designer.cli.repositories.provider_repository import ProviderRepository
|
|
7
|
+
|
|
8
|
+
__all__ = ["ConfigRepository", "ModelRepository", "ProviderRepository"]
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Generic, TypeVar
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
|
|
10
|
+
T = TypeVar("T", bound=BaseModel)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ConfigRepository(ABC, Generic[T]):
|
|
14
|
+
"""Abstract base for configuration persistence."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, config_dir: Path):
|
|
17
|
+
self.config_dir = config_dir
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def config_file(self) -> Path:
|
|
22
|
+
"""Get the configuration file path."""
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def load(self) -> T | None:
|
|
26
|
+
"""Load configuration from file."""
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def save(self, config: T) -> None:
|
|
30
|
+
"""Save configuration to file."""
|
|
31
|
+
|
|
32
|
+
def exists(self) -> bool:
|
|
33
|
+
"""Check if configuration file exists."""
|
|
34
|
+
return self.config_file.exists()
|
|
35
|
+
|
|
36
|
+
def delete(self) -> None:
|
|
37
|
+
"""Delete configuration file."""
|
|
38
|
+
if self.exists():
|
|
39
|
+
self.config_file.unlink()
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from data_designer.cli.repositories.base import ConfigRepository
|
|
9
|
+
from data_designer.config.models import ModelConfig
|
|
10
|
+
from data_designer.config.utils.constants import MODEL_CONFIGS_FILE_NAME
|
|
11
|
+
from data_designer.config.utils.io_helpers import load_config_file, save_config_file
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ModelConfigRegistry(BaseModel):
|
|
15
|
+
"""Registry for model configurations."""
|
|
16
|
+
|
|
17
|
+
model_configs: list[ModelConfig]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ModelRepository(ConfigRepository[ModelConfigRegistry]):
|
|
21
|
+
"""Repository for model configurations."""
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def config_file(self) -> Path:
|
|
25
|
+
"""Get the model configuration file path."""
|
|
26
|
+
return self.config_dir / MODEL_CONFIGS_FILE_NAME
|
|
27
|
+
|
|
28
|
+
def load(self) -> ModelConfigRegistry | None:
|
|
29
|
+
"""Load model configuration from file."""
|
|
30
|
+
if not self.exists():
|
|
31
|
+
return None
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
config_dict = load_config_file(self.config_file)
|
|
35
|
+
return ModelConfigRegistry.model_validate(config_dict)
|
|
36
|
+
except Exception:
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
def save(self, config: ModelConfigRegistry) -> None:
|
|
40
|
+
"""Save model configuration to file."""
|
|
41
|
+
config_dict = config.model_dump(mode="json", exclude_none=True)
|
|
42
|
+
save_config_file(self.config_file, config_dict)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from data_designer.cli.repositories.base import ConfigRepository
|
|
9
|
+
from data_designer.config.models import ModelProvider
|
|
10
|
+
from data_designer.config.utils.constants import MODEL_PROVIDERS_FILE_NAME
|
|
11
|
+
from data_designer.config.utils.io_helpers import load_config_file, save_config_file
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ModelProviderRegistry(BaseModel):
|
|
15
|
+
"""Registry for model provider configurations."""
|
|
16
|
+
|
|
17
|
+
providers: list[ModelProvider]
|
|
18
|
+
default: str | None = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ProviderRepository(ConfigRepository[ModelProviderRegistry]):
|
|
22
|
+
"""Repository for provider configurations."""
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def config_file(self) -> Path:
|
|
26
|
+
"""Get the provider configuration file path."""
|
|
27
|
+
return self.config_dir / MODEL_PROVIDERS_FILE_NAME
|
|
28
|
+
|
|
29
|
+
def load(self) -> ModelProviderRegistry | None:
|
|
30
|
+
"""Load provider configuration from file."""
|
|
31
|
+
if not self.exists():
|
|
32
|
+
return None
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
config_dict = load_config_file(self.config_file)
|
|
36
|
+
return ModelProviderRegistry.model_validate(config_dict)
|
|
37
|
+
except Exception:
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
def save(self, config: ModelProviderRegistry) -> None:
|
|
41
|
+
"""Save provider configuration to file."""
|
|
42
|
+
config_dict = config.model_dump(mode="json", exclude_none=True)
|
|
43
|
+
save_config_file(self.config_file, config_dict)
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.cli.services.model_service import ModelService
|
|
5
|
+
from data_designer.cli.services.provider_service import ProviderService
|
|
6
|
+
|
|
7
|
+
__all__ = ["ModelService", "ProviderService"]
|