data-designer 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/cli/forms/model_builder.py +2 -2
- data_designer/config/config_builder.py +30 -113
- data_designer/config/errors.py +3 -0
- data_designer/config/exports.py +8 -6
- data_designer/config/models.py +7 -18
- data_designer/config/run_config.py +34 -0
- data_designer/config/seed.py +16 -46
- data_designer/config/seed_source.py +84 -0
- data_designer/config/utils/constants.py +27 -2
- data_designer/config/utils/io_helpers.py +0 -20
- data_designer/engine/column_generators/generators/seed_dataset.py +5 -5
- data_designer/engine/column_generators/generators/validation.py +3 -0
- data_designer/engine/column_generators/registry.py +1 -1
- data_designer/engine/compiler.py +69 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +3 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +1 -1
- data_designer/engine/models/facade.py +2 -0
- data_designer/engine/processing/gsonschema/validators.py +55 -0
- data_designer/engine/resources/resource_provider.py +17 -5
- data_designer/engine/resources/seed_reader.py +149 -0
- data_designer/essentials/__init__.py +2 -0
- data_designer/interface/data_designer.py +72 -62
- data_designer/plugin_manager.py +1 -1
- data_designer/plugins/errors.py +3 -0
- data_designer/plugins/plugin.py +82 -12
- data_designer/plugins/testing/__init__.py +8 -0
- data_designer/plugins/testing/stubs.py +145 -0
- data_designer/plugins/testing/utils.py +11 -0
- {data_designer-0.2.3.dist-info → data_designer-0.3.1.dist-info}/METADATA +3 -3
- {data_designer-0.2.3.dist-info → data_designer-0.3.1.dist-info}/RECORD +35 -30
- data_designer/config/datastore.py +0 -187
- data_designer/engine/resources/seed_dataset_data_store.py +0 -84
- /data_designer/{config/utils → engine}/validation.py +0 -0
- {data_designer-0.2.3.dist-info → data_designer-0.3.1.dist-info}/WHEEL +0 -0
- {data_designer-0.2.3.dist-info → data_designer-0.3.1.dist-info}/entry_points.txt +0 -0
- {data_designer-0.2.3.dist-info → data_designer-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -282,6 +282,10 @@ OPENAI_PROVIDER_NAME = "openai"
|
|
|
282
282
|
|
|
283
283
|
OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY"
|
|
284
284
|
|
|
285
|
+
OPENROUTER_PROVIDER_NAME = "openrouter"
|
|
286
|
+
|
|
287
|
+
OPENROUTER_API_KEY_ENV_VAR_NAME = "OPENROUTER_API_KEY"
|
|
288
|
+
|
|
285
289
|
PREDEFINED_PROVIDERS = [
|
|
286
290
|
{
|
|
287
291
|
"name": NVIDIA_PROVIDER_NAME,
|
|
@@ -295,6 +299,12 @@ PREDEFINED_PROVIDERS = [
|
|
|
295
299
|
"provider_type": "openai",
|
|
296
300
|
"api_key": OPENAI_API_KEY_ENV_VAR_NAME,
|
|
297
301
|
},
|
|
302
|
+
{
|
|
303
|
+
"name": OPENROUTER_PROVIDER_NAME,
|
|
304
|
+
"endpoint": "https://openrouter.ai/api/v1",
|
|
305
|
+
"provider_type": "openai",
|
|
306
|
+
"api_key": OPENROUTER_API_KEY_ENV_VAR_NAME,
|
|
307
|
+
},
|
|
298
308
|
]
|
|
299
309
|
|
|
300
310
|
|
|
@@ -302,11 +312,14 @@ DEFAULT_TEXT_INFERENCE_PARAMS = {"temperature": 0.85, "top_p": 0.95}
|
|
|
302
312
|
DEFAULT_REASONING_INFERENCE_PARAMS = {"temperature": 0.35, "top_p": 0.95}
|
|
303
313
|
DEFAULT_VISION_INFERENCE_PARAMS = {"temperature": 0.85, "top_p": 0.95}
|
|
304
314
|
DEFAULT_EMBEDDING_INFERENCE_PARAMS = {"encoding_format": "float"}
|
|
305
|
-
|
|
315
|
+
NEMOTRON_3_NANO_30B_A3B_INFERENCE_PARAMS = {"temperature": 1.0, "top_p": 1.0}
|
|
306
316
|
|
|
307
317
|
PREDEFINED_PROVIDERS_MODEL_MAP = {
|
|
308
318
|
NVIDIA_PROVIDER_NAME: {
|
|
309
|
-
"text": {
|
|
319
|
+
"text": {
|
|
320
|
+
"model": "nvidia/nemotron-3-nano-30b-a3b",
|
|
321
|
+
"inference_parameters": NEMOTRON_3_NANO_30B_A3B_INFERENCE_PARAMS,
|
|
322
|
+
},
|
|
310
323
|
"reasoning": {"model": "openai/gpt-oss-20b", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
|
|
311
324
|
"vision": {"model": "nvidia/nemotron-nano-12b-v2-vl", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
|
|
312
325
|
"embedding": {
|
|
@@ -320,6 +333,18 @@ PREDEFINED_PROVIDERS_MODEL_MAP = {
|
|
|
320
333
|
"vision": {"model": "gpt-5", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
|
|
321
334
|
"embedding": {"model": "text-embedding-3-large", "inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS},
|
|
322
335
|
},
|
|
336
|
+
OPENROUTER_PROVIDER_NAME: {
|
|
337
|
+
"text": {
|
|
338
|
+
"model": "nvidia/nemotron-3-nano-30b-a3b",
|
|
339
|
+
"inference_parameters": NEMOTRON_3_NANO_30B_A3B_INFERENCE_PARAMS,
|
|
340
|
+
},
|
|
341
|
+
"reasoning": {"model": "openai/gpt-oss-20b", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
|
|
342
|
+
"vision": {"model": "nvidia/nemotron-nano-12b-v2-vl", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
|
|
343
|
+
"embedding": {
|
|
344
|
+
"model": "openai/text-embedding-3-large",
|
|
345
|
+
"inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS,
|
|
346
|
+
},
|
|
347
|
+
},
|
|
323
348
|
}
|
|
324
349
|
|
|
325
350
|
# Persona locale metadata - used by the CLI and the person sampler.
|
|
@@ -108,26 +108,6 @@ def read_parquet_dataset(path: Path) -> pd.DataFrame:
|
|
|
108
108
|
raise e
|
|
109
109
|
|
|
110
110
|
|
|
111
|
-
def write_seed_dataset(dataframe: pd.DataFrame, file_path: Path) -> None:
|
|
112
|
-
"""Write a seed dataset to a file in the specified format.
|
|
113
|
-
|
|
114
|
-
Supported file extensions: .parquet, .csv, .json, .jsonl
|
|
115
|
-
|
|
116
|
-
Args:
|
|
117
|
-
dataframe: The pandas DataFrame to write.
|
|
118
|
-
file_path: The path where the dataset should be saved.
|
|
119
|
-
Format is inferred from the file extension.
|
|
120
|
-
"""
|
|
121
|
-
file_path = validate_dataset_file_path(file_path, should_exist=False)
|
|
122
|
-
logger.info(f"💾 Saving seed dataset to {file_path}")
|
|
123
|
-
if file_path.suffix.lower() == ".parquet":
|
|
124
|
-
dataframe.to_parquet(file_path, index=False)
|
|
125
|
-
elif file_path.suffix.lower() == ".csv":
|
|
126
|
-
dataframe.to_csv(file_path, index=False)
|
|
127
|
-
elif file_path.suffix.lower() in {".json", ".jsonl"}:
|
|
128
|
-
dataframe.to_json(file_path, orient="records", lines=True)
|
|
129
|
-
|
|
130
|
-
|
|
131
111
|
def validate_dataset_file_path(file_path: str | Path, should_exist: bool = True) -> Path:
|
|
132
112
|
"""Validate that a dataset file path has a valid extension and optionally exists.
|
|
133
113
|
|
|
@@ -30,7 +30,7 @@ class SeedDatasetColumnGenerator(FromScratchColumnGenerator[SeedDatasetMultiColu
|
|
|
30
30
|
name="seed_dataset_column_generator",
|
|
31
31
|
description="Sample columns from a seed dataset.",
|
|
32
32
|
generation_strategy=GenerationStrategy.FULL_COLUMN,
|
|
33
|
-
required_resources=[ResourceType.
|
|
33
|
+
required_resources=[ResourceType.SEED_READER],
|
|
34
34
|
)
|
|
35
35
|
|
|
36
36
|
@property
|
|
@@ -39,10 +39,10 @@ class SeedDatasetColumnGenerator(FromScratchColumnGenerator[SeedDatasetMultiColu
|
|
|
39
39
|
|
|
40
40
|
@functools.cached_property
|
|
41
41
|
def duckdb_conn(self) -> duckdb.DuckDBPyConnection:
|
|
42
|
-
return self.resource_provider.
|
|
42
|
+
return self.resource_provider.seed_reader.create_duckdb_connection()
|
|
43
43
|
|
|
44
|
-
def generate(self,
|
|
45
|
-
return concat_datasets([self.generate_from_scratch(len(
|
|
44
|
+
def generate(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
45
|
+
return concat_datasets([self.generate_from_scratch(len(data)), data])
|
|
46
46
|
|
|
47
47
|
def generate_from_scratch(self, num_records: int) -> pd.DataFrame:
|
|
48
48
|
if num_records <= 0:
|
|
@@ -57,7 +57,7 @@ class SeedDatasetColumnGenerator(FromScratchColumnGenerator[SeedDatasetMultiColu
|
|
|
57
57
|
self._num_records_sampled = 0
|
|
58
58
|
self._batch_reader = None
|
|
59
59
|
self._df_remaining = None
|
|
60
|
-
self._dataset_uri = self.resource_provider.
|
|
60
|
+
self._dataset_uri = self.resource_provider.seed_reader.get_dataset_uri()
|
|
61
61
|
self._seed_dataset_size = self.duckdb_conn.execute(f"SELECT COUNT(*) FROM '{self._dataset_uri}'").fetchone()[0]
|
|
62
62
|
self._index_range = self._resolve_index_range()
|
|
63
63
|
|
|
@@ -123,11 +123,14 @@ class ValidationColumnGenerator(ColumnGenerator[ValidationColumnConfig]):
|
|
|
123
123
|
def error_callback(error: Exception, context: dict):
|
|
124
124
|
outputs[context["index"]] = ValidationResult.empty(size=len(batched_records[context["index"]]))
|
|
125
125
|
|
|
126
|
+
settings = self.resource_provider.run_config
|
|
126
127
|
with ConcurrentThreadExecutor(
|
|
127
128
|
max_workers=self.config.validator_params.max_parallel_requests,
|
|
128
129
|
column_name=self.config.name,
|
|
129
130
|
result_callback=result_callback,
|
|
130
131
|
error_callback=error_callback,
|
|
132
|
+
shutdown_error_rate=settings.shutdown_error_rate,
|
|
133
|
+
shutdown_error_window=settings.shutdown_error_window,
|
|
131
134
|
) as executor:
|
|
132
135
|
for i, batch in enumerate(batched_records):
|
|
133
136
|
executor.submit(lambda batch: self._validate_batch(validator, batch), batch, context={"index": i})
|
|
@@ -51,7 +51,7 @@ def create_default_column_generator_registry(with_plugins: bool = True) -> Colum
|
|
|
51
51
|
for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR):
|
|
52
52
|
registry.register(
|
|
53
53
|
DataDesignerColumnType(plugin.name),
|
|
54
|
-
plugin.
|
|
54
|
+
plugin.impl_cls,
|
|
55
55
|
plugin.config_cls,
|
|
56
56
|
)
|
|
57
57
|
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
from data_designer.config.column_configs import SeedDatasetColumnConfig
|
|
7
|
+
from data_designer.config.config_builder import DataDesignerConfigBuilder
|
|
8
|
+
from data_designer.config.data_designer_config import DataDesignerConfig
|
|
9
|
+
from data_designer.config.errors import InvalidConfigError
|
|
10
|
+
from data_designer.engine.resources.resource_provider import ResourceProvider
|
|
11
|
+
from data_designer.engine.resources.seed_reader import SeedReader
|
|
12
|
+
from data_designer.engine.validation import ViolationLevel, rich_print_violations, validate_data_designer_config
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def compile_data_designer_config(
|
|
18
|
+
config_builder: DataDesignerConfigBuilder, resource_provider: ResourceProvider
|
|
19
|
+
) -> DataDesignerConfig:
|
|
20
|
+
config = config_builder.build()
|
|
21
|
+
_resolve_and_add_seed_columns(config, resource_provider.seed_reader)
|
|
22
|
+
_validate(config)
|
|
23
|
+
|
|
24
|
+
return config
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _resolve_and_add_seed_columns(config: DataDesignerConfig, seed_reader: SeedReader | None) -> None:
|
|
28
|
+
"""Fetches the seed dataset column names, ensures there are no conflicts
|
|
29
|
+
with other columns, and adds seed column configs to the DataDesignerConfig.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
if not seed_reader:
|
|
33
|
+
return
|
|
34
|
+
|
|
35
|
+
seed_col_names = seed_reader.get_column_names()
|
|
36
|
+
existing_columns = {column.name for column in config.columns}
|
|
37
|
+
colliding_columns = {name for name in seed_col_names if name in existing_columns}
|
|
38
|
+
if colliding_columns:
|
|
39
|
+
raise InvalidConfigError(
|
|
40
|
+
f"🛑 Seed dataset column(s) {colliding_columns} collide with existing column(s). "
|
|
41
|
+
"Please remove the conflicting columns or use a seed dataset with different column names."
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
config.columns.extend([SeedDatasetColumnConfig(name=col_name) for col_name in seed_col_names])
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _validate(config: DataDesignerConfig) -> None:
|
|
48
|
+
allowed_references = _get_allowed_references(config)
|
|
49
|
+
violations = validate_data_designer_config(
|
|
50
|
+
columns=config.columns,
|
|
51
|
+
processor_configs=config.processors or [],
|
|
52
|
+
allowed_references=allowed_references,
|
|
53
|
+
)
|
|
54
|
+
rich_print_violations(violations)
|
|
55
|
+
if len([v for v in violations if v.level == ViolationLevel.ERROR]) > 0:
|
|
56
|
+
raise InvalidConfigError(
|
|
57
|
+
"🛑 Your configuration contains validation errors. Please address the indicated issues and try again."
|
|
58
|
+
)
|
|
59
|
+
if len(violations) == 0:
|
|
60
|
+
logger.info("✅ Validation passed")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _get_allowed_references(config: DataDesignerConfig) -> list[str]:
|
|
64
|
+
refs = set[str]()
|
|
65
|
+
for column_config in config.columns:
|
|
66
|
+
refs.add(column_config.name)
|
|
67
|
+
for side_effect_column in column_config.side_effect_columns:
|
|
68
|
+
refs.add(side_effect_column)
|
|
69
|
+
return list(refs)
|
|
@@ -217,11 +217,14 @@ class ColumnWiseDatasetBuilder:
|
|
|
217
217
|
f"🐙 Processing {generator.config.column_type} column '{generator.config.name}' "
|
|
218
218
|
f"with {max_workers} concurrent workers"
|
|
219
219
|
)
|
|
220
|
+
settings = self._resource_provider.run_config
|
|
220
221
|
with ConcurrentThreadExecutor(
|
|
221
222
|
max_workers=max_workers,
|
|
222
223
|
column_name=generator.config.name,
|
|
223
224
|
result_callback=self._worker_result_callback,
|
|
224
225
|
error_callback=self._worker_error_callback,
|
|
226
|
+
shutdown_error_rate=settings.shutdown_error_rate,
|
|
227
|
+
shutdown_error_window=settings.shutdown_error_window,
|
|
225
228
|
) as executor:
|
|
226
229
|
for i, record in self.batch_manager.iter_current_batch():
|
|
227
230
|
executor.submit(lambda record: generator.generate(record), record, context={"index": i})
|
|
@@ -34,7 +34,7 @@ def compile_dataset_builder_column_configs(config: DataDesignerConfig) -> list[D
|
|
|
34
34
|
compiled_column_configs.append(
|
|
35
35
|
SeedDatasetMultiColumnConfig(
|
|
36
36
|
columns=seed_column_configs,
|
|
37
|
-
|
|
37
|
+
source=config.seed_config.source,
|
|
38
38
|
sampling_strategy=config.seed_config.sampling_strategy,
|
|
39
39
|
selection_strategy=config.seed_config.selection_strategy,
|
|
40
40
|
)
|
|
@@ -96,6 +96,8 @@ class ModelFacade:
|
|
|
96
96
|
kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs}
|
|
97
97
|
if self.model_provider.extra_body:
|
|
98
98
|
kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
|
|
99
|
+
if self.model_provider.extra_headers:
|
|
100
|
+
kwargs["extra_headers"] = self.model_provider.extra_headers
|
|
99
101
|
return kwargs
|
|
100
102
|
|
|
101
103
|
@catch_llm_exceptions
|
|
@@ -2,7 +2,9 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
|
+
import re
|
|
5
6
|
from copy import deepcopy
|
|
7
|
+
from decimal import ROUND_HALF_UP, Decimal
|
|
6
8
|
from typing import Any, overload
|
|
7
9
|
|
|
8
10
|
from jsonschema import Draft202012Validator, ValidationError, validators
|
|
@@ -70,6 +72,57 @@ def extend_jsonschema_validator_with_pruning(validator):
|
|
|
70
72
|
return validators.extend(validator, {"additionalProperties": prune_additional_properties})
|
|
71
73
|
|
|
72
74
|
|
|
75
|
+
def _get_decimal_info_from_anyof(schema: dict) -> tuple[bool, int | None]:
|
|
76
|
+
"""Check if schema is a Decimal anyOf and extract decimal places.
|
|
77
|
+
|
|
78
|
+
Returns (is_decimal, decimal_places) where decimal_places is None if no constraint.
|
|
79
|
+
"""
|
|
80
|
+
any_of = schema.get("anyOf")
|
|
81
|
+
if not isinstance(any_of, list):
|
|
82
|
+
return False, None
|
|
83
|
+
|
|
84
|
+
has_number = any(item.get("type") == "number" for item in any_of)
|
|
85
|
+
if not has_number:
|
|
86
|
+
return False, None
|
|
87
|
+
|
|
88
|
+
for item in any_of:
|
|
89
|
+
if item.get("type") == "string" and "pattern" in item:
|
|
90
|
+
match = re.search(r"\\d\{0,(\d+)\}", item["pattern"])
|
|
91
|
+
if match:
|
|
92
|
+
return True, int(match.group(1))
|
|
93
|
+
return True, None # Decimal without precision constraint
|
|
94
|
+
return False, None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def normalize_decimal_fields(obj: DataObjectT, schema: JSONSchemaT) -> DataObjectT:
|
|
98
|
+
"""Normalize Decimal-like anyOf fields to floats with proper precision."""
|
|
99
|
+
if not isinstance(obj, dict):
|
|
100
|
+
return obj
|
|
101
|
+
|
|
102
|
+
defs = schema.get("$defs", {})
|
|
103
|
+
obj_schema = defs.get(schema.get("$ref", "")[len("#/$defs/") :], schema)
|
|
104
|
+
props = obj_schema.get("properties", {})
|
|
105
|
+
|
|
106
|
+
for key, value in obj.items():
|
|
107
|
+
field_schema = props.get(key, {})
|
|
108
|
+
if "$ref" in field_schema:
|
|
109
|
+
field_schema = defs.get(field_schema["$ref"][len("#/$defs/") :], {})
|
|
110
|
+
|
|
111
|
+
if isinstance(value, dict):
|
|
112
|
+
obj[key] = normalize_decimal_fields(value, schema)
|
|
113
|
+
elif isinstance(value, list):
|
|
114
|
+
obj[key] = [normalize_decimal_fields(v, schema) if isinstance(v, dict) else v for v in value]
|
|
115
|
+
elif isinstance(value, (int, float, str)) and not isinstance(value, bool):
|
|
116
|
+
is_decimal, decimal_places = _get_decimal_info_from_anyof(field_schema)
|
|
117
|
+
if is_decimal:
|
|
118
|
+
d = Decimal(str(value))
|
|
119
|
+
if decimal_places is not None:
|
|
120
|
+
d = d.quantize(Decimal(f"0.{'0' * decimal_places}"), rounding=ROUND_HALF_UP)
|
|
121
|
+
obj[key] = float(d)
|
|
122
|
+
|
|
123
|
+
return obj
|
|
124
|
+
|
|
125
|
+
|
|
73
126
|
## We don't expect the outer data type (e.g. dict, list, or const) to be
|
|
74
127
|
## modified by the pruning action.
|
|
75
128
|
@overload
|
|
@@ -140,4 +193,6 @@ def validate(
|
|
|
140
193
|
except ValidationError as exc:
|
|
141
194
|
raise JSONSchemaValidationError(str(exc)) from exc
|
|
142
195
|
|
|
196
|
+
final_object = normalize_decimal_fields(final_object, schema)
|
|
197
|
+
|
|
143
198
|
return final_object
|
|
@@ -3,26 +3,29 @@
|
|
|
3
3
|
|
|
4
4
|
from data_designer.config.base import ConfigBase
|
|
5
5
|
from data_designer.config.models import ModelConfig
|
|
6
|
+
from data_designer.config.run_config import RunConfig
|
|
7
|
+
from data_designer.config.seed_source import SeedSource
|
|
6
8
|
from data_designer.config.utils.type_helpers import StrEnum
|
|
7
9
|
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
|
|
8
10
|
from data_designer.engine.model_provider import ModelProviderRegistry
|
|
9
11
|
from data_designer.engine.models.registry import ModelRegistry, create_model_registry
|
|
10
12
|
from data_designer.engine.resources.managed_storage import ManagedBlobStorage, init_managed_blob_storage
|
|
11
|
-
from data_designer.engine.resources.
|
|
13
|
+
from data_designer.engine.resources.seed_reader import SeedReader, SeedReaderRegistry
|
|
12
14
|
from data_designer.engine.secret_resolver import SecretResolver
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
class ResourceType(StrEnum):
|
|
16
18
|
BLOB_STORAGE = "blob_storage"
|
|
17
|
-
DATASTORE = "datastore"
|
|
18
19
|
MODEL_REGISTRY = "model_registry"
|
|
20
|
+
SEED_READER = "seed_reader"
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
class ResourceProvider(ConfigBase):
|
|
22
24
|
artifact_storage: ArtifactStorage
|
|
23
25
|
blob_storage: ManagedBlobStorage | None = None
|
|
24
|
-
datastore: SeedDatasetDataStore | None = None
|
|
25
26
|
model_registry: ModelRegistry | None = None
|
|
27
|
+
run_config: RunConfig = RunConfig()
|
|
28
|
+
seed_reader: SeedReader | None = None
|
|
26
29
|
|
|
27
30
|
|
|
28
31
|
def create_resource_provider(
|
|
@@ -31,16 +34,25 @@ def create_resource_provider(
|
|
|
31
34
|
model_configs: list[ModelConfig],
|
|
32
35
|
secret_resolver: SecretResolver,
|
|
33
36
|
model_provider_registry: ModelProviderRegistry,
|
|
34
|
-
|
|
37
|
+
seed_reader_registry: SeedReaderRegistry,
|
|
35
38
|
blob_storage: ManagedBlobStorage | None = None,
|
|
39
|
+
seed_dataset_source: SeedSource | None = None,
|
|
40
|
+
run_config: RunConfig | None = None,
|
|
36
41
|
) -> ResourceProvider:
|
|
42
|
+
seed_reader = None
|
|
43
|
+
if seed_dataset_source:
|
|
44
|
+
seed_reader = seed_reader_registry.get_reader(
|
|
45
|
+
seed_dataset_source,
|
|
46
|
+
secret_resolver,
|
|
47
|
+
)
|
|
37
48
|
return ResourceProvider(
|
|
38
49
|
artifact_storage=artifact_storage,
|
|
39
|
-
datastore=datastore,
|
|
40
50
|
model_registry=create_model_registry(
|
|
41
51
|
model_configs=model_configs,
|
|
42
52
|
secret_resolver=secret_resolver,
|
|
43
53
|
model_provider_registry=model_provider_registry,
|
|
44
54
|
),
|
|
45
55
|
blob_storage=blob_storage or init_managed_blob_storage(),
|
|
56
|
+
seed_reader=seed_reader,
|
|
57
|
+
run_config=run_config or RunConfig(),
|
|
46
58
|
)
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from typing import Generic, TypeVar, get_args, get_origin
|
|
7
|
+
|
|
8
|
+
import duckdb
|
|
9
|
+
from huggingface_hub import HfFileSystem
|
|
10
|
+
from typing_extensions import Self
|
|
11
|
+
|
|
12
|
+
from data_designer.config.seed_source import (
|
|
13
|
+
DataFrameSeedSource,
|
|
14
|
+
HuggingFaceSeedSource,
|
|
15
|
+
LocalFileSeedSource,
|
|
16
|
+
SeedSource,
|
|
17
|
+
)
|
|
18
|
+
from data_designer.engine.secret_resolver import SecretResolver
|
|
19
|
+
from data_designer.errors import DataDesignerError
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SeedReaderError(DataDesignerError): ...
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
SourceT = TypeVar("ConfigT", bound=SeedSource)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SeedReader(ABC, Generic[SourceT]):
|
|
29
|
+
"""Base class for reading a seed dataset.
|
|
30
|
+
|
|
31
|
+
Seeds are read using duckdb. Reader implementations define duckdb connection setup details
|
|
32
|
+
and how to get a URI that can be queried with duckdb (i.e. "... FROM <uri> ...").
|
|
33
|
+
|
|
34
|
+
The Data Designer engine automatically supplies the appropriate SeedSource
|
|
35
|
+
and a SecretResolver to use for any secret fields in the config.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
source: SourceT
|
|
39
|
+
secret_resolver: SecretResolver
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def get_dataset_uri(self) -> str: ...
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: ...
|
|
46
|
+
|
|
47
|
+
def attach(self, source: SourceT, secret_resolver: SecretResolver):
|
|
48
|
+
"""Attach a source and secret resolver to the instance.
|
|
49
|
+
|
|
50
|
+
This is called internally by the engine so that these objects do not
|
|
51
|
+
need to be provided in the reader's constructor.
|
|
52
|
+
"""
|
|
53
|
+
self.source = source
|
|
54
|
+
self.secret_resolver = secret_resolver
|
|
55
|
+
|
|
56
|
+
def get_column_names(self) -> list[str]:
|
|
57
|
+
"""Returns the seed dataset's column names"""
|
|
58
|
+
conn = self.create_duckdb_connection()
|
|
59
|
+
describe_query = f"DESCRIBE SELECT * FROM '{self.get_dataset_uri()}'"
|
|
60
|
+
column_descriptions = conn.execute(describe_query).fetchall()
|
|
61
|
+
return [col[0] for col in column_descriptions]
|
|
62
|
+
|
|
63
|
+
def get_seed_type(self) -> str:
|
|
64
|
+
"""Return the seed_type of the source class this reader is generic over."""
|
|
65
|
+
# Get the generic type arguments from the reader class
|
|
66
|
+
# Check __orig_bases__ for the generic base class
|
|
67
|
+
for base in getattr(type(self), "__orig_bases__", []):
|
|
68
|
+
origin = get_origin(base)
|
|
69
|
+
if origin is SeedReader:
|
|
70
|
+
args = get_args(base)
|
|
71
|
+
if args:
|
|
72
|
+
source_cls = args[0]
|
|
73
|
+
# Extract seed_type from the source class
|
|
74
|
+
if hasattr(source_cls, "model_fields") and "seed_type" in source_cls.model_fields:
|
|
75
|
+
field = source_cls.model_fields["seed_type"]
|
|
76
|
+
default_value = field.default
|
|
77
|
+
if isinstance(default_value, str):
|
|
78
|
+
return default_value
|
|
79
|
+
|
|
80
|
+
raise SeedReaderError("Reader does not have a valid generic source type with seed_type")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class LocalFileSeedReader(SeedReader[LocalFileSeedSource]):
|
|
84
|
+
def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
|
|
85
|
+
return duckdb.connect()
|
|
86
|
+
|
|
87
|
+
def get_dataset_uri(self) -> str:
|
|
88
|
+
return self.source.path
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class HuggingFaceSeedReader(SeedReader[HuggingFaceSeedSource]):
|
|
92
|
+
def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
|
|
93
|
+
token = self.secret_resolver.resolve(self.source.token) if self.source.token else None
|
|
94
|
+
|
|
95
|
+
# Use skip_instance_cache to avoid fsspec-level caching
|
|
96
|
+
hffs = HfFileSystem(endpoint=self.source.endpoint, token=token, skip_instance_cache=True)
|
|
97
|
+
|
|
98
|
+
# Clear all internal caches to avoid stale metadata issues
|
|
99
|
+
# HfFileSystem caches file metadata (size, etc.) which can become stale when files are re-uploaded
|
|
100
|
+
if hasattr(hffs, "dircache"):
|
|
101
|
+
hffs.dircache.clear()
|
|
102
|
+
|
|
103
|
+
conn = duckdb.connect()
|
|
104
|
+
conn.register_filesystem(hffs)
|
|
105
|
+
return conn
|
|
106
|
+
|
|
107
|
+
def get_dataset_uri(self) -> str:
|
|
108
|
+
return f"hf://{self.source.path}"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class DataFrameSeedReader(SeedReader[DataFrameSeedSource]):
|
|
112
|
+
# This is a "magic string" that gets registered in the duckdb connection to make the dataframe directly queryable.
|
|
113
|
+
_table_name = "df"
|
|
114
|
+
|
|
115
|
+
def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
|
|
116
|
+
conn = duckdb.connect()
|
|
117
|
+
conn.register(self._table_name, self.source.df)
|
|
118
|
+
return conn
|
|
119
|
+
|
|
120
|
+
def get_dataset_uri(self) -> str:
|
|
121
|
+
return self._table_name
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class SeedReaderRegistry:
|
|
125
|
+
def __init__(self, readers: Sequence[SeedReader]):
|
|
126
|
+
self._readers: dict[str, SeedReader] = {}
|
|
127
|
+
for reader in readers:
|
|
128
|
+
self.add_reader(reader)
|
|
129
|
+
|
|
130
|
+
def add_reader(self, reader: SeedReader) -> Self:
|
|
131
|
+
seed_type = reader.get_seed_type()
|
|
132
|
+
|
|
133
|
+
if seed_type in self._readers:
|
|
134
|
+
raise SeedReaderError(f"A reader for seed_type {seed_type!r} already exists")
|
|
135
|
+
|
|
136
|
+
self._readers[seed_type] = reader
|
|
137
|
+
return self
|
|
138
|
+
|
|
139
|
+
def get_reader(self, seed_dataset_source: SeedSource, secret_resolver: SecretResolver) -> SeedReader:
|
|
140
|
+
reader = self._get_reader_for_source(seed_dataset_source)
|
|
141
|
+
reader.attach(seed_dataset_source, secret_resolver)
|
|
142
|
+
return reader
|
|
143
|
+
|
|
144
|
+
def _get_reader_for_source(self, seed_dataset_source: SeedSource) -> SeedReader:
|
|
145
|
+
seed_type = seed_dataset_source.seed_type
|
|
146
|
+
try:
|
|
147
|
+
return self._readers[seed_type]
|
|
148
|
+
except KeyError:
|
|
149
|
+
raise SeedReaderError(f"No reader found for seed_type {seed_type!r}")
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
from data_designer.config.default_model_settings import resolve_seed_default_model_settings
|
|
5
5
|
from data_designer.config.exports import * # noqa: F403
|
|
6
|
+
from data_designer.config.run_config import RunConfig
|
|
6
7
|
from data_designer.config.validator_params import LocalCallableValidatorParams
|
|
7
8
|
from data_designer.interface.data_designer import DataDesigner
|
|
8
9
|
from data_designer.logging import LoggingConfig, configure_logging
|
|
@@ -21,6 +22,7 @@ def get_essentials_exports() -> list[str]:
|
|
|
21
22
|
local = [
|
|
22
23
|
DataDesigner.__name__,
|
|
23
24
|
LocalCallableValidatorParams.__name__,
|
|
25
|
+
RunConfig.__name__,
|
|
24
26
|
]
|
|
25
27
|
|
|
26
28
|
return logging + local + get_config_exports() # noqa: F405
|