data-designer 0.3.8rc2__py3-none-any.whl → 0.4.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/cli/commands/__init__.py +1 -1
- data_designer/interface/__init__.py +21 -1
- data_designer/{_version.py → interface/_version.py} +2 -2
- data_designer/interface/data_designer.py +1 -7
- {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0rc1.dist-info}/METADATA +10 -42
- data_designer-0.4.0rc1.dist-info/RECORD +39 -0
- data_designer/__init__.py +0 -17
- data_designer/config/__init__.py +0 -2
- data_designer/config/analysis/__init__.py +0 -2
- data_designer/config/analysis/column_profilers.py +0 -159
- data_designer/config/analysis/column_statistics.py +0 -421
- data_designer/config/analysis/dataset_profiler.py +0 -84
- data_designer/config/analysis/utils/errors.py +0 -10
- data_designer/config/analysis/utils/reporting.py +0 -192
- data_designer/config/base.py +0 -69
- data_designer/config/column_configs.py +0 -470
- data_designer/config/column_types.py +0 -141
- data_designer/config/config_builder.py +0 -595
- data_designer/config/data_designer_config.py +0 -40
- data_designer/config/dataset_builders.py +0 -13
- data_designer/config/dataset_metadata.py +0 -18
- data_designer/config/default_model_settings.py +0 -129
- data_designer/config/errors.py +0 -24
- data_designer/config/exports.py +0 -145
- data_designer/config/interface.py +0 -55
- data_designer/config/models.py +0 -455
- data_designer/config/preview_results.py +0 -41
- data_designer/config/processors.py +0 -148
- data_designer/config/run_config.py +0 -51
- data_designer/config/sampler_constraints.py +0 -52
- data_designer/config/sampler_params.py +0 -639
- data_designer/config/seed.py +0 -116
- data_designer/config/seed_source.py +0 -84
- data_designer/config/seed_source_types.py +0 -19
- data_designer/config/utils/code_lang.py +0 -82
- data_designer/config/utils/constants.py +0 -363
- data_designer/config/utils/errors.py +0 -21
- data_designer/config/utils/info.py +0 -94
- data_designer/config/utils/io_helpers.py +0 -258
- data_designer/config/utils/misc.py +0 -78
- data_designer/config/utils/numerical_helpers.py +0 -30
- data_designer/config/utils/type_helpers.py +0 -106
- data_designer/config/utils/visualization.py +0 -482
- data_designer/config/validator_params.py +0 -94
- data_designer/engine/__init__.py +0 -2
- data_designer/engine/analysis/column_profilers/base.py +0 -49
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +0 -153
- data_designer/engine/analysis/column_profilers/registry.py +0 -22
- data_designer/engine/analysis/column_statistics.py +0 -145
- data_designer/engine/analysis/dataset_profiler.py +0 -149
- data_designer/engine/analysis/errors.py +0 -9
- data_designer/engine/analysis/utils/column_statistics_calculations.py +0 -234
- data_designer/engine/analysis/utils/judge_score_processing.py +0 -132
- data_designer/engine/column_generators/__init__.py +0 -2
- data_designer/engine/column_generators/generators/__init__.py +0 -2
- data_designer/engine/column_generators/generators/base.py +0 -122
- data_designer/engine/column_generators/generators/embedding.py +0 -35
- data_designer/engine/column_generators/generators/expression.py +0 -55
- data_designer/engine/column_generators/generators/llm_completion.py +0 -113
- data_designer/engine/column_generators/generators/samplers.py +0 -69
- data_designer/engine/column_generators/generators/seed_dataset.py +0 -144
- data_designer/engine/column_generators/generators/validation.py +0 -140
- data_designer/engine/column_generators/registry.py +0 -60
- data_designer/engine/column_generators/utils/errors.py +0 -15
- data_designer/engine/column_generators/utils/generator_classification.py +0 -43
- data_designer/engine/column_generators/utils/judge_score_factory.py +0 -58
- data_designer/engine/column_generators/utils/prompt_renderer.py +0 -100
- data_designer/engine/compiler.py +0 -97
- data_designer/engine/configurable_task.py +0 -71
- data_designer/engine/dataset_builders/artifact_storage.py +0 -283
- data_designer/engine/dataset_builders/column_wise_builder.py +0 -335
- data_designer/engine/dataset_builders/errors.py +0 -15
- data_designer/engine/dataset_builders/multi_column_configs.py +0 -46
- data_designer/engine/dataset_builders/utils/__init__.py +0 -2
- data_designer/engine/dataset_builders/utils/concurrency.py +0 -212
- data_designer/engine/dataset_builders/utils/config_compiler.py +0 -62
- data_designer/engine/dataset_builders/utils/dag.py +0 -62
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +0 -200
- data_designer/engine/dataset_builders/utils/errors.py +0 -15
- data_designer/engine/errors.py +0 -51
- data_designer/engine/model_provider.py +0 -77
- data_designer/engine/models/__init__.py +0 -2
- data_designer/engine/models/errors.py +0 -300
- data_designer/engine/models/facade.py +0 -287
- data_designer/engine/models/factory.py +0 -42
- data_designer/engine/models/litellm_overrides.py +0 -179
- data_designer/engine/models/parsers/__init__.py +0 -2
- data_designer/engine/models/parsers/errors.py +0 -34
- data_designer/engine/models/parsers/parser.py +0 -235
- data_designer/engine/models/parsers/postprocessors.py +0 -93
- data_designer/engine/models/parsers/tag_parsers.py +0 -62
- data_designer/engine/models/parsers/types.py +0 -84
- data_designer/engine/models/recipes/base.py +0 -81
- data_designer/engine/models/recipes/response_recipes.py +0 -293
- data_designer/engine/models/registry.py +0 -146
- data_designer/engine/models/telemetry.py +0 -359
- data_designer/engine/models/usage.py +0 -73
- data_designer/engine/models/utils.py +0 -38
- data_designer/engine/processing/ginja/__init__.py +0 -2
- data_designer/engine/processing/ginja/ast.py +0 -65
- data_designer/engine/processing/ginja/environment.py +0 -463
- data_designer/engine/processing/ginja/exceptions.py +0 -56
- data_designer/engine/processing/ginja/record.py +0 -32
- data_designer/engine/processing/gsonschema/__init__.py +0 -2
- data_designer/engine/processing/gsonschema/exceptions.py +0 -15
- data_designer/engine/processing/gsonschema/schema_transformers.py +0 -83
- data_designer/engine/processing/gsonschema/types.py +0 -10
- data_designer/engine/processing/gsonschema/validators.py +0 -202
- data_designer/engine/processing/processors/base.py +0 -13
- data_designer/engine/processing/processors/drop_columns.py +0 -42
- data_designer/engine/processing/processors/registry.py +0 -25
- data_designer/engine/processing/processors/schema_transform.py +0 -49
- data_designer/engine/processing/utils.py +0 -169
- data_designer/engine/registry/base.py +0 -99
- data_designer/engine/registry/data_designer_registry.py +0 -39
- data_designer/engine/registry/errors.py +0 -12
- data_designer/engine/resources/managed_dataset_generator.py +0 -39
- data_designer/engine/resources/managed_dataset_repository.py +0 -197
- data_designer/engine/resources/managed_storage.py +0 -65
- data_designer/engine/resources/resource_provider.py +0 -77
- data_designer/engine/resources/seed_reader.py +0 -154
- data_designer/engine/sampling_gen/column.py +0 -91
- data_designer/engine/sampling_gen/constraints.py +0 -100
- data_designer/engine/sampling_gen/data_sources/base.py +0 -217
- data_designer/engine/sampling_gen/data_sources/errors.py +0 -12
- data_designer/engine/sampling_gen/data_sources/sources.py +0 -347
- data_designer/engine/sampling_gen/entities/__init__.py +0 -2
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +0 -86
- data_designer/engine/sampling_gen/entities/email_address_utils.py +0 -171
- data_designer/engine/sampling_gen/entities/errors.py +0 -10
- data_designer/engine/sampling_gen/entities/national_id_utils.py +0 -102
- data_designer/engine/sampling_gen/entities/person.py +0 -144
- data_designer/engine/sampling_gen/entities/phone_number.py +0 -128
- data_designer/engine/sampling_gen/errors.py +0 -26
- data_designer/engine/sampling_gen/generator.py +0 -122
- data_designer/engine/sampling_gen/jinja_utils.py +0 -64
- data_designer/engine/sampling_gen/people_gen.py +0 -199
- data_designer/engine/sampling_gen/person_constants.py +0 -56
- data_designer/engine/sampling_gen/schema.py +0 -147
- data_designer/engine/sampling_gen/schema_builder.py +0 -61
- data_designer/engine/sampling_gen/utils.py +0 -46
- data_designer/engine/secret_resolver.py +0 -82
- data_designer/engine/validation.py +0 -367
- data_designer/engine/validators/__init__.py +0 -19
- data_designer/engine/validators/base.py +0 -38
- data_designer/engine/validators/local_callable.py +0 -39
- data_designer/engine/validators/python.py +0 -254
- data_designer/engine/validators/remote.py +0 -89
- data_designer/engine/validators/sql.py +0 -65
- data_designer/errors.py +0 -7
- data_designer/essentials/__init__.py +0 -33
- data_designer/lazy_heavy_imports.py +0 -54
- data_designer/logging.py +0 -163
- data_designer/plugin_manager.py +0 -78
- data_designer/plugins/__init__.py +0 -8
- data_designer/plugins/errors.py +0 -15
- data_designer/plugins/plugin.py +0 -141
- data_designer/plugins/registry.py +0 -88
- data_designer/plugins/testing/__init__.py +0 -10
- data_designer/plugins/testing/stubs.py +0 -116
- data_designer/plugins/testing/utils.py +0 -20
- data_designer-0.3.8rc2.dist-info/RECORD +0 -196
- data_designer-0.3.8rc2.dist-info/licenses/LICENSE +0 -201
- {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0rc1.dist-info}/WHEEL +0 -0
- {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0rc1.dist-info}/entry_points.txt +0 -0
|
@@ -1,300 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
import logging
|
|
7
|
-
from collections.abc import Callable
|
|
8
|
-
from functools import wraps
|
|
9
|
-
from typing import TYPE_CHECKING, Any
|
|
10
|
-
|
|
11
|
-
from pydantic import BaseModel
|
|
12
|
-
|
|
13
|
-
from data_designer.engine.errors import DataDesignerError
|
|
14
|
-
from data_designer.lazy_heavy_imports import litellm
|
|
15
|
-
|
|
16
|
-
if TYPE_CHECKING:
|
|
17
|
-
import litellm
|
|
18
|
-
|
|
19
|
-
logger = logging.getLogger(__name__)
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def get_exception_primary_cause(exception: BaseException) -> BaseException:
|
|
23
|
-
"""Returns the primary cause of an exception by walking backwards.
|
|
24
|
-
|
|
25
|
-
This recursive walkback halts when it arrives at an exception which
|
|
26
|
-
has no provided __cause__ (e.g. __cause__ is None).
|
|
27
|
-
|
|
28
|
-
Args:
|
|
29
|
-
exception (Exception): An exception to start from.
|
|
30
|
-
|
|
31
|
-
Raises:
|
|
32
|
-
RecursionError: if for some reason exceptions have circular
|
|
33
|
-
dependencies (seems impossible in practice).
|
|
34
|
-
"""
|
|
35
|
-
if exception.__cause__ is None:
|
|
36
|
-
return exception
|
|
37
|
-
else:
|
|
38
|
-
return get_exception_primary_cause(exception.__cause__)
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
class GenerationValidationFailureError(Exception): ...
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
class ModelRateLimitError(DataDesignerError): ...
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
class ModelTimeoutError(DataDesignerError): ...
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class ModelContextWindowExceededError(DataDesignerError): ...
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class ModelAuthenticationError(DataDesignerError): ...
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class ModelPermissionDeniedError(DataDesignerError): ...
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
class ModelNotFoundError(DataDesignerError): ...
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
class ModelUnsupportedParamsError(DataDesignerError): ...
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
class ModelBadRequestError(DataDesignerError): ...
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
class ModelInternalServerError(DataDesignerError): ...
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
class ModelAPIError(DataDesignerError): ...
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
class ModelUnprocessableEntityError(DataDesignerError): ...
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
class ModelAPIConnectionError(DataDesignerError): ...
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
class ModelStructuredOutputError(DataDesignerError): ...
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
class ModelGenerationValidationFailureError(DataDesignerError): ...
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
class FormattedLLMErrorMessage(BaseModel):
|
|
87
|
-
cause: str
|
|
88
|
-
solution: str
|
|
89
|
-
|
|
90
|
-
def __str__(self) -> str:
|
|
91
|
-
return "\n".join(
|
|
92
|
-
[
|
|
93
|
-
" |----------",
|
|
94
|
-
f" | Cause: {self.cause}",
|
|
95
|
-
f" | Solution: {self.solution}",
|
|
96
|
-
" |----------",
|
|
97
|
-
]
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
def handle_llm_exceptions(
|
|
102
|
-
exception: Exception, model_name: str, model_provider_name: str, purpose: str | None = None
|
|
103
|
-
) -> None:
|
|
104
|
-
"""Handle LLM-related exceptions and convert them to appropriate DataDesignerError errors.
|
|
105
|
-
|
|
106
|
-
This method centralizes the exception handling logic for LLM operations,
|
|
107
|
-
making it reusable across different contexts.
|
|
108
|
-
|
|
109
|
-
Args:
|
|
110
|
-
exception: The exception that was raised
|
|
111
|
-
model_name: Name of the model that was being used
|
|
112
|
-
model_provider_name: Name of the model provider that was being used
|
|
113
|
-
purpose: The purpose of the model usage to show as context in the error message
|
|
114
|
-
Raises:
|
|
115
|
-
DataDesignerError: A more user-friendly error with appropriate error type and message
|
|
116
|
-
"""
|
|
117
|
-
purpose = purpose or "running generation"
|
|
118
|
-
authentication_error = FormattedLLMErrorMessage(
|
|
119
|
-
cause=f"The API key provided for model {model_name!r} was found to be invalid or expired while {purpose}.",
|
|
120
|
-
solution=f"Verify your API key for model provider and update it in your settings for model provider {model_provider_name!r}.",
|
|
121
|
-
)
|
|
122
|
-
err_msg_parser = DownstreamLLMExceptionMessageParser(model_name, model_provider_name, purpose)
|
|
123
|
-
match exception:
|
|
124
|
-
# Common errors that can come from LiteLLM
|
|
125
|
-
case litellm.exceptions.APIError():
|
|
126
|
-
raise err_msg_parser.parse_api_error(exception, authentication_error) from None
|
|
127
|
-
|
|
128
|
-
case litellm.exceptions.APIConnectionError():
|
|
129
|
-
raise ModelAPIConnectionError(
|
|
130
|
-
FormattedLLMErrorMessage(
|
|
131
|
-
cause=f"Connection to model {model_name!r} hosted on model provider {model_provider_name!r} failed while {purpose}.",
|
|
132
|
-
solution="Check your network/proxy/firewall settings.",
|
|
133
|
-
)
|
|
134
|
-
) from None
|
|
135
|
-
|
|
136
|
-
case litellm.exceptions.AuthenticationError():
|
|
137
|
-
raise ModelAuthenticationError(authentication_error) from None
|
|
138
|
-
|
|
139
|
-
case litellm.exceptions.ContextWindowExceededError():
|
|
140
|
-
raise err_msg_parser.parse_context_window_exceeded_error(exception) from None
|
|
141
|
-
|
|
142
|
-
case litellm.exceptions.UnsupportedParamsError():
|
|
143
|
-
raise ModelUnsupportedParamsError(
|
|
144
|
-
FormattedLLMErrorMessage(
|
|
145
|
-
cause=f"One or more of the parameters you provided were found to be unsupported by model {model_name!r} while {purpose}.",
|
|
146
|
-
solution=f"Review the documentation for model provider {model_provider_name!r} and adjust your request.",
|
|
147
|
-
)
|
|
148
|
-
) from None
|
|
149
|
-
|
|
150
|
-
case litellm.exceptions.BadRequestError():
|
|
151
|
-
raise err_msg_parser.parse_bad_request_error(exception) from None
|
|
152
|
-
|
|
153
|
-
case litellm.exceptions.InternalServerError():
|
|
154
|
-
raise ModelInternalServerError(
|
|
155
|
-
FormattedLLMErrorMessage(
|
|
156
|
-
cause=f"Model {model_name!r} is currently experiencing internal server issues while {purpose}.",
|
|
157
|
-
solution=f"Try again in a few moments. Check with your model provider {model_provider_name!r} if the issue persists.",
|
|
158
|
-
)
|
|
159
|
-
) from None
|
|
160
|
-
|
|
161
|
-
case litellm.exceptions.NotFoundError():
|
|
162
|
-
raise ModelNotFoundError(
|
|
163
|
-
FormattedLLMErrorMessage(
|
|
164
|
-
cause=f"The specified model {model_name!r} could not be found while {purpose}.",
|
|
165
|
-
solution=f"Check that the model name is correct and supported by your model provider {model_provider_name!r} and try again.",
|
|
166
|
-
)
|
|
167
|
-
) from None
|
|
168
|
-
|
|
169
|
-
case litellm.exceptions.PermissionDeniedError():
|
|
170
|
-
raise ModelPermissionDeniedError(
|
|
171
|
-
FormattedLLMErrorMessage(
|
|
172
|
-
cause=f"Your API key was found to lack the necessary permissions to use model {model_name!r} while {purpose}.",
|
|
173
|
-
solution=f"Use an API key that has the right permissions for the model or use a model the API key in use has access to in model provider {model_provider_name!r}.",
|
|
174
|
-
)
|
|
175
|
-
) from None
|
|
176
|
-
|
|
177
|
-
case litellm.exceptions.RateLimitError():
|
|
178
|
-
raise ModelRateLimitError(
|
|
179
|
-
FormattedLLMErrorMessage(
|
|
180
|
-
cause=f"You have exceeded the rate limit for model {model_name!r} while {purpose}.",
|
|
181
|
-
solution="Wait and try again in a few moments.",
|
|
182
|
-
)
|
|
183
|
-
) from None
|
|
184
|
-
|
|
185
|
-
case litellm.exceptions.Timeout():
|
|
186
|
-
raise ModelTimeoutError(
|
|
187
|
-
FormattedLLMErrorMessage(
|
|
188
|
-
cause=f"The request to model {model_name!r} timed out while {purpose}.",
|
|
189
|
-
solution="Check your connection and try again. You may need to increase the timeout setting for the model.",
|
|
190
|
-
)
|
|
191
|
-
) from None
|
|
192
|
-
|
|
193
|
-
case litellm.exceptions.UnprocessableEntityError():
|
|
194
|
-
raise ModelUnprocessableEntityError(
|
|
195
|
-
FormattedLLMErrorMessage(
|
|
196
|
-
cause=f"The request to model {model_name!r} failed despite correct request format while {purpose}.",
|
|
197
|
-
solution="This is most likely temporary. Try again in a few moments.",
|
|
198
|
-
)
|
|
199
|
-
) from None
|
|
200
|
-
|
|
201
|
-
# Parsing and validation errors
|
|
202
|
-
case GenerationValidationFailureError():
|
|
203
|
-
raise ModelGenerationValidationFailureError(
|
|
204
|
-
FormattedLLMErrorMessage(
|
|
205
|
-
cause=f"The provided output schema was unable to be parsed from model {model_name!r} responses while {purpose}.",
|
|
206
|
-
solution="This is most likely temporary as we make additional attempts. If you continue to see more of this, simplify or modify the output schema for structured output and try again. If you are attempting token-intensive tasks like generations with high-reasoning effort, ensure that max_tokens in the model config is high enough to reach completion.",
|
|
207
|
-
)
|
|
208
|
-
) from None
|
|
209
|
-
|
|
210
|
-
case DataDesignerError():
|
|
211
|
-
raise exception from None
|
|
212
|
-
|
|
213
|
-
case _:
|
|
214
|
-
raise DataDesignerError(
|
|
215
|
-
FormattedLLMErrorMessage(
|
|
216
|
-
cause=f"An unexpected error occurred while {purpose}.",
|
|
217
|
-
solution=f"Review the stack trace for more details: {exception}",
|
|
218
|
-
)
|
|
219
|
-
) from exception
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
def catch_llm_exceptions(func: Callable) -> Callable:
|
|
223
|
-
"""This decorator should be used on any `ModelFacade` method that could potentially raise
|
|
224
|
-
exceptions that should turn into upstream user-facing errors.
|
|
225
|
-
"""
|
|
226
|
-
|
|
227
|
-
@wraps(func)
|
|
228
|
-
def wrapper(model_facade: Any, *args, **kwargs):
|
|
229
|
-
try:
|
|
230
|
-
return func(model_facade, *args, **kwargs)
|
|
231
|
-
except Exception as e:
|
|
232
|
-
logger.debug(
|
|
233
|
-
"\n".join(
|
|
234
|
-
[
|
|
235
|
-
"",
|
|
236
|
-
"|----------",
|
|
237
|
-
f"| Caught an exception downstream of type {type(e)!r}. Re-raising it below as a custom error with more context.",
|
|
238
|
-
"|----------",
|
|
239
|
-
]
|
|
240
|
-
),
|
|
241
|
-
exc_info=True,
|
|
242
|
-
stack_info=True,
|
|
243
|
-
)
|
|
244
|
-
handle_llm_exceptions(
|
|
245
|
-
e, model_facade.model_name, model_facade.model_provider_name, purpose=kwargs.get("purpose")
|
|
246
|
-
)
|
|
247
|
-
|
|
248
|
-
return wrapper
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
class DownstreamLLMExceptionMessageParser:
|
|
252
|
-
def __init__(self, model_name: str, model_provider_name: str, purpose: str):
|
|
253
|
-
self.model_name = model_name
|
|
254
|
-
self.model_provider_name = model_provider_name
|
|
255
|
-
self.purpose = purpose
|
|
256
|
-
|
|
257
|
-
def parse_bad_request_error(self, exception: litellm.exceptions.BadRequestError) -> DataDesignerError:
|
|
258
|
-
err_msg = FormattedLLMErrorMessage(
|
|
259
|
-
cause=f"The request for model {self.model_name!r} was found to be malformed or missing required parameters while {self.purpose}.",
|
|
260
|
-
solution="Check your request parameters and try again.",
|
|
261
|
-
)
|
|
262
|
-
if "is not a multimodal model" in str(exception):
|
|
263
|
-
err_msg = FormattedLLMErrorMessage(
|
|
264
|
-
cause=f"Model {self.model_name!r} is not a multimodal model, but it looks like you are trying to provide multimodal context while {self.purpose}.",
|
|
265
|
-
solution="Check your request parameters and try again.",
|
|
266
|
-
)
|
|
267
|
-
return ModelBadRequestError(err_msg)
|
|
268
|
-
|
|
269
|
-
def parse_context_window_exceeded_error(
|
|
270
|
-
self, exception: litellm.exceptions.ContextWindowExceededError
|
|
271
|
-
) -> DataDesignerError:
|
|
272
|
-
cause = f"The input data for model '{self.model_name}' was found to exceed its supported context width while {self.purpose}."
|
|
273
|
-
try:
|
|
274
|
-
if "OpenAIException - This model's maximum context length is " in str(exception):
|
|
275
|
-
openai_exception_cause = (
|
|
276
|
-
str(exception).split("OpenAIException - ")[1].split("\n")[0].split(" Please reduce ")[0]
|
|
277
|
-
)
|
|
278
|
-
cause = f"{cause} {openai_exception_cause}"
|
|
279
|
-
except Exception:
|
|
280
|
-
pass
|
|
281
|
-
finally:
|
|
282
|
-
return ModelContextWindowExceededError(
|
|
283
|
-
FormattedLLMErrorMessage(
|
|
284
|
-
cause=cause,
|
|
285
|
-
solution="Check the model's supported max context width. Adjust the length of your input along with completions and try again.",
|
|
286
|
-
)
|
|
287
|
-
)
|
|
288
|
-
|
|
289
|
-
def parse_api_error(
|
|
290
|
-
self, exception: litellm.exceptions.InternalServerError, auth_error_msg: FormattedLLMErrorMessage
|
|
291
|
-
) -> DataDesignerError:
|
|
292
|
-
if "Error code: 403" in str(exception):
|
|
293
|
-
return ModelAuthenticationError(auth_error_msg)
|
|
294
|
-
|
|
295
|
-
return ModelAPIError(
|
|
296
|
-
FormattedLLMErrorMessage(
|
|
297
|
-
cause=f"An unexpected API error occurred with model {self.model_name!r} while {self.purpose}.",
|
|
298
|
-
solution=f"Try again in a few moments. Check with your model provider {self.model_provider_name!r} if the issue persists.",
|
|
299
|
-
)
|
|
300
|
-
)
|
|
@@ -1,287 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
import logging
|
|
7
|
-
from collections.abc import Callable
|
|
8
|
-
from copy import deepcopy
|
|
9
|
-
from typing import TYPE_CHECKING, Any
|
|
10
|
-
|
|
11
|
-
from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
|
|
12
|
-
from data_designer.engine.model_provider import ModelProviderRegistry
|
|
13
|
-
from data_designer.engine.models.errors import (
|
|
14
|
-
GenerationValidationFailureError,
|
|
15
|
-
catch_llm_exceptions,
|
|
16
|
-
get_exception_primary_cause,
|
|
17
|
-
)
|
|
18
|
-
from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs
|
|
19
|
-
from data_designer.engine.models.parsers.errors import ParserException
|
|
20
|
-
from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
|
|
21
|
-
from data_designer.engine.models.utils import prompt_to_messages, str_to_message
|
|
22
|
-
from data_designer.engine.secret_resolver import SecretResolver
|
|
23
|
-
from data_designer.lazy_heavy_imports import litellm
|
|
24
|
-
|
|
25
|
-
if TYPE_CHECKING:
|
|
26
|
-
import litellm
|
|
27
|
-
|
|
28
|
-
logger = logging.getLogger(__name__)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
class ModelFacade:
|
|
32
|
-
def __init__(
|
|
33
|
-
self,
|
|
34
|
-
model_config: ModelConfig,
|
|
35
|
-
secret_resolver: SecretResolver,
|
|
36
|
-
model_provider_registry: ModelProviderRegistry,
|
|
37
|
-
):
|
|
38
|
-
self._model_config = model_config
|
|
39
|
-
self._secret_resolver = secret_resolver
|
|
40
|
-
self._model_provider_registry = model_provider_registry
|
|
41
|
-
self._litellm_deployment = self._get_litellm_deployment(model_config)
|
|
42
|
-
self._router = CustomRouter([self._litellm_deployment], **LiteLLMRouterDefaultKwargs().model_dump())
|
|
43
|
-
self._usage_stats = ModelUsageStats()
|
|
44
|
-
|
|
45
|
-
@property
|
|
46
|
-
def model_name(self) -> str:
|
|
47
|
-
return self._model_config.model
|
|
48
|
-
|
|
49
|
-
@property
|
|
50
|
-
def model_provider(self) -> ModelProvider:
|
|
51
|
-
return self._model_provider_registry.get_provider(self._model_config.provider)
|
|
52
|
-
|
|
53
|
-
@property
|
|
54
|
-
def model_generation_type(self) -> GenerationType:
|
|
55
|
-
return self._model_config.generation_type
|
|
56
|
-
|
|
57
|
-
@property
|
|
58
|
-
def model_provider_name(self) -> str:
|
|
59
|
-
return self.model_provider.name
|
|
60
|
-
|
|
61
|
-
@property
|
|
62
|
-
def model_alias(self) -> str:
|
|
63
|
-
return self._model_config.alias
|
|
64
|
-
|
|
65
|
-
@property
|
|
66
|
-
def usage_stats(self) -> ModelUsageStats:
|
|
67
|
-
return self._usage_stats
|
|
68
|
-
|
|
69
|
-
def completion(
|
|
70
|
-
self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs
|
|
71
|
-
) -> litellm.ModelResponse:
|
|
72
|
-
logger.debug(
|
|
73
|
-
f"Prompting model {self.model_name!r}...",
|
|
74
|
-
extra={"model": self.model_name, "messages": messages},
|
|
75
|
-
)
|
|
76
|
-
response = None
|
|
77
|
-
kwargs = self.consolidate_kwargs(**kwargs)
|
|
78
|
-
try:
|
|
79
|
-
response = self._router.completion(model=self.model_name, messages=messages, **kwargs)
|
|
80
|
-
logger.debug(
|
|
81
|
-
f"Received completion from model {self.model_name!r}",
|
|
82
|
-
extra={
|
|
83
|
-
"model": self.model_name,
|
|
84
|
-
"response": response,
|
|
85
|
-
"text": response.choices[0].message.content,
|
|
86
|
-
"usage": self._usage_stats.model_dump(),
|
|
87
|
-
},
|
|
88
|
-
)
|
|
89
|
-
return response
|
|
90
|
-
except Exception as e:
|
|
91
|
-
raise e
|
|
92
|
-
finally:
|
|
93
|
-
if not skip_usage_tracking and response is not None:
|
|
94
|
-
self._track_usage(response)
|
|
95
|
-
|
|
96
|
-
def consolidate_kwargs(self, **kwargs) -> dict[str, Any]:
|
|
97
|
-
# Remove purpose from kwargs to avoid passing it to the model
|
|
98
|
-
kwargs.pop("purpose", None)
|
|
99
|
-
kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs}
|
|
100
|
-
if self.model_provider.extra_body:
|
|
101
|
-
kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
|
|
102
|
-
if self.model_provider.extra_headers:
|
|
103
|
-
kwargs["extra_headers"] = self.model_provider.extra_headers
|
|
104
|
-
return kwargs
|
|
105
|
-
|
|
106
|
-
@catch_llm_exceptions
|
|
107
|
-
def generate_text_embeddings(
|
|
108
|
-
self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
|
|
109
|
-
) -> list[list[float]]:
|
|
110
|
-
logger.debug(
|
|
111
|
-
f"Generating embeddings with model {self.model_name!r}...",
|
|
112
|
-
extra={
|
|
113
|
-
"model": self.model_name,
|
|
114
|
-
"input_count": len(input_texts),
|
|
115
|
-
},
|
|
116
|
-
)
|
|
117
|
-
kwargs = self.consolidate_kwargs(**kwargs)
|
|
118
|
-
response = None
|
|
119
|
-
try:
|
|
120
|
-
response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs)
|
|
121
|
-
logger.debug(
|
|
122
|
-
f"Received embeddings from model {self.model_name!r}",
|
|
123
|
-
extra={
|
|
124
|
-
"model": self.model_name,
|
|
125
|
-
"embedding_count": len(response.data) if response.data else 0,
|
|
126
|
-
"usage": self._usage_stats.model_dump(),
|
|
127
|
-
},
|
|
128
|
-
)
|
|
129
|
-
if response.data and len(response.data) == len(input_texts):
|
|
130
|
-
return [data["embedding"] for data in response.data]
|
|
131
|
-
else:
|
|
132
|
-
raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}")
|
|
133
|
-
except Exception as e:
|
|
134
|
-
raise e
|
|
135
|
-
finally:
|
|
136
|
-
if not skip_usage_tracking and response is not None:
|
|
137
|
-
self._track_usage_from_embedding(response)
|
|
138
|
-
|
|
139
|
-
@catch_llm_exceptions
|
|
140
|
-
def generate(
|
|
141
|
-
self,
|
|
142
|
-
prompt: str,
|
|
143
|
-
*,
|
|
144
|
-
parser: Callable[[str], Any],
|
|
145
|
-
system_prompt: str | None = None,
|
|
146
|
-
multi_modal_context: list[dict[str, Any]] | None = None,
|
|
147
|
-
max_correction_steps: int = 0,
|
|
148
|
-
max_conversation_restarts: int = 0,
|
|
149
|
-
skip_usage_tracking: bool = False,
|
|
150
|
-
purpose: str | None = None,
|
|
151
|
-
**kwargs,
|
|
152
|
-
) -> tuple[Any, str | None]:
|
|
153
|
-
"""Generate a parsed output with correction steps.
|
|
154
|
-
|
|
155
|
-
This generation call will attempt to generate an output which is
|
|
156
|
-
valid according to the specified parser, where "valid" implies
|
|
157
|
-
that the parser can process the LLM response without raising
|
|
158
|
-
an exception.
|
|
159
|
-
|
|
160
|
-
`ParserExceptions` are routed back
|
|
161
|
-
to the LLM as new rounds in the conversation, where the LLM is provided its
|
|
162
|
-
earlier response along with the "user" role responding with the exception string
|
|
163
|
-
(not traceback). This will continue for the number of rounds specified by
|
|
164
|
-
`max_correction_steps`.
|
|
165
|
-
|
|
166
|
-
Args:
|
|
167
|
-
prompt (str): Task prompt.
|
|
168
|
-
system_prompt (str, optional): Optional system instructions. If not specified,
|
|
169
|
-
no system message is provided and the model should use its default system
|
|
170
|
-
prompt.
|
|
171
|
-
parser (func(str) -> Any): A function applied to the LLM response which processes
|
|
172
|
-
an LLM response into some output object.
|
|
173
|
-
max_correction_steps (int): Maximum number of correction rounds permitted
|
|
174
|
-
within a single conversation. Note, many rounds can lead to increasing
|
|
175
|
-
context size without necessarily improving performance -- small language
|
|
176
|
-
models can enter repeated cycles which will not be solved with more steps.
|
|
177
|
-
Default: `0` (no correction).
|
|
178
|
-
max_conversation_restarts (int): Maximum number of full conversation restarts permitted
|
|
179
|
-
if generation fails. Default: `0` (no restarts).
|
|
180
|
-
skip_usage_tracking (bool): Whether to skip usage tracking. Default: `False`.
|
|
181
|
-
purpose (str): The purpose of the model usage to show as context in the error message.
|
|
182
|
-
It is expected to be used by the @catch_llm_exceptions decorator.
|
|
183
|
-
**kwargs: Additional arguments to pass to the model.
|
|
184
|
-
|
|
185
|
-
Raises:
|
|
186
|
-
GenerationValidationFailureError: If the maximum number of retries or
|
|
187
|
-
correction steps are met and the last response failures on
|
|
188
|
-
generation validation.
|
|
189
|
-
"""
|
|
190
|
-
output_obj = None
|
|
191
|
-
curr_num_correction_steps = 0
|
|
192
|
-
curr_num_restarts = 0
|
|
193
|
-
curr_generation_attempt = 0
|
|
194
|
-
max_generation_attempts = (max_correction_steps + 1) * (max_conversation_restarts + 1)
|
|
195
|
-
|
|
196
|
-
starting_messages = prompt_to_messages(
|
|
197
|
-
user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context
|
|
198
|
-
)
|
|
199
|
-
messages = deepcopy(starting_messages)
|
|
200
|
-
|
|
201
|
-
while True:
|
|
202
|
-
curr_generation_attempt += 1
|
|
203
|
-
logger.debug(
|
|
204
|
-
f"Starting generation attempt {curr_generation_attempt} of {max_generation_attempts} attempts."
|
|
205
|
-
)
|
|
206
|
-
|
|
207
|
-
completion_response = self.completion(messages, skip_usage_tracking=skip_usage_tracking, **kwargs)
|
|
208
|
-
response = completion_response.choices[0].message.content or ""
|
|
209
|
-
reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None)
|
|
210
|
-
|
|
211
|
-
if reasoning_trace:
|
|
212
|
-
## There are generally some extra newlines with how these get parsed.
|
|
213
|
-
response = response.strip()
|
|
214
|
-
reasoning_trace = reasoning_trace.strip()
|
|
215
|
-
|
|
216
|
-
curr_num_correction_steps += 1
|
|
217
|
-
|
|
218
|
-
try:
|
|
219
|
-
output_obj = parser(response) # type: ignore - if not a string will cause a ParserException below
|
|
220
|
-
break
|
|
221
|
-
except ParserException as exc:
|
|
222
|
-
if max_correction_steps == 0 and max_conversation_restarts == 0:
|
|
223
|
-
raise GenerationValidationFailureError(
|
|
224
|
-
"Unsuccessful generation attempt. No retries were attempted."
|
|
225
|
-
) from exc
|
|
226
|
-
if curr_num_correction_steps <= max_correction_steps:
|
|
227
|
-
## Add turns to loop-back errors for correction
|
|
228
|
-
messages += [
|
|
229
|
-
str_to_message(content=response, role="assistant"),
|
|
230
|
-
str_to_message(content=str(get_exception_primary_cause(exc)), role="user"),
|
|
231
|
-
]
|
|
232
|
-
elif curr_num_restarts < max_conversation_restarts:
|
|
233
|
-
curr_num_correction_steps = 0
|
|
234
|
-
curr_num_restarts += 1
|
|
235
|
-
messages = deepcopy(starting_messages)
|
|
236
|
-
else:
|
|
237
|
-
raise GenerationValidationFailureError(
|
|
238
|
-
f"Unsuccessful generation attempt despite {max_generation_attempts} attempts."
|
|
239
|
-
) from exc
|
|
240
|
-
return output_obj, reasoning_trace
|
|
241
|
-
|
|
242
|
-
def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict:
|
|
243
|
-
provider = self._model_provider_registry.get_provider(model_config.provider)
|
|
244
|
-
api_key = None
|
|
245
|
-
if provider.api_key:
|
|
246
|
-
api_key = self._secret_resolver.resolve(provider.api_key)
|
|
247
|
-
api_key = api_key or "not-used-but-required"
|
|
248
|
-
|
|
249
|
-
litellm_params = litellm.LiteLLM_Params(
|
|
250
|
-
model=f"{provider.provider_type}/{model_config.model}",
|
|
251
|
-
api_base=provider.endpoint,
|
|
252
|
-
api_key=api_key,
|
|
253
|
-
)
|
|
254
|
-
return {
|
|
255
|
-
"model_name": model_config.model,
|
|
256
|
-
"litellm_params": litellm_params.model_dump(),
|
|
257
|
-
}
|
|
258
|
-
|
|
259
|
-
def _track_usage(self, response: litellm.types.utils.ModelResponse | None) -> None:
|
|
260
|
-
if response is None:
|
|
261
|
-
self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
|
|
262
|
-
return
|
|
263
|
-
if (
|
|
264
|
-
response.usage is not None
|
|
265
|
-
and response.usage.prompt_tokens is not None
|
|
266
|
-
and response.usage.completion_tokens is not None
|
|
267
|
-
):
|
|
268
|
-
self._usage_stats.extend(
|
|
269
|
-
token_usage=TokenUsageStats(
|
|
270
|
-
input_tokens=response.usage.prompt_tokens,
|
|
271
|
-
output_tokens=response.usage.completion_tokens,
|
|
272
|
-
),
|
|
273
|
-
request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
|
|
274
|
-
)
|
|
275
|
-
|
|
276
|
-
def _track_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None:
|
|
277
|
-
if response is None:
|
|
278
|
-
self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
|
|
279
|
-
return
|
|
280
|
-
if response.usage is not None and response.usage.prompt_tokens is not None:
|
|
281
|
-
self._usage_stats.extend(
|
|
282
|
-
token_usage=TokenUsageStats(
|
|
283
|
-
input_tokens=response.usage.prompt_tokens,
|
|
284
|
-
output_tokens=0,
|
|
285
|
-
),
|
|
286
|
-
request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
|
|
287
|
-
)
|
|
@@ -1,42 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
from typing import TYPE_CHECKING
|
|
7
|
-
|
|
8
|
-
from data_designer.config.models import ModelConfig
|
|
9
|
-
from data_designer.engine.model_provider import ModelProviderRegistry
|
|
10
|
-
from data_designer.engine.secret_resolver import SecretResolver
|
|
11
|
-
|
|
12
|
-
if TYPE_CHECKING:
|
|
13
|
-
from data_designer.engine.models.registry import ModelRegistry
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def create_model_registry(
|
|
17
|
-
*,
|
|
18
|
-
model_configs: list[ModelConfig] | None = None,
|
|
19
|
-
secret_resolver: SecretResolver,
|
|
20
|
-
model_provider_registry: ModelProviderRegistry,
|
|
21
|
-
) -> ModelRegistry:
|
|
22
|
-
"""Factory function for creating a ModelRegistry instance.
|
|
23
|
-
|
|
24
|
-
Heavy dependencies (litellm, httpx) are deferred until this function is called.
|
|
25
|
-
This is a factory function pattern - imports inside factories are idiomatic Python
|
|
26
|
-
for lazy initialization.
|
|
27
|
-
"""
|
|
28
|
-
from data_designer.engine.models.facade import ModelFacade
|
|
29
|
-
from data_designer.engine.models.litellm_overrides import apply_litellm_patches
|
|
30
|
-
from data_designer.engine.models.registry import ModelRegistry
|
|
31
|
-
|
|
32
|
-
apply_litellm_patches()
|
|
33
|
-
|
|
34
|
-
def model_facade_factory(model_config, secret_resolver, model_provider_registry):
|
|
35
|
-
return ModelFacade(model_config, secret_resolver, model_provider_registry)
|
|
36
|
-
|
|
37
|
-
return ModelRegistry(
|
|
38
|
-
model_configs=model_configs,
|
|
39
|
-
secret_resolver=secret_resolver,
|
|
40
|
-
model_provider_registry=model_provider_registry,
|
|
41
|
-
model_facade_factory=model_facade_factory,
|
|
42
|
-
)
|