data-designer 0.3.8rc2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. data_designer/cli/commands/__init__.py +1 -1
  2. data_designer/interface/__init__.py +21 -1
  3. data_designer/{_version.py → interface/_version.py} +2 -2
  4. data_designer/interface/data_designer.py +1 -7
  5. {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0.dist-info}/METADATA +10 -42
  6. data_designer-0.4.0.dist-info/RECORD +39 -0
  7. data_designer/__init__.py +0 -17
  8. data_designer/config/__init__.py +0 -2
  9. data_designer/config/analysis/__init__.py +0 -2
  10. data_designer/config/analysis/column_profilers.py +0 -159
  11. data_designer/config/analysis/column_statistics.py +0 -421
  12. data_designer/config/analysis/dataset_profiler.py +0 -84
  13. data_designer/config/analysis/utils/errors.py +0 -10
  14. data_designer/config/analysis/utils/reporting.py +0 -192
  15. data_designer/config/base.py +0 -69
  16. data_designer/config/column_configs.py +0 -470
  17. data_designer/config/column_types.py +0 -141
  18. data_designer/config/config_builder.py +0 -595
  19. data_designer/config/data_designer_config.py +0 -40
  20. data_designer/config/dataset_builders.py +0 -13
  21. data_designer/config/dataset_metadata.py +0 -18
  22. data_designer/config/default_model_settings.py +0 -129
  23. data_designer/config/errors.py +0 -24
  24. data_designer/config/exports.py +0 -145
  25. data_designer/config/interface.py +0 -55
  26. data_designer/config/models.py +0 -455
  27. data_designer/config/preview_results.py +0 -41
  28. data_designer/config/processors.py +0 -148
  29. data_designer/config/run_config.py +0 -51
  30. data_designer/config/sampler_constraints.py +0 -52
  31. data_designer/config/sampler_params.py +0 -639
  32. data_designer/config/seed.py +0 -116
  33. data_designer/config/seed_source.py +0 -84
  34. data_designer/config/seed_source_types.py +0 -19
  35. data_designer/config/utils/code_lang.py +0 -82
  36. data_designer/config/utils/constants.py +0 -363
  37. data_designer/config/utils/errors.py +0 -21
  38. data_designer/config/utils/info.py +0 -94
  39. data_designer/config/utils/io_helpers.py +0 -258
  40. data_designer/config/utils/misc.py +0 -78
  41. data_designer/config/utils/numerical_helpers.py +0 -30
  42. data_designer/config/utils/type_helpers.py +0 -106
  43. data_designer/config/utils/visualization.py +0 -482
  44. data_designer/config/validator_params.py +0 -94
  45. data_designer/engine/__init__.py +0 -2
  46. data_designer/engine/analysis/column_profilers/base.py +0 -49
  47. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +0 -153
  48. data_designer/engine/analysis/column_profilers/registry.py +0 -22
  49. data_designer/engine/analysis/column_statistics.py +0 -145
  50. data_designer/engine/analysis/dataset_profiler.py +0 -149
  51. data_designer/engine/analysis/errors.py +0 -9
  52. data_designer/engine/analysis/utils/column_statistics_calculations.py +0 -234
  53. data_designer/engine/analysis/utils/judge_score_processing.py +0 -132
  54. data_designer/engine/column_generators/__init__.py +0 -2
  55. data_designer/engine/column_generators/generators/__init__.py +0 -2
  56. data_designer/engine/column_generators/generators/base.py +0 -122
  57. data_designer/engine/column_generators/generators/embedding.py +0 -35
  58. data_designer/engine/column_generators/generators/expression.py +0 -55
  59. data_designer/engine/column_generators/generators/llm_completion.py +0 -113
  60. data_designer/engine/column_generators/generators/samplers.py +0 -69
  61. data_designer/engine/column_generators/generators/seed_dataset.py +0 -144
  62. data_designer/engine/column_generators/generators/validation.py +0 -140
  63. data_designer/engine/column_generators/registry.py +0 -60
  64. data_designer/engine/column_generators/utils/errors.py +0 -15
  65. data_designer/engine/column_generators/utils/generator_classification.py +0 -43
  66. data_designer/engine/column_generators/utils/judge_score_factory.py +0 -58
  67. data_designer/engine/column_generators/utils/prompt_renderer.py +0 -100
  68. data_designer/engine/compiler.py +0 -97
  69. data_designer/engine/configurable_task.py +0 -71
  70. data_designer/engine/dataset_builders/artifact_storage.py +0 -283
  71. data_designer/engine/dataset_builders/column_wise_builder.py +0 -335
  72. data_designer/engine/dataset_builders/errors.py +0 -15
  73. data_designer/engine/dataset_builders/multi_column_configs.py +0 -46
  74. data_designer/engine/dataset_builders/utils/__init__.py +0 -2
  75. data_designer/engine/dataset_builders/utils/concurrency.py +0 -212
  76. data_designer/engine/dataset_builders/utils/config_compiler.py +0 -62
  77. data_designer/engine/dataset_builders/utils/dag.py +0 -62
  78. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +0 -200
  79. data_designer/engine/dataset_builders/utils/errors.py +0 -15
  80. data_designer/engine/errors.py +0 -51
  81. data_designer/engine/model_provider.py +0 -77
  82. data_designer/engine/models/__init__.py +0 -2
  83. data_designer/engine/models/errors.py +0 -300
  84. data_designer/engine/models/facade.py +0 -287
  85. data_designer/engine/models/factory.py +0 -42
  86. data_designer/engine/models/litellm_overrides.py +0 -179
  87. data_designer/engine/models/parsers/__init__.py +0 -2
  88. data_designer/engine/models/parsers/errors.py +0 -34
  89. data_designer/engine/models/parsers/parser.py +0 -235
  90. data_designer/engine/models/parsers/postprocessors.py +0 -93
  91. data_designer/engine/models/parsers/tag_parsers.py +0 -62
  92. data_designer/engine/models/parsers/types.py +0 -84
  93. data_designer/engine/models/recipes/base.py +0 -81
  94. data_designer/engine/models/recipes/response_recipes.py +0 -293
  95. data_designer/engine/models/registry.py +0 -146
  96. data_designer/engine/models/telemetry.py +0 -359
  97. data_designer/engine/models/usage.py +0 -73
  98. data_designer/engine/models/utils.py +0 -38
  99. data_designer/engine/processing/ginja/__init__.py +0 -2
  100. data_designer/engine/processing/ginja/ast.py +0 -65
  101. data_designer/engine/processing/ginja/environment.py +0 -463
  102. data_designer/engine/processing/ginja/exceptions.py +0 -56
  103. data_designer/engine/processing/ginja/record.py +0 -32
  104. data_designer/engine/processing/gsonschema/__init__.py +0 -2
  105. data_designer/engine/processing/gsonschema/exceptions.py +0 -15
  106. data_designer/engine/processing/gsonschema/schema_transformers.py +0 -83
  107. data_designer/engine/processing/gsonschema/types.py +0 -10
  108. data_designer/engine/processing/gsonschema/validators.py +0 -202
  109. data_designer/engine/processing/processors/base.py +0 -13
  110. data_designer/engine/processing/processors/drop_columns.py +0 -42
  111. data_designer/engine/processing/processors/registry.py +0 -25
  112. data_designer/engine/processing/processors/schema_transform.py +0 -49
  113. data_designer/engine/processing/utils.py +0 -169
  114. data_designer/engine/registry/base.py +0 -99
  115. data_designer/engine/registry/data_designer_registry.py +0 -39
  116. data_designer/engine/registry/errors.py +0 -12
  117. data_designer/engine/resources/managed_dataset_generator.py +0 -39
  118. data_designer/engine/resources/managed_dataset_repository.py +0 -197
  119. data_designer/engine/resources/managed_storage.py +0 -65
  120. data_designer/engine/resources/resource_provider.py +0 -77
  121. data_designer/engine/resources/seed_reader.py +0 -154
  122. data_designer/engine/sampling_gen/column.py +0 -91
  123. data_designer/engine/sampling_gen/constraints.py +0 -100
  124. data_designer/engine/sampling_gen/data_sources/base.py +0 -217
  125. data_designer/engine/sampling_gen/data_sources/errors.py +0 -12
  126. data_designer/engine/sampling_gen/data_sources/sources.py +0 -347
  127. data_designer/engine/sampling_gen/entities/__init__.py +0 -2
  128. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  129. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +0 -86
  130. data_designer/engine/sampling_gen/entities/email_address_utils.py +0 -171
  131. data_designer/engine/sampling_gen/entities/errors.py +0 -10
  132. data_designer/engine/sampling_gen/entities/national_id_utils.py +0 -102
  133. data_designer/engine/sampling_gen/entities/person.py +0 -144
  134. data_designer/engine/sampling_gen/entities/phone_number.py +0 -128
  135. data_designer/engine/sampling_gen/errors.py +0 -26
  136. data_designer/engine/sampling_gen/generator.py +0 -122
  137. data_designer/engine/sampling_gen/jinja_utils.py +0 -64
  138. data_designer/engine/sampling_gen/people_gen.py +0 -199
  139. data_designer/engine/sampling_gen/person_constants.py +0 -56
  140. data_designer/engine/sampling_gen/schema.py +0 -147
  141. data_designer/engine/sampling_gen/schema_builder.py +0 -61
  142. data_designer/engine/sampling_gen/utils.py +0 -46
  143. data_designer/engine/secret_resolver.py +0 -82
  144. data_designer/engine/validation.py +0 -367
  145. data_designer/engine/validators/__init__.py +0 -19
  146. data_designer/engine/validators/base.py +0 -38
  147. data_designer/engine/validators/local_callable.py +0 -39
  148. data_designer/engine/validators/python.py +0 -254
  149. data_designer/engine/validators/remote.py +0 -89
  150. data_designer/engine/validators/sql.py +0 -65
  151. data_designer/errors.py +0 -7
  152. data_designer/essentials/__init__.py +0 -33
  153. data_designer/lazy_heavy_imports.py +0 -54
  154. data_designer/logging.py +0 -163
  155. data_designer/plugin_manager.py +0 -78
  156. data_designer/plugins/__init__.py +0 -8
  157. data_designer/plugins/errors.py +0 -15
  158. data_designer/plugins/plugin.py +0 -141
  159. data_designer/plugins/registry.py +0 -88
  160. data_designer/plugins/testing/__init__.py +0 -10
  161. data_designer/plugins/testing/stubs.py +0 -116
  162. data_designer/plugins/testing/utils.py +0 -20
  163. data_designer-0.3.8rc2.dist-info/RECORD +0 -196
  164. data_designer-0.3.8rc2.dist-info/licenses/LICENSE +0 -201
  165. {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0.dist-info}/WHEEL +0 -0
  166. {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0.dist-info}/entry_points.txt +0 -0
