data-designer-engine 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. data_designer/engine/__init__.py +2 -0
  2. data_designer/engine/_version.py +34 -0
  3. data_designer/engine/analysis/column_profilers/base.py +49 -0
  4. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +153 -0
  5. data_designer/engine/analysis/column_profilers/registry.py +22 -0
  6. data_designer/engine/analysis/column_statistics.py +145 -0
  7. data_designer/engine/analysis/dataset_profiler.py +149 -0
  8. data_designer/engine/analysis/errors.py +9 -0
  9. data_designer/engine/analysis/utils/column_statistics_calculations.py +234 -0
  10. data_designer/engine/analysis/utils/judge_score_processing.py +132 -0
  11. data_designer/engine/column_generators/__init__.py +2 -0
  12. data_designer/engine/column_generators/generators/__init__.py +2 -0
  13. data_designer/engine/column_generators/generators/base.py +122 -0
  14. data_designer/engine/column_generators/generators/embedding.py +35 -0
  15. data_designer/engine/column_generators/generators/expression.py +55 -0
  16. data_designer/engine/column_generators/generators/llm_completion.py +116 -0
  17. data_designer/engine/column_generators/generators/samplers.py +69 -0
  18. data_designer/engine/column_generators/generators/seed_dataset.py +144 -0
  19. data_designer/engine/column_generators/generators/validation.py +140 -0
  20. data_designer/engine/column_generators/registry.py +60 -0
  21. data_designer/engine/column_generators/utils/errors.py +15 -0
  22. data_designer/engine/column_generators/utils/generator_classification.py +43 -0
  23. data_designer/engine/column_generators/utils/judge_score_factory.py +58 -0
  24. data_designer/engine/column_generators/utils/prompt_renderer.py +100 -0
  25. data_designer/engine/compiler.py +97 -0
  26. data_designer/engine/configurable_task.py +71 -0
  27. data_designer/engine/dataset_builders/artifact_storage.py +283 -0
  28. data_designer/engine/dataset_builders/column_wise_builder.py +354 -0
  29. data_designer/engine/dataset_builders/errors.py +15 -0
  30. data_designer/engine/dataset_builders/multi_column_configs.py +46 -0
  31. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  32. data_designer/engine/dataset_builders/utils/concurrency.py +212 -0
  33. data_designer/engine/dataset_builders/utils/config_compiler.py +62 -0
  34. data_designer/engine/dataset_builders/utils/dag.py +62 -0
  35. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +200 -0
  36. data_designer/engine/dataset_builders/utils/errors.py +15 -0
  37. data_designer/engine/dataset_builders/utils/progress_tracker.py +122 -0
  38. data_designer/engine/errors.py +51 -0
  39. data_designer/engine/model_provider.py +77 -0
  40. data_designer/engine/models/__init__.py +2 -0
  41. data_designer/engine/models/errors.py +300 -0
  42. data_designer/engine/models/facade.py +284 -0
  43. data_designer/engine/models/factory.py +42 -0
  44. data_designer/engine/models/litellm_overrides.py +179 -0
  45. data_designer/engine/models/parsers/__init__.py +2 -0
  46. data_designer/engine/models/parsers/errors.py +34 -0
  47. data_designer/engine/models/parsers/parser.py +235 -0
  48. data_designer/engine/models/parsers/postprocessors.py +93 -0
  49. data_designer/engine/models/parsers/tag_parsers.py +62 -0
  50. data_designer/engine/models/parsers/types.py +84 -0
  51. data_designer/engine/models/recipes/base.py +81 -0
  52. data_designer/engine/models/recipes/response_recipes.py +293 -0
  53. data_designer/engine/models/registry.py +151 -0
  54. data_designer/engine/models/telemetry.py +362 -0
  55. data_designer/engine/models/usage.py +73 -0
  56. data_designer/engine/models/utils.py +101 -0
  57. data_designer/engine/processing/ginja/__init__.py +2 -0
  58. data_designer/engine/processing/ginja/ast.py +65 -0
  59. data_designer/engine/processing/ginja/environment.py +463 -0
  60. data_designer/engine/processing/ginja/exceptions.py +56 -0
  61. data_designer/engine/processing/ginja/record.py +32 -0
  62. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  63. data_designer/engine/processing/gsonschema/exceptions.py +15 -0
  64. data_designer/engine/processing/gsonschema/schema_transformers.py +83 -0
  65. data_designer/engine/processing/gsonschema/types.py +10 -0
  66. data_designer/engine/processing/gsonschema/validators.py +202 -0
  67. data_designer/engine/processing/processors/base.py +13 -0
  68. data_designer/engine/processing/processors/drop_columns.py +42 -0
  69. data_designer/engine/processing/processors/registry.py +25 -0
  70. data_designer/engine/processing/processors/schema_transform.py +71 -0
  71. data_designer/engine/processing/utils.py +169 -0
  72. data_designer/engine/registry/base.py +99 -0
  73. data_designer/engine/registry/data_designer_registry.py +39 -0
  74. data_designer/engine/registry/errors.py +12 -0
  75. data_designer/engine/resources/managed_dataset_generator.py +39 -0
  76. data_designer/engine/resources/managed_dataset_repository.py +197 -0
  77. data_designer/engine/resources/managed_storage.py +65 -0
  78. data_designer/engine/resources/resource_provider.py +77 -0
  79. data_designer/engine/resources/seed_reader.py +154 -0
  80. data_designer/engine/sampling_gen/column.py +91 -0
  81. data_designer/engine/sampling_gen/constraints.py +100 -0
  82. data_designer/engine/sampling_gen/data_sources/base.py +217 -0
  83. data_designer/engine/sampling_gen/data_sources/errors.py +12 -0
  84. data_designer/engine/sampling_gen/data_sources/sources.py +347 -0
  85. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  86. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  87. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +90 -0
  88. data_designer/engine/sampling_gen/entities/email_address_utils.py +171 -0
  89. data_designer/engine/sampling_gen/entities/errors.py +10 -0
  90. data_designer/engine/sampling_gen/entities/national_id_utils.py +102 -0
  91. data_designer/engine/sampling_gen/entities/person.py +144 -0
  92. data_designer/engine/sampling_gen/entities/phone_number.py +128 -0
  93. data_designer/engine/sampling_gen/errors.py +26 -0
  94. data_designer/engine/sampling_gen/generator.py +122 -0
  95. data_designer/engine/sampling_gen/jinja_utils.py +64 -0
  96. data_designer/engine/sampling_gen/people_gen.py +199 -0
  97. data_designer/engine/sampling_gen/person_constants.py +56 -0
  98. data_designer/engine/sampling_gen/schema.py +147 -0
  99. data_designer/engine/sampling_gen/schema_builder.py +61 -0
  100. data_designer/engine/sampling_gen/utils.py +46 -0
  101. data_designer/engine/secret_resolver.py +82 -0
  102. data_designer/engine/testing/__init__.py +12 -0
  103. data_designer/engine/testing/stubs.py +133 -0
  104. data_designer/engine/testing/utils.py +20 -0
  105. data_designer/engine/validation.py +367 -0
  106. data_designer/engine/validators/__init__.py +19 -0
  107. data_designer/engine/validators/base.py +38 -0
  108. data_designer/engine/validators/local_callable.py +39 -0
  109. data_designer/engine/validators/python.py +254 -0
  110. data_designer/engine/validators/remote.py +89 -0
  111. data_designer/engine/validators/sql.py +65 -0
  112. data_designer_engine-0.4.0.dist-info/METADATA +50 -0
  113. data_designer_engine-0.4.0.dist-info/RECORD +114 -0
  114. data_designer_engine-0.4.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,300 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import logging
