data-designer 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. data_designer/__init__.py +2 -0
  2. data_designer/_version.py +2 -2
  3. data_designer/cli/__init__.py +2 -0
  4. data_designer/cli/commands/download.py +2 -0
  5. data_designer/cli/commands/list.py +2 -0
  6. data_designer/cli/commands/models.py +2 -0
  7. data_designer/cli/commands/providers.py +2 -0
  8. data_designer/cli/commands/reset.py +2 -0
  9. data_designer/cli/controllers/__init__.py +2 -0
  10. data_designer/cli/controllers/download_controller.py +2 -0
  11. data_designer/cli/controllers/model_controller.py +6 -1
  12. data_designer/cli/controllers/provider_controller.py +6 -1
  13. data_designer/cli/forms/__init__.py +2 -0
  14. data_designer/cli/forms/builder.py +2 -0
  15. data_designer/cli/forms/field.py +2 -0
  16. data_designer/cli/forms/form.py +2 -0
  17. data_designer/cli/forms/model_builder.py +2 -0
  18. data_designer/cli/forms/provider_builder.py +2 -0
  19. data_designer/cli/main.py +2 -0
  20. data_designer/cli/repositories/__init__.py +2 -0
  21. data_designer/cli/repositories/base.py +2 -0
  22. data_designer/cli/repositories/model_repository.py +2 -0
  23. data_designer/cli/repositories/persona_repository.py +2 -0
  24. data_designer/cli/repositories/provider_repository.py +2 -0
  25. data_designer/cli/services/__init__.py +2 -0
  26. data_designer/cli/services/download_service.py +2 -0
  27. data_designer/cli/services/model_service.py +2 -0
  28. data_designer/cli/services/provider_service.py +2 -0
  29. data_designer/cli/ui.py +2 -0
  30. data_designer/cli/utils.py +2 -0
  31. data_designer/config/analysis/column_profilers.py +2 -0
  32. data_designer/config/analysis/column_statistics.py +8 -5
  33. data_designer/config/analysis/dataset_profiler.py +9 -3
  34. data_designer/config/analysis/utils/errors.py +2 -0
  35. data_designer/config/analysis/utils/reporting.py +7 -3
  36. data_designer/config/base.py +1 -0
  37. data_designer/config/column_configs.py +77 -7
  38. data_designer/config/column_types.py +33 -36
  39. data_designer/config/dataset_builders.py +2 -0
  40. data_designer/config/dataset_metadata.py +18 -0
  41. data_designer/config/default_model_settings.py +1 -0
  42. data_designer/config/errors.py +2 -0
  43. data_designer/config/exports.py +2 -0
  44. data_designer/config/interface.py +3 -2
  45. data_designer/config/models.py +7 -2
  46. data_designer/config/preview_results.py +9 -1
  47. data_designer/config/processors.py +2 -0
  48. data_designer/config/run_config.py +19 -5
  49. data_designer/config/sampler_constraints.py +2 -0
  50. data_designer/config/sampler_params.py +7 -2
  51. data_designer/config/seed.py +2 -0
  52. data_designer/config/seed_source.py +9 -3
  53. data_designer/config/seed_source_types.py +2 -0
  54. data_designer/config/utils/constants.py +2 -0
  55. data_designer/config/utils/errors.py +2 -0
  56. data_designer/config/utils/info.py +2 -0
  57. data_designer/config/utils/io_helpers.py +8 -3
  58. data_designer/config/utils/misc.py +2 -2
  59. data_designer/config/utils/numerical_helpers.py +2 -0
  60. data_designer/config/utils/type_helpers.py +2 -0
  61. data_designer/config/utils/visualization.py +19 -11
  62. data_designer/config/validator_params.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +9 -8
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +15 -19
  65. data_designer/engine/analysis/column_profilers/registry.py +2 -0
  66. data_designer/engine/analysis/column_statistics.py +5 -2
  67. data_designer/engine/analysis/dataset_profiler.py +12 -9
  68. data_designer/engine/analysis/errors.py +2 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +7 -4
  70. data_designer/engine/analysis/utils/judge_score_processing.py +7 -3
  71. data_designer/engine/column_generators/generators/base.py +26 -14
  72. data_designer/engine/column_generators/generators/embedding.py +4 -11
  73. data_designer/engine/column_generators/generators/expression.py +7 -16
  74. data_designer/engine/column_generators/generators/llm_completion.py +13 -47
  75. data_designer/engine/column_generators/generators/samplers.py +8 -14
  76. data_designer/engine/column_generators/generators/seed_dataset.py +9 -15
  77. data_designer/engine/column_generators/generators/validation.py +9 -20
  78. data_designer/engine/column_generators/registry.py +2 -0
  79. data_designer/engine/column_generators/utils/errors.py +2 -0
  80. data_designer/engine/column_generators/utils/generator_classification.py +2 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +2 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +4 -2
  83. data_designer/engine/compiler.py +3 -6
  84. data_designer/engine/configurable_task.py +12 -13
  85. data_designer/engine/dataset_builders/artifact_storage.py +87 -8
  86. data_designer/engine/dataset_builders/column_wise_builder.py +34 -35
  87. data_designer/engine/dataset_builders/errors.py +2 -0
  88. data_designer/engine/dataset_builders/multi_column_configs.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +13 -4
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +2 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +7 -2
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +35 -25
  93. data_designer/engine/dataset_builders/utils/errors.py +2 -0
  94. data_designer/engine/errors.py +2 -0
  95. data_designer/engine/model_provider.py +2 -0
  96. data_designer/engine/models/errors.py +23 -31
  97. data_designer/engine/models/facade.py +12 -9
  98. data_designer/engine/models/factory.py +42 -0
  99. data_designer/engine/models/litellm_overrides.py +16 -11
  100. data_designer/engine/models/parsers/errors.py +2 -0
  101. data_designer/engine/models/parsers/parser.py +2 -2
  102. data_designer/engine/models/parsers/postprocessors.py +1 -0
  103. data_designer/engine/models/parsers/tag_parsers.py +2 -0
  104. data_designer/engine/models/parsers/types.py +2 -0
  105. data_designer/engine/models/recipes/base.py +2 -0
  106. data_designer/engine/models/recipes/response_recipes.py +2 -0
  107. data_designer/engine/models/registry.py +11 -18
  108. data_designer/engine/models/telemetry.py +6 -2
  109. data_designer/engine/processing/ginja/ast.py +2 -0
  110. data_designer/engine/processing/ginja/environment.py +2 -0
  111. data_designer/engine/processing/ginja/exceptions.py +2 -0
  112. data_designer/engine/processing/ginja/record.py +2 -0
  113. data_designer/engine/processing/gsonschema/exceptions.py +9 -2
  114. data_designer/engine/processing/gsonschema/schema_transformers.py +2 -0
  115. data_designer/engine/processing/gsonschema/types.py +2 -0
  116. data_designer/engine/processing/gsonschema/validators.py +10 -6
  117. data_designer/engine/processing/processors/base.py +1 -5
  118. data_designer/engine/processing/processors/drop_columns.py +7 -10
  119. data_designer/engine/processing/processors/registry.py +2 -0
  120. data_designer/engine/processing/processors/schema_transform.py +7 -10
  121. data_designer/engine/processing/utils.py +7 -3
  122. data_designer/engine/registry/base.py +2 -0
  123. data_designer/engine/registry/data_designer_registry.py +2 -0
  124. data_designer/engine/registry/errors.py +2 -0
  125. data_designer/engine/resources/managed_dataset_generator.py +6 -2
  126. data_designer/engine/resources/managed_dataset_repository.py +8 -5
  127. data_designer/engine/resources/managed_storage.py +2 -0
  128. data_designer/engine/resources/resource_provider.py +20 -1
  129. data_designer/engine/resources/seed_reader.py +7 -2
  130. data_designer/engine/sampling_gen/column.py +2 -0
  131. data_designer/engine/sampling_gen/constraints.py +8 -2
  132. data_designer/engine/sampling_gen/data_sources/base.py +10 -7
  133. data_designer/engine/sampling_gen/data_sources/errors.py +2 -0
  134. data_designer/engine/sampling_gen/data_sources/sources.py +27 -22
  135. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +2 -2
  136. data_designer/engine/sampling_gen/entities/email_address_utils.py +2 -0
  137. data_designer/engine/sampling_gen/entities/errors.py +2 -0
  138. data_designer/engine/sampling_gen/entities/national_id_utils.py +2 -0
  139. data_designer/engine/sampling_gen/entities/person.py +2 -0
  140. data_designer/engine/sampling_gen/entities/phone_number.py +8 -1
  141. data_designer/engine/sampling_gen/errors.py +2 -0
  142. data_designer/engine/sampling_gen/generator.py +5 -4
  143. data_designer/engine/sampling_gen/jinja_utils.py +7 -3
  144. data_designer/engine/sampling_gen/people_gen.py +7 -7
  145. data_designer/engine/sampling_gen/person_constants.py +2 -0
  146. data_designer/engine/sampling_gen/schema.py +5 -1
  147. data_designer/engine/sampling_gen/schema_builder.py +2 -0
  148. data_designer/engine/sampling_gen/utils.py +7 -1
  149. data_designer/engine/secret_resolver.py +2 -0
  150. data_designer/engine/validation.py +2 -2
  151. data_designer/engine/validators/__init__.py +2 -0
  152. data_designer/engine/validators/base.py +2 -0
  153. data_designer/engine/validators/local_callable.py +7 -2
  154. data_designer/engine/validators/python.py +7 -1
  155. data_designer/engine/validators/remote.py +7 -1
  156. data_designer/engine/validators/sql.py +8 -3
  157. data_designer/errors.py +2 -0
  158. data_designer/essentials/__init__.py +2 -0
  159. data_designer/interface/data_designer.py +36 -39
  160. data_designer/interface/errors.py +2 -0
  161. data_designer/interface/results.py +9 -2
  162. data_designer/lazy_heavy_imports.py +54 -0
  163. data_designer/logging.py +2 -0
  164. data_designer/plugins/__init__.py +2 -0
  165. data_designer/plugins/errors.py +2 -0
  166. data_designer/plugins/plugin.py +0 -1
  167. data_designer/plugins/registry.py +2 -0
  168. data_designer/plugins/testing/__init__.py +2 -0
  169. data_designer/plugins/testing/stubs.py +21 -43
  170. data_designer/plugins/testing/utils.py +2 -0
  171. {data_designer-0.3.3.dist-info → data_designer-0.3.5.dist-info}/METADATA +19 -4
  172. data_designer-0.3.5.dist-info/RECORD +196 -0
  173. data_designer-0.3.3.dist-info/RECORD +0 -193
  174. {data_designer-0.3.3.dist-info → data_designer-0.3.5.dist-info}/WHEEL +0 -0
  175. {data_designer-0.3.3.dist-info → data_designer-0.3.5.dist-info}/entry_points.txt +0 -0
  176. {data_designer-0.3.3.dist-info → data_designer-0.3.5.dist-info}/licenses/LICENSE +0 -0
