data-designer 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. data_designer/_version.py +2 -2
  2. data_designer/cli/forms/model_builder.py +2 -2
  3. data_designer/config/config_builder.py +30 -113
  4. data_designer/config/errors.py +3 -0
  5. data_designer/config/exports.py +8 -6
  6. data_designer/config/models.py +7 -18
  7. data_designer/config/run_config.py +34 -0
  8. data_designer/config/seed.py +16 -46
  9. data_designer/config/seed_source.py +73 -0
  10. data_designer/config/utils/constants.py +27 -2
  11. data_designer/config/utils/io_helpers.py +0 -20
  12. data_designer/engine/column_generators/generators/seed_dataset.py +5 -5
  13. data_designer/engine/column_generators/generators/validation.py +3 -0
  14. data_designer/engine/column_generators/registry.py +1 -1
  15. data_designer/engine/compiler.py +69 -0
  16. data_designer/engine/dataset_builders/column_wise_builder.py +3 -0
  17. data_designer/engine/dataset_builders/utils/config_compiler.py +1 -1
  18. data_designer/engine/models/facade.py +2 -0
  19. data_designer/engine/models/litellm_overrides.py +2 -1
  20. data_designer/engine/processing/gsonschema/validators.py +55 -0
  21. data_designer/engine/resources/resource_provider.py +17 -5
  22. data_designer/engine/resources/seed_reader.py +149 -0
  23. data_designer/essentials/__init__.py +2 -0
  24. data_designer/interface/data_designer.py +72 -62
  25. data_designer/plugin_manager.py +1 -1
  26. data_designer/plugins/errors.py +3 -0
  27. data_designer/plugins/plugin.py +82 -12
  28. data_designer/plugins/testing/__init__.py +8 -0
  29. data_designer/plugins/testing/stubs.py +145 -0
  30. data_designer/plugins/testing/utils.py +11 -0
  31. {data_designer-0.2.2.dist-info → data_designer-0.3.0.dist-info}/METADATA +3 -3
  32. {data_designer-0.2.2.dist-info → data_designer-0.3.0.dist-info}/RECORD +36 -31
  33. data_designer/config/datastore.py +0 -187
  34. data_designer/engine/resources/seed_dataset_data_store.py +0 -84
  35. /data_designer/{config/utils → engine}/validation.py +0 -0
  36. {data_designer-0.2.2.dist-info → data_designer-0.3.0.dist-info}/WHEEL +0 -0
  37. {data_designer-0.2.2.dist-info → data_designer-0.3.0.dist-info}/entry_points.txt +0 -0
  38. {data_designer-0.2.2.dist-info → data_designer-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -20,7 +20,7 @@ from data_designer.config.models import (
20
20
  ModelProvider,
21
21
  )
22
22
  from data_designer.config.preview_results import PreviewResults
23
- from data_designer.config.seed import LocalSeedDatasetReference
23
+ from data_designer.config.run_config import RunConfig
24
24
  from data_designer.config.utils.constants import (
25
25
  DEFAULT_NUM_RECORDS,
26
26
  MANAGED_ASSETS_PATH,
@@ -29,21 +29,23 @@ from data_designer.config.utils.constants import (
29
29
  PREDEFINED_PROVIDERS,
30
30
  )
31
31
  from data_designer.config.utils.info import InfoType, InterfaceInfo
32
- from data_designer.config.utils.io_helpers import write_seed_dataset
33
32
  from data_designer.engine.analysis.dataset_profiler import (
34
33
  DataDesignerDatasetProfiler,
35
34
  DatasetProfilerConfig,
36
35
  )
36
+ from data_designer.engine.compiler import compile_data_designer_config
37
37
  from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
38
38
  from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder
39
39
  from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs
40
40
  from data_designer.engine.model_provider import resolve_model_provider_registry
41
- from data_designer.engine.models.registry import create_model_registry
42
41
  from data_designer.engine.resources.managed_storage import init_managed_blob_storage
43
- from data_designer.engine.resources.resource_provider import ResourceProvider
44
- from data_designer.engine.resources.seed_dataset_data_store import (
45
- HfHubSeedDatasetDataStore,
46
- LocalSeedDatasetDataStore,
42
+ from data_designer.engine.resources.resource_provider import ResourceProvider, create_resource_provider
43
+ from data_designer.engine.resources.seed_reader import (
44
+ DataFrameSeedReader,
45
+ HuggingFaceSeedReader,
46
+ LocalFileSeedReader,
47
+ SeedReader,
48
+ SeedReaderRegistry,
47
49
  )
48
50
  from data_designer.engine.secret_resolver import (
49
51
  CompositeResolver,
@@ -61,6 +63,14 @@ from data_designer.logging import RandomEmoji
61
63
 
62
64
  DEFAULT_BUFFER_SIZE = 1000
63
65
 
66
+ DEFAULT_SECRET_RESOLVER = CompositeResolver([EnvironmentResolver(), PlaintextResolver()])
67
+
68
+ DEFAULT_SEED_READERS = [
69
+ HuggingFaceSeedReader(),
70
+ LocalFileSeedReader(),
71
+ DataFrameSeedReader(),
72
+ ]
73
+
64
74
  logger = logging.getLogger(__name__)
65
75
 
66
76
 
@@ -79,6 +89,7 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
79
89
  uses default providers.
80
90
  secret_resolver: Resolver for handling secrets and credentials. Defaults to
81
91
  EnvironmentResolver which reads secrets from environment variables.
92
+ seed_readers: Optional list of seed readers. If None, uses default readers.
82
93
  managed_assets_path: Path to the managed assets directory. This is used to point
83
94
  to the location of managed datasets and other assets used during dataset generation.
84
95
  If not provided, will check for an environment variable called DATA_DESIGNER_MANAGED_ASSETS_PATH.
@@ -92,52 +103,19 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
92
103
  *,
93
104
  model_providers: list[ModelProvider] | None = None,
94
105
  secret_resolver: SecretResolver | None = None,
106
+ seed_readers: list[SeedReader] | None = None,
95
107
  managed_assets_path: Path | str | None = None,
96
108
  ):
97
- self._secret_resolver = secret_resolver or CompositeResolver([EnvironmentResolver(), PlaintextResolver()])
109
+ self._secret_resolver = secret_resolver or DEFAULT_SECRET_RESOLVER
98
110
  self._artifact_path = Path(artifact_path) if artifact_path is not None else Path.cwd() / "artifacts"
99
111
  self._buffer_size = DEFAULT_BUFFER_SIZE
112
+ self._run_config = RunConfig()
100
113
  self._managed_assets_path = Path(managed_assets_path or MANAGED_ASSETS_PATH)
101
114
  self._model_providers = self._resolve_model_providers(model_providers)
102
115
  self._model_provider_registry = resolve_model_provider_registry(
103
116
  self._model_providers, get_default_provider_name()
104
117
  )
105
-
106
- @staticmethod
107
- def make_seed_reference_from_file(file_path: str | Path) -> LocalSeedDatasetReference:
108
- """Create a seed dataset reference from an existing file.
109
-
110
- Supported file extensions: .parquet (recommended), .csv, .json, .jsonl
111
-
112
- Args:
113
- file_path: Path to an existing dataset file.
114
-
115
- Returns:
116
- A LocalSeedDatasetReference pointing to the specified file.
117
- """
118
- return LocalSeedDatasetReference(dataset=str(file_path))
119
-
120
- @classmethod
121
- def make_seed_reference_from_dataframe(
122
- cls, dataframe: pd.DataFrame, file_path: str | Path
123
- ) -> LocalSeedDatasetReference:
124
- """Create a seed dataset reference from a pandas DataFrame.
125
-
126
- This method writes the DataFrame to disk and returns a reference that can
127
- be passed to the config builder's `with_seed_dataset` method. If the file
128
- already exists, it will be overwritten.
129
-
130
- Supported file extensions: .parquet (recommended), .csv, .json, .jsonl
131
-
132
- Args:
133
- dataframe: Pandas DataFrame to use as seed data.
134
- file_path: Path where to save dataset.
135
-
136
- Returns:
137
- A LocalSeedDatasetReference pointing to the written file.
138
- """
139
- write_seed_dataset(dataframe, Path(file_path))
140
- return cls.make_seed_reference_from_file(file_path)
118
+ self._seed_reader_registry = SeedReaderRegistry(readers=seed_readers or DEFAULT_SEED_READERS)
141
119
 
142
120
  @property
143
121
  def info(self) -> InterfaceInfo:
@@ -274,6 +252,23 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
274
252
  config_builder=config_builder,
275
253
  )
276
254
 
255
+ def validate(self, config_builder: DataDesignerConfigBuilder) -> None:
256
+ """Validate the Data Designer configuration as defined by the DataDesignerConfigBuilder
257
+ with the configured engine components (SecretResolver, SeedReaders, etc.).
258
+
259
+ Args:
260
+ config_builder: The DataDesignerConfigBuilder containing the dataset
261
+ configuration (columns, constraints, seed data, etc.).
262
+
263
+ Returns:
264
+ None if the configuration is valid.
265
+
266
+ Raises:
267
+ InvalidConfigError: If the configuration is invalid.
268
+ """
269
+ resource_provider = self._create_resource_provider("validate-configuration", config_builder)
270
+ compile_data_designer_config(config_builder, resource_provider)
271
+
277
272
  def get_default_model_configs(self) -> list[ModelConfig]:
278
273
  """Get the default model configurations.
279
274
 
@@ -318,6 +313,20 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
318
313
  raise InvalidBufferValueError("Buffer size must be greater than 0.")
319
314
  self._buffer_size = buffer_size
320
315
 
316
+ def set_run_config(self, run_config: RunConfig) -> None:
317
+ """Set the runtime configuration for dataset generation.
318
+
319
+ Args:
320
+ run_config: A RunConfig instance containing runtime settings such as
321
+ early shutdown behavior. Import RunConfig from data_designer.essentials.
322
+
323
+ Example:
324
+ >>> from data_designer.essentials import DataDesigner, RunConfig
325
+ >>> dd = DataDesigner()
326
+ >>> dd.set_run_config(RunConfig(disable_early_shutdown=True))
327
+ """
328
+ self._run_config = run_config
329
+
321
330
  def _resolve_model_providers(self, model_providers: list[ModelProvider] | None) -> list[ModelProvider]:
322
331
  if model_providers is None:
323
332
  model_providers = get_default_providers()
@@ -334,11 +343,15 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
334
343
  return model_providers or []
335
344
 
336
345
  def _create_dataset_builder(
337
- self, config_builder: DataDesignerConfigBuilder, resource_provider: ResourceProvider
346
+ self,
347
+ config_builder: DataDesignerConfigBuilder,
348
+ resource_provider: ResourceProvider,
338
349
  ) -> ColumnWiseDatasetBuilder:
350
+ config = compile_data_designer_config(config_builder, resource_provider)
351
+
339
352
  return ColumnWiseDatasetBuilder(
340
- column_configs=compile_dataset_builder_column_configs(config_builder.build(raise_exceptions=True)),
341
- processor_configs=config_builder.get_processor_configs(),
353
+ column_configs=compile_dataset_builder_column_configs(config),
354
+ processor_configs=config.processors or [],
342
355
  resource_provider=resource_provider,
343
356
  )
344
357
 
@@ -356,24 +369,21 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
356
369
  def _create_resource_provider(
357
370
  self, dataset_name: str, config_builder: DataDesignerConfigBuilder
358
371
  ) -> ResourceProvider:
359
- model_configs = config_builder.model_configs
360
372
  ArtifactStorage.mkdir_if_needed(self._artifact_path)
361
- return ResourceProvider(
373
+
374
+ seed_dataset_source = None
375
+ if (seed_config := config_builder.get_seed_config()) is not None:
376
+ seed_dataset_source = seed_config.source
377
+
378
+ return create_resource_provider(
362
379
  artifact_storage=ArtifactStorage(artifact_path=self._artifact_path, dataset_name=dataset_name),
363
- model_registry=create_model_registry(
364
- model_configs=model_configs,
365
- model_provider_registry=self._model_provider_registry,
366
- secret_resolver=self._secret_resolver,
367
- ),
380
+ model_configs=config_builder.model_configs,
381
+ secret_resolver=self._secret_resolver,
382
+ model_provider_registry=self._model_provider_registry,
368
383
  blob_storage=init_managed_blob_storage(str(self._managed_assets_path)),
369
- datastore=(
370
- LocalSeedDatasetDataStore()
371
- if (settings := config_builder.get_seed_datastore_settings()) is None
372
- else HfHubSeedDatasetDataStore(
373
- endpoint=settings.endpoint,
374
- token=settings.token,
375
- )
376
- ),
384
+ seed_dataset_source=seed_dataset_source,
385
+ seed_reader_registry=self._seed_reader_registry,
386
+ run_config=self._run_config,
377
387
  )
378
388
 
379
389
  def _get_interface_info(self, model_providers: list[ModelProvider]) -> InterfaceInfo:
@@ -50,7 +50,7 @@ class PluginManager:
50
50
  type_list = []
51
51
  for plugin in self._plugin_registry.get_plugins(PluginType.COLUMN_GENERATOR):
52
52
  if required_resources:
53
- task_required_resources = plugin.task_cls.metadata().required_resources or []
53
+ task_required_resources = plugin.impl_cls.metadata().required_resources or []
54
54
  if not all(resource in task_required_resources for resource in required_resources):
55
55
  continue
56
56
  type_list.append(enum_type(plugin.name))
@@ -4,6 +4,9 @@
4
4
  from data_designer.errors import DataDesignerError
5
5
 
6
6
 
7
+ class PluginLoadError(DataDesignerError): ...
8
+
9
+
7
10
  class PluginRegistrationError(DataDesignerError): ...
8
11
 
9
12
 
@@ -1,14 +1,20 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
6
+ import ast
7
+ import importlib
8
+ import importlib.util
4
9
  from enum import Enum
10
+ from functools import cached_property
5
11
  from typing import Literal, get_origin
6
12
 
7
- from pydantic import BaseModel, model_validator
13
+ from pydantic import BaseModel, Field, field_validator, model_validator
8
14
  from typing_extensions import Self
9
15
 
10
16
  from data_designer.config.base import ConfigBase
11
- from data_designer.engine.configurable_task import ConfigurableTask
17
+ from data_designer.plugins.errors import PluginLoadError
12
18
 
13
19
 
14
20
  class PluginType(str, Enum):
@@ -26,11 +32,42 @@ class PluginType(str, Enum):
26
32
  return self.value.replace("-", " ")
27
33
 
28
34
 
35
+ def _get_module_and_object_names(fully_qualified_object: str) -> tuple[str, str]:
36
+ try:
37
+ module_name, object_name = fully_qualified_object.rsplit(".", 1)
38
+ except ValueError:
39
+ # If fully_qualified_object does not have any periods, the rsplit call will return
40
+ # a list of length 1 and the variable assignment above will raise ValueError
41
+ raise PluginLoadError("Expected a fully-qualified object name, e.g. 'my_plugin.config.MyConfig'")
42
+
43
+ return module_name, object_name
44
+
45
+
46
+ def _check_class_exists_in_file(filepath: str, class_name: str) -> None:
47
+ try:
48
+ with open(filepath, "r") as file:
49
+ source = file.read()
50
+ except FileNotFoundError:
51
+ raise PluginLoadError(f"Could not read source code at {filepath!r}")
52
+
53
+ tree = ast.parse(source)
54
+ for node in ast.walk(tree):
55
+ if isinstance(node, ast.ClassDef) and node.name == class_name:
56
+ return None
57
+
58
+ raise PluginLoadError(f"Could not find class named {class_name!r} in {filepath!r}")
59
+
60
+
29
61
  class Plugin(BaseModel):
30
- task_cls: type[ConfigurableTask]
31
- config_cls: type[ConfigBase]
32
- plugin_type: PluginType
33
- emoji: str = "🔌"
62
+ impl_qualified_name: str = Field(
63
+ ...,
64
+ description="The fully-qualified name of the implementation class object, e.g. 'my_plugin.generator.MyColumnGenerator'",
65
+ )
66
+ config_qualified_name: str = Field(
67
+ ..., description="The fully-qualified name o the config class object, e.g. 'my_plugin.config.MyConfig'"
68
+ )
69
+ plugin_type: PluginType = Field(..., description="The type of plugin")
70
+ emoji: str = Field(default="🔌", description="The emoji to use in logs related to the plugin")
34
71
 
35
72
  @property
36
73
  def config_type_as_class_name(self) -> str:
@@ -48,22 +85,55 @@ class Plugin(BaseModel):
48
85
  def discriminator_field(self) -> str:
49
86
  return self.plugin_type.discriminator_field
50
87
 
88
+ @field_validator("impl_qualified_name", "config_qualified_name", mode="after")
89
+ @classmethod
90
+ def validate_class_name(cls, value: str) -> str:
91
+ module_name, object_name = _get_module_and_object_names(value)
92
+ try:
93
+ spec = importlib.util.find_spec(module_name)
94
+ except:
95
+ raise PluginLoadError(f"Could not find module {module_name!r}")
96
+
97
+ if spec is None or spec.origin is None:
98
+ raise PluginLoadError(f"Error finding source for module {module_name!r}")
99
+
100
+ _check_class_exists_in_file(spec.origin, object_name)
101
+
102
+ return value
103
+
51
104
  @model_validator(mode="after")
52
105
  def validate_discriminator_field(self) -> Self:
53
- cfg = self.config_cls.__name__
106
+ _, cfg = _get_module_and_object_names(self.config_qualified_name)
54
107
  field = self.plugin_type.discriminator_field
55
108
  if field not in self.config_cls.model_fields:
56
- raise ValueError(f"Discriminator field '{field}' not found in config class {cfg}")
109
+ raise ValueError(f"Discriminator field {field!r} not found in config class {cfg!r}")
57
110
  field_info = self.config_cls.model_fields[field]
58
111
  if get_origin(field_info.annotation) is not Literal:
59
- raise ValueError(f"Field '{field}' of {cfg} must be a Literal type, not {field_info.annotation}.")
112
+ raise ValueError(f"Field {field!r} of {cfg!r} must be a Literal type, not {field_info.annotation!r}.")
60
113
  if not isinstance(field_info.default, str):
61
- raise ValueError(f"The default of '{field}' must be a string, not {type(field_info.default)}.")
114
+ raise ValueError(f"The default of {field!r} must be a string, not {type(field_info.default)!r}.")
62
115
  enum_key = field_info.default.replace("-", "_").upper()
63
116
  if not enum_key.isidentifier():
64
117
  raise ValueError(
65
- f"The default value '{field_info.default}' for discriminator field '{field}' "
66
- f"cannot be converted to a valid enum key. The converted key '{enum_key}' "
118
+ f"The default value {field_info.default!r} for discriminator field {field!r} "
119
+ f"cannot be converted to a valid enum key. The converted key {enum_key!r} "
67
120
  f"must be a valid Python identifier."
68
121
  )
69
122
  return self
123
+
124
+ @cached_property
125
+ def config_cls(self) -> type[ConfigBase]:
126
+ return self._load(self.config_qualified_name)
127
+
128
+ @cached_property
129
+ def impl_cls(self) -> type:
130
+ return self._load(self.impl_qualified_name)
131
+
132
+ @staticmethod
133
+ def _load(fully_qualified_object: str) -> type:
134
+ module_name, object_name = _get_module_and_object_names(fully_qualified_object)
135
+ module = importlib.import_module(module_name)
136
+ try:
137
+ return getattr(module, object_name)
138
+ except AttributeError:
139
+ raise PluginLoadError(f"Could not find class {object_name!r} in module {module_name!r}")
@@ -0,0 +1,8 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from data_designer.plugins.testing.utils import assert_valid_plugin
5
+
6
+ __all__ = [
7
+ assert_valid_plugin.__name__,
8
+ ]
@@ -0,0 +1,145 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Literal
5
+
6
+ from data_designer.config.base import ConfigBase
7
+ from data_designer.config.column_configs import SingleColumnConfig
8
+ from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata
9
+ from data_designer.engine.resources.resource_provider import ResourceType
10
+ from data_designer.plugins.plugin import Plugin, PluginType
11
+
12
+ MODULE_NAME = __name__
13
+
14
+
15
+ class ValidTestConfig(SingleColumnConfig):
16
+ """Valid config for testing plugin creation."""
17
+
18
+ column_type: Literal["test-generator"] = "test-generator"
19
+ name: str
20
+
21
+
22
+ class ValidTestTask(ConfigurableTask[ValidTestConfig]):
23
+ """Valid task for testing plugin creation."""
24
+
25
+ @staticmethod
26
+ def metadata() -> ConfigurableTaskMetadata:
27
+ return ConfigurableTaskMetadata(
28
+ name="test_generator",
29
+ description="Test generator",
30
+ required_resources=None,
31
+ )
32
+
33
+
34
+ class ConfigWithoutDiscriminator(ConfigBase):
35
+ some_field: str
36
+
37
+
38
+ class ConfigWithStringField(ConfigBase):
39
+ column_type: str = "test-generator"
40
+
41
+
42
+ class ConfigWithNonStringDefault(ConfigBase):
43
+ column_type: Literal["test-generator"] = 123 # type: ignore
44
+
45
+
46
+ class ConfigWithInvalidKey(ConfigBase):
47
+ column_type: Literal["invalid-key-!@#"] = "invalid-key-!@#"
48
+
49
+
50
+ class StubPluginConfigA(SingleColumnConfig):
51
+ column_type: Literal["test-plugin-a"] = "test-plugin-a"
52
+
53
+
54
+ class StubPluginConfigB(SingleColumnConfig):
55
+ column_type: Literal["test-plugin-b"] = "test-plugin-b"
56
+
57
+
58
+ class StubPluginTaskA(ConfigurableTask[StubPluginConfigA]):
59
+ @staticmethod
60
+ def metadata() -> ConfigurableTaskMetadata:
61
+ return ConfigurableTaskMetadata(
62
+ name="test_plugin_a",
63
+ description="Test plugin A",
64
+ required_resources=None,
65
+ )
66
+
67
+
68
+ class StubPluginTaskB(ConfigurableTask[StubPluginConfigB]):
69
+ @staticmethod
70
+ def metadata() -> ConfigurableTaskMetadata:
71
+ return ConfigurableTaskMetadata(
72
+ name="test_plugin_b",
73
+ description="Test plugin B",
74
+ required_resources=None,
75
+ )
76
+
77
+
78
+ # Stub plugins requiring different combinations of resources
79
+
80
+
81
+ class StubPluginConfigModels(SingleColumnConfig):
82
+ column_type: Literal["test-plugin-models"] = "test-plugin-models"
83
+
84
+
85
+ class StubPluginConfigModelsAndBlobs(SingleColumnConfig):
86
+ column_type: Literal["test-plugin-models-and-blobs"] = "test-plugin-models-and-blobs"
87
+
88
+
89
+ class StubPluginConfigBlobsAndSeeds(SingleColumnConfig):
90
+ column_type: Literal["test-plugin-blobs-and-seeds"] = "test-plugin-blobs-and-seeds"
91
+
92
+
93
+ class StubPluginTaskModels(ConfigurableTask[StubPluginConfigModels]):
94
+ @staticmethod
95
+ def metadata() -> ConfigurableTaskMetadata:
96
+ return ConfigurableTaskMetadata(
97
+ name="test_plugin_models",
98
+ description="Test plugin requiring models",
99
+ required_resources=[ResourceType.MODEL_REGISTRY],
100
+ )
101
+
102
+
103
+ class StubPluginTaskModelsAndBlobs(ConfigurableTask[StubPluginConfigModelsAndBlobs]):
104
+ @staticmethod
105
+ def metadata() -> ConfigurableTaskMetadata:
106
+ return ConfigurableTaskMetadata(
107
+ name="test_plugin_models_and_blobs",
108
+ description="Test plugin requiring models and blobs",
109
+ required_resources=[ResourceType.MODEL_REGISTRY, ResourceType.BLOB_STORAGE],
110
+ )
111
+
112
+
113
+ class StubPluginTaskBlobsAndSeeds(ConfigurableTask[StubPluginConfigBlobsAndSeeds]):
114
+ @staticmethod
115
+ def metadata() -> ConfigurableTaskMetadata:
116
+ return ConfigurableTaskMetadata(
117
+ name="test_plugin_blobs_and_seeds",
118
+ description="Test plugin requiring blobs and seeds",
119
+ required_resources=[ResourceType.BLOB_STORAGE, ResourceType.SEED_READER],
120
+ )
121
+
122
+
123
+ plugin_none = Plugin(
124
+ config_qualified_name=f"{MODULE_NAME}.StubPluginConfigA",
125
+ impl_qualified_name=f"{MODULE_NAME}.StubPluginTaskA",
126
+ plugin_type=PluginType.COLUMN_GENERATOR,
127
+ )
128
+
129
+ plugin_models = Plugin(
130
+ config_qualified_name=f"{MODULE_NAME}.StubPluginConfigModels",
131
+ impl_qualified_name=f"{MODULE_NAME}.StubPluginTaskModels",
132
+ plugin_type=PluginType.COLUMN_GENERATOR,
133
+ )
134
+
135
+ plugin_models_and_blobs = Plugin(
136
+ config_qualified_name=f"{MODULE_NAME}.StubPluginConfigModelsAndBlobs",
137
+ impl_qualified_name=f"{MODULE_NAME}.StubPluginTaskModelsAndBlobs",
138
+ plugin_type=PluginType.COLUMN_GENERATOR,
139
+ )
140
+
141
+ plugin_blobs_and_seeds = Plugin(
142
+ config_qualified_name=f"{MODULE_NAME}.StubPluginConfigBlobsAndSeeds",
143
+ impl_qualified_name=f"{MODULE_NAME}.StubPluginTaskBlobsAndSeeds",
144
+ plugin_type=PluginType.COLUMN_GENERATOR,
145
+ )
@@ -0,0 +1,11 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from data_designer.config.base import ConfigBase
5
+ from data_designer.engine.configurable_task import ConfigurableTask
6
+ from data_designer.plugins.plugin import Plugin
7
+
8
+
9
+ def assert_valid_plugin(plugin: Plugin) -> None:
10
+ assert issubclass(plugin.config_cls, ConfigBase), "Plugin config class is not a subclass of ConfigBase"
11
+ assert issubclass(plugin.impl_cls, ConfigurableTask), "Plugin impl class is not a subclass of ConfigurableTask"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: data-designer
3
- Version: 0.2.2
3
+ Version: 0.3.0
4
4
  Summary: General framework for synthetic data generation
5
5
  License-Expression: Apache-2.0
6
6
  License-File: LICENSE
@@ -24,7 +24,7 @@ Requires-Dist: huggingface-hub<2,>=1.0.1
24
24
  Requires-Dist: jinja2<4,>=3.1.6
25
25
  Requires-Dist: json-repair<1,>=0.48.0
26
26
  Requires-Dist: jsonpath-rust-bindings<2,>=1.0
27
- Requires-Dist: litellm<2,>=1.73.6
27
+ Requires-Dist: litellm<1.80.12,>=1.73.6
28
28
  Requires-Dist: lxml<7,>=6.0.2
29
29
  Requires-Dist: marko<3,>=2.1.2
30
30
  Requires-Dist: networkx<4,>=3.0
@@ -181,7 +181,7 @@ ModelConfig(
181
181
  alias="nv-reasoning",
182
182
  model="openai/gpt-oss-20b",
183
183
  provider="nvidia",
184
- inference_parameters=InferenceParameters(
184
+ inference_parameters=ChatCompletionInferenceParams(
185
185
  temperature=0.3,
186
186
  top_p=0.9,
187
187
  max_tokens=4096,