data-designer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. data_designer/__init__.py +15 -0
  2. data_designer/_version.py +34 -0
  3. data_designer/cli/README.md +236 -0
  4. data_designer/cli/__init__.py +6 -0
  5. data_designer/cli/commands/__init__.py +2 -0
  6. data_designer/cli/commands/list.py +130 -0
  7. data_designer/cli/commands/models.py +10 -0
  8. data_designer/cli/commands/providers.py +11 -0
  9. data_designer/cli/commands/reset.py +100 -0
  10. data_designer/cli/controllers/__init__.py +7 -0
  11. data_designer/cli/controllers/model_controller.py +246 -0
  12. data_designer/cli/controllers/provider_controller.py +317 -0
  13. data_designer/cli/forms/__init__.py +20 -0
  14. data_designer/cli/forms/builder.py +51 -0
  15. data_designer/cli/forms/field.py +180 -0
  16. data_designer/cli/forms/form.py +59 -0
  17. data_designer/cli/forms/model_builder.py +125 -0
  18. data_designer/cli/forms/provider_builder.py +76 -0
  19. data_designer/cli/main.py +44 -0
  20. data_designer/cli/repositories/__init__.py +8 -0
  21. data_designer/cli/repositories/base.py +39 -0
  22. data_designer/cli/repositories/model_repository.py +42 -0
  23. data_designer/cli/repositories/provider_repository.py +43 -0
  24. data_designer/cli/services/__init__.py +7 -0
  25. data_designer/cli/services/model_service.py +116 -0
  26. data_designer/cli/services/provider_service.py +111 -0
  27. data_designer/cli/ui.py +448 -0
  28. data_designer/cli/utils.py +47 -0
  29. data_designer/config/__init__.py +2 -0
  30. data_designer/config/analysis/column_profilers.py +89 -0
  31. data_designer/config/analysis/column_statistics.py +274 -0
  32. data_designer/config/analysis/dataset_profiler.py +60 -0
  33. data_designer/config/analysis/utils/errors.py +8 -0
  34. data_designer/config/analysis/utils/reporting.py +188 -0
  35. data_designer/config/base.py +68 -0
  36. data_designer/config/column_configs.py +354 -0
  37. data_designer/config/column_types.py +168 -0
  38. data_designer/config/config_builder.py +660 -0
  39. data_designer/config/data_designer_config.py +40 -0
  40. data_designer/config/dataset_builders.py +11 -0
  41. data_designer/config/datastore.py +151 -0
  42. data_designer/config/default_model_settings.py +123 -0
  43. data_designer/config/errors.py +19 -0
  44. data_designer/config/interface.py +54 -0
  45. data_designer/config/models.py +231 -0
  46. data_designer/config/preview_results.py +32 -0
  47. data_designer/config/processors.py +41 -0
  48. data_designer/config/sampler_constraints.py +51 -0
  49. data_designer/config/sampler_params.py +604 -0
  50. data_designer/config/seed.py +145 -0
  51. data_designer/config/utils/code_lang.py +83 -0
  52. data_designer/config/utils/constants.py +313 -0
  53. data_designer/config/utils/errors.py +19 -0
  54. data_designer/config/utils/info.py +88 -0
  55. data_designer/config/utils/io_helpers.py +273 -0
  56. data_designer/config/utils/misc.py +81 -0
  57. data_designer/config/utils/numerical_helpers.py +28 -0
  58. data_designer/config/utils/type_helpers.py +100 -0
  59. data_designer/config/utils/validation.py +336 -0
  60. data_designer/config/utils/visualization.py +427 -0
  61. data_designer/config/validator_params.py +96 -0
  62. data_designer/engine/__init__.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +55 -0
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
  65. data_designer/engine/analysis/column_profilers/registry.py +20 -0
  66. data_designer/engine/analysis/column_statistics.py +142 -0
  67. data_designer/engine/analysis/dataset_profiler.py +125 -0
  68. data_designer/engine/analysis/errors.py +7 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
  70. data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
  71. data_designer/engine/column_generators/__init__.py +2 -0
  72. data_designer/engine/column_generators/generators/__init__.py +2 -0
  73. data_designer/engine/column_generators/generators/base.py +61 -0
  74. data_designer/engine/column_generators/generators/expression.py +63 -0
  75. data_designer/engine/column_generators/generators/llm_generators.py +172 -0
  76. data_designer/engine/column_generators/generators/samplers.py +75 -0
  77. data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
  78. data_designer/engine/column_generators/generators/validation.py +147 -0
  79. data_designer/engine/column_generators/registry.py +56 -0
  80. data_designer/engine/column_generators/utils/errors.py +13 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
  83. data_designer/engine/configurable_task.py +82 -0
  84. data_designer/engine/dataset_builders/artifact_storage.py +181 -0
  85. data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
  86. data_designer/engine/dataset_builders/errors.py +13 -0
  87. data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
  88. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +56 -0
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
  93. data_designer/engine/dataset_builders/utils/errors.py +13 -0
  94. data_designer/engine/errors.py +49 -0
  95. data_designer/engine/model_provider.py +75 -0
  96. data_designer/engine/models/__init__.py +2 -0
  97. data_designer/engine/models/errors.py +308 -0
  98. data_designer/engine/models/facade.py +225 -0
  99. data_designer/engine/models/litellm_overrides.py +162 -0
  100. data_designer/engine/models/parsers/__init__.py +2 -0
  101. data_designer/engine/models/parsers/errors.py +34 -0
  102. data_designer/engine/models/parsers/parser.py +236 -0
  103. data_designer/engine/models/parsers/postprocessors.py +93 -0
  104. data_designer/engine/models/parsers/tag_parsers.py +60 -0
  105. data_designer/engine/models/parsers/types.py +82 -0
  106. data_designer/engine/models/recipes/base.py +79 -0
  107. data_designer/engine/models/recipes/response_recipes.py +291 -0
  108. data_designer/engine/models/registry.py +118 -0
  109. data_designer/engine/models/usage.py +75 -0
  110. data_designer/engine/models/utils.py +38 -0
  111. data_designer/engine/processing/ginja/__init__.py +2 -0
  112. data_designer/engine/processing/ginja/ast.py +64 -0
  113. data_designer/engine/processing/ginja/environment.py +461 -0
  114. data_designer/engine/processing/ginja/exceptions.py +54 -0
  115. data_designer/engine/processing/ginja/record.py +30 -0
  116. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  117. data_designer/engine/processing/gsonschema/exceptions.py +8 -0
  118. data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
  119. data_designer/engine/processing/gsonschema/types.py +8 -0
  120. data_designer/engine/processing/gsonschema/validators.py +143 -0
  121. data_designer/engine/processing/processors/base.py +15 -0
  122. data_designer/engine/processing/processors/drop_columns.py +46 -0
  123. data_designer/engine/processing/processors/registry.py +20 -0
  124. data_designer/engine/processing/utils.py +120 -0
  125. data_designer/engine/registry/base.py +97 -0
  126. data_designer/engine/registry/data_designer_registry.py +37 -0
  127. data_designer/engine/registry/errors.py +10 -0
  128. data_designer/engine/resources/managed_dataset_generator.py +35 -0
  129. data_designer/engine/resources/managed_dataset_repository.py +194 -0
  130. data_designer/engine/resources/managed_storage.py +63 -0
  131. data_designer/engine/resources/resource_provider.py +46 -0
  132. data_designer/engine/resources/seed_dataset_data_store.py +66 -0
  133. data_designer/engine/sampling_gen/column.py +89 -0
  134. data_designer/engine/sampling_gen/constraints.py +95 -0
  135. data_designer/engine/sampling_gen/data_sources/base.py +214 -0
  136. data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
  137. data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
  138. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  139. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  140. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
  141. data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
  142. data_designer/engine/sampling_gen/entities/errors.py +8 -0
  143. data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
  144. data_designer/engine/sampling_gen/entities/person.py +142 -0
  145. data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
  146. data_designer/engine/sampling_gen/errors.py +24 -0
  147. data_designer/engine/sampling_gen/generator.py +121 -0
  148. data_designer/engine/sampling_gen/jinja_utils.py +60 -0
  149. data_designer/engine/sampling_gen/people_gen.py +203 -0
  150. data_designer/engine/sampling_gen/person_constants.py +54 -0
  151. data_designer/engine/sampling_gen/schema.py +143 -0
  152. data_designer/engine/sampling_gen/schema_builder.py +59 -0
  153. data_designer/engine/sampling_gen/utils.py +40 -0
  154. data_designer/engine/secret_resolver.py +80 -0
  155. data_designer/engine/validators/__init__.py +17 -0
  156. data_designer/engine/validators/base.py +36 -0
  157. data_designer/engine/validators/local_callable.py +34 -0
  158. data_designer/engine/validators/python.py +245 -0
  159. data_designer/engine/validators/remote.py +83 -0
  160. data_designer/engine/validators/sql.py +60 -0
  161. data_designer/errors.py +5 -0
  162. data_designer/essentials/__init__.py +137 -0
  163. data_designer/interface/__init__.py +2 -0
  164. data_designer/interface/data_designer.py +351 -0
  165. data_designer/interface/errors.py +16 -0
  166. data_designer/interface/results.py +55 -0
  167. data_designer/logging.py +161 -0
  168. data_designer/plugin_manager.py +83 -0
  169. data_designer/plugins/__init__.py +6 -0
  170. data_designer/plugins/errors.py +10 -0
  171. data_designer/plugins/plugin.py +69 -0
  172. data_designer/plugins/registry.py +86 -0
  173. data_designer-0.1.0.dist-info/METADATA +173 -0
  174. data_designer-0.1.0.dist-info/RECORD +177 -0
  175. data_designer-0.1.0.dist-info/WHEEL +4 -0
  176. data_designer-0.1.0.dist-info/entry_points.txt +2 -0
  177. data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,194 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC, abstractmethod