@@ -6,25 +6,15 @@ from __future__ import annotations
6
6
  import logging
7
7
  from collections.abc import Callable
8
8
  from functools import wraps
9
- from typing import Any
10
-
11
- from litellm.exceptions import (
12
- APIConnectionError,
13
- APIError,
14
- AuthenticationError,
15
- BadRequestError,
16
- ContextWindowExceededError,
17
- InternalServerError,
18
- NotFoundError,
19
- PermissionDeniedError,
20
- RateLimitError,
21
- Timeout,
22
- UnprocessableEntityError,
23
- UnsupportedParamsError,
24
- )
9
+ from typing import TYPE_CHECKING, Any
10
+
25
11
  from pydantic import BaseModel
26
12
 
27
13
  from data_designer.engine.errors import DataDesignerError
14
+ from data_designer.lazy_heavy_imports import litellm
15
+
16
+ if TYPE_CHECKING:
17
+ import litellm
28
18
 
29
19
  logger = logging.getLogger(__name__)
30
20
 
@@ -132,10 +122,10 @@ def handle_llm_exceptions(
132
122
  err_msg_parser = DownstreamLLMExceptionMessageParser(model_name, model_provider_name, purpose)
133
123
  match exception:
134
124
  # Common errors that can come from LiteLLM
135
- case APIError():
125
+ case litellm.exceptions.APIError():
136
126
  raise err_msg_parser.parse_api_error(exception, authentication_error) from None
137
127
 
138
- case APIConnectionError():
128
+ case litellm.exceptions.APIConnectionError():
139
129
  raise ModelAPIConnectionError(
140
130
  FormattedLLMErrorMessage(
141
131
  cause=f"Connection to model {model_name!r} hosted on model provider {model_provider_name!r} failed while {purpose}.",
@@ -143,13 +133,13 @@ def handle_llm_exceptions(
143
133
  )
144
134
  ) from None
145
135
 
146
- case AuthenticationError():
136
+ case litellm.exceptions.AuthenticationError():
147
137
  raise ModelAuthenticationError(authentication_error) from None
148
138
 
149
- case ContextWindowExceededError():
139
+ case litellm.exceptions.ContextWindowExceededError():
150
140
  raise err_msg_parser.parse_context_window_exceeded_error(exception) from None
151
141
 
152
- case UnsupportedParamsError():
142
+ case litellm.exceptions.UnsupportedParamsError():
153
143
  raise ModelUnsupportedParamsError(
154
144
  FormattedLLMErrorMessage(
155
145
  cause=f"One or more of the parameters you provided were found to be unsupported by model {model_name!r} while {purpose}.",
@@ -157,10 +147,10 @@ def handle_llm_exceptions(
157
147
  )
158
148
  ) from None
159
149
 
160
- case BadRequestError():
150
+ case litellm.exceptions.BadRequestError():
161
151
  raise err_msg_parser.parse_bad_request_error(exception) from None
162
152
 
163
- case InternalServerError():
153
+ case litellm.exceptions.InternalServerError():
164
154
  raise ModelInternalServerError(
165
155
  FormattedLLMErrorMessage(
166
156
  cause=f"Model {model_name!r} is currently experiencing internal server issues while {purpose}.",
@@ -168,7 +158,7 @@ def handle_llm_exceptions(
168
158
  )
169
159
  ) from None
170
160
 
171
- case NotFoundError():
161
+ case litellm.exceptions.NotFoundError():
172
162
  raise ModelNotFoundError(
173
163
  FormattedLLMErrorMessage(
174
164
  cause=f"The specified model {model_name!r} could not be found while {purpose}.",
@@ -176,7 +166,7 @@ def handle_llm_exceptions(
176
166
  )
177
167
  ) from None
178
168
 
179
- case PermissionDeniedError():
169
+ case litellm.exceptions.PermissionDeniedError():
180
170
  raise ModelPermissionDeniedError(
181
171
  FormattedLLMErrorMessage(
182
172
  cause=f"Your API key was found to lack the necessary permissions to use model {model_name!r} while {purpose}.",
@@ -184,7 +174,7 @@ def handle_llm_exceptions(
184
174
  )
185
175
  ) from None
186
176
 
187
- case RateLimitError():
177
+ case litellm.exceptions.RateLimitError():
188
178
  raise ModelRateLimitError(
189
179
  FormattedLLMErrorMessage(
190
180
  cause=f"You have exceeded the rate limit for model {model_name!r} while {purpose}.",
@@ -192,7 +182,7 @@ def handle_llm_exceptions(
192
182
  )
193
183
  ) from None
194
184
 
195
- case Timeout():
185
+ case litellm.exceptions.Timeout():
196
186
  raise ModelTimeoutError(
197
187
  FormattedLLMErrorMessage(
198
188
  cause=f"The request to model {model_name!r} timed out while {purpose}.",
@@ -200,7 +190,7 @@ def handle_llm_exceptions(
200
190
  )
201
191
  ) from None
202
192
 
203
- case UnprocessableEntityError():
193
+ case litellm.exceptions.UnprocessableEntityError():
204
194
  raise ModelUnprocessableEntityError(
205
195
  FormattedLLMErrorMessage(
206
196
  cause=f"The request to model {model_name!r} failed despite correct request format while {purpose}.",
@@ -264,7 +254,7 @@ class DownstreamLLMExceptionMessageParser:
264
254
  self.model_provider_name = model_provider_name
265
255
  self.purpose = purpose
266
256
 
267
- def parse_bad_request_error(self, exception: BadRequestError) -> DataDesignerError:
257
+ def parse_bad_request_error(self, exception: litellm.exceptions.BadRequestError) -> DataDesignerError:
268
258
  err_msg = FormattedLLMErrorMessage(
269
259
  cause=f"The request for model {self.model_name!r} was found to be malformed or missing required parameters while {self.purpose}.",
270
260
  solution="Check your request parameters and try again.",
@@ -276,7 +266,9 @@ class DownstreamLLMExceptionMessageParser:
276
266
  )
277
267
  return ModelBadRequestError(err_msg)
278
268
 
279
- def parse_context_window_exceeded_error(self, exception: ContextWindowExceededError) -> DataDesignerError:
269
+ def parse_context_window_exceeded_error(
270
+ self, exception: litellm.exceptions.ContextWindowExceededError
271
+ ) -> DataDesignerError:
280
272
  cause = f"The input data for model '{self.model_name}' was found to exceed its supported context width while {self.purpose}."
281
273
  try:
282
274
  if "OpenAIException - This model's maximum context length is " in str(exception):
@@ -295,7 +287,7 @@ class DownstreamLLMExceptionMessageParser:
295
287
  )
296
288
 
297
289
  def parse_api_error(
298
- self, exception: InternalServerError, auth_error_msg: FormattedLLMErrorMessage
290
+ self, exception: litellm.exceptions.InternalServerError, auth_error_msg: FormattedLLMErrorMessage
299
291
  ) -> DataDesignerError:
300
292
  if "Error code: 403" in str(exception):
301
293
  return ModelAuthenticationError(auth_error_msg)
@@ -6,10 +6,7 @@ from __future__ import annotations
6
6
  import logging
7
7
  from collections.abc import Callable
8
8
  from copy import deepcopy
9
- from typing import Any
10
-
11
- from litellm.types.router import DeploymentTypedDict, LiteLLM_Params
12
- from litellm.types.utils import EmbeddingResponse, ModelResponse
9
+ from typing import TYPE_CHECKING, Any
13
10
 
14
11
  from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
15
12
  from data_designer.engine.model_provider import ModelProviderRegistry
@@ -23,6 +20,10 @@ from data_designer.engine.models.parsers.errors import ParserException
23
20
  from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
24
21
  from data_designer.engine.models.utils import prompt_to_messages, str_to_message
25
22
  from data_designer.engine.secret_resolver import SecretResolver
23
+ from data_designer.lazy_heavy_imports import litellm
24
+
25
+ if TYPE_CHECKING:
26
+ import litellm
26
27
 
27
28
  logger = logging.getLogger(__name__)
28
29
 
@@ -65,7 +66,9 @@ class ModelFacade:
65
66
  def usage_stats(self) -> ModelUsageStats:
66
67
  return self._usage_stats
67
68
 
68
- def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs) -> ModelResponse:
69
+ def completion(
70
+ self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs
71
+ ) -> litellm.ModelResponse:
69
72
  logger.debug(
70
73
  f"Prompting model {self.model_name!r}...",
71
74
  extra={"model": self.model_name, "messages": messages},
@@ -236,14 +239,14 @@ class ModelFacade:
236
239
  ) from exc
237
240
  return output_obj, reasoning_trace
238
241
 
239
- def _get_litellm_deployment(self, model_config: ModelConfig) -> DeploymentTypedDict:
242
+ def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict:
240
243
  provider = self._model_provider_registry.get_provider(model_config.provider)
241
244
  api_key = None
242
245
  if provider.api_key:
243
246
  api_key = self._secret_resolver.resolve(provider.api_key)
244
247
  api_key = api_key or "not-used-but-required"
245
248
 
246
- litellm_params = LiteLLM_Params(
249
+ litellm_params = litellm.LiteLLM_Params(
247
250
  model=f"{provider.provider_type}/{model_config.model}",
248
251
  api_base=provider.endpoint,
249
252
  api_key=api_key,
@@ -253,7 +256,7 @@ class ModelFacade:
253
256
  "litellm_params": litellm_params.model_dump(),
254
257
  }
255
258
 
256
- def _track_usage(self, response: ModelResponse | None) -> None:
259
+ def _track_usage(self, response: litellm.types.utils.ModelResponse | None) -> None:
257
260
  if response is None:
258
261
  self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
259
262
  return
@@ -270,7 +273,7 @@ class ModelFacade:
270
273
  request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
271
274
  )
272
275
 
273
- def _track_usage_from_embedding(self, response: EmbeddingResponse | None) -> None:
276
+ def _track_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None:
274
277
  if response is None:
275
278
  self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
276
279
  return
@@ -0,0 +1,42 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING
7
+
8
+ from data_designer.config.models import ModelConfig
9
+ from data_designer.engine.model_provider import ModelProviderRegistry
10
+ from data_designer.engine.secret_resolver import SecretResolver
11
+
12
+ if TYPE_CHECKING:
13
+ from data_designer.engine.models.registry import ModelRegistry
14
+
15
+
16
+ def create_model_registry(
17
+ *,
18
+ model_configs: list[ModelConfig] | None = None,
19
+ secret_resolver: SecretResolver,
20
+ model_provider_registry: ModelProviderRegistry,
21
+ ) -> ModelRegistry:
22
+ """Factory function for creating a ModelRegistry instance.
23
+
24
+ Heavy dependencies (litellm, httpx) are deferred until this function is called.
25
+ This is a factory function pattern - imports inside factories are idiomatic Python
26
+ for lazy initialization.
27
+ """
28
+ from data_designer.engine.models.facade import ModelFacade
29
+ from data_designer.engine.models.litellm_overrides import apply_litellm_patches
30
+ from data_designer.engine.models.registry import ModelRegistry
31
+
32
+ apply_litellm_patches()
33
+
34
+ def model_facade_factory(model_config, secret_resolver, model_provider_registry):
35
+ return ModelFacade(model_config, secret_resolver, model_provider_registry)
36
+
37
+ return ModelRegistry(
38
+ model_configs=model_configs,
39
+ secret_resolver=secret_resolver,
40
+ model_provider_registry=model_provider_registry,
41
+ model_facade_factory=model_facade_factory,
42
+ )
@@ -5,21 +5,26 @@ from __future__ import annotations
5
5
 
6
6
  import random
7
7
  import threading
8
+ from typing import TYPE_CHECKING
8
9
 
9
- import httpx
10
- import litellm
11
- from litellm import RetryPolicy
12
- from litellm.caching.in_memory_cache import InMemoryCache
13
- from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager
14
- from litellm.router import Router
15
10
  from pydantic import BaseModel, Field
16
11
  from typing_extensions import override
17
12
 
13
+ from data_designer.lazy_heavy_imports import httpx, litellm
18
14
  from data_designer.logging import quiet_noisy_logger
19
15
 
16
+ if TYPE_CHECKING:
17
+ import httpx
18
+ import litellm
19
+
20
20
  DEFAULT_MAX_CALLBACKS = 1000
21
21
 
22
22
 
23
+ def _get_logging_callback_manager():
24
+ """Lazy accessor for LoggingCallbackManager to avoid loading litellm at import time."""
25
+ return litellm.litellm_core_utils.logging_callback_manager.LoggingCallbackManager
26
+
27
+
23
28
  class LiteLLMRouterDefaultKwargs(BaseModel):
24
29
  ## Number of seconds to wait initially after a connection
25
30
  ## failure.
@@ -35,15 +40,15 @@ class LiteLLMRouterDefaultKwargs(BaseModel):
35
40
 
36
41
  ## Sets the default retry policy, including the number
37
42
  ## of retries to use in particular scenarios.
38
- retry_policy: RetryPolicy = Field(
39
- default_factory=lambda: RetryPolicy(
43
+ retry_policy: litellm.RetryPolicy = Field(
44
+ default_factory=lambda: litellm.RetryPolicy(
40
45
  RateLimitErrorRetries=3,
41
46
  TimeoutErrorRetries=3,
42
47
  )
43
48
  )
44
49
 
45
50
 
46
- class ThreadSafeCache(InMemoryCache):
51
+ class ThreadSafeCache(litellm.caching.in_memory_cache.InMemoryCache):
47
52
  def __init__(self, *args, **kwargs):
48
53
  super().__init__(*args, **kwargs)
49
54
 
@@ -78,7 +83,7 @@ class ThreadSafeCache(InMemoryCache):
78
83
  super().flush_cache()
79
84
 
80
85
 
81
- class CustomRouter(Router):
86
+ class CustomRouter(litellm.router.Router):
82
87
  def __init__(
83
88
  self,
84
89
  *args,
@@ -155,7 +160,7 @@ def apply_litellm_patches():
155
160
  litellm.in_memory_llm_clients_cache = ThreadSafeCache()
156
161
 
157
162
  # Workaround for the litellm issue described in https://github.com/BerriAI/litellm/issues/9792
158
- LoggingCallbackManager.MAX_CALLBACKS = DEFAULT_MAX_CALLBACKS
163
+ _get_logging_callback_manager().MAX_CALLBACKS = DEFAULT_MAX_CALLBACKS
159
164
 
160
165
  quiet_noisy_logger("httpx")
161
166
  quiet_noisy_logger("LiteLLM")
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
 
5
7
  class ParserException(Exception):
6
8
  """Identifies errors resulting from generic parser errors.
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from functools import reduce
5
7
 
6
8
  import marko
@@ -80,13 +82,11 @@ class LLMResponseParser:
80
82
  code: str
81
83
  syntax: Optional[str] = None
82
84
 
83
-
84
85
  class CodeBlockParser:
85
86
  def __call__(self, element: _Element) -> CodeBlock:
86
87
  # Implementation details...
87
88
  return CodeBlock(code=element.text, syntax=element.get("class"))
88
89
 
89
-
90
90
  parser = LLMResponseParser(
91
91
  tag_parsers={
92
92
  "pre.code": CodeBlockParser(),
@@ -1,6 +1,7 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
4
5
 
5
6
  import json_repair
6
7
  from pydantic import BaseModel, ValidationError
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from lxml.etree import _Element
5
7
 
6
8
  from data_designer.engine.models.parsers.types import CodeBlock, TextBlock
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from typing import Any, Protocol, runtime_checkable
5
7
 
6
8
  from lxml.etree import _Element
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import abc
5
7
  from collections.abc import Callable
6
8
  from typing import Generic, TypeVar
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import json
5
7
  from collections.abc import Callable
6
8
 
@@ -4,14 +4,17 @@
4
4
  from __future__ import annotations
5
5
 
6
6
  import logging
7
+ from collections.abc import Callable
8
+ from typing import TYPE_CHECKING
7
9
 
8
10
  from data_designer.config.models import GenerationType, ModelConfig
9
11
  from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
10
- from data_designer.engine.models.facade import ModelFacade
11
- from data_designer.engine.models.litellm_overrides import apply_litellm_patches
12
12
  from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
13
13
  from data_designer.engine.secret_resolver import SecretResolver
14
14
 
15
+ if TYPE_CHECKING:
16
+ from data_designer.engine.models.facade import ModelFacade
17
+
15
18
  logger = logging.getLogger(__name__)
16
19
 
17
20
 
@@ -22,10 +25,12 @@ class ModelRegistry:
22
25
  secret_resolver: SecretResolver,
23
26
  model_provider_registry: ModelProviderRegistry,
24
27
  model_configs: list[ModelConfig] | None = None,
28
+ model_facade_factory: Callable[[ModelConfig, SecretResolver, ModelProviderRegistry], ModelFacade] | None = None,
25
29
  ):
26
30
  self._secret_resolver = secret_resolver
27
31
  self._model_provider_registry = model_provider_registry
28
- self._model_configs = {}
32
+ self._model_facade_factory = model_facade_factory
33
+ self._model_configs: dict[str, ModelConfig] = {}
29
34
  self._models: dict[str, ModelFacade] = {}
30
35
  self._set_model_configs(model_configs)
31
36
 
@@ -136,18 +141,6 @@ class ModelRegistry:
136
141
  # Models are now lazily initialized in get_model() when first requested
137
142
 
138
143
  def _get_model(self, model_config: ModelConfig) -> ModelFacade:
139
- return ModelFacade(model_config, self._secret_resolver, self._model_provider_registry)
140
-
141
-
142
- def create_model_registry(
143
- *,
144
- model_configs: list[ModelConfig] | None = None,
145
- secret_resolver: SecretResolver,
146
- model_provider_registry: ModelProviderRegistry,
147
- ) -> ModelRegistry:
148
- apply_litellm_patches()
149
- return ModelRegistry(
150
- model_configs=model_configs,
151
- secret_resolver=secret_resolver,
152
- model_provider_registry=model_provider_registry,
153
- )
144
+ if self._model_facade_factory is None:
145
+ raise RuntimeError("ModelRegistry was not initialized with a model_facade_factory")
146
+ return self._model_facade_factory(model_config, self._secret_resolver, self._model_provider_registry)
@@ -18,11 +18,15 @@ import platform
18
18
  from dataclasses import dataclass
19
19
  from datetime import datetime, timezone
20
20
  from enum import Enum
21
- from typing import Any, ClassVar
21
+ from typing import TYPE_CHECKING, Any, ClassVar
22
22
 
23
- import httpx
24
23
  from pydantic import BaseModel, Field
25
24
 
25
+ from data_designer.lazy_heavy_imports import httpx
26
+
27
+ if TYPE_CHECKING:
28
+ import httpx
29
+
26
30
  TELEMETRY_ENABLED = os.getenv("NEMO_TELEMETRY_ENABLED", "true").lower() in ("1", "true", "yes")
27
31
  CLIENT_ID = "184482118588404"
28
32
  NEMO_TELEMETRY_VERSION = "nemo-telemetry/1.0"
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from collections import deque
5
7
 
6
8
  from jinja2 import nodes as j_nodes
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import re
5
7
  from collections.abc import Callable
6
8
  from functools import partial, wraps
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import re
5
7
 
6
8
  from jinja2 import TemplateAssertionError
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import json
5
7
 
6
8
  from data_designer.config.utils.io_helpers import serialize_data
@@ -1,8 +1,15 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
- from jsonschema import ValidationError
4
+ from __future__ import annotations
5
5
 
6
+ from typing import TYPE_CHECKING
6
7
 
7
- class JSONSchemaValidationError(ValidationError):
8
+ from data_designer.lazy_heavy_imports import jsonschema
9
+
10
+ if TYPE_CHECKING:
11
+ import jsonschema
12
+
13
+
14
+ class JSONSchemaValidationError(jsonschema.ValidationError):
8
15
  """Alias of ValidationError to ease imports."""
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from copy import deepcopy
5
7
  from typing import Any
6
8
 
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from typing import Any, TypeVar
5
7
 
6
8
  T_primitive = TypeVar("T_primitive", str, int, float, bool)
@@ -1,19 +1,23 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import logging
5
7
  import re
6
8
  from copy import deepcopy
7
9
  from decimal import ROUND_HALF_UP, Decimal
8
- from typing import Any, overload
9
-
10
- from jsonschema import Draft202012Validator, ValidationError, validators
10
+ from typing import TYPE_CHECKING, Any, overload
11
11
 
12
12
  from data_designer.engine.processing.gsonschema.exceptions import JSONSchemaValidationError
13
13
  from data_designer.engine.processing.gsonschema.schema_transformers import forbid_additional_properties
14
14
  from data_designer.engine.processing.gsonschema.types import DataObjectT, JSONSchemaT, T_primitive
15
+ from data_designer.lazy_heavy_imports import jsonschema
16
+
17
+ if TYPE_CHECKING:
18
+ import jsonschema
15
19
 
16
- DEFAULT_JSONSCHEMA_VALIDATOR = Draft202012Validator
20
+ DEFAULT_JSONSCHEMA_VALIDATOR = jsonschema.Draft202012Validator
17
21
 
18
22
  logger = logging.getLogger(__name__)
19
23
 
@@ -69,7 +73,7 @@ def extend_jsonschema_validator_with_pruning(validator):
69
73
  Type[jsonschema.Validator]: A validator class that will
70
74
  prune extra fields.
71
75
  """
72
- return validators.extend(validator, {"additionalProperties": prune_additional_properties})
76
+ return jsonschema.validators.extend(validator, {"additionalProperties": prune_additional_properties})
73
77
 
74
78
 
75
79
  def _get_decimal_info_from_anyof(schema: dict) -> tuple[bool, int | None]:
@@ -190,7 +194,7 @@ def validate(
190
194
 
191
195
  try:
192
196
  validator(schema).validate(final_object)
193
- except ValidationError as exc:
197
+ except jsonschema.ValidationError as exc:
194
198
  raise JSONSchemaValidationError(str(exc)) from exc
195
199
 
196
200
  final_object = normalize_decimal_fields(final_object, schema)
@@ -5,13 +5,9 @@ from __future__ import annotations
5
5
 
6
6
  from abc import ABC, abstractmethod
7
7
 
8
- from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT
8
+ from data_designer.engine.configurable_task import ConfigurableTask, DataT, TaskConfigT
9
9
 
10
10
 
11
11
  class Processor(ConfigurableTask[TaskConfigT], ABC):
12
- @staticmethod
13
- @abstractmethod
14
- def metadata() -> ConfigurableTaskMetadata: ...
15
-
16
12
  @abstractmethod
17
13
  def process(self, data: DataT, *, current_batch_number: int | None = None) -> DataT: ...
@@ -1,26 +1,23 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
- import logging
4
+ from __future__ import annotations
5
5
 
6
- import pandas as pd
6
+ import logging
7
+ from typing import TYPE_CHECKING
7
8
 
8
9
  from data_designer.config.processors import DropColumnsProcessorConfig
9
- from data_designer.engine.configurable_task import ConfigurableTaskMetadata
10
10
  from data_designer.engine.dataset_builders.artifact_storage import BatchStage
11
11
  from data_designer.engine.processing.processors.base import Processor
12
+ from data_designer.lazy_heavy_imports import pd
13
+
14
+ if TYPE_CHECKING:
15
+ import pandas as pd
12
16
 
13
17
  logger = logging.getLogger(__name__)
14
18
 
15
19
 
16
20
  class DropColumnsProcessor(Processor[DropColumnsProcessorConfig]):
17
- @staticmethod
18
- def metadata() -> ConfigurableTaskMetadata:
19
- return ConfigurableTaskMetadata(
20
- name="drop_columns_processor",
21
- description="Drop columns from the input dataset.",
22
- )
23
-
24
21
  def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None) -> pd.DataFrame:
25
22
  logger.info(f"🙈 Dropping columns: {self.config.column_names}")
26
23
  if current_batch_number is not None: # not in preview mode
@@ -1,6 +1,8 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  from data_designer.config.base import ConfigBase
5
7
  from data_designer.config.processors import (
6
8
  DropColumnsProcessorConfig,