data-designer 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/cli/forms/model_builder.py +2 -2
- data_designer/config/config_builder.py +30 -113
- data_designer/config/errors.py +3 -0
- data_designer/config/exports.py +8 -6
- data_designer/config/models.py +7 -18
- data_designer/config/run_config.py +34 -0
- data_designer/config/seed.py +16 -46
- data_designer/config/seed_source.py +73 -0
- data_designer/config/utils/constants.py +27 -2
- data_designer/config/utils/io_helpers.py +0 -20
- data_designer/engine/column_generators/generators/seed_dataset.py +5 -5
- data_designer/engine/column_generators/generators/validation.py +3 -0
- data_designer/engine/column_generators/registry.py +1 -1
- data_designer/engine/compiler.py +69 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +3 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +1 -1
- data_designer/engine/models/facade.py +2 -0
- data_designer/engine/processing/gsonschema/validators.py +55 -0
- data_designer/engine/resources/resource_provider.py +17 -5
- data_designer/engine/resources/seed_reader.py +149 -0
- data_designer/essentials/__init__.py +2 -0
- data_designer/interface/data_designer.py +72 -62
- data_designer/plugin_manager.py +1 -1
- data_designer/plugins/errors.py +3 -0
- data_designer/plugins/plugin.py +82 -12
- data_designer/plugins/testing/__init__.py +8 -0
- data_designer/plugins/testing/stubs.py +145 -0
- data_designer/plugins/testing/utils.py +11 -0
- {data_designer-0.2.3.dist-info → data_designer-0.3.0.dist-info}/METADATA +3 -3
- {data_designer-0.2.3.dist-info → data_designer-0.3.0.dist-info}/RECORD +35 -30
- data_designer/config/datastore.py +0 -187
- data_designer/engine/resources/seed_dataset_data_store.py +0 -84
- /data_designer/{config/utils → engine}/validation.py +0 -0
- {data_designer-0.2.3.dist-info → data_designer-0.3.0.dist-info}/WHEEL +0 -0
- {data_designer-0.2.3.dist-info → data_designer-0.3.0.dist-info}/entry_points.txt +0 -0
- {data_designer-0.2.3.dist-info → data_designer-0.3.0.dist-info}/licenses/LICENSE +0 -0
data_designer/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
|
31
|
+
__version__ = version = '0.3.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -125,11 +125,11 @@ class ModelFormBuilder(FormBuilder[ModelConfig]):
|
|
|
125
125
|
fields.append(
|
|
126
126
|
NumericField(
|
|
127
127
|
"max_tokens",
|
|
128
|
-
"Max tokens <dim>(maximum
|
|
128
|
+
"Max tokens <dim>(maximum tokens to generate in response)</dim>",
|
|
129
129
|
default=initial_params.get("max_tokens"),
|
|
130
130
|
min_value=1.0,
|
|
131
131
|
required=False,
|
|
132
|
-
help_text="Maximum number of tokens
|
|
132
|
+
help_text="Maximum number of tokens to generate in the response",
|
|
133
133
|
)
|
|
134
134
|
)
|
|
135
135
|
|
|
@@ -24,9 +24,8 @@ from data_designer.config.column_types import (
|
|
|
24
24
|
)
|
|
25
25
|
from data_designer.config.data_designer_config import DataDesignerConfig
|
|
26
26
|
from data_designer.config.dataset_builders import BuildStage
|
|
27
|
-
from data_designer.config.datastore import DatastoreSettings, fetch_seed_dataset_column_names
|
|
28
27
|
from data_designer.config.default_model_settings import get_default_model_configs
|
|
29
|
-
from data_designer.config.errors import BuilderConfigurationError,
|
|
28
|
+
from data_designer.config.errors import BuilderConfigurationError, BuilderSerializationError, InvalidColumnTypeError
|
|
30
29
|
from data_designer.config.models import ModelConfig, load_model_configs
|
|
31
30
|
from data_designer.config.processors import ProcessorConfigT, ProcessorType, get_processor_config_from_kwargs
|
|
32
31
|
from data_designer.config.sampler_constraints import (
|
|
@@ -36,20 +35,17 @@ from data_designer.config.sampler_constraints import (
|
|
|
36
35
|
ScalarInequalityConstraint,
|
|
37
36
|
)
|
|
38
37
|
from data_designer.config.seed import (
|
|
39
|
-
DatastoreSeedDatasetReference,
|
|
40
38
|
IndexRange,
|
|
41
|
-
LocalSeedDatasetReference,
|
|
42
39
|
PartitionBlock,
|
|
43
40
|
SamplingStrategy,
|
|
44
41
|
SeedConfig,
|
|
45
|
-
SeedDatasetReference,
|
|
46
42
|
)
|
|
43
|
+
from data_designer.config.seed_source import DataFrameSeedSource, SeedSource
|
|
47
44
|
from data_designer.config.utils.constants import DEFAULT_REPR_HTML_STYLE, REPR_HTML_TEMPLATE
|
|
48
45
|
from data_designer.config.utils.info import ConfigBuilderInfo
|
|
49
46
|
from data_designer.config.utils.io_helpers import serialize_data, smart_load_yaml
|
|
50
47
|
from data_designer.config.utils.misc import can_run_data_designer_locally, json_indent_list_of_strings, kebab_to_snake
|
|
51
48
|
from data_designer.config.utils.type_helpers import resolve_string_enum
|
|
52
|
-
from data_designer.config.utils.validation import ViolationLevel, rich_print_violations, validate_data_designer_config
|
|
53
49
|
|
|
54
50
|
logger = logging.getLogger(__name__)
|
|
55
51
|
|
|
@@ -63,12 +59,9 @@ class BuilderConfig(ExportableConfigBase):
|
|
|
63
59
|
Attributes:
|
|
64
60
|
data_designer: The main Data Designer configuration containing columns,
|
|
65
61
|
constraints, profilers, and other settings.
|
|
66
|
-
datastore_settings: Optional datastore settings for accessing external
|
|
67
|
-
datasets.
|
|
68
62
|
"""
|
|
69
63
|
|
|
70
64
|
data_designer: DataDesignerConfig
|
|
71
|
-
datastore_settings: DatastoreSettings | None
|
|
72
65
|
|
|
73
66
|
|
|
74
67
|
class DataDesignerConfigBuilder:
|
|
@@ -101,30 +94,19 @@ class DataDesignerConfigBuilder:
|
|
|
101
94
|
builder_config = BuilderConfig.model_validate(json_config)
|
|
102
95
|
|
|
103
96
|
builder = cls(model_configs=builder_config.data_designer.model_configs)
|
|
104
|
-
|
|
97
|
+
data_designer_config = builder_config.data_designer
|
|
105
98
|
|
|
106
|
-
for col in
|
|
99
|
+
for col in data_designer_config.columns:
|
|
107
100
|
builder.add_column(col)
|
|
108
101
|
|
|
109
|
-
for constraint in
|
|
102
|
+
for constraint in data_designer_config.constraints or []:
|
|
110
103
|
builder.add_constraint(constraint=constraint)
|
|
111
104
|
|
|
112
|
-
if
|
|
113
|
-
if builder_config.datastore_settings is None:
|
|
114
|
-
if can_run_data_designer_locally():
|
|
115
|
-
seed_dataset_reference = LocalSeedDatasetReference(dataset=config.seed_config.dataset)
|
|
116
|
-
else:
|
|
117
|
-
raise BuilderConfigurationError("🛑 Datastore settings are required.")
|
|
118
|
-
else:
|
|
119
|
-
seed_dataset_reference = DatastoreSeedDatasetReference(
|
|
120
|
-
dataset=config.seed_config.dataset,
|
|
121
|
-
datastore_settings=builder_config.datastore_settings,
|
|
122
|
-
)
|
|
123
|
-
builder.set_seed_datastore_settings(builder_config.datastore_settings)
|
|
105
|
+
if (seed_config := data_designer_config.seed_config) is not None:
|
|
124
106
|
builder.with_seed_dataset(
|
|
125
|
-
|
|
126
|
-
sampling_strategy=
|
|
127
|
-
selection_strategy=
|
|
107
|
+
seed_config.source,
|
|
108
|
+
sampling_strategy=seed_config.sampling_strategy,
|
|
109
|
+
selection_strategy=seed_config.selection_strategy,
|
|
128
110
|
)
|
|
129
111
|
|
|
130
112
|
return builder
|
|
@@ -144,7 +126,6 @@ class DataDesignerConfigBuilder:
|
|
|
144
126
|
self._seed_config: SeedConfig | None = None
|
|
145
127
|
self._constraints: list[ColumnConstraintT] = []
|
|
146
128
|
self._profilers: list[ColumnProfilerConfigT] = []
|
|
147
|
-
self._datastore_settings: DatastoreSettings | None = None
|
|
148
129
|
|
|
149
130
|
@property
|
|
150
131
|
def model_configs(self) -> list[ModelConfig]:
|
|
@@ -243,13 +224,6 @@ class DataDesignerConfigBuilder:
|
|
|
243
224
|
f"{', '.join([t.__name__ for t in allowed_column_configs])}"
|
|
244
225
|
)
|
|
245
226
|
|
|
246
|
-
existing_config = self._column_configs.get(column_config.name)
|
|
247
|
-
if existing_config is not None and isinstance(existing_config, SeedDatasetColumnConfig):
|
|
248
|
-
raise BuilderConfigurationError(
|
|
249
|
-
f"🛑 Column {column_config.name!r} already exists as a seed dataset column. "
|
|
250
|
-
"Please use a different column name or update the seed dataset."
|
|
251
|
-
)
|
|
252
|
-
|
|
253
227
|
self._column_configs[column_config.name] = column_config
|
|
254
228
|
return self
|
|
255
229
|
|
|
@@ -371,19 +345,12 @@ class DataDesignerConfigBuilder:
|
|
|
371
345
|
"""
|
|
372
346
|
return self._profilers
|
|
373
347
|
|
|
374
|
-
def build(self
|
|
348
|
+
def build(self) -> DataDesignerConfig:
|
|
375
349
|
"""Build a DataDesignerConfig instance based on the current builder configuration.
|
|
376
350
|
|
|
377
|
-
Args:
|
|
378
|
-
skip_validation: Whether to skip validation of the configuration.
|
|
379
|
-
raise_exceptions: Whether to raise an exception if the configuration is invalid.
|
|
380
|
-
|
|
381
351
|
Returns:
|
|
382
352
|
The current Data Designer config object.
|
|
383
353
|
"""
|
|
384
|
-
if not skip_validation:
|
|
385
|
-
self.validate(raise_exceptions=raise_exceptions)
|
|
386
|
-
|
|
387
354
|
return DataDesignerConfig(
|
|
388
355
|
model_configs=self._model_configs,
|
|
389
356
|
seed_config=self._seed_config,
|
|
@@ -512,14 +479,6 @@ class DataDesignerConfigBuilder:
|
|
|
512
479
|
"""
|
|
513
480
|
return self._seed_config
|
|
514
481
|
|
|
515
|
-
def get_seed_datastore_settings(self) -> DatastoreSettings | None:
|
|
516
|
-
"""Get most recent datastore settings for the current Data Designer configuration.
|
|
517
|
-
|
|
518
|
-
Returns:
|
|
519
|
-
The datastore settings if configured, None otherwise.
|
|
520
|
-
"""
|
|
521
|
-
return None if not self._datastore_settings else DatastoreSettings.model_validate(self._datastore_settings)
|
|
522
|
-
|
|
523
482
|
def num_columns_of_type(self, column_type: DataDesignerColumnType) -> int:
|
|
524
483
|
"""Get the count of columns of the specified type.
|
|
525
484
|
|
|
@@ -531,85 +490,33 @@ class DataDesignerConfigBuilder:
|
|
|
531
490
|
"""
|
|
532
491
|
return len(self.get_columns_of_type(column_type))
|
|
533
492
|
|
|
534
|
-
def set_seed_datastore_settings(self, datastore_settings: DatastoreSettings | None) -> Self:
|
|
535
|
-
"""Set the datastore settings for the seed dataset.
|
|
536
|
-
|
|
537
|
-
Args:
|
|
538
|
-
datastore_settings: The datastore settings to use for the seed dataset.
|
|
539
|
-
"""
|
|
540
|
-
self._datastore_settings = datastore_settings
|
|
541
|
-
return self
|
|
542
|
-
|
|
543
|
-
def validate(self, *, raise_exceptions: bool = False) -> Self:
|
|
544
|
-
"""Validate the current Data Designer configuration.
|
|
545
|
-
|
|
546
|
-
Args:
|
|
547
|
-
raise_exceptions: Whether to raise an exception if the configuration is invalid.
|
|
548
|
-
|
|
549
|
-
Returns:
|
|
550
|
-
The current Data Designer config builder instance.
|
|
551
|
-
|
|
552
|
-
Raises:
|
|
553
|
-
InvalidConfigError: If the configuration is invalid and raise_exceptions is True.
|
|
554
|
-
"""
|
|
555
|
-
|
|
556
|
-
violations = validate_data_designer_config(
|
|
557
|
-
columns=list(self._column_configs.values()),
|
|
558
|
-
processor_configs=self._processor_configs,
|
|
559
|
-
allowed_references=self.allowed_references,
|
|
560
|
-
)
|
|
561
|
-
rich_print_violations(violations)
|
|
562
|
-
if raise_exceptions and len([v for v in violations if v.level == ViolationLevel.ERROR]) > 0:
|
|
563
|
-
raise InvalidConfigError(
|
|
564
|
-
"🛑 Your configuration contains validation errors. Please address the indicated issues and try again."
|
|
565
|
-
)
|
|
566
|
-
if len(violations) == 0:
|
|
567
|
-
logger.info("✅ Validation passed")
|
|
568
|
-
return self
|
|
569
|
-
|
|
570
493
|
def with_seed_dataset(
|
|
571
494
|
self,
|
|
572
|
-
|
|
495
|
+
seed_source: SeedSource,
|
|
573
496
|
*,
|
|
574
497
|
sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED,
|
|
575
498
|
selection_strategy: IndexRange | PartitionBlock | None = None,
|
|
576
499
|
) -> Self:
|
|
577
500
|
"""Add a seed dataset to the current Data Designer configuration.
|
|
578
501
|
|
|
579
|
-
This method sets the seed dataset for the configuration
|
|
580
|
-
|
|
581
|
-
names are fetched from the dataset source, which can be the Hugging Face Hub, the
|
|
582
|
-
NeMo Microservices Datastore, or in the case of direct library usage, a local file.
|
|
502
|
+
This method sets the seed dataset for the configuration, but columns are not resolved until
|
|
503
|
+
compilation (including validation) is performed by the engine using a SeedReader.
|
|
583
504
|
|
|
584
505
|
Args:
|
|
585
|
-
|
|
506
|
+
seed_source: The pointer to the seed dataset.
|
|
586
507
|
sampling_strategy: The sampling strategy to use when generating data from the seed dataset.
|
|
587
508
|
Defaults to ORDERED sampling.
|
|
509
|
+
selection_strategy: An optional selection strategy to use when generating data from the seed dataset.
|
|
510
|
+
Defaults to None.
|
|
588
511
|
|
|
589
512
|
Returns:
|
|
590
513
|
The current Data Designer config builder instance.
|
|
591
|
-
|
|
592
|
-
Raises:
|
|
593
|
-
BuilderConfigurationError: If any seed dataset column name collides with an existing column.
|
|
594
514
|
"""
|
|
595
|
-
seed_column_names = fetch_seed_dataset_column_names(dataset_reference)
|
|
596
|
-
colliding_columns = [name for name in seed_column_names if name in self._column_configs]
|
|
597
|
-
if colliding_columns:
|
|
598
|
-
raise BuilderConfigurationError(
|
|
599
|
-
f"🛑 Seed dataset column(s) {colliding_columns} collide with existing column(s). "
|
|
600
|
-
"Please remove the conflicting columns or use a seed dataset with different column names."
|
|
601
|
-
)
|
|
602
|
-
|
|
603
515
|
self._seed_config = SeedConfig(
|
|
604
|
-
|
|
516
|
+
source=seed_source,
|
|
605
517
|
sampling_strategy=sampling_strategy,
|
|
606
518
|
selection_strategy=selection_strategy,
|
|
607
519
|
)
|
|
608
|
-
self.set_seed_datastore_settings(
|
|
609
|
-
dataset_reference.datastore_settings if hasattr(dataset_reference, "datastore_settings") else None
|
|
610
|
-
)
|
|
611
|
-
for column_name in seed_column_names:
|
|
612
|
-
self._column_configs[column_name] = SeedDatasetColumnConfig(name=column_name)
|
|
613
520
|
return self
|
|
614
521
|
|
|
615
522
|
def write_config(self, path: str | Path, indent: int | None = 2, **kwargs) -> None:
|
|
@@ -622,7 +529,17 @@ class DataDesignerConfigBuilder:
|
|
|
622
529
|
|
|
623
530
|
Raises:
|
|
624
531
|
BuilderConfigurationError: If the file format is unsupported.
|
|
625
|
-
|
|
532
|
+
BuilderSerializationError: If the configuration cannot be serialized.
|
|
533
|
+
"""
|
|
534
|
+
if (seed_config := self.get_seed_config()) is not None and isinstance(seed_config.source, DataFrameSeedSource):
|
|
535
|
+
raise BuilderSerializationError(
|
|
536
|
+
"This builder was configured with a DataFrame seed dataset. "
|
|
537
|
+
"DataFrame seeds cannot be serialized to config files. "
|
|
538
|
+
"To serialize this configuration, change your seed dataset to a more persistent, serializable source format. "
|
|
539
|
+
"For example, you could make a local file seed source from the dataframe:\n\n"
|
|
540
|
+
"LocalFileSeedSource.from_dataframe(my_dataframe, '/path/to/data.parquet')"
|
|
541
|
+
)
|
|
542
|
+
|
|
626
543
|
cfg = self.get_builder_config()
|
|
627
544
|
suffix = Path(path).suffix
|
|
628
545
|
if suffix in {".yaml", ".yml"}:
|
|
@@ -638,7 +555,7 @@ class DataDesignerConfigBuilder:
|
|
|
638
555
|
Returns:
|
|
639
556
|
The builder config.
|
|
640
557
|
"""
|
|
641
|
-
return BuilderConfig(data_designer=self.build()
|
|
558
|
+
return BuilderConfig(data_designer=self.build())
|
|
642
559
|
|
|
643
560
|
def __repr__(self) -> str:
|
|
644
561
|
"""Generates a string representation of the DataDesignerConfigBuilder instance.
|
|
@@ -650,7 +567,7 @@ class DataDesignerConfigBuilder:
|
|
|
650
567
|
return f"{self.__class__.__name__}()"
|
|
651
568
|
|
|
652
569
|
props_to_repr = {
|
|
653
|
-
"seed_dataset": (None if self._seed_config is None else f"
|
|
570
|
+
"seed_dataset": (None if self._seed_config is None else f"{self._seed_config.source.seed_type} seed"),
|
|
654
571
|
}
|
|
655
572
|
|
|
656
573
|
for column_type in get_column_display_order():
|
data_designer/config/errors.py
CHANGED
data_designer/config/exports.py
CHANGED
|
@@ -18,14 +18,12 @@ from data_designer.config.column_types import DataDesignerColumnType
|
|
|
18
18
|
from data_designer.config.config_builder import DataDesignerConfigBuilder
|
|
19
19
|
from data_designer.config.data_designer_config import DataDesignerConfig
|
|
20
20
|
from data_designer.config.dataset_builders import BuildStage
|
|
21
|
-
from data_designer.config.datastore import DatastoreSettings
|
|
22
21
|
from data_designer.config.models import (
|
|
23
22
|
ChatCompletionInferenceParams,
|
|
24
23
|
EmbeddingInferenceParams,
|
|
25
24
|
GenerationType,
|
|
26
25
|
ImageContext,
|
|
27
26
|
ImageFormat,
|
|
28
|
-
InferenceParameters,
|
|
29
27
|
ManualDistribution,
|
|
30
28
|
ManualDistributionParams,
|
|
31
29
|
Modality,
|
|
@@ -60,12 +58,16 @@ from data_designer.config.sampler_params import (
|
|
|
60
58
|
UUIDSamplerParams,
|
|
61
59
|
)
|
|
62
60
|
from data_designer.config.seed import (
|
|
63
|
-
DatastoreSeedDatasetReference,
|
|
64
61
|
IndexRange,
|
|
65
62
|
PartitionBlock,
|
|
66
63
|
SamplingStrategy,
|
|
67
64
|
SeedConfig,
|
|
68
65
|
)
|
|
66
|
+
from data_designer.config.seed_source import (
|
|
67
|
+
DataFrameSeedSource,
|
|
68
|
+
HuggingFaceSeedSource,
|
|
69
|
+
LocalFileSeedSource,
|
|
70
|
+
)
|
|
69
71
|
from data_designer.config.utils.code_lang import CodeLang
|
|
70
72
|
from data_designer.config.utils.info import InfoType
|
|
71
73
|
from data_designer.config.validator_params import (
|
|
@@ -89,9 +91,8 @@ def get_config_exports() -> list[str]:
|
|
|
89
91
|
DataDesignerColumnType.__name__,
|
|
90
92
|
DataDesignerConfig.__name__,
|
|
91
93
|
DataDesignerConfigBuilder.__name__,
|
|
94
|
+
DataFrameSeedSource.__name__,
|
|
92
95
|
BuildStage.__name__,
|
|
93
|
-
DatastoreSeedDatasetReference.__name__,
|
|
94
|
-
DatastoreSettings.__name__,
|
|
95
96
|
DatetimeSamplerParams.__name__,
|
|
96
97
|
DropColumnsProcessorConfig.__name__,
|
|
97
98
|
EmbeddingColumnConfig.__name__,
|
|
@@ -99,16 +100,17 @@ def get_config_exports() -> list[str]:
|
|
|
99
100
|
ExpressionColumnConfig.__name__,
|
|
100
101
|
GaussianSamplerParams.__name__,
|
|
101
102
|
GenerationType.__name__,
|
|
103
|
+
HuggingFaceSeedSource.__name__,
|
|
102
104
|
IndexRange.__name__,
|
|
103
105
|
InfoType.__name__,
|
|
104
106
|
ImageContext.__name__,
|
|
105
107
|
ImageFormat.__name__,
|
|
106
|
-
InferenceParameters.__name__,
|
|
107
108
|
JudgeScoreProfilerConfig.__name__,
|
|
108
109
|
LLMCodeColumnConfig.__name__,
|
|
109
110
|
LLMJudgeColumnConfig.__name__,
|
|
110
111
|
LLMStructuredColumnConfig.__name__,
|
|
111
112
|
LLMTextColumnConfig.__name__,
|
|
113
|
+
LocalFileSeedSource.__name__,
|
|
112
114
|
ManualDistribution.__name__,
|
|
113
115
|
ManualDistributionParams.__name__,
|
|
114
116
|
Modality.__name__,
|
data_designer/config/models.py
CHANGED
|
@@ -5,7 +5,7 @@ import logging
|
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from enum import Enum
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import Any, Generic, Literal, TypeVar
|
|
8
|
+
from typing import Annotated, Any, Generic, Literal, TypeVar
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
@@ -278,7 +278,7 @@ class ChatCompletionInferenceParams(BaseInferenceParams):
|
|
|
278
278
|
generation_type: Type of generation, always "chat-completion" for this class.
|
|
279
279
|
temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling.
|
|
280
280
|
top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
|
|
281
|
-
max_tokens: Maximum number of tokens
|
|
281
|
+
max_tokens: Maximum number of tokens to generate in the response.
|
|
282
282
|
"""
|
|
283
283
|
|
|
284
284
|
generation_type: Literal[GenerationType.CHAT_COMPLETION] = GenerationType.CHAT_COMPLETION
|
|
@@ -357,21 +357,6 @@ class ChatCompletionInferenceParams(BaseInferenceParams):
|
|
|
357
357
|
return super()._format_value(key, value)
|
|
358
358
|
|
|
359
359
|
|
|
360
|
-
# Maintain backwards compatibility with a deprecation warning
|
|
361
|
-
class InferenceParameters(ChatCompletionInferenceParams):
|
|
362
|
-
"""
|
|
363
|
-
Deprecated: Use ChatCompletionInferenceParams instead.
|
|
364
|
-
This alias will be removed in a future version.
|
|
365
|
-
"""
|
|
366
|
-
|
|
367
|
-
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
368
|
-
logger.warning(
|
|
369
|
-
"InferenceParameters is deprecated and will be removed in a future version. "
|
|
370
|
-
"Use ChatCompletionInferenceParams instead."
|
|
371
|
-
)
|
|
372
|
-
super().__init__(*args, **kwargs)
|
|
373
|
-
|
|
374
|
-
|
|
375
360
|
class EmbeddingInferenceParams(BaseInferenceParams):
|
|
376
361
|
"""Configuration for embedding generation parameters.
|
|
377
362
|
|
|
@@ -395,7 +380,9 @@ class EmbeddingInferenceParams(BaseInferenceParams):
|
|
|
395
380
|
return result
|
|
396
381
|
|
|
397
382
|
|
|
398
|
-
InferenceParamsT: TypeAlias =
|
|
383
|
+
InferenceParamsT: TypeAlias = Annotated[
|
|
384
|
+
ChatCompletionInferenceParams | EmbeddingInferenceParams, Field(discriminator="generation_type")
|
|
385
|
+
]
|
|
399
386
|
|
|
400
387
|
|
|
401
388
|
class ModelConfig(ConfigBase):
|
|
@@ -441,6 +428,7 @@ class ModelProvider(ConfigBase):
|
|
|
441
428
|
provider_type: Provider type (default: "openai"). Determines the API format to use.
|
|
442
429
|
api_key: Optional API key for authentication.
|
|
443
430
|
extra_body: Additional parameters to pass in API requests.
|
|
431
|
+
extra_headers: Additional headers to pass in API requests.
|
|
444
432
|
"""
|
|
445
433
|
|
|
446
434
|
name: str
|
|
@@ -448,6 +436,7 @@ class ModelProvider(ConfigBase):
|
|
|
448
436
|
provider_type: str = "openai"
|
|
449
437
|
api_key: str | None = None
|
|
450
438
|
extra_body: dict[str, Any] | None = None
|
|
439
|
+
extra_headers: dict[str, str] | None = None
|
|
451
440
|
|
|
452
441
|
|
|
453
442
|
def load_model_configs(model_configs: list[ModelConfig] | str | Path) -> list[ModelConfig]:
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from pydantic import Field, model_validator
|
|
5
|
+
from typing_extensions import Self
|
|
6
|
+
|
|
7
|
+
from data_designer.config.base import ConfigBase
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RunConfig(ConfigBase):
|
|
11
|
+
"""Runtime configuration for dataset generation.
|
|
12
|
+
|
|
13
|
+
Groups configuration options that control generation behavior but aren't
|
|
14
|
+
part of the dataset configuration itself.
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
disable_early_shutdown: If True, disables early shutdown entirely. Generation
|
|
18
|
+
will continue regardless of error rate. Default is False.
|
|
19
|
+
shutdown_error_rate: Error rate threshold (0.0-1.0) that triggers early shutdown.
|
|
20
|
+
When early shutdown is disabled, this value is normalized to 1.0. Default is 0.5.
|
|
21
|
+
shutdown_error_window: Minimum number of completed tasks before error rate
|
|
22
|
+
monitoring begins. Must be >= 0. Default is 10.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
disable_early_shutdown: bool = False
|
|
26
|
+
shutdown_error_rate: float = Field(default=0.5, ge=0.0, le=1.0)
|
|
27
|
+
shutdown_error_window: int = Field(default=10, ge=0)
|
|
28
|
+
|
|
29
|
+
@model_validator(mode="after")
|
|
30
|
+
def normalize_shutdown_settings(self) -> Self:
|
|
31
|
+
"""Set shutdown_error_rate to 1.0 when early shutdown is disabled."""
|
|
32
|
+
if self.disable_early_shutdown:
|
|
33
|
+
self.shutdown_error_rate = 1.0
|
|
34
|
+
return self
|
data_designer/config/seed.py
CHANGED
|
@@ -1,19 +1,13 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from abc import ABC
|
|
5
4
|
from enum import Enum
|
|
6
5
|
|
|
7
|
-
from pydantic import Field,
|
|
6
|
+
from pydantic import Field, model_validator
|
|
8
7
|
from typing_extensions import Self
|
|
9
8
|
|
|
10
9
|
from data_designer.config.base import ConfigBase
|
|
11
|
-
from data_designer.config.
|
|
12
|
-
from data_designer.config.utils.io_helpers import (
|
|
13
|
-
VALID_DATASET_FILE_EXTENSIONS,
|
|
14
|
-
validate_dataset_file_path,
|
|
15
|
-
validate_path_contains_files_of_type,
|
|
16
|
-
)
|
|
10
|
+
from data_designer.config.seed_source import SeedSourceT
|
|
17
11
|
|
|
18
12
|
|
|
19
13
|
class SamplingStrategy(str, Enum):
|
|
@@ -62,7 +56,7 @@ class SeedConfig(ConfigBase):
|
|
|
62
56
|
"""Configuration for sampling data from a seed dataset.
|
|
63
57
|
|
|
64
58
|
Args:
|
|
65
|
-
|
|
59
|
+
source: A SeedSource defining where the seed data exists
|
|
66
60
|
sampling_strategy: Strategy for how to sample rows from the dataset.
|
|
67
61
|
- ORDERED: Read rows sequentially in their original order.
|
|
68
62
|
- SHUFFLE: Randomly shuffle rows before sampling. When used with
|
|
@@ -75,70 +69,46 @@ class SeedConfig(ConfigBase):
|
|
|
75
69
|
|
|
76
70
|
Examples:
|
|
77
71
|
Read rows sequentially from start to end:
|
|
78
|
-
SeedConfig(
|
|
72
|
+
SeedConfig(
|
|
73
|
+
source=LocalFileSeedSource(path="my_data.parquet"),
|
|
74
|
+
sampling_strategy=SamplingStrategy.ORDERED
|
|
75
|
+
)
|
|
79
76
|
|
|
80
77
|
Read rows in random order:
|
|
81
|
-
SeedConfig(
|
|
78
|
+
SeedConfig(
|
|
79
|
+
source=LocalFileSeedSource(path="my_data.parquet"),
|
|
80
|
+
sampling_strategy=SamplingStrategy.SHUFFLE
|
|
81
|
+
)
|
|
82
82
|
|
|
83
83
|
Read specific index range (rows 100-199):
|
|
84
84
|
SeedConfig(
|
|
85
|
-
|
|
85
|
+
source=LocalFileSeedSource(path="my_data.parquet"),
|
|
86
86
|
sampling_strategy=SamplingStrategy.ORDERED,
|
|
87
87
|
selection_strategy=IndexRange(start=100, end=199)
|
|
88
88
|
)
|
|
89
89
|
|
|
90
90
|
Read random rows from a specific index range (shuffles within rows 100-199):
|
|
91
91
|
SeedConfig(
|
|
92
|
-
|
|
92
|
+
source=LocalFileSeedSource(path="my_data.parquet"),
|
|
93
93
|
sampling_strategy=SamplingStrategy.SHUFFLE,
|
|
94
94
|
selection_strategy=IndexRange(start=100, end=199)
|
|
95
95
|
)
|
|
96
96
|
|
|
97
97
|
Read from partition 2 (3rd partition, zero-based) of 5 partitions (20% of dataset):
|
|
98
98
|
SeedConfig(
|
|
99
|
-
|
|
99
|
+
source=LocalFileSeedSource(path="my_data.parquet"),
|
|
100
100
|
sampling_strategy=SamplingStrategy.ORDERED,
|
|
101
101
|
selection_strategy=PartitionBlock(index=2, num_partitions=5)
|
|
102
102
|
)
|
|
103
103
|
|
|
104
104
|
Read shuffled rows from partition 0 of 10 partitions (shuffles within the partition):
|
|
105
105
|
SeedConfig(
|
|
106
|
-
|
|
106
|
+
source=LocalFileSeedSource(path="my_data.parquet"),
|
|
107
107
|
sampling_strategy=SamplingStrategy.SHUFFLE,
|
|
108
108
|
selection_strategy=PartitionBlock(index=0, num_partitions=10)
|
|
109
109
|
)
|
|
110
110
|
"""
|
|
111
111
|
|
|
112
|
-
|
|
112
|
+
source: SeedSourceT
|
|
113
113
|
sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED
|
|
114
114
|
selection_strategy: IndexRange | PartitionBlock | None = None
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
class SeedDatasetReference(ABC, ConfigBase):
|
|
118
|
-
dataset: str
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
class DatastoreSeedDatasetReference(SeedDatasetReference):
|
|
122
|
-
datastore_settings: DatastoreSettings
|
|
123
|
-
|
|
124
|
-
@property
|
|
125
|
-
def repo_id(self) -> str:
|
|
126
|
-
return "/".join(self.dataset.split("/")[:-1])
|
|
127
|
-
|
|
128
|
-
@property
|
|
129
|
-
def filename(self) -> str:
|
|
130
|
-
return self.dataset.split("/")[-1]
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
class LocalSeedDatasetReference(SeedDatasetReference):
|
|
134
|
-
@field_validator("dataset", mode="after")
|
|
135
|
-
def validate_dataset_is_file(cls, v: str) -> str:
|
|
136
|
-
valid_wild_card_versions = {f"*{ext}" for ext in VALID_DATASET_FILE_EXTENSIONS}
|
|
137
|
-
if any(v.endswith(wildcard) for wildcard in valid_wild_card_versions):
|
|
138
|
-
parts = v.split("*.")
|
|
139
|
-
file_path = parts[0]
|
|
140
|
-
file_extension = parts[-1]
|
|
141
|
-
validate_path_contains_files_of_type(file_path, file_extension)
|
|
142
|
-
else:
|
|
143
|
-
validate_dataset_file_path(v)
|
|
144
|
-
return v
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from abc import ABC
|
|
5
|
+
from typing import Annotated, Literal
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
9
|
+
from typing_extensions import Self
|
|
10
|
+
|
|
11
|
+
from data_designer.config.utils.io_helpers import (
|
|
12
|
+
VALID_DATASET_FILE_EXTENSIONS,
|
|
13
|
+
validate_dataset_file_path,
|
|
14
|
+
validate_path_contains_files_of_type,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SeedSource(BaseModel, ABC):
|
|
19
|
+
"""Base class for seed dataset configurations.
|
|
20
|
+
|
|
21
|
+
All subclasses must define a `seed_type` field with a Literal value.
|
|
22
|
+
This serves as a discriminated union discriminator.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
seed_type: str
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class LocalFileSeedSource(SeedSource):
|
|
29
|
+
seed_type: Literal["local"] = "local"
|
|
30
|
+
|
|
31
|
+
path: str
|
|
32
|
+
|
|
33
|
+
@field_validator("path", mode="after")
|
|
34
|
+
def validate_path(cls, v: str) -> str:
|
|
35
|
+
valid_wild_card_versions = {f"*{ext}" for ext in VALID_DATASET_FILE_EXTENSIONS}
|
|
36
|
+
if any(v.endswith(wildcard) for wildcard in valid_wild_card_versions):
|
|
37
|
+
parts = v.split("*.")
|
|
38
|
+
file_path = parts[0]
|
|
39
|
+
file_extension = parts[-1]
|
|
40
|
+
validate_path_contains_files_of_type(file_path, file_extension)
|
|
41
|
+
else:
|
|
42
|
+
validate_dataset_file_path(v)
|
|
43
|
+
return v
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def from_dataframe(cls, df: pd.DataFrame, path: str) -> Self:
|
|
47
|
+
df.to_parquet(path, index=False)
|
|
48
|
+
return cls(path=path)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class HuggingFaceSeedSource(SeedSource):
|
|
52
|
+
seed_type: Literal["hf"] = "hf"
|
|
53
|
+
|
|
54
|
+
path: str = Field(
|
|
55
|
+
...,
|
|
56
|
+
description="Path to the seed data in HuggingFace. Wildcards are allowed. Examples include 'datasets/my-username/my-dataset/data/000_00000.parquet', 'datasets/my-username/my-dataset/data/*.parquet', 'datasets/my-username/my-dataset/**/*.parquet'",
|
|
57
|
+
)
|
|
58
|
+
token: str | None = None
|
|
59
|
+
endpoint: str = "https://huggingface.co"
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class DataFrameSeedSource(SeedSource):
|
|
63
|
+
seed_type: Literal["df"] = "df"
|
|
64
|
+
|
|
65
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
66
|
+
|
|
67
|
+
df: pd.DataFrame
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
SeedSourceT = Annotated[
|
|
71
|
+
LocalFileSeedSource | HuggingFaceSeedSource | DataFrameSeedSource,
|
|
72
|
+
Field(discriminator="seed_type"),
|
|
73
|
+
]
|