data-designer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +15 -0
- data_designer/_version.py +34 -0
- data_designer/cli/README.md +236 -0
- data_designer/cli/__init__.py +6 -0
- data_designer/cli/commands/__init__.py +2 -0
- data_designer/cli/commands/list.py +130 -0
- data_designer/cli/commands/models.py +10 -0
- data_designer/cli/commands/providers.py +11 -0
- data_designer/cli/commands/reset.py +100 -0
- data_designer/cli/controllers/__init__.py +7 -0
- data_designer/cli/controllers/model_controller.py +246 -0
- data_designer/cli/controllers/provider_controller.py +317 -0
- data_designer/cli/forms/__init__.py +20 -0
- data_designer/cli/forms/builder.py +51 -0
- data_designer/cli/forms/field.py +180 -0
- data_designer/cli/forms/form.py +59 -0
- data_designer/cli/forms/model_builder.py +125 -0
- data_designer/cli/forms/provider_builder.py +76 -0
- data_designer/cli/main.py +44 -0
- data_designer/cli/repositories/__init__.py +8 -0
- data_designer/cli/repositories/base.py +39 -0
- data_designer/cli/repositories/model_repository.py +42 -0
- data_designer/cli/repositories/provider_repository.py +43 -0
- data_designer/cli/services/__init__.py +7 -0
- data_designer/cli/services/model_service.py +116 -0
- data_designer/cli/services/provider_service.py +111 -0
- data_designer/cli/ui.py +448 -0
- data_designer/cli/utils.py +47 -0
- data_designer/config/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +89 -0
- data_designer/config/analysis/column_statistics.py +274 -0
- data_designer/config/analysis/dataset_profiler.py +60 -0
- data_designer/config/analysis/utils/errors.py +8 -0
- data_designer/config/analysis/utils/reporting.py +188 -0
- data_designer/config/base.py +68 -0
- data_designer/config/column_configs.py +354 -0
- data_designer/config/column_types.py +168 -0
- data_designer/config/config_builder.py +660 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +11 -0
- data_designer/config/datastore.py +151 -0
- data_designer/config/default_model_settings.py +123 -0
- data_designer/config/errors.py +19 -0
- data_designer/config/interface.py +54 -0
- data_designer/config/models.py +231 -0
- data_designer/config/preview_results.py +32 -0
- data_designer/config/processors.py +41 -0
- data_designer/config/sampler_constraints.py +51 -0
- data_designer/config/sampler_params.py +604 -0
- data_designer/config/seed.py +145 -0
- data_designer/config/utils/code_lang.py +83 -0
- data_designer/config/utils/constants.py +313 -0
- data_designer/config/utils/errors.py +19 -0
- data_designer/config/utils/info.py +88 -0
- data_designer/config/utils/io_helpers.py +273 -0
- data_designer/config/utils/misc.py +81 -0
- data_designer/config/utils/numerical_helpers.py +28 -0
- data_designer/config/utils/type_helpers.py +100 -0
- data_designer/config/utils/validation.py +336 -0
- data_designer/config/utils/visualization.py +427 -0
- data_designer/config/validator_params.py +96 -0
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +55 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
- data_designer/engine/analysis/column_profilers/registry.py +20 -0
- data_designer/engine/analysis/column_statistics.py +142 -0
- data_designer/engine/analysis/dataset_profiler.py +125 -0
- data_designer/engine/analysis/errors.py +7 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +61 -0
- data_designer/engine/column_generators/generators/expression.py +63 -0
- data_designer/engine/column_generators/generators/llm_generators.py +172 -0
- data_designer/engine/column_generators/generators/samplers.py +75 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
- data_designer/engine/column_generators/generators/validation.py +147 -0
- data_designer/engine/column_generators/registry.py +56 -0
- data_designer/engine/column_generators/utils/errors.py +13 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
- data_designer/engine/configurable_task.py +82 -0
- data_designer/engine/dataset_builders/artifact_storage.py +181 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
- data_designer/engine/dataset_builders/errors.py +13 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
- data_designer/engine/dataset_builders/utils/dag.py +56 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
- data_designer/engine/dataset_builders/utils/errors.py +13 -0
- data_designer/engine/errors.py +49 -0
- data_designer/engine/model_provider.py +75 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +308 -0
- data_designer/engine/models/facade.py +225 -0
- data_designer/engine/models/litellm_overrides.py +162 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +236 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +60 -0
- data_designer/engine/models/parsers/types.py +82 -0
- data_designer/engine/models/recipes/base.py +79 -0
- data_designer/engine/models/recipes/response_recipes.py +291 -0
- data_designer/engine/models/registry.py +118 -0
- data_designer/engine/models/usage.py +75 -0
- data_designer/engine/models/utils.py +38 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +64 -0
- data_designer/engine/processing/ginja/environment.py +461 -0
- data_designer/engine/processing/ginja/exceptions.py +54 -0
- data_designer/engine/processing/ginja/record.py +30 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +8 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
- data_designer/engine/processing/gsonschema/types.py +8 -0
- data_designer/engine/processing/gsonschema/validators.py +143 -0
- data_designer/engine/processing/processors/base.py +15 -0
- data_designer/engine/processing/processors/drop_columns.py +46 -0
- data_designer/engine/processing/processors/registry.py +20 -0
- data_designer/engine/processing/utils.py +120 -0
- data_designer/engine/registry/base.py +97 -0
- data_designer/engine/registry/data_designer_registry.py +37 -0
- data_designer/engine/registry/errors.py +10 -0
- data_designer/engine/resources/managed_dataset_generator.py +35 -0
- data_designer/engine/resources/managed_dataset_repository.py +194 -0
- data_designer/engine/resources/managed_storage.py +63 -0
- data_designer/engine/resources/resource_provider.py +46 -0
- data_designer/engine/resources/seed_dataset_data_store.py +66 -0
- data_designer/engine/sampling_gen/column.py +89 -0
- data_designer/engine/sampling_gen/constraints.py +95 -0
- data_designer/engine/sampling_gen/data_sources/base.py +214 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
- data_designer/engine/sampling_gen/entities/errors.py +8 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
- data_designer/engine/sampling_gen/entities/person.py +142 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
- data_designer/engine/sampling_gen/errors.py +24 -0
- data_designer/engine/sampling_gen/generator.py +121 -0
- data_designer/engine/sampling_gen/jinja_utils.py +60 -0
- data_designer/engine/sampling_gen/people_gen.py +203 -0
- data_designer/engine/sampling_gen/person_constants.py +54 -0
- data_designer/engine/sampling_gen/schema.py +143 -0
- data_designer/engine/sampling_gen/schema_builder.py +59 -0
- data_designer/engine/sampling_gen/utils.py +40 -0
- data_designer/engine/secret_resolver.py +80 -0
- data_designer/engine/validators/__init__.py +17 -0
- data_designer/engine/validators/base.py +36 -0
- data_designer/engine/validators/local_callable.py +34 -0
- data_designer/engine/validators/python.py +245 -0
- data_designer/engine/validators/remote.py +83 -0
- data_designer/engine/validators/sql.py +60 -0
- data_designer/errors.py +5 -0
- data_designer/essentials/__init__.py +137 -0
- data_designer/interface/__init__.py +2 -0
- data_designer/interface/data_designer.py +351 -0
- data_designer/interface/errors.py +16 -0
- data_designer/interface/results.py +55 -0
- data_designer/logging.py +161 -0
- data_designer/plugin_manager.py +83 -0
- data_designer/plugins/__init__.py +6 -0
- data_designer/plugins/errors.py +10 -0
- data_designer/plugins/plugin.py +69 -0
- data_designer/plugins/registry.py +86 -0
- data_designer-0.1.0.dist-info/METADATA +173 -0
- data_designer-0.1.0.dist-info/RECORD +177 -0
- data_designer-0.1.0.dist-info/WHEEL +4 -0
- data_designer-0.1.0.dist-info/entry_points.txt +2 -0
- data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from functools import cached_property
|
|
7
|
+
import logging
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
import tempfile
|
|
10
|
+
import threading
|
|
11
|
+
import time
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import duckdb
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
17
|
+
from data_designer.config.utils.constants import LOCALES_WITH_MANAGED_DATASETS
|
|
18
|
+
from data_designer.engine.resources.managed_storage import LocalBlobStorageProvider, ManagedBlobStorage
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
DATASETS_ROOT = "datasets"
|
|
23
|
+
"""
|
|
24
|
+
Path in object storage to managed datasets
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class Table:
|
|
30
|
+
"""
|
|
31
|
+
Managed datasets are organized by dataset by table under a root
|
|
32
|
+
table path in object storage.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
source: str
|
|
36
|
+
"""
|
|
37
|
+
Table source path
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
schema: str = "main"
|
|
41
|
+
"""
|
|
42
|
+
Specifies the schema to use when registering the table.
|
|
43
|
+
|
|
44
|
+
Note: this is not the schema of the table, but rather the _database_
|
|
45
|
+
schema to associated with the table.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
@cached_property
|
|
49
|
+
def name(self) -> str:
|
|
50
|
+
return Path(self.source).stem
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
DataCatalog = list[Table]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# For now we hardcode the remote data catalog in code. This make it easier
|
|
57
|
+
# initialize the data catalog. Eventually we can make this work more
|
|
58
|
+
# dynamically once this data catalog pattern becomes more widely adopted.
|
|
59
|
+
DEFAULT_DATA_CATALOG: DataCatalog = [Table(f"{locale}.parquet") for locale in LOCALES_WITH_MANAGED_DATASETS]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ManagedDatasetRepository(ABC):
|
|
63
|
+
@abstractmethod
|
|
64
|
+
def query(self, sql: str, parameters: list[Any]) -> pd.DataFrame: ...
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def data_catalog(self) -> DataCatalog: ...
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class DuckDBDatasetRepository(ManagedDatasetRepository):
|
|
72
|
+
"""
|
|
73
|
+
Provides a duckdb based sql interface over Gretel managed datasets.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
_default_config = {"threads": 2, "memory_limit": "4 gb"}
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
blob_storage: ManagedBlobStorage,
|
|
81
|
+
config: dict | None = None,
|
|
82
|
+
data_catalog: DataCatalog = DEFAULT_DATA_CATALOG,
|
|
83
|
+
datasets_root: str = DATASETS_ROOT,
|
|
84
|
+
use_cache: bool = True,
|
|
85
|
+
):
|
|
86
|
+
"""
|
|
87
|
+
Create a new DuckDB backed dataset repository
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
blob_storage: A managed blob storage provider
|
|
91
|
+
config: DuckDB configuration options,
|
|
92
|
+
https://duckdb.org/docs/configuration/overview.html#configuration-reference
|
|
93
|
+
data_catalog: A list of tables to register with the DuckDB instance
|
|
94
|
+
datasets_root: The root path in blob storage to managed datasets
|
|
95
|
+
use_cache: Whether to cache datasets locally. Trades off disk memory
|
|
96
|
+
and startup time for faster queries.
|
|
97
|
+
"""
|
|
98
|
+
self._data_catalog = data_catalog
|
|
99
|
+
self._data_sets_root = datasets_root
|
|
100
|
+
self._blob_storage = blob_storage
|
|
101
|
+
self._config = self._default_config if config is None else config
|
|
102
|
+
self._use_cache = use_cache
|
|
103
|
+
|
|
104
|
+
# Configure database and register tables
|
|
105
|
+
self.db = duckdb.connect(config=self._config)
|
|
106
|
+
|
|
107
|
+
# Dataset registration completion is tracked with an event. Consumers can
|
|
108
|
+
# wait on this event to ensure the catalog is ready.
|
|
109
|
+
self._registration_event = threading.Event()
|
|
110
|
+
self._register_lock = threading.Lock()
|
|
111
|
+
|
|
112
|
+
# Kick off dataset registration in a background thread so that IO-heavy
|
|
113
|
+
# caching and view creation can run asynchronously without blocking the
|
|
114
|
+
# caller that constructs this repository instance.
|
|
115
|
+
self._register_thread = threading.Thread(target=self._register_datasets, daemon=True)
|
|
116
|
+
self._register_thread.start()
|
|
117
|
+
|
|
118
|
+
def _register_datasets(self):
|
|
119
|
+
# Just in case this method gets called from inside a thread.
|
|
120
|
+
# This operation isn't thread-safe by default, so we
|
|
121
|
+
# synchronize the registration process.
|
|
122
|
+
if self._registration_event.is_set():
|
|
123
|
+
return
|
|
124
|
+
with self._register_lock:
|
|
125
|
+
# check once more to see if the catalog is ready it's possible a
|
|
126
|
+
# previous thread already registered the dataset.
|
|
127
|
+
if self._registration_event.is_set():
|
|
128
|
+
return
|
|
129
|
+
try:
|
|
130
|
+
for table in self.data_catalog:
|
|
131
|
+
key = table.source if table.schema == "main" else f"{table.schema}/{table.source}"
|
|
132
|
+
if self._use_cache:
|
|
133
|
+
tmp_root = Path(tempfile.gettempdir()) / "dd_cache"
|
|
134
|
+
local_path = tmp_root / key
|
|
135
|
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
|
136
|
+
if not local_path.exists():
|
|
137
|
+
start = time.time()
|
|
138
|
+
logger.debug("Caching database %s to %s", table.name, local_path)
|
|
139
|
+
with self._blob_storage.get_blob(f"{self._data_sets_root}/{key}") as src_fd:
|
|
140
|
+
with open(local_path, "wb") as dst_fd:
|
|
141
|
+
dst_fd.write(src_fd.read())
|
|
142
|
+
logger.debug(
|
|
143
|
+
"Cached database %s in %.2f s",
|
|
144
|
+
table.name,
|
|
145
|
+
time.time() - start,
|
|
146
|
+
)
|
|
147
|
+
data_path = local_path.as_posix()
|
|
148
|
+
else:
|
|
149
|
+
data_path = self._blob_storage.uri_for_key(f"{self._data_sets_root}/{key}")
|
|
150
|
+
if table.schema != "main":
|
|
151
|
+
self.db.sql(f"CREATE SCHEMA IF NOT EXISTS {table.schema}")
|
|
152
|
+
logger.debug(f"Registering dataset {table.name} from {data_path}")
|
|
153
|
+
self.db.sql(f"CREATE VIEW {table.schema}.{table.name} AS FROM '{data_path}'")
|
|
154
|
+
|
|
155
|
+
logger.debug("DuckDBDatasetRepository registration complete")
|
|
156
|
+
|
|
157
|
+
except Exception as e:
|
|
158
|
+
logger.exception(f"Failed to register datasets: {str(e)}")
|
|
159
|
+
|
|
160
|
+
finally:
|
|
161
|
+
# Signal that registration is complete so any waiting queries can proceed.
|
|
162
|
+
self._registration_event.set()
|
|
163
|
+
|
|
164
|
+
def query(self, sql: str, parameters: list[Any]) -> pd.DataFrame:
|
|
165
|
+
# Ensure dataset registration has completed. Possible future optimization:
|
|
166
|
+
# pull datasets in parallel and only wait here if the query requires a
|
|
167
|
+
# table that isn't cached.
|
|
168
|
+
if not self._registration_event.is_set():
|
|
169
|
+
logger.debug("Waiting for dataset caching and registration to finish...")
|
|
170
|
+
self._registration_event.wait()
|
|
171
|
+
|
|
172
|
+
# the duckdb connection isn't thread-safe, so we create a new
|
|
173
|
+
# connection per query using cursor().
|
|
174
|
+
# more details here: https://duckdb.org/docs/stable/guides/python/multiple_threads.html
|
|
175
|
+
cursor = self.db.cursor()
|
|
176
|
+
try:
|
|
177
|
+
df = cursor.execute(sql, parameters).df()
|
|
178
|
+
finally:
|
|
179
|
+
cursor.close()
|
|
180
|
+
return df
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def data_catalog(self) -> DataCatalog:
|
|
184
|
+
return self._data_catalog
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def load_managed_dataset_repository(blob_storage: ManagedBlobStorage, locales: list[str]) -> ManagedDatasetRepository:
|
|
188
|
+
return DuckDBDatasetRepository(
|
|
189
|
+
blob_storage,
|
|
190
|
+
config={"threads": 1, "memory_limit": "2 gb"},
|
|
191
|
+
data_catalog=[Table(f"{locale}.parquet") for locale in locales],
|
|
192
|
+
# Only cache if not using local storage.
|
|
193
|
+
use_cache=not isinstance(blob_storage, LocalBlobStorageProvider),
|
|
194
|
+
)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from collections.abc import Iterator
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
import logging
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import IO
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ManagedBlobStorage(ABC):
|
|
15
|
+
"""
|
|
16
|
+
Provides a low-level interface for access object in blob storage. This interface
|
|
17
|
+
can be used to access model weights, raw datasets, or any artifact in blob
|
|
18
|
+
storage.
|
|
19
|
+
|
|
20
|
+
If you want a high-level interface for accessing datasets, use the `ManagedDatasetRepository`
|
|
21
|
+
which provides a high-level SQL interface over each dataset.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
@contextmanager
|
|
26
|
+
def get_blob(self, blob_key: str) -> Iterator[IO]: ...
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def _key_uri_builder(self, key: str) -> str: ...
|
|
30
|
+
|
|
31
|
+
def uri_for_key(self, key: str) -> str:
|
|
32
|
+
"""
|
|
33
|
+
Returns a qualified storage URI for a given a key. `key` is
|
|
34
|
+
normalized to ensure that and leading path components ("/") are removed.
|
|
35
|
+
"""
|
|
36
|
+
return self._key_uri_builder(key.lstrip("/"))
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class LocalBlobStorageProvider(ManagedBlobStorage):
|
|
40
|
+
"""
|
|
41
|
+
Provide a local blob storage service. Useful for running
|
|
42
|
+
tests that don't require access to external infrastructure
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, root_path: Path) -> None:
|
|
46
|
+
self._root_path = root_path
|
|
47
|
+
|
|
48
|
+
@contextmanager
|
|
49
|
+
def get_blob(self, blob_key: str) -> Iterator[IO]:
|
|
50
|
+
with open(self._key_uri_builder(blob_key), "rb") as fd:
|
|
51
|
+
yield fd
|
|
52
|
+
|
|
53
|
+
def _key_uri_builder(self, key: str) -> str:
|
|
54
|
+
return f"{self._root_path}/{key}"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def init_managed_blob_storage(assets_storage: str) -> ManagedBlobStorage:
|
|
58
|
+
path = Path(assets_storage)
|
|
59
|
+
if not path.exists():
|
|
60
|
+
raise RuntimeError(f"Local storage path {assets_storage!r} does not exist.")
|
|
61
|
+
|
|
62
|
+
logger.debug(f"Using local storage for managed datasets: {assets_storage!r}")
|
|
63
|
+
return LocalBlobStorageProvider(Path(assets_storage))
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.config.base import ConfigBase
|
|
5
|
+
from data_designer.config.models import ModelConfig
|
|
6
|
+
from data_designer.config.utils.type_helpers import StrEnum
|
|
7
|
+
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
|
|
8
|
+
from data_designer.engine.model_provider import ModelProviderRegistry
|
|
9
|
+
from data_designer.engine.models.registry import ModelRegistry, create_model_registry
|
|
10
|
+
from data_designer.engine.resources.managed_storage import ManagedBlobStorage, init_managed_blob_storage
|
|
11
|
+
from data_designer.engine.resources.seed_dataset_data_store import SeedDatasetDataStore
|
|
12
|
+
from data_designer.engine.secret_resolver import SecretResolver
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ResourceType(StrEnum):
|
|
16
|
+
BLOB_STORAGE = "blob_storage"
|
|
17
|
+
DATASTORE = "datastore"
|
|
18
|
+
MODEL_REGISTRY = "model_registry"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ResourceProvider(ConfigBase):
|
|
22
|
+
artifact_storage: ArtifactStorage
|
|
23
|
+
blob_storage: ManagedBlobStorage | None = None
|
|
24
|
+
datastore: SeedDatasetDataStore | None = None
|
|
25
|
+
model_registry: ModelRegistry | None = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def create_resource_provider(
|
|
29
|
+
*,
|
|
30
|
+
artifact_storage: ArtifactStorage,
|
|
31
|
+
model_configs: list[ModelConfig],
|
|
32
|
+
secret_resolver: SecretResolver,
|
|
33
|
+
model_provider_registry: ModelProviderRegistry,
|
|
34
|
+
datastore: SeedDatasetDataStore | None = None,
|
|
35
|
+
blob_storage: ManagedBlobStorage | None = None,
|
|
36
|
+
) -> ResourceProvider:
|
|
37
|
+
return ResourceProvider(
|
|
38
|
+
artifact_storage=artifact_storage,
|
|
39
|
+
datastore=datastore,
|
|
40
|
+
model_registry=create_model_registry(
|
|
41
|
+
model_configs=model_configs,
|
|
42
|
+
secret_resolver=secret_resolver,
|
|
43
|
+
model_provider_registry=model_provider_registry,
|
|
44
|
+
),
|
|
45
|
+
blob_storage=blob_storage or init_managed_blob_storage(),
|
|
46
|
+
)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
|
|
6
|
+
import duckdb
|
|
7
|
+
from huggingface_hub import HfApi, HfFileSystem
|
|
8
|
+
|
|
9
|
+
from data_designer.logging import quiet_noisy_logger
|
|
10
|
+
|
|
11
|
+
quiet_noisy_logger("httpx")
|
|
12
|
+
|
|
13
|
+
_HF_DATASETS_PREFIX = "hf://datasets/"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MalformedFileIdError(Exception):
|
|
17
|
+
"""Raised when file_id format is invalid."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SeedDatasetDataStore(ABC):
|
|
21
|
+
"""Abstract base class for dataset storage implementations."""
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: ...
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def get_dataset_uri(self, file_id: str) -> str: ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class LocalSeedDatasetDataStore(SeedDatasetDataStore):
|
|
31
|
+
"""Local filesystem-based dataset storage."""
|
|
32
|
+
|
|
33
|
+
def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
|
|
34
|
+
return duckdb.connect()
|
|
35
|
+
|
|
36
|
+
def get_dataset_uri(self, file_id: str) -> str:
|
|
37
|
+
return file_id
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class HfHubSeedDatasetDataStore(SeedDatasetDataStore):
|
|
41
|
+
"""Hugging Face and Data Store dataset storage."""
|
|
42
|
+
|
|
43
|
+
def __init__(self, endpoint: str, token: str | None):
|
|
44
|
+
self.hfapi = HfApi(endpoint=endpoint, token=token)
|
|
45
|
+
self.hffs = HfFileSystem(endpoint=endpoint, token=token)
|
|
46
|
+
|
|
47
|
+
def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
|
|
48
|
+
conn = duckdb.connect()
|
|
49
|
+
conn.register_filesystem(self.hffs)
|
|
50
|
+
return conn
|
|
51
|
+
|
|
52
|
+
def get_dataset_uri(self, file_id: str) -> str:
|
|
53
|
+
identifier = file_id.removeprefix(_HF_DATASETS_PREFIX)
|
|
54
|
+
repo_id, filename = self._get_repo_id_and_filename(identifier)
|
|
55
|
+
return f"{_HF_DATASETS_PREFIX}{repo_id}/{filename}"
|
|
56
|
+
|
|
57
|
+
def _get_repo_id_and_filename(self, identifier: str) -> tuple[str, str]:
|
|
58
|
+
"""Extract repo_id and filename from identifier."""
|
|
59
|
+
parts = identifier.split("/", 2)
|
|
60
|
+
if len(parts) < 3:
|
|
61
|
+
raise MalformedFileIdError(
|
|
62
|
+
"Could not extract repo id and filename from file_id, "
|
|
63
|
+
"expected 'hf://datasets/{repo-namespace}/{repo-name}/{filename}'"
|
|
64
|
+
)
|
|
65
|
+
repo_ns, repo_name, filename = parts
|
|
66
|
+
return f"{repo_ns}/{repo_name}", filename
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import field_serializer, model_validator
|
|
7
|
+
from typing_extensions import Self
|
|
8
|
+
|
|
9
|
+
from data_designer.config.column_configs import SamplerColumnConfig
|
|
10
|
+
from data_designer.config.sampler_params import SamplerParamsT, SamplerType
|
|
11
|
+
from data_designer.engine.sampling_gen.data_sources.base import DataSource
|
|
12
|
+
from data_designer.engine.sampling_gen.data_sources.sources import SamplerRegistry
|
|
13
|
+
from data_designer.engine.sampling_gen.jinja_utils import extract_column_names_from_expression
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ConditionalDataColumn(SamplerColumnConfig):
|
|
17
|
+
@property
|
|
18
|
+
def _negative_condition(self) -> str:
|
|
19
|
+
conditions = list(self.conditional_params.keys())
|
|
20
|
+
return "not (" + " or ".join([f"({c})" for c in conditions]) + ")"
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def conditions(self) -> list[str]:
|
|
24
|
+
c = list(self.conditional_params.keys())
|
|
25
|
+
return c + [self._negative_condition] if len(c) > 0 else ["..."]
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def conditional_column_names(self) -> set[str]:
|
|
29
|
+
names = set()
|
|
30
|
+
for condition in self.conditional_params.keys():
|
|
31
|
+
names.update(extract_column_names_from_expression(condition))
|
|
32
|
+
return names
|
|
33
|
+
|
|
34
|
+
@field_serializer("sampler_type")
|
|
35
|
+
def serialize_sampler_type(self, sampler_type: SamplerType) -> str:
|
|
36
|
+
return SamplerType(sampler_type).value
|
|
37
|
+
|
|
38
|
+
@field_serializer("params")
|
|
39
|
+
def serialize_params(self, params: SamplerParamsT) -> dict:
|
|
40
|
+
return params.model_dump()
|
|
41
|
+
|
|
42
|
+
@field_serializer("conditional_params")
|
|
43
|
+
def serialize_conditional_params(self, conditional_params: dict[str, SamplerParamsT]) -> dict:
|
|
44
|
+
for condition, params in conditional_params.items():
|
|
45
|
+
conditional_params[condition] = params.model_dump()
|
|
46
|
+
return conditional_params
|
|
47
|
+
|
|
48
|
+
@model_validator(mode="before")
|
|
49
|
+
@classmethod
|
|
50
|
+
def validate_params_with_type(cls, data: Any) -> Any:
|
|
51
|
+
if not isinstance(data, dict) or "sampler_type" not in data:
|
|
52
|
+
return data
|
|
53
|
+
if isinstance(data["sampler_type"], str):
|
|
54
|
+
if not SamplerRegistry.is_registered(data["sampler_type"]):
|
|
55
|
+
raise ValueError(
|
|
56
|
+
f"Invalid sampler type: {data['sampler_type']}. Available samplers: {[s.value for s in SamplerType]}"
|
|
57
|
+
)
|
|
58
|
+
if "params" in data:
|
|
59
|
+
data["params"] = SamplerRegistry.get_sampler(data["sampler_type"])(params=data["params"]).params
|
|
60
|
+
if "conditional_params" in data:
|
|
61
|
+
for condition, params in data["conditional_params"].items():
|
|
62
|
+
data["conditional_params"][condition] = SamplerRegistry.get_sampler(data["sampler_type"])(
|
|
63
|
+
params=params
|
|
64
|
+
).params
|
|
65
|
+
return data
|
|
66
|
+
|
|
67
|
+
@model_validator(mode="after")
|
|
68
|
+
def validate_params(self) -> Self:
|
|
69
|
+
self.params = SamplerRegistry.validate_sampler_type(self.sampler_type)(params=self.params).params
|
|
70
|
+
return self
|
|
71
|
+
|
|
72
|
+
@model_validator(mode="after")
|
|
73
|
+
def validate_data_conversion(self) -> Self:
|
|
74
|
+
self.get_default_sampler().validate_data_conversion(self.convert_to)
|
|
75
|
+
return self
|
|
76
|
+
|
|
77
|
+
@model_validator(mode="after")
|
|
78
|
+
def validate_conditional_params(self) -> Self:
|
|
79
|
+
for condition, params in self.conditional_params.items():
|
|
80
|
+
self.conditional_params[condition] = SamplerRegistry.get_sampler(self.sampler_type)(params=params).params
|
|
81
|
+
return self
|
|
82
|
+
|
|
83
|
+
def get_default_sampler(self, **kwargs) -> DataSource:
|
|
84
|
+
return self.get_sampler("...", **kwargs)
|
|
85
|
+
|
|
86
|
+
def get_sampler(self, condition: str, **kwargs) -> DataSource:
|
|
87
|
+
if condition in ["...", self._negative_condition]:
|
|
88
|
+
return SamplerRegistry.get_sampler(self.sampler_type)(self.params, **kwargs)
|
|
89
|
+
return SamplerRegistry.get_sampler(self.sampler_type)(self.conditional_params[condition], **kwargs)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from typing import Type
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
import pandas as pd
|
|
10
|
+
|
|
11
|
+
from data_designer.config.base import ConfigBase
|
|
12
|
+
from data_designer.config.sampler_constraints import (
|
|
13
|
+
ColumnInequalityConstraint,
|
|
14
|
+
Constraint,
|
|
15
|
+
ConstraintType,
|
|
16
|
+
InequalityOperator,
|
|
17
|
+
ScalarInequalityConstraint,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ConstraintChecker(ConfigBase, ABC):
|
|
22
|
+
constraint: Constraint
|
|
23
|
+
|
|
24
|
+
def get_required_column_names(self) -> tuple[str, ...]:
|
|
25
|
+
return (self.constraint.target_column,)
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def check(self, dataframe: pd.DataFrame) -> NDArray[np.bool_]: ...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class WithCompareMixin:
|
|
32
|
+
@property
|
|
33
|
+
def lhs(self) -> str:
|
|
34
|
+
return self.constraint.target_column
|
|
35
|
+
|
|
36
|
+
def compare(self, lhs: float | int | NDArray, rhs: float | int | NDArray) -> bool | NDArray[np.bool_]:
|
|
37
|
+
operator = {
|
|
38
|
+
InequalityOperator.LT: np.less,
|
|
39
|
+
InequalityOperator.LE: np.less_equal,
|
|
40
|
+
InequalityOperator.GT: np.greater,
|
|
41
|
+
InequalityOperator.GE: np.greater_equal,
|
|
42
|
+
}[InequalityOperator(self.constraint.operator)]
|
|
43
|
+
return operator(lhs, rhs)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ScalarInequalityChecker(ConstraintChecker, WithCompareMixin):
|
|
47
|
+
"""Compare a column to a scalar value.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
column_name: Name of the constrained column. Will be
|
|
51
|
+
used as the left-hand side (lhs) of the comparison.
|
|
52
|
+
operator: Comparison operator.
|
|
53
|
+
rhs: Scalar value to compare against.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
constraint: ScalarInequalityConstraint
|
|
57
|
+
|
|
58
|
+
def check(self, dataframe: pd.DataFrame) -> NDArray[np.bool_]:
|
|
59
|
+
return self.compare(dataframe[self.lhs].values, self.constraint.rhs)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ColumnInequalityChecker(ConstraintChecker, WithCompareMixin):
|
|
63
|
+
"""Compare the values of two columns.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
column_name: Name of the constrained column. Will be
|
|
67
|
+
used as the left-hand side (lhs) of the comparison.
|
|
68
|
+
operator: Comparison operator.
|
|
69
|
+
rhs: Name of the column to compare against.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
constraint: ColumnInequalityConstraint
|
|
73
|
+
|
|
74
|
+
def get_required_column_names(self) -> tuple[str, ...]:
|
|
75
|
+
"""Return the names of columns required for the constraint.
|
|
76
|
+
|
|
77
|
+
Note that order matters. Edges in the DAG are created as column_names[1], column_names[0].
|
|
78
|
+
"""
|
|
79
|
+
return (self.lhs, self.constraint.rhs)
|
|
80
|
+
|
|
81
|
+
def check(self, dataframe: pd.DataFrame) -> NDArray[np.bool_]:
|
|
82
|
+
return self.compare(
|
|
83
|
+
dataframe[self.lhs].values,
|
|
84
|
+
dataframe[self.constraint.rhs].values,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
CONSTRAINT_TYPE_TO_CHECKER = {
|
|
89
|
+
ConstraintType.SCALAR_INEQUALITY: ScalarInequalityChecker,
|
|
90
|
+
ConstraintType.COLUMN_INEQUALITY: ColumnInequalityChecker,
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def get_constraint_checker(constraint_type: ConstraintType) -> Type[ConstraintChecker]:
|
|
95
|
+
return CONSTRAINT_TYPE_TO_CHECKER[ConstraintType(constraint_type)]
|