data-designer 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/cli/forms/model_builder.py +2 -2
- data_designer/config/config_builder.py +30 -113
- data_designer/config/errors.py +3 -0
- data_designer/config/exports.py +8 -6
- data_designer/config/models.py +7 -18
- data_designer/config/run_config.py +34 -0
- data_designer/config/seed.py +16 -46
- data_designer/config/seed_source.py +84 -0
- data_designer/config/utils/constants.py +27 -2
- data_designer/config/utils/io_helpers.py +0 -20
- data_designer/engine/column_generators/generators/seed_dataset.py +5 -5
- data_designer/engine/column_generators/generators/validation.py +3 -0
- data_designer/engine/column_generators/registry.py +1 -1
- data_designer/engine/compiler.py +69 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +3 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +1 -1
- data_designer/engine/models/facade.py +2 -0
- data_designer/engine/processing/gsonschema/validators.py +55 -0
- data_designer/engine/resources/resource_provider.py +17 -5
- data_designer/engine/resources/seed_reader.py +149 -0
- data_designer/essentials/__init__.py +2 -0
- data_designer/interface/data_designer.py +72 -62
- data_designer/plugin_manager.py +1 -1
- data_designer/plugins/errors.py +3 -0
- data_designer/plugins/plugin.py +82 -12
- data_designer/plugins/testing/__init__.py +8 -0
- data_designer/plugins/testing/stubs.py +145 -0
- data_designer/plugins/testing/utils.py +11 -0
- {data_designer-0.2.3.dist-info → data_designer-0.3.1.dist-info}/METADATA +3 -3
- {data_designer-0.2.3.dist-info → data_designer-0.3.1.dist-info}/RECORD +35 -30
- data_designer/config/datastore.py +0 -187
- data_designer/engine/resources/seed_dataset_data_store.py +0 -84
- /data_designer/{config/utils → engine}/validation.py +0 -0
- {data_designer-0.2.3.dist-info → data_designer-0.3.1.dist-info}/WHEEL +0 -0
- {data_designer-0.2.3.dist-info → data_designer-0.3.1.dist-info}/entry_points.txt +0 -0
- {data_designer-0.2.3.dist-info → data_designer-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -20,7 +20,7 @@ from data_designer.config.models import (
|
|
|
20
20
|
ModelProvider,
|
|
21
21
|
)
|
|
22
22
|
from data_designer.config.preview_results import PreviewResults
|
|
23
|
-
from data_designer.config.
|
|
23
|
+
from data_designer.config.run_config import RunConfig
|
|
24
24
|
from data_designer.config.utils.constants import (
|
|
25
25
|
DEFAULT_NUM_RECORDS,
|
|
26
26
|
MANAGED_ASSETS_PATH,
|
|
@@ -29,21 +29,23 @@ from data_designer.config.utils.constants import (
|
|
|
29
29
|
PREDEFINED_PROVIDERS,
|
|
30
30
|
)
|
|
31
31
|
from data_designer.config.utils.info import InfoType, InterfaceInfo
|
|
32
|
-
from data_designer.config.utils.io_helpers import write_seed_dataset
|
|
33
32
|
from data_designer.engine.analysis.dataset_profiler import (
|
|
34
33
|
DataDesignerDatasetProfiler,
|
|
35
34
|
DatasetProfilerConfig,
|
|
36
35
|
)
|
|
36
|
+
from data_designer.engine.compiler import compile_data_designer_config
|
|
37
37
|
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
|
|
38
38
|
from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder
|
|
39
39
|
from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs
|
|
40
40
|
from data_designer.engine.model_provider import resolve_model_provider_registry
|
|
41
|
-
from data_designer.engine.models.registry import create_model_registry
|
|
42
41
|
from data_designer.engine.resources.managed_storage import init_managed_blob_storage
|
|
43
|
-
from data_designer.engine.resources.resource_provider import ResourceProvider
|
|
44
|
-
from data_designer.engine.resources.
|
|
45
|
-
|
|
46
|
-
|
|
42
|
+
from data_designer.engine.resources.resource_provider import ResourceProvider, create_resource_provider
|
|
43
|
+
from data_designer.engine.resources.seed_reader import (
|
|
44
|
+
DataFrameSeedReader,
|
|
45
|
+
HuggingFaceSeedReader,
|
|
46
|
+
LocalFileSeedReader,
|
|
47
|
+
SeedReader,
|
|
48
|
+
SeedReaderRegistry,
|
|
47
49
|
)
|
|
48
50
|
from data_designer.engine.secret_resolver import (
|
|
49
51
|
CompositeResolver,
|
|
@@ -61,6 +63,14 @@ from data_designer.logging import RandomEmoji
|
|
|
61
63
|
|
|
62
64
|
DEFAULT_BUFFER_SIZE = 1000
|
|
63
65
|
|
|
66
|
+
DEFAULT_SECRET_RESOLVER = CompositeResolver([EnvironmentResolver(), PlaintextResolver()])
|
|
67
|
+
|
|
68
|
+
DEFAULT_SEED_READERS = [
|
|
69
|
+
HuggingFaceSeedReader(),
|
|
70
|
+
LocalFileSeedReader(),
|
|
71
|
+
DataFrameSeedReader(),
|
|
72
|
+
]
|
|
73
|
+
|
|
64
74
|
logger = logging.getLogger(__name__)
|
|
65
75
|
|
|
66
76
|
|
|
@@ -79,6 +89,7 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
|
|
|
79
89
|
uses default providers.
|
|
80
90
|
secret_resolver: Resolver for handling secrets and credentials. Defaults to
|
|
81
91
|
EnvironmentResolver which reads secrets from environment variables.
|
|
92
|
+
seed_readers: Optional list of seed readers. If None, uses default readers.
|
|
82
93
|
managed_assets_path: Path to the managed assets directory. This is used to point
|
|
83
94
|
to the location of managed datasets and other assets used during dataset generation.
|
|
84
95
|
If not provided, will check for an environment variable called DATA_DESIGNER_MANAGED_ASSETS_PATH.
|
|
@@ -92,52 +103,19 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
|
|
|
92
103
|
*,
|
|
93
104
|
model_providers: list[ModelProvider] | None = None,
|
|
94
105
|
secret_resolver: SecretResolver | None = None,
|
|
106
|
+
seed_readers: list[SeedReader] | None = None,
|
|
95
107
|
managed_assets_path: Path | str | None = None,
|
|
96
108
|
):
|
|
97
|
-
self._secret_resolver = secret_resolver or
|
|
109
|
+
self._secret_resolver = secret_resolver or DEFAULT_SECRET_RESOLVER
|
|
98
110
|
self._artifact_path = Path(artifact_path) if artifact_path is not None else Path.cwd() / "artifacts"
|
|
99
111
|
self._buffer_size = DEFAULT_BUFFER_SIZE
|
|
112
|
+
self._run_config = RunConfig()
|
|
100
113
|
self._managed_assets_path = Path(managed_assets_path or MANAGED_ASSETS_PATH)
|
|
101
114
|
self._model_providers = self._resolve_model_providers(model_providers)
|
|
102
115
|
self._model_provider_registry = resolve_model_provider_registry(
|
|
103
116
|
self._model_providers, get_default_provider_name()
|
|
104
117
|
)
|
|
105
|
-
|
|
106
|
-
@staticmethod
|
|
107
|
-
def make_seed_reference_from_file(file_path: str | Path) -> LocalSeedDatasetReference:
|
|
108
|
-
"""Create a seed dataset reference from an existing file.
|
|
109
|
-
|
|
110
|
-
Supported file extensions: .parquet (recommended), .csv, .json, .jsonl
|
|
111
|
-
|
|
112
|
-
Args:
|
|
113
|
-
file_path: Path to an existing dataset file.
|
|
114
|
-
|
|
115
|
-
Returns:
|
|
116
|
-
A LocalSeedDatasetReference pointing to the specified file.
|
|
117
|
-
"""
|
|
118
|
-
return LocalSeedDatasetReference(dataset=str(file_path))
|
|
119
|
-
|
|
120
|
-
@classmethod
|
|
121
|
-
def make_seed_reference_from_dataframe(
|
|
122
|
-
cls, dataframe: pd.DataFrame, file_path: str | Path
|
|
123
|
-
) -> LocalSeedDatasetReference:
|
|
124
|
-
"""Create a seed dataset reference from a pandas DataFrame.
|
|
125
|
-
|
|
126
|
-
This method writes the DataFrame to disk and returns a reference that can
|
|
127
|
-
be passed to the config builder's `with_seed_dataset` method. If the file
|
|
128
|
-
already exists, it will be overwritten.
|
|
129
|
-
|
|
130
|
-
Supported file extensions: .parquet (recommended), .csv, .json, .jsonl
|
|
131
|
-
|
|
132
|
-
Args:
|
|
133
|
-
dataframe: Pandas DataFrame to use as seed data.
|
|
134
|
-
file_path: Path where to save dataset.
|
|
135
|
-
|
|
136
|
-
Returns:
|
|
137
|
-
A LocalSeedDatasetReference pointing to the written file.
|
|
138
|
-
"""
|
|
139
|
-
write_seed_dataset(dataframe, Path(file_path))
|
|
140
|
-
return cls.make_seed_reference_from_file(file_path)
|
|
118
|
+
self._seed_reader_registry = SeedReaderRegistry(readers=seed_readers or DEFAULT_SEED_READERS)
|
|
141
119
|
|
|
142
120
|
@property
|
|
143
121
|
def info(self) -> InterfaceInfo:
|
|
@@ -274,6 +252,23 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
|
|
|
274
252
|
config_builder=config_builder,
|
|
275
253
|
)
|
|
276
254
|
|
|
255
|
+
def validate(self, config_builder: DataDesignerConfigBuilder) -> None:
|
|
256
|
+
"""Validate the Data Designer configuration as defined by the DataDesignerConfigBuilder
|
|
257
|
+
with the configured engine components (SecretResolver, SeedReaders, etc.).
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
config_builder: The DataDesignerConfigBuilder containing the dataset
|
|
261
|
+
configuration (columns, constraints, seed data, etc.).
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
None if the configuration is valid.
|
|
265
|
+
|
|
266
|
+
Raises:
|
|
267
|
+
InvalidConfigError: If the configuration is invalid.
|
|
268
|
+
"""
|
|
269
|
+
resource_provider = self._create_resource_provider("validate-configuration", config_builder)
|
|
270
|
+
compile_data_designer_config(config_builder, resource_provider)
|
|
271
|
+
|
|
277
272
|
def get_default_model_configs(self) -> list[ModelConfig]:
|
|
278
273
|
"""Get the default model configurations.
|
|
279
274
|
|
|
@@ -318,6 +313,20 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
|
|
|
318
313
|
raise InvalidBufferValueError("Buffer size must be greater than 0.")
|
|
319
314
|
self._buffer_size = buffer_size
|
|
320
315
|
|
|
316
|
+
def set_run_config(self, run_config: RunConfig) -> None:
|
|
317
|
+
"""Set the runtime configuration for dataset generation.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
run_config: A RunConfig instance containing runtime settings such as
|
|
321
|
+
early shutdown behavior. Import RunConfig from data_designer.essentials.
|
|
322
|
+
|
|
323
|
+
Example:
|
|
324
|
+
>>> from data_designer.essentials import DataDesigner, RunConfig
|
|
325
|
+
>>> dd = DataDesigner()
|
|
326
|
+
>>> dd.set_run_config(RunConfig(disable_early_shutdown=True))
|
|
327
|
+
"""
|
|
328
|
+
self._run_config = run_config
|
|
329
|
+
|
|
321
330
|
def _resolve_model_providers(self, model_providers: list[ModelProvider] | None) -> list[ModelProvider]:
|
|
322
331
|
if model_providers is None:
|
|
323
332
|
model_providers = get_default_providers()
|
|
@@ -334,11 +343,15 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
|
|
|
334
343
|
return model_providers or []
|
|
335
344
|
|
|
336
345
|
def _create_dataset_builder(
|
|
337
|
-
self,
|
|
346
|
+
self,
|
|
347
|
+
config_builder: DataDesignerConfigBuilder,
|
|
348
|
+
resource_provider: ResourceProvider,
|
|
338
349
|
) -> ColumnWiseDatasetBuilder:
|
|
350
|
+
config = compile_data_designer_config(config_builder, resource_provider)
|
|
351
|
+
|
|
339
352
|
return ColumnWiseDatasetBuilder(
|
|
340
|
-
column_configs=compile_dataset_builder_column_configs(
|
|
341
|
-
processor_configs=
|
|
353
|
+
column_configs=compile_dataset_builder_column_configs(config),
|
|
354
|
+
processor_configs=config.processors or [],
|
|
342
355
|
resource_provider=resource_provider,
|
|
343
356
|
)
|
|
344
357
|
|
|
@@ -356,24 +369,21 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
|
|
|
356
369
|
def _create_resource_provider(
|
|
357
370
|
self, dataset_name: str, config_builder: DataDesignerConfigBuilder
|
|
358
371
|
) -> ResourceProvider:
|
|
359
|
-
model_configs = config_builder.model_configs
|
|
360
372
|
ArtifactStorage.mkdir_if_needed(self._artifact_path)
|
|
361
|
-
|
|
373
|
+
|
|
374
|
+
seed_dataset_source = None
|
|
375
|
+
if (seed_config := config_builder.get_seed_config()) is not None:
|
|
376
|
+
seed_dataset_source = seed_config.source
|
|
377
|
+
|
|
378
|
+
return create_resource_provider(
|
|
362
379
|
artifact_storage=ArtifactStorage(artifact_path=self._artifact_path, dataset_name=dataset_name),
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
secret_resolver=self._secret_resolver,
|
|
367
|
-
),
|
|
380
|
+
model_configs=config_builder.model_configs,
|
|
381
|
+
secret_resolver=self._secret_resolver,
|
|
382
|
+
model_provider_registry=self._model_provider_registry,
|
|
368
383
|
blob_storage=init_managed_blob_storage(str(self._managed_assets_path)),
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
else HfHubSeedDatasetDataStore(
|
|
373
|
-
endpoint=settings.endpoint,
|
|
374
|
-
token=settings.token,
|
|
375
|
-
)
|
|
376
|
-
),
|
|
384
|
+
seed_dataset_source=seed_dataset_source,
|
|
385
|
+
seed_reader_registry=self._seed_reader_registry,
|
|
386
|
+
run_config=self._run_config,
|
|
377
387
|
)
|
|
378
388
|
|
|
379
389
|
def _get_interface_info(self, model_providers: list[ModelProvider]) -> InterfaceInfo:
|
data_designer/plugin_manager.py
CHANGED
|
@@ -50,7 +50,7 @@ class PluginManager:
|
|
|
50
50
|
type_list = []
|
|
51
51
|
for plugin in self._plugin_registry.get_plugins(PluginType.COLUMN_GENERATOR):
|
|
52
52
|
if required_resources:
|
|
53
|
-
task_required_resources = plugin.
|
|
53
|
+
task_required_resources = plugin.impl_cls.metadata().required_resources or []
|
|
54
54
|
if not all(resource in task_required_resources for resource in required_resources):
|
|
55
55
|
continue
|
|
56
56
|
type_list.append(enum_type(plugin.name))
|
data_designer/plugins/errors.py
CHANGED
data_designer/plugins/plugin.py
CHANGED
|
@@ -1,14 +1,20 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import ast
|
|
7
|
+
import importlib
|
|
8
|
+
import importlib.util
|
|
4
9
|
from enum import Enum
|
|
10
|
+
from functools import cached_property
|
|
5
11
|
from typing import Literal, get_origin
|
|
6
12
|
|
|
7
|
-
from pydantic import BaseModel, model_validator
|
|
13
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
8
14
|
from typing_extensions import Self
|
|
9
15
|
|
|
10
16
|
from data_designer.config.base import ConfigBase
|
|
11
|
-
from data_designer.
|
|
17
|
+
from data_designer.plugins.errors import PluginLoadError
|
|
12
18
|
|
|
13
19
|
|
|
14
20
|
class PluginType(str, Enum):
|
|
@@ -26,11 +32,42 @@ class PluginType(str, Enum):
|
|
|
26
32
|
return self.value.replace("-", " ")
|
|
27
33
|
|
|
28
34
|
|
|
35
|
+
def _get_module_and_object_names(fully_qualified_object: str) -> tuple[str, str]:
|
|
36
|
+
try:
|
|
37
|
+
module_name, object_name = fully_qualified_object.rsplit(".", 1)
|
|
38
|
+
except ValueError:
|
|
39
|
+
# If fully_qualified_object does not have any periods, the rsplit call will return
|
|
40
|
+
# a list of length 1 and the variable assignment above will raise ValueError
|
|
41
|
+
raise PluginLoadError("Expected a fully-qualified object name, e.g. 'my_plugin.config.MyConfig'")
|
|
42
|
+
|
|
43
|
+
return module_name, object_name
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _check_class_exists_in_file(filepath: str, class_name: str) -> None:
|
|
47
|
+
try:
|
|
48
|
+
with open(filepath, "r") as file:
|
|
49
|
+
source = file.read()
|
|
50
|
+
except FileNotFoundError:
|
|
51
|
+
raise PluginLoadError(f"Could not read source code at {filepath!r}")
|
|
52
|
+
|
|
53
|
+
tree = ast.parse(source)
|
|
54
|
+
for node in ast.walk(tree):
|
|
55
|
+
if isinstance(node, ast.ClassDef) and node.name == class_name:
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
raise PluginLoadError(f"Could not find class named {class_name!r} in {filepath!r}")
|
|
59
|
+
|
|
60
|
+
|
|
29
61
|
class Plugin(BaseModel):
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
62
|
+
impl_qualified_name: str = Field(
|
|
63
|
+
...,
|
|
64
|
+
description="The fully-qualified name of the implementation class object, e.g. 'my_plugin.generator.MyColumnGenerator'",
|
|
65
|
+
)
|
|
66
|
+
config_qualified_name: str = Field(
|
|
67
|
+
..., description="The fully-qualified name o the config class object, e.g. 'my_plugin.config.MyConfig'"
|
|
68
|
+
)
|
|
69
|
+
plugin_type: PluginType = Field(..., description="The type of plugin")
|
|
70
|
+
emoji: str = Field(default="🔌", description="The emoji to use in logs related to the plugin")
|
|
34
71
|
|
|
35
72
|
@property
|
|
36
73
|
def config_type_as_class_name(self) -> str:
|
|
@@ -48,22 +85,55 @@ class Plugin(BaseModel):
|
|
|
48
85
|
def discriminator_field(self) -> str:
|
|
49
86
|
return self.plugin_type.discriminator_field
|
|
50
87
|
|
|
88
|
+
@field_validator("impl_qualified_name", "config_qualified_name", mode="after")
|
|
89
|
+
@classmethod
|
|
90
|
+
def validate_class_name(cls, value: str) -> str:
|
|
91
|
+
module_name, object_name = _get_module_and_object_names(value)
|
|
92
|
+
try:
|
|
93
|
+
spec = importlib.util.find_spec(module_name)
|
|
94
|
+
except:
|
|
95
|
+
raise PluginLoadError(f"Could not find module {module_name!r}")
|
|
96
|
+
|
|
97
|
+
if spec is None or spec.origin is None:
|
|
98
|
+
raise PluginLoadError(f"Error finding source for module {module_name!r}")
|
|
99
|
+
|
|
100
|
+
_check_class_exists_in_file(spec.origin, object_name)
|
|
101
|
+
|
|
102
|
+
return value
|
|
103
|
+
|
|
51
104
|
@model_validator(mode="after")
|
|
52
105
|
def validate_discriminator_field(self) -> Self:
|
|
53
|
-
cfg = self.
|
|
106
|
+
_, cfg = _get_module_and_object_names(self.config_qualified_name)
|
|
54
107
|
field = self.plugin_type.discriminator_field
|
|
55
108
|
if field not in self.config_cls.model_fields:
|
|
56
|
-
raise ValueError(f"Discriminator field
|
|
109
|
+
raise ValueError(f"Discriminator field {field!r} not found in config class {cfg!r}")
|
|
57
110
|
field_info = self.config_cls.model_fields[field]
|
|
58
111
|
if get_origin(field_info.annotation) is not Literal:
|
|
59
|
-
raise ValueError(f"Field
|
|
112
|
+
raise ValueError(f"Field {field!r} of {cfg!r} must be a Literal type, not {field_info.annotation!r}.")
|
|
60
113
|
if not isinstance(field_info.default, str):
|
|
61
|
-
raise ValueError(f"The default of
|
|
114
|
+
raise ValueError(f"The default of {field!r} must be a string, not {type(field_info.default)!r}.")
|
|
62
115
|
enum_key = field_info.default.replace("-", "_").upper()
|
|
63
116
|
if not enum_key.isidentifier():
|
|
64
117
|
raise ValueError(
|
|
65
|
-
f"The default value
|
|
66
|
-
f"cannot be converted to a valid enum key. The converted key
|
|
118
|
+
f"The default value {field_info.default!r} for discriminator field {field!r} "
|
|
119
|
+
f"cannot be converted to a valid enum key. The converted key {enum_key!r} "
|
|
67
120
|
f"must be a valid Python identifier."
|
|
68
121
|
)
|
|
69
122
|
return self
|
|
123
|
+
|
|
124
|
+
@cached_property
|
|
125
|
+
def config_cls(self) -> type[ConfigBase]:
|
|
126
|
+
return self._load(self.config_qualified_name)
|
|
127
|
+
|
|
128
|
+
@cached_property
|
|
129
|
+
def impl_cls(self) -> type:
|
|
130
|
+
return self._load(self.impl_qualified_name)
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def _load(fully_qualified_object: str) -> type:
|
|
134
|
+
module_name, object_name = _get_module_and_object_names(fully_qualified_object)
|
|
135
|
+
module = importlib.import_module(module_name)
|
|
136
|
+
try:
|
|
137
|
+
return getattr(module, object_name)
|
|
138
|
+
except AttributeError:
|
|
139
|
+
raise PluginLoadError(f"Could not find class {object_name!r} in module {module_name!r}")
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
from data_designer.config.base import ConfigBase
|
|
7
|
+
from data_designer.config.column_configs import SingleColumnConfig
|
|
8
|
+
from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata
|
|
9
|
+
from data_designer.engine.resources.resource_provider import ResourceType
|
|
10
|
+
from data_designer.plugins.plugin import Plugin, PluginType
|
|
11
|
+
|
|
12
|
+
MODULE_NAME = __name__
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ValidTestConfig(SingleColumnConfig):
|
|
16
|
+
"""Valid config for testing plugin creation."""
|
|
17
|
+
|
|
18
|
+
column_type: Literal["test-generator"] = "test-generator"
|
|
19
|
+
name: str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ValidTestTask(ConfigurableTask[ValidTestConfig]):
|
|
23
|
+
"""Valid task for testing plugin creation."""
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def metadata() -> ConfigurableTaskMetadata:
|
|
27
|
+
return ConfigurableTaskMetadata(
|
|
28
|
+
name="test_generator",
|
|
29
|
+
description="Test generator",
|
|
30
|
+
required_resources=None,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ConfigWithoutDiscriminator(ConfigBase):
|
|
35
|
+
some_field: str
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ConfigWithStringField(ConfigBase):
|
|
39
|
+
column_type: str = "test-generator"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ConfigWithNonStringDefault(ConfigBase):
|
|
43
|
+
column_type: Literal["test-generator"] = 123 # type: ignore
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ConfigWithInvalidKey(ConfigBase):
|
|
47
|
+
column_type: Literal["invalid-key-!@#"] = "invalid-key-!@#"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class StubPluginConfigA(SingleColumnConfig):
|
|
51
|
+
column_type: Literal["test-plugin-a"] = "test-plugin-a"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class StubPluginConfigB(SingleColumnConfig):
|
|
55
|
+
column_type: Literal["test-plugin-b"] = "test-plugin-b"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class StubPluginTaskA(ConfigurableTask[StubPluginConfigA]):
|
|
59
|
+
@staticmethod
|
|
60
|
+
def metadata() -> ConfigurableTaskMetadata:
|
|
61
|
+
return ConfigurableTaskMetadata(
|
|
62
|
+
name="test_plugin_a",
|
|
63
|
+
description="Test plugin A",
|
|
64
|
+
required_resources=None,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class StubPluginTaskB(ConfigurableTask[StubPluginConfigB]):
|
|
69
|
+
@staticmethod
|
|
70
|
+
def metadata() -> ConfigurableTaskMetadata:
|
|
71
|
+
return ConfigurableTaskMetadata(
|
|
72
|
+
name="test_plugin_b",
|
|
73
|
+
description="Test plugin B",
|
|
74
|
+
required_resources=None,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# Stub plugins requiring different combinations of resources
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class StubPluginConfigModels(SingleColumnConfig):
|
|
82
|
+
column_type: Literal["test-plugin-models"] = "test-plugin-models"
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class StubPluginConfigModelsAndBlobs(SingleColumnConfig):
|
|
86
|
+
column_type: Literal["test-plugin-models-and-blobs"] = "test-plugin-models-and-blobs"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class StubPluginConfigBlobsAndSeeds(SingleColumnConfig):
|
|
90
|
+
column_type: Literal["test-plugin-blobs-and-seeds"] = "test-plugin-blobs-and-seeds"
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class StubPluginTaskModels(ConfigurableTask[StubPluginConfigModels]):
|
|
94
|
+
@staticmethod
|
|
95
|
+
def metadata() -> ConfigurableTaskMetadata:
|
|
96
|
+
return ConfigurableTaskMetadata(
|
|
97
|
+
name="test_plugin_models",
|
|
98
|
+
description="Test plugin requiring models",
|
|
99
|
+
required_resources=[ResourceType.MODEL_REGISTRY],
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class StubPluginTaskModelsAndBlobs(ConfigurableTask[StubPluginConfigModelsAndBlobs]):
|
|
104
|
+
@staticmethod
|
|
105
|
+
def metadata() -> ConfigurableTaskMetadata:
|
|
106
|
+
return ConfigurableTaskMetadata(
|
|
107
|
+
name="test_plugin_models_and_blobs",
|
|
108
|
+
description="Test plugin requiring models and blobs",
|
|
109
|
+
required_resources=[ResourceType.MODEL_REGISTRY, ResourceType.BLOB_STORAGE],
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class StubPluginTaskBlobsAndSeeds(ConfigurableTask[StubPluginConfigBlobsAndSeeds]):
|
|
114
|
+
@staticmethod
|
|
115
|
+
def metadata() -> ConfigurableTaskMetadata:
|
|
116
|
+
return ConfigurableTaskMetadata(
|
|
117
|
+
name="test_plugin_blobs_and_seeds",
|
|
118
|
+
description="Test plugin requiring blobs and seeds",
|
|
119
|
+
required_resources=[ResourceType.BLOB_STORAGE, ResourceType.SEED_READER],
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
plugin_none = Plugin(
|
|
124
|
+
config_qualified_name=f"{MODULE_NAME}.StubPluginConfigA",
|
|
125
|
+
impl_qualified_name=f"{MODULE_NAME}.StubPluginTaskA",
|
|
126
|
+
plugin_type=PluginType.COLUMN_GENERATOR,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
plugin_models = Plugin(
|
|
130
|
+
config_qualified_name=f"{MODULE_NAME}.StubPluginConfigModels",
|
|
131
|
+
impl_qualified_name=f"{MODULE_NAME}.StubPluginTaskModels",
|
|
132
|
+
plugin_type=PluginType.COLUMN_GENERATOR,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
plugin_models_and_blobs = Plugin(
|
|
136
|
+
config_qualified_name=f"{MODULE_NAME}.StubPluginConfigModelsAndBlobs",
|
|
137
|
+
impl_qualified_name=f"{MODULE_NAME}.StubPluginTaskModelsAndBlobs",
|
|
138
|
+
plugin_type=PluginType.COLUMN_GENERATOR,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
plugin_blobs_and_seeds = Plugin(
|
|
142
|
+
config_qualified_name=f"{MODULE_NAME}.StubPluginConfigBlobsAndSeeds",
|
|
143
|
+
impl_qualified_name=f"{MODULE_NAME}.StubPluginTaskBlobsAndSeeds",
|
|
144
|
+
plugin_type=PluginType.COLUMN_GENERATOR,
|
|
145
|
+
)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.config.base import ConfigBase
|
|
5
|
+
from data_designer.engine.configurable_task import ConfigurableTask
|
|
6
|
+
from data_designer.plugins.plugin import Plugin
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def assert_valid_plugin(plugin: Plugin) -> None:
|
|
10
|
+
assert issubclass(plugin.config_cls, ConfigBase), "Plugin config class is not a subclass of ConfigBase"
|
|
11
|
+
assert issubclass(plugin.impl_cls, ConfigurableTask), "Plugin impl class is not a subclass of ConfigurableTask"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: data-designer
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.1
|
|
4
4
|
Summary: General framework for synthetic data generation
|
|
5
5
|
License-Expression: Apache-2.0
|
|
6
6
|
License-File: LICENSE
|
|
@@ -24,7 +24,7 @@ Requires-Dist: huggingface-hub<2,>=1.0.1
|
|
|
24
24
|
Requires-Dist: jinja2<4,>=3.1.6
|
|
25
25
|
Requires-Dist: json-repair<1,>=0.48.0
|
|
26
26
|
Requires-Dist: jsonpath-rust-bindings<2,>=1.0
|
|
27
|
-
Requires-Dist: litellm<
|
|
27
|
+
Requires-Dist: litellm<1.80.12,>=1.73.6
|
|
28
28
|
Requires-Dist: lxml<7,>=6.0.2
|
|
29
29
|
Requires-Dist: marko<3,>=2.1.2
|
|
30
30
|
Requires-Dist: networkx<4,>=3.0
|
|
@@ -181,7 +181,7 @@ ModelConfig(
|
|
|
181
181
|
alias="nv-reasoning",
|
|
182
182
|
model="openai/gpt-oss-20b",
|
|
183
183
|
provider="nvidia",
|
|
184
|
-
inference_parameters=
|
|
184
|
+
inference_parameters=ChatCompletionInferenceParams(
|
|
185
185
|
temperature=0.3,
|
|
186
186
|
top_p=0.9,
|
|
187
187
|
max_tokens=4096,
|