data-designer-config 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/config/__init__.py +149 -0
- data_designer/config/_version.py +34 -0
- data_designer/config/analysis/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +159 -0
- data_designer/config/analysis/column_statistics.py +421 -0
- data_designer/config/analysis/dataset_profiler.py +84 -0
- data_designer/config/analysis/utils/errors.py +10 -0
- data_designer/config/analysis/utils/reporting.py +192 -0
- data_designer/config/base.py +69 -0
- data_designer/config/column_configs.py +476 -0
- data_designer/config/column_types.py +141 -0
- data_designer/config/config_builder.py +595 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +13 -0
- data_designer/config/dataset_metadata.py +18 -0
- data_designer/config/default_model_settings.py +129 -0
- data_designer/config/errors.py +24 -0
- data_designer/config/interface.py +55 -0
- data_designer/config/models.py +486 -0
- data_designer/config/preview_results.py +41 -0
- data_designer/config/processors.py +148 -0
- data_designer/config/run_config.py +56 -0
- data_designer/config/sampler_constraints.py +52 -0
- data_designer/config/sampler_params.py +639 -0
- data_designer/config/seed.py +116 -0
- data_designer/config/seed_source.py +84 -0
- data_designer/config/seed_source_types.py +19 -0
- data_designer/config/testing/__init__.py +6 -0
- data_designer/config/testing/fixtures.py +308 -0
- data_designer/config/utils/code_lang.py +93 -0
- data_designer/config/utils/constants.py +365 -0
- data_designer/config/utils/errors.py +21 -0
- data_designer/config/utils/info.py +94 -0
- data_designer/config/utils/io_helpers.py +258 -0
- data_designer/config/utils/misc.py +78 -0
- data_designer/config/utils/numerical_helpers.py +30 -0
- data_designer/config/utils/type_helpers.py +106 -0
- data_designer/config/utils/visualization.py +482 -0
- data_designer/config/validator_params.py +94 -0
- data_designer/errors.py +7 -0
- data_designer/lazy_heavy_imports.py +56 -0
- data_designer/logging.py +180 -0
- data_designer/plugin_manager.py +78 -0
- data_designer/plugins/__init__.py +8 -0
- data_designer/plugins/errors.py +15 -0
- data_designer/plugins/plugin.py +141 -0
- data_designer/plugins/registry.py +88 -0
- data_designer_config-0.4.0.dist-info/METADATA +75 -0
- data_designer_config-0.4.0.dist-info/RECORD +50 -0
- data_designer_config-0.4.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
from functools import lru_cache
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Literal
|
|
11
|
+
|
|
12
|
+
from data_designer.config.models import (
|
|
13
|
+
ChatCompletionInferenceParams,
|
|
14
|
+
EmbeddingInferenceParams,
|
|
15
|
+
InferenceParamsT,
|
|
16
|
+
ModelConfig,
|
|
17
|
+
ModelProvider,
|
|
18
|
+
)
|
|
19
|
+
from data_designer.config.utils.constants import (
|
|
20
|
+
MANAGED_ASSETS_PATH,
|
|
21
|
+
MODEL_CONFIGS_FILE_PATH,
|
|
22
|
+
MODEL_PROVIDERS_FILE_PATH,
|
|
23
|
+
PREDEFINED_PROVIDERS,
|
|
24
|
+
PREDEFINED_PROVIDERS_MODEL_MAP,
|
|
25
|
+
)
|
|
26
|
+
from data_designer.config.utils.io_helpers import load_config_file, save_config_file
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_default_inference_parameters(
|
|
32
|
+
model_alias: Literal["text", "reasoning", "vision", "embedding"],
|
|
33
|
+
inference_parameters: dict[str, Any],
|
|
34
|
+
) -> InferenceParamsT:
|
|
35
|
+
if model_alias == "reasoning":
|
|
36
|
+
return ChatCompletionInferenceParams(**inference_parameters)
|
|
37
|
+
elif model_alias == "vision":
|
|
38
|
+
return ChatCompletionInferenceParams(**inference_parameters)
|
|
39
|
+
elif model_alias == "embedding":
|
|
40
|
+
return EmbeddingInferenceParams(**inference_parameters)
|
|
41
|
+
else:
|
|
42
|
+
return ChatCompletionInferenceParams(**inference_parameters)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_builtin_model_configs() -> list[ModelConfig]:
|
|
46
|
+
model_configs = []
|
|
47
|
+
for provider, model_alias_map in PREDEFINED_PROVIDERS_MODEL_MAP.items():
|
|
48
|
+
for model_alias, settings in model_alias_map.items():
|
|
49
|
+
model_configs.append(
|
|
50
|
+
ModelConfig(
|
|
51
|
+
alias=f"{provider}-{model_alias}",
|
|
52
|
+
model=settings["model"],
|
|
53
|
+
provider=provider,
|
|
54
|
+
inference_parameters=get_default_inference_parameters(
|
|
55
|
+
model_alias, settings["inference_parameters"]
|
|
56
|
+
),
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
return model_configs
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_builtin_model_providers() -> list[ModelProvider]:
|
|
63
|
+
return [ModelProvider.model_validate(provider) for provider in PREDEFINED_PROVIDERS]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_default_model_configs() -> list[ModelConfig]:
|
|
67
|
+
if MODEL_CONFIGS_FILE_PATH.exists():
|
|
68
|
+
config_dict = load_config_file(MODEL_CONFIGS_FILE_PATH)
|
|
69
|
+
if "model_configs" in config_dict:
|
|
70
|
+
return [ModelConfig.model_validate(mc) for mc in config_dict["model_configs"]]
|
|
71
|
+
return []
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def get_providers_with_missing_api_keys(providers: list[ModelProvider]) -> list[ModelProvider]:
|
|
75
|
+
providers_with_missing_keys = []
|
|
76
|
+
|
|
77
|
+
for provider in providers:
|
|
78
|
+
if provider.api_key is None:
|
|
79
|
+
# No API key specified at all
|
|
80
|
+
providers_with_missing_keys.append(provider)
|
|
81
|
+
elif provider.api_key.isupper() and "_" in provider.api_key:
|
|
82
|
+
# Looks like an environment variable name, check if it's set
|
|
83
|
+
if os.environ.get(provider.api_key) is None:
|
|
84
|
+
providers_with_missing_keys.append(provider)
|
|
85
|
+
# else: It's an actual API key value (not an env var), so it's valid
|
|
86
|
+
|
|
87
|
+
return providers_with_missing_keys
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_default_providers() -> list[ModelProvider]:
|
|
91
|
+
config_dict = _get_default_providers_file_content(MODEL_PROVIDERS_FILE_PATH)
|
|
92
|
+
if "providers" in config_dict:
|
|
93
|
+
return [ModelProvider.model_validate(p) for p in config_dict["providers"]]
|
|
94
|
+
return []
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_default_provider_name() -> str | None:
|
|
98
|
+
return _get_default_providers_file_content(MODEL_PROVIDERS_FILE_PATH).get("default")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def resolve_seed_default_model_settings() -> None:
|
|
102
|
+
if not MODEL_CONFIGS_FILE_PATH.exists():
|
|
103
|
+
logger.debug(
|
|
104
|
+
f"🍾 Default model configs were not found, so writing the following to {str(MODEL_CONFIGS_FILE_PATH)!r}"
|
|
105
|
+
)
|
|
106
|
+
save_config_file(
|
|
107
|
+
MODEL_CONFIGS_FILE_PATH,
|
|
108
|
+
{"model_configs": [mc.model_dump(mode="json") for mc in get_builtin_model_configs()]},
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if not MODEL_PROVIDERS_FILE_PATH.exists():
|
|
112
|
+
logger.debug(
|
|
113
|
+
f"🪄 Default model providers were not found, so writing the following to {str(MODEL_PROVIDERS_FILE_PATH)!r}"
|
|
114
|
+
)
|
|
115
|
+
save_config_file(
|
|
116
|
+
MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump(mode="json") for p in get_builtin_model_providers()]}
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
if not MANAGED_ASSETS_PATH.exists():
|
|
120
|
+
logger.debug(f"🏗️ Default managed assets path was not found, so creating it at {str(MANAGED_ASSETS_PATH)!r}")
|
|
121
|
+
MANAGED_ASSETS_PATH.mkdir(parents=True, exist_ok=True)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@lru_cache(maxsize=1)
|
|
125
|
+
def _get_default_providers_file_content(file_path: Path) -> dict[str, Any]:
|
|
126
|
+
"""Load and cache the default providers file content."""
|
|
127
|
+
if file_path.exists():
|
|
128
|
+
return load_config_file(file_path)
|
|
129
|
+
raise FileNotFoundError(f"Default model providers file not found at {str(file_path)!r}")
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from data_designer.errors import DataDesignerError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BuilderConfigurationError(DataDesignerError): ...
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BuilderSerializationError(DataDesignerError): ...
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class InvalidColumnTypeError(DataDesignerError): ...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class InvalidConfigError(DataDesignerError): ...
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class InvalidFilePathError(DataDesignerError): ...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class InvalidFileFormatError(DataDesignerError): ...
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar
|
|
8
|
+
|
|
9
|
+
from data_designer.config.models import ModelConfig, ModelProvider
|
|
10
|
+
from data_designer.config.utils.constants import DEFAULT_NUM_RECORDS
|
|
11
|
+
from data_designer.config.utils.info import InterfaceInfo
|
|
12
|
+
from data_designer.lazy_heavy_imports import pd
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
17
|
+
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
|
|
18
|
+
from data_designer.config.config_builder import DataDesignerConfigBuilder
|
|
19
|
+
from data_designer.config.preview_results import PreviewResults
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ResultsProtocol(Protocol):
|
|
23
|
+
def load_analysis(self) -> DatasetProfilerResults: ...
|
|
24
|
+
def load_dataset(self) -> pd.DataFrame: ...
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
ResultsT = TypeVar("ResultsT", bound=ResultsProtocol)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class DataDesignerInterface(ABC, Generic[ResultsT]):
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def create(
|
|
33
|
+
self,
|
|
34
|
+
config_builder: DataDesignerConfigBuilder,
|
|
35
|
+
*,
|
|
36
|
+
num_records: int = DEFAULT_NUM_RECORDS,
|
|
37
|
+
) -> ResultsT: ...
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def preview(
|
|
41
|
+
self,
|
|
42
|
+
config_builder: DataDesignerConfigBuilder,
|
|
43
|
+
*,
|
|
44
|
+
num_records: int = DEFAULT_NUM_RECORDS,
|
|
45
|
+
) -> PreviewResults: ...
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def get_default_model_configs(self) -> list[ModelConfig]: ...
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def get_default_model_providers(self) -> list[ModelProvider]: ...
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def info(self) -> InterfaceInfo: ...
|
|
@@ -0,0 +1,486 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeVar
|
|
12
|
+
|
|
13
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
14
|
+
from typing_extensions import Self, TypeAlias
|
|
15
|
+
|
|
16
|
+
from data_designer.config.base import ConfigBase
|
|
17
|
+
from data_designer.config.errors import InvalidConfigError
|
|
18
|
+
from data_designer.config.utils.constants import (
|
|
19
|
+
MAX_TEMPERATURE,
|
|
20
|
+
MAX_TOP_P,
|
|
21
|
+
MIN_TEMPERATURE,
|
|
22
|
+
MIN_TOP_P,
|
|
23
|
+
)
|
|
24
|
+
from data_designer.config.utils.io_helpers import smart_load_yaml
|
|
25
|
+
from data_designer.lazy_heavy_imports import np
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
import numpy as np
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Modality(str, Enum):
|
|
34
|
+
"""Supported modality types for multimodal model data."""
|
|
35
|
+
|
|
36
|
+
IMAGE = "image"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ModalityDataType(str, Enum):
|
|
40
|
+
"""Data type formats for multimodal data."""
|
|
41
|
+
|
|
42
|
+
URL = "url"
|
|
43
|
+
BASE64 = "base64"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ImageFormat(str, Enum):
|
|
47
|
+
"""Supported image formats for image modality."""
|
|
48
|
+
|
|
49
|
+
PNG = "png"
|
|
50
|
+
JPG = "jpg"
|
|
51
|
+
JPEG = "jpeg"
|
|
52
|
+
GIF = "gif"
|
|
53
|
+
WEBP = "webp"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class DistributionType(str, Enum):
|
|
57
|
+
"""Types of distributions for sampling inference parameters."""
|
|
58
|
+
|
|
59
|
+
UNIFORM = "uniform"
|
|
60
|
+
MANUAL = "manual"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ModalityContext(ABC, BaseModel):
|
|
64
|
+
modality: Modality
|
|
65
|
+
column_name: str
|
|
66
|
+
data_type: ModalityDataType
|
|
67
|
+
|
|
68
|
+
@abstractmethod
|
|
69
|
+
def get_contexts(self, record: dict) -> list[dict[str, Any]]: ...
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class ImageContext(ModalityContext):
|
|
73
|
+
"""Configuration for providing image context to multimodal models.
|
|
74
|
+
|
|
75
|
+
Attributes:
|
|
76
|
+
modality: The modality type (always "image").
|
|
77
|
+
column_name: Name of the column containing image data.
|
|
78
|
+
data_type: Format of the image data ("url" or "base64").
|
|
79
|
+
image_format: Image format (required for base64 data).
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
modality: Modality = Modality.IMAGE
|
|
83
|
+
image_format: ImageFormat | None = None
|
|
84
|
+
|
|
85
|
+
def get_contexts(self, record: dict) -> list[dict[str, Any]]:
|
|
86
|
+
"""Get the contexts for the image modality.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
record: The record containing the image data. The data can be:
|
|
90
|
+
- A JSON serialized list of strings
|
|
91
|
+
- A list of strings
|
|
92
|
+
- A single string
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
A list of image contexts.
|
|
96
|
+
"""
|
|
97
|
+
raw_value = record[self.column_name]
|
|
98
|
+
|
|
99
|
+
# Normalize to list of strings
|
|
100
|
+
if isinstance(raw_value, str):
|
|
101
|
+
# Try to parse as JSON first
|
|
102
|
+
try:
|
|
103
|
+
parsed_value = json.loads(raw_value)
|
|
104
|
+
if isinstance(parsed_value, list):
|
|
105
|
+
context_values = parsed_value
|
|
106
|
+
else:
|
|
107
|
+
context_values = [raw_value]
|
|
108
|
+
except (json.JSONDecodeError, TypeError):
|
|
109
|
+
context_values = [raw_value]
|
|
110
|
+
elif isinstance(raw_value, list):
|
|
111
|
+
context_values = raw_value
|
|
112
|
+
elif hasattr(raw_value, "__iter__") and not isinstance(raw_value, (str, bytes, dict)):
|
|
113
|
+
# Handle array-like objects (numpy arrays, pandas Series, etc.)
|
|
114
|
+
context_values = list(raw_value)
|
|
115
|
+
else:
|
|
116
|
+
context_values = [raw_value]
|
|
117
|
+
|
|
118
|
+
# Build context list
|
|
119
|
+
contexts = []
|
|
120
|
+
for context_value in context_values:
|
|
121
|
+
context = dict(type="image_url")
|
|
122
|
+
if self.data_type == ModalityDataType.URL:
|
|
123
|
+
context["image_url"] = context_value
|
|
124
|
+
else:
|
|
125
|
+
context["image_url"] = {
|
|
126
|
+
"url": f"data:image/{self.image_format.value};base64,{context_value}",
|
|
127
|
+
"format": self.image_format.value,
|
|
128
|
+
}
|
|
129
|
+
contexts.append(context)
|
|
130
|
+
|
|
131
|
+
return contexts
|
|
132
|
+
|
|
133
|
+
@model_validator(mode="after")
|
|
134
|
+
def _validate_image_format(self) -> Self:
|
|
135
|
+
if self.data_type == ModalityDataType.BASE64 and self.image_format is None:
|
|
136
|
+
raise ValueError(f"image_format is required when data_type is {self.data_type.value}")
|
|
137
|
+
return self
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
DistributionParamsT = TypeVar("DistributionParamsT", bound=ConfigBase)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class Distribution(ABC, ConfigBase, Generic[DistributionParamsT]):
|
|
144
|
+
distribution_type: DistributionType
|
|
145
|
+
params: DistributionParamsT
|
|
146
|
+
|
|
147
|
+
@abstractmethod
|
|
148
|
+
def sample(self) -> float: ...
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class ManualDistributionParams(ConfigBase):
|
|
152
|
+
"""Parameters for manual distribution sampling.
|
|
153
|
+
|
|
154
|
+
Attributes:
|
|
155
|
+
values: List of possible values to sample from.
|
|
156
|
+
weights: Optional list of weights for each value. If not provided, all values have equal probability.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
values: list[float] = Field(min_length=1)
|
|
160
|
+
weights: list[float] | None = None
|
|
161
|
+
|
|
162
|
+
@model_validator(mode="after")
|
|
163
|
+
def _normalize_weights(self) -> Self:
|
|
164
|
+
if self.weights is not None:
|
|
165
|
+
self.weights = [w / sum(self.weights) for w in self.weights]
|
|
166
|
+
return self
|
|
167
|
+
|
|
168
|
+
@model_validator(mode="after")
|
|
169
|
+
def _validate_equal_lengths(self) -> Self:
|
|
170
|
+
if self.weights and len(self.values) != len(self.weights):
|
|
171
|
+
raise ValueError("`values` and `weights` must have the same length")
|
|
172
|
+
return self
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class ManualDistribution(Distribution[ManualDistributionParams]):
|
|
176
|
+
"""Manual (discrete) distribution for sampling inference parameters.
|
|
177
|
+
|
|
178
|
+
Samples from a discrete set of values with optional weights. Useful for testing
|
|
179
|
+
specific values or creating custom probability distributions for temperature or top_p.
|
|
180
|
+
|
|
181
|
+
Attributes:
|
|
182
|
+
distribution_type: Type of distribution ("manual").
|
|
183
|
+
params: Distribution parameters (values, weights).
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
distribution_type: DistributionType | None = "manual"
|
|
187
|
+
params: ManualDistributionParams
|
|
188
|
+
|
|
189
|
+
def sample(self) -> float:
|
|
190
|
+
"""Sample a value from the manual distribution.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
A float value sampled from the manual distribution.
|
|
194
|
+
"""
|
|
195
|
+
return float(np.random.choice(self.params.values, p=self.params.weights))
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class UniformDistributionParams(ConfigBase):
|
|
199
|
+
"""Parameters for uniform distribution sampling.
|
|
200
|
+
|
|
201
|
+
Attributes:
|
|
202
|
+
low: Lower bound (inclusive).
|
|
203
|
+
high: Upper bound (exclusive).
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
low: float
|
|
207
|
+
high: float
|
|
208
|
+
|
|
209
|
+
@model_validator(mode="after")
|
|
210
|
+
def _validate_low_lt_high(self) -> Self:
|
|
211
|
+
if self.low >= self.high:
|
|
212
|
+
raise ValueError("`low` must be less than `high`")
|
|
213
|
+
return self
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class UniformDistribution(Distribution[UniformDistributionParams]):
|
|
217
|
+
"""Uniform distribution for sampling inference parameters.
|
|
218
|
+
|
|
219
|
+
Samples values uniformly between low and high bounds. Useful for exploring
|
|
220
|
+
a continuous range of values for temperature or top_p.
|
|
221
|
+
|
|
222
|
+
Attributes:
|
|
223
|
+
distribution_type: Type of distribution ("uniform").
|
|
224
|
+
params: Distribution parameters (low, high).
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
distribution_type: DistributionType | None = "uniform"
|
|
228
|
+
params: UniformDistributionParams
|
|
229
|
+
|
|
230
|
+
def sample(self) -> float:
|
|
231
|
+
"""Sample a value from the uniform distribution.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
A float value sampled from the uniform distribution.
|
|
235
|
+
"""
|
|
236
|
+
return float(np.random.uniform(low=self.params.low, high=self.params.high, size=1)[0])
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
DistributionT: TypeAlias = UniformDistribution | ManualDistribution
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class GenerationType(str, Enum):
|
|
243
|
+
CHAT_COMPLETION = "chat-completion"
|
|
244
|
+
EMBEDDING = "embedding"
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
class BaseInferenceParams(ConfigBase, ABC):
|
|
248
|
+
"""Base configuration for inference parameters.
|
|
249
|
+
|
|
250
|
+
Attributes:
|
|
251
|
+
generation_type: Type of generation (chat-completion or embedding). Acts as discriminator.
|
|
252
|
+
max_parallel_requests: Maximum number of parallel requests to the model API.
|
|
253
|
+
timeout: Timeout in seconds for each request.
|
|
254
|
+
extra_body: Additional parameters to pass to the model API.
|
|
255
|
+
"""
|
|
256
|
+
|
|
257
|
+
generation_type: GenerationType
|
|
258
|
+
max_parallel_requests: int = Field(default=4, ge=1)
|
|
259
|
+
timeout: int | None = Field(default=None, ge=1)
|
|
260
|
+
extra_body: dict[str, Any] | None = None
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def generate_kwargs(self) -> dict[str, Any]:
|
|
264
|
+
"""Get the generate kwargs for the inference parameters.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
A dictionary of the generate kwargs.
|
|
268
|
+
"""
|
|
269
|
+
result = {}
|
|
270
|
+
if self.timeout is not None:
|
|
271
|
+
result["timeout"] = self.timeout
|
|
272
|
+
if self.extra_body is not None and self.extra_body != {}:
|
|
273
|
+
result["extra_body"] = self.extra_body
|
|
274
|
+
return result
|
|
275
|
+
|
|
276
|
+
def format_for_display(self) -> str:
|
|
277
|
+
"""Format inference parameters for display.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
Formatted string of inference parameters
|
|
281
|
+
"""
|
|
282
|
+
params_dict = self.model_dump(exclude_none=True, mode="json")
|
|
283
|
+
|
|
284
|
+
if not params_dict:
|
|
285
|
+
return "(none)"
|
|
286
|
+
|
|
287
|
+
parts = []
|
|
288
|
+
for key, value in params_dict.items():
|
|
289
|
+
formatted_value = self._format_value(key, value)
|
|
290
|
+
parts.append(f"{key}={formatted_value}")
|
|
291
|
+
return ", ".join(parts)
|
|
292
|
+
|
|
293
|
+
def _format_value(self, key: str, value: Any) -> str:
|
|
294
|
+
"""Format a single parameter value. Override in subclasses for custom formatting.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
key: Parameter name
|
|
298
|
+
value: Parameter value
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
Formatted string representation of the value
|
|
302
|
+
"""
|
|
303
|
+
if isinstance(value, float):
|
|
304
|
+
return f"{value:.2f}"
|
|
305
|
+
return str(value)
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
class ChatCompletionInferenceParams(BaseInferenceParams):
|
|
309
|
+
"""Configuration for LLM inference parameters.
|
|
310
|
+
|
|
311
|
+
Attributes:
|
|
312
|
+
generation_type: Type of generation, always "chat-completion" for this class.
|
|
313
|
+
temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling.
|
|
314
|
+
top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
|
|
315
|
+
max_tokens: Maximum number of tokens to generate in the response.
|
|
316
|
+
"""
|
|
317
|
+
|
|
318
|
+
generation_type: Literal[GenerationType.CHAT_COMPLETION] = GenerationType.CHAT_COMPLETION
|
|
319
|
+
temperature: float | DistributionT | None = None
|
|
320
|
+
top_p: float | DistributionT | None = None
|
|
321
|
+
max_tokens: int | None = Field(default=None, ge=1)
|
|
322
|
+
|
|
323
|
+
@property
|
|
324
|
+
def generate_kwargs(self) -> dict[str, Any]:
|
|
325
|
+
result = super().generate_kwargs
|
|
326
|
+
if self.temperature is not None:
|
|
327
|
+
result["temperature"] = (
|
|
328
|
+
self.temperature.sample() if hasattr(self.temperature, "sample") else self.temperature
|
|
329
|
+
)
|
|
330
|
+
if self.top_p is not None:
|
|
331
|
+
result["top_p"] = self.top_p.sample() if hasattr(self.top_p, "sample") else self.top_p
|
|
332
|
+
if self.max_tokens is not None:
|
|
333
|
+
result["max_tokens"] = self.max_tokens
|
|
334
|
+
return result
|
|
335
|
+
|
|
336
|
+
@model_validator(mode="after")
|
|
337
|
+
def _validate_temperature(self) -> Self:
|
|
338
|
+
return self._run_validation(
|
|
339
|
+
value=self.temperature,
|
|
340
|
+
param_name="temperature",
|
|
341
|
+
min_value=MIN_TEMPERATURE,
|
|
342
|
+
max_value=MAX_TEMPERATURE,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
@model_validator(mode="after")
|
|
346
|
+
def _validate_top_p(self) -> Self:
|
|
347
|
+
return self._run_validation(
|
|
348
|
+
value=self.top_p,
|
|
349
|
+
param_name="top_p",
|
|
350
|
+
min_value=MIN_TOP_P,
|
|
351
|
+
max_value=MAX_TOP_P,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
def _run_validation(
|
|
355
|
+
self,
|
|
356
|
+
value: float | DistributionT | None,
|
|
357
|
+
param_name: str,
|
|
358
|
+
min_value: float,
|
|
359
|
+
max_value: float,
|
|
360
|
+
) -> Self:
|
|
361
|
+
if value is None:
|
|
362
|
+
return self
|
|
363
|
+
value_err = ValueError(f"{param_name} defined in model config must be between {min_value} and {max_value}")
|
|
364
|
+
if isinstance(value, Distribution):
|
|
365
|
+
if value.distribution_type == DistributionType.UNIFORM:
|
|
366
|
+
if value.params.low < min_value or value.params.high > max_value:
|
|
367
|
+
raise value_err
|
|
368
|
+
elif value.distribution_type == DistributionType.MANUAL:
|
|
369
|
+
if any(not self._is_value_in_range(v, min_value, max_value) for v in value.params.values):
|
|
370
|
+
raise value_err
|
|
371
|
+
else:
|
|
372
|
+
if not self._is_value_in_range(value, min_value, max_value):
|
|
373
|
+
raise value_err
|
|
374
|
+
return self
|
|
375
|
+
|
|
376
|
+
def _is_value_in_range(self, value: float, min_value: float, max_value: float) -> bool:
|
|
377
|
+
return min_value <= value <= max_value
|
|
378
|
+
|
|
379
|
+
def _format_value(self, key: str, value: Any) -> str:
|
|
380
|
+
"""Format chat completion parameter values, including distributions.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
key: Parameter name
|
|
384
|
+
value: Parameter value
|
|
385
|
+
|
|
386
|
+
Returns:
|
|
387
|
+
Formatted string representation of the value
|
|
388
|
+
"""
|
|
389
|
+
if isinstance(value, dict) and "distribution_type" in value:
|
|
390
|
+
return "dist"
|
|
391
|
+
return super()._format_value(key, value)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
class EmbeddingInferenceParams(BaseInferenceParams):
|
|
395
|
+
"""Configuration for embedding generation parameters.
|
|
396
|
+
|
|
397
|
+
Attributes:
|
|
398
|
+
generation_type: Type of generation, always "embedding" for this class.
|
|
399
|
+
encoding_format: Format of the embedding encoding ("float" or "base64").
|
|
400
|
+
dimensions: Number of dimensions for the embedding.
|
|
401
|
+
"""
|
|
402
|
+
|
|
403
|
+
generation_type: Literal[GenerationType.EMBEDDING] = GenerationType.EMBEDDING
|
|
404
|
+
encoding_format: Literal["float", "base64"] = "float"
|
|
405
|
+
dimensions: int | None = None
|
|
406
|
+
|
|
407
|
+
@property
|
|
408
|
+
def generate_kwargs(self) -> dict[str, float | int]:
|
|
409
|
+
result = super().generate_kwargs
|
|
410
|
+
if self.encoding_format is not None:
|
|
411
|
+
result["encoding_format"] = self.encoding_format
|
|
412
|
+
if self.dimensions is not None:
|
|
413
|
+
result["dimensions"] = self.dimensions
|
|
414
|
+
return result
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
InferenceParamsT: TypeAlias = Annotated[
|
|
418
|
+
ChatCompletionInferenceParams | EmbeddingInferenceParams, Field(discriminator="generation_type")
|
|
419
|
+
]
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
class ModelConfig(ConfigBase):
|
|
423
|
+
"""Configuration for a model used for generation.
|
|
424
|
+
|
|
425
|
+
Attributes:
|
|
426
|
+
alias: User-defined alias to reference in column configurations.
|
|
427
|
+
model: Model identifier (e.g., from build.nvidia.com or other providers).
|
|
428
|
+
inference_parameters: Inference parameters for the model (temperature, top_p, max_tokens, etc.).
|
|
429
|
+
The generation_type is determined by the type of inference_parameters.
|
|
430
|
+
provider: Optional model provider name if using custom providers.
|
|
431
|
+
skip_health_check: Whether to skip the health check for this model. Defaults to False.
|
|
432
|
+
"""
|
|
433
|
+
|
|
434
|
+
alias: str
|
|
435
|
+
model: str
|
|
436
|
+
inference_parameters: InferenceParamsT = Field(default_factory=ChatCompletionInferenceParams)
|
|
437
|
+
provider: str | None = None
|
|
438
|
+
skip_health_check: bool = False
|
|
439
|
+
|
|
440
|
+
@property
|
|
441
|
+
def generation_type(self) -> GenerationType:
|
|
442
|
+
"""Get the generation type from the inference parameters."""
|
|
443
|
+
return self.inference_parameters.generation_type
|
|
444
|
+
|
|
445
|
+
@field_validator("inference_parameters", mode="before")
|
|
446
|
+
@classmethod
|
|
447
|
+
def _convert_inference_parameters(cls, value: Any) -> Any:
|
|
448
|
+
"""Convert raw dict to appropriate inference parameters type based on field presence."""
|
|
449
|
+
if isinstance(value, dict):
|
|
450
|
+
# Infer type from presence of embedding-specific fields
|
|
451
|
+
if "encoding_format" in value or "dimensions" in value:
|
|
452
|
+
return EmbeddingInferenceParams(**value)
|
|
453
|
+
else:
|
|
454
|
+
return ChatCompletionInferenceParams(**value)
|
|
455
|
+
return value
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
class ModelProvider(ConfigBase):
|
|
459
|
+
"""Configuration for a custom model provider.
|
|
460
|
+
|
|
461
|
+
Attributes:
|
|
462
|
+
name: Name of the model provider.
|
|
463
|
+
endpoint: API endpoint URL for the provider.
|
|
464
|
+
provider_type: Provider type (default: "openai"). Determines the API format to use.
|
|
465
|
+
api_key: Optional API key for authentication.
|
|
466
|
+
extra_body: Additional parameters to pass in API requests.
|
|
467
|
+
extra_headers: Additional headers to pass in API requests.
|
|
468
|
+
"""
|
|
469
|
+
|
|
470
|
+
name: str
|
|
471
|
+
endpoint: str
|
|
472
|
+
provider_type: str = "openai"
|
|
473
|
+
api_key: str | None = None
|
|
474
|
+
extra_body: dict[str, Any] | None = None
|
|
475
|
+
extra_headers: dict[str, str] | None = None
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
def load_model_configs(model_configs: list[ModelConfig] | str | Path) -> list[ModelConfig]:
|
|
479
|
+
if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs):
|
|
480
|
+
return model_configs
|
|
481
|
+
json_config = smart_load_yaml(model_configs)
|
|
482
|
+
if "model_configs" not in json_config:
|
|
483
|
+
raise InvalidConfigError(
|
|
484
|
+
"The list of model configs must be provided under model_configs in the configuration file."
|
|
485
|
+
)
|
|
486
|
+
return [ModelConfig.model_validate(mc) for mc in json_config["model_configs"]]
|