data-designer 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +2 -0
- data_designer/_version.py +2 -2
- data_designer/cli/__init__.py +2 -0
- data_designer/cli/commands/download.py +2 -0
- data_designer/cli/commands/list.py +2 -0
- data_designer/cli/commands/models.py +2 -0
- data_designer/cli/commands/providers.py +2 -0
- data_designer/cli/commands/reset.py +2 -0
- data_designer/cli/controllers/__init__.py +2 -0
- data_designer/cli/controllers/download_controller.py +2 -0
- data_designer/cli/controllers/model_controller.py +6 -1
- data_designer/cli/controllers/provider_controller.py +6 -1
- data_designer/cli/forms/__init__.py +2 -0
- data_designer/cli/forms/builder.py +2 -0
- data_designer/cli/forms/field.py +2 -0
- data_designer/cli/forms/form.py +2 -0
- data_designer/cli/forms/model_builder.py +2 -0
- data_designer/cli/forms/provider_builder.py +2 -0
- data_designer/cli/main.py +2 -0
- data_designer/cli/repositories/__init__.py +2 -0
- data_designer/cli/repositories/base.py +2 -0
- data_designer/cli/repositories/model_repository.py +2 -0
- data_designer/cli/repositories/persona_repository.py +2 -0
- data_designer/cli/repositories/provider_repository.py +2 -0
- data_designer/cli/services/__init__.py +2 -0
- data_designer/cli/services/download_service.py +2 -0
- data_designer/cli/services/model_service.py +2 -0
- data_designer/cli/services/provider_service.py +2 -0
- data_designer/cli/ui.py +2 -0
- data_designer/cli/utils.py +2 -0
- data_designer/config/analysis/column_profilers.py +2 -0
- data_designer/config/analysis/column_statistics.py +8 -5
- data_designer/config/analysis/dataset_profiler.py +9 -3
- data_designer/config/analysis/utils/errors.py +2 -0
- data_designer/config/analysis/utils/reporting.py +7 -3
- data_designer/config/base.py +1 -0
- data_designer/config/column_configs.py +77 -7
- data_designer/config/column_types.py +33 -36
- data_designer/config/dataset_builders.py +2 -0
- data_designer/config/dataset_metadata.py +18 -0
- data_designer/config/default_model_settings.py +1 -0
- data_designer/config/errors.py +2 -0
- data_designer/config/exports.py +2 -0
- data_designer/config/interface.py +3 -2
- data_designer/config/models.py +7 -2
- data_designer/config/preview_results.py +9 -1
- data_designer/config/processors.py +2 -0
- data_designer/config/run_config.py +19 -5
- data_designer/config/sampler_constraints.py +2 -0
- data_designer/config/sampler_params.py +7 -2
- data_designer/config/seed.py +2 -0
- data_designer/config/seed_source.py +9 -3
- data_designer/config/seed_source_types.py +2 -0
- data_designer/config/utils/constants.py +2 -0
- data_designer/config/utils/errors.py +2 -0
- data_designer/config/utils/info.py +2 -0
- data_designer/config/utils/io_helpers.py +8 -3
- data_designer/config/utils/misc.py +2 -2
- data_designer/config/utils/numerical_helpers.py +2 -0
- data_designer/config/utils/type_helpers.py +2 -0
- data_designer/config/utils/visualization.py +19 -11
- data_designer/config/validator_params.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +9 -8
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +15 -19
- data_designer/engine/analysis/column_profilers/registry.py +2 -0
- data_designer/engine/analysis/column_statistics.py +5 -2
- data_designer/engine/analysis/dataset_profiler.py +12 -9
- data_designer/engine/analysis/errors.py +2 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +7 -4
- data_designer/engine/analysis/utils/judge_score_processing.py +7 -3
- data_designer/engine/column_generators/generators/base.py +26 -14
- data_designer/engine/column_generators/generators/embedding.py +4 -11
- data_designer/engine/column_generators/generators/expression.py +7 -16
- data_designer/engine/column_generators/generators/llm_completion.py +13 -47
- data_designer/engine/column_generators/generators/samplers.py +8 -14
- data_designer/engine/column_generators/generators/seed_dataset.py +9 -15
- data_designer/engine/column_generators/generators/validation.py +9 -20
- data_designer/engine/column_generators/registry.py +2 -0
- data_designer/engine/column_generators/utils/errors.py +2 -0
- data_designer/engine/column_generators/utils/generator_classification.py +2 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +2 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +4 -2
- data_designer/engine/compiler.py +3 -6
- data_designer/engine/configurable_task.py +12 -13
- data_designer/engine/dataset_builders/artifact_storage.py +87 -8
- data_designer/engine/dataset_builders/column_wise_builder.py +34 -35
- data_designer/engine/dataset_builders/errors.py +2 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +13 -4
- data_designer/engine/dataset_builders/utils/config_compiler.py +2 -0
- data_designer/engine/dataset_builders/utils/dag.py +7 -2
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +35 -25
- data_designer/engine/dataset_builders/utils/errors.py +2 -0
- data_designer/engine/errors.py +2 -0
- data_designer/engine/model_provider.py +2 -0
- data_designer/engine/models/errors.py +23 -31
- data_designer/engine/models/facade.py +12 -9
- data_designer/engine/models/factory.py +42 -0
- data_designer/engine/models/litellm_overrides.py +16 -11
- data_designer/engine/models/parsers/errors.py +2 -0
- data_designer/engine/models/parsers/parser.py +2 -2
- data_designer/engine/models/parsers/postprocessors.py +1 -0
- data_designer/engine/models/parsers/tag_parsers.py +2 -0
- data_designer/engine/models/parsers/types.py +2 -0
- data_designer/engine/models/recipes/base.py +2 -0
- data_designer/engine/models/recipes/response_recipes.py +2 -0
- data_designer/engine/models/registry.py +11 -18
- data_designer/engine/models/telemetry.py +6 -2
- data_designer/engine/processing/ginja/ast.py +2 -0
- data_designer/engine/processing/ginja/environment.py +2 -0
- data_designer/engine/processing/ginja/exceptions.py +2 -0
- data_designer/engine/processing/ginja/record.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +9 -2
- data_designer/engine/processing/gsonschema/schema_transformers.py +2 -0
- data_designer/engine/processing/gsonschema/types.py +2 -0
- data_designer/engine/processing/gsonschema/validators.py +10 -6
- data_designer/engine/processing/processors/base.py +1 -5
- data_designer/engine/processing/processors/drop_columns.py +7 -10
- data_designer/engine/processing/processors/registry.py +2 -0
- data_designer/engine/processing/processors/schema_transform.py +7 -10
- data_designer/engine/processing/utils.py +7 -3
- data_designer/engine/registry/base.py +2 -0
- data_designer/engine/registry/data_designer_registry.py +2 -0
- data_designer/engine/registry/errors.py +2 -0
- data_designer/engine/resources/managed_dataset_generator.py +6 -2
- data_designer/engine/resources/managed_dataset_repository.py +8 -5
- data_designer/engine/resources/managed_storage.py +2 -0
- data_designer/engine/resources/resource_provider.py +20 -1
- data_designer/engine/resources/seed_reader.py +7 -2
- data_designer/engine/sampling_gen/column.py +2 -0
- data_designer/engine/sampling_gen/constraints.py +8 -2
- data_designer/engine/sampling_gen/data_sources/base.py +10 -7
- data_designer/engine/sampling_gen/data_sources/errors.py +2 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +27 -22
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +2 -2
- data_designer/engine/sampling_gen/entities/email_address_utils.py +2 -0
- data_designer/engine/sampling_gen/entities/errors.py +2 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +2 -0
- data_designer/engine/sampling_gen/entities/person.py +2 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +8 -1
- data_designer/engine/sampling_gen/errors.py +2 -0
- data_designer/engine/sampling_gen/generator.py +5 -4
- data_designer/engine/sampling_gen/jinja_utils.py +7 -3
- data_designer/engine/sampling_gen/people_gen.py +7 -7
- data_designer/engine/sampling_gen/person_constants.py +2 -0
- data_designer/engine/sampling_gen/schema.py +5 -1
- data_designer/engine/sampling_gen/schema_builder.py +2 -0
- data_designer/engine/sampling_gen/utils.py +7 -1
- data_designer/engine/secret_resolver.py +2 -0
- data_designer/engine/validation.py +2 -2
- data_designer/engine/validators/__init__.py +2 -0
- data_designer/engine/validators/base.py +2 -0
- data_designer/engine/validators/local_callable.py +7 -2
- data_designer/engine/validators/python.py +7 -1
- data_designer/engine/validators/remote.py +7 -1
- data_designer/engine/validators/sql.py +8 -3
- data_designer/errors.py +2 -0
- data_designer/essentials/__init__.py +2 -0
- data_designer/interface/data_designer.py +36 -39
- data_designer/interface/errors.py +2 -0
- data_designer/interface/results.py +9 -2
- data_designer/lazy_heavy_imports.py +54 -0
- data_designer/logging.py +2 -0
- data_designer/plugins/__init__.py +2 -0
- data_designer/plugins/errors.py +2 -0
- data_designer/plugins/plugin.py +0 -1
- data_designer/plugins/registry.py +2 -0
- data_designer/plugins/testing/__init__.py +2 -0
- data_designer/plugins/testing/stubs.py +21 -43
- data_designer/plugins/testing/utils.py +2 -0
- {data_designer-0.3.3.dist-info → data_designer-0.3.5.dist-info}/METADATA +19 -4
- data_designer-0.3.5.dist-info/RECORD +196 -0
- data_designer-0.3.3.dist-info/RECORD +0 -193
- {data_designer-0.3.3.dist-info → data_designer-0.3.5.dist-info}/WHEEL +0 -0
- {data_designer-0.3.3.dist-info → data_designer-0.3.5.dist-info}/entry_points.txt +0 -0
- {data_designer-0.3.3.dist-info → data_designer-0.3.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,25 +6,15 @@ from __future__ import annotations
|
|
|
6
6
|
import logging
|
|
7
7
|
from collections.abc import Callable
|
|
8
8
|
from functools import wraps
|
|
9
|
-
from typing import Any
|
|
10
|
-
|
|
11
|
-
from litellm.exceptions import (
|
|
12
|
-
APIConnectionError,
|
|
13
|
-
APIError,
|
|
14
|
-
AuthenticationError,
|
|
15
|
-
BadRequestError,
|
|
16
|
-
ContextWindowExceededError,
|
|
17
|
-
InternalServerError,
|
|
18
|
-
NotFoundError,
|
|
19
|
-
PermissionDeniedError,
|
|
20
|
-
RateLimitError,
|
|
21
|
-
Timeout,
|
|
22
|
-
UnprocessableEntityError,
|
|
23
|
-
UnsupportedParamsError,
|
|
24
|
-
)
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
|
+
|
|
25
11
|
from pydantic import BaseModel
|
|
26
12
|
|
|
27
13
|
from data_designer.engine.errors import DataDesignerError
|
|
14
|
+
from data_designer.lazy_heavy_imports import litellm
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
import litellm
|
|
28
18
|
|
|
29
19
|
logger = logging.getLogger(__name__)
|
|
30
20
|
|
|
@@ -132,10 +122,10 @@ def handle_llm_exceptions(
|
|
|
132
122
|
err_msg_parser = DownstreamLLMExceptionMessageParser(model_name, model_provider_name, purpose)
|
|
133
123
|
match exception:
|
|
134
124
|
# Common errors that can come from LiteLLM
|
|
135
|
-
case APIError():
|
|
125
|
+
case litellm.exceptions.APIError():
|
|
136
126
|
raise err_msg_parser.parse_api_error(exception, authentication_error) from None
|
|
137
127
|
|
|
138
|
-
case APIConnectionError():
|
|
128
|
+
case litellm.exceptions.APIConnectionError():
|
|
139
129
|
raise ModelAPIConnectionError(
|
|
140
130
|
FormattedLLMErrorMessage(
|
|
141
131
|
cause=f"Connection to model {model_name!r} hosted on model provider {model_provider_name!r} failed while {purpose}.",
|
|
@@ -143,13 +133,13 @@ def handle_llm_exceptions(
|
|
|
143
133
|
)
|
|
144
134
|
) from None
|
|
145
135
|
|
|
146
|
-
case AuthenticationError():
|
|
136
|
+
case litellm.exceptions.AuthenticationError():
|
|
147
137
|
raise ModelAuthenticationError(authentication_error) from None
|
|
148
138
|
|
|
149
|
-
case ContextWindowExceededError():
|
|
139
|
+
case litellm.exceptions.ContextWindowExceededError():
|
|
150
140
|
raise err_msg_parser.parse_context_window_exceeded_error(exception) from None
|
|
151
141
|
|
|
152
|
-
case UnsupportedParamsError():
|
|
142
|
+
case litellm.exceptions.UnsupportedParamsError():
|
|
153
143
|
raise ModelUnsupportedParamsError(
|
|
154
144
|
FormattedLLMErrorMessage(
|
|
155
145
|
cause=f"One or more of the parameters you provided were found to be unsupported by model {model_name!r} while {purpose}.",
|
|
@@ -157,10 +147,10 @@ def handle_llm_exceptions(
|
|
|
157
147
|
)
|
|
158
148
|
) from None
|
|
159
149
|
|
|
160
|
-
case BadRequestError():
|
|
150
|
+
case litellm.exceptions.BadRequestError():
|
|
161
151
|
raise err_msg_parser.parse_bad_request_error(exception) from None
|
|
162
152
|
|
|
163
|
-
case InternalServerError():
|
|
153
|
+
case litellm.exceptions.InternalServerError():
|
|
164
154
|
raise ModelInternalServerError(
|
|
165
155
|
FormattedLLMErrorMessage(
|
|
166
156
|
cause=f"Model {model_name!r} is currently experiencing internal server issues while {purpose}.",
|
|
@@ -168,7 +158,7 @@ def handle_llm_exceptions(
|
|
|
168
158
|
)
|
|
169
159
|
) from None
|
|
170
160
|
|
|
171
|
-
case NotFoundError():
|
|
161
|
+
case litellm.exceptions.NotFoundError():
|
|
172
162
|
raise ModelNotFoundError(
|
|
173
163
|
FormattedLLMErrorMessage(
|
|
174
164
|
cause=f"The specified model {model_name!r} could not be found while {purpose}.",
|
|
@@ -176,7 +166,7 @@ def handle_llm_exceptions(
|
|
|
176
166
|
)
|
|
177
167
|
) from None
|
|
178
168
|
|
|
179
|
-
case PermissionDeniedError():
|
|
169
|
+
case litellm.exceptions.PermissionDeniedError():
|
|
180
170
|
raise ModelPermissionDeniedError(
|
|
181
171
|
FormattedLLMErrorMessage(
|
|
182
172
|
cause=f"Your API key was found to lack the necessary permissions to use model {model_name!r} while {purpose}.",
|
|
@@ -184,7 +174,7 @@ def handle_llm_exceptions(
|
|
|
184
174
|
)
|
|
185
175
|
) from None
|
|
186
176
|
|
|
187
|
-
case RateLimitError():
|
|
177
|
+
case litellm.exceptions.RateLimitError():
|
|
188
178
|
raise ModelRateLimitError(
|
|
189
179
|
FormattedLLMErrorMessage(
|
|
190
180
|
cause=f"You have exceeded the rate limit for model {model_name!r} while {purpose}.",
|
|
@@ -192,7 +182,7 @@ def handle_llm_exceptions(
|
|
|
192
182
|
)
|
|
193
183
|
) from None
|
|
194
184
|
|
|
195
|
-
case Timeout():
|
|
185
|
+
case litellm.exceptions.Timeout():
|
|
196
186
|
raise ModelTimeoutError(
|
|
197
187
|
FormattedLLMErrorMessage(
|
|
198
188
|
cause=f"The request to model {model_name!r} timed out while {purpose}.",
|
|
@@ -200,7 +190,7 @@ def handle_llm_exceptions(
|
|
|
200
190
|
)
|
|
201
191
|
) from None
|
|
202
192
|
|
|
203
|
-
case UnprocessableEntityError():
|
|
193
|
+
case litellm.exceptions.UnprocessableEntityError():
|
|
204
194
|
raise ModelUnprocessableEntityError(
|
|
205
195
|
FormattedLLMErrorMessage(
|
|
206
196
|
cause=f"The request to model {model_name!r} failed despite correct request format while {purpose}.",
|
|
@@ -264,7 +254,7 @@ class DownstreamLLMExceptionMessageParser:
|
|
|
264
254
|
self.model_provider_name = model_provider_name
|
|
265
255
|
self.purpose = purpose
|
|
266
256
|
|
|
267
|
-
def parse_bad_request_error(self, exception: BadRequestError) -> DataDesignerError:
|
|
257
|
+
def parse_bad_request_error(self, exception: litellm.exceptions.BadRequestError) -> DataDesignerError:
|
|
268
258
|
err_msg = FormattedLLMErrorMessage(
|
|
269
259
|
cause=f"The request for model {self.model_name!r} was found to be malformed or missing required parameters while {self.purpose}.",
|
|
270
260
|
solution="Check your request parameters and try again.",
|
|
@@ -276,7 +266,9 @@ class DownstreamLLMExceptionMessageParser:
|
|
|
276
266
|
)
|
|
277
267
|
return ModelBadRequestError(err_msg)
|
|
278
268
|
|
|
279
|
-
def parse_context_window_exceeded_error(
|
|
269
|
+
def parse_context_window_exceeded_error(
|
|
270
|
+
self, exception: litellm.exceptions.ContextWindowExceededError
|
|
271
|
+
) -> DataDesignerError:
|
|
280
272
|
cause = f"The input data for model '{self.model_name}' was found to exceed its supported context width while {self.purpose}."
|
|
281
273
|
try:
|
|
282
274
|
if "OpenAIException - This model's maximum context length is " in str(exception):
|
|
@@ -295,7 +287,7 @@ class DownstreamLLMExceptionMessageParser:
|
|
|
295
287
|
)
|
|
296
288
|
|
|
297
289
|
def parse_api_error(
|
|
298
|
-
self, exception: InternalServerError, auth_error_msg: FormattedLLMErrorMessage
|
|
290
|
+
self, exception: litellm.exceptions.InternalServerError, auth_error_msg: FormattedLLMErrorMessage
|
|
299
291
|
) -> DataDesignerError:
|
|
300
292
|
if "Error code: 403" in str(exception):
|
|
301
293
|
return ModelAuthenticationError(auth_error_msg)
|
|
@@ -6,10 +6,7 @@ from __future__ import annotations
|
|
|
6
6
|
import logging
|
|
7
7
|
from collections.abc import Callable
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Any
|
|
10
|
-
|
|
11
|
-
from litellm.types.router import DeploymentTypedDict, LiteLLM_Params
|
|
12
|
-
from litellm.types.utils import EmbeddingResponse, ModelResponse
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
13
10
|
|
|
14
11
|
from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
|
|
15
12
|
from data_designer.engine.model_provider import ModelProviderRegistry
|
|
@@ -23,6 +20,10 @@ from data_designer.engine.models.parsers.errors import ParserException
|
|
|
23
20
|
from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
|
|
24
21
|
from data_designer.engine.models.utils import prompt_to_messages, str_to_message
|
|
25
22
|
from data_designer.engine.secret_resolver import SecretResolver
|
|
23
|
+
from data_designer.lazy_heavy_imports import litellm
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
import litellm
|
|
26
27
|
|
|
27
28
|
logger = logging.getLogger(__name__)
|
|
28
29
|
|
|
@@ -65,7 +66,9 @@ class ModelFacade:
|
|
|
65
66
|
def usage_stats(self) -> ModelUsageStats:
|
|
66
67
|
return self._usage_stats
|
|
67
68
|
|
|
68
|
-
def completion(
|
|
69
|
+
def completion(
|
|
70
|
+
self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs
|
|
71
|
+
) -> litellm.ModelResponse:
|
|
69
72
|
logger.debug(
|
|
70
73
|
f"Prompting model {self.model_name!r}...",
|
|
71
74
|
extra={"model": self.model_name, "messages": messages},
|
|
@@ -236,14 +239,14 @@ class ModelFacade:
|
|
|
236
239
|
) from exc
|
|
237
240
|
return output_obj, reasoning_trace
|
|
238
241
|
|
|
239
|
-
def _get_litellm_deployment(self, model_config: ModelConfig) -> DeploymentTypedDict:
|
|
242
|
+
def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict:
|
|
240
243
|
provider = self._model_provider_registry.get_provider(model_config.provider)
|
|
241
244
|
api_key = None
|
|
242
245
|
if provider.api_key:
|
|
243
246
|
api_key = self._secret_resolver.resolve(provider.api_key)
|
|
244
247
|
api_key = api_key or "not-used-but-required"
|
|
245
248
|
|
|
246
|
-
litellm_params = LiteLLM_Params(
|
|
249
|
+
litellm_params = litellm.LiteLLM_Params(
|
|
247
250
|
model=f"{provider.provider_type}/{model_config.model}",
|
|
248
251
|
api_base=provider.endpoint,
|
|
249
252
|
api_key=api_key,
|
|
@@ -253,7 +256,7 @@ class ModelFacade:
|
|
|
253
256
|
"litellm_params": litellm_params.model_dump(),
|
|
254
257
|
}
|
|
255
258
|
|
|
256
|
-
def _track_usage(self, response: ModelResponse | None) -> None:
|
|
259
|
+
def _track_usage(self, response: litellm.types.utils.ModelResponse | None) -> None:
|
|
257
260
|
if response is None:
|
|
258
261
|
self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
|
|
259
262
|
return
|
|
@@ -270,7 +273,7 @@ class ModelFacade:
|
|
|
270
273
|
request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
|
|
271
274
|
)
|
|
272
275
|
|
|
273
|
-
def _track_usage_from_embedding(self, response: EmbeddingResponse | None) -> None:
|
|
276
|
+
def _track_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None:
|
|
274
277
|
if response is None:
|
|
275
278
|
self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
|
|
276
279
|
return
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from data_designer.config.models import ModelConfig
|
|
9
|
+
from data_designer.engine.model_provider import ModelProviderRegistry
|
|
10
|
+
from data_designer.engine.secret_resolver import SecretResolver
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from data_designer.engine.models.registry import ModelRegistry
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def create_model_registry(
|
|
17
|
+
*,
|
|
18
|
+
model_configs: list[ModelConfig] | None = None,
|
|
19
|
+
secret_resolver: SecretResolver,
|
|
20
|
+
model_provider_registry: ModelProviderRegistry,
|
|
21
|
+
) -> ModelRegistry:
|
|
22
|
+
"""Factory function for creating a ModelRegistry instance.
|
|
23
|
+
|
|
24
|
+
Heavy dependencies (litellm, httpx) are deferred until this function is called.
|
|
25
|
+
This is a factory function pattern - imports inside factories are idiomatic Python
|
|
26
|
+
for lazy initialization.
|
|
27
|
+
"""
|
|
28
|
+
from data_designer.engine.models.facade import ModelFacade
|
|
29
|
+
from data_designer.engine.models.litellm_overrides import apply_litellm_patches
|
|
30
|
+
from data_designer.engine.models.registry import ModelRegistry
|
|
31
|
+
|
|
32
|
+
apply_litellm_patches()
|
|
33
|
+
|
|
34
|
+
def model_facade_factory(model_config, secret_resolver, model_provider_registry):
|
|
35
|
+
return ModelFacade(model_config, secret_resolver, model_provider_registry)
|
|
36
|
+
|
|
37
|
+
return ModelRegistry(
|
|
38
|
+
model_configs=model_configs,
|
|
39
|
+
secret_resolver=secret_resolver,
|
|
40
|
+
model_provider_registry=model_provider_registry,
|
|
41
|
+
model_facade_factory=model_facade_factory,
|
|
42
|
+
)
|
|
@@ -5,21 +5,26 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import random
|
|
7
7
|
import threading
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
8
9
|
|
|
9
|
-
import httpx
|
|
10
|
-
import litellm
|
|
11
|
-
from litellm import RetryPolicy
|
|
12
|
-
from litellm.caching.in_memory_cache import InMemoryCache
|
|
13
|
-
from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager
|
|
14
|
-
from litellm.router import Router
|
|
15
10
|
from pydantic import BaseModel, Field
|
|
16
11
|
from typing_extensions import override
|
|
17
12
|
|
|
13
|
+
from data_designer.lazy_heavy_imports import httpx, litellm
|
|
18
14
|
from data_designer.logging import quiet_noisy_logger
|
|
19
15
|
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
import httpx
|
|
18
|
+
import litellm
|
|
19
|
+
|
|
20
20
|
DEFAULT_MAX_CALLBACKS = 1000
|
|
21
21
|
|
|
22
22
|
|
|
23
|
+
def _get_logging_callback_manager():
|
|
24
|
+
"""Lazy accessor for LoggingCallbackManager to avoid loading litellm at import time."""
|
|
25
|
+
return litellm.litellm_core_utils.logging_callback_manager.LoggingCallbackManager
|
|
26
|
+
|
|
27
|
+
|
|
23
28
|
class LiteLLMRouterDefaultKwargs(BaseModel):
|
|
24
29
|
## Number of seconds to wait initially after a connection
|
|
25
30
|
## failure.
|
|
@@ -35,15 +40,15 @@ class LiteLLMRouterDefaultKwargs(BaseModel):
|
|
|
35
40
|
|
|
36
41
|
## Sets the default retry policy, including the number
|
|
37
42
|
## of retries to use in particular scenarios.
|
|
38
|
-
retry_policy: RetryPolicy = Field(
|
|
39
|
-
default_factory=lambda: RetryPolicy(
|
|
43
|
+
retry_policy: litellm.RetryPolicy = Field(
|
|
44
|
+
default_factory=lambda: litellm.RetryPolicy(
|
|
40
45
|
RateLimitErrorRetries=3,
|
|
41
46
|
TimeoutErrorRetries=3,
|
|
42
47
|
)
|
|
43
48
|
)
|
|
44
49
|
|
|
45
50
|
|
|
46
|
-
class ThreadSafeCache(InMemoryCache):
|
|
51
|
+
class ThreadSafeCache(litellm.caching.in_memory_cache.InMemoryCache):
|
|
47
52
|
def __init__(self, *args, **kwargs):
|
|
48
53
|
super().__init__(*args, **kwargs)
|
|
49
54
|
|
|
@@ -78,7 +83,7 @@ class ThreadSafeCache(InMemoryCache):
|
|
|
78
83
|
super().flush_cache()
|
|
79
84
|
|
|
80
85
|
|
|
81
|
-
class CustomRouter(Router):
|
|
86
|
+
class CustomRouter(litellm.router.Router):
|
|
82
87
|
def __init__(
|
|
83
88
|
self,
|
|
84
89
|
*args,
|
|
@@ -155,7 +160,7 @@ def apply_litellm_patches():
|
|
|
155
160
|
litellm.in_memory_llm_clients_cache = ThreadSafeCache()
|
|
156
161
|
|
|
157
162
|
# Workaround for the litellm issue described in https://github.com/BerriAI/litellm/issues/9792
|
|
158
|
-
|
|
163
|
+
_get_logging_callback_manager().MAX_CALLBACKS = DEFAULT_MAX_CALLBACKS
|
|
159
164
|
|
|
160
165
|
quiet_noisy_logger("httpx")
|
|
161
166
|
quiet_noisy_logger("LiteLLM")
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
|
|
5
7
|
class ParserException(Exception):
|
|
6
8
|
"""Identifies errors resulting from generic parser errors.
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
from functools import reduce
|
|
5
7
|
|
|
6
8
|
import marko
|
|
@@ -80,13 +82,11 @@ class LLMResponseParser:
|
|
|
80
82
|
code: str
|
|
81
83
|
syntax: Optional[str] = None
|
|
82
84
|
|
|
83
|
-
|
|
84
85
|
class CodeBlockParser:
|
|
85
86
|
def __call__(self, element: _Element) -> CodeBlock:
|
|
86
87
|
# Implementation details...
|
|
87
88
|
return CodeBlock(code=element.text, syntax=element.get("class"))
|
|
88
89
|
|
|
89
|
-
|
|
90
90
|
parser = LLMResponseParser(
|
|
91
91
|
tag_parsers={
|
|
92
92
|
"pre.code": CodeBlockParser(),
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
from lxml.etree import _Element
|
|
5
7
|
|
|
6
8
|
from data_designer.engine.models.parsers.types import CodeBlock, TextBlock
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
from typing import Any, Protocol, runtime_checkable
|
|
5
7
|
|
|
6
8
|
from lxml.etree import _Element
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
import abc
|
|
5
7
|
from collections.abc import Callable
|
|
6
8
|
from typing import Generic, TypeVar
|
|
@@ -4,14 +4,17 @@
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
7
9
|
|
|
8
10
|
from data_designer.config.models import GenerationType, ModelConfig
|
|
9
11
|
from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
|
|
10
|
-
from data_designer.engine.models.facade import ModelFacade
|
|
11
|
-
from data_designer.engine.models.litellm_overrides import apply_litellm_patches
|
|
12
12
|
from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
|
|
13
13
|
from data_designer.engine.secret_resolver import SecretResolver
|
|
14
14
|
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from data_designer.engine.models.facade import ModelFacade
|
|
17
|
+
|
|
15
18
|
logger = logging.getLogger(__name__)
|
|
16
19
|
|
|
17
20
|
|
|
@@ -22,10 +25,12 @@ class ModelRegistry:
|
|
|
22
25
|
secret_resolver: SecretResolver,
|
|
23
26
|
model_provider_registry: ModelProviderRegistry,
|
|
24
27
|
model_configs: list[ModelConfig] | None = None,
|
|
28
|
+
model_facade_factory: Callable[[ModelConfig, SecretResolver, ModelProviderRegistry], ModelFacade] | None = None,
|
|
25
29
|
):
|
|
26
30
|
self._secret_resolver = secret_resolver
|
|
27
31
|
self._model_provider_registry = model_provider_registry
|
|
28
|
-
self.
|
|
32
|
+
self._model_facade_factory = model_facade_factory
|
|
33
|
+
self._model_configs: dict[str, ModelConfig] = {}
|
|
29
34
|
self._models: dict[str, ModelFacade] = {}
|
|
30
35
|
self._set_model_configs(model_configs)
|
|
31
36
|
|
|
@@ -136,18 +141,6 @@ class ModelRegistry:
|
|
|
136
141
|
# Models are now lazily initialized in get_model() when first requested
|
|
137
142
|
|
|
138
143
|
def _get_model(self, model_config: ModelConfig) -> ModelFacade:
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
def create_model_registry(
|
|
143
|
-
*,
|
|
144
|
-
model_configs: list[ModelConfig] | None = None,
|
|
145
|
-
secret_resolver: SecretResolver,
|
|
146
|
-
model_provider_registry: ModelProviderRegistry,
|
|
147
|
-
) -> ModelRegistry:
|
|
148
|
-
apply_litellm_patches()
|
|
149
|
-
return ModelRegistry(
|
|
150
|
-
model_configs=model_configs,
|
|
151
|
-
secret_resolver=secret_resolver,
|
|
152
|
-
model_provider_registry=model_provider_registry,
|
|
153
|
-
)
|
|
144
|
+
if self._model_facade_factory is None:
|
|
145
|
+
raise RuntimeError("ModelRegistry was not initialized with a model_facade_factory")
|
|
146
|
+
return self._model_facade_factory(model_config, self._secret_resolver, self._model_provider_registry)
|
|
@@ -18,11 +18,15 @@ import platform
|
|
|
18
18
|
from dataclasses import dataclass
|
|
19
19
|
from datetime import datetime, timezone
|
|
20
20
|
from enum import Enum
|
|
21
|
-
from typing import Any, ClassVar
|
|
21
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
|
22
22
|
|
|
23
|
-
import httpx
|
|
24
23
|
from pydantic import BaseModel, Field
|
|
25
24
|
|
|
25
|
+
from data_designer.lazy_heavy_imports import httpx
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
import httpx
|
|
29
|
+
|
|
26
30
|
TELEMETRY_ENABLED = os.getenv("NEMO_TELEMETRY_ENABLED", "true").lower() in ("1", "true", "yes")
|
|
27
31
|
CLIENT_ID = "184482118588404"
|
|
28
32
|
NEMO_TELEMETRY_VERSION = "nemo-telemetry/1.0"
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
import re
|
|
5
7
|
from collections.abc import Callable
|
|
6
8
|
from functools import partial, wraps
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
import json
|
|
5
7
|
|
|
6
8
|
from data_designer.config.utils.io_helpers import serialize_data
|
|
@@ -1,8 +1,15 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from
|
|
4
|
+
from __future__ import annotations
|
|
5
5
|
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
6
7
|
|
|
7
|
-
|
|
8
|
+
from data_designer.lazy_heavy_imports import jsonschema
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import jsonschema
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class JSONSchemaValidationError(jsonschema.ValidationError):
|
|
8
15
|
"""Alias of ValidationError to ease imports."""
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
from typing import Any, TypeVar
|
|
5
7
|
|
|
6
8
|
T_primitive = TypeVar("T_primitive", str, int, float, bool)
|
|
@@ -1,19 +1,23 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
import logging
|
|
5
7
|
import re
|
|
6
8
|
from copy import deepcopy
|
|
7
9
|
from decimal import ROUND_HALF_UP, Decimal
|
|
8
|
-
from typing import Any, overload
|
|
9
|
-
|
|
10
|
-
from jsonschema import Draft202012Validator, ValidationError, validators
|
|
10
|
+
from typing import TYPE_CHECKING, Any, overload
|
|
11
11
|
|
|
12
12
|
from data_designer.engine.processing.gsonschema.exceptions import JSONSchemaValidationError
|
|
13
13
|
from data_designer.engine.processing.gsonschema.schema_transformers import forbid_additional_properties
|
|
14
14
|
from data_designer.engine.processing.gsonschema.types import DataObjectT, JSONSchemaT, T_primitive
|
|
15
|
+
from data_designer.lazy_heavy_imports import jsonschema
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import jsonschema
|
|
15
19
|
|
|
16
|
-
DEFAULT_JSONSCHEMA_VALIDATOR = Draft202012Validator
|
|
20
|
+
DEFAULT_JSONSCHEMA_VALIDATOR = jsonschema.Draft202012Validator
|
|
17
21
|
|
|
18
22
|
logger = logging.getLogger(__name__)
|
|
19
23
|
|
|
@@ -69,7 +73,7 @@ def extend_jsonschema_validator_with_pruning(validator):
|
|
|
69
73
|
Type[jsonschema.Validator]: A validator class that will
|
|
70
74
|
prune extra fields.
|
|
71
75
|
"""
|
|
72
|
-
return validators.extend(validator, {"additionalProperties": prune_additional_properties})
|
|
76
|
+
return jsonschema.validators.extend(validator, {"additionalProperties": prune_additional_properties})
|
|
73
77
|
|
|
74
78
|
|
|
75
79
|
def _get_decimal_info_from_anyof(schema: dict) -> tuple[bool, int | None]:
|
|
@@ -190,7 +194,7 @@ def validate(
|
|
|
190
194
|
|
|
191
195
|
try:
|
|
192
196
|
validator(schema).validate(final_object)
|
|
193
|
-
except ValidationError as exc:
|
|
197
|
+
except jsonschema.ValidationError as exc:
|
|
194
198
|
raise JSONSchemaValidationError(str(exc)) from exc
|
|
195
199
|
|
|
196
200
|
final_object = normalize_decimal_fields(final_object, schema)
|
|
@@ -5,13 +5,9 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
from abc import ABC, abstractmethod
|
|
7
7
|
|
|
8
|
-
from data_designer.engine.configurable_task import ConfigurableTask,
|
|
8
|
+
from data_designer.engine.configurable_task import ConfigurableTask, DataT, TaskConfigT
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class Processor(ConfigurableTask[TaskConfigT], ABC):
|
|
12
|
-
@staticmethod
|
|
13
|
-
@abstractmethod
|
|
14
|
-
def metadata() -> ConfigurableTaskMetadata: ...
|
|
15
|
-
|
|
16
12
|
@abstractmethod
|
|
17
13
|
def process(self, data: DataT, *, current_batch_number: int | None = None) -> DataT: ...
|
|
@@ -1,26 +1,23 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
import
|
|
4
|
+
from __future__ import annotations
|
|
5
5
|
|
|
6
|
-
import
|
|
6
|
+
import logging
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
7
8
|
|
|
8
9
|
from data_designer.config.processors import DropColumnsProcessorConfig
|
|
9
|
-
from data_designer.engine.configurable_task import ConfigurableTaskMetadata
|
|
10
10
|
from data_designer.engine.dataset_builders.artifact_storage import BatchStage
|
|
11
11
|
from data_designer.engine.processing.processors.base import Processor
|
|
12
|
+
from data_designer.lazy_heavy_imports import pd
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import pandas as pd
|
|
12
16
|
|
|
13
17
|
logger = logging.getLogger(__name__)
|
|
14
18
|
|
|
15
19
|
|
|
16
20
|
class DropColumnsProcessor(Processor[DropColumnsProcessorConfig]):
|
|
17
|
-
@staticmethod
|
|
18
|
-
def metadata() -> ConfigurableTaskMetadata:
|
|
19
|
-
return ConfigurableTaskMetadata(
|
|
20
|
-
name="drop_columns_processor",
|
|
21
|
-
description="Drop columns from the input dataset.",
|
|
22
|
-
)
|
|
23
|
-
|
|
24
21
|
def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None) -> pd.DataFrame:
|
|
25
22
|
logger.info(f"🙈 Dropping columns: {self.config.column_names}")
|
|
26
23
|
if current_batch_number is not None: # not in preview mode
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
from data_designer.config.base import ConfigBase
|
|
5
7
|
from data_designer.config.processors import (
|
|
6
8
|
DropColumnsProcessorConfig,
|