data-designer 0.3.8rc2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/cli/commands/__init__.py +1 -1
- data_designer/interface/__init__.py +21 -1
- data_designer/{_version.py → interface/_version.py} +2 -2
- data_designer/interface/data_designer.py +1 -7
- {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0.dist-info}/METADATA +10 -42
- data_designer-0.4.0.dist-info/RECORD +39 -0
- data_designer/__init__.py +0 -17
- data_designer/config/__init__.py +0 -2
- data_designer/config/analysis/__init__.py +0 -2
- data_designer/config/analysis/column_profilers.py +0 -159
- data_designer/config/analysis/column_statistics.py +0 -421
- data_designer/config/analysis/dataset_profiler.py +0 -84
- data_designer/config/analysis/utils/errors.py +0 -10
- data_designer/config/analysis/utils/reporting.py +0 -192
- data_designer/config/base.py +0 -69
- data_designer/config/column_configs.py +0 -470
- data_designer/config/column_types.py +0 -141
- data_designer/config/config_builder.py +0 -595
- data_designer/config/data_designer_config.py +0 -40
- data_designer/config/dataset_builders.py +0 -13
- data_designer/config/dataset_metadata.py +0 -18
- data_designer/config/default_model_settings.py +0 -129
- data_designer/config/errors.py +0 -24
- data_designer/config/exports.py +0 -145
- data_designer/config/interface.py +0 -55
- data_designer/config/models.py +0 -455
- data_designer/config/preview_results.py +0 -41
- data_designer/config/processors.py +0 -148
- data_designer/config/run_config.py +0 -51
- data_designer/config/sampler_constraints.py +0 -52
- data_designer/config/sampler_params.py +0 -639
- data_designer/config/seed.py +0 -116
- data_designer/config/seed_source.py +0 -84
- data_designer/config/seed_source_types.py +0 -19
- data_designer/config/utils/code_lang.py +0 -82
- data_designer/config/utils/constants.py +0 -363
- data_designer/config/utils/errors.py +0 -21
- data_designer/config/utils/info.py +0 -94
- data_designer/config/utils/io_helpers.py +0 -258
- data_designer/config/utils/misc.py +0 -78
- data_designer/config/utils/numerical_helpers.py +0 -30
- data_designer/config/utils/type_helpers.py +0 -106
- data_designer/config/utils/visualization.py +0 -482
- data_designer/config/validator_params.py +0 -94
- data_designer/engine/__init__.py +0 -2
- data_designer/engine/analysis/column_profilers/base.py +0 -49
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +0 -153
- data_designer/engine/analysis/column_profilers/registry.py +0 -22
- data_designer/engine/analysis/column_statistics.py +0 -145
- data_designer/engine/analysis/dataset_profiler.py +0 -149
- data_designer/engine/analysis/errors.py +0 -9
- data_designer/engine/analysis/utils/column_statistics_calculations.py +0 -234
- data_designer/engine/analysis/utils/judge_score_processing.py +0 -132
- data_designer/engine/column_generators/__init__.py +0 -2
- data_designer/engine/column_generators/generators/__init__.py +0 -2
- data_designer/engine/column_generators/generators/base.py +0 -122
- data_designer/engine/column_generators/generators/embedding.py +0 -35
- data_designer/engine/column_generators/generators/expression.py +0 -55
- data_designer/engine/column_generators/generators/llm_completion.py +0 -113
- data_designer/engine/column_generators/generators/samplers.py +0 -69
- data_designer/engine/column_generators/generators/seed_dataset.py +0 -144
- data_designer/engine/column_generators/generators/validation.py +0 -140
- data_designer/engine/column_generators/registry.py +0 -60
- data_designer/engine/column_generators/utils/errors.py +0 -15
- data_designer/engine/column_generators/utils/generator_classification.py +0 -43
- data_designer/engine/column_generators/utils/judge_score_factory.py +0 -58
- data_designer/engine/column_generators/utils/prompt_renderer.py +0 -100
- data_designer/engine/compiler.py +0 -97
- data_designer/engine/configurable_task.py +0 -71
- data_designer/engine/dataset_builders/artifact_storage.py +0 -283
- data_designer/engine/dataset_builders/column_wise_builder.py +0 -335
- data_designer/engine/dataset_builders/errors.py +0 -15
- data_designer/engine/dataset_builders/multi_column_configs.py +0 -46
- data_designer/engine/dataset_builders/utils/__init__.py +0 -2
- data_designer/engine/dataset_builders/utils/concurrency.py +0 -212
- data_designer/engine/dataset_builders/utils/config_compiler.py +0 -62
- data_designer/engine/dataset_builders/utils/dag.py +0 -62
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +0 -200
- data_designer/engine/dataset_builders/utils/errors.py +0 -15
- data_designer/engine/errors.py +0 -51
- data_designer/engine/model_provider.py +0 -77
- data_designer/engine/models/__init__.py +0 -2
- data_designer/engine/models/errors.py +0 -300
- data_designer/engine/models/facade.py +0 -287
- data_designer/engine/models/factory.py +0 -42
- data_designer/engine/models/litellm_overrides.py +0 -179
- data_designer/engine/models/parsers/__init__.py +0 -2
- data_designer/engine/models/parsers/errors.py +0 -34
- data_designer/engine/models/parsers/parser.py +0 -235
- data_designer/engine/models/parsers/postprocessors.py +0 -93
- data_designer/engine/models/parsers/tag_parsers.py +0 -62
- data_designer/engine/models/parsers/types.py +0 -84
- data_designer/engine/models/recipes/base.py +0 -81
- data_designer/engine/models/recipes/response_recipes.py +0 -293
- data_designer/engine/models/registry.py +0 -146
- data_designer/engine/models/telemetry.py +0 -359
- data_designer/engine/models/usage.py +0 -73
- data_designer/engine/models/utils.py +0 -38
- data_designer/engine/processing/ginja/__init__.py +0 -2
- data_designer/engine/processing/ginja/ast.py +0 -65
- data_designer/engine/processing/ginja/environment.py +0 -463
- data_designer/engine/processing/ginja/exceptions.py +0 -56
- data_designer/engine/processing/ginja/record.py +0 -32
- data_designer/engine/processing/gsonschema/__init__.py +0 -2
- data_designer/engine/processing/gsonschema/exceptions.py +0 -15
- data_designer/engine/processing/gsonschema/schema_transformers.py +0 -83
- data_designer/engine/processing/gsonschema/types.py +0 -10
- data_designer/engine/processing/gsonschema/validators.py +0 -202
- data_designer/engine/processing/processors/base.py +0 -13
- data_designer/engine/processing/processors/drop_columns.py +0 -42
- data_designer/engine/processing/processors/registry.py +0 -25
- data_designer/engine/processing/processors/schema_transform.py +0 -49
- data_designer/engine/processing/utils.py +0 -169
- data_designer/engine/registry/base.py +0 -99
- data_designer/engine/registry/data_designer_registry.py +0 -39
- data_designer/engine/registry/errors.py +0 -12
- data_designer/engine/resources/managed_dataset_generator.py +0 -39
- data_designer/engine/resources/managed_dataset_repository.py +0 -197
- data_designer/engine/resources/managed_storage.py +0 -65
- data_designer/engine/resources/resource_provider.py +0 -77
- data_designer/engine/resources/seed_reader.py +0 -154
- data_designer/engine/sampling_gen/column.py +0 -91
- data_designer/engine/sampling_gen/constraints.py +0 -100
- data_designer/engine/sampling_gen/data_sources/base.py +0 -217
- data_designer/engine/sampling_gen/data_sources/errors.py +0 -12
- data_designer/engine/sampling_gen/data_sources/sources.py +0 -347
- data_designer/engine/sampling_gen/entities/__init__.py +0 -2
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +0 -86
- data_designer/engine/sampling_gen/entities/email_address_utils.py +0 -171
- data_designer/engine/sampling_gen/entities/errors.py +0 -10
- data_designer/engine/sampling_gen/entities/national_id_utils.py +0 -102
- data_designer/engine/sampling_gen/entities/person.py +0 -144
- data_designer/engine/sampling_gen/entities/phone_number.py +0 -128
- data_designer/engine/sampling_gen/errors.py +0 -26
- data_designer/engine/sampling_gen/generator.py +0 -122
- data_designer/engine/sampling_gen/jinja_utils.py +0 -64
- data_designer/engine/sampling_gen/people_gen.py +0 -199
- data_designer/engine/sampling_gen/person_constants.py +0 -56
- data_designer/engine/sampling_gen/schema.py +0 -147
- data_designer/engine/sampling_gen/schema_builder.py +0 -61
- data_designer/engine/sampling_gen/utils.py +0 -46
- data_designer/engine/secret_resolver.py +0 -82
- data_designer/engine/validation.py +0 -367
- data_designer/engine/validators/__init__.py +0 -19
- data_designer/engine/validators/base.py +0 -38
- data_designer/engine/validators/local_callable.py +0 -39
- data_designer/engine/validators/python.py +0 -254
- data_designer/engine/validators/remote.py +0 -89
- data_designer/engine/validators/sql.py +0 -65
- data_designer/errors.py +0 -7
- data_designer/essentials/__init__.py +0 -33
- data_designer/lazy_heavy_imports.py +0 -54
- data_designer/logging.py +0 -163
- data_designer/plugin_manager.py +0 -78
- data_designer/plugins/__init__.py +0 -8
- data_designer/plugins/errors.py +0 -15
- data_designer/plugins/plugin.py +0 -141
- data_designer/plugins/registry.py +0 -88
- data_designer/plugins/testing/__init__.py +0 -10
- data_designer/plugins/testing/stubs.py +0 -116
- data_designer/plugins/testing/utils.py +0 -20
- data_designer-0.3.8rc2.dist-info/RECORD +0 -196
- data_designer-0.3.8rc2.dist-info/licenses/LICENSE +0 -201
- {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0.dist-info}/WHEEL +0 -0
- {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,39 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
from typing import TYPE_CHECKING, Any
|
|
7
|
-
|
|
8
|
-
from data_designer.engine.resources.managed_dataset_repository import ManagedDatasetRepository
|
|
9
|
-
from data_designer.lazy_heavy_imports import pd
|
|
10
|
-
|
|
11
|
-
if TYPE_CHECKING:
|
|
12
|
-
import pandas as pd
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class ManagedDatasetGenerator:
|
|
16
|
-
def __init__(self, managed_datasets: ManagedDatasetRepository, dataset_name: str):
|
|
17
|
-
self.managed_datasets = managed_datasets
|
|
18
|
-
self.dataset_name = dataset_name
|
|
19
|
-
|
|
20
|
-
def generate_samples(
|
|
21
|
-
self,
|
|
22
|
-
size: int = 1,
|
|
23
|
-
evidence: dict[str, Any | list[Any]] = {},
|
|
24
|
-
) -> pd.DataFrame:
|
|
25
|
-
parameters = []
|
|
26
|
-
query = f"select * from {self.dataset_name}"
|
|
27
|
-
if evidence:
|
|
28
|
-
where_conditions = []
|
|
29
|
-
for column, values in evidence.items():
|
|
30
|
-
if values:
|
|
31
|
-
values = values if isinstance(values, list) else [values]
|
|
32
|
-
formatted_values = ["?"] * len(values)
|
|
33
|
-
condition = f"{column} IN ({', '.join(formatted_values)})"
|
|
34
|
-
where_conditions.append(condition)
|
|
35
|
-
parameters.extend(values)
|
|
36
|
-
if where_conditions:
|
|
37
|
-
query += " where " + " and ".join(where_conditions)
|
|
38
|
-
query += f" order by random() limit {size}"
|
|
39
|
-
return self.managed_datasets.query(query, parameters)
|
|
@@ -1,197 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
import logging
|
|
7
|
-
import tempfile
|
|
8
|
-
import threading
|
|
9
|
-
import time
|
|
10
|
-
from abc import ABC, abstractmethod
|
|
11
|
-
from dataclasses import dataclass
|
|
12
|
-
from functools import cached_property
|
|
13
|
-
from pathlib import Path
|
|
14
|
-
from typing import TYPE_CHECKING, Any
|
|
15
|
-
|
|
16
|
-
from data_designer.config.utils.constants import LOCALES_WITH_MANAGED_DATASETS
|
|
17
|
-
from data_designer.engine.resources.managed_storage import LocalBlobStorageProvider, ManagedBlobStorage
|
|
18
|
-
from data_designer.lazy_heavy_imports import duckdb, pd
|
|
19
|
-
|
|
20
|
-
if TYPE_CHECKING:
|
|
21
|
-
import duckdb
|
|
22
|
-
import pandas as pd
|
|
23
|
-
|
|
24
|
-
logger = logging.getLogger(__name__)
|
|
25
|
-
|
|
26
|
-
DATASETS_ROOT = "datasets"
|
|
27
|
-
"""
|
|
28
|
-
Path in object storage to managed datasets
|
|
29
|
-
"""
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
@dataclass
|
|
33
|
-
class Table:
|
|
34
|
-
"""
|
|
35
|
-
Managed datasets are organized by dataset by table under a root
|
|
36
|
-
table path in object storage.
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
source: str
|
|
40
|
-
"""
|
|
41
|
-
Table source path
|
|
42
|
-
"""
|
|
43
|
-
|
|
44
|
-
schema: str = "main"
|
|
45
|
-
"""
|
|
46
|
-
Specifies the schema to use when registering the table.
|
|
47
|
-
|
|
48
|
-
Note: this is not the schema of the table, but rather the _database_
|
|
49
|
-
schema to associated with the table.
|
|
50
|
-
"""
|
|
51
|
-
|
|
52
|
-
@cached_property
|
|
53
|
-
def name(self) -> str:
|
|
54
|
-
return Path(self.source).stem
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
DataCatalog = list[Table]
|
|
58
|
-
|
|
59
|
-
# For now we hardcode the remote data catalog in code. This make it easier
|
|
60
|
-
# initialize the data catalog. Eventually we can make this work more
|
|
61
|
-
# dynamically once this data catalog pattern becomes more widely adopted.
|
|
62
|
-
DEFAULT_DATA_CATALOG: DataCatalog = [Table(f"{locale}.parquet") for locale in LOCALES_WITH_MANAGED_DATASETS]
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
class ManagedDatasetRepository(ABC):
|
|
66
|
-
@abstractmethod
|
|
67
|
-
def query(self, sql: str, parameters: list[Any]) -> pd.DataFrame: ...
|
|
68
|
-
|
|
69
|
-
@property
|
|
70
|
-
@abstractmethod
|
|
71
|
-
def data_catalog(self) -> DataCatalog: ...
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
class DuckDBDatasetRepository(ManagedDatasetRepository):
|
|
75
|
-
"""
|
|
76
|
-
Provides a duckdb based sql interface over Gretel managed datasets.
|
|
77
|
-
"""
|
|
78
|
-
|
|
79
|
-
_default_config = {"threads": 2, "memory_limit": "4 gb"}
|
|
80
|
-
|
|
81
|
-
def __init__(
|
|
82
|
-
self,
|
|
83
|
-
blob_storage: ManagedBlobStorage,
|
|
84
|
-
config: dict | None = None,
|
|
85
|
-
data_catalog: DataCatalog = DEFAULT_DATA_CATALOG,
|
|
86
|
-
datasets_root: str = DATASETS_ROOT,
|
|
87
|
-
use_cache: bool = True,
|
|
88
|
-
):
|
|
89
|
-
"""
|
|
90
|
-
Create a new DuckDB backed dataset repository
|
|
91
|
-
|
|
92
|
-
Args:
|
|
93
|
-
blob_storage: A managed blob storage provider
|
|
94
|
-
config: DuckDB configuration options,
|
|
95
|
-
https://duckdb.org/docs/configuration/overview.html#configuration-reference
|
|
96
|
-
data_catalog: A list of tables to register with the DuckDB instance
|
|
97
|
-
datasets_root: The root path in blob storage to managed datasets
|
|
98
|
-
use_cache: Whether to cache datasets locally. Trades off disk memory
|
|
99
|
-
and startup time for faster queries.
|
|
100
|
-
"""
|
|
101
|
-
self._data_catalog = data_catalog
|
|
102
|
-
self._data_sets_root = datasets_root
|
|
103
|
-
self._blob_storage = blob_storage
|
|
104
|
-
self._config = self._default_config if config is None else config
|
|
105
|
-
self._use_cache = use_cache
|
|
106
|
-
|
|
107
|
-
# Configure database and register tables
|
|
108
|
-
self.db = duckdb.connect(config=self._config)
|
|
109
|
-
|
|
110
|
-
# Dataset registration completion is tracked with an event. Consumers can
|
|
111
|
-
# wait on this event to ensure the catalog is ready.
|
|
112
|
-
self._registration_event = threading.Event()
|
|
113
|
-
self._register_lock = threading.Lock()
|
|
114
|
-
|
|
115
|
-
# Kick off dataset registration in a background thread so that IO-heavy
|
|
116
|
-
# caching and view creation can run asynchronously without blocking the
|
|
117
|
-
# caller that constructs this repository instance.
|
|
118
|
-
self._register_thread = threading.Thread(target=self._register_datasets, daemon=True)
|
|
119
|
-
self._register_thread.start()
|
|
120
|
-
|
|
121
|
-
def _register_datasets(self):
|
|
122
|
-
# Just in case this method gets called from inside a thread.
|
|
123
|
-
# This operation isn't thread-safe by default, so we
|
|
124
|
-
# synchronize the registration process.
|
|
125
|
-
if self._registration_event.is_set():
|
|
126
|
-
return
|
|
127
|
-
with self._register_lock:
|
|
128
|
-
# check once more to see if the catalog is ready it's possible a
|
|
129
|
-
# previous thread already registered the dataset.
|
|
130
|
-
if self._registration_event.is_set():
|
|
131
|
-
return
|
|
132
|
-
try:
|
|
133
|
-
for table in self.data_catalog:
|
|
134
|
-
key = table.source if table.schema == "main" else f"{table.schema}/{table.source}"
|
|
135
|
-
if self._use_cache:
|
|
136
|
-
tmp_root = Path(tempfile.gettempdir()) / "dd_cache"
|
|
137
|
-
local_path = tmp_root / key
|
|
138
|
-
local_path.parent.mkdir(parents=True, exist_ok=True)
|
|
139
|
-
if not local_path.exists():
|
|
140
|
-
start = time.time()
|
|
141
|
-
logger.debug("Caching database %s to %s", table.name, local_path)
|
|
142
|
-
with self._blob_storage.get_blob(f"{self._data_sets_root}/{key}") as src_fd:
|
|
143
|
-
with open(local_path, "wb") as dst_fd:
|
|
144
|
-
dst_fd.write(src_fd.read())
|
|
145
|
-
logger.debug(
|
|
146
|
-
"Cached database %s in %.2f s",
|
|
147
|
-
table.name,
|
|
148
|
-
time.time() - start,
|
|
149
|
-
)
|
|
150
|
-
data_path = local_path.as_posix()
|
|
151
|
-
else:
|
|
152
|
-
data_path = self._blob_storage.uri_for_key(f"{self._data_sets_root}/{key}")
|
|
153
|
-
if table.schema != "main":
|
|
154
|
-
self.db.sql(f"CREATE SCHEMA IF NOT EXISTS {table.schema}")
|
|
155
|
-
logger.debug(f"Registering dataset {table.name} from {data_path}")
|
|
156
|
-
self.db.sql(f"CREATE VIEW {table.schema}.{table.name} AS FROM '{data_path}'")
|
|
157
|
-
|
|
158
|
-
logger.debug("DuckDBDatasetRepository registration complete")
|
|
159
|
-
|
|
160
|
-
except Exception as e:
|
|
161
|
-
logger.exception(f"Failed to register datasets: {str(e)}")
|
|
162
|
-
|
|
163
|
-
finally:
|
|
164
|
-
# Signal that registration is complete so any waiting queries can proceed.
|
|
165
|
-
self._registration_event.set()
|
|
166
|
-
|
|
167
|
-
def query(self, sql: str, parameters: list[Any]) -> pd.DataFrame:
|
|
168
|
-
# Ensure dataset registration has completed. Possible future optimization:
|
|
169
|
-
# pull datasets in parallel and only wait here if the query requires a
|
|
170
|
-
# table that isn't cached.
|
|
171
|
-
if not self._registration_event.is_set():
|
|
172
|
-
logger.debug("Waiting for dataset caching and registration to finish...")
|
|
173
|
-
self._registration_event.wait()
|
|
174
|
-
|
|
175
|
-
# the duckdb connection isn't thread-safe, so we create a new
|
|
176
|
-
# connection per query using cursor().
|
|
177
|
-
# more details here: https://duckdb.org/docs/stable/guides/python/multiple_threads.html
|
|
178
|
-
cursor = self.db.cursor()
|
|
179
|
-
try:
|
|
180
|
-
df = cursor.execute(sql, parameters).df()
|
|
181
|
-
finally:
|
|
182
|
-
cursor.close()
|
|
183
|
-
return df
|
|
184
|
-
|
|
185
|
-
@property
|
|
186
|
-
def data_catalog(self) -> DataCatalog:
|
|
187
|
-
return self._data_catalog
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
def load_managed_dataset_repository(blob_storage: ManagedBlobStorage, locales: list[str]) -> ManagedDatasetRepository:
|
|
191
|
-
return DuckDBDatasetRepository(
|
|
192
|
-
blob_storage,
|
|
193
|
-
config={"threads": 1, "memory_limit": "2 gb"},
|
|
194
|
-
data_catalog=[Table(f"{locale}.parquet") for locale in locales],
|
|
195
|
-
# Only cache if not using local storage.
|
|
196
|
-
use_cache=not isinstance(blob_storage, LocalBlobStorageProvider),
|
|
197
|
-
)
|
|
@@ -1,65 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
import logging
|
|
7
|
-
from abc import ABC, abstractmethod
|
|
8
|
-
from collections.abc import Iterator
|
|
9
|
-
from contextlib import contextmanager
|
|
10
|
-
from pathlib import Path
|
|
11
|
-
from typing import IO
|
|
12
|
-
|
|
13
|
-
logger = logging.getLogger(__name__)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class ManagedBlobStorage(ABC):
|
|
17
|
-
"""
|
|
18
|
-
Provides a low-level interface for access object in blob storage. This interface
|
|
19
|
-
can be used to access model weights, raw datasets, or any artifact in blob
|
|
20
|
-
storage.
|
|
21
|
-
|
|
22
|
-
If you want a high-level interface for accessing datasets, use the `ManagedDatasetRepository`
|
|
23
|
-
which provides a high-level SQL interface over each dataset.
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
@abstractmethod
|
|
27
|
-
@contextmanager
|
|
28
|
-
def get_blob(self, blob_key: str) -> Iterator[IO]: ...
|
|
29
|
-
|
|
30
|
-
@abstractmethod
|
|
31
|
-
def _key_uri_builder(self, key: str) -> str: ...
|
|
32
|
-
|
|
33
|
-
def uri_for_key(self, key: str) -> str:
|
|
34
|
-
"""
|
|
35
|
-
Returns a qualified storage URI for a given a key. `key` is
|
|
36
|
-
normalized to ensure that and leading path components ("/") are removed.
|
|
37
|
-
"""
|
|
38
|
-
return self._key_uri_builder(key.lstrip("/"))
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
class LocalBlobStorageProvider(ManagedBlobStorage):
|
|
42
|
-
"""
|
|
43
|
-
Provide a local blob storage service. Useful for running
|
|
44
|
-
tests that don't require access to external infrastructure
|
|
45
|
-
"""
|
|
46
|
-
|
|
47
|
-
def __init__(self, root_path: Path) -> None:
|
|
48
|
-
self._root_path = root_path
|
|
49
|
-
|
|
50
|
-
@contextmanager
|
|
51
|
-
def get_blob(self, blob_key: str) -> Iterator[IO]:
|
|
52
|
-
with open(self._key_uri_builder(blob_key), "rb") as fd:
|
|
53
|
-
yield fd
|
|
54
|
-
|
|
55
|
-
def _key_uri_builder(self, key: str) -> str:
|
|
56
|
-
return f"{self._root_path}/{key}"
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def init_managed_blob_storage(assets_storage: str) -> ManagedBlobStorage:
|
|
60
|
-
path = Path(assets_storage)
|
|
61
|
-
if not path.exists():
|
|
62
|
-
raise RuntimeError(f"Local storage path {assets_storage!r} does not exist.")
|
|
63
|
-
|
|
64
|
-
logger.debug(f"Using local storage for managed datasets: {assets_storage!r}")
|
|
65
|
-
return LocalBlobStorageProvider(Path(assets_storage))
|
|
@@ -1,77 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
from data_designer.config.base import ConfigBase
|
|
7
|
-
from data_designer.config.dataset_metadata import DatasetMetadata
|
|
8
|
-
from data_designer.config.models import ModelConfig
|
|
9
|
-
from data_designer.config.run_config import RunConfig
|
|
10
|
-
from data_designer.config.seed_source import SeedSource
|
|
11
|
-
from data_designer.config.utils.type_helpers import StrEnum
|
|
12
|
-
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
|
|
13
|
-
from data_designer.engine.model_provider import ModelProviderRegistry
|
|
14
|
-
from data_designer.engine.models.factory import create_model_registry
|
|
15
|
-
from data_designer.engine.models.registry import ModelRegistry
|
|
16
|
-
from data_designer.engine.resources.managed_storage import ManagedBlobStorage, init_managed_blob_storage
|
|
17
|
-
from data_designer.engine.resources.seed_reader import SeedReader, SeedReaderRegistry
|
|
18
|
-
from data_designer.engine.secret_resolver import SecretResolver
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class ResourceType(StrEnum):
|
|
22
|
-
BLOB_STORAGE = "blob_storage"
|
|
23
|
-
MODEL_REGISTRY = "model_registry"
|
|
24
|
-
SEED_READER = "seed_reader"
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class ResourceProvider(ConfigBase):
|
|
28
|
-
artifact_storage: ArtifactStorage
|
|
29
|
-
blob_storage: ManagedBlobStorage | None = None
|
|
30
|
-
model_registry: ModelRegistry | None = None
|
|
31
|
-
run_config: RunConfig = RunConfig()
|
|
32
|
-
seed_reader: SeedReader | None = None
|
|
33
|
-
|
|
34
|
-
def get_dataset_metadata(self) -> DatasetMetadata:
|
|
35
|
-
"""Get metadata about the dataset being generated.
|
|
36
|
-
|
|
37
|
-
Returns:
|
|
38
|
-
DatasetMetadata with seed column names and other metadata.
|
|
39
|
-
"""
|
|
40
|
-
seed_column_names = []
|
|
41
|
-
if self.seed_reader is not None:
|
|
42
|
-
seed_column_names = self.seed_reader.get_column_names()
|
|
43
|
-
return DatasetMetadata(seed_column_names=seed_column_names)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def create_resource_provider(
|
|
47
|
-
*,
|
|
48
|
-
artifact_storage: ArtifactStorage,
|
|
49
|
-
model_configs: list[ModelConfig],
|
|
50
|
-
secret_resolver: SecretResolver,
|
|
51
|
-
model_provider_registry: ModelProviderRegistry,
|
|
52
|
-
seed_reader_registry: SeedReaderRegistry,
|
|
53
|
-
blob_storage: ManagedBlobStorage | None = None,
|
|
54
|
-
seed_dataset_source: SeedSource | None = None,
|
|
55
|
-
run_config: RunConfig | None = None,
|
|
56
|
-
) -> ResourceProvider:
|
|
57
|
-
"""Factory function for creating a ResourceProvider instance.
|
|
58
|
-
This function triggers lazy loading of heavy dependencies like litellm.
|
|
59
|
-
"""
|
|
60
|
-
seed_reader = None
|
|
61
|
-
if seed_dataset_source:
|
|
62
|
-
seed_reader = seed_reader_registry.get_reader(
|
|
63
|
-
seed_dataset_source,
|
|
64
|
-
secret_resolver,
|
|
65
|
-
)
|
|
66
|
-
|
|
67
|
-
return ResourceProvider(
|
|
68
|
-
artifact_storage=artifact_storage,
|
|
69
|
-
model_registry=create_model_registry(
|
|
70
|
-
model_configs=model_configs,
|
|
71
|
-
secret_resolver=secret_resolver,
|
|
72
|
-
model_provider_registry=model_provider_registry,
|
|
73
|
-
),
|
|
74
|
-
blob_storage=blob_storage or init_managed_blob_storage(),
|
|
75
|
-
seed_reader=seed_reader,
|
|
76
|
-
run_config=run_config or RunConfig(),
|
|
77
|
-
)
|
|
@@ -1,154 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
from abc import ABC, abstractmethod
|
|
7
|
-
from collections.abc import Sequence
|
|
8
|
-
from typing import TYPE_CHECKING, Generic, TypeVar, get_args, get_origin
|
|
9
|
-
|
|
10
|
-
from huggingface_hub import HfFileSystem
|
|
11
|
-
from typing_extensions import Self
|
|
12
|
-
|
|
13
|
-
from data_designer.config.seed_source import (
|
|
14
|
-
DataFrameSeedSource,
|
|
15
|
-
HuggingFaceSeedSource,
|
|
16
|
-
LocalFileSeedSource,
|
|
17
|
-
SeedSource,
|
|
18
|
-
)
|
|
19
|
-
from data_designer.engine.secret_resolver import SecretResolver
|
|
20
|
-
from data_designer.errors import DataDesignerError
|
|
21
|
-
from data_designer.lazy_heavy_imports import duckdb
|
|
22
|
-
|
|
23
|
-
if TYPE_CHECKING:
|
|
24
|
-
import duckdb
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class SeedReaderError(DataDesignerError): ...
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
SourceT = TypeVar("ConfigT", bound=SeedSource)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class SeedReader(ABC, Generic[SourceT]):
|
|
34
|
-
"""Base class for reading a seed dataset.
|
|
35
|
-
|
|
36
|
-
Seeds are read using duckdb. Reader implementations define duckdb connection setup details
|
|
37
|
-
and how to get a URI that can be queried with duckdb (i.e. "... FROM <uri> ...").
|
|
38
|
-
|
|
39
|
-
The Data Designer engine automatically supplies the appropriate SeedSource
|
|
40
|
-
and a SecretResolver to use for any secret fields in the config.
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
source: SourceT
|
|
44
|
-
secret_resolver: SecretResolver
|
|
45
|
-
|
|
46
|
-
@abstractmethod
|
|
47
|
-
def get_dataset_uri(self) -> str: ...
|
|
48
|
-
|
|
49
|
-
@abstractmethod
|
|
50
|
-
def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: ...
|
|
51
|
-
|
|
52
|
-
def attach(self, source: SourceT, secret_resolver: SecretResolver):
|
|
53
|
-
"""Attach a source and secret resolver to the instance.
|
|
54
|
-
|
|
55
|
-
This is called internally by the engine so that these objects do not
|
|
56
|
-
need to be provided in the reader's constructor.
|
|
57
|
-
"""
|
|
58
|
-
self.source = source
|
|
59
|
-
self.secret_resolver = secret_resolver
|
|
60
|
-
|
|
61
|
-
def get_column_names(self) -> list[str]:
|
|
62
|
-
"""Returns the seed dataset's column names"""
|
|
63
|
-
conn = self.create_duckdb_connection()
|
|
64
|
-
describe_query = f"DESCRIBE SELECT * FROM '{self.get_dataset_uri()}'"
|
|
65
|
-
column_descriptions = conn.execute(describe_query).fetchall()
|
|
66
|
-
return [col[0] for col in column_descriptions]
|
|
67
|
-
|
|
68
|
-
def get_seed_type(self) -> str:
|
|
69
|
-
"""Return the seed_type of the source class this reader is generic over."""
|
|
70
|
-
# Get the generic type arguments from the reader class
|
|
71
|
-
# Check __orig_bases__ for the generic base class
|
|
72
|
-
for base in getattr(type(self), "__orig_bases__", []):
|
|
73
|
-
origin = get_origin(base)
|
|
74
|
-
if origin is SeedReader:
|
|
75
|
-
args = get_args(base)
|
|
76
|
-
if args:
|
|
77
|
-
source_cls = args[0]
|
|
78
|
-
# Extract seed_type from the source class
|
|
79
|
-
if hasattr(source_cls, "model_fields") and "seed_type" in source_cls.model_fields:
|
|
80
|
-
field = source_cls.model_fields["seed_type"]
|
|
81
|
-
default_value = field.default
|
|
82
|
-
if isinstance(default_value, str):
|
|
83
|
-
return default_value
|
|
84
|
-
|
|
85
|
-
raise SeedReaderError("Reader does not have a valid generic source type with seed_type")
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
class LocalFileSeedReader(SeedReader[LocalFileSeedSource]):
|
|
89
|
-
def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
|
|
90
|
-
return duckdb.connect()
|
|
91
|
-
|
|
92
|
-
def get_dataset_uri(self) -> str:
|
|
93
|
-
return self.source.path
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
class HuggingFaceSeedReader(SeedReader[HuggingFaceSeedSource]):
|
|
97
|
-
def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
|
|
98
|
-
token = self.secret_resolver.resolve(self.source.token) if self.source.token else None
|
|
99
|
-
|
|
100
|
-
# Use skip_instance_cache to avoid fsspec-level caching
|
|
101
|
-
hffs = HfFileSystem(endpoint=self.source.endpoint, token=token, skip_instance_cache=True)
|
|
102
|
-
|
|
103
|
-
# Clear all internal caches to avoid stale metadata issues
|
|
104
|
-
# HfFileSystem caches file metadata (size, etc.) which can become stale when files are re-uploaded
|
|
105
|
-
if hasattr(hffs, "dircache"):
|
|
106
|
-
hffs.dircache.clear()
|
|
107
|
-
|
|
108
|
-
conn = duckdb.connect()
|
|
109
|
-
conn.register_filesystem(hffs)
|
|
110
|
-
return conn
|
|
111
|
-
|
|
112
|
-
def get_dataset_uri(self) -> str:
|
|
113
|
-
return f"hf://{self.source.path}"
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
class DataFrameSeedReader(SeedReader[DataFrameSeedSource]):
|
|
117
|
-
# This is a "magic string" that gets registered in the duckdb connection to make the dataframe directly queryable.
|
|
118
|
-
_table_name = "df"
|
|
119
|
-
|
|
120
|
-
def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
|
|
121
|
-
conn = duckdb.connect()
|
|
122
|
-
conn.register(self._table_name, self.source.df)
|
|
123
|
-
return conn
|
|
124
|
-
|
|
125
|
-
def get_dataset_uri(self) -> str:
|
|
126
|
-
return self._table_name
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
class SeedReaderRegistry:
|
|
130
|
-
def __init__(self, readers: Sequence[SeedReader]):
|
|
131
|
-
self._readers: dict[str, SeedReader] = {}
|
|
132
|
-
for reader in readers:
|
|
133
|
-
self.add_reader(reader)
|
|
134
|
-
|
|
135
|
-
def add_reader(self, reader: SeedReader) -> Self:
|
|
136
|
-
seed_type = reader.get_seed_type()
|
|
137
|
-
|
|
138
|
-
if seed_type in self._readers:
|
|
139
|
-
raise SeedReaderError(f"A reader for seed_type {seed_type!r} already exists")
|
|
140
|
-
|
|
141
|
-
self._readers[seed_type] = reader
|
|
142
|
-
return self
|
|
143
|
-
|
|
144
|
-
def get_reader(self, seed_dataset_source: SeedSource, secret_resolver: SecretResolver) -> SeedReader:
|
|
145
|
-
reader = self._get_reader_for_source(seed_dataset_source)
|
|
146
|
-
reader.attach(seed_dataset_source, secret_resolver)
|
|
147
|
-
return reader
|
|
148
|
-
|
|
149
|
-
def _get_reader_for_source(self, seed_dataset_source: SeedSource) -> SeedReader:
|
|
150
|
-
seed_type = seed_dataset_source.seed_type
|
|
151
|
-
try:
|
|
152
|
-
return self._readers[seed_type]
|
|
153
|
-
except KeyError:
|
|
154
|
-
raise SeedReaderError(f"No reader found for seed_type {seed_type!r}")
|
|
@@ -1,91 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
from typing import Any
|
|
7
|
-
|
|
8
|
-
from pydantic import field_serializer, model_validator
|
|
9
|
-
from typing_extensions import Self
|
|
10
|
-
|
|
11
|
-
from data_designer.config.column_configs import SamplerColumnConfig
|
|
12
|
-
from data_designer.config.sampler_params import SamplerParamsT, SamplerType
|
|
13
|
-
from data_designer.engine.sampling_gen.data_sources.base import DataSource
|
|
14
|
-
from data_designer.engine.sampling_gen.data_sources.sources import SamplerRegistry
|
|
15
|
-
from data_designer.engine.sampling_gen.jinja_utils import extract_column_names_from_expression
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class ConditionalDataColumn(SamplerColumnConfig):
|
|
19
|
-
@property
|
|
20
|
-
def _negative_condition(self) -> str:
|
|
21
|
-
conditions = list(self.conditional_params.keys())
|
|
22
|
-
return "not (" + " or ".join([f"({c})" for c in conditions]) + ")"
|
|
23
|
-
|
|
24
|
-
@property
|
|
25
|
-
def conditions(self) -> list[str]:
|
|
26
|
-
c = list(self.conditional_params.keys())
|
|
27
|
-
return c + [self._negative_condition] if len(c) > 0 else ["..."]
|
|
28
|
-
|
|
29
|
-
@property
|
|
30
|
-
def conditional_column_names(self) -> set[str]:
|
|
31
|
-
names = set()
|
|
32
|
-
for condition in self.conditional_params.keys():
|
|
33
|
-
names.update(extract_column_names_from_expression(condition))
|
|
34
|
-
return names
|
|
35
|
-
|
|
36
|
-
@field_serializer("sampler_type")
|
|
37
|
-
def serialize_sampler_type(self, sampler_type: SamplerType) -> str:
|
|
38
|
-
return SamplerType(sampler_type).value
|
|
39
|
-
|
|
40
|
-
@field_serializer("params")
|
|
41
|
-
def serialize_params(self, params: SamplerParamsT) -> dict:
|
|
42
|
-
return params.model_dump()
|
|
43
|
-
|
|
44
|
-
@field_serializer("conditional_params")
|
|
45
|
-
def serialize_conditional_params(self, conditional_params: dict[str, SamplerParamsT]) -> dict:
|
|
46
|
-
for condition, params in conditional_params.items():
|
|
47
|
-
conditional_params[condition] = params.model_dump()
|
|
48
|
-
return conditional_params
|
|
49
|
-
|
|
50
|
-
@model_validator(mode="before")
|
|
51
|
-
@classmethod
|
|
52
|
-
def validate_params_with_type(cls, data: Any) -> Any:
|
|
53
|
-
if not isinstance(data, dict) or "sampler_type" not in data:
|
|
54
|
-
return data
|
|
55
|
-
if isinstance(data["sampler_type"], str):
|
|
56
|
-
if not SamplerRegistry.is_registered(data["sampler_type"]):
|
|
57
|
-
raise ValueError(
|
|
58
|
-
f"Invalid sampler type: {data['sampler_type']}. Available samplers: {[s.value for s in SamplerType]}"
|
|
59
|
-
)
|
|
60
|
-
if "params" in data:
|
|
61
|
-
data["params"] = SamplerRegistry.get_sampler(data["sampler_type"])(params=data["params"]).params
|
|
62
|
-
if "conditional_params" in data:
|
|
63
|
-
for condition, params in data["conditional_params"].items():
|
|
64
|
-
data["conditional_params"][condition] = SamplerRegistry.get_sampler(data["sampler_type"])(
|
|
65
|
-
params=params
|
|
66
|
-
).params
|
|
67
|
-
return data
|
|
68
|
-
|
|
69
|
-
@model_validator(mode="after")
|
|
70
|
-
def validate_params(self) -> Self:
|
|
71
|
-
self.params = SamplerRegistry.validate_sampler_type(self.sampler_type)(params=self.params).params
|
|
72
|
-
return self
|
|
73
|
-
|
|
74
|
-
@model_validator(mode="after")
|
|
75
|
-
def validate_data_conversion(self) -> Self:
|
|
76
|
-
self.get_default_sampler().validate_data_conversion(self.convert_to)
|
|
77
|
-
return self
|
|
78
|
-
|
|
79
|
-
@model_validator(mode="after")
|
|
80
|
-
def validate_conditional_params(self) -> Self:
|
|
81
|
-
for condition, params in self.conditional_params.items():
|
|
82
|
-
self.conditional_params[condition] = SamplerRegistry.get_sampler(self.sampler_type)(params=params).params
|
|
83
|
-
return self
|
|
84
|
-
|
|
85
|
-
def get_default_sampler(self, **kwargs) -> DataSource:
|
|
86
|
-
return self.get_sampler("...", **kwargs)
|
|
87
|
-
|
|
88
|
-
def get_sampler(self, condition: str, **kwargs) -> DataSource:
|
|
89
|
-
if condition in ["...", self._negative_condition]:
|
|
90
|
-
return SamplerRegistry.get_sampler(self.sampler_type)(self.params, **kwargs)
|
|
91
|
-
return SamplerRegistry.get_sampler(self.sampler_type)(self.conditional_params[condition], **kwargs)
|