7
+ from collections.abc import Callable
8
+ from functools import wraps
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ from pydantic import BaseModel
12
+
13
+ from data_designer.engine.errors import DataDesignerError
14
+ from data_designer.lazy_heavy_imports import litellm
15
+
16
+ if TYPE_CHECKING:
17
+ import litellm
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def get_exception_primary_cause(exception: BaseException) -> BaseException:
23
+ """Returns the primary cause of an exception by walking backwards.
24
+
25
+ This recursive walkback halts when it arrives at an exception which
26
+ has no provided __cause__ (e.g. __cause__ is None).
27
+
28
+ Args:
29
+ exception (Exception): An exception to start from.
30
+
31
+ Raises:
32
+ RecursionError: if for some reason exceptions have circular
33
+ dependencies (seems impossible in practice).
34
+ """
35
+ if exception.__cause__ is None:
36
+ return exception
37
+ else:
38
+ return get_exception_primary_cause(exception.__cause__)
39
+
40
+
41
+ class GenerationValidationFailureError(Exception): ...
42
+
43
+
44
+ class ModelRateLimitError(DataDesignerError): ...
45
+
46
+
47
+ class ModelTimeoutError(DataDesignerError): ...
48
+
49
+
50
+ class ModelContextWindowExceededError(DataDesignerError): ...
51
+
52
+
53
+ class ModelAuthenticationError(DataDesignerError): ...
54
+
55
+
56
+ class ModelPermissionDeniedError(DataDesignerError): ...
57
+
58
+
59
+ class ModelNotFoundError(DataDesignerError): ...
60
+
61
+
62
+ class ModelUnsupportedParamsError(DataDesignerError): ...
63
+
64
+
65
+ class ModelBadRequestError(DataDesignerError): ...
66
+
67
+
68
+ class ModelInternalServerError(DataDesignerError): ...
69
+
70
+
71
+ class ModelAPIError(DataDesignerError): ...
72
+
73
+
74
+ class ModelUnprocessableEntityError(DataDesignerError): ...
75
+
76
+
77
+ class ModelAPIConnectionError(DataDesignerError): ...
78
+
79
+
80
+ class ModelStructuredOutputError(DataDesignerError): ...
81
+
82
+
83
+ class ModelGenerationValidationFailureError(DataDesignerError): ...
84
+
85
+
86
+ class FormattedLLMErrorMessage(BaseModel):
87
+ cause: str
88
+ solution: str
89
+
90
+ def __str__(self) -> str:
91
+ return "\n".join(
92
+ [
93
+ " |----------",
94
+ f" | Cause: {self.cause}",
95
+ f" | Solution: {self.solution}",
96
+ " |----------",
97
+ ]
98
+ )
99
+
100
+
101
+ def handle_llm_exceptions(
102
+ exception: Exception, model_name: str, model_provider_name: str, purpose: str | None = None
103
+ ) -> None:
104
+ """Handle LLM-related exceptions and convert them to appropriate DataDesignerError errors.
105
+
106
+ This method centralizes the exception handling logic for LLM operations,
107
+ making it reusable across different contexts.
108
+
109
+ Args:
110
+ exception: The exception that was raised
111
+ model_name: Name of the model that was being used
112
+ model_provider_name: Name of the model provider that was being used
113
+ purpose: The purpose of the model usage to show as context in the error message
114
+ Raises:
115
+ DataDesignerError: A more user-friendly error with appropriate error type and message
116
+ """
117
+ purpose = purpose or "running generation"
118
+ authentication_error = FormattedLLMErrorMessage(
119
+ cause=f"The API key provided for model {model_name!r} was found to be invalid or expired while {purpose}.",
120
+ solution=f"Verify your API key for model provider and update it in your settings for model provider {model_provider_name!r}.",
121
+ )
122
+ err_msg_parser = DownstreamLLMExceptionMessageParser(model_name, model_provider_name, purpose)
123
+ match exception:
124
+ # Common errors that can come from LiteLLM
125
+ case litellm.exceptions.APIError():
126
+ raise err_msg_parser.parse_api_error(exception, authentication_error) from None
127
+
128
+ case litellm.exceptions.APIConnectionError():
129
+ raise ModelAPIConnectionError(
130
+ FormattedLLMErrorMessage(
131
+ cause=f"Connection to model {model_name!r} hosted on model provider {model_provider_name!r} failed while {purpose}.",
132
+ solution="Check your network/proxy/firewall settings.",
133
+ )
134
+ ) from None
135
+
136
+ case litellm.exceptions.AuthenticationError():
137
+ raise ModelAuthenticationError(authentication_error) from None
138
+
139
+ case litellm.exceptions.ContextWindowExceededError():
140
+ raise err_msg_parser.parse_context_window_exceeded_error(exception) from None
141
+
142
+ case litellm.exceptions.UnsupportedParamsError():
143
+ raise ModelUnsupportedParamsError(
144
+ FormattedLLMErrorMessage(
145
+ cause=f"One or more of the parameters you provided were found to be unsupported by model {model_name!r} while {purpose}.",
146
+ solution=f"Review the documentation for model provider {model_provider_name!r} and adjust your request.",
147
+ )
148
+ ) from None
149
+
150
+ case litellm.exceptions.BadRequestError():
151
+ raise err_msg_parser.parse_bad_request_error(exception) from None
152
+
153
+ case litellm.exceptions.InternalServerError():
154
+ raise ModelInternalServerError(
155
+ FormattedLLMErrorMessage(
156
+ cause=f"Model {model_name!r} is currently experiencing internal server issues while {purpose}.",
157
+ solution=f"Try again in a few moments. Check with your model provider {model_provider_name!r} if the issue persists.",
158
+ )
159
+ ) from None
160
+
161
+ case litellm.exceptions.NotFoundError():
162
+ raise ModelNotFoundError(
163
+ FormattedLLMErrorMessage(
164
+ cause=f"The specified model {model_name!r} could not be found while {purpose}.",
165
+ solution=f"Check that the model name is correct and supported by your model provider {model_provider_name!r} and try again.",
166
+ )
167
+ ) from None
168
+
169
+ case litellm.exceptions.PermissionDeniedError():
170
+ raise ModelPermissionDeniedError(
171
+ FormattedLLMErrorMessage(
172
+ cause=f"Your API key was found to lack the necessary permissions to use model {model_name!r} while {purpose}.",
173
+ solution=f"Use an API key that has the right permissions for the model or use a model the API key in use has access to in model provider {model_provider_name!r}.",
174
+ )
175
+ ) from None
176
+
177
+ case litellm.exceptions.RateLimitError():
178
+ raise ModelRateLimitError(
179
+ FormattedLLMErrorMessage(
180
+ cause=f"You have exceeded the rate limit for model {model_name!r} while {purpose}.",
181
+ solution="Wait and try again in a few moments.",
182
+ )
183
+ ) from None
184
+
185
+ case litellm.exceptions.Timeout():
186
+ raise ModelTimeoutError(
187
+ FormattedLLMErrorMessage(
188
+ cause=f"The request to model {model_name!r} timed out while {purpose}.",
189
+ solution="Check your connection and try again. You may need to increase the timeout setting for the model.",
190
+ )
191
+ ) from None
192
+
193
+ case litellm.exceptions.UnprocessableEntityError():
194
+ raise ModelUnprocessableEntityError(
195
+ FormattedLLMErrorMessage(
196
+ cause=f"The request to model {model_name!r} failed despite correct request format while {purpose}.",
197
+ solution="This is most likely temporary. Try again in a few moments.",
198
+ )
199
+ ) from None
200
+
201
+ # Parsing and validation errors
202
+ case GenerationValidationFailureError():
203
+ raise ModelGenerationValidationFailureError(
204
+ FormattedLLMErrorMessage(
205
+ cause=f"The provided output schema was unable to be parsed from model {model_name!r} responses while {purpose}.",
206
+ solution="This is most likely temporary as we make additional attempts. If you continue to see more of this, simplify or modify the output schema for structured output and try again. If you are attempting token-intensive tasks like generations with high-reasoning effort, ensure that max_tokens in the model config is high enough to reach completion.",
207
+ )
208
+ ) from None
209
+
210
+ case DataDesignerError():
211
+ raise exception from None
212
+
213
+ case _:
214
+ raise DataDesignerError(
215
+ FormattedLLMErrorMessage(
216
+ cause=f"An unexpected error occurred while {purpose}.",
217
+ solution=f"Review the stack trace for more details: {exception}",
218
+ )
219
+ ) from exception
220
+
221
+
222
+ def catch_llm_exceptions(func: Callable) -> Callable:
223
+ """This decorator should be used on any `ModelFacade` method that could potentially raise
224
+ exceptions that should turn into upstream user-facing errors.
225
+ """
226
+
227
+ @wraps(func)
228
+ def wrapper(model_facade: Any, *args, **kwargs):
229
+ try:
230
+ return func(model_facade, *args, **kwargs)
231
+ except Exception as e:
232
+ logger.debug(
233
+ "\n".join(
234
+ [
235
+ "",
236
+ "|----------",
237
+ f"| Caught an exception downstream of type {type(e)!r}. Re-raising it below as a custom error with more context.",
238
+ "|----------",
239
+ ]
240
+ ),
241
+ exc_info=True,
242
+ stack_info=True,
243
+ )
244
+ handle_llm_exceptions(
245
+ e, model_facade.model_name, model_facade.model_provider_name, purpose=kwargs.get("purpose")
246
+ )
247
+
248
+ return wrapper
249
+
250
+
251
+ class DownstreamLLMExceptionMessageParser:
252
+ def __init__(self, model_name: str, model_provider_name: str, purpose: str):
253
+ self.model_name = model_name
254
+ self.model_provider_name = model_provider_name
255
+ self.purpose = purpose
256
+
257
+ def parse_bad_request_error(self, exception: litellm.exceptions.BadRequestError) -> DataDesignerError:
258
+ err_msg = FormattedLLMErrorMessage(
259
+ cause=f"The request for model {self.model_name!r} was found to be malformed or missing required parameters while {self.purpose}.",
260
+ solution="Check your request parameters and try again.",
261
+ )
262
+ if "is not a multimodal model" in str(exception):
263
+ err_msg = FormattedLLMErrorMessage(
264
+ cause=f"Model {self.model_name!r} is not a multimodal model, but it looks like you are trying to provide multimodal context while {self.purpose}.",
265
+ solution="Check your request parameters and try again.",
266
+ )
267
+ return ModelBadRequestError(err_msg)
268
+
269
+ def parse_context_window_exceeded_error(
270
+ self, exception: litellm.exceptions.ContextWindowExceededError
271
+ ) -> DataDesignerError:
272
+ cause = f"The input data for model '{self.model_name}' was found to exceed its supported context width while {self.purpose}."
273
+ try:
274
+ if "OpenAIException - This model's maximum context length is " in str(exception):
275
+ openai_exception_cause = (
276
+ str(exception).split("OpenAIException - ")[1].split("\n")[0].split(" Please reduce ")[0]
277
+ )
278
+ cause = f"{cause} {openai_exception_cause}"
279
+ except Exception:
280
+ pass
281
+ finally:
282
+ return ModelContextWindowExceededError(
283
+ FormattedLLMErrorMessage(
284
+ cause=cause,
285
+ solution="Check the model's supported max context width. Adjust the length of your input along with completions and try again.",
286
+ )
287
+ )
288
+
289
+ def parse_api_error(
290
+ self, exception: litellm.exceptions.InternalServerError, auth_error_msg: FormattedLLMErrorMessage
291
+ ) -> DataDesignerError:
292
+ if "Error code: 403" in str(exception):
293
+ return ModelAuthenticationError(auth_error_msg)
294
+
295
+ return ModelAPIError(
296
+ FormattedLLMErrorMessage(
297
+ cause=f"An unexpected API error occurred with model {self.model_name!r} while {self.purpose}.",
298
+ solution=f"Try again in a few moments. Check with your model provider {self.model_provider_name!r} if the issue persists.",
299
+ )
300
+ )
@@ -0,0 +1,284 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import logging
7
+ from collections.abc import Callable
8
+ from copy import deepcopy
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
12
+ from data_designer.engine.model_provider import ModelProviderRegistry
13
+ from data_designer.engine.models.errors import (
14
+ GenerationValidationFailureError,
15
+ catch_llm_exceptions,
16
+ get_exception_primary_cause,
17
+ )
18
+ from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs
19
+ from data_designer.engine.models.parsers.errors import ParserException
20
+ from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
21
+ from data_designer.engine.models.utils import ChatMessage, prompt_to_messages
22
+ from data_designer.engine.secret_resolver import SecretResolver
23
+ from data_designer.lazy_heavy_imports import litellm
24
+
25
+ if TYPE_CHECKING:
26
+ import litellm
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class ModelFacade:
32
+ def __init__(
33
+ self,
34
+ model_config: ModelConfig,
35
+ secret_resolver: SecretResolver,
36
+ model_provider_registry: ModelProviderRegistry,
37
+ ):
38
+ self._model_config = model_config
39
+ self._secret_resolver = secret_resolver
40
+ self._model_provider_registry = model_provider_registry
41
+ self._litellm_deployment = self._get_litellm_deployment(model_config)
42
+ self._router = CustomRouter([self._litellm_deployment], **LiteLLMRouterDefaultKwargs().model_dump())
43
+ self._usage_stats = ModelUsageStats()
44
+
45
+ @property
46
+ def model_name(self) -> str:
47
+ return self._model_config.model
48
+
49
+ @property
50
+ def model_provider(self) -> ModelProvider:
51
+ return self._model_provider_registry.get_provider(self._model_config.provider)
52
+
53
+ @property
54
+ def model_generation_type(self) -> GenerationType:
55
+ return self._model_config.generation_type
56
+
57
+ @property
58
+ def model_provider_name(self) -> str:
59
+ return self.model_provider.name
60
+
61
+ @property
62
+ def model_alias(self) -> str:
63
+ return self._model_config.alias
64
+
65
+ @property
66
+ def usage_stats(self) -> ModelUsageStats:
67
+ return self._usage_stats
68
+
69
+ def completion(
70
+ self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs
71
+ ) -> litellm.ModelResponse:
72
+ message_payloads = [message.to_dict() for message in messages]
73
+ logger.debug(
74
+ f"Prompting model {self.model_name!r}...",
75
+ extra={"model": self.model_name, "messages": message_payloads},
76
+ )
77
+ response = None
78
+ kwargs = self.consolidate_kwargs(**kwargs)
79
+ try:
80
+ response = self._router.completion(model=self.model_name, messages=message_payloads, **kwargs)
81
+ logger.debug(
82
+ f"Received completion from model {self.model_name!r}",
83
+ extra={
84
+ "model": self.model_name,
85
+ "response": response,
86
+ "text": response.choices[0].message.content,
87
+ "usage": self._usage_stats.model_dump(),
88
+ },
89
+ )
90
+ return response
91
+ except Exception as e:
92
+ raise e
93
+ finally:
94
+ if not skip_usage_tracking and response is not None:
95
+ self._track_usage(response)
96
+
97
+ def consolidate_kwargs(self, **kwargs) -> dict[str, Any]:
98
+ # Remove purpose from kwargs to avoid passing it to the model
99
+ kwargs.pop("purpose", None)
100
+ kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs}
101
+ if self.model_provider.extra_body:
102
+ kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
103
+ if self.model_provider.extra_headers:
104
+ kwargs["extra_headers"] = self.model_provider.extra_headers
105
+ return kwargs
106
+
107
+ @catch_llm_exceptions
108
+ def generate_text_embeddings(
109
+ self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
110
+ ) -> list[list[float]]:
111
+ logger.debug(
112
+ f"Generating embeddings with model {self.model_name!r}...",
113
+ extra={
114
+ "model": self.model_name,
115
+ "input_count": len(input_texts),
116
+ },
117
+ )
118
+ kwargs = self.consolidate_kwargs(**kwargs)
119
+ response = None
120
+ try:
121
+ response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs)
122
+ logger.debug(
123
+ f"Received embeddings from model {self.model_name!r}",
124
+ extra={
125
+ "model": self.model_name,
126
+ "embedding_count": len(response.data) if response.data else 0,
127
+ "usage": self._usage_stats.model_dump(),
128
+ },
129
+ )
130
+ if response.data and len(response.data) == len(input_texts):
131
+ return [data["embedding"] for data in response.data]
132
+ else:
133
+ raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}")
134
+ except Exception as e:
135
+ raise e
136
+ finally:
137
+ if not skip_usage_tracking and response is not None:
138
+ self._track_usage_from_embedding(response)
139
+
140
+ @catch_llm_exceptions
141
+ def generate(
142
+ self,
143
+ prompt: str,
144
+ *,
145
+ parser: Callable[[str], Any],
146
+ system_prompt: str | None = None,
147
+ multi_modal_context: list[dict[str, Any]] | None = None,
148
+ max_correction_steps: int = 0,
149
+ max_conversation_restarts: int = 0,
150
+ skip_usage_tracking: bool = False,
151
+ purpose: str | None = None,
152
+ **kwargs,
153
+ ) -> tuple[Any, list[ChatMessage]]:
154
+ """Generate a parsed output with correction steps.
155
+
156
+ This generation call will attempt to generate an output which is
157
+ valid according to the specified parser, where "valid" implies
158
+ that the parser can process the LLM response without raising
159
+ an exception.
160
+
161
+ `ParserExceptions` are routed back
162
+ to the LLM as new rounds in the conversation, where the LLM is provided its
163
+ earlier response along with the "user" role responding with the exception string
164
+ (not traceback). This will continue for the number of rounds specified by
165
+ `max_correction_steps`.
166
+
167
+ Args:
168
+ prompt (str): Task prompt.
169
+ system_prompt (str, optional): Optional system instructions. If not specified,
170
+ no system message is provided and the model should use its default system
171
+ prompt.
172
+ parser (func(str) -> Any): A function applied to the LLM response which processes
173
+ an LLM response into some output object.
174
+ max_correction_steps (int): Maximum number of correction rounds permitted
175
+ within a single conversation. Note, many rounds can lead to increasing
176
+ context size without necessarily improving performance -- small language
177
+ models can enter repeated cycles which will not be solved with more steps.
178
+ Default: `0` (no correction).
179
+ max_conversation_restarts (int): Maximum number of full conversation restarts permitted
180
+ if generation fails. Default: `0` (no restarts).
181
+ skip_usage_tracking (bool): Whether to skip usage tracking. Default: `False`.
182
+ purpose (str): The purpose of the model usage to show as context in the error message.
183
+ It is expected to be used by the @catch_llm_exceptions decorator.
184
+ **kwargs: Additional arguments to pass to the model.
185
+
186
+ Returns:
187
+ A tuple containing:
188
+ - The parsed output object from the parser.
189
+ - The full trace of ChatMessage entries in the conversation, including any
190
+ corrections and reasoning traces. Callers can decide whether to store this.
191
+
192
+ Raises:
193
+ GenerationValidationFailureError: If the maximum number of retries or
194
+ correction steps are met and the last response failures on
195
+ generation validation.
196
+ """
197
+ output_obj = None
198
+ curr_num_correction_steps = 0
199
+ curr_num_restarts = 0
200
+
201
+ starting_messages = prompt_to_messages(
202
+ user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context
203
+ )
204
+ messages: list[ChatMessage] = deepcopy(starting_messages)
205
+
206
+ while True:
207
+ completion_response = self.completion(messages, skip_usage_tracking=skip_usage_tracking, **kwargs)
208
+ response = completion_response.choices[0].message.content or ""
209
+ reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None)
210
+ messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None))
211
+ curr_num_correction_steps += 1
212
+
213
+ try:
214
+ output_obj = parser(response) # type: ignore - if not a string will cause a ParserException below
215
+ break
216
+ except ParserException as exc:
217
+ if max_correction_steps == 0 and max_conversation_restarts == 0:
218
+ raise GenerationValidationFailureError(
219
+ "Unsuccessful generation attempt. No retries were attempted."
220
+ ) from exc
221
+
222
+ if curr_num_correction_steps <= max_correction_steps:
223
+ # Add user message with error for correction
224
+ messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc))))
225
+
226
+ elif curr_num_restarts < max_conversation_restarts:
227
+ curr_num_correction_steps = 0
228
+ curr_num_restarts += 1
229
+ messages = deepcopy(starting_messages)
230
+
231
+ else:
232
+ raise GenerationValidationFailureError(
233
+ f"Unsuccessful generation despite {max_correction_steps} correction steps "
234
+ f"and {max_conversation_restarts} conversation restarts."
235
+ ) from exc
236
+
237
+ return output_obj, messages
238
+
239
+ def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict:
240
+ provider = self._model_provider_registry.get_provider(model_config.provider)
241
+ api_key = None
242
+ if provider.api_key:
243
+ api_key = self._secret_resolver.resolve(provider.api_key)
244
+ api_key = api_key or "not-used-but-required"
245
+
246
+ litellm_params = litellm.LiteLLM_Params(
247
+ model=f"{provider.provider_type}/{model_config.model}",
248
+ api_base=provider.endpoint,
249
+ api_key=api_key,
250
+ )
251
+ return {
252
+ "model_name": model_config.model,
253
+ "litellm_params": litellm_params.model_dump(),
254
+ }
255
+
256
+ def _track_usage(self, response: litellm.types.utils.ModelResponse | None) -> None:
257
+ if response is None:
258
+ self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
259
+ return
260
+ if (
261
+ response.usage is not None
262
+ and response.usage.prompt_tokens is not None
263
+ and response.usage.completion_tokens is not None
264
+ ):
265
+ self._usage_stats.extend(
266
+ token_usage=TokenUsageStats(
267
+ input_tokens=response.usage.prompt_tokens,
268
+ output_tokens=response.usage.completion_tokens,
269
+ ),
270
+ request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
271
+ )
272
+
273
+ def _track_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None:
274
+ if response is None:
275
+ self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
276
+ return
277
+ if response.usage is not None and response.usage.prompt_tokens is not None:
278
+ self._usage_stats.extend(
279
+ token_usage=TokenUsageStats(
280
+ input_tokens=response.usage.prompt_tokens,
281
+ output_tokens=0,
282
+ ),
283
+ request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
284
+ )
@@ -0,0 +1,42 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING
7
+
8
+ from data_designer.config.models import ModelConfig
9
+ from data_designer.engine.model_provider import ModelProviderRegistry
10
+ from data_designer.engine.secret_resolver import SecretResolver
11
+
12
+ if TYPE_CHECKING:
13
+ from data_designer.engine.models.registry import ModelRegistry
14
+
15
+
16
+ def create_model_registry(
17
+ *,
18
+ model_configs: list[ModelConfig] | None = None,
19
+ secret_resolver: SecretResolver,
20
+ model_provider_registry: ModelProviderRegistry,
21
+ ) -> ModelRegistry:
22
+ """Factory function for creating a ModelRegistry instance.
23
+
24
+ Heavy dependencies (litellm, httpx) are deferred until this function is called.
25
+ This is a factory function pattern - imports inside factories are idiomatic Python
26
+ for lazy initialization.
27
+ """
28
+ from data_designer.engine.models.facade import ModelFacade
29
+ from data_designer.engine.models.litellm_overrides import apply_litellm_patches
30
+ from data_designer.engine.models.registry import ModelRegistry
31
+
32
+ apply_litellm_patches()
33
+
34
+ def model_facade_factory(model_config, secret_resolver, model_provider_registry):
35
+ return ModelFacade(model_config, secret_resolver, model_provider_registry)
36
+
37
+ return ModelRegistry(
38
+ model_configs=model_configs,
39
+ secret_resolver=secret_resolver,
40
+ model_provider_registry=model_provider_registry,
41
+ model_facade_factory=model_facade_factory,
42
+ )