@@ -1,39 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from typing import TYPE_CHECKING, Any
7
-
8
- from data_designer.engine.resources.managed_dataset_repository import ManagedDatasetRepository
9
- from data_designer.lazy_heavy_imports import pd
10
-
11
- if TYPE_CHECKING:
12
- import pandas as pd
13
-
14
-
15
- class ManagedDatasetGenerator:
16
- def __init__(self, managed_datasets: ManagedDatasetRepository, dataset_name: str):
17
- self.managed_datasets = managed_datasets
18
- self.dataset_name = dataset_name
19
-
20
- def generate_samples(
21
- self,
22
- size: int = 1,
23
- evidence: dict[str, Any | list[Any]] = {},
24
- ) -> pd.DataFrame:
25
- parameters = []
26
- query = f"select * from {self.dataset_name}"
27
- if evidence:
28
- where_conditions = []
29
- for column, values in evidence.items():
30
- if values:
31
- values = values if isinstance(values, list) else [values]
32
- formatted_values = ["?"] * len(values)
33
- condition = f"{column} IN ({', '.join(formatted_values)})"
34
- where_conditions.append(condition)
35
- parameters.extend(values)
36
- if where_conditions:
37
- query += " where " + " and ".join(where_conditions)
38
- query += f" order by random() limit {size}"
39
- return self.managed_datasets.query(query, parameters)
@@ -1,197 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import logging
7
- import tempfile
8
- import threading
9
- import time
10
- from abc import ABC, abstractmethod
11
- from dataclasses import dataclass
12
- from functools import cached_property
13
- from pathlib import Path
14
- from typing import TYPE_CHECKING, Any
15
-
16
- from data_designer.config.utils.constants import LOCALES_WITH_MANAGED_DATASETS
17
- from data_designer.engine.resources.managed_storage import LocalBlobStorageProvider, ManagedBlobStorage
18
- from data_designer.lazy_heavy_imports import duckdb, pd
19
-
20
- if TYPE_CHECKING:
21
- import duckdb
22
- import pandas as pd
23
-
24
- logger = logging.getLogger(__name__)
25
-
26
- DATASETS_ROOT = "datasets"
27
- """
28
- Path in object storage to managed datasets
29
- """
30
-
31
-
32
- @dataclass
33
- class Table:
34
- """
35
- Managed datasets are organized by dataset by table under a root
36
- table path in object storage.
37
- """
38
-
39
- source: str
40
- """
41
- Table source path
42
- """
43
-
44
- schema: str = "main"
45
- """
46
- Specifies the schema to use when registering the table.
47
-
48
- Note: this is not the schema of the table, but rather the _database_
49
- schema to associated with the table.
50
- """
51
-
52
- @cached_property
53
- def name(self) -> str:
54
- return Path(self.source).stem
55
-
56
-
57
- DataCatalog = list[Table]
58
-
59
- # For now we hardcode the remote data catalog in code. This make it easier
60
- # initialize the data catalog. Eventually we can make this work more
61
- # dynamically once this data catalog pattern becomes more widely adopted.
62
- DEFAULT_DATA_CATALOG: DataCatalog = [Table(f"{locale}.parquet") for locale in LOCALES_WITH_MANAGED_DATASETS]
63
-
64
-
65
- class ManagedDatasetRepository(ABC):
66
- @abstractmethod
67
- def query(self, sql: str, parameters: list[Any]) -> pd.DataFrame: ...
68
-
69
- @property
70
- @abstractmethod
71
- def data_catalog(self) -> DataCatalog: ...
72
-
73
-
74
- class DuckDBDatasetRepository(ManagedDatasetRepository):
75
- """
76
- Provides a duckdb based sql interface over Gretel managed datasets.
77
- """
78
-
79
- _default_config = {"threads": 2, "memory_limit": "4 gb"}
80
-
81
- def __init__(
82
- self,
83
- blob_storage: ManagedBlobStorage,
84
- config: dict | None = None,
85
- data_catalog: DataCatalog = DEFAULT_DATA_CATALOG,
86
- datasets_root: str = DATASETS_ROOT,
87
- use_cache: bool = True,
88
- ):
89
- """
90
- Create a new DuckDB backed dataset repository
91
-
92
- Args:
93
- blob_storage: A managed blob storage provider
94
- config: DuckDB configuration options,
95
- https://duckdb.org/docs/configuration/overview.html#configuration-reference
96
- data_catalog: A list of tables to register with the DuckDB instance
97
- datasets_root: The root path in blob storage to managed datasets
98
- use_cache: Whether to cache datasets locally. Trades off disk memory
99
- and startup time for faster queries.
100
- """
101
- self._data_catalog = data_catalog
102
- self._data_sets_root = datasets_root
103
- self._blob_storage = blob_storage
104
- self._config = self._default_config if config is None else config
105
- self._use_cache = use_cache
106
-
107
- # Configure database and register tables
108
- self.db = duckdb.connect(config=self._config)
109
-
110
- # Dataset registration completion is tracked with an event. Consumers can
111
- # wait on this event to ensure the catalog is ready.
112
- self._registration_event = threading.Event()
113
- self._register_lock = threading.Lock()
114
-
115
- # Kick off dataset registration in a background thread so that IO-heavy
116
- # caching and view creation can run asynchronously without blocking the
117
- # caller that constructs this repository instance.
118
- self._register_thread = threading.Thread(target=self._register_datasets, daemon=True)
119
- self._register_thread.start()
120
-
121
- def _register_datasets(self):
122
- # Just in case this method gets called from inside a thread.
123
- # This operation isn't thread-safe by default, so we
124
- # synchronize the registration process.
125
- if self._registration_event.is_set():
126
- return
127
- with self._register_lock:
128
- # check once more to see if the catalog is ready it's possible a
129
- # previous thread already registered the dataset.
130
- if self._registration_event.is_set():
131
- return
132
- try:
133
- for table in self.data_catalog:
134
- key = table.source if table.schema == "main" else f"{table.schema}/{table.source}"
135
- if self._use_cache:
136
- tmp_root = Path(tempfile.gettempdir()) / "dd_cache"
137
- local_path = tmp_root / key
138
- local_path.parent.mkdir(parents=True, exist_ok=True)
139
- if not local_path.exists():
140
- start = time.time()
141
- logger.debug("Caching database %s to %s", table.name, local_path)
142
- with self._blob_storage.get_blob(f"{self._data_sets_root}/{key}") as src_fd:
143
- with open(local_path, "wb") as dst_fd:
144
- dst_fd.write(src_fd.read())
145
- logger.debug(
146
- "Cached database %s in %.2f s",
147
- table.name,
148
- time.time() - start,
149
- )
150
- data_path = local_path.as_posix()
151
- else:
152
- data_path = self._blob_storage.uri_for_key(f"{self._data_sets_root}/{key}")
153
- if table.schema != "main":
154
- self.db.sql(f"CREATE SCHEMA IF NOT EXISTS {table.schema}")
155
- logger.debug(f"Registering dataset {table.name} from {data_path}")
156
- self.db.sql(f"CREATE VIEW {table.schema}.{table.name} AS FROM '{data_path}'")
157
-
158
- logger.debug("DuckDBDatasetRepository registration complete")
159
-
160
- except Exception as e:
161
- logger.exception(f"Failed to register datasets: {str(e)}")
162
-
163
- finally:
164
- # Signal that registration is complete so any waiting queries can proceed.
165
- self._registration_event.set()
166
-
167
- def query(self, sql: str, parameters: list[Any]) -> pd.DataFrame:
168
- # Ensure dataset registration has completed. Possible future optimization:
169
- # pull datasets in parallel and only wait here if the query requires a
170
- # table that isn't cached.
171
- if not self._registration_event.is_set():
172
- logger.debug("Waiting for dataset caching and registration to finish...")
173
- self._registration_event.wait()
174
-
175
- # the duckdb connection isn't thread-safe, so we create a new
176
- # connection per query using cursor().
177
- # more details here: https://duckdb.org/docs/stable/guides/python/multiple_threads.html
178
- cursor = self.db.cursor()
179
- try:
180
- df = cursor.execute(sql, parameters).df()
181
- finally:
182
- cursor.close()
183
- return df
184
-
185
- @property
186
- def data_catalog(self) -> DataCatalog:
187
- return self._data_catalog
188
-
189
-
190
- def load_managed_dataset_repository(blob_storage: ManagedBlobStorage, locales: list[str]) -> ManagedDatasetRepository:
191
- return DuckDBDatasetRepository(
192
- blob_storage,
193
- config={"threads": 1, "memory_limit": "2 gb"},
194
- data_catalog=[Table(f"{locale}.parquet") for locale in locales],
195
- # Only cache if not using local storage.
196
- use_cache=not isinstance(blob_storage, LocalBlobStorageProvider),
197
- )
@@ -1,65 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import logging
7
- from abc import ABC, abstractmethod
8
- from collections.abc import Iterator
9
- from contextlib import contextmanager
10
- from pathlib import Path
11
- from typing import IO
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
- class ManagedBlobStorage(ABC):
17
- """
18
- Provides a low-level interface for access object in blob storage. This interface
19
- can be used to access model weights, raw datasets, or any artifact in blob
20
- storage.
21
-
22
- If you want a high-level interface for accessing datasets, use the `ManagedDatasetRepository`
23
- which provides a high-level SQL interface over each dataset.
24
- """
25
-
26
- @abstractmethod
27
- @contextmanager
28
- def get_blob(self, blob_key: str) -> Iterator[IO]: ...
29
-
30
- @abstractmethod
31
- def _key_uri_builder(self, key: str) -> str: ...
32
-
33
- def uri_for_key(self, key: str) -> str:
34
- """
35
- Returns a qualified storage URI for a given a key. `key` is
36
- normalized to ensure that and leading path components ("/") are removed.
37
- """
38
- return self._key_uri_builder(key.lstrip("/"))
39
-
40
-
41
- class LocalBlobStorageProvider(ManagedBlobStorage):
42
- """
43
- Provide a local blob storage service. Useful for running
44
- tests that don't require access to external infrastructure
45
- """
46
-
47
- def __init__(self, root_path: Path) -> None:
48
- self._root_path = root_path
49
-
50
- @contextmanager
51
- def get_blob(self, blob_key: str) -> Iterator[IO]:
52
- with open(self._key_uri_builder(blob_key), "rb") as fd:
53
- yield fd
54
-
55
- def _key_uri_builder(self, key: str) -> str:
56
- return f"{self._root_path}/{key}"
57
-
58
-
59
- def init_managed_blob_storage(assets_storage: str) -> ManagedBlobStorage:
60
- path = Path(assets_storage)
61
- if not path.exists():
62
- raise RuntimeError(f"Local storage path {assets_storage!r} does not exist.")
63
-
64
- logger.debug(f"Using local storage for managed datasets: {assets_storage!r}")
65
- return LocalBlobStorageProvider(Path(assets_storage))
@@ -1,77 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from data_designer.config.base import ConfigBase
7
- from data_designer.config.dataset_metadata import DatasetMetadata
8
- from data_designer.config.models import ModelConfig
9
- from data_designer.config.run_config import RunConfig
10
- from data_designer.config.seed_source import SeedSource
11
- from data_designer.config.utils.type_helpers import StrEnum
12
- from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
13
- from data_designer.engine.model_provider import ModelProviderRegistry
14
- from data_designer.engine.models.factory import create_model_registry
15
- from data_designer.engine.models.registry import ModelRegistry
16
- from data_designer.engine.resources.managed_storage import ManagedBlobStorage, init_managed_blob_storage
17
- from data_designer.engine.resources.seed_reader import SeedReader, SeedReaderRegistry
18
- from data_designer.engine.secret_resolver import SecretResolver
19
-
20
-
21
- class ResourceType(StrEnum):
22
- BLOB_STORAGE = "blob_storage"
23
- MODEL_REGISTRY = "model_registry"
24
- SEED_READER = "seed_reader"
25
-
26
-
27
- class ResourceProvider(ConfigBase):
28
- artifact_storage: ArtifactStorage
29
- blob_storage: ManagedBlobStorage | None = None
30
- model_registry: ModelRegistry | None = None
31
- run_config: RunConfig = RunConfig()
32
- seed_reader: SeedReader | None = None
33
-
34
- def get_dataset_metadata(self) -> DatasetMetadata:
35
- """Get metadata about the dataset being generated.
36
-
37
- Returns:
38
- DatasetMetadata with seed column names and other metadata.
39
- """
40
- seed_column_names = []
41
- if self.seed_reader is not None:
42
- seed_column_names = self.seed_reader.get_column_names()
43
- return DatasetMetadata(seed_column_names=seed_column_names)
44
-
45
-
46
- def create_resource_provider(
47
- *,
48
- artifact_storage: ArtifactStorage,
49
- model_configs: list[ModelConfig],
50
- secret_resolver: SecretResolver,
51
- model_provider_registry: ModelProviderRegistry,
52
- seed_reader_registry: SeedReaderRegistry,
53
- blob_storage: ManagedBlobStorage | None = None,
54
- seed_dataset_source: SeedSource | None = None,
55
- run_config: RunConfig | None = None,
56
- ) -> ResourceProvider:
57
- """Factory function for creating a ResourceProvider instance.
58
- This function triggers lazy loading of heavy dependencies like litellm.
59
- """
60
- seed_reader = None
61
- if seed_dataset_source:
62
- seed_reader = seed_reader_registry.get_reader(
63
- seed_dataset_source,
64
- secret_resolver,
65
- )
66
-
67
- return ResourceProvider(
68
- artifact_storage=artifact_storage,
69
- model_registry=create_model_registry(
70
- model_configs=model_configs,
71
- secret_resolver=secret_resolver,
72
- model_provider_registry=model_provider_registry,
73
- ),
74
- blob_storage=blob_storage or init_managed_blob_storage(),
75
- seed_reader=seed_reader,
76
- run_config=run_config or RunConfig(),
77
- )
@@ -1,154 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from abc import ABC, abstractmethod
7
- from collections.abc import Sequence
8
- from typing import TYPE_CHECKING, Generic, TypeVar, get_args, get_origin
9
-
10
- from huggingface_hub import HfFileSystem
11
- from typing_extensions import Self
12
-
13
- from data_designer.config.seed_source import (
14
- DataFrameSeedSource,
15
- HuggingFaceSeedSource,
16
- LocalFileSeedSource,
17
- SeedSource,
18
- )
19
- from data_designer.engine.secret_resolver import SecretResolver
20
- from data_designer.errors import DataDesignerError
21
- from data_designer.lazy_heavy_imports import duckdb
22
-
23
- if TYPE_CHECKING:
24
- import duckdb
25
-
26
-
27
- class SeedReaderError(DataDesignerError): ...
28
-
29
-
30
- SourceT = TypeVar("ConfigT", bound=SeedSource)
31
-
32
-
33
- class SeedReader(ABC, Generic[SourceT]):
34
- """Base class for reading a seed dataset.
35
-
36
- Seeds are read using duckdb. Reader implementations define duckdb connection setup details
37
- and how to get a URI that can be queried with duckdb (i.e. "... FROM <uri> ...").
38
-
39
- The Data Designer engine automatically supplies the appropriate SeedSource
40
- and a SecretResolver to use for any secret fields in the config.
41
- """
42
-
43
- source: SourceT
44
- secret_resolver: SecretResolver
45
-
46
- @abstractmethod
47
- def get_dataset_uri(self) -> str: ...
48
-
49
- @abstractmethod
50
- def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: ...
51
-
52
- def attach(self, source: SourceT, secret_resolver: SecretResolver):
53
- """Attach a source and secret resolver to the instance.
54
-
55
- This is called internally by the engine so that these objects do not
56
- need to be provided in the reader's constructor.
57
- """
58
- self.source = source
59
- self.secret_resolver = secret_resolver
60
-
61
- def get_column_names(self) -> list[str]:
62
- """Returns the seed dataset's column names"""
63
- conn = self.create_duckdb_connection()
64
- describe_query = f"DESCRIBE SELECT * FROM '{self.get_dataset_uri()}'"
65
- column_descriptions = conn.execute(describe_query).fetchall()
66
- return [col[0] for col in column_descriptions]
67
-
68
- def get_seed_type(self) -> str:
69
- """Return the seed_type of the source class this reader is generic over."""
70
- # Get the generic type arguments from the reader class
71
- # Check __orig_bases__ for the generic base class
72
- for base in getattr(type(self), "__orig_bases__", []):
73
- origin = get_origin(base)
74
- if origin is SeedReader:
75
- args = get_args(base)
76
- if args:
77
- source_cls = args[0]
78
- # Extract seed_type from the source class
79
- if hasattr(source_cls, "model_fields") and "seed_type" in source_cls.model_fields:
80
- field = source_cls.model_fields["seed_type"]
81
- default_value = field.default
82
- if isinstance(default_value, str):
83
- return default_value
84
-
85
- raise SeedReaderError("Reader does not have a valid generic source type with seed_type")
86
-
87
-
88
- class LocalFileSeedReader(SeedReader[LocalFileSeedSource]):
89
- def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
90
- return duckdb.connect()
91
-
92
- def get_dataset_uri(self) -> str:
93
- return self.source.path
94
-
95
-
96
- class HuggingFaceSeedReader(SeedReader[HuggingFaceSeedSource]):
97
- def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
98
- token = self.secret_resolver.resolve(self.source.token) if self.source.token else None
99
-
100
- # Use skip_instance_cache to avoid fsspec-level caching
101
- hffs = HfFileSystem(endpoint=self.source.endpoint, token=token, skip_instance_cache=True)
102
-
103
- # Clear all internal caches to avoid stale metadata issues
104
- # HfFileSystem caches file metadata (size, etc.) which can become stale when files are re-uploaded
105
- if hasattr(hffs, "dircache"):
106
- hffs.dircache.clear()
107
-
108
- conn = duckdb.connect()
109
- conn.register_filesystem(hffs)
110
- return conn
111
-
112
- def get_dataset_uri(self) -> str:
113
- return f"hf://{self.source.path}"
114
-
115
-
116
- class DataFrameSeedReader(SeedReader[DataFrameSeedSource]):
117
- # This is a "magic string" that gets registered in the duckdb connection to make the dataframe directly queryable.
118
- _table_name = "df"
119
-
120
- def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
121
- conn = duckdb.connect()
122
- conn.register(self._table_name, self.source.df)
123
- return conn
124
-
125
- def get_dataset_uri(self) -> str:
126
- return self._table_name
127
-
128
-
129
- class SeedReaderRegistry:
130
- def __init__(self, readers: Sequence[SeedReader]):
131
- self._readers: dict[str, SeedReader] = {}
132
- for reader in readers:
133
- self.add_reader(reader)
134
-
135
- def add_reader(self, reader: SeedReader) -> Self:
136
- seed_type = reader.get_seed_type()
137
-
138
- if seed_type in self._readers:
139
- raise SeedReaderError(f"A reader for seed_type {seed_type!r} already exists")
140
-
141
- self._readers[seed_type] = reader
142
- return self
143
-
144
- def get_reader(self, seed_dataset_source: SeedSource, secret_resolver: SecretResolver) -> SeedReader:
145
- reader = self._get_reader_for_source(seed_dataset_source)
146
- reader.attach(seed_dataset_source, secret_resolver)
147
- return reader
148
-
149
- def _get_reader_for_source(self, seed_dataset_source: SeedSource) -> SeedReader:
150
- seed_type = seed_dataset_source.seed_type
151
- try:
152
- return self._readers[seed_type]
153
- except KeyError:
154
- raise SeedReaderError(f"No reader found for seed_type {seed_type!r}")
@@ -1,91 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from typing import Any
7
-
8
- from pydantic import field_serializer, model_validator
9
- from typing_extensions import Self
10
-
11
- from data_designer.config.column_configs import SamplerColumnConfig
12
- from data_designer.config.sampler_params import SamplerParamsT, SamplerType
13
- from data_designer.engine.sampling_gen.data_sources.base import DataSource
14
- from data_designer.engine.sampling_gen.data_sources.sources import SamplerRegistry
15
- from data_designer.engine.sampling_gen.jinja_utils import extract_column_names_from_expression
16
-
17
-
18
- class ConditionalDataColumn(SamplerColumnConfig):
19
- @property
20
- def _negative_condition(self) -> str:
21
- conditions = list(self.conditional_params.keys())
22
- return "not (" + " or ".join([f"({c})" for c in conditions]) + ")"
23
-
24
- @property
25
- def conditions(self) -> list[str]:
26
- c = list(self.conditional_params.keys())
27
- return c + [self._negative_condition] if len(c) > 0 else ["..."]
28
-
29
- @property
30
- def conditional_column_names(self) -> set[str]:
31
- names = set()
32
- for condition in self.conditional_params.keys():
33
- names.update(extract_column_names_from_expression(condition))
34
- return names
35
-
36
- @field_serializer("sampler_type")
37
- def serialize_sampler_type(self, sampler_type: SamplerType) -> str:
38
- return SamplerType(sampler_type).value
39
-
40
- @field_serializer("params")
41
- def serialize_params(self, params: SamplerParamsT) -> dict:
42
- return params.model_dump()
43
-
44
- @field_serializer("conditional_params")
45
- def serialize_conditional_params(self, conditional_params: dict[str, SamplerParamsT]) -> dict:
46
- for condition, params in conditional_params.items():
47
- conditional_params[condition] = params.model_dump()
48
- return conditional_params
49
-
50
- @model_validator(mode="before")
51
- @classmethod
52
- def validate_params_with_type(cls, data: Any) -> Any:
53
- if not isinstance(data, dict) or "sampler_type" not in data:
54
- return data
55
- if isinstance(data["sampler_type"], str):
56
- if not SamplerRegistry.is_registered(data["sampler_type"]):
57
- raise ValueError(
58
- f"Invalid sampler type: {data['sampler_type']}. Available samplers: {[s.value for s in SamplerType]}"
59
- )
60
- if "params" in data:
61
- data["params"] = SamplerRegistry.get_sampler(data["sampler_type"])(params=data["params"]).params
62
- if "conditional_params" in data:
63
- for condition, params in data["conditional_params"].items():
64
- data["conditional_params"][condition] = SamplerRegistry.get_sampler(data["sampler_type"])(
65
- params=params
66
- ).params
67
- return data
68
-
69
- @model_validator(mode="after")
70
- def validate_params(self) -> Self:
71
- self.params = SamplerRegistry.validate_sampler_type(self.sampler_type)(params=self.params).params
72
- return self
73
-
74
- @model_validator(mode="after")
75
- def validate_data_conversion(self) -> Self:
76
- self.get_default_sampler().validate_data_conversion(self.convert_to)
77
- return self
78
-
79
- @model_validator(mode="after")
80
- def validate_conditional_params(self) -> Self:
81
- for condition, params in self.conditional_params.items():
82
- self.conditional_params[condition] = SamplerRegistry.get_sampler(self.sampler_type)(params=params).params
83
- return self
84
-
85
- def get_default_sampler(self, **kwargs) -> DataSource:
86
- return self.get_sampler("...", **kwargs)
87
-
88
- def get_sampler(self, condition: str, **kwargs) -> DataSource:
89
- if condition in ["...", self._negative_condition]:
90
- return SamplerRegistry.get_sampler(self.sampler_type)(self.params, **kwargs)
91
- return SamplerRegistry.get_sampler(self.sampler_type)(self.conditional_params[condition], **kwargs)