5
+ from dataclasses import dataclass
6
+ from functools import cached_property
7
+ import logging
8
+ from pathlib import Path
9
+ import tempfile
10
+ import threading
11
+ import time
12
+ from typing import Any
13
+
14
+ import duckdb
15
+ import pandas as pd
16
+
17
+ from data_designer.config.utils.constants import LOCALES_WITH_MANAGED_DATASETS
18
+ from data_designer.engine.resources.managed_storage import LocalBlobStorageProvider, ManagedBlobStorage
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ DATASETS_ROOT = "datasets"
23
+ """
24
+ Path in object storage to managed datasets
25
+ """
26
+
27
+
28
+ @dataclass
29
+ class Table:
30
+ """
31
+ Managed datasets are organized by dataset by table under a root
32
+ table path in object storage.
33
+ """
34
+
35
+ source: str
36
+ """
37
+ Table source path
38
+ """
39
+
40
+ schema: str = "main"
41
+ """
42
+ Specifies the schema to use when registering the table.
43
+
44
+ Note: this is not the schema of the table, but rather the _database_
45
+ schema to associated with the table.
46
+ """
47
+
48
+ @cached_property
49
+ def name(self) -> str:
50
+ return Path(self.source).stem
51
+
52
+
53
+ DataCatalog = list[Table]
54
+
55
+
56
+ # For now we hardcode the remote data catalog in code. This make it easier
57
+ # initialize the data catalog. Eventually we can make this work more
58
+ # dynamically once this data catalog pattern becomes more widely adopted.
59
+ DEFAULT_DATA_CATALOG: DataCatalog = [Table(f"{locale}.parquet") for locale in LOCALES_WITH_MANAGED_DATASETS]
60
+
61
+
62
+ class ManagedDatasetRepository(ABC):
63
+ @abstractmethod
64
+ def query(self, sql: str, parameters: list[Any]) -> pd.DataFrame: ...
65
+
66
+ @property
67
+ @abstractmethod
68
+ def data_catalog(self) -> DataCatalog: ...
69
+
70
+
71
+ class DuckDBDatasetRepository(ManagedDatasetRepository):
72
+ """
73
+ Provides a duckdb based sql interface over Gretel managed datasets.
74
+ """
75
+
76
+ _default_config = {"threads": 2, "memory_limit": "4 gb"}
77
+
78
+ def __init__(
79
+ self,
80
+ blob_storage: ManagedBlobStorage,
81
+ config: dict | None = None,
82
+ data_catalog: DataCatalog = DEFAULT_DATA_CATALOG,
83
+ datasets_root: str = DATASETS_ROOT,
84
+ use_cache: bool = True,
85
+ ):
86
+ """
87
+ Create a new DuckDB backed dataset repository
88
+
89
+ Args:
90
+ blob_storage: A managed blob storage provider
91
+ config: DuckDB configuration options,
92
+ https://duckdb.org/docs/configuration/overview.html#configuration-reference
93
+ data_catalog: A list of tables to register with the DuckDB instance
94
+ datasets_root: The root path in blob storage to managed datasets
95
+ use_cache: Whether to cache datasets locally. Trades off disk memory
96
+ and startup time for faster queries.
97
+ """
98
+ self._data_catalog = data_catalog
99
+ self._data_sets_root = datasets_root
100
+ self._blob_storage = blob_storage
101
+ self._config = self._default_config if config is None else config
102
+ self._use_cache = use_cache
103
+
104
+ # Configure database and register tables
105
+ self.db = duckdb.connect(config=self._config)
106
+
107
+ # Dataset registration completion is tracked with an event. Consumers can
108
+ # wait on this event to ensure the catalog is ready.
109
+ self._registration_event = threading.Event()
110
+ self._register_lock = threading.Lock()
111
+
112
+ # Kick off dataset registration in a background thread so that IO-heavy
113
+ # caching and view creation can run asynchronously without blocking the
114
+ # caller that constructs this repository instance.
115
+ self._register_thread = threading.Thread(target=self._register_datasets, daemon=True)
116
+ self._register_thread.start()
117
+
118
+ def _register_datasets(self):
119
+ # Just in case this method gets called from inside a thread.
120
+ # This operation isn't thread-safe by default, so we
121
+ # synchronize the registration process.
122
+ if self._registration_event.is_set():
123
+ return
124
+ with self._register_lock:
125
+ # check once more to see if the catalog is ready it's possible a
126
+ # previous thread already registered the dataset.
127
+ if self._registration_event.is_set():
128
+ return
129
+ try:
130
+ for table in self.data_catalog:
131
+ key = table.source if table.schema == "main" else f"{table.schema}/{table.source}"
132
+ if self._use_cache:
133
+ tmp_root = Path(tempfile.gettempdir()) / "dd_cache"
134
+ local_path = tmp_root / key
135
+ local_path.parent.mkdir(parents=True, exist_ok=True)
136
+ if not local_path.exists():
137
+ start = time.time()
138
+ logger.debug("Caching database %s to %s", table.name, local_path)
139
+ with self._blob_storage.get_blob(f"{self._data_sets_root}/{key}") as src_fd:
140
+ with open(local_path, "wb") as dst_fd:
141
+ dst_fd.write(src_fd.read())
142
+ logger.debug(
143
+ "Cached database %s in %.2f s",
144
+ table.name,
145
+ time.time() - start,
146
+ )
147
+ data_path = local_path.as_posix()
148
+ else:
149
+ data_path = self._blob_storage.uri_for_key(f"{self._data_sets_root}/{key}")
150
+ if table.schema != "main":
151
+ self.db.sql(f"CREATE SCHEMA IF NOT EXISTS {table.schema}")
152
+ logger.debug(f"Registering dataset {table.name} from {data_path}")
153
+ self.db.sql(f"CREATE VIEW {table.schema}.{table.name} AS FROM '{data_path}'")
154
+
155
+ logger.debug("DuckDBDatasetRepository registration complete")
156
+
157
+ except Exception as e:
158
+ logger.exception(f"Failed to register datasets: {str(e)}")
159
+
160
+ finally:
161
+ # Signal that registration is complete so any waiting queries can proceed.
162
+ self._registration_event.set()
163
+
164
+ def query(self, sql: str, parameters: list[Any]) -> pd.DataFrame:
165
+ # Ensure dataset registration has completed. Possible future optimization:
166
+ # pull datasets in parallel and only wait here if the query requires a
167
+ # table that isn't cached.
168
+ if not self._registration_event.is_set():
169
+ logger.debug("Waiting for dataset caching and registration to finish...")
170
+ self._registration_event.wait()
171
+
172
+ # the duckdb connection isn't thread-safe, so we create a new
173
+ # connection per query using cursor().
174
+ # more details here: https://duckdb.org/docs/stable/guides/python/multiple_threads.html
175
+ cursor = self.db.cursor()
176
+ try:
177
+ df = cursor.execute(sql, parameters).df()
178
+ finally:
179
+ cursor.close()
180
+ return df
181
+
182
+ @property
183
+ def data_catalog(self) -> DataCatalog:
184
+ return self._data_catalog
185
+
186
+
187
+ def load_managed_dataset_repository(blob_storage: ManagedBlobStorage, locales: list[str]) -> ManagedDatasetRepository:
188
+ return DuckDBDatasetRepository(
189
+ blob_storage,
190
+ config={"threads": 1, "memory_limit": "2 gb"},
191
+ data_catalog=[Table(f"{locale}.parquet") for locale in locales],
192
+ # Only cache if not using local storage.
193
+ use_cache=not isinstance(blob_storage, LocalBlobStorageProvider),
194
+ )
@@ -0,0 +1,63 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC, abstractmethod
5
+ from collections.abc import Iterator
6
+ from contextlib import contextmanager
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import IO
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class ManagedBlobStorage(ABC):
15
+ """
16
+ Provides a low-level interface for access object in blob storage. This interface
17
+ can be used to access model weights, raw datasets, or any artifact in blob
18
+ storage.
19
+
20
+ If you want a high-level interface for accessing datasets, use the `ManagedDatasetRepository`
21
+ which provides a high-level SQL interface over each dataset.
22
+ """
23
+
24
+ @abstractmethod
25
+ @contextmanager
26
+ def get_blob(self, blob_key: str) -> Iterator[IO]: ...
27
+
28
+ @abstractmethod
29
+ def _key_uri_builder(self, key: str) -> str: ...
30
+
31
+ def uri_for_key(self, key: str) -> str:
32
+ """
33
+ Returns a qualified storage URI for a given a key. `key` is
34
+ normalized to ensure that and leading path components ("/") are removed.
35
+ """
36
+ return self._key_uri_builder(key.lstrip("/"))
37
+
38
+
39
+ class LocalBlobStorageProvider(ManagedBlobStorage):
40
+ """
41
+ Provide a local blob storage service. Useful for running
42
+ tests that don't require access to external infrastructure
43
+ """
44
+
45
+ def __init__(self, root_path: Path) -> None:
46
+ self._root_path = root_path
47
+
48
+ @contextmanager
49
+ def get_blob(self, blob_key: str) -> Iterator[IO]:
50
+ with open(self._key_uri_builder(blob_key), "rb") as fd:
51
+ yield fd
52
+
53
+ def _key_uri_builder(self, key: str) -> str:
54
+ return f"{self._root_path}/{key}"
55
+
56
+
57
+ def init_managed_blob_storage(assets_storage: str) -> ManagedBlobStorage:
58
+ path = Path(assets_storage)
59
+ if not path.exists():
60
+ raise RuntimeError(f"Local storage path {assets_storage!r} does not exist.")
61
+
62
+ logger.debug(f"Using local storage for managed datasets: {assets_storage!r}")
63
+ return LocalBlobStorageProvider(Path(assets_storage))
@@ -0,0 +1,46 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from data_designer.config.base import ConfigBase
5
+ from data_designer.config.models import ModelConfig
6
+ from data_designer.config.utils.type_helpers import StrEnum
7
+ from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
8
+ from data_designer.engine.model_provider import ModelProviderRegistry
9
+ from data_designer.engine.models.registry import ModelRegistry, create_model_registry
10
+ from data_designer.engine.resources.managed_storage import ManagedBlobStorage, init_managed_blob_storage
11
+ from data_designer.engine.resources.seed_dataset_data_store import SeedDatasetDataStore
12
+ from data_designer.engine.secret_resolver import SecretResolver
13
+
14
+
15
+ class ResourceType(StrEnum):
16
+ BLOB_STORAGE = "blob_storage"
17
+ DATASTORE = "datastore"
18
+ MODEL_REGISTRY = "model_registry"
19
+
20
+
21
+ class ResourceProvider(ConfigBase):
22
+ artifact_storage: ArtifactStorage
23
+ blob_storage: ManagedBlobStorage | None = None
24
+ datastore: SeedDatasetDataStore | None = None
25
+ model_registry: ModelRegistry | None = None
26
+
27
+
28
+ def create_resource_provider(
29
+ *,
30
+ artifact_storage: ArtifactStorage,
31
+ model_configs: list[ModelConfig],
32
+ secret_resolver: SecretResolver,
33
+ model_provider_registry: ModelProviderRegistry,
34
+ datastore: SeedDatasetDataStore | None = None,
35
+ blob_storage: ManagedBlobStorage | None = None,
36
+ ) -> ResourceProvider:
37
+ return ResourceProvider(
38
+ artifact_storage=artifact_storage,
39
+ datastore=datastore,
40
+ model_registry=create_model_registry(
41
+ model_configs=model_configs,
42
+ secret_resolver=secret_resolver,
43
+ model_provider_registry=model_provider_registry,
44
+ ),
45
+ blob_storage=blob_storage or init_managed_blob_storage(),
46
+ )
@@ -0,0 +1,66 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC, abstractmethod
5
+
6
+ import duckdb
7
+ from huggingface_hub import HfApi, HfFileSystem
8
+
9
+ from data_designer.logging import quiet_noisy_logger
10
+
11
+ quiet_noisy_logger("httpx")
12
+
13
+ _HF_DATASETS_PREFIX = "hf://datasets/"
14
+
15
+
16
+ class MalformedFileIdError(Exception):
17
+ """Raised when file_id format is invalid."""
18
+
19
+
20
+ class SeedDatasetDataStore(ABC):
21
+ """Abstract base class for dataset storage implementations."""
22
+
23
+ @abstractmethod
24
+ def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: ...
25
+
26
+ @abstractmethod
27
+ def get_dataset_uri(self, file_id: str) -> str: ...
28
+
29
+
30
+ class LocalSeedDatasetDataStore(SeedDatasetDataStore):
31
+ """Local filesystem-based dataset storage."""
32
+
33
+ def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
34
+ return duckdb.connect()
35
+
36
+ def get_dataset_uri(self, file_id: str) -> str:
37
+ return file_id
38
+
39
+
40
+ class HfHubSeedDatasetDataStore(SeedDatasetDataStore):
41
+ """Hugging Face and Data Store dataset storage."""
42
+
43
+ def __init__(self, endpoint: str, token: str | None):
44
+ self.hfapi = HfApi(endpoint=endpoint, token=token)
45
+ self.hffs = HfFileSystem(endpoint=endpoint, token=token)
46
+
47
+ def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
48
+ conn = duckdb.connect()
49
+ conn.register_filesystem(self.hffs)
50
+ return conn
51
+
52
+ def get_dataset_uri(self, file_id: str) -> str:
53
+ identifier = file_id.removeprefix(_HF_DATASETS_PREFIX)
54
+ repo_id, filename = self._get_repo_id_and_filename(identifier)
55
+ return f"{_HF_DATASETS_PREFIX}{repo_id}/{filename}"
56
+
57
+ def _get_repo_id_and_filename(self, identifier: str) -> tuple[str, str]:
58
+ """Extract repo_id and filename from identifier."""
59
+ parts = identifier.split("/", 2)
60
+ if len(parts) < 3:
61
+ raise MalformedFileIdError(
62
+ "Could not extract repo id and filename from file_id, "
63
+ "expected 'hf://datasets/{repo-namespace}/{repo-name}/{filename}'"
64
+ )
65
+ repo_ns, repo_name, filename = parts
66
+ return f"{repo_ns}/{repo_name}", filename
@@ -0,0 +1,89 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ from pydantic import field_serializer, model_validator
7
+ from typing_extensions import Self
8
+
9
+ from data_designer.config.column_configs import SamplerColumnConfig
10
+ from data_designer.config.sampler_params import SamplerParamsT, SamplerType
11
+ from data_designer.engine.sampling_gen.data_sources.base import DataSource
12
+ from data_designer.engine.sampling_gen.data_sources.sources import SamplerRegistry
13
+ from data_designer.engine.sampling_gen.jinja_utils import extract_column_names_from_expression
14
+
15
+
16
+ class ConditionalDataColumn(SamplerColumnConfig):
17
+ @property
18
+ def _negative_condition(self) -> str:
19
+ conditions = list(self.conditional_params.keys())
20
+ return "not (" + " or ".join([f"({c})" for c in conditions]) + ")"
21
+
22
+ @property
23
+ def conditions(self) -> list[str]:
24
+ c = list(self.conditional_params.keys())
25
+ return c + [self._negative_condition] if len(c) > 0 else ["..."]
26
+
27
+ @property
28
+ def conditional_column_names(self) -> set[str]:
29
+ names = set()
30
+ for condition in self.conditional_params.keys():
31
+ names.update(extract_column_names_from_expression(condition))
32
+ return names
33
+
34
+ @field_serializer("sampler_type")
35
+ def serialize_sampler_type(self, sampler_type: SamplerType) -> str:
36
+ return SamplerType(sampler_type).value
37
+
38
+ @field_serializer("params")
39
+ def serialize_params(self, params: SamplerParamsT) -> dict:
40
+ return params.model_dump()
41
+
42
+ @field_serializer("conditional_params")
43
+ def serialize_conditional_params(self, conditional_params: dict[str, SamplerParamsT]) -> dict:
44
+ for condition, params in conditional_params.items():
45
+ conditional_params[condition] = params.model_dump()
46
+ return conditional_params
47
+
48
+ @model_validator(mode="before")
49
+ @classmethod
50
+ def validate_params_with_type(cls, data: Any) -> Any:
51
+ if not isinstance(data, dict) or "sampler_type" not in data:
52
+ return data
53
+ if isinstance(data["sampler_type"], str):
54
+ if not SamplerRegistry.is_registered(data["sampler_type"]):
55
+ raise ValueError(
56
+ f"Invalid sampler type: {data['sampler_type']}. Available samplers: {[s.value for s in SamplerType]}"
57
+ )
58
+ if "params" in data:
59
+ data["params"] = SamplerRegistry.get_sampler(data["sampler_type"])(params=data["params"]).params
60
+ if "conditional_params" in data:
61
+ for condition, params in data["conditional_params"].items():
62
+ data["conditional_params"][condition] = SamplerRegistry.get_sampler(data["sampler_type"])(
63
+ params=params
64
+ ).params
65
+ return data
66
+
67
+ @model_validator(mode="after")
68
+ def validate_params(self) -> Self:
69
+ self.params = SamplerRegistry.validate_sampler_type(self.sampler_type)(params=self.params).params
70
+ return self
71
+
72
+ @model_validator(mode="after")
73
+ def validate_data_conversion(self) -> Self:
74
+ self.get_default_sampler().validate_data_conversion(self.convert_to)
75
+ return self
76
+
77
+ @model_validator(mode="after")
78
+ def validate_conditional_params(self) -> Self:
79
+ for condition, params in self.conditional_params.items():
80
+ self.conditional_params[condition] = SamplerRegistry.get_sampler(self.sampler_type)(params=params).params
81
+ return self
82
+
83
+ def get_default_sampler(self, **kwargs) -> DataSource:
84
+ return self.get_sampler("...", **kwargs)
85
+
86
+ def get_sampler(self, condition: str, **kwargs) -> DataSource:
87
+ if condition in ["...", self._negative_condition]:
88
+ return SamplerRegistry.get_sampler(self.sampler_type)(self.params, **kwargs)
89
+ return SamplerRegistry.get_sampler(self.sampler_type)(self.conditional_params[condition], **kwargs)
@@ -0,0 +1,95 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC, abstractmethod
5
+ from typing import Type
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+ import pandas as pd
10
+
11
+ from data_designer.config.base import ConfigBase
12
+ from data_designer.config.sampler_constraints import (
13
+ ColumnInequalityConstraint,
14
+ Constraint,
15
+ ConstraintType,
16
+ InequalityOperator,
17
+ ScalarInequalityConstraint,
18
+ )
19
+
20
+
21
+ class ConstraintChecker(ConfigBase, ABC):
22
+ constraint: Constraint
23
+
24
+ def get_required_column_names(self) -> tuple[str, ...]:
25
+ return (self.constraint.target_column,)
26
+
27
+ @abstractmethod
28
+ def check(self, dataframe: pd.DataFrame) -> NDArray[np.bool_]: ...
29
+
30
+
31
+ class WithCompareMixin:
32
+ @property
33
+ def lhs(self) -> str:
34
+ return self.constraint.target_column
35
+
36
+ def compare(self, lhs: float | int | NDArray, rhs: float | int | NDArray) -> bool | NDArray[np.bool_]:
37
+ operator = {
38
+ InequalityOperator.LT: np.less,
39
+ InequalityOperator.LE: np.less_equal,
40
+ InequalityOperator.GT: np.greater,
41
+ InequalityOperator.GE: np.greater_equal,
42
+ }[InequalityOperator(self.constraint.operator)]
43
+ return operator(lhs, rhs)
44
+
45
+
46
+ class ScalarInequalityChecker(ConstraintChecker, WithCompareMixin):
47
+ """Compare a column to a scalar value.
48
+
49
+ Args:
50
+ column_name: Name of the constrained column. Will be
51
+ used as the left-hand side (lhs) of the comparison.
52
+ operator: Comparison operator.
53
+ rhs: Scalar value to compare against.
54
+ """
55
+
56
+ constraint: ScalarInequalityConstraint
57
+
58
+ def check(self, dataframe: pd.DataFrame) -> NDArray[np.bool_]:
59
+ return self.compare(dataframe[self.lhs].values, self.constraint.rhs)
60
+
61
+
62
+ class ColumnInequalityChecker(ConstraintChecker, WithCompareMixin):
63
+ """Compare the values of two columns.
64
+
65
+ Args:
66
+ column_name: Name of the constrained column. Will be
67
+ used as the left-hand side (lhs) of the comparison.
68
+ operator: Comparison operator.
69
+ rhs: Name of the column to compare against.
70
+ """
71
+
72
+ constraint: ColumnInequalityConstraint
73
+
74
+ def get_required_column_names(self) -> tuple[str, ...]:
75
+ """Return the names of columns required for the constraint.
76
+
77
+ Note that order matters. Edges in the DAG are created as column_names[1], column_names[0].
78
+ """
79
+ return (self.lhs, self.constraint.rhs)
80
+
81
+ def check(self, dataframe: pd.DataFrame) -> NDArray[np.bool_]:
82
+ return self.compare(
83
+ dataframe[self.lhs].values,
84
+ dataframe[self.constraint.rhs].values,
85
+ )
86
+
87
+
88
+ CONSTRAINT_TYPE_TO_CHECKER = {
89
+ ConstraintType.SCALAR_INEQUALITY: ScalarInequalityChecker,
90
+ ConstraintType.COLUMN_INEQUALITY: ColumnInequalityChecker,
91
+ }
92
+
93
+
94
+ def get_constraint_checker(constraint_type: ConstraintType) -> Type[ConstraintChecker]:
95
+ return CONSTRAINT_TYPE_TO_CHECKER[ConstraintType(constraint_type)]