data-designer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +15 -0
- data_designer/_version.py +34 -0
- data_designer/cli/README.md +236 -0
- data_designer/cli/__init__.py +6 -0
- data_designer/cli/commands/__init__.py +2 -0
- data_designer/cli/commands/list.py +130 -0
- data_designer/cli/commands/models.py +10 -0
- data_designer/cli/commands/providers.py +11 -0
- data_designer/cli/commands/reset.py +100 -0
- data_designer/cli/controllers/__init__.py +7 -0
- data_designer/cli/controllers/model_controller.py +246 -0
- data_designer/cli/controllers/provider_controller.py +317 -0
- data_designer/cli/forms/__init__.py +20 -0
- data_designer/cli/forms/builder.py +51 -0
- data_designer/cli/forms/field.py +180 -0
- data_designer/cli/forms/form.py +59 -0
- data_designer/cli/forms/model_builder.py +125 -0
- data_designer/cli/forms/provider_builder.py +76 -0
- data_designer/cli/main.py +44 -0
- data_designer/cli/repositories/__init__.py +8 -0
- data_designer/cli/repositories/base.py +39 -0
- data_designer/cli/repositories/model_repository.py +42 -0
- data_designer/cli/repositories/provider_repository.py +43 -0
- data_designer/cli/services/__init__.py +7 -0
- data_designer/cli/services/model_service.py +116 -0
- data_designer/cli/services/provider_service.py +111 -0
- data_designer/cli/ui.py +448 -0
- data_designer/cli/utils.py +47 -0
- data_designer/config/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +89 -0
- data_designer/config/analysis/column_statistics.py +274 -0
- data_designer/config/analysis/dataset_profiler.py +60 -0
- data_designer/config/analysis/utils/errors.py +8 -0
- data_designer/config/analysis/utils/reporting.py +188 -0
- data_designer/config/base.py +68 -0
- data_designer/config/column_configs.py +354 -0
- data_designer/config/column_types.py +168 -0
- data_designer/config/config_builder.py +660 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +11 -0
- data_designer/config/datastore.py +151 -0
- data_designer/config/default_model_settings.py +123 -0
- data_designer/config/errors.py +19 -0
- data_designer/config/interface.py +54 -0
- data_designer/config/models.py +231 -0
- data_designer/config/preview_results.py +32 -0
- data_designer/config/processors.py +41 -0
- data_designer/config/sampler_constraints.py +51 -0
- data_designer/config/sampler_params.py +604 -0
- data_designer/config/seed.py +145 -0
- data_designer/config/utils/code_lang.py +83 -0
- data_designer/config/utils/constants.py +313 -0
- data_designer/config/utils/errors.py +19 -0
- data_designer/config/utils/info.py +88 -0
- data_designer/config/utils/io_helpers.py +273 -0
- data_designer/config/utils/misc.py +81 -0
- data_designer/config/utils/numerical_helpers.py +28 -0
- data_designer/config/utils/type_helpers.py +100 -0
- data_designer/config/utils/validation.py +336 -0
- data_designer/config/utils/visualization.py +427 -0
- data_designer/config/validator_params.py +96 -0
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +55 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
- data_designer/engine/analysis/column_profilers/registry.py +20 -0
- data_designer/engine/analysis/column_statistics.py +142 -0
- data_designer/engine/analysis/dataset_profiler.py +125 -0
- data_designer/engine/analysis/errors.py +7 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +61 -0
- data_designer/engine/column_generators/generators/expression.py +63 -0
- data_designer/engine/column_generators/generators/llm_generators.py +172 -0
- data_designer/engine/column_generators/generators/samplers.py +75 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
- data_designer/engine/column_generators/generators/validation.py +147 -0
- data_designer/engine/column_generators/registry.py +56 -0
- data_designer/engine/column_generators/utils/errors.py +13 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
- data_designer/engine/configurable_task.py +82 -0
- data_designer/engine/dataset_builders/artifact_storage.py +181 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
- data_designer/engine/dataset_builders/errors.py +13 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
- data_designer/engine/dataset_builders/utils/dag.py +56 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
- data_designer/engine/dataset_builders/utils/errors.py +13 -0
- data_designer/engine/errors.py +49 -0
- data_designer/engine/model_provider.py +75 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +308 -0
- data_designer/engine/models/facade.py +225 -0
- data_designer/engine/models/litellm_overrides.py +162 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +236 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +60 -0
- data_designer/engine/models/parsers/types.py +82 -0
- data_designer/engine/models/recipes/base.py +79 -0
- data_designer/engine/models/recipes/response_recipes.py +291 -0
- data_designer/engine/models/registry.py +118 -0
- data_designer/engine/models/usage.py +75 -0
- data_designer/engine/models/utils.py +38 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +64 -0
- data_designer/engine/processing/ginja/environment.py +461 -0
- data_designer/engine/processing/ginja/exceptions.py +54 -0
- data_designer/engine/processing/ginja/record.py +30 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +8 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
- data_designer/engine/processing/gsonschema/types.py +8 -0
- data_designer/engine/processing/gsonschema/validators.py +143 -0
- data_designer/engine/processing/processors/base.py +15 -0
- data_designer/engine/processing/processors/drop_columns.py +46 -0
- data_designer/engine/processing/processors/registry.py +20 -0
- data_designer/engine/processing/utils.py +120 -0
- data_designer/engine/registry/base.py +97 -0
- data_designer/engine/registry/data_designer_registry.py +37 -0
- data_designer/engine/registry/errors.py +10 -0
- data_designer/engine/resources/managed_dataset_generator.py +35 -0
- data_designer/engine/resources/managed_dataset_repository.py +194 -0
- data_designer/engine/resources/managed_storage.py +63 -0
- data_designer/engine/resources/resource_provider.py +46 -0
- data_designer/engine/resources/seed_dataset_data_store.py +66 -0
- data_designer/engine/sampling_gen/column.py +89 -0
- data_designer/engine/sampling_gen/constraints.py +95 -0
- data_designer/engine/sampling_gen/data_sources/base.py +214 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
- data_designer/engine/sampling_gen/entities/errors.py +8 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
- data_designer/engine/sampling_gen/entities/person.py +142 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
- data_designer/engine/sampling_gen/errors.py +24 -0
- data_designer/engine/sampling_gen/generator.py +121 -0
- data_designer/engine/sampling_gen/jinja_utils.py +60 -0
- data_designer/engine/sampling_gen/people_gen.py +203 -0
- data_designer/engine/sampling_gen/person_constants.py +54 -0
- data_designer/engine/sampling_gen/schema.py +143 -0
- data_designer/engine/sampling_gen/schema_builder.py +59 -0
- data_designer/engine/sampling_gen/utils.py +40 -0
- data_designer/engine/secret_resolver.py +80 -0
- data_designer/engine/validators/__init__.py +17 -0
- data_designer/engine/validators/base.py +36 -0
- data_designer/engine/validators/local_callable.py +34 -0
- data_designer/engine/validators/python.py +245 -0
- data_designer/engine/validators/remote.py +83 -0
- data_designer/engine/validators/sql.py +60 -0
- data_designer/errors.py +5 -0
- data_designer/essentials/__init__.py +137 -0
- data_designer/interface/__init__.py +2 -0
- data_designer/interface/data_designer.py +351 -0
- data_designer/interface/errors.py +16 -0
- data_designer/interface/results.py +55 -0
- data_designer/logging.py +161 -0
- data_designer/plugin_manager.py +83 -0
- data_designer/plugins/__init__.py +6 -0
- data_designer/plugins/errors.py +10 -0
- data_designer/plugins/plugin.py +69 -0
- data_designer/plugins/registry.py +86 -0
- data_designer-0.1.0.dist-info/METADATA +173 -0
- data_designer-0.1.0.dist-info/RECORD +177 -0
- data_designer-0.1.0.dist-info/WHEEL +4 -0
- data_designer-0.1.0.dist-info/entry_points.txt +2 -0
- data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from litellm.types.router import DeploymentTypedDict, LiteLLM_Params
|
|
12
|
+
from litellm.types.utils import ModelResponse
|
|
13
|
+
|
|
14
|
+
from data_designer.config.models import ModelConfig, ModelProvider
|
|
15
|
+
from data_designer.engine.model_provider import ModelProviderRegistry
|
|
16
|
+
from data_designer.engine.models.errors import (
|
|
17
|
+
GenerationValidationFailureError,
|
|
18
|
+
catch_llm_exceptions,
|
|
19
|
+
get_exception_primary_cause,
|
|
20
|
+
)
|
|
21
|
+
from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs
|
|
22
|
+
from data_designer.engine.models.parsers.errors import ParserException
|
|
23
|
+
from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
|
|
24
|
+
from data_designer.engine.models.utils import prompt_to_messages, str_to_message
|
|
25
|
+
from data_designer.engine.secret_resolver import SecretResolver
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ModelFacade:
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
model_config: ModelConfig,
|
|
34
|
+
secret_resolver: SecretResolver,
|
|
35
|
+
model_provider_registry: ModelProviderRegistry,
|
|
36
|
+
):
|
|
37
|
+
self._model_config = model_config
|
|
38
|
+
self._secret_resolver = secret_resolver
|
|
39
|
+
self._model_provider_registry = model_provider_registry
|
|
40
|
+
self._litellm_deployment = self._get_litellm_deployment(model_config)
|
|
41
|
+
self._router = CustomRouter([self._litellm_deployment], **LiteLLMRouterDefaultKwargs().model_dump())
|
|
42
|
+
self._usage_stats = ModelUsageStats()
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def model_name(self) -> str:
|
|
46
|
+
return self._model_config.model
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def model_provider(self) -> ModelProvider:
|
|
50
|
+
return self._model_provider_registry.get_provider(self._model_config.provider)
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def model_provider_name(self) -> str:
|
|
54
|
+
return self.model_provider.name
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def model_alias(self) -> str:
|
|
58
|
+
return self._model_config.alias
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def usage_stats(self) -> ModelUsageStats:
|
|
62
|
+
return self._usage_stats
|
|
63
|
+
|
|
64
|
+
def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs) -> ModelResponse:
|
|
65
|
+
logger.debug(
|
|
66
|
+
f"Prompting model {self.model_name!r}...",
|
|
67
|
+
extra={"model": self.model_name, "messages": messages, "sensitive": True},
|
|
68
|
+
)
|
|
69
|
+
response = None
|
|
70
|
+
if self.model_provider.extra_body:
|
|
71
|
+
kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
|
|
72
|
+
try:
|
|
73
|
+
response = self._router.completion(self.model_name, messages, **kwargs)
|
|
74
|
+
logger.debug(
|
|
75
|
+
f"Received completion from model {self.model_name!r}",
|
|
76
|
+
extra={
|
|
77
|
+
"model": self.model_name,
|
|
78
|
+
"response": response,
|
|
79
|
+
"text": response.choices[0].message.content,
|
|
80
|
+
"usage": self._usage_stats.model_dump(),
|
|
81
|
+
},
|
|
82
|
+
)
|
|
83
|
+
return response
|
|
84
|
+
except Exception as e:
|
|
85
|
+
raise e
|
|
86
|
+
finally:
|
|
87
|
+
if not skip_usage_tracking:
|
|
88
|
+
self._track_usage(response)
|
|
89
|
+
|
|
90
|
+
@catch_llm_exceptions
|
|
91
|
+
def generate(
|
|
92
|
+
self,
|
|
93
|
+
prompt: str,
|
|
94
|
+
*,
|
|
95
|
+
parser: Callable[[str], Any],
|
|
96
|
+
system_prompt: str | None = None,
|
|
97
|
+
multi_modal_context: list[dict[str, Any]] | None = None,
|
|
98
|
+
max_correction_steps: int = 0,
|
|
99
|
+
max_conversation_restarts: int = 0,
|
|
100
|
+
skip_usage_tracking: bool = False,
|
|
101
|
+
purpose: str | None = None,
|
|
102
|
+
**kwargs,
|
|
103
|
+
) -> tuple[Any, str | None]:
|
|
104
|
+
"""Generate a parsed output with correction steps.
|
|
105
|
+
|
|
106
|
+
This generation call will attempt to generate an output which is
|
|
107
|
+
valid according to the specified parser, where "valid" implies
|
|
108
|
+
that the parser can process the LLM response without raising
|
|
109
|
+
an exception.
|
|
110
|
+
|
|
111
|
+
`ParserExceptions` are routed back
|
|
112
|
+
to the LLM as new rounds in the conversation, where the LLM is provided its
|
|
113
|
+
earlier response along with the "user" role responding with the exception string
|
|
114
|
+
(not traceback). This will continue for the number of rounds specified by
|
|
115
|
+
`max_correction_steps`.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
prompt (str): Task prompt.
|
|
119
|
+
system_prompt (str, optional): Optional system instructions. If not specified,
|
|
120
|
+
no system message is provided and the model should use its default system
|
|
121
|
+
prompt.
|
|
122
|
+
parser (func(str) -> Any): A function applied to the LLM response which processes
|
|
123
|
+
an LLM response into some output object.
|
|
124
|
+
max_correction_steps (int): Maximum number of correction rounds permitted
|
|
125
|
+
within a single conversation. Note, many rounds can lead to increasing
|
|
126
|
+
context size without necessarily improving performance -- small language
|
|
127
|
+
models can enter repeated cycles which will not be solved with more steps.
|
|
128
|
+
Default: `0` (no correction).
|
|
129
|
+
max_conversation_restarts (int): Maximum number of full conversation restarts permitted
|
|
130
|
+
if generation fails. Default: `0` (no restarts).
|
|
131
|
+
skip_usage_tracking (bool): Whether to skip usage tracking. Default: `False`.
|
|
132
|
+
purpose (str): The purpose of the model usage to show as context in the error message.
|
|
133
|
+
It is expected to be used by the @catch_llm_exceptions decorator.
|
|
134
|
+
**kwargs: Additional arguments to pass to the model.
|
|
135
|
+
|
|
136
|
+
Raises:
|
|
137
|
+
GenerationValidationFailureError: If the maximum number of retries or
|
|
138
|
+
correction steps are met and the last response failures on
|
|
139
|
+
generation validation.
|
|
140
|
+
"""
|
|
141
|
+
output_obj = None
|
|
142
|
+
curr_num_correction_steps = 0
|
|
143
|
+
curr_num_restarts = 0
|
|
144
|
+
curr_generation_attempt = 0
|
|
145
|
+
max_generation_attempts = (max_correction_steps + 1) * (max_conversation_restarts + 1)
|
|
146
|
+
|
|
147
|
+
starting_messages = prompt_to_messages(
|
|
148
|
+
user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context
|
|
149
|
+
)
|
|
150
|
+
messages = deepcopy(starting_messages)
|
|
151
|
+
|
|
152
|
+
while True:
|
|
153
|
+
curr_generation_attempt += 1
|
|
154
|
+
logger.debug(
|
|
155
|
+
f"Starting generation attempt {curr_generation_attempt} of {max_generation_attempts} attempts."
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
completion_response = self.completion(messages, skip_usage_tracking=skip_usage_tracking, **kwargs)
|
|
159
|
+
response = completion_response.choices[0].message.content or ""
|
|
160
|
+
reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None)
|
|
161
|
+
|
|
162
|
+
if reasoning_trace:
|
|
163
|
+
## There are generally some extra newlines with how these get parsed.
|
|
164
|
+
response = response.strip()
|
|
165
|
+
reasoning_trace = reasoning_trace.strip()
|
|
166
|
+
|
|
167
|
+
curr_num_correction_steps += 1
|
|
168
|
+
|
|
169
|
+
try:
|
|
170
|
+
output_obj = parser(response) # type: ignore - if not a string will cause a ParserException below
|
|
171
|
+
break
|
|
172
|
+
except ParserException as exc:
|
|
173
|
+
if max_correction_steps == 0 and max_conversation_restarts == 0:
|
|
174
|
+
raise GenerationValidationFailureError(
|
|
175
|
+
"Unsuccessful generation attempt. No retries were attempted."
|
|
176
|
+
) from exc
|
|
177
|
+
if curr_num_correction_steps <= max_correction_steps:
|
|
178
|
+
## Add turns to loop-back errors for correction
|
|
179
|
+
messages += [
|
|
180
|
+
str_to_message(content=response, role="assistant"),
|
|
181
|
+
str_to_message(content=str(get_exception_primary_cause(exc)), role="user"),
|
|
182
|
+
]
|
|
183
|
+
elif curr_num_restarts < max_conversation_restarts:
|
|
184
|
+
curr_num_correction_steps = 0
|
|
185
|
+
curr_num_restarts += 1
|
|
186
|
+
messages = deepcopy(starting_messages)
|
|
187
|
+
else:
|
|
188
|
+
raise GenerationValidationFailureError(
|
|
189
|
+
f"Unsuccessful generation attempt despite {max_generation_attempts} attempts."
|
|
190
|
+
) from exc
|
|
191
|
+
return output_obj, reasoning_trace
|
|
192
|
+
|
|
193
|
+
def _get_litellm_deployment(self, model_config: ModelConfig) -> DeploymentTypedDict:
|
|
194
|
+
provider = self._model_provider_registry.get_provider(model_config.provider)
|
|
195
|
+
api_key = None
|
|
196
|
+
if provider.api_key:
|
|
197
|
+
api_key = self._secret_resolver.resolve(provider.api_key)
|
|
198
|
+
api_key = api_key or "not-used-but-required"
|
|
199
|
+
|
|
200
|
+
litellm_params = LiteLLM_Params(
|
|
201
|
+
model=f"{provider.provider_type}/{model_config.model}",
|
|
202
|
+
api_base=provider.endpoint,
|
|
203
|
+
api_key=api_key,
|
|
204
|
+
)
|
|
205
|
+
return {
|
|
206
|
+
"model_name": model_config.model,
|
|
207
|
+
"litellm_params": litellm_params.model_dump(),
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
def _track_usage(self, response: ModelResponse | None) -> None:
|
|
211
|
+
if response is None:
|
|
212
|
+
self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
|
|
213
|
+
return
|
|
214
|
+
if (
|
|
215
|
+
response.usage is not None
|
|
216
|
+
and response.usage.prompt_tokens is not None
|
|
217
|
+
and response.usage.completion_tokens is not None
|
|
218
|
+
):
|
|
219
|
+
self._usage_stats.extend(
|
|
220
|
+
token_usage=TokenUsageStats(
|
|
221
|
+
prompt_tokens=response.usage.prompt_tokens,
|
|
222
|
+
completion_tokens=response.usage.completion_tokens,
|
|
223
|
+
),
|
|
224
|
+
request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
|
|
225
|
+
)
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import random
|
|
7
|
+
import threading
|
|
8
|
+
from typing import Optional, Union
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
import litellm
|
|
12
|
+
from litellm import RetryPolicy
|
|
13
|
+
from litellm.caching.in_memory_cache import InMemoryCache
|
|
14
|
+
from litellm.router import Router
|
|
15
|
+
from pydantic import BaseModel, Field
|
|
16
|
+
from typing_extensions import override
|
|
17
|
+
|
|
18
|
+
from data_designer.logging import quiet_noisy_logger
|
|
19
|
+
|
|
20
|
+
DEFAULT_MAX_CALLBACKS = 1000
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LiteLLMRouterDefaultKwargs(BaseModel):
|
|
24
|
+
## Number of seconds to wait initially after a connection
|
|
25
|
+
## failure.
|
|
26
|
+
initial_retry_after_s: float = 2.0
|
|
27
|
+
|
|
28
|
+
## Jitter percentage added during exponential backoff to
|
|
29
|
+
## smooth repeated retries over time.
|
|
30
|
+
jitter_pct: float = 0.2
|
|
31
|
+
|
|
32
|
+
## Maximum number of seconds to wait for an API request
|
|
33
|
+
## before letting it die. Will trigger a retry.
|
|
34
|
+
timeout: float = 60.0
|
|
35
|
+
|
|
36
|
+
## Sets the default retry policy, including the number
|
|
37
|
+
## of retries to use in particular scenarios.
|
|
38
|
+
retry_policy: RetryPolicy = Field(
|
|
39
|
+
default_factory=lambda: RetryPolicy(
|
|
40
|
+
RateLimitErrorRetries=3,
|
|
41
|
+
TimeoutErrorRetries=3,
|
|
42
|
+
)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ThreadSafeCache(InMemoryCache):
|
|
47
|
+
def __init__(self, *args, **kwargs):
|
|
48
|
+
super().__init__(*args, **kwargs)
|
|
49
|
+
|
|
50
|
+
self._lock = threading.RLock()
|
|
51
|
+
|
|
52
|
+
def get_cache(self, key, **kwargs):
|
|
53
|
+
with self._lock:
|
|
54
|
+
return super().get_cache(key, **kwargs)
|
|
55
|
+
|
|
56
|
+
def set_cache(self, key, value, **kwargs):
|
|
57
|
+
with self._lock:
|
|
58
|
+
super().set_cache(key, value, **kwargs)
|
|
59
|
+
|
|
60
|
+
def batch_get_cache(self, keys: list, **kwargs):
|
|
61
|
+
with self._lock:
|
|
62
|
+
return super().batch_get_cache(keys, **kwargs)
|
|
63
|
+
|
|
64
|
+
def delete_cache(self, key):
|
|
65
|
+
with self._lock:
|
|
66
|
+
super().delete_cache(key)
|
|
67
|
+
|
|
68
|
+
def evict_cache(self):
|
|
69
|
+
with self._lock:
|
|
70
|
+
super().evict_cache()
|
|
71
|
+
|
|
72
|
+
def increment_cache(self, key, value: int, **kwargs) -> int:
|
|
73
|
+
with self._lock:
|
|
74
|
+
return super().increment_cache(key, value, **kwargs)
|
|
75
|
+
|
|
76
|
+
def flush_cache(self):
|
|
77
|
+
with self._lock:
|
|
78
|
+
super().flush_cache()
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class CustomRouter(Router):
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
*args,
|
|
85
|
+
initial_retry_after_s: float,
|
|
86
|
+
jitter_pct: float,
|
|
87
|
+
**kwargs,
|
|
88
|
+
):
|
|
89
|
+
super().__init__(*args, **kwargs)
|
|
90
|
+
self._initial_retry_after_s = initial_retry_after_s
|
|
91
|
+
self._jitter_pct = jitter_pct
|
|
92
|
+
|
|
93
|
+
def _extract_retry_delay_from_headers(self, e: Exception) -> Optional[Union[int, float]]:
|
|
94
|
+
"""
|
|
95
|
+
Most of this code logic was extracted directly from the parent
|
|
96
|
+
`Router`'s `_time_to_sleep_before_retry` function. Our override
|
|
97
|
+
of that method below should only affect requests where the server
|
|
98
|
+
didn't explicitly return a desired retry-delay. If the server did
|
|
99
|
+
return this info, we'll simply use that retry value returned here.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
response_headers: Optional[httpx.Headers] = None
|
|
103
|
+
if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
|
|
104
|
+
response_headers = e.response.headers # type: ignore
|
|
105
|
+
if hasattr(e, "litellm_response_headers"):
|
|
106
|
+
response_headers = e.litellm_response_headers # type: ignore
|
|
107
|
+
|
|
108
|
+
retry_after = litellm.utils._get_retry_after_from_exception_header(response_headers)
|
|
109
|
+
|
|
110
|
+
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
|
|
111
|
+
if retry_after is not None and 0 < retry_after <= 60:
|
|
112
|
+
return retry_after
|
|
113
|
+
else:
|
|
114
|
+
return None
|
|
115
|
+
|
|
116
|
+
@override
|
|
117
|
+
def _time_to_sleep_before_retry(
|
|
118
|
+
self,
|
|
119
|
+
e: Exception,
|
|
120
|
+
remaining_retries: int,
|
|
121
|
+
num_retries: int,
|
|
122
|
+
healthy_deployments: Optional[list] = None,
|
|
123
|
+
all_deployments: Optional[list] = None,
|
|
124
|
+
) -> Union[int, float]:
|
|
125
|
+
"""
|
|
126
|
+
Implements exponential backoff for retries.
|
|
127
|
+
|
|
128
|
+
Technically, litellm's `Router` already implements some
|
|
129
|
+
form of exponential backoff. However, that backoff
|
|
130
|
+
is not customizable w.r.t jitter and initial delay
|
|
131
|
+
timing. For that reason, we override this method to
|
|
132
|
+
utilize our own custom instance variables, deferring
|
|
133
|
+
to the existing implementation wherever we can.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
# If the response headers indicated how long we should wait,
|
|
137
|
+
# use that information.
|
|
138
|
+
if retry_after := self._extract_retry_delay_from_headers(e):
|
|
139
|
+
return retry_after
|
|
140
|
+
|
|
141
|
+
return self.calculate_exponential_backoff(
|
|
142
|
+
initial_retry_after_s=self._initial_retry_after_s,
|
|
143
|
+
current_retry=num_retries - remaining_retries,
|
|
144
|
+
jitter_pct=self._jitter_pct,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def calculate_exponential_backoff(initial_retry_after_s: float, current_retry: int, jitter_pct: float) -> float:
|
|
149
|
+
sleep_s = initial_retry_after_s * (pow(2.0, current_retry))
|
|
150
|
+
jitter = 1.0 + random.uniform(-jitter_pct, jitter_pct)
|
|
151
|
+
return sleep_s * jitter
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def apply_litellm_patches():
|
|
155
|
+
litellm.in_memory_llm_clients_cache = ThreadSafeCache()
|
|
156
|
+
|
|
157
|
+
# Workaround for the litellm issue described in https://github.com/BerriAI/litellm/issues/9792
|
|
158
|
+
litellm.litellm_core_utils.logging_callback_manager.LoggingCallbackManager.MAX_CALLBACKS = DEFAULT_MAX_CALLBACKS
|
|
159
|
+
|
|
160
|
+
quiet_noisy_logger("httpx")
|
|
161
|
+
quiet_noisy_logger("LiteLLM")
|
|
162
|
+
quiet_noisy_logger("LiteLLM Router")
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ParserException(Exception):
|
|
8
|
+
"""Identifies errors resulting from generic parser errors.
|
|
9
|
+
|
|
10
|
+
Attributes:
|
|
11
|
+
source (str | None): The source string that the parser
|
|
12
|
+
attempted to parse.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
source: Optional[str]
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def _log_format(source: str) -> str:
|
|
19
|
+
## NOTE: The point of this was to be able to report offending
|
|
20
|
+
## failure cases to the logs. This might not be what we want
|
|
21
|
+
## to do in all cases. In the meantime, this note is left
|
|
22
|
+
## for later review.
|
|
23
|
+
#
|
|
24
|
+
# return f"<source>{source}</source>"
|
|
25
|
+
return ""
|
|
26
|
+
|
|
27
|
+
def __init__(self, msg: Optional[str] = None, source: Optional[str] = None):
|
|
28
|
+
msg = "" if msg is None else msg.strip()
|
|
29
|
+
|
|
30
|
+
if source is not None:
|
|
31
|
+
msg += self._log_format(source)
|
|
32
|
+
|
|
33
|
+
super().__init__(msg)
|
|
34
|
+
self.source = source
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from functools import reduce
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
from lxml import etree
|
|
8
|
+
from lxml.etree import _Element
|
|
9
|
+
import marko
|
|
10
|
+
|
|
11
|
+
from data_designer.engine.models.parsers.postprocessors import merge_text_blocks
|
|
12
|
+
import data_designer.engine.models.parsers.tag_parsers as tp
|
|
13
|
+
from data_designer.engine.models.parsers.types import (
|
|
14
|
+
LLMStructuredResponse,
|
|
15
|
+
PostProcessor,
|
|
16
|
+
TagParser,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
DEFAULT_TAG_PARSERS = {
|
|
20
|
+
"pre.code": tp.code_block_parser,
|
|
21
|
+
"p.code": tp.inline_code_parser,
|
|
22
|
+
"p": tp.text_parser,
|
|
23
|
+
"pre": tp.text_parser,
|
|
24
|
+
"": tp.text_parser_keep_markup,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
DEFAULT_POST_PROCESSORS = [merge_text_blocks]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _patch_tags_before_code_fences(response: str) -> str:
|
|
31
|
+
"""Patch to add a linebreak between a tag prior to a code block.
|
|
32
|
+
|
|
33
|
+
Marko conversion of MD->HTML has a quirk. If there is a case like
|
|
34
|
+
the following, it will not convert the code block at all:
|
|
35
|
+
|
|
36
|
+
...
|
|
37
|
+
</ending_tag>
|
|
38
|
+
```syntax
|
|
39
|
+
...
|
|
40
|
+
|
|
41
|
+
We want to find these cases and simply introduce an additional
|
|
42
|
+
line break.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
return response.replace(">\n```", ">\n\n```")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class LLMResponseParser:
|
|
49
|
+
"""
|
|
50
|
+
Parses Language Model (LLM) responses containing a mixture of Markdown and custom markup into structured data.
|
|
51
|
+
|
|
52
|
+
The `LLMResponseParser` class facilitates the translation of LLM-generated responses, which may include
|
|
53
|
+
Markdown and custom markup tags, into a structured format using ElementTree. It allows for customizable
|
|
54
|
+
parsing behavior through the registration of tag-specific parsers and post-processors.
|
|
55
|
+
|
|
56
|
+
## Description
|
|
57
|
+
|
|
58
|
+
The core functionality of this class enables LLMs to respond using Markdown along with any custom
|
|
59
|
+
prompted markup specified by the system or task. The parsing process involves converting the Markdown
|
|
60
|
+
and markup into an ElementTree, then processing each element using registered tag parsers to produce
|
|
61
|
+
a list of structured `BaseModel` instances. Post-processors can further refine the structured response.
|
|
62
|
+
|
|
63
|
+
### Tag Parsers
|
|
64
|
+
|
|
65
|
+
Tag parsers are responsible for handling specific markup tags within the LLM response. They can be
|
|
66
|
+
registered with the parser using dot-path notation to manage hierarchical tag structures. This allows
|
|
67
|
+
downstream tasks to customize how specific elements are processed into `BaseModel` instances.
|
|
68
|
+
|
|
69
|
+
### Post-Processors
|
|
70
|
+
|
|
71
|
+
Post-processors are functions that operate on the list of parsed blocks to perform additional
|
|
72
|
+
transformations or aggregations. They are applied after the initial parsing of the response.
|
|
73
|
+
|
|
74
|
+
Attributes:
|
|
75
|
+
tag_parsers (dict[str, TagParser]): A dictionary mapping tag paths to their corresponding `TagParser` instances.
|
|
76
|
+
postprocessors (list[PostProcessor]): A list of post-processing functions to apply to the structured response.
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
```python
|
|
80
|
+
class CodeBlock(BaseModel):
|
|
81
|
+
code: str
|
|
82
|
+
syntax: Optional[str] = None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class CodeBlockParser:
|
|
86
|
+
def __call__(self, element: _Element) -> CodeBlock:
|
|
87
|
+
# Implementation details...
|
|
88
|
+
return CodeBlock(code=element.text, syntax=element.get("class"))
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
parser = LLMResponseParser(
|
|
92
|
+
tag_parsers={
|
|
93
|
+
"pre.code": CodeBlockParser(),
|
|
94
|
+
}
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
out = parser.parse('```json\n{"answer": 42}\n```')
|
|
98
|
+
print(out.parsed)
|
|
99
|
+
# Output: [CodeBlock(code='{"answer": 42}\n', syntax='json')]
|
|
100
|
+
```
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
tag_parsers: dict[str, TagParser]
|
|
104
|
+
postprocessors: list[PostProcessor]
|
|
105
|
+
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
tag_parsers: Optional[dict[str, TagParser]] = None,
|
|
109
|
+
postprocessors: Optional[list[PostProcessor]] = None,
|
|
110
|
+
):
|
|
111
|
+
"""
|
|
112
|
+
Initializes the LLMResponseParser with optional tag parsers and post-processors.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
tag_parsers (Optional[dict[str, TagParser]]): A dictionary mapping tag paths to `TagParser` instances.
|
|
116
|
+
If provided, these parsers will be merged with the default tag parsers.
|
|
117
|
+
postprocessors (Optional[list[PostProcessor]]): A list of post-processing functions to apply
|
|
118
|
+
to the structured response. If not provided, a default post-processor `merge_text_blocks`
|
|
119
|
+
is used.
|
|
120
|
+
|
|
121
|
+
Attributes:
|
|
122
|
+
tag_parsers (dict[str, TagParser]): Initialized with default tag parsers, updated with any provided.
|
|
123
|
+
postprocessors (list[PostProcessor]): Initialized with default post-processors or the provided list.
|
|
124
|
+
"""
|
|
125
|
+
self.tag_parsers = {**DEFAULT_TAG_PARSERS}
|
|
126
|
+
if tag_parsers:
|
|
127
|
+
self.tag_parsers.update(tag_parsers)
|
|
128
|
+
|
|
129
|
+
self.postprocessors = [
|
|
130
|
+
merge_text_blocks,
|
|
131
|
+
]
|
|
132
|
+
if postprocessors is not None:
|
|
133
|
+
self.postprocessors = postprocessors
|
|
134
|
+
|
|
135
|
+
def lookup_parser(self, element: _Element) -> TagParser:
|
|
136
|
+
"""
|
|
137
|
+
Resolves and retrieves the appropriate `TagParser` for a given XML element based on its tag hierarchy.
|
|
138
|
+
|
|
139
|
+
The method constructs the dot-path lineage of the element's tags, starting from the root and moving
|
|
140
|
+
towards the specific element. It then attempts to find the most specific matching `TagParser` by
|
|
141
|
+
progressively reducing the specificity of the tag path until a matching parser is found.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
element (_Element): The XML element for which to find the corresponding `TagParser`.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
TagParser: The `TagParser` instance that matches the element's tag path.
|
|
148
|
+
|
|
149
|
+
Raises:
|
|
150
|
+
KeyError: If no matching `TagParser` is found for the element's tag path.
|
|
151
|
+
"""
|
|
152
|
+
# Get the dot path lineage of this tag, sans root.
|
|
153
|
+
# Note that the lineage comes back in reverse order.
|
|
154
|
+
parents = [e.tag for e in element.iterancestors()][::-1]
|
|
155
|
+
lineage = [*parents, element.tag]
|
|
156
|
+
|
|
157
|
+
# Now attempt to matchup with the tag parsers name.
|
|
158
|
+
# Starts from the full linear (most specific), and
|
|
159
|
+
# breaks on the first hit. So this should properly
|
|
160
|
+
# prioritize specific parsers over general ones.
|
|
161
|
+
while lineage:
|
|
162
|
+
tag_path = ".".join(lineage)
|
|
163
|
+
if tag_path not in self.tag_parsers:
|
|
164
|
+
lineage.pop(0)
|
|
165
|
+
else:
|
|
166
|
+
break
|
|
167
|
+
|
|
168
|
+
# Tag path can be an empty string, which hits the
|
|
169
|
+
# default parsing option specified by the "" entry
|
|
170
|
+
# of the tag parsers dict.
|
|
171
|
+
tag_path = ".".join(lineage)
|
|
172
|
+
return self.tag_parsers[tag_path]
|
|
173
|
+
|
|
174
|
+
def postprocess(self, structured_response: LLMStructuredResponse) -> LLMStructuredResponse:
|
|
175
|
+
"""
|
|
176
|
+
Applies post-processing functions to the structured response.
|
|
177
|
+
|
|
178
|
+
If no post-processors are registered, the original structured response is returned.
|
|
179
|
+
Otherwise, each post-processor is applied in sequence to transform the response.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
structured_response (LLMStructuredResponse): The initial structured response to be post-processed.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
LLMStructuredResponse: The post-processed structured response.
|
|
186
|
+
"""
|
|
187
|
+
if not self.postprocessors:
|
|
188
|
+
return structured_response
|
|
189
|
+
|
|
190
|
+
return reduce(lambda acc, func: func(acc), self.postprocessors, structured_response)
|
|
191
|
+
|
|
192
|
+
def parse(self, md_response: str) -> LLMStructuredResponse:
|
|
193
|
+
"""
|
|
194
|
+
Parses a Markdown-formatted LLM response into a structured `LLMStructuredResponse`.
|
|
195
|
+
|
|
196
|
+
The parsing process involves converting the Markdown and custom markup into an XML tree,
|
|
197
|
+
iterating over each element in a depth-first traversal to apply the appropriate
|
|
198
|
+
`TagParser`, and then applying any registered post-processors to the resulting structured data.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
md_response (str): The Markdown-formatted response from the LLM, potentially containing custom markup.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
LLMStructuredResponse: The structured representation of the parsed response, containing parsed blocks.
|
|
205
|
+
|
|
206
|
+
Raises:
|
|
207
|
+
etree.XMLSyntaxError: If the provided Markdown cannot be converted into a valid XML structure.
|
|
208
|
+
"""
|
|
209
|
+
response = marko.convert(_patch_tags_before_code_fences(md_response))
|
|
210
|
+
output = LLMStructuredResponse(response=md_response, markup=response)
|
|
211
|
+
|
|
212
|
+
# Generate document tree
|
|
213
|
+
parser = etree.HTMLParser(recover=True, remove_blank_text=True)
|
|
214
|
+
root = etree.fromstring(response, parser=parser)
|
|
215
|
+
tags = root.iter() if root is not None else []
|
|
216
|
+
|
|
217
|
+
# Iterate over tags, depth first
|
|
218
|
+
for element in tags:
|
|
219
|
+
if element == root or element.tag == "body":
|
|
220
|
+
continue
|
|
221
|
+
|
|
222
|
+
parsed_block = self.lookup_parser(element)(element)
|
|
223
|
+
|
|
224
|
+
# Make a quick check for dead text blocks, which
|
|
225
|
+
# can happen with container tags like <pre>, <ul>, and <ol>.
|
|
226
|
+
drop_block = isinstance(parsed_block, tp.TextBlock) and not parsed_block.text.strip()
|
|
227
|
+
|
|
228
|
+
if not drop_block:
|
|
229
|
+
output.parsed.append(parsed_block)
|
|
230
|
+
|
|
231
|
+
# Check tails -- inelegant, but they're always text.
|
|
232
|
+
# Don't add the tail if it is just blank space.
|
|
233
|
+
if element.tail and element.tail.strip():
|
|
234
|
+
output.parsed.append(tp.TextBlock(text=element.tail))
|
|
235
|
+
|
|
236
|
+
return self.postprocess(output)
|