data-designer 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +2 -0
- data_designer/_version.py +2 -2
- data_designer/cli/__init__.py +2 -0
- data_designer/cli/commands/download.py +2 -0
- data_designer/cli/commands/list.py +2 -0
- data_designer/cli/commands/models.py +2 -0
- data_designer/cli/commands/providers.py +2 -0
- data_designer/cli/commands/reset.py +2 -0
- data_designer/cli/controllers/__init__.py +2 -0
- data_designer/cli/controllers/download_controller.py +2 -0
- data_designer/cli/controllers/model_controller.py +6 -1
- data_designer/cli/controllers/provider_controller.py +6 -1
- data_designer/cli/forms/__init__.py +2 -0
- data_designer/cli/forms/builder.py +2 -0
- data_designer/cli/forms/field.py +2 -0
- data_designer/cli/forms/form.py +2 -0
- data_designer/cli/forms/model_builder.py +2 -0
- data_designer/cli/forms/provider_builder.py +2 -0
- data_designer/cli/main.py +2 -0
- data_designer/cli/repositories/__init__.py +2 -0
- data_designer/cli/repositories/base.py +2 -0
- data_designer/cli/repositories/model_repository.py +2 -0
- data_designer/cli/repositories/persona_repository.py +2 -0
- data_designer/cli/repositories/provider_repository.py +2 -0
- data_designer/cli/services/__init__.py +2 -0
- data_designer/cli/services/download_service.py +2 -0
- data_designer/cli/services/model_service.py +2 -0
- data_designer/cli/services/provider_service.py +2 -0
- data_designer/cli/ui.py +2 -0
- data_designer/cli/utils.py +2 -0
- data_designer/config/analysis/column_profilers.py +2 -0
- data_designer/config/analysis/column_statistics.py +8 -5
- data_designer/config/analysis/dataset_profiler.py +9 -3
- data_designer/config/analysis/utils/errors.py +2 -0
- data_designer/config/analysis/utils/reporting.py +7 -3
- data_designer/config/column_configs.py +77 -7
- data_designer/config/column_types.py +33 -36
- data_designer/config/dataset_builders.py +2 -0
- data_designer/config/default_model_settings.py +1 -0
- data_designer/config/errors.py +2 -0
- data_designer/config/exports.py +2 -0
- data_designer/config/interface.py +3 -2
- data_designer/config/models.py +7 -2
- data_designer/config/preview_results.py +7 -3
- data_designer/config/processors.py +2 -0
- data_designer/config/run_config.py +2 -0
- data_designer/config/sampler_constraints.py +2 -0
- data_designer/config/sampler_params.py +7 -2
- data_designer/config/seed.py +2 -0
- data_designer/config/seed_source.py +7 -2
- data_designer/config/seed_source_types.py +2 -0
- data_designer/config/utils/constants.py +2 -0
- data_designer/config/utils/errors.py +2 -0
- data_designer/config/utils/info.py +2 -0
- data_designer/config/utils/io_helpers.py +8 -3
- data_designer/config/utils/misc.py +2 -2
- data_designer/config/utils/numerical_helpers.py +2 -0
- data_designer/config/utils/type_helpers.py +2 -0
- data_designer/config/utils/visualization.py +8 -4
- data_designer/config/validator_params.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +9 -8
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +15 -19
- data_designer/engine/analysis/column_profilers/registry.py +2 -0
- data_designer/engine/analysis/column_statistics.py +5 -2
- data_designer/engine/analysis/dataset_profiler.py +12 -9
- data_designer/engine/analysis/errors.py +2 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +7 -4
- data_designer/engine/analysis/utils/judge_score_processing.py +7 -3
- data_designer/engine/column_generators/generators/base.py +26 -14
- data_designer/engine/column_generators/generators/embedding.py +4 -11
- data_designer/engine/column_generators/generators/expression.py +7 -16
- data_designer/engine/column_generators/generators/llm_completion.py +11 -37
- data_designer/engine/column_generators/generators/samplers.py +8 -14
- data_designer/engine/column_generators/generators/seed_dataset.py +9 -15
- data_designer/engine/column_generators/generators/validation.py +8 -20
- data_designer/engine/column_generators/registry.py +2 -0
- data_designer/engine/column_generators/utils/errors.py +2 -0
- data_designer/engine/column_generators/utils/generator_classification.py +2 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +2 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +4 -2
- data_designer/engine/compiler.py +3 -6
- data_designer/engine/configurable_task.py +12 -13
- data_designer/engine/dataset_builders/artifact_storage.py +87 -8
- data_designer/engine/dataset_builders/column_wise_builder.py +32 -34
- data_designer/engine/dataset_builders/errors.py +2 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +2 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +2 -0
- data_designer/engine/dataset_builders/utils/dag.py +7 -2
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +9 -6
- data_designer/engine/dataset_builders/utils/errors.py +2 -0
- data_designer/engine/errors.py +2 -0
- data_designer/engine/model_provider.py +2 -0
- data_designer/engine/models/errors.py +23 -31
- data_designer/engine/models/facade.py +12 -9
- data_designer/engine/models/factory.py +42 -0
- data_designer/engine/models/litellm_overrides.py +22 -11
- data_designer/engine/models/parsers/errors.py +2 -0
- data_designer/engine/models/parsers/parser.py +2 -2
- data_designer/engine/models/parsers/postprocessors.py +1 -0
- data_designer/engine/models/parsers/tag_parsers.py +2 -0
- data_designer/engine/models/parsers/types.py +2 -0
- data_designer/engine/models/recipes/base.py +2 -0
- data_designer/engine/models/recipes/response_recipes.py +2 -0
- data_designer/engine/models/registry.py +11 -18
- data_designer/engine/models/telemetry.py +6 -2
- data_designer/engine/processing/ginja/ast.py +2 -0
- data_designer/engine/processing/ginja/environment.py +2 -0
- data_designer/engine/processing/ginja/exceptions.py +2 -0
- data_designer/engine/processing/ginja/record.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +9 -2
- data_designer/engine/processing/gsonschema/schema_transformers.py +2 -0
- data_designer/engine/processing/gsonschema/types.py +2 -0
- data_designer/engine/processing/gsonschema/validators.py +10 -6
- data_designer/engine/processing/processors/base.py +1 -5
- data_designer/engine/processing/processors/drop_columns.py +7 -10
- data_designer/engine/processing/processors/registry.py +2 -0
- data_designer/engine/processing/processors/schema_transform.py +7 -10
- data_designer/engine/processing/utils.py +7 -3
- data_designer/engine/registry/base.py +2 -0
- data_designer/engine/registry/data_designer_registry.py +2 -0
- data_designer/engine/registry/errors.py +2 -0
- data_designer/engine/resources/managed_dataset_generator.py +6 -2
- data_designer/engine/resources/managed_dataset_repository.py +8 -5
- data_designer/engine/resources/managed_storage.py +2 -0
- data_designer/engine/resources/resource_provider.py +8 -1
- data_designer/engine/resources/seed_reader.py +7 -2
- data_designer/engine/sampling_gen/column.py +2 -0
- data_designer/engine/sampling_gen/constraints.py +8 -2
- data_designer/engine/sampling_gen/data_sources/base.py +10 -7
- data_designer/engine/sampling_gen/data_sources/errors.py +2 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +27 -22
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +2 -2
- data_designer/engine/sampling_gen/entities/email_address_utils.py +2 -0
- data_designer/engine/sampling_gen/entities/errors.py +2 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +2 -0
- data_designer/engine/sampling_gen/entities/person.py +2 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +8 -1
- data_designer/engine/sampling_gen/errors.py +2 -0
- data_designer/engine/sampling_gen/generator.py +5 -4
- data_designer/engine/sampling_gen/jinja_utils.py +7 -3
- data_designer/engine/sampling_gen/people_gen.py +7 -7
- data_designer/engine/sampling_gen/person_constants.py +2 -0
- data_designer/engine/sampling_gen/schema.py +5 -1
- data_designer/engine/sampling_gen/schema_builder.py +2 -0
- data_designer/engine/sampling_gen/utils.py +7 -1
- data_designer/engine/secret_resolver.py +2 -0
- data_designer/engine/validation.py +2 -2
- data_designer/engine/validators/__init__.py +2 -0
- data_designer/engine/validators/base.py +2 -0
- data_designer/engine/validators/local_callable.py +7 -2
- data_designer/engine/validators/python.py +7 -1
- data_designer/engine/validators/remote.py +7 -1
- data_designer/engine/validators/sql.py +8 -3
- data_designer/errors.py +2 -0
- data_designer/essentials/__init__.py +2 -0
- data_designer/interface/data_designer.py +23 -17
- data_designer/interface/errors.py +2 -0
- data_designer/interface/results.py +5 -2
- data_designer/lazy_heavy_imports.py +54 -0
- data_designer/logging.py +2 -0
- data_designer/plugins/__init__.py +2 -0
- data_designer/plugins/errors.py +2 -0
- data_designer/plugins/plugin.py +0 -1
- data_designer/plugins/registry.py +2 -0
- data_designer/plugins/testing/__init__.py +2 -0
- data_designer/plugins/testing/stubs.py +21 -43
- data_designer/plugins/testing/utils.py +2 -0
- {data_designer-0.3.4.dist-info → data_designer-0.3.6.dist-info}/METADATA +12 -5
- data_designer-0.3.6.dist-info/RECORD +196 -0
- data_designer-0.3.4.dist-info/RECORD +0 -194
- {data_designer-0.3.4.dist-info → data_designer-0.3.6.dist-info}/WHEEL +0 -0
- {data_designer-0.3.4.dist-info → data_designer-0.3.6.dist-info}/entry_points.txt +0 -0
- {data_designer-0.3.4.dist-info → data_designer-0.3.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,23 +1,30 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
import json
|
|
5
7
|
import logging
|
|
6
8
|
import shutil
|
|
7
9
|
from datetime import datetime
|
|
8
10
|
from functools import cached_property
|
|
9
11
|
from pathlib import Path
|
|
12
|
+
from typing import TYPE_CHECKING
|
|
10
13
|
|
|
11
|
-
import pandas as pd
|
|
12
14
|
from pydantic import BaseModel, field_validator, model_validator
|
|
13
15
|
|
|
14
16
|
from data_designer.config.utils.io_helpers import read_parquet_dataset
|
|
15
17
|
from data_designer.config.utils.type_helpers import StrEnum, resolve_string_enum
|
|
16
18
|
from data_designer.engine.dataset_builders.errors import ArtifactStorageError
|
|
19
|
+
from data_designer.lazy_heavy_imports import pd
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
import pandas as pd
|
|
17
23
|
|
|
18
24
|
logger = logging.getLogger(__name__)
|
|
19
25
|
|
|
20
26
|
BATCH_FILE_NAME_FORMAT = "batch_{batch_number:05d}.parquet"
|
|
27
|
+
SDG_CONFIG_FILENAME = "sdg.json"
|
|
21
28
|
|
|
22
29
|
|
|
23
30
|
class BatchStage(StrEnum):
|
|
@@ -164,12 +171,6 @@ class ArtifactStorage(BaseModel):
|
|
|
164
171
|
shutil.move(partial_result_path, final_file_path)
|
|
165
172
|
return final_file_path
|
|
166
173
|
|
|
167
|
-
def write_configs(self, json_file_name: str, configs: list[dict]) -> Path:
|
|
168
|
-
self.mkdir_if_needed(self.base_dataset_path)
|
|
169
|
-
with open(self.base_dataset_path / json_file_name, "w") as file:
|
|
170
|
-
json.dump([c.model_dump(mode="json") for c in configs], file, indent=4)
|
|
171
|
-
return self.base_dataset_path / json_file_name
|
|
172
|
-
|
|
173
174
|
def write_batch_to_parquet_file(
|
|
174
175
|
self,
|
|
175
176
|
batch_number: int,
|
|
@@ -194,11 +195,89 @@ class ArtifactStorage(BaseModel):
|
|
|
194
195
|
dataframe.to_parquet(file_path, index=False)
|
|
195
196
|
return file_path
|
|
196
197
|
|
|
198
|
+
def get_parquet_file_paths(self) -> list[str]:
|
|
199
|
+
"""Get list of parquet file paths relative to base_dataset_path.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
List of relative paths to parquet files in the final dataset folder.
|
|
203
|
+
"""
|
|
204
|
+
return [str(f.relative_to(self.base_dataset_path)) for f in sorted(self.final_dataset_path.glob("*.parquet"))]
|
|
205
|
+
|
|
206
|
+
def get_processor_file_paths(self) -> dict[str, list[str]]:
|
|
207
|
+
"""Get processor output files organized by processor name.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Dictionary mapping processor names to lists of relative file paths.
|
|
211
|
+
"""
|
|
212
|
+
processor_files: dict[str, list[str]] = {}
|
|
213
|
+
if self.processors_outputs_path.exists():
|
|
214
|
+
for processor_dir in sorted(self.processors_outputs_path.iterdir()):
|
|
215
|
+
if processor_dir.is_dir():
|
|
216
|
+
processor_name = processor_dir.name
|
|
217
|
+
processor_files[processor_name] = [
|
|
218
|
+
str(f.relative_to(self.base_dataset_path))
|
|
219
|
+
for f in sorted(processor_dir.rglob("*"))
|
|
220
|
+
if f.is_file()
|
|
221
|
+
]
|
|
222
|
+
return processor_files
|
|
223
|
+
|
|
224
|
+
def get_file_paths(self) -> dict[str, list[str] | dict[str, list[str]]]:
|
|
225
|
+
"""Get all file paths organized by type.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
Dictionary with 'parquet-files' and 'processor-files' keys.
|
|
229
|
+
"""
|
|
230
|
+
file_paths = {
|
|
231
|
+
"parquet-files": self.get_parquet_file_paths(),
|
|
232
|
+
}
|
|
233
|
+
processor_file_paths = self.get_processor_file_paths()
|
|
234
|
+
if processor_file_paths:
|
|
235
|
+
file_paths["processor-files"] = processor_file_paths
|
|
236
|
+
|
|
237
|
+
return file_paths
|
|
238
|
+
|
|
239
|
+
def read_metadata(self) -> dict:
|
|
240
|
+
"""Read metadata from the metadata.json file.
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
Dictionary containing the metadata.
|
|
244
|
+
|
|
245
|
+
Raises:
|
|
246
|
+
FileNotFoundError: If metadata file doesn't exist.
|
|
247
|
+
"""
|
|
248
|
+
with open(self.metadata_file_path, "r") as file:
|
|
249
|
+
return json.load(file)
|
|
250
|
+
|
|
197
251
|
def write_metadata(self, metadata: dict) -> Path:
|
|
252
|
+
"""Write metadata to the metadata.json file.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
metadata: Dictionary containing metadata to write.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
Path to the written metadata file.
|
|
259
|
+
"""
|
|
198
260
|
self.mkdir_if_needed(self.base_dataset_path)
|
|
199
261
|
with open(self.metadata_file_path, "w") as file:
|
|
200
|
-
json.dump(metadata, file)
|
|
262
|
+
json.dump(metadata, file, indent=4, sort_keys=True)
|
|
201
263
|
return self.metadata_file_path
|
|
202
264
|
|
|
265
|
+
def update_metadata(self, updates: dict) -> Path:
|
|
266
|
+
"""Update existing metadata with new fields.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
updates: Dictionary of fields to add/update in metadata.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
Path to the updated metadata file.
|
|
273
|
+
"""
|
|
274
|
+
try:
|
|
275
|
+
existing_metadata = self.read_metadata()
|
|
276
|
+
except FileNotFoundError:
|
|
277
|
+
existing_metadata = {}
|
|
278
|
+
|
|
279
|
+
existing_metadata.update(updates)
|
|
280
|
+
return self.write_metadata(existing_metadata)
|
|
281
|
+
|
|
203
282
|
def _get_stage_path(self, stage: BatchStage) -> Path:
|
|
204
283
|
return getattr(self, resolve_string_enum(stage, BatchStage).value)
|
|
@@ -12,9 +12,9 @@ import uuid
|
|
|
12
12
|
from pathlib import Path
|
|
13
13
|
from typing import TYPE_CHECKING, Callable
|
|
14
14
|
|
|
15
|
-
import pandas as pd
|
|
16
|
-
|
|
17
15
|
from data_designer.config.column_types import ColumnConfigT
|
|
16
|
+
from data_designer.config.config_builder import BuilderConfig
|
|
17
|
+
from data_designer.config.data_designer_config import DataDesignerConfig
|
|
18
18
|
from data_designer.config.dataset_builders import BuildStage
|
|
19
19
|
from data_designer.config.processors import (
|
|
20
20
|
DropColumnsProcessorConfig,
|
|
@@ -27,40 +27,38 @@ from data_designer.engine.column_generators.generators.base import (
|
|
|
27
27
|
GenerationStrategy,
|
|
28
28
|
)
|
|
29
29
|
from data_designer.engine.column_generators.utils.generator_classification import column_type_is_model_generated
|
|
30
|
-
from data_designer.engine.
|
|
30
|
+
from data_designer.engine.compiler import compile_data_designer_config
|
|
31
|
+
from data_designer.engine.dataset_builders.artifact_storage import SDG_CONFIG_FILENAME, ArtifactStorage
|
|
31
32
|
from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
|
|
32
|
-
from data_designer.engine.dataset_builders.multi_column_configs import
|
|
33
|
-
DatasetBuilderColumnConfigT,
|
|
34
|
-
MultiColumnConfig,
|
|
35
|
-
)
|
|
33
|
+
from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig
|
|
36
34
|
from data_designer.engine.dataset_builders.utils.concurrency import (
|
|
37
35
|
MAX_CONCURRENCY_PER_NON_LLM_GENERATOR,
|
|
38
36
|
ConcurrentThreadExecutor,
|
|
39
37
|
)
|
|
40
|
-
from data_designer.engine.dataset_builders.utils.
|
|
41
|
-
|
|
42
|
-
)
|
|
38
|
+
from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs
|
|
39
|
+
from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager
|
|
43
40
|
from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler
|
|
44
41
|
from data_designer.engine.processing.processors.base import Processor
|
|
45
42
|
from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
|
|
46
43
|
from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
|
|
47
44
|
from data_designer.engine.resources.resource_provider import ResourceProvider
|
|
45
|
+
from data_designer.lazy_heavy_imports import pd
|
|
48
46
|
|
|
49
47
|
if TYPE_CHECKING:
|
|
48
|
+
import pandas as pd
|
|
49
|
+
|
|
50
50
|
from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModelRegistry
|
|
51
51
|
from data_designer.engine.models.usage import ModelUsageStats
|
|
52
52
|
|
|
53
53
|
logger = logging.getLogger(__name__)
|
|
54
54
|
|
|
55
|
-
|
|
56
55
|
_CLIENT_VERSION: str = importlib.metadata.version("data_designer")
|
|
57
56
|
|
|
58
57
|
|
|
59
58
|
class ColumnWiseDatasetBuilder:
|
|
60
59
|
def __init__(
|
|
61
60
|
self,
|
|
62
|
-
|
|
63
|
-
processor_configs: list[ProcessorConfig],
|
|
61
|
+
data_designer_config: DataDesignerConfig,
|
|
64
62
|
resource_provider: ResourceProvider,
|
|
65
63
|
registry: DataDesignerRegistry | None = None,
|
|
66
64
|
):
|
|
@@ -68,8 +66,12 @@ class ColumnWiseDatasetBuilder:
|
|
|
68
66
|
self._resource_provider = resource_provider
|
|
69
67
|
self._records_to_drop: set[int] = set()
|
|
70
68
|
self._registry = registry or DataDesignerRegistry()
|
|
71
|
-
|
|
72
|
-
self.
|
|
69
|
+
|
|
70
|
+
self._data_designer_config = compile_data_designer_config(data_designer_config, resource_provider)
|
|
71
|
+
self._column_configs = compile_dataset_builder_column_configs(self._data_designer_config)
|
|
72
|
+
self._processors: dict[BuildStage, list[Processor]] = self._initialize_processors(
|
|
73
|
+
self._data_designer_config.processors or []
|
|
74
|
+
)
|
|
73
75
|
self._validate_column_configs()
|
|
74
76
|
|
|
75
77
|
@property
|
|
@@ -96,9 +98,8 @@ class ColumnWiseDatasetBuilder:
|
|
|
96
98
|
num_records: int,
|
|
97
99
|
on_batch_complete: Callable[[Path], None] | None = None,
|
|
98
100
|
) -> Path:
|
|
99
|
-
self._write_configs()
|
|
100
101
|
self._run_model_health_check_if_needed()
|
|
101
|
-
|
|
102
|
+
self._write_builder_config()
|
|
102
103
|
generators = self._initialize_generators()
|
|
103
104
|
start_time = time.perf_counter()
|
|
104
105
|
group_id = uuid.uuid4().hex
|
|
@@ -157,6 +158,12 @@ class ColumnWiseDatasetBuilder:
|
|
|
157
158
|
for config in self._column_configs
|
|
158
159
|
]
|
|
159
160
|
|
|
161
|
+
def _write_builder_config(self) -> None:
|
|
162
|
+
self.artifact_storage.mkdir_if_needed(self.artifact_storage.base_dataset_path)
|
|
163
|
+
BuilderConfig(data_designer=self._data_designer_config).to_json(
|
|
164
|
+
self.artifact_storage.base_dataset_path / SDG_CONFIG_FILENAME
|
|
165
|
+
)
|
|
166
|
+
|
|
160
167
|
def _run_batch(
|
|
161
168
|
self, generators: list[ColumnGenerator], *, batch_mode: str, save_partial_results: bool = True, group_id: str
|
|
162
169
|
) -> None:
|
|
@@ -164,15 +171,16 @@ class ColumnWiseDatasetBuilder:
|
|
|
164
171
|
for generator in generators:
|
|
165
172
|
generator.log_pre_generation()
|
|
166
173
|
try:
|
|
174
|
+
generation_strategy = generator.get_generation_strategy()
|
|
167
175
|
if generator.can_generate_from_scratch and self.batch_manager.buffer_is_empty:
|
|
168
176
|
self._run_from_scratch_column_generator(generator)
|
|
169
|
-
elif
|
|
177
|
+
elif generation_strategy == GenerationStrategy.CELL_BY_CELL:
|
|
170
178
|
self._run_cell_by_cell_generator(generator)
|
|
171
|
-
elif
|
|
179
|
+
elif generation_strategy == GenerationStrategy.FULL_COLUMN:
|
|
172
180
|
self._run_full_column_generator(generator)
|
|
173
181
|
else:
|
|
174
|
-
logger.error(f"❌ Unknown generation strategy: {
|
|
175
|
-
raise DatasetGenerationError(f"🛑 Unknown generation strategy: {
|
|
182
|
+
logger.error(f"❌ Unknown generation strategy: {generation_strategy}")
|
|
183
|
+
raise DatasetGenerationError(f"🛑 Unknown generation strategy: {generation_strategy}")
|
|
176
184
|
if save_partial_results:
|
|
177
185
|
self.batch_manager.write()
|
|
178
186
|
except Exception as e:
|
|
@@ -210,9 +218,9 @@ class ColumnWiseDatasetBuilder:
|
|
|
210
218
|
)
|
|
211
219
|
|
|
212
220
|
def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None:
|
|
213
|
-
if generator.
|
|
221
|
+
if generator.get_generation_strategy() != GenerationStrategy.CELL_BY_CELL:
|
|
214
222
|
raise DatasetGenerationError(
|
|
215
|
-
f"Generator {generator.
|
|
223
|
+
f"Generator {generator.name} is not a {GenerationStrategy.CELL_BY_CELL} "
|
|
216
224
|
"generator so concurrency through threads is not supported."
|
|
217
225
|
)
|
|
218
226
|
|
|
@@ -292,7 +300,7 @@ class ColumnWiseDatasetBuilder:
|
|
|
292
300
|
dataframe = processor.process(dataframe, current_batch_number=current_batch_number)
|
|
293
301
|
except Exception as e:
|
|
294
302
|
raise DatasetProcessingError(
|
|
295
|
-
f"🛑 Failed to process dataset with processor {processor.
|
|
303
|
+
f"🛑 Failed to process dataset with processor {processor.name} in stage {stage}: {e}"
|
|
296
304
|
) from e
|
|
297
305
|
return dataframe
|
|
298
306
|
|
|
@@ -307,16 +315,6 @@ class ColumnWiseDatasetBuilder:
|
|
|
307
315
|
def _worker_result_callback(self, result: dict, *, context: dict | None = None) -> None:
|
|
308
316
|
self.batch_manager.update_record(context["index"], result)
|
|
309
317
|
|
|
310
|
-
def _write_configs(self) -> None:
|
|
311
|
-
self.artifact_storage.write_configs(
|
|
312
|
-
json_file_name="column_configs.json",
|
|
313
|
-
configs=self._column_configs,
|
|
314
|
-
)
|
|
315
|
-
self.artifact_storage.write_configs(
|
|
316
|
-
json_file_name="model_configs.json",
|
|
317
|
-
configs=self._resource_provider.model_registry.model_configs.values(),
|
|
318
|
-
)
|
|
319
|
-
|
|
320
318
|
def _emit_batch_inference_events(
|
|
321
319
|
self, batch_mode: str, usage_deltas: dict[str, ModelUsageStats], group_id: str
|
|
322
320
|
) -> None:
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
from data_designer.config.column_types import DataDesignerColumnType
|
|
5
7
|
from data_designer.config.data_designer_config import DataDesignerConfig
|
|
6
8
|
from data_designer.config.processors import ProcessorConfig
|
|
@@ -1,13 +1,18 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
import
|
|
4
|
+
from __future__ import annotations
|
|
5
5
|
|
|
6
|
-
import
|
|
6
|
+
import logging
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
7
8
|
|
|
8
9
|
from data_designer.config.column_types import ColumnConfigT
|
|
9
10
|
from data_designer.engine.column_generators.utils.generator_classification import column_type_used_in_execution_dag
|
|
10
11
|
from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
|
|
12
|
+
from data_designer.lazy_heavy_imports import nx
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import networkx as nx
|
|
11
16
|
|
|
12
17
|
logger = logging.getLogger(__name__)
|
|
13
18
|
|
|
@@ -1,16 +1,20 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
import logging
|
|
5
7
|
import shutil
|
|
6
8
|
from pathlib import Path
|
|
7
|
-
from typing import Callable, Container, Iterator
|
|
8
|
-
|
|
9
|
-
import pandas as pd
|
|
10
|
-
import pyarrow.parquet as pq
|
|
9
|
+
from typing import TYPE_CHECKING, Callable, Container, Iterator
|
|
11
10
|
|
|
12
11
|
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage, BatchStage
|
|
13
12
|
from data_designer.engine.dataset_builders.utils.errors import DatasetBatchManagementError
|
|
13
|
+
from data_designer.lazy_heavy_imports import pd, pq
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
import pandas as pd
|
|
17
|
+
import pyarrow.parquet as pq
|
|
14
18
|
|
|
15
19
|
logger = logging.getLogger(__name__)
|
|
16
20
|
|
|
@@ -87,8 +91,7 @@ class DatasetBatchManager:
|
|
|
87
91
|
"total_num_batches": self.num_batches,
|
|
88
92
|
"buffer_size": self._buffer_size,
|
|
89
93
|
"schema": {field.name: str(field.type) for field in pq.read_schema(final_file_path)},
|
|
90
|
-
"file_paths":
|
|
91
|
-
"num_records": self.num_records_list[: self._current_batch_number + 1],
|
|
94
|
+
"file_paths": self.artifact_storage.get_file_paths(),
|
|
92
95
|
"num_completed_batches": self._current_batch_number + 1,
|
|
93
96
|
"dataset_name": self.artifact_storage.dataset_name,
|
|
94
97
|
}
|
data_designer/engine/errors.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
from pydantic import BaseModel, Field
|
|
5
7
|
|
|
6
8
|
from data_designer.errors import DataDesignerError
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
from functools import cached_property
|
|
5
7
|
|
|
6
8
|
from pydantic import BaseModel, field_validator, model_validator
|
|
@@ -6,25 +6,15 @@ from __future__ import annotations
|
|
|
6
6
|
import logging
|
|
7
7
|
from collections.abc import Callable
|
|
8
8
|
from functools import wraps
|
|
9
|
-
from typing import Any
|
|
10
|
-
|
|
11
|
-
from litellm.exceptions import (
|
|
12
|
-
APIConnectionError,
|
|
13
|
-
APIError,
|
|
14
|
-
AuthenticationError,
|
|
15
|
-
BadRequestError,
|
|
16
|
-
ContextWindowExceededError,
|
|
17
|
-
InternalServerError,
|
|
18
|
-
NotFoundError,
|
|
19
|
-
PermissionDeniedError,
|
|
20
|
-
RateLimitError,
|
|
21
|
-
Timeout,
|
|
22
|
-
UnprocessableEntityError,
|
|
23
|
-
UnsupportedParamsError,
|
|
24
|
-
)
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
|
+
|
|
25
11
|
from pydantic import BaseModel
|
|
26
12
|
|
|
27
13
|
from data_designer.engine.errors import DataDesignerError
|
|
14
|
+
from data_designer.lazy_heavy_imports import litellm
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
import litellm
|
|
28
18
|
|
|
29
19
|
logger = logging.getLogger(__name__)
|
|
30
20
|
|
|
@@ -132,10 +122,10 @@ def handle_llm_exceptions(
|
|
|
132
122
|
err_msg_parser = DownstreamLLMExceptionMessageParser(model_name, model_provider_name, purpose)
|
|
133
123
|
match exception:
|
|
134
124
|
# Common errors that can come from LiteLLM
|
|
135
|
-
case APIError():
|
|
125
|
+
case litellm.exceptions.APIError():
|
|
136
126
|
raise err_msg_parser.parse_api_error(exception, authentication_error) from None
|
|
137
127
|
|
|
138
|
-
case APIConnectionError():
|
|
128
|
+
case litellm.exceptions.APIConnectionError():
|
|
139
129
|
raise ModelAPIConnectionError(
|
|
140
130
|
FormattedLLMErrorMessage(
|
|
141
131
|
cause=f"Connection to model {model_name!r} hosted on model provider {model_provider_name!r} failed while {purpose}.",
|
|
@@ -143,13 +133,13 @@ def handle_llm_exceptions(
|
|
|
143
133
|
)
|
|
144
134
|
) from None
|
|
145
135
|
|
|
146
|
-
case AuthenticationError():
|
|
136
|
+
case litellm.exceptions.AuthenticationError():
|
|
147
137
|
raise ModelAuthenticationError(authentication_error) from None
|
|
148
138
|
|
|
149
|
-
case ContextWindowExceededError():
|
|
139
|
+
case litellm.exceptions.ContextWindowExceededError():
|
|
150
140
|
raise err_msg_parser.parse_context_window_exceeded_error(exception) from None
|
|
151
141
|
|
|
152
|
-
case UnsupportedParamsError():
|
|
142
|
+
case litellm.exceptions.UnsupportedParamsError():
|
|
153
143
|
raise ModelUnsupportedParamsError(
|
|
154
144
|
FormattedLLMErrorMessage(
|
|
155
145
|
cause=f"One or more of the parameters you provided were found to be unsupported by model {model_name!r} while {purpose}.",
|
|
@@ -157,10 +147,10 @@ def handle_llm_exceptions(
|
|
|
157
147
|
)
|
|
158
148
|
) from None
|
|
159
149
|
|
|
160
|
-
case BadRequestError():
|
|
150
|
+
case litellm.exceptions.BadRequestError():
|
|
161
151
|
raise err_msg_parser.parse_bad_request_error(exception) from None
|
|
162
152
|
|
|
163
|
-
case InternalServerError():
|
|
153
|
+
case litellm.exceptions.InternalServerError():
|
|
164
154
|
raise ModelInternalServerError(
|
|
165
155
|
FormattedLLMErrorMessage(
|
|
166
156
|
cause=f"Model {model_name!r} is currently experiencing internal server issues while {purpose}.",
|
|
@@ -168,7 +158,7 @@ def handle_llm_exceptions(
|
|
|
168
158
|
)
|
|
169
159
|
) from None
|
|
170
160
|
|
|
171
|
-
case NotFoundError():
|
|
161
|
+
case litellm.exceptions.NotFoundError():
|
|
172
162
|
raise ModelNotFoundError(
|
|
173
163
|
FormattedLLMErrorMessage(
|
|
174
164
|
cause=f"The specified model {model_name!r} could not be found while {purpose}.",
|
|
@@ -176,7 +166,7 @@ def handle_llm_exceptions(
|
|
|
176
166
|
)
|
|
177
167
|
) from None
|
|
178
168
|
|
|
179
|
-
case PermissionDeniedError():
|
|
169
|
+
case litellm.exceptions.PermissionDeniedError():
|
|
180
170
|
raise ModelPermissionDeniedError(
|
|
181
171
|
FormattedLLMErrorMessage(
|
|
182
172
|
cause=f"Your API key was found to lack the necessary permissions to use model {model_name!r} while {purpose}.",
|
|
@@ -184,7 +174,7 @@ def handle_llm_exceptions(
|
|
|
184
174
|
)
|
|
185
175
|
) from None
|
|
186
176
|
|
|
187
|
-
case RateLimitError():
|
|
177
|
+
case litellm.exceptions.RateLimitError():
|
|
188
178
|
raise ModelRateLimitError(
|
|
189
179
|
FormattedLLMErrorMessage(
|
|
190
180
|
cause=f"You have exceeded the rate limit for model {model_name!r} while {purpose}.",
|
|
@@ -192,7 +182,7 @@ def handle_llm_exceptions(
|
|
|
192
182
|
)
|
|
193
183
|
) from None
|
|
194
184
|
|
|
195
|
-
case Timeout():
|
|
185
|
+
case litellm.exceptions.Timeout():
|
|
196
186
|
raise ModelTimeoutError(
|
|
197
187
|
FormattedLLMErrorMessage(
|
|
198
188
|
cause=f"The request to model {model_name!r} timed out while {purpose}.",
|
|
@@ -200,7 +190,7 @@ def handle_llm_exceptions(
|
|
|
200
190
|
)
|
|
201
191
|
) from None
|
|
202
192
|
|
|
203
|
-
case UnprocessableEntityError():
|
|
193
|
+
case litellm.exceptions.UnprocessableEntityError():
|
|
204
194
|
raise ModelUnprocessableEntityError(
|
|
205
195
|
FormattedLLMErrorMessage(
|
|
206
196
|
cause=f"The request to model {model_name!r} failed despite correct request format while {purpose}.",
|
|
@@ -264,7 +254,7 @@ class DownstreamLLMExceptionMessageParser:
|
|
|
264
254
|
self.model_provider_name = model_provider_name
|
|
265
255
|
self.purpose = purpose
|
|
266
256
|
|
|
267
|
-
def parse_bad_request_error(self, exception: BadRequestError) -> DataDesignerError:
|
|
257
|
+
def parse_bad_request_error(self, exception: litellm.exceptions.BadRequestError) -> DataDesignerError:
|
|
268
258
|
err_msg = FormattedLLMErrorMessage(
|
|
269
259
|
cause=f"The request for model {self.model_name!r} was found to be malformed or missing required parameters while {self.purpose}.",
|
|
270
260
|
solution="Check your request parameters and try again.",
|
|
@@ -276,7 +266,9 @@ class DownstreamLLMExceptionMessageParser:
|
|
|
276
266
|
)
|
|
277
267
|
return ModelBadRequestError(err_msg)
|
|
278
268
|
|
|
279
|
-
def parse_context_window_exceeded_error(
|
|
269
|
+
def parse_context_window_exceeded_error(
|
|
270
|
+
self, exception: litellm.exceptions.ContextWindowExceededError
|
|
271
|
+
) -> DataDesignerError:
|
|
280
272
|
cause = f"The input data for model '{self.model_name}' was found to exceed its supported context width while {self.purpose}."
|
|
281
273
|
try:
|
|
282
274
|
if "OpenAIException - This model's maximum context length is " in str(exception):
|
|
@@ -295,7 +287,7 @@ class DownstreamLLMExceptionMessageParser:
|
|
|
295
287
|
)
|
|
296
288
|
|
|
297
289
|
def parse_api_error(
|
|
298
|
-
self, exception: InternalServerError, auth_error_msg: FormattedLLMErrorMessage
|
|
290
|
+
self, exception: litellm.exceptions.InternalServerError, auth_error_msg: FormattedLLMErrorMessage
|
|
299
291
|
) -> DataDesignerError:
|
|
300
292
|
if "Error code: 403" in str(exception):
|
|
301
293
|
return ModelAuthenticationError(auth_error_msg)
|
|
@@ -6,10 +6,7 @@ from __future__ import annotations
|
|
|
6
6
|
import logging
|
|
7
7
|
from collections.abc import Callable
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Any
|
|
10
|
-
|
|
11
|
-
from litellm.types.router import DeploymentTypedDict, LiteLLM_Params
|
|
12
|
-
from litellm.types.utils import EmbeddingResponse, ModelResponse
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
13
10
|
|
|
14
11
|
from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
|
|
15
12
|
from data_designer.engine.model_provider import ModelProviderRegistry
|
|
@@ -23,6 +20,10 @@ from data_designer.engine.models.parsers.errors import ParserException
|
|
|
23
20
|
from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
|
|
24
21
|
from data_designer.engine.models.utils import prompt_to_messages, str_to_message
|
|
25
22
|
from data_designer.engine.secret_resolver import SecretResolver
|
|
23
|
+
from data_designer.lazy_heavy_imports import litellm
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
import litellm
|
|
26
27
|
|
|
27
28
|
logger = logging.getLogger(__name__)
|
|
28
29
|
|
|
@@ -65,7 +66,9 @@ class ModelFacade:
|
|
|
65
66
|
def usage_stats(self) -> ModelUsageStats:
|
|
66
67
|
return self._usage_stats
|
|
67
68
|
|
|
68
|
-
def completion(
|
|
69
|
+
def completion(
|
|
70
|
+
self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs
|
|
71
|
+
) -> litellm.ModelResponse:
|
|
69
72
|
logger.debug(
|
|
70
73
|
f"Prompting model {self.model_name!r}...",
|
|
71
74
|
extra={"model": self.model_name, "messages": messages},
|
|
@@ -236,14 +239,14 @@ class ModelFacade:
|
|
|
236
239
|
) from exc
|
|
237
240
|
return output_obj, reasoning_trace
|
|
238
241
|
|
|
239
|
-
def _get_litellm_deployment(self, model_config: ModelConfig) -> DeploymentTypedDict:
|
|
242
|
+
def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict:
|
|
240
243
|
provider = self._model_provider_registry.get_provider(model_config.provider)
|
|
241
244
|
api_key = None
|
|
242
245
|
if provider.api_key:
|
|
243
246
|
api_key = self._secret_resolver.resolve(provider.api_key)
|
|
244
247
|
api_key = api_key or "not-used-but-required"
|
|
245
248
|
|
|
246
|
-
litellm_params = LiteLLM_Params(
|
|
249
|
+
litellm_params = litellm.LiteLLM_Params(
|
|
247
250
|
model=f"{provider.provider_type}/{model_config.model}",
|
|
248
251
|
api_base=provider.endpoint,
|
|
249
252
|
api_key=api_key,
|
|
@@ -253,7 +256,7 @@ class ModelFacade:
|
|
|
253
256
|
"litellm_params": litellm_params.model_dump(),
|
|
254
257
|
}
|
|
255
258
|
|
|
256
|
-
def _track_usage(self, response: ModelResponse | None) -> None:
|
|
259
|
+
def _track_usage(self, response: litellm.types.utils.ModelResponse | None) -> None:
|
|
257
260
|
if response is None:
|
|
258
261
|
self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
|
|
259
262
|
return
|
|
@@ -270,7 +273,7 @@ class ModelFacade:
|
|
|
270
273
|
request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
|
|
271
274
|
)
|
|
272
275
|
|
|
273
|
-
def _track_usage_from_embedding(self, response: EmbeddingResponse | None) -> None:
|
|
276
|
+
def _track_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None:
|
|
274
277
|
if response is None:
|
|
275
278
|
self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
|
|
276
279
|
return
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from data_designer.config.models import ModelConfig
|
|
9
|
+
from data_designer.engine.model_provider import ModelProviderRegistry
|
|
10
|
+
from data_designer.engine.secret_resolver import SecretResolver
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from data_designer.engine.models.registry import ModelRegistry
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def create_model_registry(
|
|
17
|
+
*,
|
|
18
|
+
model_configs: list[ModelConfig] | None = None,
|
|
19
|
+
secret_resolver: SecretResolver,
|
|
20
|
+
model_provider_registry: ModelProviderRegistry,
|
|
21
|
+
) -> ModelRegistry:
|
|
22
|
+
"""Factory function for creating a ModelRegistry instance.
|
|
23
|
+
|
|
24
|
+
Heavy dependencies (litellm, httpx) are deferred until this function is called.
|
|
25
|
+
This is a factory function pattern - imports inside factories are idiomatic Python
|
|
26
|
+
for lazy initialization.
|
|
27
|
+
"""
|
|
28
|
+
from data_designer.engine.models.facade import ModelFacade
|
|
29
|
+
from data_designer.engine.models.litellm_overrides import apply_litellm_patches
|
|
30
|
+
from data_designer.engine.models.registry import ModelRegistry
|
|
31
|
+
|
|
32
|
+
apply_litellm_patches()
|
|
33
|
+
|
|
34
|
+
def model_facade_factory(model_config, secret_resolver, model_provider_registry):
|
|
35
|
+
return ModelFacade(model_config, secret_resolver, model_provider_registry)
|
|
36
|
+
|
|
37
|
+
return ModelRegistry(
|
|
38
|
+
model_configs=model_configs,
|
|
39
|
+
secret_resolver=secret_resolver,
|
|
40
|
+
model_provider_registry=model_provider_registry,
|
|
41
|
+
model_facade_factory=model_facade_factory,
|
|
42
|
+
)
|