data-designer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +15 -0
- data_designer/_version.py +34 -0
- data_designer/cli/README.md +236 -0
- data_designer/cli/__init__.py +6 -0
- data_designer/cli/commands/__init__.py +2 -0
- data_designer/cli/commands/list.py +130 -0
- data_designer/cli/commands/models.py +10 -0
- data_designer/cli/commands/providers.py +11 -0
- data_designer/cli/commands/reset.py +100 -0
- data_designer/cli/controllers/__init__.py +7 -0
- data_designer/cli/controllers/model_controller.py +246 -0
- data_designer/cli/controllers/provider_controller.py +317 -0
- data_designer/cli/forms/__init__.py +20 -0
- data_designer/cli/forms/builder.py +51 -0
- data_designer/cli/forms/field.py +180 -0
- data_designer/cli/forms/form.py +59 -0
- data_designer/cli/forms/model_builder.py +125 -0
- data_designer/cli/forms/provider_builder.py +76 -0
- data_designer/cli/main.py +44 -0
- data_designer/cli/repositories/__init__.py +8 -0
- data_designer/cli/repositories/base.py +39 -0
- data_designer/cli/repositories/model_repository.py +42 -0
- data_designer/cli/repositories/provider_repository.py +43 -0
- data_designer/cli/services/__init__.py +7 -0
- data_designer/cli/services/model_service.py +116 -0
- data_designer/cli/services/provider_service.py +111 -0
- data_designer/cli/ui.py +448 -0
- data_designer/cli/utils.py +47 -0
- data_designer/config/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +89 -0
- data_designer/config/analysis/column_statistics.py +274 -0
- data_designer/config/analysis/dataset_profiler.py +60 -0
- data_designer/config/analysis/utils/errors.py +8 -0
- data_designer/config/analysis/utils/reporting.py +188 -0
- data_designer/config/base.py +68 -0
- data_designer/config/column_configs.py +354 -0
- data_designer/config/column_types.py +168 -0
- data_designer/config/config_builder.py +660 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +11 -0
- data_designer/config/datastore.py +151 -0
- data_designer/config/default_model_settings.py +123 -0
- data_designer/config/errors.py +19 -0
- data_designer/config/interface.py +54 -0
- data_designer/config/models.py +231 -0
- data_designer/config/preview_results.py +32 -0
- data_designer/config/processors.py +41 -0
- data_designer/config/sampler_constraints.py +51 -0
- data_designer/config/sampler_params.py +604 -0
- data_designer/config/seed.py +145 -0
- data_designer/config/utils/code_lang.py +83 -0
- data_designer/config/utils/constants.py +313 -0
- data_designer/config/utils/errors.py +19 -0
- data_designer/config/utils/info.py +88 -0
- data_designer/config/utils/io_helpers.py +273 -0
- data_designer/config/utils/misc.py +81 -0
- data_designer/config/utils/numerical_helpers.py +28 -0
- data_designer/config/utils/type_helpers.py +100 -0
- data_designer/config/utils/validation.py +336 -0
- data_designer/config/utils/visualization.py +427 -0
- data_designer/config/validator_params.py +96 -0
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +55 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
- data_designer/engine/analysis/column_profilers/registry.py +20 -0
- data_designer/engine/analysis/column_statistics.py +142 -0
- data_designer/engine/analysis/dataset_profiler.py +125 -0
- data_designer/engine/analysis/errors.py +7 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +61 -0
- data_designer/engine/column_generators/generators/expression.py +63 -0
- data_designer/engine/column_generators/generators/llm_generators.py +172 -0
- data_designer/engine/column_generators/generators/samplers.py +75 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
- data_designer/engine/column_generators/generators/validation.py +147 -0
- data_designer/engine/column_generators/registry.py +56 -0
- data_designer/engine/column_generators/utils/errors.py +13 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
- data_designer/engine/configurable_task.py +82 -0
- data_designer/engine/dataset_builders/artifact_storage.py +181 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
- data_designer/engine/dataset_builders/errors.py +13 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
- data_designer/engine/dataset_builders/utils/dag.py +56 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
- data_designer/engine/dataset_builders/utils/errors.py +13 -0
- data_designer/engine/errors.py +49 -0
- data_designer/engine/model_provider.py +75 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +308 -0
- data_designer/engine/models/facade.py +225 -0
- data_designer/engine/models/litellm_overrides.py +162 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +236 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +60 -0
- data_designer/engine/models/parsers/types.py +82 -0
- data_designer/engine/models/recipes/base.py +79 -0
- data_designer/engine/models/recipes/response_recipes.py +291 -0
- data_designer/engine/models/registry.py +118 -0
- data_designer/engine/models/usage.py +75 -0
- data_designer/engine/models/utils.py +38 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +64 -0
- data_designer/engine/processing/ginja/environment.py +461 -0
- data_designer/engine/processing/ginja/exceptions.py +54 -0
- data_designer/engine/processing/ginja/record.py +30 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +8 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
- data_designer/engine/processing/gsonschema/types.py +8 -0
- data_designer/engine/processing/gsonschema/validators.py +143 -0
- data_designer/engine/processing/processors/base.py +15 -0
- data_designer/engine/processing/processors/drop_columns.py +46 -0
- data_designer/engine/processing/processors/registry.py +20 -0
- data_designer/engine/processing/utils.py +120 -0
- data_designer/engine/registry/base.py +97 -0
- data_designer/engine/registry/data_designer_registry.py +37 -0
- data_designer/engine/registry/errors.py +10 -0
- data_designer/engine/resources/managed_dataset_generator.py +35 -0
- data_designer/engine/resources/managed_dataset_repository.py +194 -0
- data_designer/engine/resources/managed_storage.py +63 -0
- data_designer/engine/resources/resource_provider.py +46 -0
- data_designer/engine/resources/seed_dataset_data_store.py +66 -0
- data_designer/engine/sampling_gen/column.py +89 -0
- data_designer/engine/sampling_gen/constraints.py +95 -0
- data_designer/engine/sampling_gen/data_sources/base.py +214 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
- data_designer/engine/sampling_gen/entities/errors.py +8 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
- data_designer/engine/sampling_gen/entities/person.py +142 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
- data_designer/engine/sampling_gen/errors.py +24 -0
- data_designer/engine/sampling_gen/generator.py +121 -0
- data_designer/engine/sampling_gen/jinja_utils.py +60 -0
- data_designer/engine/sampling_gen/people_gen.py +203 -0
- data_designer/engine/sampling_gen/person_constants.py +54 -0
- data_designer/engine/sampling_gen/schema.py +143 -0
- data_designer/engine/sampling_gen/schema_builder.py +59 -0
- data_designer/engine/sampling_gen/utils.py +40 -0
- data_designer/engine/secret_resolver.py +80 -0
- data_designer/engine/validators/__init__.py +17 -0
- data_designer/engine/validators/base.py +36 -0
- data_designer/engine/validators/local_callable.py +34 -0
- data_designer/engine/validators/python.py +245 -0
- data_designer/engine/validators/remote.py +83 -0
- data_designer/engine/validators/sql.py +60 -0
- data_designer/errors.py +5 -0
- data_designer/essentials/__init__.py +137 -0
- data_designer/interface/__init__.py +2 -0
- data_designer/interface/data_designer.py +351 -0
- data_designer/interface/errors.py +16 -0
- data_designer/interface/results.py +55 -0
- data_designer/logging.py +161 -0
- data_designer/plugin_manager.py +83 -0
- data_designer/plugins/__init__.py +6 -0
- data_designer/plugins/errors.py +10 -0
- data_designer/plugins/plugin.py +69 -0
- data_designer/plugins/registry.py +86 -0
- data_designer-0.1.0.dist-info/METADATA +173 -0
- data_designer-0.1.0.dist-info/RECORD +177 -0
- data_designer-0.1.0.dist-info/WHEEL +4 -0
- data_designer-0.1.0.dist-info/entry_points.txt +2 -0
- data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import shutil
|
|
7
|
+
from typing import Callable, Container, Iterator
|
|
8
|
+
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import pyarrow.parquet as pq
|
|
11
|
+
|
|
12
|
+
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage, BatchStage
|
|
13
|
+
from data_designer.engine.dataset_builders.utils.errors import DatasetBatchManagementError
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DatasetBatchManager:
|
|
19
|
+
def __init__(self, artifact_storage: ArtifactStorage):
|
|
20
|
+
self._buffer: list[dict] = []
|
|
21
|
+
self._current_batch_number = 0
|
|
22
|
+
self._num_records_list: list[int] | None = None
|
|
23
|
+
self._buffer_size: int | None = None
|
|
24
|
+
self.artifact_storage = artifact_storage
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def num_batches(self) -> int:
|
|
28
|
+
if self._num_records_list is None:
|
|
29
|
+
return 0
|
|
30
|
+
return len(self._num_records_list)
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def num_records_batch(self) -> int:
|
|
34
|
+
if self._num_records_list is None or self._current_batch_number >= len(self._num_records_list):
|
|
35
|
+
raise DatasetBatchManagementError("🛑 Invalid batch number or num_records_list not set.")
|
|
36
|
+
return self._num_records_list[self._current_batch_number]
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def num_records_list(self) -> list[int]:
|
|
40
|
+
if self._num_records_list is None:
|
|
41
|
+
raise DatasetBatchManagementError("🛑 `num_records_list` is not set. Call start() first.")
|
|
42
|
+
return self._num_records_list
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def num_records_in_buffer(self) -> int:
|
|
46
|
+
return len(self._buffer)
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def buffer_is_empty(self) -> bool:
|
|
50
|
+
return len(self._buffer) == 0
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def buffer_size(self) -> int:
|
|
54
|
+
if self._buffer_size is None:
|
|
55
|
+
raise DatasetBatchManagementError("🛑 `buffer_size` is not set. Call start() first.")
|
|
56
|
+
return self._buffer_size
|
|
57
|
+
|
|
58
|
+
def add_record(self, record: dict) -> None:
|
|
59
|
+
self.add_records([record])
|
|
60
|
+
|
|
61
|
+
def add_records(self, records: list[dict]) -> None:
|
|
62
|
+
self._buffer.extend(records)
|
|
63
|
+
if len(self._buffer) > self.buffer_size:
|
|
64
|
+
raise DatasetBatchManagementError(
|
|
65
|
+
f"🛑 Buffer size exceeded. Current: {len(self._buffer)}, Max: {self.buffer_size}. "
|
|
66
|
+
"Flush the batch before adding more records."
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def drop_records(self, index: Container[int]) -> None:
|
|
70
|
+
self._buffer = [record for i, record in enumerate(self._buffer) if i not in index]
|
|
71
|
+
|
|
72
|
+
def finish_batch(self, on_complete: Callable[[Path], None] | None = None) -> Path:
|
|
73
|
+
"""Finish the batch by moving the results from the partial results path to the final parquet folder.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
The path to the written parquet file.
|
|
77
|
+
"""
|
|
78
|
+
if self._current_batch_number >= self.num_batches:
|
|
79
|
+
raise DatasetBatchManagementError("🛑 All batches have been processed.")
|
|
80
|
+
|
|
81
|
+
if not self.write():
|
|
82
|
+
raise DatasetBatchManagementError("🛑 Batch finished without any results to write.")
|
|
83
|
+
|
|
84
|
+
final_file_path = self.artifact_storage.move_partial_result_to_final_file_path(self._current_batch_number)
|
|
85
|
+
|
|
86
|
+
self.artifact_storage.write_metadata(
|
|
87
|
+
{
|
|
88
|
+
"target_num_records": sum(self.num_records_list),
|
|
89
|
+
"total_num_batches": self.num_batches,
|
|
90
|
+
"buffer_size": self._buffer_size,
|
|
91
|
+
"schema": {field.name: str(field.type) for field in pq.read_schema(final_file_path)},
|
|
92
|
+
"file_paths": [str(f) for f in sorted(self.artifact_storage.final_dataset_path.glob("*.parquet"))],
|
|
93
|
+
"num_records": self.num_records_list[: self._current_batch_number + 1],
|
|
94
|
+
"num_completed_batches": self._current_batch_number + 1,
|
|
95
|
+
"dataset_name": self.artifact_storage.dataset_name,
|
|
96
|
+
}
|
|
97
|
+
)
|
|
98
|
+
self._current_batch_number += 1
|
|
99
|
+
self._buffer: list[dict] = []
|
|
100
|
+
|
|
101
|
+
if on_complete:
|
|
102
|
+
on_complete(final_file_path)
|
|
103
|
+
|
|
104
|
+
return final_file_path
|
|
105
|
+
|
|
106
|
+
def finish(self) -> None:
|
|
107
|
+
"""Finish the dataset writing process by deleting the partial results path if it exists and is empty."""
|
|
108
|
+
|
|
109
|
+
# If the partial results path is empty, delete it.
|
|
110
|
+
if not any(self.artifact_storage.partial_results_path.iterdir()):
|
|
111
|
+
self.artifact_storage.partial_results_path.rmdir()
|
|
112
|
+
|
|
113
|
+
# Otherwise, log a warning, since existing partial results means the dataset is not complete.
|
|
114
|
+
else:
|
|
115
|
+
logger.warning("⚠️ Dataset writing finished with partial results.")
|
|
116
|
+
|
|
117
|
+
self.reset()
|
|
118
|
+
|
|
119
|
+
def get_current_batch_number(self) -> int:
|
|
120
|
+
return self._current_batch_number
|
|
121
|
+
|
|
122
|
+
def get_current_batch(self, *, as_dataframe: bool = False) -> pd.DataFrame | list[dict]:
|
|
123
|
+
if as_dataframe:
|
|
124
|
+
return pd.DataFrame(self._buffer)
|
|
125
|
+
return self._buffer
|
|
126
|
+
|
|
127
|
+
def iter_current_batch(self) -> Iterator[tuple[int, dict]]:
|
|
128
|
+
for i, record in enumerate(self._buffer):
|
|
129
|
+
yield i, record
|
|
130
|
+
|
|
131
|
+
def reset(self, delete_files: bool = False) -> None:
|
|
132
|
+
self._current_batch_number = 0
|
|
133
|
+
self._buffer: list[dict] = []
|
|
134
|
+
if delete_files:
|
|
135
|
+
for dir_path in [
|
|
136
|
+
self.artifact_storage.final_dataset_path,
|
|
137
|
+
self.artifact_storage.partial_results_path,
|
|
138
|
+
self.artifact_storage.dropped_columns_dataset_path,
|
|
139
|
+
self.artifact_storage.base_dataset_path,
|
|
140
|
+
]:
|
|
141
|
+
if dir_path.exists():
|
|
142
|
+
try:
|
|
143
|
+
shutil.rmtree(dir_path)
|
|
144
|
+
except OSError as e:
|
|
145
|
+
raise DatasetBatchManagementError(f"🛑 Failed to delete directory {dir_path}: {e}")
|
|
146
|
+
|
|
147
|
+
def start(self, *, num_records: int, buffer_size: int) -> None:
|
|
148
|
+
if num_records <= 0:
|
|
149
|
+
raise DatasetBatchManagementError("🛑 num_records must be positive.")
|
|
150
|
+
if buffer_size <= 0:
|
|
151
|
+
raise DatasetBatchManagementError("🛑 buffer_size must be positive.")
|
|
152
|
+
|
|
153
|
+
self._buffer_size = buffer_size
|
|
154
|
+
self._num_records_list = [buffer_size] * (num_records // buffer_size)
|
|
155
|
+
if remaining_records := num_records % buffer_size:
|
|
156
|
+
self._num_records_list.append(remaining_records)
|
|
157
|
+
self.reset()
|
|
158
|
+
|
|
159
|
+
def write(self) -> Path | None:
|
|
160
|
+
"""Write the current batch to a parquet file.
|
|
161
|
+
|
|
162
|
+
This method always writes results to the partial results path.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
The path to the written parquet file. If the buffer is empty, returns None.
|
|
166
|
+
"""
|
|
167
|
+
if len(self._buffer) == 0:
|
|
168
|
+
return None
|
|
169
|
+
try:
|
|
170
|
+
file_path = self.artifact_storage.write_batch_to_parquet_file(
|
|
171
|
+
batch_number=self._current_batch_number,
|
|
172
|
+
dataframe=pd.DataFrame(self._buffer),
|
|
173
|
+
batch_stage=BatchStage.PARTIAL_RESULT,
|
|
174
|
+
)
|
|
175
|
+
return file_path
|
|
176
|
+
except Exception as e:
|
|
177
|
+
raise DatasetBatchManagementError(f"🛑 Failed to write batch {self._current_batch_number}: {e}")
|
|
178
|
+
|
|
179
|
+
def update_record(self, index: int, record: dict) -> None:
|
|
180
|
+
if index < 0 or index >= len(self._buffer):
|
|
181
|
+
raise IndexError(f"🛑 Index {index} is out of bounds for buffer of size {len(self._buffer)}.")
|
|
182
|
+
self._buffer[index] = record
|
|
183
|
+
|
|
184
|
+
def update_records(self, records: list[dict]) -> None:
|
|
185
|
+
if len(records) != len(self._buffer):
|
|
186
|
+
raise DatasetBatchManagementError(
|
|
187
|
+
f"🛑 Number of records to update ({len(records)}) must match "
|
|
188
|
+
f"the number of records in the buffer ({len(self._buffer)})."
|
|
189
|
+
)
|
|
190
|
+
self._buffer = records
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.engine.errors import DataDesignerError
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DatasetBatchManagementError(DataDesignerError): ...
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ConfigCompilationError(DataDesignerError): ...
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DAGCircularDependencyError(DataDesignerError): ...
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
from ..errors import DataDesignerError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DataDesignerRuntimeError(DataDesignerError): ...
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class UnknownModelAliasError(DataDesignerError): ...
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class UnknownProviderError(DataDesignerError): ...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class NoModelProvidersError(DataDesignerError): ...
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SecretResolutionError(DataDesignerError): ...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class RemoteValidationSchemaError(DataDesignerError): ...
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LocalCallableValidationError(DataDesignerError): ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ErrorTrap(BaseModel):
|
|
31
|
+
error_count: int = 0
|
|
32
|
+
task_errors: dict[str, int] = Field(default_factory=dict)
|
|
33
|
+
|
|
34
|
+
def _track_error(self, error: DataDesignerError) -> None:
|
|
35
|
+
"""
|
|
36
|
+
Track a specific error type.
|
|
37
|
+
"""
|
|
38
|
+
error_type = type(error).__name__
|
|
39
|
+
if error_type not in self.task_errors:
|
|
40
|
+
self.task_errors[error_type] = 0
|
|
41
|
+
self.task_errors[error_type] += 1
|
|
42
|
+
|
|
43
|
+
def handle_error(self, error: Exception) -> None:
|
|
44
|
+
self.error_count += 1
|
|
45
|
+
|
|
46
|
+
if not isinstance(error, DataDesignerError):
|
|
47
|
+
error = DataDesignerError(str(error))
|
|
48
|
+
|
|
49
|
+
self._track_error(error)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from functools import cached_property
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, field_validator, model_validator
|
|
7
|
+
from typing_extensions import Self
|
|
8
|
+
|
|
9
|
+
from data_designer.config.models import ModelProvider
|
|
10
|
+
from data_designer.engine.errors import NoModelProvidersError, UnknownProviderError
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ModelProviderRegistry(BaseModel):
|
|
14
|
+
providers: list[ModelProvider]
|
|
15
|
+
default: str | None = None
|
|
16
|
+
|
|
17
|
+
@field_validator("providers", mode="after")
|
|
18
|
+
@classmethod
|
|
19
|
+
def validate_providers_not_empty(cls, v: list[ModelProvider]) -> list[ModelProvider]:
|
|
20
|
+
if len(v) == 0:
|
|
21
|
+
raise ValueError("At least one model provider must be defined")
|
|
22
|
+
return v
|
|
23
|
+
|
|
24
|
+
@field_validator("providers", mode="after")
|
|
25
|
+
@classmethod
|
|
26
|
+
def validate_providers_have_unique_names(cls, v: list[ModelProvider]) -> list[ModelProvider]:
|
|
27
|
+
names = set()
|
|
28
|
+
dupes = set()
|
|
29
|
+
for provider in v:
|
|
30
|
+
if provider.name in names:
|
|
31
|
+
dupes.add(provider.name)
|
|
32
|
+
names.add(provider.name)
|
|
33
|
+
|
|
34
|
+
if len(dupes) > 0:
|
|
35
|
+
raise ValueError(f"Model providers must have unique names, found duplicates: {dupes}")
|
|
36
|
+
return v
|
|
37
|
+
|
|
38
|
+
@model_validator(mode="after")
|
|
39
|
+
def check_implicit_default(self) -> Self:
|
|
40
|
+
if self.default is None and len(self.providers) != 1:
|
|
41
|
+
raise ValueError("A default provider must be specified if multiple model providers are defined")
|
|
42
|
+
return self
|
|
43
|
+
|
|
44
|
+
@model_validator(mode="after")
|
|
45
|
+
def check_default_exists(self) -> Self:
|
|
46
|
+
if self.default and self.default not in self._providers_dict:
|
|
47
|
+
raise ValueError(f"Specified default {self.default!r} not found in providers list")
|
|
48
|
+
return self
|
|
49
|
+
|
|
50
|
+
def get_default_provider_name(self) -> str:
|
|
51
|
+
return self.default or self.providers[0].name
|
|
52
|
+
|
|
53
|
+
@cached_property
|
|
54
|
+
def _providers_dict(self) -> dict[str, ModelProvider]:
|
|
55
|
+
return {p.name: p for p in self.providers}
|
|
56
|
+
|
|
57
|
+
def get_provider(self, name: str | None) -> ModelProvider:
|
|
58
|
+
if name is None:
|
|
59
|
+
name = self.get_default_provider_name()
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
return self._providers_dict[name]
|
|
63
|
+
except KeyError:
|
|
64
|
+
raise UnknownProviderError(f"No provider named {name!r} registered")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def resolve_model_provider_registry(
|
|
68
|
+
model_providers: list[ModelProvider], default_provider_name: str | None = None
|
|
69
|
+
) -> ModelProviderRegistry:
|
|
70
|
+
if len(model_providers) == 0:
|
|
71
|
+
raise NoModelProvidersError("At least one model provider must be defined")
|
|
72
|
+
return ModelProviderRegistry(
|
|
73
|
+
providers=model_providers,
|
|
74
|
+
default=default_provider_name or model_providers[0].name,
|
|
75
|
+
)
|
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from functools import wraps
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from litellm.exceptions import (
|
|
12
|
+
APIConnectionError,
|
|
13
|
+
APIError,
|
|
14
|
+
AuthenticationError,
|
|
15
|
+
BadRequestError,
|
|
16
|
+
ContextWindowExceededError,
|
|
17
|
+
InternalServerError,
|
|
18
|
+
NotFoundError,
|
|
19
|
+
PermissionDeniedError,
|
|
20
|
+
RateLimitError,
|
|
21
|
+
Timeout,
|
|
22
|
+
UnprocessableEntityError,
|
|
23
|
+
UnsupportedParamsError,
|
|
24
|
+
)
|
|
25
|
+
from pydantic import BaseModel
|
|
26
|
+
|
|
27
|
+
from data_designer.engine.errors import DataDesignerError
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_exception_primary_cause(exception: BaseException) -> BaseException:
|
|
33
|
+
"""Returns the primary cause of an exception by walking backwards.
|
|
34
|
+
|
|
35
|
+
This recursive walkback halts when it arrives at an exception which
|
|
36
|
+
has no provided __cause__ (e.g. __cause__ is None).
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
exception (Exception): An exception to start from.
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
RecursionError: if for some reason exceptions have circular
|
|
43
|
+
dependencies (seems impossible in practice).
|
|
44
|
+
"""
|
|
45
|
+
if exception.__cause__ is None:
|
|
46
|
+
return exception
|
|
47
|
+
else:
|
|
48
|
+
return get_exception_primary_cause(exception.__cause__)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class GenerationValidationFailureError(Exception): ...
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ModelRateLimitError(DataDesignerError): ...
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class ModelTimeoutError(DataDesignerError): ...
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ModelContextWindowExceededError(DataDesignerError): ...
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ModelAuthenticationError(DataDesignerError): ...
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ModelPermissionDeniedError(DataDesignerError): ...
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class ModelNotFoundError(DataDesignerError): ...
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class ModelUnsupportedParamsError(DataDesignerError): ...
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ModelBadRequestError(DataDesignerError): ...
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class ModelInternalServerError(DataDesignerError): ...
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class ModelAPIError(DataDesignerError): ...
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class ModelUnprocessableEntityError(DataDesignerError): ...
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class ModelAPIConnectionError(DataDesignerError): ...
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class ModelStructuredOutputError(DataDesignerError): ...
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class ModelGenerationValidationFailureError(DataDesignerError): ...
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class FormattedLLMErrorMessage(BaseModel):
|
|
97
|
+
cause: str
|
|
98
|
+
solution: str
|
|
99
|
+
|
|
100
|
+
def __str__(self) -> str:
|
|
101
|
+
return "\n".join(
|
|
102
|
+
[
|
|
103
|
+
" |----------",
|
|
104
|
+
f" | Cause: {self.cause}",
|
|
105
|
+
f" | Solution: {self.solution}",
|
|
106
|
+
" |----------",
|
|
107
|
+
]
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def handle_llm_exceptions(
|
|
112
|
+
exception: Exception, model_name: str, model_provider_name: str, purpose: str | None = None
|
|
113
|
+
) -> None:
|
|
114
|
+
"""Handle LLM-related exceptions and convert them to appropriate DataDesignerError errors.
|
|
115
|
+
|
|
116
|
+
This method centralizes the exception handling logic for LLM operations,
|
|
117
|
+
making it reusable across different contexts.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
exception: The exception that was raised
|
|
121
|
+
model_name: Name of the model that was being used
|
|
122
|
+
model_provider_name: Name of the model provider that was being used
|
|
123
|
+
purpose: The purpose of the model usage to show as context in the error message
|
|
124
|
+
Raises:
|
|
125
|
+
DataDesignerError: A more user-friendly error with appropriate error type and message
|
|
126
|
+
"""
|
|
127
|
+
purpose = purpose or "running generation"
|
|
128
|
+
authentication_error = FormattedLLMErrorMessage(
|
|
129
|
+
cause=f"The API key provided for model {model_name!r} was found to be invalid or expired while {purpose}.",
|
|
130
|
+
solution=f"Verify your API key for model provider and update it in your settings for model provider {model_provider_name!r}.",
|
|
131
|
+
)
|
|
132
|
+
err_msg_parser = DownstreamLLMExceptionMessageParser(model_name, model_provider_name, purpose)
|
|
133
|
+
match exception:
|
|
134
|
+
# Common errors that can come from LiteLLM
|
|
135
|
+
case APIError():
|
|
136
|
+
raise err_msg_parser.parse_api_error(exception, authentication_error) from None
|
|
137
|
+
|
|
138
|
+
case APIConnectionError():
|
|
139
|
+
raise ModelAPIConnectionError(
|
|
140
|
+
FormattedLLMErrorMessage(
|
|
141
|
+
cause=f"Connection to model {model_name!r} hosted on model provider {model_provider_name!r} failed while {purpose}.",
|
|
142
|
+
solution="Check your network/proxy/firewall settings.",
|
|
143
|
+
)
|
|
144
|
+
) from None
|
|
145
|
+
|
|
146
|
+
case AuthenticationError():
|
|
147
|
+
raise ModelAuthenticationError(authentication_error) from None
|
|
148
|
+
|
|
149
|
+
case ContextWindowExceededError():
|
|
150
|
+
raise err_msg_parser.parse_context_window_exceeded_error(exception) from None
|
|
151
|
+
|
|
152
|
+
case UnsupportedParamsError():
|
|
153
|
+
raise ModelUnsupportedParamsError(
|
|
154
|
+
FormattedLLMErrorMessage(
|
|
155
|
+
cause=f"One or more of the parameters you provided were found to be unsupported by model {model_name!r} while {purpose}.",
|
|
156
|
+
solution=f"Review the documentation for model provider {model_provider_name!r} and adjust your request.",
|
|
157
|
+
)
|
|
158
|
+
) from None
|
|
159
|
+
|
|
160
|
+
case BadRequestError():
|
|
161
|
+
raise err_msg_parser.parse_bad_request_error(exception) from None
|
|
162
|
+
|
|
163
|
+
case InternalServerError():
|
|
164
|
+
raise ModelInternalServerError(
|
|
165
|
+
FormattedLLMErrorMessage(
|
|
166
|
+
cause=f"Model {model_name!r} is currently experiencing internal server issues while {purpose}.",
|
|
167
|
+
solution=f"Try again in a few moments. Check with your model provider {model_provider_name!r} if the issue persists.",
|
|
168
|
+
)
|
|
169
|
+
) from None
|
|
170
|
+
|
|
171
|
+
case NotFoundError():
|
|
172
|
+
raise ModelNotFoundError(
|
|
173
|
+
FormattedLLMErrorMessage(
|
|
174
|
+
cause=f"The specified model {model_name!r} could not be found while {purpose}.",
|
|
175
|
+
solution=f"Check that the model name is correct and supported by your model provider {model_provider_name!r} and try again.",
|
|
176
|
+
)
|
|
177
|
+
) from None
|
|
178
|
+
|
|
179
|
+
case PermissionDeniedError():
|
|
180
|
+
raise ModelPermissionDeniedError(
|
|
181
|
+
FormattedLLMErrorMessage(
|
|
182
|
+
cause=f"Your API key was found to lack the necessary permissions to use model {model_name!r} while {purpose}.",
|
|
183
|
+
solution=f"Use an API key that has the right permissions for the model or use a model the API key in use has access to in model provider {model_provider_name!r}.",
|
|
184
|
+
)
|
|
185
|
+
) from None
|
|
186
|
+
|
|
187
|
+
case RateLimitError():
|
|
188
|
+
raise ModelRateLimitError(
|
|
189
|
+
FormattedLLMErrorMessage(
|
|
190
|
+
cause=f"You have exceeded the rate limit for model {model_name!r} while {purpose}.",
|
|
191
|
+
solution="Wait and try again in a few moments.",
|
|
192
|
+
)
|
|
193
|
+
) from None
|
|
194
|
+
|
|
195
|
+
case Timeout():
|
|
196
|
+
raise ModelTimeoutError(
|
|
197
|
+
FormattedLLMErrorMessage(
|
|
198
|
+
cause=f"The request to model {model_name!r} timed out while {purpose}.",
|
|
199
|
+
solution="Check your connection and try again. You may need to increase the timeout setting for the model.",
|
|
200
|
+
)
|
|
201
|
+
) from None
|
|
202
|
+
|
|
203
|
+
case UnprocessableEntityError():
|
|
204
|
+
raise ModelUnprocessableEntityError(
|
|
205
|
+
FormattedLLMErrorMessage(
|
|
206
|
+
cause=f"The request to model {model_name!r} failed despite correct request format while {purpose}.",
|
|
207
|
+
solution="This is most likely temporary. Try again in a few moments.",
|
|
208
|
+
)
|
|
209
|
+
) from None
|
|
210
|
+
|
|
211
|
+
# Parsing and validation errors
|
|
212
|
+
case GenerationValidationFailureError():
|
|
213
|
+
raise ModelGenerationValidationFailureError(
|
|
214
|
+
FormattedLLMErrorMessage(
|
|
215
|
+
cause=f"The provided output schema was unable to be parsed from model {model_name!r} responses while {purpose}.",
|
|
216
|
+
solution="This is most likely temporary as we make additional attempts. If you continue to see more of this, simplify or modify the output schema for structured output and try again. If you are attempting token-intensive tasks like generations with high-reasoning effort, ensure that max_tokens in the model config is high enough to reach completion.",
|
|
217
|
+
)
|
|
218
|
+
) from None
|
|
219
|
+
|
|
220
|
+
case DataDesignerError():
|
|
221
|
+
raise exception from None
|
|
222
|
+
|
|
223
|
+
case _:
|
|
224
|
+
raise DataDesignerError(
|
|
225
|
+
FormattedLLMErrorMessage(
|
|
226
|
+
cause=f"An unexpected error occurred while {purpose}.",
|
|
227
|
+
solution=f"Review the stack trace for more details: {exception}",
|
|
228
|
+
)
|
|
229
|
+
) from exception
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def catch_llm_exceptions(func: Callable) -> Callable:
|
|
233
|
+
"""This decorator should be used on any `ModelFacade` method that could potentially raise
|
|
234
|
+
exceptions that should turn into upstream user-facing errors.
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
@wraps(func)
|
|
238
|
+
def wrapper(model_facade: Any, *args, **kwargs):
|
|
239
|
+
try:
|
|
240
|
+
return func(model_facade, *args, **kwargs)
|
|
241
|
+
except Exception as e:
|
|
242
|
+
logger.debug(
|
|
243
|
+
"\n".join(
|
|
244
|
+
[
|
|
245
|
+
"",
|
|
246
|
+
"|----------",
|
|
247
|
+
f"| Caught an exception downstream of type {type(e)!r}. Re-raising it below as a custom error with more context.",
|
|
248
|
+
"|----------",
|
|
249
|
+
]
|
|
250
|
+
),
|
|
251
|
+
exc_info=True,
|
|
252
|
+
stack_info=True,
|
|
253
|
+
)
|
|
254
|
+
handle_llm_exceptions(
|
|
255
|
+
e, model_facade.model_name, model_facade.model_provider_name, purpose=kwargs.get("purpose")
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
return wrapper
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class DownstreamLLMExceptionMessageParser:
|
|
262
|
+
def __init__(self, model_name: str, model_provider_name: str, purpose: str):
|
|
263
|
+
self.model_name = model_name
|
|
264
|
+
self.model_provider_name = model_provider_name
|
|
265
|
+
self.purpose = purpose
|
|
266
|
+
|
|
267
|
+
def parse_bad_request_error(self, exception: BadRequestError) -> DataDesignerError:
|
|
268
|
+
err_msg = FormattedLLMErrorMessage(
|
|
269
|
+
cause=f"The request for model {self.model_name!r} was found to be malformed or missing required parameters while {self.purpose}.",
|
|
270
|
+
solution="Check your request parameters and try again.",
|
|
271
|
+
)
|
|
272
|
+
if "is not a multimodal model" in str(exception):
|
|
273
|
+
err_msg = FormattedLLMErrorMessage(
|
|
274
|
+
cause=f"Model {self.model_name!r} is not a multimodal model, but it looks like you are trying to provide multimodal context while {self.purpose}.",
|
|
275
|
+
solution="Check your request parameters and try again.",
|
|
276
|
+
)
|
|
277
|
+
return ModelBadRequestError(err_msg)
|
|
278
|
+
|
|
279
|
+
def parse_context_window_exceeded_error(self, exception: ContextWindowExceededError) -> DataDesignerError:
|
|
280
|
+
cause = f"The input data for model '{self.model_name}' was found to exceed its supported context width while {self.purpose}."
|
|
281
|
+
try:
|
|
282
|
+
if "OpenAIException - This model's maximum context length is " in str(exception):
|
|
283
|
+
openai_exception_cause = (
|
|
284
|
+
str(exception).split("OpenAIException - ")[1].split("\n")[0].split(" Please reduce ")[0]
|
|
285
|
+
)
|
|
286
|
+
cause = f"{cause} {openai_exception_cause}"
|
|
287
|
+
except Exception:
|
|
288
|
+
pass
|
|
289
|
+
finally:
|
|
290
|
+
return ModelContextWindowExceededError(
|
|
291
|
+
FormattedLLMErrorMessage(
|
|
292
|
+
cause=cause,
|
|
293
|
+
solution="Check the model's supported max context width. Adjust the length of your input along with completions and try again.",
|
|
294
|
+
)
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
def parse_api_error(
|
|
298
|
+
self, exception: InternalServerError, auth_error_msg: FormattedLLMErrorMessage
|
|
299
|
+
) -> DataDesignerError:
|
|
300
|
+
if "Error code: 403" in str(exception):
|
|
301
|
+
return ModelAuthenticationError(auth_error_msg)
|
|
302
|
+
|
|
303
|
+
return ModelAPIError(
|
|
304
|
+
FormattedLLMErrorMessage(
|
|
305
|
+
cause=f"An unexpected API error occurred with model {self.model_name!r} while {self.purpose}.",
|
|
306
|
+
solution=f"Try again in a few moments. Check with your model provider {self.model_provider_name!r} if the issue persists.",
|
|
307
|
+
)
|
|
308
|
+
)
|