data-designer-config 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. data_designer/config/__init__.py +149 -0
  2. data_designer/config/_version.py +34 -0
  3. data_designer/config/analysis/__init__.py +2 -0
  4. data_designer/config/analysis/column_profilers.py +159 -0
  5. data_designer/config/analysis/column_statistics.py +421 -0
  6. data_designer/config/analysis/dataset_profiler.py +84 -0
  7. data_designer/config/analysis/utils/errors.py +10 -0
  8. data_designer/config/analysis/utils/reporting.py +192 -0
  9. data_designer/config/base.py +69 -0
  10. data_designer/config/column_configs.py +476 -0
  11. data_designer/config/column_types.py +141 -0
  12. data_designer/config/config_builder.py +595 -0
  13. data_designer/config/data_designer_config.py +40 -0
  14. data_designer/config/dataset_builders.py +13 -0
  15. data_designer/config/dataset_metadata.py +18 -0
  16. data_designer/config/default_model_settings.py +129 -0
  17. data_designer/config/errors.py +24 -0
  18. data_designer/config/interface.py +55 -0
  19. data_designer/config/models.py +486 -0
  20. data_designer/config/preview_results.py +41 -0
  21. data_designer/config/processors.py +148 -0
  22. data_designer/config/run_config.py +56 -0
  23. data_designer/config/sampler_constraints.py +52 -0
  24. data_designer/config/sampler_params.py +639 -0
  25. data_designer/config/seed.py +116 -0
  26. data_designer/config/seed_source.py +84 -0
  27. data_designer/config/seed_source_types.py +19 -0
  28. data_designer/config/testing/__init__.py +6 -0
  29. data_designer/config/testing/fixtures.py +308 -0
  30. data_designer/config/utils/code_lang.py +93 -0
  31. data_designer/config/utils/constants.py +365 -0
  32. data_designer/config/utils/errors.py +21 -0
  33. data_designer/config/utils/info.py +94 -0
  34. data_designer/config/utils/io_helpers.py +258 -0
  35. data_designer/config/utils/misc.py +78 -0
  36. data_designer/config/utils/numerical_helpers.py +30 -0
  37. data_designer/config/utils/type_helpers.py +106 -0
  38. data_designer/config/utils/visualization.py +482 -0
  39. data_designer/config/validator_params.py +94 -0
  40. data_designer/errors.py +7 -0
  41. data_designer/lazy_heavy_imports.py +56 -0
  42. data_designer/logging.py +180 -0
  43. data_designer/plugin_manager.py +78 -0
  44. data_designer/plugins/__init__.py +8 -0
  45. data_designer/plugins/errors.py +15 -0
  46. data_designer/plugins/plugin.py +141 -0
  47. data_designer/plugins/registry.py +88 -0
  48. data_designer_config-0.4.0.dist-info/METADATA +75 -0
  49. data_designer_config-0.4.0.dist-info/RECORD +50 -0
  50. data_designer_config-0.4.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,129 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import logging
