data-designer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. data_designer/__init__.py +15 -0
  2. data_designer/_version.py +34 -0
  3. data_designer/cli/README.md +236 -0
  4. data_designer/cli/__init__.py +6 -0
  5. data_designer/cli/commands/__init__.py +2 -0
  6. data_designer/cli/commands/list.py +130 -0
  7. data_designer/cli/commands/models.py +10 -0
  8. data_designer/cli/commands/providers.py +11 -0
  9. data_designer/cli/commands/reset.py +100 -0
  10. data_designer/cli/controllers/__init__.py +7 -0
  11. data_designer/cli/controllers/model_controller.py +246 -0
  12. data_designer/cli/controllers/provider_controller.py +317 -0
  13. data_designer/cli/forms/__init__.py +20 -0
  14. data_designer/cli/forms/builder.py +51 -0
  15. data_designer/cli/forms/field.py +180 -0
  16. data_designer/cli/forms/form.py +59 -0
  17. data_designer/cli/forms/model_builder.py +125 -0
  18. data_designer/cli/forms/provider_builder.py +76 -0
  19. data_designer/cli/main.py +44 -0
  20. data_designer/cli/repositories/__init__.py +8 -0
  21. data_designer/cli/repositories/base.py +39 -0
  22. data_designer/cli/repositories/model_repository.py +42 -0
  23. data_designer/cli/repositories/provider_repository.py +43 -0
  24. data_designer/cli/services/__init__.py +7 -0
  25. data_designer/cli/services/model_service.py +116 -0
  26. data_designer/cli/services/provider_service.py +111 -0
  27. data_designer/cli/ui.py +448 -0
  28. data_designer/cli/utils.py +47 -0
  29. data_designer/config/__init__.py +2 -0
  30. data_designer/config/analysis/column_profilers.py +89 -0
  31. data_designer/config/analysis/column_statistics.py +274 -0
  32. data_designer/config/analysis/dataset_profiler.py +60 -0
  33. data_designer/config/analysis/utils/errors.py +8 -0
  34. data_designer/config/analysis/utils/reporting.py +188 -0
  35. data_designer/config/base.py +68 -0
  36. data_designer/config/column_configs.py +354 -0
  37. data_designer/config/column_types.py +168 -0
  38. data_designer/config/config_builder.py +660 -0
  39. data_designer/config/data_designer_config.py +40 -0
  40. data_designer/config/dataset_builders.py +11 -0
  41. data_designer/config/datastore.py +151 -0
  42. data_designer/config/default_model_settings.py +123 -0
  43. data_designer/config/errors.py +19 -0
  44. data_designer/config/interface.py +54 -0
  45. data_designer/config/models.py +231 -0
  46. data_designer/config/preview_results.py +32 -0
  47. data_designer/config/processors.py +41 -0
  48. data_designer/config/sampler_constraints.py +51 -0
  49. data_designer/config/sampler_params.py +604 -0
  50. data_designer/config/seed.py +145 -0
  51. data_designer/config/utils/code_lang.py +83 -0
  52. data_designer/config/utils/constants.py +313 -0
  53. data_designer/config/utils/errors.py +19 -0
  54. data_designer/config/utils/info.py +88 -0
  55. data_designer/config/utils/io_helpers.py +273 -0
  56. data_designer/config/utils/misc.py +81 -0
  57. data_designer/config/utils/numerical_helpers.py +28 -0
  58. data_designer/config/utils/type_helpers.py +100 -0
  59. data_designer/config/utils/validation.py +336 -0
  60. data_designer/config/utils/visualization.py +427 -0
  61. data_designer/config/validator_params.py +96 -0
  62. data_designer/engine/__init__.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +55 -0
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
  65. data_designer/engine/analysis/column_profilers/registry.py +20 -0
  66. data_designer/engine/analysis/column_statistics.py +142 -0
  67. data_designer/engine/analysis/dataset_profiler.py +125 -0
  68. data_designer/engine/analysis/errors.py +7 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
  70. data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
  71. data_designer/engine/column_generators/__init__.py +2 -0
  72. data_designer/engine/column_generators/generators/__init__.py +2 -0
  73. data_designer/engine/column_generators/generators/base.py +61 -0
  74. data_designer/engine/column_generators/generators/expression.py +63 -0
  75. data_designer/engine/column_generators/generators/llm_generators.py +172 -0
  76. data_designer/engine/column_generators/generators/samplers.py +75 -0
  77. data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
  78. data_designer/engine/column_generators/generators/validation.py +147 -0
  79. data_designer/engine/column_generators/registry.py +56 -0
  80. data_designer/engine/column_generators/utils/errors.py +13 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
  83. data_designer/engine/configurable_task.py +82 -0
  84. data_designer/engine/dataset_builders/artifact_storage.py +181 -0
  85. data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
  86. data_designer/engine/dataset_builders/errors.py +13 -0
  87. data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
  88. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +56 -0
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
  93. data_designer/engine/dataset_builders/utils/errors.py +13 -0
  94. data_designer/engine/errors.py +49 -0
  95. data_designer/engine/model_provider.py +75 -0
  96. data_designer/engine/models/__init__.py +2 -0
  97. data_designer/engine/models/errors.py +308 -0
  98. data_designer/engine/models/facade.py +225 -0
  99. data_designer/engine/models/litellm_overrides.py +162 -0
  100. data_designer/engine/models/parsers/__init__.py +2 -0
  101. data_designer/engine/models/parsers/errors.py +34 -0
  102. data_designer/engine/models/parsers/parser.py +236 -0
  103. data_designer/engine/models/parsers/postprocessors.py +93 -0
  104. data_designer/engine/models/parsers/tag_parsers.py +60 -0
  105. data_designer/engine/models/parsers/types.py +82 -0
  106. data_designer/engine/models/recipes/base.py +79 -0
  107. data_designer/engine/models/recipes/response_recipes.py +291 -0
  108. data_designer/engine/models/registry.py +118 -0
  109. data_designer/engine/models/usage.py +75 -0
  110. data_designer/engine/models/utils.py +38 -0
  111. data_designer/engine/processing/ginja/__init__.py +2 -0
  112. data_designer/engine/processing/ginja/ast.py +64 -0
  113. data_designer/engine/processing/ginja/environment.py +461 -0
  114. data_designer/engine/processing/ginja/exceptions.py +54 -0
  115. data_designer/engine/processing/ginja/record.py +30 -0
  116. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  117. data_designer/engine/processing/gsonschema/exceptions.py +8 -0
  118. data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
  119. data_designer/engine/processing/gsonschema/types.py +8 -0
  120. data_designer/engine/processing/gsonschema/validators.py +143 -0
  121. data_designer/engine/processing/processors/base.py +15 -0
  122. data_designer/engine/processing/processors/drop_columns.py +46 -0
  123. data_designer/engine/processing/processors/registry.py +20 -0
  124. data_designer/engine/processing/utils.py +120 -0
  125. data_designer/engine/registry/base.py +97 -0
  126. data_designer/engine/registry/data_designer_registry.py +37 -0
  127. data_designer/engine/registry/errors.py +10 -0
  128. data_designer/engine/resources/managed_dataset_generator.py +35 -0
  129. data_designer/engine/resources/managed_dataset_repository.py +194 -0
  130. data_designer/engine/resources/managed_storage.py +63 -0
  131. data_designer/engine/resources/resource_provider.py +46 -0
  132. data_designer/engine/resources/seed_dataset_data_store.py +66 -0
  133. data_designer/engine/sampling_gen/column.py +89 -0
  134. data_designer/engine/sampling_gen/constraints.py +95 -0
  135. data_designer/engine/sampling_gen/data_sources/base.py +214 -0
  136. data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
  137. data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
  138. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  139. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  140. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
  141. data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
  142. data_designer/engine/sampling_gen/entities/errors.py +8 -0
  143. data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
  144. data_designer/engine/sampling_gen/entities/person.py +142 -0
  145. data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
  146. data_designer/engine/sampling_gen/errors.py +24 -0
  147. data_designer/engine/sampling_gen/generator.py +121 -0
  148. data_designer/engine/sampling_gen/jinja_utils.py +60 -0
  149. data_designer/engine/sampling_gen/people_gen.py +203 -0
  150. data_designer/engine/sampling_gen/person_constants.py +54 -0
  151. data_designer/engine/sampling_gen/schema.py +143 -0
  152. data_designer/engine/sampling_gen/schema_builder.py +59 -0
  153. data_designer/engine/sampling_gen/utils.py +40 -0
  154. data_designer/engine/secret_resolver.py +80 -0
  155. data_designer/engine/validators/__init__.py +17 -0
  156. data_designer/engine/validators/base.py +36 -0
  157. data_designer/engine/validators/local_callable.py +34 -0
  158. data_designer/engine/validators/python.py +245 -0
  159. data_designer/engine/validators/remote.py +83 -0
  160. data_designer/engine/validators/sql.py +60 -0
  161. data_designer/errors.py +5 -0
  162. data_designer/essentials/__init__.py +137 -0
  163. data_designer/interface/__init__.py +2 -0
  164. data_designer/interface/data_designer.py +351 -0
  165. data_designer/interface/errors.py +16 -0
  166. data_designer/interface/results.py +55 -0
  167. data_designer/logging.py +161 -0
  168. data_designer/plugin_manager.py +83 -0
  169. data_designer/plugins/__init__.py +6 -0
  170. data_designer/plugins/errors.py +10 -0
  171. data_designer/plugins/plugin.py +69 -0
  172. data_designer/plugins/registry.py +86 -0
  173. data_designer-0.1.0.dist-info/METADATA +173 -0
  174. data_designer-0.1.0.dist-info/RECORD +177 -0
  175. data_designer-0.1.0.dist-info/WHEEL +4 -0
  176. data_designer-0.1.0.dist-info/entry_points.txt +2 -0
  177. data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,225 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from collections.abc import Callable
