data-designer-config 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/config/__init__.py +149 -0
- data_designer/config/_version.py +34 -0
- data_designer/config/analysis/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +159 -0
- data_designer/config/analysis/column_statistics.py +421 -0
- data_designer/config/analysis/dataset_profiler.py +84 -0
- data_designer/config/analysis/utils/errors.py +10 -0
- data_designer/config/analysis/utils/reporting.py +192 -0
- data_designer/config/base.py +69 -0
- data_designer/config/column_configs.py +476 -0
- data_designer/config/column_types.py +141 -0
- data_designer/config/config_builder.py +595 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +13 -0
- data_designer/config/dataset_metadata.py +18 -0
- data_designer/config/default_model_settings.py +129 -0
- data_designer/config/errors.py +24 -0
- data_designer/config/interface.py +55 -0
- data_designer/config/models.py +486 -0
- data_designer/config/preview_results.py +41 -0
- data_designer/config/processors.py +148 -0
- data_designer/config/run_config.py +56 -0
- data_designer/config/sampler_constraints.py +52 -0
- data_designer/config/sampler_params.py +639 -0
- data_designer/config/seed.py +116 -0
- data_designer/config/seed_source.py +84 -0
- data_designer/config/seed_source_types.py +19 -0
- data_designer/config/testing/__init__.py +6 -0
- data_designer/config/testing/fixtures.py +308 -0
- data_designer/config/utils/code_lang.py +93 -0
- data_designer/config/utils/constants.py +365 -0
- data_designer/config/utils/errors.py +21 -0
- data_designer/config/utils/info.py +94 -0
- data_designer/config/utils/io_helpers.py +258 -0
- data_designer/config/utils/misc.py +78 -0
- data_designer/config/utils/numerical_helpers.py +30 -0
- data_designer/config/utils/type_helpers.py +106 -0
- data_designer/config/utils/visualization.py +482 -0
- data_designer/config/validator_params.py +94 -0
- data_designer/errors.py +7 -0
- data_designer/lazy_heavy_imports.py +56 -0
- data_designer/logging.py +180 -0
- data_designer/plugin_manager.py +78 -0
- data_designer/plugins/__init__.py +8 -0
- data_designer/plugins/errors.py +15 -0
- data_designer/plugins/plugin.py +141 -0
- data_designer/plugins/registry.py +88 -0
- data_designer_config-0.4.0.dist-info/METADATA +75 -0
- data_designer_config-0.4.0.dist-info/RECORD +50 -0
- data_designer_config-0.4.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
from rich.theme import Theme
|
|
11
|
+
|
|
12
|
+
DEFAULT_NUM_RECORDS = 10
|
|
13
|
+
|
|
14
|
+
EPSILON = 1e-8
|
|
15
|
+
REPORTING_PRECISION = 2
|
|
16
|
+
|
|
17
|
+
DEFAULT_REPR_HTML_STYLE = "nord"
|
|
18
|
+
|
|
19
|
+
REPR_HTML_FIXED_WIDTH = 1000
|
|
20
|
+
REPR_HTML_TEMPLATE = """
|
|
21
|
+
<meta charset="UTF-8">
|
|
22
|
+
<style>
|
|
23
|
+
{{css}}
|
|
24
|
+
|
|
25
|
+
.code {{{{
|
|
26
|
+
padding: 4px;
|
|
27
|
+
border: 1px solid grey;
|
|
28
|
+
border-radius: 4px;
|
|
29
|
+
max-width: {fixed_width}px;
|
|
30
|
+
width: 100%;
|
|
31
|
+
display: inline-block;
|
|
32
|
+
box-sizing: border-box;
|
|
33
|
+
text-align: left;
|
|
34
|
+
vertical-align: top;
|
|
35
|
+
line-height: normal;
|
|
36
|
+
overflow-x: auto;
|
|
37
|
+
}}}}
|
|
38
|
+
|
|
39
|
+
.code pre {{{{
|
|
40
|
+
white-space: pre-wrap; /* CSS 3 */
|
|
41
|
+
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
|
42
|
+
white-space: -pre-wrap; /* Opera 4-6 */
|
|
43
|
+
white-space: -o-pre-wrap; /* Opera 7 */
|
|
44
|
+
word-wrap: break-word;
|
|
45
|
+
overflow-wrap: break-word;
|
|
46
|
+
margin: 0;
|
|
47
|
+
}}}}
|
|
48
|
+
</style>
|
|
49
|
+
{{highlighted_html}}
|
|
50
|
+
""".format(fixed_width=REPR_HTML_FIXED_WIDTH)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class NordColor(Enum):
|
|
54
|
+
NORD0 = "#2E3440" # Darkest gray (background)
|
|
55
|
+
NORD1 = "#3B4252" # Dark gray
|
|
56
|
+
NORD2 = "#434C5E" # Medium dark gray
|
|
57
|
+
NORD3 = "#4C566A" # Lighter dark gray
|
|
58
|
+
NORD4 = "#D8DEE9" # Light gray (default text)
|
|
59
|
+
NORD5 = "#E5E9F0" # Very light gray
|
|
60
|
+
NORD6 = "#ECEFF4" # Almost white
|
|
61
|
+
NORD7 = "#8FBCBB" # Teal
|
|
62
|
+
NORD8 = "#88C0D0" # Light cyan
|
|
63
|
+
NORD9 = "#81A1C1" # Soft blue
|
|
64
|
+
NORD10 = "#5E81AC" # Darker blue
|
|
65
|
+
NORD11 = "#BF616A" # Red
|
|
66
|
+
NORD12 = "#D08770" # Orange
|
|
67
|
+
NORD13 = "#EBCB8B" # Yellow
|
|
68
|
+
NORD14 = "#A3BE8C" # Green
|
|
69
|
+
NORD15 = "#B48EAD" # Purple
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
RICH_CONSOLE_THEME = Theme(
|
|
73
|
+
{
|
|
74
|
+
"repr.number": NordColor.NORD15.value, # Purple for numbers
|
|
75
|
+
"repr.string": NordColor.NORD14.value, # Green for strings
|
|
76
|
+
"repr.bool_true": NordColor.NORD9.value, # Blue for True
|
|
77
|
+
"repr.bool_false": NordColor.NORD9.value, # Blue for False
|
|
78
|
+
"repr.none": NordColor.NORD11.value, # Red for None
|
|
79
|
+
"repr.brace": NordColor.NORD7.value, # Teal for brackets/braces
|
|
80
|
+
"repr.comma": NordColor.NORD7.value, # Teal for commas
|
|
81
|
+
"repr.ellipsis": NordColor.NORD7.value, # Teal for ellipsis
|
|
82
|
+
"repr.attrib_name": NordColor.NORD3.value, # Light gray for dict keys
|
|
83
|
+
"repr.attrib_equal": NordColor.NORD7.value, # Teal for equals signs
|
|
84
|
+
"repr.call": NordColor.NORD10.value, # Darker blue for function calls
|
|
85
|
+
"repr.function_name": NordColor.NORD10.value, # Darker blue for function names
|
|
86
|
+
"repr.class_name": NordColor.NORD12.value, # Orange for class names
|
|
87
|
+
"repr.module_name": NordColor.NORD8.value, # Light cyan for module names
|
|
88
|
+
"repr.error": NordColor.NORD11.value, # Red for errors
|
|
89
|
+
"repr.warning": NordColor.NORD13.value, # Yellow for warnings
|
|
90
|
+
}
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
DEFAULT_HIST_NAME_COLOR = "medium_purple1"
|
|
94
|
+
|
|
95
|
+
DEFAULT_HIST_VALUE_COLOR = "pale_green3"
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
DEFAULT_AGE_RANGE = [18, 114]
|
|
99
|
+
MIN_AGE = 0
|
|
100
|
+
MAX_AGE = 114
|
|
101
|
+
|
|
102
|
+
US_STATES_AND_MAJOR_TERRITORIES = {
|
|
103
|
+
# States
|
|
104
|
+
"AK",
|
|
105
|
+
"AL",
|
|
106
|
+
"AR",
|
|
107
|
+
"AZ",
|
|
108
|
+
"CA",
|
|
109
|
+
"CO",
|
|
110
|
+
"CT",
|
|
111
|
+
"DE",
|
|
112
|
+
"FL",
|
|
113
|
+
"GA",
|
|
114
|
+
"HI",
|
|
115
|
+
"IA",
|
|
116
|
+
"ID",
|
|
117
|
+
"IL",
|
|
118
|
+
"IN",
|
|
119
|
+
"KS",
|
|
120
|
+
"KY",
|
|
121
|
+
"LA",
|
|
122
|
+
"MA",
|
|
123
|
+
"MD",
|
|
124
|
+
"ME",
|
|
125
|
+
"MI",
|
|
126
|
+
"MN",
|
|
127
|
+
"MO",
|
|
128
|
+
"MS",
|
|
129
|
+
"MT",
|
|
130
|
+
"NC",
|
|
131
|
+
"ND",
|
|
132
|
+
"NE",
|
|
133
|
+
"NH",
|
|
134
|
+
"NJ",
|
|
135
|
+
"NM",
|
|
136
|
+
"NV",
|
|
137
|
+
"NY",
|
|
138
|
+
"OH",
|
|
139
|
+
"OK",
|
|
140
|
+
"OR",
|
|
141
|
+
"PA",
|
|
142
|
+
"RI",
|
|
143
|
+
"SC",
|
|
144
|
+
"SD",
|
|
145
|
+
"TN",
|
|
146
|
+
"TX",
|
|
147
|
+
"UT",
|
|
148
|
+
"VA",
|
|
149
|
+
"VT",
|
|
150
|
+
"WA",
|
|
151
|
+
"WI",
|
|
152
|
+
"WV",
|
|
153
|
+
"WY",
|
|
154
|
+
# D.C.
|
|
155
|
+
"DC",
|
|
156
|
+
# Territories
|
|
157
|
+
"AS",
|
|
158
|
+
"GU",
|
|
159
|
+
"MP",
|
|
160
|
+
"PR",
|
|
161
|
+
"VI",
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
MAX_TEMPERATURE = 2.0
|
|
165
|
+
MIN_TEMPERATURE = 0.0
|
|
166
|
+
MAX_TOP_P = 1.0
|
|
167
|
+
MIN_TOP_P = 0.0
|
|
168
|
+
MIN_MAX_TOKENS = 1
|
|
169
|
+
TRACE_COLUMN_POSTFIX = "__trace"
|
|
170
|
+
|
|
171
|
+
AVAILABLE_LOCALES = [
|
|
172
|
+
"ar_AA",
|
|
173
|
+
"ar_AE",
|
|
174
|
+
"ar_BH",
|
|
175
|
+
"ar_EG",
|
|
176
|
+
"ar_JO",
|
|
177
|
+
"ar_PS",
|
|
178
|
+
"ar_SA",
|
|
179
|
+
"az_AZ",
|
|
180
|
+
"bg_BG",
|
|
181
|
+
"bn_BD",
|
|
182
|
+
"bs_BA",
|
|
183
|
+
"cs_CZ",
|
|
184
|
+
"da_DK",
|
|
185
|
+
"de",
|
|
186
|
+
"de_AT",
|
|
187
|
+
"de_CH",
|
|
188
|
+
"de_DE",
|
|
189
|
+
"dk_DK",
|
|
190
|
+
"el_CY",
|
|
191
|
+
"el_GR",
|
|
192
|
+
"en",
|
|
193
|
+
"en_AU",
|
|
194
|
+
"en_BD",
|
|
195
|
+
"en_CA",
|
|
196
|
+
"en_GB",
|
|
197
|
+
"en_IE",
|
|
198
|
+
"en_IN",
|
|
199
|
+
"en_NZ",
|
|
200
|
+
"en_PH",
|
|
201
|
+
"en_TH",
|
|
202
|
+
"en_US",
|
|
203
|
+
"es",
|
|
204
|
+
"es_AR",
|
|
205
|
+
"es_CA",
|
|
206
|
+
"es_CL",
|
|
207
|
+
"es_CO",
|
|
208
|
+
"es_ES",
|
|
209
|
+
"es_MX",
|
|
210
|
+
"et_EE",
|
|
211
|
+
"fa_IR",
|
|
212
|
+
"fi_FI",
|
|
213
|
+
"fil_PH",
|
|
214
|
+
"fr_BE",
|
|
215
|
+
"fr_CA",
|
|
216
|
+
"fr_CH",
|
|
217
|
+
"fr_FR",
|
|
218
|
+
# "fr_QC", deprecated, use fr_CA instead
|
|
219
|
+
"ga_IE",
|
|
220
|
+
"he_IL",
|
|
221
|
+
"hi_IN",
|
|
222
|
+
"hr_HR",
|
|
223
|
+
"hu_HU",
|
|
224
|
+
"hy_AM",
|
|
225
|
+
"id_ID",
|
|
226
|
+
"it_CH",
|
|
227
|
+
"it_IT",
|
|
228
|
+
"ja_JP",
|
|
229
|
+
"ka_GE",
|
|
230
|
+
"ko_KR",
|
|
231
|
+
"la",
|
|
232
|
+
"lb_LU",
|
|
233
|
+
"lt_LT",
|
|
234
|
+
"lv_LV",
|
|
235
|
+
"mt_MT",
|
|
236
|
+
"ne_NP",
|
|
237
|
+
"nl_BE",
|
|
238
|
+
"nl_NL",
|
|
239
|
+
"no_NO",
|
|
240
|
+
"or_IN",
|
|
241
|
+
"pl_PL",
|
|
242
|
+
"pt_BR",
|
|
243
|
+
"pt_PT",
|
|
244
|
+
"ro_RO",
|
|
245
|
+
"ru_RU",
|
|
246
|
+
"sk_SK",
|
|
247
|
+
"sl_SI",
|
|
248
|
+
"sq_AL",
|
|
249
|
+
"sv_SE",
|
|
250
|
+
"ta_IN",
|
|
251
|
+
"th",
|
|
252
|
+
"th_TH",
|
|
253
|
+
"tl_PH",
|
|
254
|
+
"tr_TR",
|
|
255
|
+
"tw_GH",
|
|
256
|
+
"uk_UA",
|
|
257
|
+
"vi_VN",
|
|
258
|
+
"zh_CN",
|
|
259
|
+
"zh_TW",
|
|
260
|
+
"zu_ZA",
|
|
261
|
+
]
|
|
262
|
+
|
|
263
|
+
DATA_DESIGNER_HOME_ENV_VAR = "DATA_DESIGNER_HOME"
|
|
264
|
+
|
|
265
|
+
DATA_DESIGNER_HOME = Path(os.getenv(DATA_DESIGNER_HOME_ENV_VAR, Path.home() / ".data-designer"))
|
|
266
|
+
|
|
267
|
+
MANAGED_ASSETS_PATH_ENV_VAR = "DATA_DESIGNER_MANAGED_ASSETS_PATH"
|
|
268
|
+
|
|
269
|
+
MANAGED_ASSETS_PATH = Path(os.getenv(MANAGED_ASSETS_PATH_ENV_VAR, DATA_DESIGNER_HOME / "managed-assets"))
|
|
270
|
+
|
|
271
|
+
MODEL_CONFIGS_FILE_NAME = "model_configs.yaml"
|
|
272
|
+
|
|
273
|
+
MODEL_CONFIGS_FILE_PATH = DATA_DESIGNER_HOME / MODEL_CONFIGS_FILE_NAME
|
|
274
|
+
|
|
275
|
+
MODEL_PROVIDERS_FILE_NAME = "model_providers.yaml"
|
|
276
|
+
|
|
277
|
+
MODEL_PROVIDERS_FILE_PATH = DATA_DESIGNER_HOME / MODEL_PROVIDERS_FILE_NAME
|
|
278
|
+
|
|
279
|
+
NVIDIA_PROVIDER_NAME = "nvidia"
|
|
280
|
+
|
|
281
|
+
NVIDIA_API_KEY_ENV_VAR_NAME = "NVIDIA_API_KEY"
|
|
282
|
+
|
|
283
|
+
OPENAI_PROVIDER_NAME = "openai"
|
|
284
|
+
|
|
285
|
+
OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY"
|
|
286
|
+
|
|
287
|
+
OPENROUTER_PROVIDER_NAME = "openrouter"
|
|
288
|
+
|
|
289
|
+
OPENROUTER_API_KEY_ENV_VAR_NAME = "OPENROUTER_API_KEY"
|
|
290
|
+
|
|
291
|
+
PREDEFINED_PROVIDERS = [
|
|
292
|
+
{
|
|
293
|
+
"name": NVIDIA_PROVIDER_NAME,
|
|
294
|
+
"endpoint": "https://integrate.api.nvidia.com/v1",
|
|
295
|
+
"provider_type": "openai",
|
|
296
|
+
"api_key": NVIDIA_API_KEY_ENV_VAR_NAME,
|
|
297
|
+
},
|
|
298
|
+
{
|
|
299
|
+
"name": OPENAI_PROVIDER_NAME,
|
|
300
|
+
"endpoint": "https://api.openai.com/v1",
|
|
301
|
+
"provider_type": "openai",
|
|
302
|
+
"api_key": OPENAI_API_KEY_ENV_VAR_NAME,
|
|
303
|
+
},
|
|
304
|
+
{
|
|
305
|
+
"name": OPENROUTER_PROVIDER_NAME,
|
|
306
|
+
"endpoint": "https://openrouter.ai/api/v1",
|
|
307
|
+
"provider_type": "openai",
|
|
308
|
+
"api_key": OPENROUTER_API_KEY_ENV_VAR_NAME,
|
|
309
|
+
},
|
|
310
|
+
]
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
DEFAULT_TEXT_INFERENCE_PARAMS = {"temperature": 0.85, "top_p": 0.95}
|
|
314
|
+
DEFAULT_REASONING_INFERENCE_PARAMS = {"temperature": 0.35, "top_p": 0.95}
|
|
315
|
+
DEFAULT_VISION_INFERENCE_PARAMS = {"temperature": 0.85, "top_p": 0.95}
|
|
316
|
+
DEFAULT_EMBEDDING_INFERENCE_PARAMS = {"encoding_format": "float"}
|
|
317
|
+
NEMOTRON_3_NANO_30B_A3B_INFERENCE_PARAMS = {"temperature": 1.0, "top_p": 1.0}
|
|
318
|
+
|
|
319
|
+
PREDEFINED_PROVIDERS_MODEL_MAP = {
|
|
320
|
+
NVIDIA_PROVIDER_NAME: {
|
|
321
|
+
"text": {
|
|
322
|
+
"model": "nvidia/nemotron-3-nano-30b-a3b",
|
|
323
|
+
"inference_parameters": NEMOTRON_3_NANO_30B_A3B_INFERENCE_PARAMS,
|
|
324
|
+
},
|
|
325
|
+
"reasoning": {"model": "openai/gpt-oss-20b", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
|
|
326
|
+
"vision": {"model": "nvidia/nemotron-nano-12b-v2-vl", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
|
|
327
|
+
"embedding": {
|
|
328
|
+
"model": "nvidia/llama-3.2-nv-embedqa-1b-v2",
|
|
329
|
+
"inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS | {"extra_body": {"input_type": "query"}},
|
|
330
|
+
},
|
|
331
|
+
},
|
|
332
|
+
OPENAI_PROVIDER_NAME: {
|
|
333
|
+
"text": {"model": "gpt-4.1", "inference_parameters": DEFAULT_TEXT_INFERENCE_PARAMS},
|
|
334
|
+
"reasoning": {"model": "gpt-5", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
|
|
335
|
+
"vision": {"model": "gpt-5", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
|
|
336
|
+
"embedding": {"model": "text-embedding-3-large", "inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS},
|
|
337
|
+
},
|
|
338
|
+
OPENROUTER_PROVIDER_NAME: {
|
|
339
|
+
"text": {
|
|
340
|
+
"model": "nvidia/nemotron-3-nano-30b-a3b",
|
|
341
|
+
"inference_parameters": NEMOTRON_3_NANO_30B_A3B_INFERENCE_PARAMS,
|
|
342
|
+
},
|
|
343
|
+
"reasoning": {"model": "openai/gpt-oss-20b", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
|
|
344
|
+
"vision": {"model": "nvidia/nemotron-nano-12b-v2-vl", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
|
|
345
|
+
"embedding": {
|
|
346
|
+
"model": "openai/text-embedding-3-large",
|
|
347
|
+
"inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS,
|
|
348
|
+
},
|
|
349
|
+
},
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
# Persona locale metadata - used by the CLI and the person sampler.
|
|
353
|
+
NEMOTRON_PERSONAS_DATASET_SIZES = {
|
|
354
|
+
"en_US": "1.24 GB",
|
|
355
|
+
"en_IN": "2.39 GB",
|
|
356
|
+
"en_SG": "0.30 GB",
|
|
357
|
+
"hi_Deva_IN": "4.14 GB",
|
|
358
|
+
"hi_Latn_IN": "2.7 GB",
|
|
359
|
+
"ja_JP": "1.69 GB",
|
|
360
|
+
"pt_BR": "2.33 GB",
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
LOCALES_WITH_MANAGED_DATASETS = list[str](NEMOTRON_PERSONAS_DATASET_SIZES.keys())
|
|
364
|
+
|
|
365
|
+
NEMOTRON_PERSONAS_DATASET_PREFIX = "nemotron-personas-dataset-"
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from data_designer.errors import DataDesignerError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class UserJinjaTemplateSyntaxError(DataDesignerError): ...
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class InvalidEnumValueError(DataDesignerError): ...
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class InvalidTypeUnionError(DataDesignerError): ...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class InvalidDiscriminatorFieldError(DataDesignerError): ...
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DatasetSampleDisplayError(DataDesignerError): ...
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Literal, TypeVar
|
|
9
|
+
|
|
10
|
+
from data_designer.config.models import ModelConfig, ModelProvider
|
|
11
|
+
from data_designer.config.sampler_params import SamplerType
|
|
12
|
+
from data_designer.config.utils.type_helpers import get_sampler_params
|
|
13
|
+
from data_designer.config.utils.visualization import (
|
|
14
|
+
display_model_configs_table,
|
|
15
|
+
display_model_providers_table,
|
|
16
|
+
display_sampler_table,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class InfoType(str, Enum):
|
|
21
|
+
SAMPLERS = "samplers"
|
|
22
|
+
MODEL_CONFIGS = "model_configs"
|
|
23
|
+
MODEL_PROVIDERS = "model_providers"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
ConfigBuilderInfoType = Literal[InfoType.SAMPLERS, InfoType.MODEL_CONFIGS]
|
|
27
|
+
DataDesignerInfoType = Literal[InfoType.MODEL_PROVIDERS]
|
|
28
|
+
InfoTypeT = TypeVar("InfoTypeT", bound=InfoType)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class InfoDisplay(ABC):
|
|
32
|
+
"""Base class for info display classes that provide type-safe display methods."""
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def display(self, info_type: InfoTypeT, **kwargs) -> None:
|
|
36
|
+
"""Display information based on the provided info type.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
info_type: Type of information to display.
|
|
40
|
+
"""
|
|
41
|
+
...
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ConfigBuilderInfo(InfoDisplay):
|
|
45
|
+
def __init__(self, model_configs: list[ModelConfig]):
|
|
46
|
+
self._sampler_params = get_sampler_params()
|
|
47
|
+
self._model_configs = model_configs
|
|
48
|
+
|
|
49
|
+
def display(self, info_type: ConfigBuilderInfoType, **kwargs) -> None:
|
|
50
|
+
"""Display information based on the provided info type.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
info_type: Type of information to display. Only SAMPLERS and MODEL_CONFIGS are supported.
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ValueError: If an unsupported info_type is provided.
|
|
57
|
+
"""
|
|
58
|
+
if info_type == InfoType.SAMPLERS:
|
|
59
|
+
self._display_sampler_info(sampler_type=kwargs.get("sampler_type"))
|
|
60
|
+
elif info_type == InfoType.MODEL_CONFIGS:
|
|
61
|
+
display_model_configs_table(self._model_configs)
|
|
62
|
+
else:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"Unsupported info_type: {str(info_type)!r}. "
|
|
65
|
+
f"ConfigBuilderInfo only supports {InfoType.SAMPLERS.value!r} and {InfoType.MODEL_CONFIGS.value!r}."
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def _display_sampler_info(self, sampler_type: SamplerType | None) -> None:
|
|
69
|
+
if sampler_type is not None:
|
|
70
|
+
title = f"{SamplerType(sampler_type).value.replace('_', ' ').title()} Sampler"
|
|
71
|
+
display_sampler_table({sampler_type: self._sampler_params[sampler_type]}, title=title)
|
|
72
|
+
else:
|
|
73
|
+
display_sampler_table(self._sampler_params)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class InterfaceInfo(InfoDisplay):
|
|
77
|
+
def __init__(self, model_providers: list[ModelProvider]):
|
|
78
|
+
self._model_providers = model_providers
|
|
79
|
+
|
|
80
|
+
def display(self, info_type: DataDesignerInfoType, **kwargs) -> None:
|
|
81
|
+
"""Display information based on the provided info type.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
info_type: Type of information to display. Only MODEL_PROVIDERS is supported.
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
ValueError: If an unsupported info_type is provided.
|
|
88
|
+
"""
|
|
89
|
+
if info_type == InfoType.MODEL_PROVIDERS:
|
|
90
|
+
display_model_providers_table(self._model_providers)
|
|
91
|
+
else:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Unsupported info_type: {str(info_type)!r}. InterfaceInfo only supports {InfoType.MODEL_PROVIDERS.value!r}."
|
|
94
|
+
)
|