data-designer 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. data_designer/_version.py +2 -2
  2. data_designer/cli/forms/model_builder.py +2 -2
  3. data_designer/config/config_builder.py +30 -113
  4. data_designer/config/errors.py +3 -0
  5. data_designer/config/exports.py +8 -6
  6. data_designer/config/models.py +7 -18
  7. data_designer/config/run_config.py +34 -0
  8. data_designer/config/seed.py +16 -46
  9. data_designer/config/seed_source.py +73 -0
  10. data_designer/config/utils/constants.py +27 -2
  11. data_designer/config/utils/io_helpers.py +0 -20
  12. data_designer/engine/column_generators/generators/seed_dataset.py +5 -5
  13. data_designer/engine/column_generators/generators/validation.py +3 -0
  14. data_designer/engine/column_generators/registry.py +1 -1
  15. data_designer/engine/compiler.py +69 -0
  16. data_designer/engine/dataset_builders/column_wise_builder.py +3 -0
  17. data_designer/engine/dataset_builders/utils/config_compiler.py +1 -1
  18. data_designer/engine/models/facade.py +2 -0
  19. data_designer/engine/processing/gsonschema/validators.py +55 -0
  20. data_designer/engine/resources/resource_provider.py +17 -5
  21. data_designer/engine/resources/seed_reader.py +149 -0
  22. data_designer/essentials/__init__.py +2 -0
  23. data_designer/interface/data_designer.py +72 -62
  24. data_designer/plugin_manager.py +1 -1
  25. data_designer/plugins/errors.py +3 -0
  26. data_designer/plugins/plugin.py +82 -12
  27. data_designer/plugins/testing/__init__.py +8 -0
  28. data_designer/plugins/testing/stubs.py +145 -0
  29. data_designer/plugins/testing/utils.py +11 -0
  30. {data_designer-0.2.3.dist-info → data_designer-0.3.0.dist-info}/METADATA +3 -3
  31. {data_designer-0.2.3.dist-info → data_designer-0.3.0.dist-info}/RECORD +35 -30
  32. data_designer/config/datastore.py +0 -187
  33. data_designer/engine/resources/seed_dataset_data_store.py +0 -84
  34. /data_designer/{config/utils → engine}/validation.py +0 -0
  35. {data_designer-0.2.3.dist-info → data_designer-0.3.0.dist-info}/WHEEL +0 -0
  36. {data_designer-0.2.3.dist-info → data_designer-0.3.0.dist-info}/entry_points.txt +0 -0
  37. {data_designer-0.2.3.dist-info → data_designer-0.3.0.dist-info}/licenses/LICENSE +0 -0
