data-designer 0.3.8rc1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. data_designer/cli/commands/__init__.py +1 -1
  2. data_designer/interface/__init__.py +21 -1
  3. data_designer/{_version.py → interface/_version.py} +2 -2
  4. data_designer/interface/data_designer.py +8 -11
  5. {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/METADATA +10 -42
  6. data_designer-0.4.0.dist-info/RECORD +39 -0
  7. data_designer/__init__.py +0 -17
  8. data_designer/config/__init__.py +0 -2
  9. data_designer/config/analysis/__init__.py +0 -2
  10. data_designer/config/analysis/column_profilers.py +0 -159
  11. data_designer/config/analysis/column_statistics.py +0 -421
  12. data_designer/config/analysis/dataset_profiler.py +0 -84
  13. data_designer/config/analysis/utils/errors.py +0 -10
  14. data_designer/config/analysis/utils/reporting.py +0 -192
  15. data_designer/config/base.py +0 -69
  16. data_designer/config/column_configs.py +0 -470
  17. data_designer/config/column_types.py +0 -141
  18. data_designer/config/config_builder.py +0 -595
  19. data_designer/config/data_designer_config.py +0 -40
  20. data_designer/config/dataset_builders.py +0 -13
  21. data_designer/config/dataset_metadata.py +0 -18
  22. data_designer/config/default_model_settings.py +0 -121
  23. data_designer/config/errors.py +0 -24
  24. data_designer/config/exports.py +0 -145
  25. data_designer/config/interface.py +0 -55
  26. data_designer/config/models.py +0 -455
  27. data_designer/config/preview_results.py +0 -41
  28. data_designer/config/processors.py +0 -148
  29. data_designer/config/run_config.py +0 -48
  30. data_designer/config/sampler_constraints.py +0 -52
  31. data_designer/config/sampler_params.py +0 -639
  32. data_designer/config/seed.py +0 -116
  33. data_designer/config/seed_source.py +0 -84
  34. data_designer/config/seed_source_types.py +0 -19
  35. data_designer/config/utils/code_lang.py +0 -82
  36. data_designer/config/utils/constants.py +0 -363
  37. data_designer/config/utils/errors.py +0 -21
  38. data_designer/config/utils/info.py +0 -94
  39. data_designer/config/utils/io_helpers.py +0 -258
  40. data_designer/config/utils/misc.py +0 -78
  41. data_designer/config/utils/numerical_helpers.py +0 -30
  42. data_designer/config/utils/type_helpers.py +0 -106
  43. data_designer/config/utils/visualization.py +0 -482
  44. data_designer/config/validator_params.py +0 -94
  45. data_designer/engine/__init__.py +0 -2
  46. data_designer/engine/analysis/column_profilers/base.py +0 -49
  47. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +0 -153
  48. data_designer/engine/analysis/column_profilers/registry.py +0 -22
  49. data_designer/engine/analysis/column_statistics.py +0 -145
  50. data_designer/engine/analysis/dataset_profiler.py +0 -149
  51. data_designer/engine/analysis/errors.py +0 -9
  52. data_designer/engine/analysis/utils/column_statistics_calculations.py +0 -234
  53. data_designer/engine/analysis/utils/judge_score_processing.py +0 -132
  54. data_designer/engine/column_generators/__init__.py +0 -2
  55. data_designer/engine/column_generators/generators/__init__.py +0 -2
  56. data_designer/engine/column_generators/generators/base.py +0 -122
  57. data_designer/engine/column_generators/generators/embedding.py +0 -35
  58. data_designer/engine/column_generators/generators/expression.py +0 -55
  59. data_designer/engine/column_generators/generators/llm_completion.py +0 -113
  60. data_designer/engine/column_generators/generators/samplers.py +0 -69
  61. data_designer/engine/column_generators/generators/seed_dataset.py +0 -144
  62. data_designer/engine/column_generators/generators/validation.py +0 -140
  63. data_designer/engine/column_generators/registry.py +0 -60
  64. data_designer/engine/column_generators/utils/errors.py +0 -15
  65. data_designer/engine/column_generators/utils/generator_classification.py +0 -43
  66. data_designer/engine/column_generators/utils/judge_score_factory.py +0 -58
  67. data_designer/engine/column_generators/utils/prompt_renderer.py +0 -100
  68. data_designer/engine/compiler.py +0 -97
  69. data_designer/engine/configurable_task.py +0 -71
  70. data_designer/engine/dataset_builders/artifact_storage.py +0 -283
  71. data_designer/engine/dataset_builders/column_wise_builder.py +0 -338
  72. data_designer/engine/dataset_builders/errors.py +0 -15
  73. data_designer/engine/dataset_builders/multi_column_configs.py +0 -46
  74. data_designer/engine/dataset_builders/utils/__init__.py +0 -2
  75. data_designer/engine/dataset_builders/utils/concurrency.py +0 -215
  76. data_designer/engine/dataset_builders/utils/config_compiler.py +0 -62
  77. data_designer/engine/dataset_builders/utils/dag.py +0 -62
  78. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +0 -200
  79. data_designer/engine/dataset_builders/utils/errors.py +0 -15
  80. data_designer/engine/errors.py +0 -51
  81. data_designer/engine/model_provider.py +0 -77
  82. data_designer/engine/models/__init__.py +0 -2
  83. data_designer/engine/models/errors.py +0 -300
  84. data_designer/engine/models/facade.py +0 -287
  85. data_designer/engine/models/factory.py +0 -42
  86. data_designer/engine/models/litellm_overrides.py +0 -179
  87. data_designer/engine/models/parsers/__init__.py +0 -2
  88. data_designer/engine/models/parsers/errors.py +0 -34
  89. data_designer/engine/models/parsers/parser.py +0 -235
  90. data_designer/engine/models/parsers/postprocessors.py +0 -93
  91. data_designer/engine/models/parsers/tag_parsers.py +0 -62
  92. data_designer/engine/models/parsers/types.py +0 -84
  93. data_designer/engine/models/recipes/base.py +0 -81
  94. data_designer/engine/models/recipes/response_recipes.py +0 -293
  95. data_designer/engine/models/registry.py +0 -146
  96. data_designer/engine/models/telemetry.py +0 -359
  97. data_designer/engine/models/usage.py +0 -73
  98. data_designer/engine/models/utils.py +0 -38
  99. data_designer/engine/processing/ginja/__init__.py +0 -2
  100. data_designer/engine/processing/ginja/ast.py +0 -65
  101. data_designer/engine/processing/ginja/environment.py +0 -463
  102. data_designer/engine/processing/ginja/exceptions.py +0 -56
  103. data_designer/engine/processing/ginja/record.py +0 -32
  104. data_designer/engine/processing/gsonschema/__init__.py +0 -2
  105. data_designer/engine/processing/gsonschema/exceptions.py +0 -15
  106. data_designer/engine/processing/gsonschema/schema_transformers.py +0 -83
  107. data_designer/engine/processing/gsonschema/types.py +0 -10
  108. data_designer/engine/processing/gsonschema/validators.py +0 -202
  109. data_designer/engine/processing/processors/base.py +0 -13
  110. data_designer/engine/processing/processors/drop_columns.py +0 -42
  111. data_designer/engine/processing/processors/registry.py +0 -25
  112. data_designer/engine/processing/processors/schema_transform.py +0 -49
  113. data_designer/engine/processing/utils.py +0 -169
  114. data_designer/engine/registry/base.py +0 -99
  115. data_designer/engine/registry/data_designer_registry.py +0 -39
  116. data_designer/engine/registry/errors.py +0 -12
  117. data_designer/engine/resources/managed_dataset_generator.py +0 -39
  118. data_designer/engine/resources/managed_dataset_repository.py +0 -197
  119. data_designer/engine/resources/managed_storage.py +0 -65
  120. data_designer/engine/resources/resource_provider.py +0 -77
  121. data_designer/engine/resources/seed_reader.py +0 -154
  122. data_designer/engine/sampling_gen/column.py +0 -91
  123. data_designer/engine/sampling_gen/constraints.py +0 -100
  124. data_designer/engine/sampling_gen/data_sources/base.py +0 -217
  125. data_designer/engine/sampling_gen/data_sources/errors.py +0 -12
  126. data_designer/engine/sampling_gen/data_sources/sources.py +0 -347
  127. data_designer/engine/sampling_gen/entities/__init__.py +0 -2
  128. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  129. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +0 -86
  130. data_designer/engine/sampling_gen/entities/email_address_utils.py +0 -171
  131. data_designer/engine/sampling_gen/entities/errors.py +0 -10
  132. data_designer/engine/sampling_gen/entities/national_id_utils.py +0 -102
  133. data_designer/engine/sampling_gen/entities/person.py +0 -144
  134. data_designer/engine/sampling_gen/entities/phone_number.py +0 -128
  135. data_designer/engine/sampling_gen/errors.py +0 -26
  136. data_designer/engine/sampling_gen/generator.py +0 -122
  137. data_designer/engine/sampling_gen/jinja_utils.py +0 -64
  138. data_designer/engine/sampling_gen/people_gen.py +0 -199
  139. data_designer/engine/sampling_gen/person_constants.py +0 -56
  140. data_designer/engine/sampling_gen/schema.py +0 -147
  141. data_designer/engine/sampling_gen/schema_builder.py +0 -61
  142. data_designer/engine/sampling_gen/utils.py +0 -46
  143. data_designer/engine/secret_resolver.py +0 -82
  144. data_designer/engine/validation.py +0 -367
  145. data_designer/engine/validators/__init__.py +0 -19
  146. data_designer/engine/validators/base.py +0 -38
  147. data_designer/engine/validators/local_callable.py +0 -39
  148. data_designer/engine/validators/python.py +0 -254
  149. data_designer/engine/validators/remote.py +0 -89
  150. data_designer/engine/validators/sql.py +0 -65
  151. data_designer/errors.py +0 -7
  152. data_designer/essentials/__init__.py +0 -33
  153. data_designer/lazy_heavy_imports.py +0 -54
  154. data_designer/logging.py +0 -163
  155. data_designer/plugin_manager.py +0 -78
  156. data_designer/plugins/__init__.py +0 -8
  157. data_designer/plugins/errors.py +0 -15
  158. data_designer/plugins/plugin.py +0 -141
  159. data_designer/plugins/registry.py +0 -88
  160. data_designer/plugins/testing/__init__.py +0 -10
  161. data_designer/plugins/testing/stubs.py +0 -116
  162. data_designer/plugins/testing/utils.py +0 -20
  163. data_designer-0.3.8rc1.dist-info/RECORD +0 -196
  164. data_designer-0.3.8rc1.dist-info/licenses/LICENSE +0 -201
  165. {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/WHEEL +0 -0
  166. {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/entry_points.txt +0 -0
@@ -1,300 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import logging
7
- from collections.abc import Callable
8
- from functools import wraps
9
- from typing import TYPE_CHECKING, Any
10
-
11
- from pydantic import BaseModel
12
-
13
- from data_designer.engine.errors import DataDesignerError
14
- from data_designer.lazy_heavy_imports import litellm
15
-
16
- if TYPE_CHECKING:
17
- import litellm
18
-
19
- logger = logging.getLogger(__name__)
20
-
21
-
22
- def get_exception_primary_cause(exception: BaseException) -> BaseException:
23
- """Returns the primary cause of an exception by walking backwards.
24
-
25
- This recursive walkback halts when it arrives at an exception which
26
- has no provided __cause__ (e.g. __cause__ is None).
27
-
28
- Args:
29
- exception (Exception): An exception to start from.
30
-
31
- Raises:
32
- RecursionError: if for some reason exceptions have circular
33
- dependencies (seems impossible in practice).
34
- """
35
- if exception.__cause__ is None:
36
- return exception
37
- else:
38
- return get_exception_primary_cause(exception.__cause__)
39
-
40
-
41
- class GenerationValidationFailureError(Exception): ...
42
-
43
-
44
- class ModelRateLimitError(DataDesignerError): ...
45
-
46
-
47
- class ModelTimeoutError(DataDesignerError): ...
48
-
49
-
50
- class ModelContextWindowExceededError(DataDesignerError): ...
51
-
52
-
53
- class ModelAuthenticationError(DataDesignerError): ...
54
-
55
-
56
- class ModelPermissionDeniedError(DataDesignerError): ...
57
-
58
-
59
- class ModelNotFoundError(DataDesignerError): ...
60
-
61
-
62
- class ModelUnsupportedParamsError(DataDesignerError): ...
63
-
64
-
65
- class ModelBadRequestError(DataDesignerError): ...
66
-
67
-
68
- class ModelInternalServerError(DataDesignerError): ...
69
-
70
-
71
- class ModelAPIError(DataDesignerError): ...
72
-
73
-
74
- class ModelUnprocessableEntityError(DataDesignerError): ...
75
-
76
-
77
- class ModelAPIConnectionError(DataDesignerError): ...
78
-
79
-
80
- class ModelStructuredOutputError(DataDesignerError): ...
81
-
82
-
83
- class ModelGenerationValidationFailureError(DataDesignerError): ...
84
-
85
-
86
- class FormattedLLMErrorMessage(BaseModel):
87
- cause: str
88
- solution: str
89
-
90
- def __str__(self) -> str:
91
- return "\n".join(
92
- [
93
- " |----------",
94
- f" | Cause: {self.cause}",
95
- f" | Solution: {self.solution}",
96
- " |----------",
97
- ]
98
- )
99
-
100
-
101
- def handle_llm_exceptions(
102
- exception: Exception, model_name: str, model_provider_name: str, purpose: str | None = None
103
- ) -> None:
104
- """Handle LLM-related exceptions and convert them to appropriate DataDesignerError errors.
105
-
106
- This method centralizes the exception handling logic for LLM operations,
107
- making it reusable across different contexts.
108
-
109
- Args:
110
- exception: The exception that was raised
111
- model_name: Name of the model that was being used
112
- model_provider_name: Name of the model provider that was being used
113
- purpose: The purpose of the model usage to show as context in the error message
114
- Raises:
115
- DataDesignerError: A more user-friendly error with appropriate error type and message
116
- """
117
- purpose = purpose or "running generation"
118
- authentication_error = FormattedLLMErrorMessage(
119
- cause=f"The API key provided for model {model_name!r} was found to be invalid or expired while {purpose}.",
120
- solution=f"Verify your API key for model provider and update it in your settings for model provider {model_provider_name!r}.",
121
- )
122
- err_msg_parser = DownstreamLLMExceptionMessageParser(model_name, model_provider_name, purpose)
123
- match exception:
124
- # Common errors that can come from LiteLLM
125
- case litellm.exceptions.APIError():
126
- raise err_msg_parser.parse_api_error(exception, authentication_error) from None
127
-
128
- case litellm.exceptions.APIConnectionError():
129
- raise ModelAPIConnectionError(
130
- FormattedLLMErrorMessage(
131
- cause=f"Connection to model {model_name!r} hosted on model provider {model_provider_name!r} failed while {purpose}.",
132
- solution="Check your network/proxy/firewall settings.",
133
- )
134
- ) from None
135
-
136
- case litellm.exceptions.AuthenticationError():
137
- raise ModelAuthenticationError(authentication_error) from None
138
-
139
- case litellm.exceptions.ContextWindowExceededError():
140
- raise err_msg_parser.parse_context_window_exceeded_error(exception) from None
141
-
142
- case litellm.exceptions.UnsupportedParamsError():
143
- raise ModelUnsupportedParamsError(
144
- FormattedLLMErrorMessage(
145
- cause=f"One or more of the parameters you provided were found to be unsupported by model {model_name!r} while {purpose}.",
146
- solution=f"Review the documentation for model provider {model_provider_name!r} and adjust your request.",
147
- )
148
- ) from None
149
-
150
- case litellm.exceptions.BadRequestError():
151
- raise err_msg_parser.parse_bad_request_error(exception) from None
152
-
153
- case litellm.exceptions.InternalServerError():
154
- raise ModelInternalServerError(
155
- FormattedLLMErrorMessage(
156
- cause=f"Model {model_name!r} is currently experiencing internal server issues while {purpose}.",
157
- solution=f"Try again in a few moments. Check with your model provider {model_provider_name!r} if the issue persists.",
158
- )
159
- ) from None
160
-
161
- case litellm.exceptions.NotFoundError():
162
- raise ModelNotFoundError(
163
- FormattedLLMErrorMessage(
164
- cause=f"The specified model {model_name!r} could not be found while {purpose}.",
165
- solution=f"Check that the model name is correct and supported by your model provider {model_provider_name!r} and try again.",
166
- )
167
- ) from None
168
-
169
- case litellm.exceptions.PermissionDeniedError():
170
- raise ModelPermissionDeniedError(
171
- FormattedLLMErrorMessage(
172
- cause=f"Your API key was found to lack the necessary permissions to use model {model_name!r} while {purpose}.",
173
- solution=f"Use an API key that has the right permissions for the model or use a model the API key in use has access to in model provider {model_provider_name!r}.",
174
- )
175
- ) from None
176
-
177
- case litellm.exceptions.RateLimitError():
178
- raise ModelRateLimitError(
179
- FormattedLLMErrorMessage(
180
- cause=f"You have exceeded the rate limit for model {model_name!r} while {purpose}.",
181
- solution="Wait and try again in a few moments.",
182
- )
183
- ) from None
184
-
185
- case litellm.exceptions.Timeout():
186
- raise ModelTimeoutError(
187
- FormattedLLMErrorMessage(
188
- cause=f"The request to model {model_name!r} timed out while {purpose}.",
189
- solution="Check your connection and try again. You may need to increase the timeout setting for the model.",
190
- )
191
- ) from None
192
-
193
- case litellm.exceptions.UnprocessableEntityError():
194
- raise ModelUnprocessableEntityError(
195
- FormattedLLMErrorMessage(
196
- cause=f"The request to model {model_name!r} failed despite correct request format while {purpose}.",
197
- solution="This is most likely temporary. Try again in a few moments.",
198
- )
199
- ) from None
200
-
201
- # Parsing and validation errors
202
- case GenerationValidationFailureError():
203
- raise ModelGenerationValidationFailureError(
204
- FormattedLLMErrorMessage(
205
- cause=f"The provided output schema was unable to be parsed from model {model_name!r} responses while {purpose}.",
206
- solution="This is most likely temporary as we make additional attempts. If you continue to see more of this, simplify or modify the output schema for structured output and try again. If you are attempting token-intensive tasks like generations with high-reasoning effort, ensure that max_tokens in the model config is high enough to reach completion.",
207
- )
208
- ) from None
209
-
210
- case DataDesignerError():
211
- raise exception from None
212
-
213
- case _:
214
- raise DataDesignerError(
215
- FormattedLLMErrorMessage(
216
- cause=f"An unexpected error occurred while {purpose}.",
217
- solution=f"Review the stack trace for more details: {exception}",
218
- )
219
- ) from exception
220
-
221
-
222
- def catch_llm_exceptions(func: Callable) -> Callable:
223
- """This decorator should be used on any `ModelFacade` method that could potentially raise
224
- exceptions that should turn into upstream user-facing errors.
225
- """
226
-
227
- @wraps(func)
228
- def wrapper(model_facade: Any, *args, **kwargs):
229
- try:
230
- return func(model_facade, *args, **kwargs)
231
- except Exception as e:
232
- logger.debug(
233
- "\n".join(
234
- [
235
- "",
236
- "|----------",
237
- f"| Caught an exception downstream of type {type(e)!r}. Re-raising it below as a custom error with more context.",
238
- "|----------",
239
- ]
240
- ),
241
- exc_info=True,
242
- stack_info=True,
243
- )
244
- handle_llm_exceptions(
245
- e, model_facade.model_name, model_facade.model_provider_name, purpose=kwargs.get("purpose")
246
- )
247
-
248
- return wrapper
249
-
250
-
251
- class DownstreamLLMExceptionMessageParser:
252
- def __init__(self, model_name: str, model_provider_name: str, purpose: str):
253
- self.model_name = model_name
254
- self.model_provider_name = model_provider_name
255
- self.purpose = purpose
256
-
257
- def parse_bad_request_error(self, exception: litellm.exceptions.BadRequestError) -> DataDesignerError:
258
- err_msg = FormattedLLMErrorMessage(
259
- cause=f"The request for model {self.model_name!r} was found to be malformed or missing required parameters while {self.purpose}.",
260
- solution="Check your request parameters and try again.",
261
- )
262
- if "is not a multimodal model" in str(exception):
263
- err_msg = FormattedLLMErrorMessage(
264
- cause=f"Model {self.model_name!r} is not a multimodal model, but it looks like you are trying to provide multimodal context while {self.purpose}.",
265
- solution="Check your request parameters and try again.",
266
- )
267
- return ModelBadRequestError(err_msg)
268
-
269
- def parse_context_window_exceeded_error(
270
- self, exception: litellm.exceptions.ContextWindowExceededError
271
- ) -> DataDesignerError:
272
- cause = f"The input data for model '{self.model_name}' was found to exceed its supported context width while {self.purpose}."
273
- try:
274
- if "OpenAIException - This model's maximum context length is " in str(exception):
275
- openai_exception_cause = (
276
- str(exception).split("OpenAIException - ")[1].split("\n")[0].split(" Please reduce ")[0]
277
- )
278
- cause = f"{cause} {openai_exception_cause}"
279
- except Exception:
280
- pass
281
- finally:
282
- return ModelContextWindowExceededError(
283
- FormattedLLMErrorMessage(
284
- cause=cause,
285
- solution="Check the model's supported max context width. Adjust the length of your input along with completions and try again.",
286
- )
287
- )
288
-
289
- def parse_api_error(
290
- self, exception: litellm.exceptions.InternalServerError, auth_error_msg: FormattedLLMErrorMessage
291
- ) -> DataDesignerError:
292
- if "Error code: 403" in str(exception):
293
- return ModelAuthenticationError(auth_error_msg)
294
-
295
- return ModelAPIError(
296
- FormattedLLMErrorMessage(
297
- cause=f"An unexpected API error occurred with model {self.model_name!r} while {self.purpose}.",
298
- solution=f"Try again in a few moments. Check with your model provider {self.model_provider_name!r} if the issue persists.",
299
- )
300
- )
@@ -1,287 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import logging
7
- from collections.abc import Callable
8
- from copy import deepcopy
9
- from typing import TYPE_CHECKING, Any
10
-
11
- from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
12
- from data_designer.engine.model_provider import ModelProviderRegistry
13
- from data_designer.engine.models.errors import (
14
- GenerationValidationFailureError,
15
- catch_llm_exceptions,
16
- get_exception_primary_cause,
17
- )
18
- from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs
19
- from data_designer.engine.models.parsers.errors import ParserException
20
- from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
21
- from data_designer.engine.models.utils import prompt_to_messages, str_to_message
22
- from data_designer.engine.secret_resolver import SecretResolver
23
- from data_designer.lazy_heavy_imports import litellm
24
-
25
- if TYPE_CHECKING:
26
- import litellm
27
-
28
- logger = logging.getLogger(__name__)
29
-
30
-
31
- class ModelFacade:
32
- def __init__(
33
- self,
34
- model_config: ModelConfig,
35
- secret_resolver: SecretResolver,
36
- model_provider_registry: ModelProviderRegistry,
37
- ):
38
- self._model_config = model_config
39
- self._secret_resolver = secret_resolver
40
- self._model_provider_registry = model_provider_registry
41
- self._litellm_deployment = self._get_litellm_deployment(model_config)
42
- self._router = CustomRouter([self._litellm_deployment], **LiteLLMRouterDefaultKwargs().model_dump())
43
- self._usage_stats = ModelUsageStats()
44
-
45
- @property
46
- def model_name(self) -> str:
47
- return self._model_config.model
48
-
49
- @property
50
- def model_provider(self) -> ModelProvider:
51
- return self._model_provider_registry.get_provider(self._model_config.provider)
52
-
53
- @property
54
- def model_generation_type(self) -> GenerationType:
55
- return self._model_config.generation_type
56
-
57
- @property
58
- def model_provider_name(self) -> str:
59
- return self.model_provider.name
60
-
61
- @property
62
- def model_alias(self) -> str:
63
- return self._model_config.alias
64
-
65
- @property
66
- def usage_stats(self) -> ModelUsageStats:
67
- return self._usage_stats
68
-
69
- def completion(
70
- self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs
71
- ) -> litellm.ModelResponse:
72
- logger.debug(
73
- f"Prompting model {self.model_name!r}...",
74
- extra={"model": self.model_name, "messages": messages},
75
- )
76
- response = None
77
- kwargs = self.consolidate_kwargs(**kwargs)
78
- try:
79
- response = self._router.completion(model=self.model_name, messages=messages, **kwargs)
80
- logger.debug(
81
- f"Received completion from model {self.model_name!r}",
82
- extra={
83
- "model": self.model_name,
84
- "response": response,
85
- "text": response.choices[0].message.content,
86
- "usage": self._usage_stats.model_dump(),
87
- },
88
- )
89
- return response
90
- except Exception as e:
91
- raise e
92
- finally:
93
- if not skip_usage_tracking and response is not None:
94
- self._track_usage(response)
95
-
96
- def consolidate_kwargs(self, **kwargs) -> dict[str, Any]:
97
- # Remove purpose from kwargs to avoid passing it to the model
98
- kwargs.pop("purpose", None)
99
- kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs}
100
- if self.model_provider.extra_body:
101
- kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
102
- if self.model_provider.extra_headers:
103
- kwargs["extra_headers"] = self.model_provider.extra_headers
104
- return kwargs
105
-
106
- @catch_llm_exceptions
107
- def generate_text_embeddings(
108
- self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
109
- ) -> list[list[float]]:
110
- logger.debug(
111
- f"Generating embeddings with model {self.model_name!r}...",
112
- extra={
113
- "model": self.model_name,
114
- "input_count": len(input_texts),
115
- },
116
- )
117
- kwargs = self.consolidate_kwargs(**kwargs)
118
- response = None
119
- try:
120
- response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs)
121
- logger.debug(
122
- f"Received embeddings from model {self.model_name!r}",
123
- extra={
124
- "model": self.model_name,
125
- "embedding_count": len(response.data) if response.data else 0,
126
- "usage": self._usage_stats.model_dump(),
127
- },
128
- )
129
- if response.data and len(response.data) == len(input_texts):
130
- return [data["embedding"] for data in response.data]
131
- else:
132
- raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}")
133
- except Exception as e:
134
- raise e
135
- finally:
136
- if not skip_usage_tracking and response is not None:
137
- self._track_usage_from_embedding(response)
138
-
139
- @catch_llm_exceptions
140
- def generate(
141
- self,
142
- prompt: str,
143
- *,
144
- parser: Callable[[str], Any],
145
- system_prompt: str | None = None,
146
- multi_modal_context: list[dict[str, Any]] | None = None,
147
- max_correction_steps: int = 0,
148
- max_conversation_restarts: int = 0,
149
- skip_usage_tracking: bool = False,
150
- purpose: str | None = None,
151
- **kwargs,
152
- ) -> tuple[Any, str | None]:
153
- """Generate a parsed output with correction steps.
154
-
155
- This generation call will attempt to generate an output which is
156
- valid according to the specified parser, where "valid" implies
157
- that the parser can process the LLM response without raising
158
- an exception.
159
-
160
- `ParserExceptions` are routed back
161
- to the LLM as new rounds in the conversation, where the LLM is provided its
162
- earlier response along with the "user" role responding with the exception string
163
- (not traceback). This will continue for the number of rounds specified by
164
- `max_correction_steps`.
165
-
166
- Args:
167
- prompt (str): Task prompt.
168
- system_prompt (str, optional): Optional system instructions. If not specified,
169
- no system message is provided and the model should use its default system
170
- prompt.
171
- parser (func(str) -> Any): A function applied to the LLM response which processes
172
- an LLM response into some output object.
173
- max_correction_steps (int): Maximum number of correction rounds permitted
174
- within a single conversation. Note, many rounds can lead to increasing
175
- context size without necessarily improving performance -- small language
176
- models can enter repeated cycles which will not be solved with more steps.
177
- Default: `0` (no correction).
178
- max_conversation_restarts (int): Maximum number of full conversation restarts permitted
179
- if generation fails. Default: `0` (no restarts).
180
- skip_usage_tracking (bool): Whether to skip usage tracking. Default: `False`.
181
- purpose (str): The purpose of the model usage to show as context in the error message.
182
- It is expected to be used by the @catch_llm_exceptions decorator.
183
- **kwargs: Additional arguments to pass to the model.
184
-
185
- Raises:
186
- GenerationValidationFailureError: If the maximum number of retries or
187
- correction steps are met and the last response failures on
188
- generation validation.
189
- """
190
- output_obj = None
191
- curr_num_correction_steps = 0
192
- curr_num_restarts = 0
193
- curr_generation_attempt = 0
194
- max_generation_attempts = (max_correction_steps + 1) * (max_conversation_restarts + 1)
195
-
196
- starting_messages = prompt_to_messages(
197
- user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context
198
- )
199
- messages = deepcopy(starting_messages)
200
-
201
- while True:
202
- curr_generation_attempt += 1
203
- logger.debug(
204
- f"Starting generation attempt {curr_generation_attempt} of {max_generation_attempts} attempts."
205
- )
206
-
207
- completion_response = self.completion(messages, skip_usage_tracking=skip_usage_tracking, **kwargs)
208
- response = completion_response.choices[0].message.content or ""
209
- reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None)
210
-
211
- if reasoning_trace:
212
- ## There are generally some extra newlines with how these get parsed.
213
- response = response.strip()
214
- reasoning_trace = reasoning_trace.strip()
215
-
216
- curr_num_correction_steps += 1
217
-
218
- try:
219
- output_obj = parser(response) # type: ignore - if not a string will cause a ParserException below
220
- break
221
- except ParserException as exc:
222
- if max_correction_steps == 0 and max_conversation_restarts == 0:
223
- raise GenerationValidationFailureError(
224
- "Unsuccessful generation attempt. No retries were attempted."
225
- ) from exc
226
- if curr_num_correction_steps <= max_correction_steps:
227
- ## Add turns to loop-back errors for correction
228
- messages += [
229
- str_to_message(content=response, role="assistant"),
230
- str_to_message(content=str(get_exception_primary_cause(exc)), role="user"),
231
- ]
232
- elif curr_num_restarts < max_conversation_restarts:
233
- curr_num_correction_steps = 0
234
- curr_num_restarts += 1
235
- messages = deepcopy(starting_messages)
236
- else:
237
- raise GenerationValidationFailureError(
238
- f"Unsuccessful generation attempt despite {max_generation_attempts} attempts."
239
- ) from exc
240
- return output_obj, reasoning_trace
241
-
242
- def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict:
243
- provider = self._model_provider_registry.get_provider(model_config.provider)
244
- api_key = None
245
- if provider.api_key:
246
- api_key = self._secret_resolver.resolve(provider.api_key)
247
- api_key = api_key or "not-used-but-required"
248
-
249
- litellm_params = litellm.LiteLLM_Params(
250
- model=f"{provider.provider_type}/{model_config.model}",
251
- api_base=provider.endpoint,
252
- api_key=api_key,
253
- )
254
- return {
255
- "model_name": model_config.model,
256
- "litellm_params": litellm_params.model_dump(),
257
- }
258
-
259
- def _track_usage(self, response: litellm.types.utils.ModelResponse | None) -> None:
260
- if response is None:
261
- self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
262
- return
263
- if (
264
- response.usage is not None
265
- and response.usage.prompt_tokens is not None
266
- and response.usage.completion_tokens is not None
267
- ):
268
- self._usage_stats.extend(
269
- token_usage=TokenUsageStats(
270
- input_tokens=response.usage.prompt_tokens,
271
- output_tokens=response.usage.completion_tokens,
272
- ),
273
- request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
274
- )
275
-
276
- def _track_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None:
277
- if response is None:
278
- self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
279
- return
280
- if response.usage is not None and response.usage.prompt_tokens is not None:
281
- self._usage_stats.extend(
282
- token_usage=TokenUsageStats(
283
- input_tokens=response.usage.prompt_tokens,
284
- output_tokens=0,
285
- ),
286
- request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
287
- )
@@ -1,42 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from typing import TYPE_CHECKING
7
-
8
- from data_designer.config.models import ModelConfig
9
- from data_designer.engine.model_provider import ModelProviderRegistry
10
- from data_designer.engine.secret_resolver import SecretResolver
11
-
12
- if TYPE_CHECKING:
13
- from data_designer.engine.models.registry import ModelRegistry
14
-
15
-
16
- def create_model_registry(
17
- *,
18
- model_configs: list[ModelConfig] | None = None,
19
- secret_resolver: SecretResolver,
20
- model_provider_registry: ModelProviderRegistry,
21
- ) -> ModelRegistry:
22
- """Factory function for creating a ModelRegistry instance.
23
-
24
- Heavy dependencies (litellm, httpx) are deferred until this function is called.
25
- This is a factory function pattern - imports inside factories are idiomatic Python
26
- for lazy initialization.
27
- """
28
- from data_designer.engine.models.facade import ModelFacade
29
- from data_designer.engine.models.litellm_overrides import apply_litellm_patches
30
- from data_designer.engine.models.registry import ModelRegistry
31
-
32
- apply_litellm_patches()
33
-
34
- def model_facade_factory(model_config, secret_resolver, model_provider_registry):
35
- return ModelFacade(model_config, secret_resolver, model_provider_registry)
36
-
37
- return ModelRegistry(
38
- model_configs=model_configs,
39
- secret_resolver=secret_resolver,
40
- model_provider_registry=model_provider_registry,
41
- model_facade_factory=model_facade_factory,
42
- )