7
+ import os
8
+ from functools import lru_cache
9
+ from pathlib import Path
10
+ from typing import Any, Literal
11
+
12
+ from data_designer.config.models import (
13
+ ChatCompletionInferenceParams,
14
+ EmbeddingInferenceParams,
15
+ InferenceParamsT,
16
+ ModelConfig,
17
+ ModelProvider,
18
+ )
19
+ from data_designer.config.utils.constants import (
20
+ MANAGED_ASSETS_PATH,
21
+ MODEL_CONFIGS_FILE_PATH,
22
+ MODEL_PROVIDERS_FILE_PATH,
23
+ PREDEFINED_PROVIDERS,
24
+ PREDEFINED_PROVIDERS_MODEL_MAP,
25
+ )
26
+ from data_designer.config.utils.io_helpers import load_config_file, save_config_file
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def get_default_inference_parameters(
32
+ model_alias: Literal["text", "reasoning", "vision", "embedding"],
33
+ inference_parameters: dict[str, Any],
34
+ ) -> InferenceParamsT:
35
+ if model_alias == "reasoning":
36
+ return ChatCompletionInferenceParams(**inference_parameters)
37
+ elif model_alias == "vision":
38
+ return ChatCompletionInferenceParams(**inference_parameters)
39
+ elif model_alias == "embedding":
40
+ return EmbeddingInferenceParams(**inference_parameters)
41
+ else:
42
+ return ChatCompletionInferenceParams(**inference_parameters)
43
+
44
+
45
+ def get_builtin_model_configs() -> list[ModelConfig]:
46
+ model_configs = []
47
+ for provider, model_alias_map in PREDEFINED_PROVIDERS_MODEL_MAP.items():
48
+ for model_alias, settings in model_alias_map.items():
49
+ model_configs.append(
50
+ ModelConfig(
51
+ alias=f"{provider}-{model_alias}",
52
+ model=settings["model"],
53
+ provider=provider,
54
+ inference_parameters=get_default_inference_parameters(
55
+ model_alias, settings["inference_parameters"]
56
+ ),
57
+ )
58
+ )
59
+ return model_configs
60
+
61
+
62
+ def get_builtin_model_providers() -> list[ModelProvider]:
63
+ return [ModelProvider.model_validate(provider) for provider in PREDEFINED_PROVIDERS]
64
+
65
+
66
+ def get_default_model_configs() -> list[ModelConfig]:
67
+ if MODEL_CONFIGS_FILE_PATH.exists():
68
+ config_dict = load_config_file(MODEL_CONFIGS_FILE_PATH)
69
+ if "model_configs" in config_dict:
70
+ return [ModelConfig.model_validate(mc) for mc in config_dict["model_configs"]]
71
+ return []
72
+
73
+
74
+ def get_providers_with_missing_api_keys(providers: list[ModelProvider]) -> list[ModelProvider]:
75
+ providers_with_missing_keys = []
76
+
77
+ for provider in providers:
78
+ if provider.api_key is None:
79
+ # No API key specified at all
80
+ providers_with_missing_keys.append(provider)
81
+ elif provider.api_key.isupper() and "_" in provider.api_key:
82
+ # Looks like an environment variable name, check if it's set
83
+ if os.environ.get(provider.api_key) is None:
84
+ providers_with_missing_keys.append(provider)
85
+ # else: It's an actual API key value (not an env var), so it's valid
86
+
87
+ return providers_with_missing_keys
88
+
89
+
90
+ def get_default_providers() -> list[ModelProvider]:
91
+ config_dict = _get_default_providers_file_content(MODEL_PROVIDERS_FILE_PATH)
92
+ if "providers" in config_dict:
93
+ return [ModelProvider.model_validate(p) for p in config_dict["providers"]]
94
+ return []
95
+
96
+
97
+ def get_default_provider_name() -> str | None:
98
+ return _get_default_providers_file_content(MODEL_PROVIDERS_FILE_PATH).get("default")
99
+
100
+
101
+ def resolve_seed_default_model_settings() -> None:
102
+ if not MODEL_CONFIGS_FILE_PATH.exists():
103
+ logger.debug(
104
+ f"🍾 Default model configs were not found, so writing the following to {str(MODEL_CONFIGS_FILE_PATH)!r}"
105
+ )
106
+ save_config_file(
107
+ MODEL_CONFIGS_FILE_PATH,
108
+ {"model_configs": [mc.model_dump(mode="json") for mc in get_builtin_model_configs()]},
109
+ )
110
+
111
+ if not MODEL_PROVIDERS_FILE_PATH.exists():
112
+ logger.debug(
113
+ f"🪄 Default model providers were not found, so writing the following to {str(MODEL_PROVIDERS_FILE_PATH)!r}"
114
+ )
115
+ save_config_file(
116
+ MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump(mode="json") for p in get_builtin_model_providers()]}
117
+ )
118
+
119
+ if not MANAGED_ASSETS_PATH.exists():
120
+ logger.debug(f"🏗️ Default managed assets path was not found, so creating it at {str(MANAGED_ASSETS_PATH)!r}")
121
+ MANAGED_ASSETS_PATH.mkdir(parents=True, exist_ok=True)
122
+
123
+
124
+ @lru_cache(maxsize=1)
125
+ def _get_default_providers_file_content(file_path: Path) -> dict[str, Any]:
126
+ """Load and cache the default providers file content."""
127
+ if file_path.exists():
128
+ return load_config_file(file_path)
129
+ raise FileNotFoundError(f"Default model providers file not found at {str(file_path)!r}")
@@ -0,0 +1,24 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from data_designer.errors import DataDesignerError
7
+
8
+
9
+ class BuilderConfigurationError(DataDesignerError): ...
10
+
11
+
12
+ class BuilderSerializationError(DataDesignerError): ...
13
+
14
+
15
+ class InvalidColumnTypeError(DataDesignerError): ...
16
+
17
+
18
+ class InvalidConfigError(DataDesignerError): ...
19
+
20
+
21
+ class InvalidFilePathError(DataDesignerError): ...
22
+
23
+
24
+ class InvalidFileFormatError(DataDesignerError): ...
@@ -0,0 +1,55 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from abc import ABC, abstractmethod
7
+ from typing import TYPE_CHECKING, Generic, Protocol, TypeVar
8
+
9
+ from data_designer.config.models import ModelConfig, ModelProvider
10
+ from data_designer.config.utils.constants import DEFAULT_NUM_RECORDS
11
+ from data_designer.config.utils.info import InterfaceInfo
12
+ from data_designer.lazy_heavy_imports import pd
13
+
14
+ if TYPE_CHECKING:
15
+ import pandas as pd
16
+
17
+ from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
18
+ from data_designer.config.config_builder import DataDesignerConfigBuilder
19
+ from data_designer.config.preview_results import PreviewResults
20
+
21
+
22
+ class ResultsProtocol(Protocol):
23
+ def load_analysis(self) -> DatasetProfilerResults: ...
24
+ def load_dataset(self) -> pd.DataFrame: ...
25
+
26
+
27
+ ResultsT = TypeVar("ResultsT", bound=ResultsProtocol)
28
+
29
+
30
+ class DataDesignerInterface(ABC, Generic[ResultsT]):
31
+ @abstractmethod
32
+ def create(
33
+ self,
34
+ config_builder: DataDesignerConfigBuilder,
35
+ *,
36
+ num_records: int = DEFAULT_NUM_RECORDS,
37
+ ) -> ResultsT: ...
38
+
39
+ @abstractmethod
40
+ def preview(
41
+ self,
42
+ config_builder: DataDesignerConfigBuilder,
43
+ *,
44
+ num_records: int = DEFAULT_NUM_RECORDS,
45
+ ) -> PreviewResults: ...
46
+
47
+ @abstractmethod
48
+ def get_default_model_configs(self) -> list[ModelConfig]: ...
49
+
50
+ @abstractmethod
51
+ def get_default_model_providers(self) -> list[ModelProvider]: ...
52
+
53
+ @property
54
+ @abstractmethod
55
+ def info(self) -> InterfaceInfo: ...
@@ -0,0 +1,486 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ import logging
8
+ from abc import ABC, abstractmethod
9
+ from enum import Enum
10
+ from pathlib import Path
11
+ from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeVar
12
+
13
+ from pydantic import BaseModel, Field, field_validator, model_validator
14
+ from typing_extensions import Self, TypeAlias
15
+
16
+ from data_designer.config.base import ConfigBase
17
+ from data_designer.config.errors import InvalidConfigError
18
+ from data_designer.config.utils.constants import (
19
+ MAX_TEMPERATURE,
20
+ MAX_TOP_P,
21
+ MIN_TEMPERATURE,
22
+ MIN_TOP_P,
23
+ )
24
+ from data_designer.config.utils.io_helpers import smart_load_yaml
25
+ from data_designer.lazy_heavy_imports import np
26
+
27
+ if TYPE_CHECKING:
28
+ import numpy as np
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class Modality(str, Enum):
34
+ """Supported modality types for multimodal model data."""
35
+
36
+ IMAGE = "image"
37
+
38
+
39
+ class ModalityDataType(str, Enum):
40
+ """Data type formats for multimodal data."""
41
+
42
+ URL = "url"
43
+ BASE64 = "base64"
44
+
45
+
46
+ class ImageFormat(str, Enum):
47
+ """Supported image formats for image modality."""
48
+
49
+ PNG = "png"
50
+ JPG = "jpg"
51
+ JPEG = "jpeg"
52
+ GIF = "gif"
53
+ WEBP = "webp"
54
+
55
+
56
+ class DistributionType(str, Enum):
57
+ """Types of distributions for sampling inference parameters."""
58
+
59
+ UNIFORM = "uniform"
60
+ MANUAL = "manual"
61
+
62
+
63
+ class ModalityContext(ABC, BaseModel):
64
+ modality: Modality
65
+ column_name: str
66
+ data_type: ModalityDataType
67
+
68
+ @abstractmethod
69
+ def get_contexts(self, record: dict) -> list[dict[str, Any]]: ...
70
+
71
+
72
+ class ImageContext(ModalityContext):
73
+ """Configuration for providing image context to multimodal models.
74
+
75
+ Attributes:
76
+ modality: The modality type (always "image").
77
+ column_name: Name of the column containing image data.
78
+ data_type: Format of the image data ("url" or "base64").
79
+ image_format: Image format (required for base64 data).
80
+ """
81
+
82
+ modality: Modality = Modality.IMAGE
83
+ image_format: ImageFormat | None = None
84
+
85
+ def get_contexts(self, record: dict) -> list[dict[str, Any]]:
86
+ """Get the contexts for the image modality.
87
+
88
+ Args:
89
+ record: The record containing the image data. The data can be:
90
+ - A JSON serialized list of strings
91
+ - A list of strings
92
+ - A single string
93
+
94
+ Returns:
95
+ A list of image contexts.
96
+ """
97
+ raw_value = record[self.column_name]
98
+
99
+ # Normalize to list of strings
100
+ if isinstance(raw_value, str):
101
+ # Try to parse as JSON first
102
+ try:
103
+ parsed_value = json.loads(raw_value)
104
+ if isinstance(parsed_value, list):
105
+ context_values = parsed_value
106
+ else:
107
+ context_values = [raw_value]
108
+ except (json.JSONDecodeError, TypeError):
109
+ context_values = [raw_value]
110
+ elif isinstance(raw_value, list):
111
+ context_values = raw_value
112
+ elif hasattr(raw_value, "__iter__") and not isinstance(raw_value, (str, bytes, dict)):
113
+ # Handle array-like objects (numpy arrays, pandas Series, etc.)
114
+ context_values = list(raw_value)
115
+ else:
116
+ context_values = [raw_value]
117
+
118
+ # Build context list
119
+ contexts = []
120
+ for context_value in context_values:
121
+ context = dict(type="image_url")
122
+ if self.data_type == ModalityDataType.URL:
123
+ context["image_url"] = context_value
124
+ else:
125
+ context["image_url"] = {
126
+ "url": f"data:image/{self.image_format.value};base64,{context_value}",
127
+ "format": self.image_format.value,
128
+ }
129
+ contexts.append(context)
130
+
131
+ return contexts
132
+
133
+ @model_validator(mode="after")
134
+ def _validate_image_format(self) -> Self:
135
+ if self.data_type == ModalityDataType.BASE64 and self.image_format is None:
136
+ raise ValueError(f"image_format is required when data_type is {self.data_type.value}")
137
+ return self
138
+
139
+
140
+ DistributionParamsT = TypeVar("DistributionParamsT", bound=ConfigBase)
141
+
142
+
143
+ class Distribution(ABC, ConfigBase, Generic[DistributionParamsT]):
144
+ distribution_type: DistributionType
145
+ params: DistributionParamsT
146
+
147
+ @abstractmethod
148
+ def sample(self) -> float: ...
149
+
150
+
151
+ class ManualDistributionParams(ConfigBase):
152
+ """Parameters for manual distribution sampling.
153
+
154
+ Attributes:
155
+ values: List of possible values to sample from.
156
+ weights: Optional list of weights for each value. If not provided, all values have equal probability.
157
+ """
158
+
159
+ values: list[float] = Field(min_length=1)
160
+ weights: list[float] | None = None
161
+
162
+ @model_validator(mode="after")
163
+ def _normalize_weights(self) -> Self:
164
+ if self.weights is not None:
165
+ self.weights = [w / sum(self.weights) for w in self.weights]
166
+ return self
167
+
168
+ @model_validator(mode="after")
169
+ def _validate_equal_lengths(self) -> Self:
170
+ if self.weights and len(self.values) != len(self.weights):
171
+ raise ValueError("`values` and `weights` must have the same length")
172
+ return self
173
+
174
+
175
+ class ManualDistribution(Distribution[ManualDistributionParams]):
176
+ """Manual (discrete) distribution for sampling inference parameters.
177
+
178
+ Samples from a discrete set of values with optional weights. Useful for testing
179
+ specific values or creating custom probability distributions for temperature or top_p.
180
+
181
+ Attributes:
182
+ distribution_type: Type of distribution ("manual").
183
+ params: Distribution parameters (values, weights).
184
+ """
185
+
186
+ distribution_type: DistributionType | None = "manual"
187
+ params: ManualDistributionParams
188
+
189
+ def sample(self) -> float:
190
+ """Sample a value from the manual distribution.
191
+
192
+ Returns:
193
+ A float value sampled from the manual distribution.
194
+ """
195
+ return float(np.random.choice(self.params.values, p=self.params.weights))
196
+
197
+
198
+ class UniformDistributionParams(ConfigBase):
199
+ """Parameters for uniform distribution sampling.
200
+
201
+ Attributes:
202
+ low: Lower bound (inclusive).
203
+ high: Upper bound (exclusive).
204
+ """
205
+
206
+ low: float
207
+ high: float
208
+
209
+ @model_validator(mode="after")
210
+ def _validate_low_lt_high(self) -> Self:
211
+ if self.low >= self.high:
212
+ raise ValueError("`low` must be less than `high`")
213
+ return self
214
+
215
+
216
+ class UniformDistribution(Distribution[UniformDistributionParams]):
217
+ """Uniform distribution for sampling inference parameters.
218
+
219
+ Samples values uniformly between low and high bounds. Useful for exploring
220
+ a continuous range of values for temperature or top_p.
221
+
222
+ Attributes:
223
+ distribution_type: Type of distribution ("uniform").
224
+ params: Distribution parameters (low, high).
225
+ """
226
+
227
+ distribution_type: DistributionType | None = "uniform"
228
+ params: UniformDistributionParams
229
+
230
+ def sample(self) -> float:
231
+ """Sample a value from the uniform distribution.
232
+
233
+ Returns:
234
+ A float value sampled from the uniform distribution.
235
+ """
236
+ return float(np.random.uniform(low=self.params.low, high=self.params.high, size=1)[0])
237
+
238
+
239
+ DistributionT: TypeAlias = UniformDistribution | ManualDistribution
240
+
241
+
242
+ class GenerationType(str, Enum):
243
+ CHAT_COMPLETION = "chat-completion"
244
+ EMBEDDING = "embedding"
245
+
246
+
247
+ class BaseInferenceParams(ConfigBase, ABC):
248
+ """Base configuration for inference parameters.
249
+
250
+ Attributes:
251
+ generation_type: Type of generation (chat-completion or embedding). Acts as discriminator.
252
+ max_parallel_requests: Maximum number of parallel requests to the model API.
253
+ timeout: Timeout in seconds for each request.
254
+ extra_body: Additional parameters to pass to the model API.
255
+ """
256
+
257
+ generation_type: GenerationType
258
+ max_parallel_requests: int = Field(default=4, ge=1)
259
+ timeout: int | None = Field(default=None, ge=1)
260
+ extra_body: dict[str, Any] | None = None
261
+
262
+ @property
263
+ def generate_kwargs(self) -> dict[str, Any]:
264
+ """Get the generate kwargs for the inference parameters.
265
+
266
+ Returns:
267
+ A dictionary of the generate kwargs.
268
+ """
269
+ result = {}
270
+ if self.timeout is not None:
271
+ result["timeout"] = self.timeout
272
+ if self.extra_body is not None and self.extra_body != {}:
273
+ result["extra_body"] = self.extra_body
274
+ return result
275
+
276
+ def format_for_display(self) -> str:
277
+ """Format inference parameters for display.
278
+
279
+ Returns:
280
+ Formatted string of inference parameters
281
+ """
282
+ params_dict = self.model_dump(exclude_none=True, mode="json")
283
+
284
+ if not params_dict:
285
+ return "(none)"
286
+
287
+ parts = []
288
+ for key, value in params_dict.items():
289
+ formatted_value = self._format_value(key, value)
290
+ parts.append(f"{key}={formatted_value}")
291
+ return ", ".join(parts)
292
+
293
+ def _format_value(self, key: str, value: Any) -> str:
294
+ """Format a single parameter value. Override in subclasses for custom formatting.
295
+
296
+ Args:
297
+ key: Parameter name
298
+ value: Parameter value
299
+
300
+ Returns:
301
+ Formatted string representation of the value
302
+ """
303
+ if isinstance(value, float):
304
+ return f"{value:.2f}"
305
+ return str(value)
306
+
307
+
308
+ class ChatCompletionInferenceParams(BaseInferenceParams):
309
+ """Configuration for LLM inference parameters.
310
+
311
+ Attributes:
312
+ generation_type: Type of generation, always "chat-completion" for this class.
313
+ temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling.
314
+ top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
315
+ max_tokens: Maximum number of tokens to generate in the response.
316
+ """
317
+
318
+ generation_type: Literal[GenerationType.CHAT_COMPLETION] = GenerationType.CHAT_COMPLETION
319
+ temperature: float | DistributionT | None = None
320
+ top_p: float | DistributionT | None = None
321
+ max_tokens: int | None = Field(default=None, ge=1)
322
+
323
+ @property
324
+ def generate_kwargs(self) -> dict[str, Any]:
325
+ result = super().generate_kwargs
326
+ if self.temperature is not None:
327
+ result["temperature"] = (
328
+ self.temperature.sample() if hasattr(self.temperature, "sample") else self.temperature
329
+ )
330
+ if self.top_p is not None:
331
+ result["top_p"] = self.top_p.sample() if hasattr(self.top_p, "sample") else self.top_p
332
+ if self.max_tokens is not None:
333
+ result["max_tokens"] = self.max_tokens
334
+ return result
335
+
336
+ @model_validator(mode="after")
337
+ def _validate_temperature(self) -> Self:
338
+ return self._run_validation(
339
+ value=self.temperature,
340
+ param_name="temperature",
341
+ min_value=MIN_TEMPERATURE,
342
+ max_value=MAX_TEMPERATURE,
343
+ )
344
+
345
+ @model_validator(mode="after")
346
+ def _validate_top_p(self) -> Self:
347
+ return self._run_validation(
348
+ value=self.top_p,
349
+ param_name="top_p",
350
+ min_value=MIN_TOP_P,
351
+ max_value=MAX_TOP_P,
352
+ )
353
+
354
+ def _run_validation(
355
+ self,
356
+ value: float | DistributionT | None,
357
+ param_name: str,
358
+ min_value: float,
359
+ max_value: float,
360
+ ) -> Self:
361
+ if value is None:
362
+ return self
363
+ value_err = ValueError(f"{param_name} defined in model config must be between {min_value} and {max_value}")
364
+ if isinstance(value, Distribution):
365
+ if value.distribution_type == DistributionType.UNIFORM:
366
+ if value.params.low < min_value or value.params.high > max_value:
367
+ raise value_err
368
+ elif value.distribution_type == DistributionType.MANUAL:
369
+ if any(not self._is_value_in_range(v, min_value, max_value) for v in value.params.values):
370
+ raise value_err
371
+ else:
372
+ if not self._is_value_in_range(value, min_value, max_value):
373
+ raise value_err
374
+ return self
375
+
376
+ def _is_value_in_range(self, value: float, min_value: float, max_value: float) -> bool:
377
+ return min_value <= value <= max_value
378
+
379
+ def _format_value(self, key: str, value: Any) -> str:
380
+ """Format chat completion parameter values, including distributions.
381
+
382
+ Args:
383
+ key: Parameter name
384
+ value: Parameter value
385
+
386
+ Returns:
387
+ Formatted string representation of the value
388
+ """
389
+ if isinstance(value, dict) and "distribution_type" in value:
390
+ return "dist"
391
+ return super()._format_value(key, value)
392
+
393
+
394
+ class EmbeddingInferenceParams(BaseInferenceParams):
395
+ """Configuration for embedding generation parameters.
396
+
397
+ Attributes:
398
+ generation_type: Type of generation, always "embedding" for this class.
399
+ encoding_format: Format of the embedding encoding ("float" or "base64").
400
+ dimensions: Number of dimensions for the embedding.
401
+ """
402
+
403
+ generation_type: Literal[GenerationType.EMBEDDING] = GenerationType.EMBEDDING
404
+ encoding_format: Literal["float", "base64"] = "float"
405
+ dimensions: int | None = None
406
+
407
+ @property
408
+ def generate_kwargs(self) -> dict[str, float | int]:
409
+ result = super().generate_kwargs
410
+ if self.encoding_format is not None:
411
+ result["encoding_format"] = self.encoding_format
412
+ if self.dimensions is not None:
413
+ result["dimensions"] = self.dimensions
414
+ return result
415
+
416
+
417
+ InferenceParamsT: TypeAlias = Annotated[
418
+ ChatCompletionInferenceParams | EmbeddingInferenceParams, Field(discriminator="generation_type")
419
+ ]
420
+
421
+
422
+ class ModelConfig(ConfigBase):
423
+ """Configuration for a model used for generation.
424
+
425
+ Attributes:
426
+ alias: User-defined alias to reference in column configurations.
427
+ model: Model identifier (e.g., from build.nvidia.com or other providers).
428
+ inference_parameters: Inference parameters for the model (temperature, top_p, max_tokens, etc.).
429
+ The generation_type is determined by the type of inference_parameters.
430
+ provider: Optional model provider name if using custom providers.
431
+ skip_health_check: Whether to skip the health check for this model. Defaults to False.
432
+ """
433
+
434
+ alias: str
435
+ model: str
436
+ inference_parameters: InferenceParamsT = Field(default_factory=ChatCompletionInferenceParams)
437
+ provider: str | None = None
438
+ skip_health_check: bool = False
439
+
440
+ @property
441
+ def generation_type(self) -> GenerationType:
442
+ """Get the generation type from the inference parameters."""
443
+ return self.inference_parameters.generation_type
444
+
445
+ @field_validator("inference_parameters", mode="before")
446
+ @classmethod
447
+ def _convert_inference_parameters(cls, value: Any) -> Any:
448
+ """Convert raw dict to appropriate inference parameters type based on field presence."""
449
+ if isinstance(value, dict):
450
+ # Infer type from presence of embedding-specific fields
451
+ if "encoding_format" in value or "dimensions" in value:
452
+ return EmbeddingInferenceParams(**value)
453
+ else:
454
+ return ChatCompletionInferenceParams(**value)
455
+ return value
456
+
457
+
458
+ class ModelProvider(ConfigBase):
459
+ """Configuration for a custom model provider.
460
+
461
+ Attributes:
462
+ name: Name of the model provider.
463
+ endpoint: API endpoint URL for the provider.
464
+ provider_type: Provider type (default: "openai"). Determines the API format to use.
465
+ api_key: Optional API key for authentication.
466
+ extra_body: Additional parameters to pass in API requests.
467
+ extra_headers: Additional headers to pass in API requests.
468
+ """
469
+
470
+ name: str
471
+ endpoint: str
472
+ provider_type: str = "openai"
473
+ api_key: str | None = None
474
+ extra_body: dict[str, Any] | None = None
475
+ extra_headers: dict[str, str] | None = None
476
+
477
+
478
+ def load_model_configs(model_configs: list[ModelConfig] | str | Path) -> list[ModelConfig]:
479
+ if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs):
480
+ return model_configs
481
+ json_config = smart_load_yaml(model_configs)
482
+ if "model_configs" not in json_config:
483
+ raise InvalidConfigError(
484
+ "The list of model configs must be provided under model_configs in the configuration file."
485
+ )
486
+ return [ModelConfig.model_validate(mc) for mc in json_config["model_configs"]]