data-designer 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. data_designer/_version.py +2 -2
  2. data_designer/cli/forms/model_builder.py +2 -2
  3. data_designer/config/config_builder.py +30 -113
  4. data_designer/config/errors.py +3 -0
  5. data_designer/config/exports.py +8 -6
  6. data_designer/config/models.py +7 -18
  7. data_designer/config/run_config.py +34 -0
  8. data_designer/config/seed.py +16 -46
  9. data_designer/config/seed_source.py +84 -0
  10. data_designer/config/utils/constants.py +27 -2
  11. data_designer/config/utils/io_helpers.py +0 -20
  12. data_designer/engine/column_generators/generators/seed_dataset.py +5 -5
  13. data_designer/engine/column_generators/generators/validation.py +3 -0
  14. data_designer/engine/column_generators/registry.py +1 -1
  15. data_designer/engine/compiler.py +69 -0
  16. data_designer/engine/dataset_builders/column_wise_builder.py +3 -0
  17. data_designer/engine/dataset_builders/utils/config_compiler.py +1 -1
  18. data_designer/engine/models/facade.py +2 -0
  19. data_designer/engine/processing/gsonschema/validators.py +55 -0
  20. data_designer/engine/resources/resource_provider.py +17 -5
  21. data_designer/engine/resources/seed_reader.py +149 -0
  22. data_designer/essentials/__init__.py +2 -0
  23. data_designer/interface/data_designer.py +72 -62
  24. data_designer/plugin_manager.py +1 -1
  25. data_designer/plugins/errors.py +3 -0
  26. data_designer/plugins/plugin.py +82 -12
  27. data_designer/plugins/testing/__init__.py +8 -0
  28. data_designer/plugins/testing/stubs.py +145 -0
  29. data_designer/plugins/testing/utils.py +11 -0
  30. {data_designer-0.2.3.dist-info → data_designer-0.3.1.dist-info}/METADATA +3 -3
  31. {data_designer-0.2.3.dist-info → data_designer-0.3.1.dist-info}/RECORD +35 -30
  32. data_designer/config/datastore.py +0 -187
  33. data_designer/engine/resources/seed_dataset_data_store.py +0 -84
  34. /data_designer/{config/utils → engine}/validation.py +0 -0
  35. {data_designer-0.2.3.dist-info → data_designer-0.3.1.dist-info}/WHEEL +0 -0
  36. {data_designer-0.2.3.dist-info → data_designer-0.3.1.dist-info}/entry_points.txt +0 -0
  37. {data_designer-0.2.3.dist-info → data_designer-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -282,6 +282,10 @@ OPENAI_PROVIDER_NAME = "openai"
282
282
 
283
283
  OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY"
284
284
 
285
+ OPENROUTER_PROVIDER_NAME = "openrouter"
286
+
287
+ OPENROUTER_API_KEY_ENV_VAR_NAME = "OPENROUTER_API_KEY"
288
+
285
289
  PREDEFINED_PROVIDERS = [
286
290
  {
287
291
  "name": NVIDIA_PROVIDER_NAME,
@@ -295,6 +299,12 @@ PREDEFINED_PROVIDERS = [
295
299
  "provider_type": "openai",
296
300
  "api_key": OPENAI_API_KEY_ENV_VAR_NAME,
297
301
  },
302
+ {
303
+ "name": OPENROUTER_PROVIDER_NAME,
304
+ "endpoint": "https://openrouter.ai/api/v1",
305
+ "provider_type": "openai",
306
+ "api_key": OPENROUTER_API_KEY_ENV_VAR_NAME,
307
+ },
298
308
  ]
299
309
 
300
310
 
@@ -302,11 +312,14 @@ DEFAULT_TEXT_INFERENCE_PARAMS = {"temperature": 0.85, "top_p": 0.95}
302
312
  DEFAULT_REASONING_INFERENCE_PARAMS = {"temperature": 0.35, "top_p": 0.95}
303
313
  DEFAULT_VISION_INFERENCE_PARAMS = {"temperature": 0.85, "top_p": 0.95}
304
314
  DEFAULT_EMBEDDING_INFERENCE_PARAMS = {"encoding_format": "float"}
305
-
315
+ NEMOTRON_3_NANO_30B_A3B_INFERENCE_PARAMS = {"temperature": 1.0, "top_p": 1.0}
306
316
 
307
317
  PREDEFINED_PROVIDERS_MODEL_MAP = {
308
318
  NVIDIA_PROVIDER_NAME: {
309
- "text": {"model": "nvidia/nemotron-3-nano-30b-a3b", "inference_parameters": {"temperature": 1.0, "top_p": 1.0}},
319
+ "text": {
320
+ "model": "nvidia/nemotron-3-nano-30b-a3b",
321
+ "inference_parameters": NEMOTRON_3_NANO_30B_A3B_INFERENCE_PARAMS,
322
+ },
310
323
  "reasoning": {"model": "openai/gpt-oss-20b", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
311
324
  "vision": {"model": "nvidia/nemotron-nano-12b-v2-vl", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
312
325
  "embedding": {
@@ -320,6 +333,18 @@ PREDEFINED_PROVIDERS_MODEL_MAP = {
320
333
  "vision": {"model": "gpt-5", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
321
334
  "embedding": {"model": "text-embedding-3-large", "inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS},
322
335
  },
336
+ OPENROUTER_PROVIDER_NAME: {
337
+ "text": {
338
+ "model": "nvidia/nemotron-3-nano-30b-a3b",
339
+ "inference_parameters": NEMOTRON_3_NANO_30B_A3B_INFERENCE_PARAMS,
340
+ },
341
+ "reasoning": {"model": "openai/gpt-oss-20b", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
342
+ "vision": {"model": "nvidia/nemotron-nano-12b-v2-vl", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
343
+ "embedding": {
344
+ "model": "openai/text-embedding-3-large",
345
+ "inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS,
346
+ },
347
+ },
323
348
  }
324
349
 
325
350
  # Persona locale metadata - used by the CLI and the person sampler.
@@ -108,26 +108,6 @@ def read_parquet_dataset(path: Path) -> pd.DataFrame:
108
108
  raise e
109
109
 
110
110
 
111
- def write_seed_dataset(dataframe: pd.DataFrame, file_path: Path) -> None:
112
- """Write a seed dataset to a file in the specified format.
113
-
114
- Supported file extensions: .parquet, .csv, .json, .jsonl
115
-
116
- Args:
117
- dataframe: The pandas DataFrame to write.
118
- file_path: The path where the dataset should be saved.
119
- Format is inferred from the file extension.
120
- """
121
- file_path = validate_dataset_file_path(file_path, should_exist=False)
122
- logger.info(f"💾 Saving seed dataset to {file_path}")
123
- if file_path.suffix.lower() == ".parquet":
124
- dataframe.to_parquet(file_path, index=False)
125
- elif file_path.suffix.lower() == ".csv":
126
- dataframe.to_csv(file_path, index=False)
127
- elif file_path.suffix.lower() in {".json", ".jsonl"}:
128
- dataframe.to_json(file_path, orient="records", lines=True)
129
-
130
-
131
111
  def validate_dataset_file_path(file_path: str | Path, should_exist: bool = True) -> Path:
132
112
  """Validate that a dataset file path has a valid extension and optionally exists.
133
113
 
@@ -30,7 +30,7 @@ class SeedDatasetColumnGenerator(FromScratchColumnGenerator[SeedDatasetMultiColu
30
30
  name="seed_dataset_column_generator",
31
31
  description="Sample columns from a seed dataset.",
32
32
  generation_strategy=GenerationStrategy.FULL_COLUMN,
33
- required_resources=[ResourceType.DATASTORE],
33
+ required_resources=[ResourceType.SEED_READER],
34
34
  )
35
35
 
36
36
  @property
@@ -39,10 +39,10 @@ class SeedDatasetColumnGenerator(FromScratchColumnGenerator[SeedDatasetMultiColu
39
39
 
40
40
  @functools.cached_property
41
41
  def duckdb_conn(self) -> duckdb.DuckDBPyConnection:
42
- return self.resource_provider.datastore.create_duckdb_connection()
42
+ return self.resource_provider.seed_reader.create_duckdb_connection()
43
43
 
44
- def generate(self, dataset: pd.DataFrame) -> pd.DataFrame:
45
- return concat_datasets([self.generate_from_scratch(len(dataset)), dataset])
44
+ def generate(self, data: pd.DataFrame) -> pd.DataFrame:
45
+ return concat_datasets([self.generate_from_scratch(len(data)), data])
46
46
 
47
47
  def generate_from_scratch(self, num_records: int) -> pd.DataFrame:
48
48
  if num_records <= 0:
@@ -57,7 +57,7 @@ class SeedDatasetColumnGenerator(FromScratchColumnGenerator[SeedDatasetMultiColu
57
57
  self._num_records_sampled = 0
58
58
  self._batch_reader = None
59
59
  self._df_remaining = None
60
- self._dataset_uri = self.resource_provider.datastore.get_dataset_uri(self.config.dataset)
60
+ self._dataset_uri = self.resource_provider.seed_reader.get_dataset_uri()
61
61
  self._seed_dataset_size = self.duckdb_conn.execute(f"SELECT COUNT(*) FROM '{self._dataset_uri}'").fetchone()[0]
62
62
  self._index_range = self._resolve_index_range()
63
63
 
@@ -123,11 +123,14 @@ class ValidationColumnGenerator(ColumnGenerator[ValidationColumnConfig]):
123
123
  def error_callback(error: Exception, context: dict):
124
124
  outputs[context["index"]] = ValidationResult.empty(size=len(batched_records[context["index"]]))
125
125
 
126
+ settings = self.resource_provider.run_config
126
127
  with ConcurrentThreadExecutor(
127
128
  max_workers=self.config.validator_params.max_parallel_requests,
128
129
  column_name=self.config.name,
129
130
  result_callback=result_callback,
130
131
  error_callback=error_callback,
132
+ shutdown_error_rate=settings.shutdown_error_rate,
133
+ shutdown_error_window=settings.shutdown_error_window,
131
134
  ) as executor:
132
135
  for i, batch in enumerate(batched_records):
133
136
  executor.submit(lambda batch: self._validate_batch(validator, batch), batch, context={"index": i})
@@ -51,7 +51,7 @@ def create_default_column_generator_registry(with_plugins: bool = True) -> Colum
51
51
  for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR):
52
52
  registry.register(
53
53
  DataDesignerColumnType(plugin.name),
54
- plugin.task_cls,
54
+ plugin.impl_cls,
55
55
  plugin.config_cls,
56
56
  )
57
57
 
@@ -0,0 +1,69 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import logging
5
+
6
+ from data_designer.config.column_configs import SeedDatasetColumnConfig
7
+ from data_designer.config.config_builder import DataDesignerConfigBuilder
8
+ from data_designer.config.data_designer_config import DataDesignerConfig
9
+ from data_designer.config.errors import InvalidConfigError
10
+ from data_designer.engine.resources.resource_provider import ResourceProvider
11
+ from data_designer.engine.resources.seed_reader import SeedReader
12
+ from data_designer.engine.validation import ViolationLevel, rich_print_violations, validate_data_designer_config
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def compile_data_designer_config(
18
+ config_builder: DataDesignerConfigBuilder, resource_provider: ResourceProvider
19
+ ) -> DataDesignerConfig:
20
+ config = config_builder.build()
21
+ _resolve_and_add_seed_columns(config, resource_provider.seed_reader)
22
+ _validate(config)
23
+
24
+ return config
25
+
26
+
27
+ def _resolve_and_add_seed_columns(config: DataDesignerConfig, seed_reader: SeedReader | None) -> None:
28
+ """Fetches the seed dataset column names, ensures there are no conflicts
29
+ with other columns, and adds seed column configs to the DataDesignerConfig.
30
+ """
31
+
32
+ if not seed_reader:
33
+ return
34
+
35
+ seed_col_names = seed_reader.get_column_names()
36
+ existing_columns = {column.name for column in config.columns}
37
+ colliding_columns = {name for name in seed_col_names if name in existing_columns}
38
+ if colliding_columns:
39
+ raise InvalidConfigError(
40
+ f"🛑 Seed dataset column(s) {colliding_columns} collide with existing column(s). "
41
+ "Please remove the conflicting columns or use a seed dataset with different column names."
42
+ )
43
+
44
+ config.columns.extend([SeedDatasetColumnConfig(name=col_name) for col_name in seed_col_names])
45
+
46
+
47
+ def _validate(config: DataDesignerConfig) -> None:
48
+ allowed_references = _get_allowed_references(config)
49
+ violations = validate_data_designer_config(
50
+ columns=config.columns,
51
+ processor_configs=config.processors or [],
52
+ allowed_references=allowed_references,
53
+ )
54
+ rich_print_violations(violations)
55
+ if len([v for v in violations if v.level == ViolationLevel.ERROR]) > 0:
56
+ raise InvalidConfigError(
57
+ "🛑 Your configuration contains validation errors. Please address the indicated issues and try again."
58
+ )
59
+ if len(violations) == 0:
60
+ logger.info("✅ Validation passed")
61
+
62
+
63
+ def _get_allowed_references(config: DataDesignerConfig) -> list[str]:
64
+ refs = set[str]()
65
+ for column_config in config.columns:
66
+ refs.add(column_config.name)
67
+ for side_effect_column in column_config.side_effect_columns:
68
+ refs.add(side_effect_column)
69
+ return list(refs)
@@ -217,11 +217,14 @@ class ColumnWiseDatasetBuilder:
217
217
  f"🐙 Processing {generator.config.column_type} column '{generator.config.name}' "
218
218
  f"with {max_workers} concurrent workers"
219
219
  )
220
+ settings = self._resource_provider.run_config
220
221
  with ConcurrentThreadExecutor(
221
222
  max_workers=max_workers,
222
223
  column_name=generator.config.name,
223
224
  result_callback=self._worker_result_callback,
224
225
  error_callback=self._worker_error_callback,
226
+ shutdown_error_rate=settings.shutdown_error_rate,
227
+ shutdown_error_window=settings.shutdown_error_window,
225
228
  ) as executor:
226
229
  for i, record in self.batch_manager.iter_current_batch():
227
230
  executor.submit(lambda record: generator.generate(record), record, context={"index": i})
@@ -34,7 +34,7 @@ def compile_dataset_builder_column_configs(config: DataDesignerConfig) -> list[D
34
34
  compiled_column_configs.append(
35
35
  SeedDatasetMultiColumnConfig(
36
36
  columns=seed_column_configs,
37
- dataset=config.seed_config.dataset,
37
+ source=config.seed_config.source,
38
38
  sampling_strategy=config.seed_config.sampling_strategy,
39
39
  selection_strategy=config.seed_config.selection_strategy,
40
40
  )
@@ -96,6 +96,8 @@ class ModelFacade:
96
96
  kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs}
97
97
  if self.model_provider.extra_body:
98
98
  kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
99
+ if self.model_provider.extra_headers:
100
+ kwargs["extra_headers"] = self.model_provider.extra_headers
99
101
  return kwargs
100
102
 
101
103
  @catch_llm_exceptions
@@ -2,7 +2,9 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  import logging
5
+ import re
5
6
  from copy import deepcopy
7
+ from decimal import ROUND_HALF_UP, Decimal
6
8
  from typing import Any, overload
7
9
 
8
10
  from jsonschema import Draft202012Validator, ValidationError, validators
@@ -70,6 +72,57 @@ def extend_jsonschema_validator_with_pruning(validator):
70
72
  return validators.extend(validator, {"additionalProperties": prune_additional_properties})
71
73
 
72
74
 
75
+ def _get_decimal_info_from_anyof(schema: dict) -> tuple[bool, int | None]:
76
+ """Check if schema is a Decimal anyOf and extract decimal places.
77
+
78
+ Returns (is_decimal, decimal_places) where decimal_places is None if no constraint.
79
+ """
80
+ any_of = schema.get("anyOf")
81
+ if not isinstance(any_of, list):
82
+ return False, None
83
+
84
+ has_number = any(item.get("type") == "number" for item in any_of)
85
+ if not has_number:
86
+ return False, None
87
+
88
+ for item in any_of:
89
+ if item.get("type") == "string" and "pattern" in item:
90
+ match = re.search(r"\\d\{0,(\d+)\}", item["pattern"])
91
+ if match:
92
+ return True, int(match.group(1))
93
+ return True, None # Decimal without precision constraint
94
+ return False, None
95
+
96
+
97
+ def normalize_decimal_fields(obj: DataObjectT, schema: JSONSchemaT) -> DataObjectT:
98
+ """Normalize Decimal-like anyOf fields to floats with proper precision."""
99
+ if not isinstance(obj, dict):
100
+ return obj
101
+
102
+ defs = schema.get("$defs", {})
103
+ obj_schema = defs.get(schema.get("$ref", "")[len("#/$defs/") :], schema)
104
+ props = obj_schema.get("properties", {})
105
+
106
+ for key, value in obj.items():
107
+ field_schema = props.get(key, {})
108
+ if "$ref" in field_schema:
109
+ field_schema = defs.get(field_schema["$ref"][len("#/$defs/") :], {})
110
+
111
+ if isinstance(value, dict):
112
+ obj[key] = normalize_decimal_fields(value, schema)
113
+ elif isinstance(value, list):
114
+ obj[key] = [normalize_decimal_fields(v, schema) if isinstance(v, dict) else v for v in value]
115
+ elif isinstance(value, (int, float, str)) and not isinstance(value, bool):
116
+ is_decimal, decimal_places = _get_decimal_info_from_anyof(field_schema)
117
+ if is_decimal:
118
+ d = Decimal(str(value))
119
+ if decimal_places is not None:
120
+ d = d.quantize(Decimal(f"0.{'0' * decimal_places}"), rounding=ROUND_HALF_UP)
121
+ obj[key] = float(d)
122
+
123
+ return obj
124
+
125
+
73
126
  ## We don't expect the outer data type (e.g. dict, list, or const) to be
74
127
  ## modified by the pruning action.
75
128
  @overload
@@ -140,4 +193,6 @@ def validate(
140
193
  except ValidationError as exc:
141
194
  raise JSONSchemaValidationError(str(exc)) from exc
142
195
 
196
+ final_object = normalize_decimal_fields(final_object, schema)
197
+
143
198
  return final_object
@@ -3,26 +3,29 @@
3
3
 
4
4
  from data_designer.config.base import ConfigBase
5
5
  from data_designer.config.models import ModelConfig
6
+ from data_designer.config.run_config import RunConfig
7
+ from data_designer.config.seed_source import SeedSource
6
8
  from data_designer.config.utils.type_helpers import StrEnum
7
9
  from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
8
10
  from data_designer.engine.model_provider import ModelProviderRegistry
9
11
  from data_designer.engine.models.registry import ModelRegistry, create_model_registry
10
12
  from data_designer.engine.resources.managed_storage import ManagedBlobStorage, init_managed_blob_storage
11
- from data_designer.engine.resources.seed_dataset_data_store import SeedDatasetDataStore
13
+ from data_designer.engine.resources.seed_reader import SeedReader, SeedReaderRegistry
12
14
  from data_designer.engine.secret_resolver import SecretResolver
13
15
 
14
16
 
15
17
  class ResourceType(StrEnum):
16
18
  BLOB_STORAGE = "blob_storage"
17
- DATASTORE = "datastore"
18
19
  MODEL_REGISTRY = "model_registry"
20
+ SEED_READER = "seed_reader"
19
21
 
20
22
 
21
23
  class ResourceProvider(ConfigBase):
22
24
  artifact_storage: ArtifactStorage
23
25
  blob_storage: ManagedBlobStorage | None = None
24
- datastore: SeedDatasetDataStore | None = None
25
26
  model_registry: ModelRegistry | None = None
27
+ run_config: RunConfig = RunConfig()
28
+ seed_reader: SeedReader | None = None
26
29
 
27
30
 
28
31
  def create_resource_provider(
@@ -31,16 +34,25 @@ def create_resource_provider(
31
34
  model_configs: list[ModelConfig],
32
35
  secret_resolver: SecretResolver,
33
36
  model_provider_registry: ModelProviderRegistry,
34
- datastore: SeedDatasetDataStore | None = None,
37
+ seed_reader_registry: SeedReaderRegistry,
35
38
  blob_storage: ManagedBlobStorage | None = None,
39
+ seed_dataset_source: SeedSource | None = None,
40
+ run_config: RunConfig | None = None,
36
41
  ) -> ResourceProvider:
42
+ seed_reader = None
43
+ if seed_dataset_source:
44
+ seed_reader = seed_reader_registry.get_reader(
45
+ seed_dataset_source,
46
+ secret_resolver,
47
+ )
37
48
  return ResourceProvider(
38
49
  artifact_storage=artifact_storage,
39
- datastore=datastore,
40
50
  model_registry=create_model_registry(
41
51
  model_configs=model_configs,
42
52
  secret_resolver=secret_resolver,
43
53
  model_provider_registry=model_provider_registry,
44
54
  ),
45
55
  blob_storage=blob_storage or init_managed_blob_storage(),
56
+ seed_reader=seed_reader,
57
+ run_config=run_config or RunConfig(),
46
58
  )
@@ -0,0 +1,149 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC, abstractmethod
5
+ from collections.abc import Sequence
6
+ from typing import Generic, TypeVar, get_args, get_origin
7
+
8
+ import duckdb
9
+ from huggingface_hub import HfFileSystem
10
+ from typing_extensions import Self
11
+
12
+ from data_designer.config.seed_source import (
13
+ DataFrameSeedSource,
14
+ HuggingFaceSeedSource,
15
+ LocalFileSeedSource,
16
+ SeedSource,
17
+ )
18
+ from data_designer.engine.secret_resolver import SecretResolver
19
+ from data_designer.errors import DataDesignerError
20
+
21
+
22
+ class SeedReaderError(DataDesignerError): ...
23
+
24
+
25
+ SourceT = TypeVar("ConfigT", bound=SeedSource)
26
+
27
+
28
+ class SeedReader(ABC, Generic[SourceT]):
29
+ """Base class for reading a seed dataset.
30
+
31
+ Seeds are read using duckdb. Reader implementations define duckdb connection setup details
32
+ and how to get a URI that can be queried with duckdb (i.e. "... FROM <uri> ...").
33
+
34
+ The Data Designer engine automatically supplies the appropriate SeedSource
35
+ and a SecretResolver to use for any secret fields in the config.
36
+ """
37
+
38
+ source: SourceT
39
+ secret_resolver: SecretResolver
40
+
41
+ @abstractmethod
42
+ def get_dataset_uri(self) -> str: ...
43
+
44
+ @abstractmethod
45
+ def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: ...
46
+
47
+ def attach(self, source: SourceT, secret_resolver: SecretResolver):
48
+ """Attach a source and secret resolver to the instance.
49
+
50
+ This is called internally by the engine so that these objects do not
51
+ need to be provided in the reader's constructor.
52
+ """
53
+ self.source = source
54
+ self.secret_resolver = secret_resolver
55
+
56
+ def get_column_names(self) -> list[str]:
57
+ """Returns the seed dataset's column names"""
58
+ conn = self.create_duckdb_connection()
59
+ describe_query = f"DESCRIBE SELECT * FROM '{self.get_dataset_uri()}'"
60
+ column_descriptions = conn.execute(describe_query).fetchall()
61
+ return [col[0] for col in column_descriptions]
62
+
63
+ def get_seed_type(self) -> str:
64
+ """Return the seed_type of the source class this reader is generic over."""
65
+ # Get the generic type arguments from the reader class
66
+ # Check __orig_bases__ for the generic base class
67
+ for base in getattr(type(self), "__orig_bases__", []):
68
+ origin = get_origin(base)
69
+ if origin is SeedReader:
70
+ args = get_args(base)
71
+ if args:
72
+ source_cls = args[0]
73
+ # Extract seed_type from the source class
74
+ if hasattr(source_cls, "model_fields") and "seed_type" in source_cls.model_fields:
75
+ field = source_cls.model_fields["seed_type"]
76
+ default_value = field.default
77
+ if isinstance(default_value, str):
78
+ return default_value
79
+
80
+ raise SeedReaderError("Reader does not have a valid generic source type with seed_type")
81
+
82
+
83
+ class LocalFileSeedReader(SeedReader[LocalFileSeedSource]):
84
+ def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
85
+ return duckdb.connect()
86
+
87
+ def get_dataset_uri(self) -> str:
88
+ return self.source.path
89
+
90
+
91
+ class HuggingFaceSeedReader(SeedReader[HuggingFaceSeedSource]):
92
+ def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
93
+ token = self.secret_resolver.resolve(self.source.token) if self.source.token else None
94
+
95
+ # Use skip_instance_cache to avoid fsspec-level caching
96
+ hffs = HfFileSystem(endpoint=self.source.endpoint, token=token, skip_instance_cache=True)
97
+
98
+ # Clear all internal caches to avoid stale metadata issues
99
+ # HfFileSystem caches file metadata (size, etc.) which can become stale when files are re-uploaded
100
+ if hasattr(hffs, "dircache"):
101
+ hffs.dircache.clear()
102
+
103
+ conn = duckdb.connect()
104
+ conn.register_filesystem(hffs)
105
+ return conn
106
+
107
+ def get_dataset_uri(self) -> str:
108
+ return f"hf://{self.source.path}"
109
+
110
+
111
+ class DataFrameSeedReader(SeedReader[DataFrameSeedSource]):
112
+ # This is a "magic string" that gets registered in the duckdb connection to make the dataframe directly queryable.
113
+ _table_name = "df"
114
+
115
+ def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
116
+ conn = duckdb.connect()
117
+ conn.register(self._table_name, self.source.df)
118
+ return conn
119
+
120
+ def get_dataset_uri(self) -> str:
121
+ return self._table_name
122
+
123
+
124
+ class SeedReaderRegistry:
125
+ def __init__(self, readers: Sequence[SeedReader]):
126
+ self._readers: dict[str, SeedReader] = {}
127
+ for reader in readers:
128
+ self.add_reader(reader)
129
+
130
+ def add_reader(self, reader: SeedReader) -> Self:
131
+ seed_type = reader.get_seed_type()
132
+
133
+ if seed_type in self._readers:
134
+ raise SeedReaderError(f"A reader for seed_type {seed_type!r} already exists")
135
+
136
+ self._readers[seed_type] = reader
137
+ return self
138
+
139
+ def get_reader(self, seed_dataset_source: SeedSource, secret_resolver: SecretResolver) -> SeedReader:
140
+ reader = self._get_reader_for_source(seed_dataset_source)
141
+ reader.attach(seed_dataset_source, secret_resolver)
142
+ return reader
143
+
144
+ def _get_reader_for_source(self, seed_dataset_source: SeedSource) -> SeedReader:
145
+ seed_type = seed_dataset_source.seed_type
146
+ try:
147
+ return self._readers[seed_type]
148
+ except KeyError:
149
+ raise SeedReaderError(f"No reader found for seed_type {seed_type!r}")
@@ -3,6 +3,7 @@
3
3
 
4
4
  from data_designer.config.default_model_settings import resolve_seed_default_model_settings
5
5
  from data_designer.config.exports import * # noqa: F403
6
+ from data_designer.config.run_config import RunConfig
6
7
  from data_designer.config.validator_params import LocalCallableValidatorParams
7
8
  from data_designer.interface.data_designer import DataDesigner
8
9
  from data_designer.logging import LoggingConfig, configure_logging
@@ -21,6 +22,7 @@ def get_essentials_exports() -> list[str]:
21
22
  local = [
22
23
  DataDesigner.__name__,
23
24
  LocalCallableValidatorParams.__name__,
25
+ RunConfig.__name__,
24
26
  ]
25
27
 
26
28
  return logging + local + get_config_exports() # noqa: F405