data_designer/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.3'
32
- __version_tuple__ = version_tuple = (0, 2, 3)
31
+ __version__ = version = '0.3.0'
32
+ __version_tuple__ = version_tuple = (0, 3, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -125,11 +125,11 @@ class ModelFormBuilder(FormBuilder[ModelConfig]):
125
125
  fields.append(
126
126
  NumericField(
127
127
  "max_tokens",
128
- "Max tokens <dim>(maximum total tokens including input and output)</dim>",
128
+ "Max tokens <dim>(maximum tokens to generate in response)</dim>",
129
129
  default=initial_params.get("max_tokens"),
130
130
  min_value=1.0,
131
131
  required=False,
132
- help_text="Maximum number of tokens including both input prompt and generated response",
132
+ help_text="Maximum number of tokens to generate in the response",
133
133
  )
134
134
  )
135
135
 
@@ -24,9 +24,8 @@ from data_designer.config.column_types import (
24
24
  )
25
25
  from data_designer.config.data_designer_config import DataDesignerConfig
26
26
  from data_designer.config.dataset_builders import BuildStage
27
- from data_designer.config.datastore import DatastoreSettings, fetch_seed_dataset_column_names
28
27
  from data_designer.config.default_model_settings import get_default_model_configs
29
- from data_designer.config.errors import BuilderConfigurationError, InvalidColumnTypeError, InvalidConfigError
28
+ from data_designer.config.errors import BuilderConfigurationError, BuilderSerializationError, InvalidColumnTypeError
30
29
  from data_designer.config.models import ModelConfig, load_model_configs
31
30
  from data_designer.config.processors import ProcessorConfigT, ProcessorType, get_processor_config_from_kwargs
32
31
  from data_designer.config.sampler_constraints import (
@@ -36,20 +35,17 @@ from data_designer.config.sampler_constraints import (
36
35
  ScalarInequalityConstraint,
37
36
  )
38
37
  from data_designer.config.seed import (
39
- DatastoreSeedDatasetReference,
40
38
  IndexRange,
41
- LocalSeedDatasetReference,
42
39
  PartitionBlock,
43
40
  SamplingStrategy,
44
41
  SeedConfig,
45
- SeedDatasetReference,
46
42
  )
43
+ from data_designer.config.seed_source import DataFrameSeedSource, SeedSource
47
44
  from data_designer.config.utils.constants import DEFAULT_REPR_HTML_STYLE, REPR_HTML_TEMPLATE
48
45
  from data_designer.config.utils.info import ConfigBuilderInfo
49
46
  from data_designer.config.utils.io_helpers import serialize_data, smart_load_yaml
50
47
  from data_designer.config.utils.misc import can_run_data_designer_locally, json_indent_list_of_strings, kebab_to_snake
51
48
  from data_designer.config.utils.type_helpers import resolve_string_enum
52
- from data_designer.config.utils.validation import ViolationLevel, rich_print_violations, validate_data_designer_config
53
49
 
54
50
  logger = logging.getLogger(__name__)
55
51
 
@@ -63,12 +59,9 @@ class BuilderConfig(ExportableConfigBase):
63
59
  Attributes:
64
60
  data_designer: The main Data Designer configuration containing columns,
65
61
  constraints, profilers, and other settings.
66
- datastore_settings: Optional datastore settings for accessing external
67
- datasets.
68
62
  """
69
63
 
70
64
  data_designer: DataDesignerConfig
71
- datastore_settings: DatastoreSettings | None
72
65
 
73
66
 
74
67
  class DataDesignerConfigBuilder:
@@ -101,30 +94,19 @@ class DataDesignerConfigBuilder:
101
94
  builder_config = BuilderConfig.model_validate(json_config)
102
95
 
103
96
  builder = cls(model_configs=builder_config.data_designer.model_configs)
104
- config = builder_config.data_designer
97
+ data_designer_config = builder_config.data_designer
105
98
 
106
- for col in config.columns:
99
+ for col in data_designer_config.columns:
107
100
  builder.add_column(col)
108
101
 
109
- for constraint in config.constraints or []:
102
+ for constraint in data_designer_config.constraints or []:
110
103
  builder.add_constraint(constraint=constraint)
111
104
 
112
- if config.seed_config:
113
- if builder_config.datastore_settings is None:
114
- if can_run_data_designer_locally():
115
- seed_dataset_reference = LocalSeedDatasetReference(dataset=config.seed_config.dataset)
116
- else:
117
- raise BuilderConfigurationError("🛑 Datastore settings are required.")
118
- else:
119
- seed_dataset_reference = DatastoreSeedDatasetReference(
120
- dataset=config.seed_config.dataset,
121
- datastore_settings=builder_config.datastore_settings,
122
- )
123
- builder.set_seed_datastore_settings(builder_config.datastore_settings)
105
+ if (seed_config := data_designer_config.seed_config) is not None:
124
106
  builder.with_seed_dataset(
125
- seed_dataset_reference,
126
- sampling_strategy=config.seed_config.sampling_strategy,
127
- selection_strategy=config.seed_config.selection_strategy,
107
+ seed_config.source,
108
+ sampling_strategy=seed_config.sampling_strategy,
109
+ selection_strategy=seed_config.selection_strategy,
128
110
  )
129
111
 
130
112
  return builder
@@ -144,7 +126,6 @@ class DataDesignerConfigBuilder:
144
126
  self._seed_config: SeedConfig | None = None
145
127
  self._constraints: list[ColumnConstraintT] = []
146
128
  self._profilers: list[ColumnProfilerConfigT] = []
147
- self._datastore_settings: DatastoreSettings | None = None
148
129
 
149
130
  @property
150
131
  def model_configs(self) -> list[ModelConfig]:
@@ -243,13 +224,6 @@ class DataDesignerConfigBuilder:
243
224
  f"{', '.join([t.__name__ for t in allowed_column_configs])}"
244
225
  )
245
226
 
246
- existing_config = self._column_configs.get(column_config.name)
247
- if existing_config is not None and isinstance(existing_config, SeedDatasetColumnConfig):
248
- raise BuilderConfigurationError(
249
- f"🛑 Column {column_config.name!r} already exists as a seed dataset column. "
250
- "Please use a different column name or update the seed dataset."
251
- )
252
-
253
227
  self._column_configs[column_config.name] = column_config
254
228
  return self
255
229
 
@@ -371,19 +345,12 @@ class DataDesignerConfigBuilder:
371
345
  """
372
346
  return self._profilers
373
347
 
374
- def build(self, *, skip_validation: bool = False, raise_exceptions: bool = False) -> DataDesignerConfig:
348
+ def build(self) -> DataDesignerConfig:
375
349
  """Build a DataDesignerConfig instance based on the current builder configuration.
376
350
 
377
- Args:
378
- skip_validation: Whether to skip validation of the configuration.
379
- raise_exceptions: Whether to raise an exception if the configuration is invalid.
380
-
381
351
  Returns:
382
352
  The current Data Designer config object.
383
353
  """
384
- if not skip_validation:
385
- self.validate(raise_exceptions=raise_exceptions)
386
-
387
354
  return DataDesignerConfig(
388
355
  model_configs=self._model_configs,
389
356
  seed_config=self._seed_config,
@@ -512,14 +479,6 @@ class DataDesignerConfigBuilder:
512
479
  """
513
480
  return self._seed_config
514
481
 
515
- def get_seed_datastore_settings(self) -> DatastoreSettings | None:
516
- """Get most recent datastore settings for the current Data Designer configuration.
517
-
518
- Returns:
519
- The datastore settings if configured, None otherwise.
520
- """
521
- return None if not self._datastore_settings else DatastoreSettings.model_validate(self._datastore_settings)
522
-
523
482
  def num_columns_of_type(self, column_type: DataDesignerColumnType) -> int:
524
483
  """Get the count of columns of the specified type.
525
484
 
@@ -531,85 +490,33 @@ class DataDesignerConfigBuilder:
531
490
  """
532
491
  return len(self.get_columns_of_type(column_type))
533
492
 
534
- def set_seed_datastore_settings(self, datastore_settings: DatastoreSettings | None) -> Self:
535
- """Set the datastore settings for the seed dataset.
536
-
537
- Args:
538
- datastore_settings: The datastore settings to use for the seed dataset.
539
- """
540
- self._datastore_settings = datastore_settings
541
- return self
542
-
543
- def validate(self, *, raise_exceptions: bool = False) -> Self:
544
- """Validate the current Data Designer configuration.
545
-
546
- Args:
547
- raise_exceptions: Whether to raise an exception if the configuration is invalid.
548
-
549
- Returns:
550
- The current Data Designer config builder instance.
551
-
552
- Raises:
553
- InvalidConfigError: If the configuration is invalid and raise_exceptions is True.
554
- """
555
-
556
- violations = validate_data_designer_config(
557
- columns=list(self._column_configs.values()),
558
- processor_configs=self._processor_configs,
559
- allowed_references=self.allowed_references,
560
- )
561
- rich_print_violations(violations)
562
- if raise_exceptions and len([v for v in violations if v.level == ViolationLevel.ERROR]) > 0:
563
- raise InvalidConfigError(
564
- "🛑 Your configuration contains validation errors. Please address the indicated issues and try again."
565
- )
566
- if len(violations) == 0:
567
- logger.info("✅ Validation passed")
568
- return self
569
-
570
493
  def with_seed_dataset(
571
494
  self,
572
- dataset_reference: SeedDatasetReference,
495
+ seed_source: SeedSource,
573
496
  *,
574
497
  sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED,
575
498
  selection_strategy: IndexRange | PartitionBlock | None = None,
576
499
  ) -> Self:
577
500
  """Add a seed dataset to the current Data Designer configuration.
578
501
 
579
- This method sets the seed dataset for the configuration and automatically creates
580
- SeedDatasetColumnConfig objects for each column found in the dataset. The column
581
- names are fetched from the dataset source, which can be the Hugging Face Hub, the
582
- NeMo Microservices Datastore, or in the case of direct library usage, a local file.
502
+ This method sets the seed dataset for the configuration, but columns are not resolved until
503
+ compilation (including validation) is performed by the engine using a SeedReader.
583
504
 
584
505
  Args:
585
- dataset_reference: Seed dataset reference for fetching from the datastore.
506
+ seed_source: The pointer to the seed dataset.
586
507
  sampling_strategy: The sampling strategy to use when generating data from the seed dataset.
587
508
  Defaults to ORDERED sampling.
509
+ selection_strategy: An optional selection strategy to use when generating data from the seed dataset.
510
+ Defaults to None.
588
511
 
589
512
  Returns:
590
513
  The current Data Designer config builder instance.
591
-
592
- Raises:
593
- BuilderConfigurationError: If any seed dataset column name collides with an existing column.
594
514
  """
595
- seed_column_names = fetch_seed_dataset_column_names(dataset_reference)
596
- colliding_columns = [name for name in seed_column_names if name in self._column_configs]
597
- if colliding_columns:
598
- raise BuilderConfigurationError(
599
- f"🛑 Seed dataset column(s) {colliding_columns} collide with existing column(s). "
600
- "Please remove the conflicting columns or use a seed dataset with different column names."
601
- )
602
-
603
515
  self._seed_config = SeedConfig(
604
- dataset=dataset_reference.dataset,
516
+ source=seed_source,
605
517
  sampling_strategy=sampling_strategy,
606
518
  selection_strategy=selection_strategy,
607
519
  )
608
- self.set_seed_datastore_settings(
609
- dataset_reference.datastore_settings if hasattr(dataset_reference, "datastore_settings") else None
610
- )
611
- for column_name in seed_column_names:
612
- self._column_configs[column_name] = SeedDatasetColumnConfig(name=column_name)
613
520
  return self
614
521
 
615
522
  def write_config(self, path: str | Path, indent: int | None = 2, **kwargs) -> None:
@@ -622,7 +529,17 @@ class DataDesignerConfigBuilder:
622
529
 
623
530
  Raises:
624
531
  BuilderConfigurationError: If the file format is unsupported.
625
- """
532
+ BuilderSerializationError: If the configuration cannot be serialized.
533
+ """
534
+ if (seed_config := self.get_seed_config()) is not None and isinstance(seed_config.source, DataFrameSeedSource):
535
+ raise BuilderSerializationError(
536
+ "This builder was configured with a DataFrame seed dataset. "
537
+ "DataFrame seeds cannot be serialized to config files. "
538
+ "To serialize this configuration, change your seed dataset to a more persistent, serializable source format. "
539
+ "For example, you could make a local file seed source from the dataframe:\n\n"
540
+ "LocalFileSeedSource.from_dataframe(my_dataframe, '/path/to/data.parquet')"
541
+ )
542
+
626
543
  cfg = self.get_builder_config()
627
544
  suffix = Path(path).suffix
628
545
  if suffix in {".yaml", ".yml"}:
@@ -638,7 +555,7 @@ class DataDesignerConfigBuilder:
638
555
  Returns:
639
556
  The builder config.
640
557
  """
641
- return BuilderConfig(data_designer=self.build(), datastore_settings=self._datastore_settings)
558
+ return BuilderConfig(data_designer=self.build())
642
559
 
643
560
  def __repr__(self) -> str:
644
561
  """Generates a string representation of the DataDesignerConfigBuilder instance.
@@ -650,7 +567,7 @@ class DataDesignerConfigBuilder:
650
567
  return f"{self.__class__.__name__}()"
651
568
 
652
569
  props_to_repr = {
653
- "seed_dataset": (None if self._seed_config is None else f"'{self._seed_config.dataset}'"),
570
+ "seed_dataset": (None if self._seed_config is None else f"{self._seed_config.source.seed_type} seed"),
654
571
  }
655
572
 
656
573
  for column_type in get_column_display_order():
@@ -7,6 +7,9 @@ from data_designer.errors import DataDesignerError
7
7
  class BuilderConfigurationError(DataDesignerError): ...
8
8
 
9
9
 
10
+ class BuilderSerializationError(DataDesignerError): ...
11
+
12
+
10
13
  class InvalidColumnTypeError(DataDesignerError): ...
11
14
 
12
15
 
@@ -18,14 +18,12 @@ from data_designer.config.column_types import DataDesignerColumnType
18
18
  from data_designer.config.config_builder import DataDesignerConfigBuilder
19
19
  from data_designer.config.data_designer_config import DataDesignerConfig
20
20
  from data_designer.config.dataset_builders import BuildStage
21
- from data_designer.config.datastore import DatastoreSettings
22
21
  from data_designer.config.models import (
23
22
  ChatCompletionInferenceParams,
24
23
  EmbeddingInferenceParams,
25
24
  GenerationType,
26
25
  ImageContext,
27
26
  ImageFormat,
28
- InferenceParameters,
29
27
  ManualDistribution,
30
28
  ManualDistributionParams,
31
29
  Modality,
@@ -60,12 +58,16 @@ from data_designer.config.sampler_params import (
60
58
  UUIDSamplerParams,
61
59
  )
62
60
  from data_designer.config.seed import (
63
- DatastoreSeedDatasetReference,
64
61
  IndexRange,
65
62
  PartitionBlock,
66
63
  SamplingStrategy,
67
64
  SeedConfig,
68
65
  )
66
+ from data_designer.config.seed_source import (
67
+ DataFrameSeedSource,
68
+ HuggingFaceSeedSource,
69
+ LocalFileSeedSource,
70
+ )
69
71
  from data_designer.config.utils.code_lang import CodeLang
70
72
  from data_designer.config.utils.info import InfoType
71
73
  from data_designer.config.validator_params import (
@@ -89,9 +91,8 @@ def get_config_exports() -> list[str]:
89
91
  DataDesignerColumnType.__name__,
90
92
  DataDesignerConfig.__name__,
91
93
  DataDesignerConfigBuilder.__name__,
94
+ DataFrameSeedSource.__name__,
92
95
  BuildStage.__name__,
93
- DatastoreSeedDatasetReference.__name__,
94
- DatastoreSettings.__name__,
95
96
  DatetimeSamplerParams.__name__,
96
97
  DropColumnsProcessorConfig.__name__,
97
98
  EmbeddingColumnConfig.__name__,
@@ -99,16 +100,17 @@ def get_config_exports() -> list[str]:
99
100
  ExpressionColumnConfig.__name__,
100
101
  GaussianSamplerParams.__name__,
101
102
  GenerationType.__name__,
103
+ HuggingFaceSeedSource.__name__,
102
104
  IndexRange.__name__,
103
105
  InfoType.__name__,
104
106
  ImageContext.__name__,
105
107
  ImageFormat.__name__,
106
- InferenceParameters.__name__,
107
108
  JudgeScoreProfilerConfig.__name__,
108
109
  LLMCodeColumnConfig.__name__,
109
110
  LLMJudgeColumnConfig.__name__,
110
111
  LLMStructuredColumnConfig.__name__,
111
112
  LLMTextColumnConfig.__name__,
113
+ LocalFileSeedSource.__name__,
112
114
  ManualDistribution.__name__,
113
115
  ManualDistributionParams.__name__,
114
116
  Modality.__name__,
@@ -5,7 +5,7 @@ import logging
5
5
  from abc import ABC, abstractmethod
6
6
  from enum import Enum
7
7
  from pathlib import Path
8
- from typing import Any, Generic, Literal, TypeVar
8
+ from typing import Annotated, Any, Generic, Literal, TypeVar
9
9
 
10
10
  import numpy as np
11
11
  from pydantic import BaseModel, Field, field_validator, model_validator
@@ -278,7 +278,7 @@ class ChatCompletionInferenceParams(BaseInferenceParams):
278
278
  generation_type: Type of generation, always "chat-completion" for this class.
279
279
  temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling.
280
280
  top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
281
- max_tokens: Maximum number of tokens (includes both input and output tokens).
281
+ max_tokens: Maximum number of tokens to generate in the response.
282
282
  """
283
283
 
284
284
  generation_type: Literal[GenerationType.CHAT_COMPLETION] = GenerationType.CHAT_COMPLETION
@@ -357,21 +357,6 @@ class ChatCompletionInferenceParams(BaseInferenceParams):
357
357
  return super()._format_value(key, value)
358
358
 
359
359
 
360
- # Maintain backwards compatibility with a deprecation warning
361
- class InferenceParameters(ChatCompletionInferenceParams):
362
- """
363
- Deprecated: Use ChatCompletionInferenceParams instead.
364
- This alias will be removed in a future version.
365
- """
366
-
367
- def __init__(self, *args: Any, **kwargs: Any) -> None:
368
- logger.warning(
369
- "InferenceParameters is deprecated and will be removed in a future version. "
370
- "Use ChatCompletionInferenceParams instead."
371
- )
372
- super().__init__(*args, **kwargs)
373
-
374
-
375
360
  class EmbeddingInferenceParams(BaseInferenceParams):
376
361
  """Configuration for embedding generation parameters.
377
362
 
@@ -395,7 +380,9 @@ class EmbeddingInferenceParams(BaseInferenceParams):
395
380
  return result
396
381
 
397
382
 
398
- InferenceParamsT: TypeAlias = ChatCompletionInferenceParams | EmbeddingInferenceParams | InferenceParameters
383
+ InferenceParamsT: TypeAlias = Annotated[
384
+ ChatCompletionInferenceParams | EmbeddingInferenceParams, Field(discriminator="generation_type")
385
+ ]
399
386
 
400
387
 
401
388
  class ModelConfig(ConfigBase):
@@ -441,6 +428,7 @@ class ModelProvider(ConfigBase):
441
428
  provider_type: Provider type (default: "openai"). Determines the API format to use.
442
429
  api_key: Optional API key for authentication.
443
430
  extra_body: Additional parameters to pass in API requests.
431
+ extra_headers: Additional headers to pass in API requests.
444
432
  """
445
433
 
446
434
  name: str
@@ -448,6 +436,7 @@ class ModelProvider(ConfigBase):
448
436
  provider_type: str = "openai"
449
437
  api_key: str | None = None
450
438
  extra_body: dict[str, Any] | None = None
439
+ extra_headers: dict[str, str] | None = None
451
440
 
452
441
 
453
442
  def load_model_configs(model_configs: list[ModelConfig] | str | Path) -> list[ModelConfig]:
@@ -0,0 +1,34 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from pydantic import Field, model_validator
5
+ from typing_extensions import Self
6
+
7
+ from data_designer.config.base import ConfigBase
8
+
9
+
10
+ class RunConfig(ConfigBase):
11
+ """Runtime configuration for dataset generation.
12
+
13
+ Groups configuration options that control generation behavior but aren't
14
+ part of the dataset configuration itself.
15
+
16
+ Attributes:
17
+ disable_early_shutdown: If True, disables early shutdown entirely. Generation
18
+ will continue regardless of error rate. Default is False.
19
+ shutdown_error_rate: Error rate threshold (0.0-1.0) that triggers early shutdown.
20
+ When early shutdown is disabled, this value is normalized to 1.0. Default is 0.5.
21
+ shutdown_error_window: Minimum number of completed tasks before error rate
22
+ monitoring begins. Must be >= 0. Default is 10.
23
+ """
24
+
25
+ disable_early_shutdown: bool = False
26
+ shutdown_error_rate: float = Field(default=0.5, ge=0.0, le=1.0)
27
+ shutdown_error_window: int = Field(default=10, ge=0)
28
+
29
+ @model_validator(mode="after")
30
+ def normalize_shutdown_settings(self) -> Self:
31
+ """Set shutdown_error_rate to 1.0 when early shutdown is disabled."""
32
+ if self.disable_early_shutdown:
33
+ self.shutdown_error_rate = 1.0
34
+ return self
@@ -1,19 +1,13 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
- from abc import ABC
5
4
  from enum import Enum
6
5
 
7
- from pydantic import Field, field_validator, model_validator
6
+ from pydantic import Field, model_validator
8
7
  from typing_extensions import Self
9
8
 
10
9
  from data_designer.config.base import ConfigBase
11
- from data_designer.config.datastore import DatastoreSettings
12
- from data_designer.config.utils.io_helpers import (
13
- VALID_DATASET_FILE_EXTENSIONS,
14
- validate_dataset_file_path,
15
- validate_path_contains_files_of_type,
16
- )
10
+ from data_designer.config.seed_source import SeedSourceT
17
11
 
18
12
 
19
13
  class SamplingStrategy(str, Enum):
@@ -62,7 +56,7 @@ class SeedConfig(ConfigBase):
62
56
  """Configuration for sampling data from a seed dataset.
63
57
 
64
58
  Args:
65
- dataset: Path or identifier for the seed dataset.
59
+ source: A SeedSource defining where the seed data exists
66
60
  sampling_strategy: Strategy for how to sample rows from the dataset.
67
61
  - ORDERED: Read rows sequentially in their original order.
68
62
  - SHUFFLE: Randomly shuffle rows before sampling. When used with
@@ -75,70 +69,46 @@ class SeedConfig(ConfigBase):
75
69
 
76
70
  Examples:
77
71
  Read rows sequentially from start to end:
78
- SeedConfig(dataset="my_data.parquet", sampling_strategy=SamplingStrategy.ORDERED)
72
+ SeedConfig(
73
+ source=LocalFileSeedSource(path="my_data.parquet"),
74
+ sampling_strategy=SamplingStrategy.ORDERED
75
+ )
79
76
 
80
77
  Read rows in random order:
81
- SeedConfig(dataset="my_data.parquet", sampling_strategy=SamplingStrategy.SHUFFLE)
78
+ SeedConfig(
79
+ source=LocalFileSeedSource(path="my_data.parquet"),
80
+ sampling_strategy=SamplingStrategy.SHUFFLE
81
+ )
82
82
 
83
83
  Read specific index range (rows 100-199):
84
84
  SeedConfig(
85
- dataset="my_data.parquet",
85
+ source=LocalFileSeedSource(path="my_data.parquet"),
86
86
  sampling_strategy=SamplingStrategy.ORDERED,
87
87
  selection_strategy=IndexRange(start=100, end=199)
88
88
  )
89
89
 
90
90
  Read random rows from a specific index range (shuffles within rows 100-199):
91
91
  SeedConfig(
92
- dataset="my_data.parquet",
92
+ source=LocalFileSeedSource(path="my_data.parquet"),
93
93
  sampling_strategy=SamplingStrategy.SHUFFLE,
94
94
  selection_strategy=IndexRange(start=100, end=199)
95
95
  )
96
96
 
97
97
  Read from partition 2 (3rd partition, zero-based) of 5 partitions (20% of dataset):
98
98
  SeedConfig(
99
- dataset="my_data.parquet",
99
+ source=LocalFileSeedSource(path="my_data.parquet"),
100
100
  sampling_strategy=SamplingStrategy.ORDERED,
101
101
  selection_strategy=PartitionBlock(index=2, num_partitions=5)
102
102
  )
103
103
 
104
104
  Read shuffled rows from partition 0 of 10 partitions (shuffles within the partition):
105
105
  SeedConfig(
106
- dataset="my_data.parquet",
106
+ source=LocalFileSeedSource(path="my_data.parquet"),
107
107
  sampling_strategy=SamplingStrategy.SHUFFLE,
108
108
  selection_strategy=PartitionBlock(index=0, num_partitions=10)
109
109
  )
110
110
  """
111
111
 
112
- dataset: str
112
+ source: SeedSourceT
113
113
  sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED
114
114
  selection_strategy: IndexRange | PartitionBlock | None = None
115
-
116
-
117
- class SeedDatasetReference(ABC, ConfigBase):
118
- dataset: str
119
-
120
-
121
- class DatastoreSeedDatasetReference(SeedDatasetReference):
122
- datastore_settings: DatastoreSettings
123
-
124
- @property
125
- def repo_id(self) -> str:
126
- return "/".join(self.dataset.split("/")[:-1])
127
-
128
- @property
129
- def filename(self) -> str:
130
- return self.dataset.split("/")[-1]
131
-
132
-
133
- class LocalSeedDatasetReference(SeedDatasetReference):
134
- @field_validator("dataset", mode="after")
135
- def validate_dataset_is_file(cls, v: str) -> str:
136
- valid_wild_card_versions = {f"*{ext}" for ext in VALID_DATASET_FILE_EXTENSIONS}
137
- if any(v.endswith(wildcard) for wildcard in valid_wild_card_versions):
138
- parts = v.split("*.")
139
- file_path = parts[0]
140
- file_extension = parts[-1]
141
- validate_path_contains_files_of_type(file_path, file_extension)
142
- else:
143
- validate_dataset_file_path(v)
144
- return v
@@ -0,0 +1,73 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC
5
+ from typing import Annotated, Literal
6
+
7
+ import pandas as pd
8
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
9
+ from typing_extensions import Self
10
+
11
+ from data_designer.config.utils.io_helpers import (
12
+ VALID_DATASET_FILE_EXTENSIONS,
13
+ validate_dataset_file_path,
14
+ validate_path_contains_files_of_type,
15
+ )
16
+
17
+
18
+ class SeedSource(BaseModel, ABC):
19
+ """Base class for seed dataset configurations.
20
+
21
+ All subclasses must define a `seed_type` field with a Literal value.
22
+ This serves as a discriminated union discriminator.
23
+ """
24
+
25
+ seed_type: str
26
+
27
+
28
+ class LocalFileSeedSource(SeedSource):
29
+ seed_type: Literal["local"] = "local"
30
+
31
+ path: str
32
+
33
+ @field_validator("path", mode="after")
34
+ def validate_path(cls, v: str) -> str:
35
+ valid_wild_card_versions = {f"*{ext}" for ext in VALID_DATASET_FILE_EXTENSIONS}
36
+ if any(v.endswith(wildcard) for wildcard in valid_wild_card_versions):
37
+ parts = v.split("*.")
38
+ file_path = parts[0]
39
+ file_extension = parts[-1]
40
+ validate_path_contains_files_of_type(file_path, file_extension)
41
+ else:
42
+ validate_dataset_file_path(v)
43
+ return v
44
+
45
+ @classmethod
46
+ def from_dataframe(cls, df: pd.DataFrame, path: str) -> Self:
47
+ df.to_parquet(path, index=False)
48
+ return cls(path=path)
49
+
50
+
51
+ class HuggingFaceSeedSource(SeedSource):
52
+ seed_type: Literal["hf"] = "hf"
53
+
54
+ path: str = Field(
55
+ ...,
56
+ description="Path to the seed data in HuggingFace. Wildcards are allowed. Examples include 'datasets/my-username/my-dataset/data/000_00000.parquet', 'datasets/my-username/my-dataset/data/*.parquet', 'datasets/my-username/my-dataset/**/*.parquet'",
57
+ )
58
+ token: str | None = None
59
+ endpoint: str = "https://huggingface.co"
60
+
61
+
62
+ class DataFrameSeedSource(SeedSource):
63
+ seed_type: Literal["df"] = "df"
64
+
65
+ model_config = ConfigDict(arbitrary_types_allowed=True)
66
+
67
+ df: pd.DataFrame
68
+
69
+
70
+ SeedSourceT = Annotated[
71
+ LocalFileSeedSource | HuggingFaceSeedSource | DataFrameSeedSource,
72
+ Field(discriminator="seed_type"),
73
+ ]