7
+ from copy import deepcopy
8
+ import logging
9
+ from typing import Any
10
+
11
+ from litellm.types.router import DeploymentTypedDict, LiteLLM_Params
12
+ from litellm.types.utils import ModelResponse
13
+
14
+ from data_designer.config.models import ModelConfig, ModelProvider
15
+ from data_designer.engine.model_provider import ModelProviderRegistry
16
+ from data_designer.engine.models.errors import (
17
+ GenerationValidationFailureError,
18
+ catch_llm_exceptions,
19
+ get_exception_primary_cause,
20
+ )
21
+ from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs
22
+ from data_designer.engine.models.parsers.errors import ParserException
23
+ from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
24
+ from data_designer.engine.models.utils import prompt_to_messages, str_to_message
25
+ from data_designer.engine.secret_resolver import SecretResolver
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class ModelFacade:
31
+ def __init__(
32
+ self,
33
+ model_config: ModelConfig,
34
+ secret_resolver: SecretResolver,
35
+ model_provider_registry: ModelProviderRegistry,
36
+ ):
37
+ self._model_config = model_config
38
+ self._secret_resolver = secret_resolver
39
+ self._model_provider_registry = model_provider_registry
40
+ self._litellm_deployment = self._get_litellm_deployment(model_config)
41
+ self._router = CustomRouter([self._litellm_deployment], **LiteLLMRouterDefaultKwargs().model_dump())
42
+ self._usage_stats = ModelUsageStats()
43
+
44
+ @property
45
+ def model_name(self) -> str:
46
+ return self._model_config.model
47
+
48
+ @property
49
+ def model_provider(self) -> ModelProvider:
50
+ return self._model_provider_registry.get_provider(self._model_config.provider)
51
+
52
+ @property
53
+ def model_provider_name(self) -> str:
54
+ return self.model_provider.name
55
+
56
+ @property
57
+ def model_alias(self) -> str:
58
+ return self._model_config.alias
59
+
60
+ @property
61
+ def usage_stats(self) -> ModelUsageStats:
62
+ return self._usage_stats
63
+
64
+ def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs) -> ModelResponse:
65
+ logger.debug(
66
+ f"Prompting model {self.model_name!r}...",
67
+ extra={"model": self.model_name, "messages": messages, "sensitive": True},
68
+ )
69
+ response = None
70
+ if self.model_provider.extra_body:
71
+ kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
72
+ try:
73
+ response = self._router.completion(self.model_name, messages, **kwargs)
74
+ logger.debug(
75
+ f"Received completion from model {self.model_name!r}",
76
+ extra={
77
+ "model": self.model_name,
78
+ "response": response,
79
+ "text": response.choices[0].message.content,
80
+ "usage": self._usage_stats.model_dump(),
81
+ },
82
+ )
83
+ return response
84
+ except Exception as e:
85
+ raise e
86
+ finally:
87
+ if not skip_usage_tracking:
88
+ self._track_usage(response)
89
+
90
+ @catch_llm_exceptions
91
+ def generate(
92
+ self,
93
+ prompt: str,
94
+ *,
95
+ parser: Callable[[str], Any],
96
+ system_prompt: str | None = None,
97
+ multi_modal_context: list[dict[str, Any]] | None = None,
98
+ max_correction_steps: int = 0,
99
+ max_conversation_restarts: int = 0,
100
+ skip_usage_tracking: bool = False,
101
+ purpose: str | None = None,
102
+ **kwargs,
103
+ ) -> tuple[Any, str | None]:
104
+ """Generate a parsed output with correction steps.
105
+
106
+ This generation call will attempt to generate an output which is
107
+ valid according to the specified parser, where "valid" implies
108
+ that the parser can process the LLM response without raising
109
+ an exception.
110
+
111
+ `ParserExceptions` are routed back
112
+ to the LLM as new rounds in the conversation, where the LLM is provided its
113
+ earlier response along with the "user" role responding with the exception string
114
+ (not traceback). This will continue for the number of rounds specified by
115
+ `max_correction_steps`.
116
+
117
+ Args:
118
+ prompt (str): Task prompt.
119
+ system_prompt (str, optional): Optional system instructions. If not specified,
120
+ no system message is provided and the model should use its default system
121
+ prompt.
122
+ parser (func(str) -> Any): A function applied to the LLM response which processes
123
+ an LLM response into some output object.
124
+ max_correction_steps (int): Maximum number of correction rounds permitted
125
+ within a single conversation. Note, many rounds can lead to increasing
126
+ context size without necessarily improving performance -- small language
127
+ models can enter repeated cycles which will not be solved with more steps.
128
+ Default: `0` (no correction).
129
+ max_conversation_restarts (int): Maximum number of full conversation restarts permitted
130
+ if generation fails. Default: `0` (no restarts).
131
+ skip_usage_tracking (bool): Whether to skip usage tracking. Default: `False`.
132
+ purpose (str): The purpose of the model usage to show as context in the error message.
133
+ It is expected to be used by the @catch_llm_exceptions decorator.
134
+ **kwargs: Additional arguments to pass to the model.
135
+
136
+ Raises:
137
+ GenerationValidationFailureError: If the maximum number of retries or
138
+ correction steps are met and the last response failures on
139
+ generation validation.
140
+ """
141
+ output_obj = None
142
+ curr_num_correction_steps = 0
143
+ curr_num_restarts = 0
144
+ curr_generation_attempt = 0
145
+ max_generation_attempts = (max_correction_steps + 1) * (max_conversation_restarts + 1)
146
+
147
+ starting_messages = prompt_to_messages(
148
+ user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context
149
+ )
150
+ messages = deepcopy(starting_messages)
151
+
152
+ while True:
153
+ curr_generation_attempt += 1
154
+ logger.debug(
155
+ f"Starting generation attempt {curr_generation_attempt} of {max_generation_attempts} attempts."
156
+ )
157
+
158
+ completion_response = self.completion(messages, skip_usage_tracking=skip_usage_tracking, **kwargs)
159
+ response = completion_response.choices[0].message.content or ""
160
+ reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None)
161
+
162
+ if reasoning_trace:
163
+ ## There are generally some extra newlines with how these get parsed.
164
+ response = response.strip()
165
+ reasoning_trace = reasoning_trace.strip()
166
+
167
+ curr_num_correction_steps += 1
168
+
169
+ try:
170
+ output_obj = parser(response) # type: ignore - if not a string will cause a ParserException below
171
+ break
172
+ except ParserException as exc:
173
+ if max_correction_steps == 0 and max_conversation_restarts == 0:
174
+ raise GenerationValidationFailureError(
175
+ "Unsuccessful generation attempt. No retries were attempted."
176
+ ) from exc
177
+ if curr_num_correction_steps <= max_correction_steps:
178
+ ## Add turns to loop-back errors for correction
179
+ messages += [
180
+ str_to_message(content=response, role="assistant"),
181
+ str_to_message(content=str(get_exception_primary_cause(exc)), role="user"),
182
+ ]
183
+ elif curr_num_restarts < max_conversation_restarts:
184
+ curr_num_correction_steps = 0
185
+ curr_num_restarts += 1
186
+ messages = deepcopy(starting_messages)
187
+ else:
188
+ raise GenerationValidationFailureError(
189
+ f"Unsuccessful generation attempt despite {max_generation_attempts} attempts."
190
+ ) from exc
191
+ return output_obj, reasoning_trace
192
+
193
+ def _get_litellm_deployment(self, model_config: ModelConfig) -> DeploymentTypedDict:
194
+ provider = self._model_provider_registry.get_provider(model_config.provider)
195
+ api_key = None
196
+ if provider.api_key:
197
+ api_key = self._secret_resolver.resolve(provider.api_key)
198
+ api_key = api_key or "not-used-but-required"
199
+
200
+ litellm_params = LiteLLM_Params(
201
+ model=f"{provider.provider_type}/{model_config.model}",
202
+ api_base=provider.endpoint,
203
+ api_key=api_key,
204
+ )
205
+ return {
206
+ "model_name": model_config.model,
207
+ "litellm_params": litellm_params.model_dump(),
208
+ }
209
+
210
+ def _track_usage(self, response: ModelResponse | None) -> None:
211
+ if response is None:
212
+ self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
213
+ return
214
+ if (
215
+ response.usage is not None
216
+ and response.usage.prompt_tokens is not None
217
+ and response.usage.completion_tokens is not None
218
+ ):
219
+ self._usage_stats.extend(
220
+ token_usage=TokenUsageStats(
221
+ prompt_tokens=response.usage.prompt_tokens,
222
+ completion_tokens=response.usage.completion_tokens,
223
+ ),
224
+ request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
225
+ )
@@ -0,0 +1,162 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import random
7
+ import threading
8
+ from typing import Optional, Union
9
+
10
+ import httpx
11
+ import litellm
12
+ from litellm import RetryPolicy
13
+ from litellm.caching.in_memory_cache import InMemoryCache
14
+ from litellm.router import Router
15
+ from pydantic import BaseModel, Field
16
+ from typing_extensions import override
17
+
18
+ from data_designer.logging import quiet_noisy_logger
19
+
20
+ DEFAULT_MAX_CALLBACKS = 1000
21
+
22
+
23
+ class LiteLLMRouterDefaultKwargs(BaseModel):
24
+ ## Number of seconds to wait initially after a connection
25
+ ## failure.
26
+ initial_retry_after_s: float = 2.0
27
+
28
+ ## Jitter percentage added during exponential backoff to
29
+ ## smooth repeated retries over time.
30
+ jitter_pct: float = 0.2
31
+
32
+ ## Maximum number of seconds to wait for an API request
33
+ ## before letting it die. Will trigger a retry.
34
+ timeout: float = 60.0
35
+
36
+ ## Sets the default retry policy, including the number
37
+ ## of retries to use in particular scenarios.
38
+ retry_policy: RetryPolicy = Field(
39
+ default_factory=lambda: RetryPolicy(
40
+ RateLimitErrorRetries=3,
41
+ TimeoutErrorRetries=3,
42
+ )
43
+ )
44
+
45
+
46
+ class ThreadSafeCache(InMemoryCache):
47
+ def __init__(self, *args, **kwargs):
48
+ super().__init__(*args, **kwargs)
49
+
50
+ self._lock = threading.RLock()
51
+
52
+ def get_cache(self, key, **kwargs):
53
+ with self._lock:
54
+ return super().get_cache(key, **kwargs)
55
+
56
+ def set_cache(self, key, value, **kwargs):
57
+ with self._lock:
58
+ super().set_cache(key, value, **kwargs)
59
+
60
+ def batch_get_cache(self, keys: list, **kwargs):
61
+ with self._lock:
62
+ return super().batch_get_cache(keys, **kwargs)
63
+
64
+ def delete_cache(self, key):
65
+ with self._lock:
66
+ super().delete_cache(key)
67
+
68
+ def evict_cache(self):
69
+ with self._lock:
70
+ super().evict_cache()
71
+
72
+ def increment_cache(self, key, value: int, **kwargs) -> int:
73
+ with self._lock:
74
+ return super().increment_cache(key, value, **kwargs)
75
+
76
+ def flush_cache(self):
77
+ with self._lock:
78
+ super().flush_cache()
79
+
80
+
81
+ class CustomRouter(Router):
82
+ def __init__(
83
+ self,
84
+ *args,
85
+ initial_retry_after_s: float,
86
+ jitter_pct: float,
87
+ **kwargs,
88
+ ):
89
+ super().__init__(*args, **kwargs)
90
+ self._initial_retry_after_s = initial_retry_after_s
91
+ self._jitter_pct = jitter_pct
92
+
93
+ def _extract_retry_delay_from_headers(self, e: Exception) -> Optional[Union[int, float]]:
94
+ """
95
+ Most of this code logic was extracted directly from the parent
96
+ `Router`'s `_time_to_sleep_before_retry` function. Our override
97
+ of that method below should only affect requests where the server
98
+ didn't explicitly return a desired retry-delay. If the server did
99
+ return this info, we'll simply use that retry value returned here.
100
+ """
101
+
102
+ response_headers: Optional[httpx.Headers] = None
103
+ if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
104
+ response_headers = e.response.headers # type: ignore
105
+ if hasattr(e, "litellm_response_headers"):
106
+ response_headers = e.litellm_response_headers # type: ignore
107
+
108
+ retry_after = litellm.utils._get_retry_after_from_exception_header(response_headers)
109
+
110
+ # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
111
+ if retry_after is not None and 0 < retry_after <= 60:
112
+ return retry_after
113
+ else:
114
+ return None
115
+
116
+ @override
117
+ def _time_to_sleep_before_retry(
118
+ self,
119
+ e: Exception,
120
+ remaining_retries: int,
121
+ num_retries: int,
122
+ healthy_deployments: Optional[list] = None,
123
+ all_deployments: Optional[list] = None,
124
+ ) -> Union[int, float]:
125
+ """
126
+ Implements exponential backoff for retries.
127
+
128
+ Technically, litellm's `Router` already implements some
129
+ form of exponential backoff. However, that backoff
130
+ is not customizable w.r.t jitter and initial delay
131
+ timing. For that reason, we override this method to
132
+ utilize our own custom instance variables, deferring
133
+ to the existing implementation wherever we can.
134
+ """
135
+
136
+ # If the response headers indicated how long we should wait,
137
+ # use that information.
138
+ if retry_after := self._extract_retry_delay_from_headers(e):
139
+ return retry_after
140
+
141
+ return self.calculate_exponential_backoff(
142
+ initial_retry_after_s=self._initial_retry_after_s,
143
+ current_retry=num_retries - remaining_retries,
144
+ jitter_pct=self._jitter_pct,
145
+ )
146
+
147
+ @staticmethod
148
+ def calculate_exponential_backoff(initial_retry_after_s: float, current_retry: int, jitter_pct: float) -> float:
149
+ sleep_s = initial_retry_after_s * (pow(2.0, current_retry))
150
+ jitter = 1.0 + random.uniform(-jitter_pct, jitter_pct)
151
+ return sleep_s * jitter
152
+
153
+
154
+ def apply_litellm_patches():
155
+ litellm.in_memory_llm_clients_cache = ThreadSafeCache()
156
+
157
+ # Workaround for the litellm issue described in https://github.com/BerriAI/litellm/issues/9792
158
+ litellm.litellm_core_utils.logging_callback_manager.LoggingCallbackManager.MAX_CALLBACKS = DEFAULT_MAX_CALLBACKS
159
+
160
+ quiet_noisy_logger("httpx")
161
+ quiet_noisy_logger("LiteLLM")
162
+ quiet_noisy_logger("LiteLLM Router")
@@ -0,0 +1,2 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,34 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Optional
5
+
6
+
7
+ class ParserException(Exception):
8
+ """Identifies errors resulting from generic parser errors.
9
+
10
+ Attributes:
11
+ source (str | None): The source string that the parser
12
+ attempted to parse.
13
+ """
14
+
15
+ source: Optional[str]
16
+
17
+ @staticmethod
18
+ def _log_format(source: str) -> str:
19
+ ## NOTE: The point of this was to be able to report offending
20
+ ## failure cases to the logs. This might not be what we want
21
+ ## to do in all cases. In the meantime, this note is left
22
+ ## for later review.
23
+ #
24
+ # return f"<source>{source}</source>"
25
+ return ""
26
+
27
+ def __init__(self, msg: Optional[str] = None, source: Optional[str] = None):
28
+ msg = "" if msg is None else msg.strip()
29
+
30
+ if source is not None:
31
+ msg += self._log_format(source)
32
+
33
+ super().__init__(msg)
34
+ self.source = source
@@ -0,0 +1,236 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from functools import reduce
5
+ from typing import Optional
6
+
7
+ from lxml import etree
8
+ from lxml.etree import _Element
9
+ import marko
10
+
11
+ from data_designer.engine.models.parsers.postprocessors import merge_text_blocks
12
+ import data_designer.engine.models.parsers.tag_parsers as tp
13
+ from data_designer.engine.models.parsers.types import (
14
+ LLMStructuredResponse,
15
+ PostProcessor,
16
+ TagParser,
17
+ )
18
+
19
+ DEFAULT_TAG_PARSERS = {
20
+ "pre.code": tp.code_block_parser,
21
+ "p.code": tp.inline_code_parser,
22
+ "p": tp.text_parser,
23
+ "pre": tp.text_parser,
24
+ "": tp.text_parser_keep_markup,
25
+ }
26
+
27
+ DEFAULT_POST_PROCESSORS = [merge_text_blocks]
28
+
29
+
30
+ def _patch_tags_before_code_fences(response: str) -> str:
31
+ """Patch to add a linebreak between a tag prior to a code block.
32
+
33
+ Marko conversion of MD->HTML has a quirk. If there is a case like
34
+ the following, it will not convert the code block at all:
35
+
36
+ ...
37
+ </ending_tag>
38
+ ```syntax
39
+ ...
40
+
41
+ We want to find these cases and simply introduce an additional
42
+ line break.
43
+ """
44
+
45
+ return response.replace(">\n```", ">\n\n```")
46
+
47
+
48
+ class LLMResponseParser:
49
+ """
50
+ Parses Language Model (LLM) responses containing a mixture of Markdown and custom markup into structured data.
51
+
52
+ The `LLMResponseParser` class facilitates the translation of LLM-generated responses, which may include
53
+ Markdown and custom markup tags, into a structured format using ElementTree. It allows for customizable
54
+ parsing behavior through the registration of tag-specific parsers and post-processors.
55
+
56
+ ## Description
57
+
58
+ The core functionality of this class enables LLMs to respond using Markdown along with any custom
59
+ prompted markup specified by the system or task. The parsing process involves converting the Markdown
60
+ and markup into an ElementTree, then processing each element using registered tag parsers to produce
61
+ a list of structured `BaseModel` instances. Post-processors can further refine the structured response.
62
+
63
+ ### Tag Parsers
64
+
65
+ Tag parsers are responsible for handling specific markup tags within the LLM response. They can be
66
+ registered with the parser using dot-path notation to manage hierarchical tag structures. This allows
67
+ downstream tasks to customize how specific elements are processed into `BaseModel` instances.
68
+
69
+ ### Post-Processors
70
+
71
+ Post-processors are functions that operate on the list of parsed blocks to perform additional
72
+ transformations or aggregations. They are applied after the initial parsing of the response.
73
+
74
+ Attributes:
75
+ tag_parsers (dict[str, TagParser]): A dictionary mapping tag paths to their corresponding `TagParser` instances.
76
+ postprocessors (list[PostProcessor]): A list of post-processing functions to apply to the structured response.
77
+
78
+ Example:
79
+ ```python
80
+ class CodeBlock(BaseModel):
81
+ code: str
82
+ syntax: Optional[str] = None
83
+
84
+
85
+ class CodeBlockParser:
86
+ def __call__(self, element: _Element) -> CodeBlock:
87
+ # Implementation details...
88
+ return CodeBlock(code=element.text, syntax=element.get("class"))
89
+
90
+
91
+ parser = LLMResponseParser(
92
+ tag_parsers={
93
+ "pre.code": CodeBlockParser(),
94
+ }
95
+ )
96
+
97
+ out = parser.parse('```json\n{"answer": 42}\n```')
98
+ print(out.parsed)
99
+ # Output: [CodeBlock(code='{"answer": 42}\n', syntax='json')]
100
+ ```
101
+ """
102
+
103
+ tag_parsers: dict[str, TagParser]
104
+ postprocessors: list[PostProcessor]
105
+
106
+ def __init__(
107
+ self,
108
+ tag_parsers: Optional[dict[str, TagParser]] = None,
109
+ postprocessors: Optional[list[PostProcessor]] = None,
110
+ ):
111
+ """
112
+ Initializes the LLMResponseParser with optional tag parsers and post-processors.
113
+
114
+ Args:
115
+ tag_parsers (Optional[dict[str, TagParser]]): A dictionary mapping tag paths to `TagParser` instances.
116
+ If provided, these parsers will be merged with the default tag parsers.
117
+ postprocessors (Optional[list[PostProcessor]]): A list of post-processing functions to apply
118
+ to the structured response. If not provided, a default post-processor `merge_text_blocks`
119
+ is used.
120
+
121
+ Attributes:
122
+ tag_parsers (dict[str, TagParser]): Initialized with default tag parsers, updated with any provided.
123
+ postprocessors (list[PostProcessor]): Initialized with default post-processors or the provided list.
124
+ """
125
+ self.tag_parsers = {**DEFAULT_TAG_PARSERS}
126
+ if tag_parsers:
127
+ self.tag_parsers.update(tag_parsers)
128
+
129
+ self.postprocessors = [
130
+ merge_text_blocks,
131
+ ]
132
+ if postprocessors is not None:
133
+ self.postprocessors = postprocessors
134
+
135
+ def lookup_parser(self, element: _Element) -> TagParser:
136
+ """
137
+ Resolves and retrieves the appropriate `TagParser` for a given XML element based on its tag hierarchy.
138
+
139
+ The method constructs the dot-path lineage of the element's tags, starting from the root and moving
140
+ towards the specific element. It then attempts to find the most specific matching `TagParser` by
141
+ progressively reducing the specificity of the tag path until a matching parser is found.
142
+
143
+ Args:
144
+ element (_Element): The XML element for which to find the corresponding `TagParser`.
145
+
146
+ Returns:
147
+ TagParser: The `TagParser` instance that matches the element's tag path.
148
+
149
+ Raises:
150
+ KeyError: If no matching `TagParser` is found for the element's tag path.
151
+ """
152
+ # Get the dot path lineage of this tag, sans root.
153
+ # Note that the lineage comes back in reverse order.
154
+ parents = [e.tag for e in element.iterancestors()][::-1]
155
+ lineage = [*parents, element.tag]
156
+
157
+ # Now attempt to matchup with the tag parsers name.
158
+ # Starts from the full linear (most specific), and
159
+ # breaks on the first hit. So this should properly
160
+ # prioritize specific parsers over general ones.
161
+ while lineage:
162
+ tag_path = ".".join(lineage)
163
+ if tag_path not in self.tag_parsers:
164
+ lineage.pop(0)
165
+ else:
166
+ break
167
+
168
+ # Tag path can be an empty string, which hits the
169
+ # default parsing option specified by the "" entry
170
+ # of the tag parsers dict.
171
+ tag_path = ".".join(lineage)
172
+ return self.tag_parsers[tag_path]
173
+
174
+ def postprocess(self, structured_response: LLMStructuredResponse) -> LLMStructuredResponse:
175
+ """
176
+ Applies post-processing functions to the structured response.
177
+
178
+ If no post-processors are registered, the original structured response is returned.
179
+ Otherwise, each post-processor is applied in sequence to transform the response.
180
+
181
+ Args:
182
+ structured_response (LLMStructuredResponse): The initial structured response to be post-processed.
183
+
184
+ Returns:
185
+ LLMStructuredResponse: The post-processed structured response.
186
+ """
187
+ if not self.postprocessors:
188
+ return structured_response
189
+
190
+ return reduce(lambda acc, func: func(acc), self.postprocessors, structured_response)
191
+
192
+ def parse(self, md_response: str) -> LLMStructuredResponse:
193
+ """
194
+ Parses a Markdown-formatted LLM response into a structured `LLMStructuredResponse`.
195
+
196
+ The parsing process involves converting the Markdown and custom markup into an XML tree,
197
+ iterating over each element in a depth-first traversal to apply the appropriate
198
+ `TagParser`, and then applying any registered post-processors to the resulting structured data.
199
+
200
+ Args:
201
+ md_response (str): The Markdown-formatted response from the LLM, potentially containing custom markup.
202
+
203
+ Returns:
204
+ LLMStructuredResponse: The structured representation of the parsed response, containing parsed blocks.
205
+
206
+ Raises:
207
+ etree.XMLSyntaxError: If the provided Markdown cannot be converted into a valid XML structure.
208
+ """
209
+ response = marko.convert(_patch_tags_before_code_fences(md_response))
210
+ output = LLMStructuredResponse(response=md_response, markup=response)
211
+
212
+ # Generate document tree
213
+ parser = etree.HTMLParser(recover=True, remove_blank_text=True)
214
+ root = etree.fromstring(response, parser=parser)
215
+ tags = root.iter() if root is not None else []
216
+
217
+ # Iterate over tags, depth first
218
+ for element in tags:
219
+ if element == root or element.tag == "body":
220
+ continue
221
+
222
+ parsed_block = self.lookup_parser(element)(element)
223
+
224
+ # Make a quick check for dead text blocks, which
225
+ # can happen with container tags like <pre>, <ul>, and <ol>.
226
+ drop_block = isinstance(parsed_block, tp.TextBlock) and not parsed_block.text.strip()
227
+
228
+ if not drop_block:
229
+ output.parsed.append(parsed_block)
230
+
231
+ # Check tails -- inelegant, but they're always text.
232
+ # Don't add the tail if it is just blank space.
233
+ if element.tail and element.tail.strip():
234
+ output.parsed.append(tp.TextBlock(text=element.tail))
235
+
236
+ return self.postprocess(output)