data-designer-config 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/config/__init__.py +149 -0
- data_designer/config/_version.py +34 -0
- data_designer/config/analysis/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +159 -0
- data_designer/config/analysis/column_statistics.py +421 -0
- data_designer/config/analysis/dataset_profiler.py +84 -0
- data_designer/config/analysis/utils/errors.py +10 -0
- data_designer/config/analysis/utils/reporting.py +192 -0
- data_designer/config/base.py +69 -0
- data_designer/config/column_configs.py +476 -0
- data_designer/config/column_types.py +141 -0
- data_designer/config/config_builder.py +595 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +13 -0
- data_designer/config/dataset_metadata.py +18 -0
- data_designer/config/default_model_settings.py +129 -0
- data_designer/config/errors.py +24 -0
- data_designer/config/interface.py +55 -0
- data_designer/config/models.py +486 -0
- data_designer/config/preview_results.py +41 -0
- data_designer/config/processors.py +148 -0
- data_designer/config/run_config.py +56 -0
- data_designer/config/sampler_constraints.py +52 -0
- data_designer/config/sampler_params.py +639 -0
- data_designer/config/seed.py +116 -0
- data_designer/config/seed_source.py +84 -0
- data_designer/config/seed_source_types.py +19 -0
- data_designer/config/testing/__init__.py +6 -0
- data_designer/config/testing/fixtures.py +308 -0
- data_designer/config/utils/code_lang.py +93 -0
- data_designer/config/utils/constants.py +365 -0
- data_designer/config/utils/errors.py +21 -0
- data_designer/config/utils/info.py +94 -0
- data_designer/config/utils/io_helpers.py +258 -0
- data_designer/config/utils/misc.py +78 -0
- data_designer/config/utils/numerical_helpers.py +30 -0
- data_designer/config/utils/type_helpers.py +106 -0
- data_designer/config/utils/visualization.py +482 -0
- data_designer/config/validator_params.py +94 -0
- data_designer/errors.py +7 -0
- data_designer/lazy_heavy_imports.py +56 -0
- data_designer/logging.py +180 -0
- data_designer/plugin_manager.py +78 -0
- data_designer/plugins/__init__.py +8 -0
- data_designer/plugins/errors.py +15 -0
- data_designer/plugins/plugin.py +141 -0
- data_designer/plugins/registry.py +88 -0
- data_designer_config-0.4.0.dist-info/METADATA +75 -0
- data_designer_config-0.4.0.dist-info/RECORD +50 -0
- data_designer_config-0.4.0.dist-info/WHEEL +4 -0
data_designer/logging.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import random
|
|
8
|
+
import sys
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import TextIO
|
|
12
|
+
|
|
13
|
+
from pythonjsonlogger import jsonlogger
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class LoggerConfig:
|
|
18
|
+
name: str
|
|
19
|
+
level: str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class OutputConfig:
|
|
24
|
+
destination: TextIO | Path
|
|
25
|
+
structured: bool
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class LoggingConfig:
|
|
30
|
+
logger_configs: list[LoggerConfig]
|
|
31
|
+
output_configs: list[OutputConfig]
|
|
32
|
+
root_level: str = "INFO"
|
|
33
|
+
to_silence: list[str] = field(default_factory=lambda: _DEFAULT_NOISY_LOGGERS)
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def default(cls):
|
|
37
|
+
return LoggingConfig(
|
|
38
|
+
logger_configs=[LoggerConfig(name="data_designer", level="INFO")],
|
|
39
|
+
output_configs=[OutputConfig(destination=sys.stderr, structured=False)],
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def debug(cls):
|
|
44
|
+
return LoggingConfig(
|
|
45
|
+
logger_configs=[LoggerConfig(name="data_designer", level="DEBUG")],
|
|
46
|
+
output_configs=[OutputConfig(destination=sys.stderr, structured=False)],
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class RandomEmoji:
|
|
51
|
+
"""A generator for various themed emoji collections."""
|
|
52
|
+
|
|
53
|
+
def __init__(self) -> None:
|
|
54
|
+
self._progress_style = random.choice(_PROGRESS_STYLES)
|
|
55
|
+
|
|
56
|
+
def progress(self, percent: float) -> str:
|
|
57
|
+
"""Get a progress emoji based on completion percentage (0-100)."""
|
|
58
|
+
phase_idx = min(int(percent / 25), len(self._progress_style) - 1)
|
|
59
|
+
return self._progress_style[phase_idx]
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def cooking() -> str:
|
|
63
|
+
"""Get a random cooking or food preparation emoji."""
|
|
64
|
+
return random.choice(["👨🍳", "👩🍳", "🍳", "🥘", "🍲", "🔪", "🥄", "🍴", "⏲️", "🥗"])
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def data() -> str:
|
|
68
|
+
"""Get a random data or analytics emoji."""
|
|
69
|
+
return random.choice(["📊", "📈", "📉", "💾", "💿", "📀", "🗄️", "📁", "📂", "🗃️"])
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def generating() -> str:
|
|
73
|
+
"""Get a random generating or creating emoji."""
|
|
74
|
+
return random.choice(["🏭", "⚙️", "🔨", "🛠️", "🏗️", "🎨", "✍️", "📝", "🔧", "⚒️"])
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def loading() -> str:
|
|
78
|
+
"""Get a random loading or waiting emoji."""
|
|
79
|
+
return random.choice(["⏳", "⌛", "🔄", "♻️", "🔃", "⏰", "⏱️", "⏲️", "📡", "🌀"])
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def magic() -> str:
|
|
83
|
+
"""Get a random magical or special effect emoji."""
|
|
84
|
+
return random.choice(["✨", "⭐", "🌟", "💫", "🪄", "🔮", "🎩", "🌈", "💎", "🦄"])
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def previewing() -> str:
|
|
88
|
+
"""Get a random previewing or looking ahead emoji."""
|
|
89
|
+
return random.choice(["👀", "📺", "🔁", "👁️", "🔭", "🕵️", "🧐", "📸", "🎥", "🖼️"])
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def speed() -> str:
|
|
93
|
+
"""Get a random speed or fast emoji."""
|
|
94
|
+
return random.choice(["⚡", "💨", "🏃", "🏎️", "🚄", "✈️", "💥", "⏩", "🏃♂️", "🏃♀️"])
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def start() -> str:
|
|
98
|
+
"""Get a random emoji representing starting or launching something."""
|
|
99
|
+
return random.choice(["🚀", "▶️", "🎬", "🌅", "🏁", "🎯", "🚦", "🔔", "📣", "🎺"])
|
|
100
|
+
|
|
101
|
+
@staticmethod
|
|
102
|
+
def success() -> str:
|
|
103
|
+
"""Get a random success or celebration emoji."""
|
|
104
|
+
return random.choice(["🎉", "🎊", "👏", "🙌", "🎆", "🍾", "☀️", "🏆", "✅", "🥳"])
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def thinking() -> str:
|
|
108
|
+
"""Get a random thinking or processing emoji."""
|
|
109
|
+
return random.choice(["🤔", "💭", "🧠", "💡", "🔍", "🔎", "🤨", "🧐", "📝", "🧮"])
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def working() -> str:
|
|
113
|
+
"""Get a random working or in-progress emoji."""
|
|
114
|
+
return random.choice(["⚙️", "🔧", "🔨", "⚒️", "🛠️", "💼", "👷", "🏗️", "🪛", "👨💻"])
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def configure_logging(config: LoggingConfig | None = None) -> None:
|
|
118
|
+
config = config or LoggingConfig.default()
|
|
119
|
+
|
|
120
|
+
root_logger = logging.getLogger()
|
|
121
|
+
|
|
122
|
+
# Remove all handlers
|
|
123
|
+
root_logger.handlers.clear()
|
|
124
|
+
|
|
125
|
+
# Create and attach handler(s)
|
|
126
|
+
handlers = [_create_handler(output_config) for output_config in config.output_configs]
|
|
127
|
+
for handler in handlers:
|
|
128
|
+
root_logger.addHandler(handler)
|
|
129
|
+
|
|
130
|
+
# Set levels
|
|
131
|
+
root_logger.setLevel(config.root_level)
|
|
132
|
+
for logger_config in config.logger_configs:
|
|
133
|
+
logger = logging.getLogger(logger_config.name)
|
|
134
|
+
logger.setLevel(logger_config.level)
|
|
135
|
+
|
|
136
|
+
# Adjust noisy loggers
|
|
137
|
+
for name in config.to_silence:
|
|
138
|
+
quiet_noisy_logger(name)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def quiet_noisy_logger(name: str) -> None:
|
|
142
|
+
logger = logging.getLogger(name)
|
|
143
|
+
logger.handlers.clear()
|
|
144
|
+
logger.setLevel(logging.WARNING)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _create_handler(output_config: OutputConfig) -> logging.Handler:
|
|
148
|
+
if isinstance(output_config.destination, Path):
|
|
149
|
+
handler = logging.FileHandler(str(output_config.destination))
|
|
150
|
+
else:
|
|
151
|
+
handler = logging.StreamHandler()
|
|
152
|
+
|
|
153
|
+
if output_config.structured:
|
|
154
|
+
formatter = _make_json_formatter()
|
|
155
|
+
else:
|
|
156
|
+
formatter = _make_stream_formatter()
|
|
157
|
+
|
|
158
|
+
handler.setFormatter(formatter)
|
|
159
|
+
return handler
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _make_json_formatter() -> logging.Formatter:
|
|
163
|
+
log_format = "%(asctime)s %(levelname)s %(name)s %(message)s"
|
|
164
|
+
return jsonlogger.JsonFormatter(log_format)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _make_stream_formatter() -> logging.Formatter:
|
|
168
|
+
log_format = "[%(asctime)s] [%(levelname)s] %(message)s"
|
|
169
|
+
time_format = "%H:%M:%S"
|
|
170
|
+
return logging.Formatter(log_format, time_format)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
_DEFAULT_NOISY_LOGGERS = ["httpx", "matplotlib"]
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
_PROGRESS_STYLES: list[list[str]] = [
|
|
177
|
+
["🌑", "🌘", "🌗", "🌖", "🌕"], # Moon phases
|
|
178
|
+
["🌧️", "🌦️", "⛅", "🌤️", "☀️"], # Weather (storm to sun)
|
|
179
|
+
["🥚", "🐣", "🐥", "🐤", "🐔"], # Hatching (egg to chicken)
|
|
180
|
+
]
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import TYPE_CHECKING, TypeAlias
|
|
8
|
+
|
|
9
|
+
from data_designer.plugins.plugin import PluginType
|
|
10
|
+
from data_designer.plugins.registry import PluginRegistry
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from data_designer.plugins.plugin import Plugin
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PluginManager:
|
|
17
|
+
def __init__(self):
|
|
18
|
+
self._plugin_registry = PluginRegistry()
|
|
19
|
+
|
|
20
|
+
def get_column_generator_plugins(self) -> list[Plugin]:
|
|
21
|
+
"""Get all column generator plugins.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
A list of all column generator plugins.
|
|
25
|
+
"""
|
|
26
|
+
return self._plugin_registry.get_plugins(PluginType.COLUMN_GENERATOR)
|
|
27
|
+
|
|
28
|
+
def get_column_generator_plugin_if_exists(self, plugin_name: str) -> Plugin | None:
|
|
29
|
+
"""Get a column generator plugin by name if it exists.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
plugin_name: The name of the plugin to retrieve.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
The plugin if found, otherwise None.
|
|
36
|
+
"""
|
|
37
|
+
if self._plugin_registry.plugin_exists(plugin_name):
|
|
38
|
+
return self._plugin_registry.get_plugin(plugin_name)
|
|
39
|
+
|
|
40
|
+
def get_plugin_column_types(self, enum_type: type[Enum]) -> list[Enum]:
|
|
41
|
+
"""Get a list of plugin column types.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
enum_type: The enum type to use for plugin entries.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
A list of plugin column types.
|
|
48
|
+
"""
|
|
49
|
+
type_list = []
|
|
50
|
+
for plugin in self._plugin_registry.get_plugins(PluginType.COLUMN_GENERATOR):
|
|
51
|
+
type_list.append(enum_type(plugin.name))
|
|
52
|
+
return type_list
|
|
53
|
+
|
|
54
|
+
def inject_into_column_config_type_union(self, column_config_type: type[TypeAlias]) -> type[TypeAlias]:
|
|
55
|
+
"""Inject plugins into the column config type.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
column_config_type: The column config type to inject plugins into.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
The column config type with plugins injected.
|
|
62
|
+
"""
|
|
63
|
+
column_config_type = self._plugin_registry.add_plugin_types_to_union(
|
|
64
|
+
column_config_type, PluginType.COLUMN_GENERATOR
|
|
65
|
+
)
|
|
66
|
+
return column_config_type
|
|
67
|
+
|
|
68
|
+
def inject_into_seed_source_type_union(self, seed_source_type: type[TypeAlias]) -> type[TypeAlias]:
|
|
69
|
+
"""Inject plugins into the seed source type.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
seed_source_type: The seed source type to inject plugins into.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
The seed source type with plugins injected.
|
|
76
|
+
"""
|
|
77
|
+
seed_source_type = self._plugin_registry.add_plugin_types_to_union(seed_source_type, PluginType.SEED_READER)
|
|
78
|
+
return seed_source_type
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from data_designer.plugins.plugin import Plugin, PluginType
|
|
7
|
+
|
|
8
|
+
__all__ = ["Plugin", "PluginType"]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from data_designer.errors import DataDesignerError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PluginLoadError(DataDesignerError): ...
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PluginRegistrationError(DataDesignerError): ...
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PluginNotFoundError(DataDesignerError): ...
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import ast
|
|
7
|
+
import importlib
|
|
8
|
+
import importlib.util
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from functools import cached_property
|
|
11
|
+
from typing import Literal, get_origin
|
|
12
|
+
|
|
13
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
14
|
+
from typing_extensions import Self
|
|
15
|
+
|
|
16
|
+
from data_designer.config.base import ConfigBase
|
|
17
|
+
from data_designer.plugins.errors import PluginLoadError
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PluginType(str, Enum):
|
|
21
|
+
COLUMN_GENERATOR = "column-generator"
|
|
22
|
+
SEED_READER = "seed-reader"
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def discriminator_field(self) -> str:
|
|
26
|
+
if self == PluginType.COLUMN_GENERATOR:
|
|
27
|
+
return "column_type"
|
|
28
|
+
elif self == PluginType.SEED_READER:
|
|
29
|
+
return "seed_type"
|
|
30
|
+
else:
|
|
31
|
+
raise ValueError(f"Invalid plugin type: {self.value}")
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def display_name(self) -> str:
|
|
35
|
+
return self.value.replace("-", " ")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _get_module_and_object_names(fully_qualified_object: str) -> tuple[str, str]:
|
|
39
|
+
try:
|
|
40
|
+
module_name, object_name = fully_qualified_object.rsplit(".", 1)
|
|
41
|
+
except ValueError:
|
|
42
|
+
# If fully_qualified_object does not have any periods, the rsplit call will return
|
|
43
|
+
# a list of length 1 and the variable assignment above will raise ValueError
|
|
44
|
+
raise PluginLoadError("Expected a fully-qualified object name, e.g. 'my_plugin.config.MyConfig'")
|
|
45
|
+
|
|
46
|
+
return module_name, object_name
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _check_class_exists_in_file(filepath: str, class_name: str) -> None:
|
|
50
|
+
try:
|
|
51
|
+
with open(filepath, "r") as file:
|
|
52
|
+
source = file.read()
|
|
53
|
+
except FileNotFoundError:
|
|
54
|
+
raise PluginLoadError(f"Could not read source code at {filepath!r}")
|
|
55
|
+
|
|
56
|
+
tree = ast.parse(source)
|
|
57
|
+
for node in ast.walk(tree):
|
|
58
|
+
if isinstance(node, ast.ClassDef) and node.name == class_name:
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
raise PluginLoadError(f"Could not find class named {class_name!r} in {filepath!r}")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class Plugin(BaseModel):
|
|
65
|
+
impl_qualified_name: str = Field(
|
|
66
|
+
...,
|
|
67
|
+
description="The fully-qualified name of the implementation class object, e.g. 'my_plugin.generator.MyColumnGenerator'",
|
|
68
|
+
)
|
|
69
|
+
config_qualified_name: str = Field(
|
|
70
|
+
..., description="The fully-qualified name o the config class object, e.g. 'my_plugin.config.MyConfig'"
|
|
71
|
+
)
|
|
72
|
+
plugin_type: PluginType = Field(..., description="The type of plugin")
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def config_type_as_class_name(self) -> str:
|
|
76
|
+
return self.enum_key_name.title().replace("_", "")
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def enum_key_name(self) -> str:
|
|
80
|
+
return self.name.replace("-", "_").upper()
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def name(self) -> str:
|
|
84
|
+
return self.config_cls.model_fields[self.discriminator_field].default
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def discriminator_field(self) -> str:
|
|
88
|
+
return self.plugin_type.discriminator_field
|
|
89
|
+
|
|
90
|
+
@field_validator("impl_qualified_name", "config_qualified_name", mode="after")
|
|
91
|
+
@classmethod
|
|
92
|
+
def validate_class_name(cls, value: str) -> str:
|
|
93
|
+
module_name, object_name = _get_module_and_object_names(value)
|
|
94
|
+
try:
|
|
95
|
+
spec = importlib.util.find_spec(module_name)
|
|
96
|
+
except:
|
|
97
|
+
raise PluginLoadError(f"Could not find module {module_name!r}")
|
|
98
|
+
|
|
99
|
+
if spec is None or spec.origin is None:
|
|
100
|
+
raise PluginLoadError(f"Error finding source for module {module_name!r}")
|
|
101
|
+
|
|
102
|
+
_check_class_exists_in_file(spec.origin, object_name)
|
|
103
|
+
|
|
104
|
+
return value
|
|
105
|
+
|
|
106
|
+
@model_validator(mode="after")
|
|
107
|
+
def validate_discriminator_field(self) -> Self:
|
|
108
|
+
_, cfg = _get_module_and_object_names(self.config_qualified_name)
|
|
109
|
+
field = self.plugin_type.discriminator_field
|
|
110
|
+
if field not in self.config_cls.model_fields:
|
|
111
|
+
raise ValueError(f"Discriminator field {field!r} not found in config class {cfg!r}")
|
|
112
|
+
field_info = self.config_cls.model_fields[field]
|
|
113
|
+
if get_origin(field_info.annotation) is not Literal:
|
|
114
|
+
raise ValueError(f"Field {field!r} of {cfg!r} must be a Literal type, not {field_info.annotation!r}.")
|
|
115
|
+
if not isinstance(field_info.default, str):
|
|
116
|
+
raise ValueError(f"The default of {field!r} must be a string, not {type(field_info.default)!r}.")
|
|
117
|
+
enum_key = field_info.default.replace("-", "_").upper()
|
|
118
|
+
if not enum_key.isidentifier():
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"The default value {field_info.default!r} for discriminator field {field!r} "
|
|
121
|
+
f"cannot be converted to a valid enum key. The converted key {enum_key!r} "
|
|
122
|
+
f"must be a valid Python identifier."
|
|
123
|
+
)
|
|
124
|
+
return self
|
|
125
|
+
|
|
126
|
+
@cached_property
|
|
127
|
+
def config_cls(self) -> type[ConfigBase]:
|
|
128
|
+
return self._load(self.config_qualified_name)
|
|
129
|
+
|
|
130
|
+
@cached_property
|
|
131
|
+
def impl_cls(self) -> type:
|
|
132
|
+
return self._load(self.impl_qualified_name)
|
|
133
|
+
|
|
134
|
+
@staticmethod
|
|
135
|
+
def _load(fully_qualified_object: str) -> type:
|
|
136
|
+
module_name, object_name = _get_module_and_object_names(fully_qualified_object)
|
|
137
|
+
module = importlib.import_module(module_name)
|
|
138
|
+
try:
|
|
139
|
+
return getattr(module, object_name)
|
|
140
|
+
except AttributeError:
|
|
141
|
+
raise PluginLoadError(f"Could not find class {object_name!r} in module {module_name!r}")
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import threading
|
|
9
|
+
from importlib.metadata import entry_points
|
|
10
|
+
from typing import TypeAlias
|
|
11
|
+
|
|
12
|
+
from typing_extensions import Self
|
|
13
|
+
|
|
14
|
+
from data_designer.plugins.errors import PluginNotFoundError
|
|
15
|
+
from data_designer.plugins.plugin import Plugin, PluginType
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
PLUGINS_DISABLED = os.getenv("DISABLE_DATA_DESIGNER_PLUGINS", "false").lower() == "true"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PluginRegistry:
|
|
24
|
+
_instance = None
|
|
25
|
+
_plugins_discovered = False
|
|
26
|
+
_lock = threading.Lock()
|
|
27
|
+
|
|
28
|
+
_plugins: dict[str, Plugin] = {}
|
|
29
|
+
|
|
30
|
+
def __init__(self):
|
|
31
|
+
with self._lock:
|
|
32
|
+
if not self._plugins_discovered:
|
|
33
|
+
self._discover()
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def reset(cls) -> None:
|
|
37
|
+
with cls._lock:
|
|
38
|
+
cls._instance = None
|
|
39
|
+
cls._plugins_discovered = False
|
|
40
|
+
cls._plugins = {}
|
|
41
|
+
|
|
42
|
+
def add_plugin_types_to_union(self, type_union: type[TypeAlias], plugin_type: PluginType) -> type[TypeAlias]:
|
|
43
|
+
for plugin in self.get_plugins(plugin_type):
|
|
44
|
+
if plugin.config_cls not in type_union.__args__:
|
|
45
|
+
type_union |= plugin.config_cls
|
|
46
|
+
return type_union
|
|
47
|
+
|
|
48
|
+
def get_plugin(self, plugin_name: str) -> Plugin:
|
|
49
|
+
if plugin_name not in self._plugins:
|
|
50
|
+
raise PluginNotFoundError(f"Plugin {plugin_name!r} not found.")
|
|
51
|
+
return self._plugins[plugin_name]
|
|
52
|
+
|
|
53
|
+
def get_plugins(self, plugin_type: PluginType) -> list[Plugin]:
|
|
54
|
+
return [plugin for plugin in self._plugins.values() if plugin.plugin_type == plugin_type]
|
|
55
|
+
|
|
56
|
+
def get_plugin_names(self, plugin_type: PluginType) -> list[str]:
|
|
57
|
+
return [plugin.name for plugin in self.get_plugins(plugin_type)]
|
|
58
|
+
|
|
59
|
+
def num_plugins(self, plugin_type: PluginType) -> int:
|
|
60
|
+
return len(self.get_plugins(plugin_type))
|
|
61
|
+
|
|
62
|
+
def plugin_exists(self, plugin_name: str) -> bool:
|
|
63
|
+
return plugin_name in self._plugins
|
|
64
|
+
|
|
65
|
+
def _discover(self) -> Self:
|
|
66
|
+
if PLUGINS_DISABLED:
|
|
67
|
+
return self
|
|
68
|
+
for ep in entry_points(group="data_designer.plugins"):
|
|
69
|
+
try:
|
|
70
|
+
plugin = ep.load()
|
|
71
|
+
if isinstance(plugin, Plugin):
|
|
72
|
+
logger.info(
|
|
73
|
+
f"🔌 Plugin discovered ➜ {plugin.plugin_type.display_name} "
|
|
74
|
+
f"{plugin.enum_key_name} is now available ⚡️"
|
|
75
|
+
)
|
|
76
|
+
self._plugins[plugin.name] = plugin
|
|
77
|
+
except Exception as e:
|
|
78
|
+
logger.warning(f"🛑 Failed to load plugin from entry point {ep.name!r}: {e}")
|
|
79
|
+
self._plugins_discovered = True
|
|
80
|
+
return self
|
|
81
|
+
|
|
82
|
+
def __new__(cls, *args, **kwargs):
|
|
83
|
+
"""Plugin manager is a singleton."""
|
|
84
|
+
if not cls._instance:
|
|
85
|
+
with cls._lock:
|
|
86
|
+
if not cls._instance:
|
|
87
|
+
cls._instance = super().__new__(cls)
|
|
88
|
+
return cls._instance
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: data-designer-config
|
|
3
|
+
Version: 0.4.0
|
|
4
|
+
Summary: Configuration layer for DataDesigner synthetic data generation
|
|
5
|
+
License-Expression: Apache-2.0
|
|
6
|
+
Classifier: Development Status :: 4 - Beta
|
|
7
|
+
Classifier: Intended Audience :: Developers
|
|
8
|
+
Classifier: Intended Audience :: Science/Research
|
|
9
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
15
|
+
Requires-Python: >=3.10
|
|
16
|
+
Requires-Dist: jinja2<4,>=3.1.6
|
|
17
|
+
Requires-Dist: numpy<3,>=1.23.5
|
|
18
|
+
Requires-Dist: pandas<3,>=2.3.3
|
|
19
|
+
Requires-Dist: pyarrow<20,>=19.0.1
|
|
20
|
+
Requires-Dist: pydantic[email]<3,>=2.9.2
|
|
21
|
+
Requires-Dist: pygments<3,>=2.19.2
|
|
22
|
+
Requires-Dist: python-json-logger<4,>=3
|
|
23
|
+
Requires-Dist: pyyaml<7,>=6.0.1
|
|
24
|
+
Requires-Dist: rich<15,>=13.7.1
|
|
25
|
+
Description-Content-Type: text/markdown
|
|
26
|
+
|
|
27
|
+
# data-designer-config
|
|
28
|
+
|
|
29
|
+
Configuration layer for NeMo Data Designer synthetic data generation framework.
|
|
30
|
+
|
|
31
|
+
This package provides the configuration API for defining synthetic data generation pipelines. It's a lightweight dependency that can be used standalone for configuration management.
|
|
32
|
+
|
|
33
|
+
## Installation
|
|
34
|
+
|
|
35
|
+
```bash
|
|
36
|
+
pip install data-designer-config
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
## Usage
|
|
40
|
+
|
|
41
|
+
```python
|
|
42
|
+
import data_designer.config as dd
|
|
43
|
+
|
|
44
|
+
# Initialize config builder with model config(s)
|
|
45
|
+
config_builder = dd.DataDesignerConfigBuilder(
|
|
46
|
+
model_configs=[
|
|
47
|
+
dd.ModelConfig(
|
|
48
|
+
alias="my-model",
|
|
49
|
+
model="nvidia/nemotron-3-nano-30b-a3b",
|
|
50
|
+
inference_parameters=dd.ChatCompletionInferenceParams(temperature=0.7),
|
|
51
|
+
),
|
|
52
|
+
]
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Add columns
|
|
56
|
+
config_builder.add_column(
|
|
57
|
+
dd.SamplerColumnConfig(
|
|
58
|
+
name="user_id",
|
|
59
|
+
sampler_type=dd.SamplerType.UUID,
|
|
60
|
+
params=dd.UUIDSamplerParams(prefix="user-"),
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
config_builder.add_column(
|
|
64
|
+
dd.LLMTextColumnConfig(
|
|
65
|
+
name="description",
|
|
66
|
+
prompt="Write a product description",
|
|
67
|
+
model_alias="my-model",
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Build configuration
|
|
72
|
+
config = config_builder.build()
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
See main [README.md](https://github.com/NVIDIA-NeMo/DataDesigner/blob/main/README.md) for more information.
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
data_designer/errors.py,sha256=r1pBvmvRBAsPmb7oF_veubhkxZ2uPo9cGEDwykLziX4,220
|
|
2
|
+
data_designer/lazy_heavy_imports.py,sha256=5X04vUj9TYbKgfDmY2qvhzRf5-IZWKOanIpi3_u7fmM,1662
|
|
3
|
+
data_designer/logging.py,sha256=Xq2cRwxmDJ-r8_s9NWnk15efLRsrKm5iVScHy6HkjiE,6044
|
|
4
|
+
data_designer/plugin_manager.py,sha256=C2ZkZiXlcMRiaxfrrho5Shz6DKdExVeBha7ch-d4CnU,2695
|
|
5
|
+
data_designer/config/__init__.py,sha256=MWzRZhXA41sTpc0sL_xq2baA3kSlV37alT6g8RlP8dU,4919
|
|
6
|
+
data_designer/config/_version.py,sha256=2_0GUP7yBCXRus-qiJKxQD62z172WSs1sQ6DVpPsbmM,704
|
|
7
|
+
data_designer/config/base.py,sha256=IGj6sy_GnKzC94uu2rdxe12EqR_AmGJ6O3rl2MxOv6g,2449
|
|
8
|
+
data_designer/config/column_configs.py,sha256=QEHXbxljbGEfOEnzNsiR3_CRpaCukQsayBbHQyhMhbc,20720
|
|
9
|
+
data_designer/config/column_types.py,sha256=xGXuu0EBy3Y5Jd74f2VM6x5jHq72GmK9leA6qOnAz8c,5423
|
|
10
|
+
data_designer/config/config_builder.py,sha256=vuPibkodbJxbCXdaI1tt1Uyo1SVCnAOfLBAW1AmhajI,24707
|
|
11
|
+
data_designer/config/data_designer_config.py,sha256=qOojviug05vHR2S4800sjd4OmxhSVi6kB8SAFXLlPog,1891
|
|
12
|
+
data_designer/config/dataset_builders.py,sha256=jdCujJYFlKAiSkPNX2Qeyrs683GrRcCDv_m8ZZhtg64,368
|
|
13
|
+
data_designer/config/dataset_metadata.py,sha256=UTlEgnHWgjwPuc7bP95T7gaKmcr7pIhFMy9vvbUwMV4,647
|
|
14
|
+
data_designer/config/default_model_settings.py,sha256=c-llH2otfG0tMCMsxoz3ZcS1nFxIQQPfRedFXAydDbc,4868
|
|
15
|
+
data_designer/config/errors.py,sha256=JhvUYecfLmP0gZjQzqA3OmfaSs9TRlC5E-ubnV_-3gs,560
|
|
16
|
+
data_designer/config/interface.py,sha256=ikmpm_KwencTpM-yg0auo7XMgcmMSa67S75IqdpFLfk,1676
|
|
17
|
+
data_designer/config/models.py,sha256=_NctRk4brgBeb5q5V7r_hXE5OORlLh6SCVZP0eu2LGo,16721
|
|
18
|
+
data_designer/config/preview_results.py,sha256=WnPlDcHElIHNfjV_P-nLu_Dpul8D3Eyb5qyi3E173Gs,1744
|
|
19
|
+
data_designer/config/processors.py,sha256=lnyUZA1EhO9NWjjVFFioYxSgeYpoAaM1J7UzwOYkvms,6028
|
|
20
|
+
data_designer/config/run_config.py,sha256=m_rrqEmNHR533AYJ_OR5yq0a9Pegy9vPGZgyfD4x9cI,3052
|
|
21
|
+
data_designer/config/sampler_constraints.py,sha256=tQI1XLF5bS4TnyKMLo0nArvefnXI8dWCzov38r4qNCQ,1197
|
|
22
|
+
data_designer/config/sampler_params.py,sha256=Gio-53vjSYOdPhF2CEq4HSWCXCaZMy4WpGPbuFVcWOM,27965
|
|
23
|
+
data_designer/config/seed.py,sha256=eShSqOcSUzfCEZBnqY-rB0qZpRGxjeOE3fSaJAwacec,4668
|
|
24
|
+
data_designer/config/seed_source.py,sha256=ufcZdibP3aeruswC1lfh-JJcr5NjK_Ht50uY6-wnl8E,2635
|
|
25
|
+
data_designer/config/seed_source_types.py,sha256=sxu6EOVr4ChZFvv2dI1-F9AZg_9fnv8UJ0dGVbsWQ6E,715
|
|
26
|
+
data_designer/config/validator_params.py,sha256=xm5H1IgphK61aMFoH2FOu4MROlvxeL84CajI8DTPv6Y,3947
|
|
27
|
+
data_designer/config/analysis/__init__.py,sha256=ObZ6NUPeEvvpGTJ5WIGKUyIrIjaI747OM6ErweRtHxQ,137
|
|
28
|
+
data_designer/config/analysis/column_profilers.py,sha256=sgfHrHYRZgtTDrHtfkDtu6F9iZUwD3ISc3m9kka0UUE,6256
|
|
29
|
+
data_designer/config/analysis/column_statistics.py,sha256=g3ipgwHMLyTLhvJrB7St0SYhyvIe6ENfTCEJKoePetc,16885
|
|
30
|
+
data_designer/config/analysis/dataset_profiler.py,sha256=-5eX55IXivwUBMg2pI-d_3e7nbJb83a0tyxL-WzL-MY,4174
|
|
31
|
+
data_designer/config/analysis/utils/errors.py,sha256=pvmdQ_YuIlWW4NFw-cX_rOoQf-GG8y_FiQzNctB__DQ,331
|
|
32
|
+
data_designer/config/analysis/utils/reporting.py,sha256=teTzd1OHtpI4vbIinGOGsKXyNldO3F5eqbNdAztF0_s,7066
|
|
33
|
+
data_designer/config/testing/__init__.py,sha256=vxFrIOqDoDfOx-MWjC5lb_hvmB4kRKvh1QdTv--QYFM,222
|
|
34
|
+
data_designer/config/testing/fixtures.py,sha256=J1bcWjerAIoVUIZBVPbUcuvEa2laj_kspVcLS7UZMbo,10876
|
|
35
|
+
data_designer/config/utils/code_lang.py,sha256=nUeWjuzSYBVF5gwOiUE2-EsYCEDzRZaw31RIivt7GPI,2638
|
|
36
|
+
data_designer/config/utils/constants.py,sha256=lprfeF_bIzGJ_oGrZBhvHEbLVgrGfFtVbCdWJHf_6B8,8953
|
|
37
|
+
data_designer/config/utils/errors.py,sha256=HCjer0YrF0bMn5j8gmgWaLb0395LAr_hxMD1ftOsOc8,520
|
|
38
|
+
data_designer/config/utils/info.py,sha256=yOa4U8kI_CY4OfCKZxCm2okU8klAiThvyjKM5tG-F0A,3469
|
|
39
|
+
data_designer/config/utils/io_helpers.py,sha256=kzvOR7QgqijkqU-O2enIlpCWwHvzc3oRaEl4Lsjh1Do,8466
|
|
40
|
+
data_designer/config/utils/misc.py,sha256=7n_0txc78IoK6V39CwZY-65KtYcjh38WDl0Q1bQM-EA,2481
|
|
41
|
+
data_designer/config/utils/numerical_helpers.py,sha256=DIubKzc8q2_Bw7xRjyOGwxYulTV3dt3JxCdpH560dak,838
|
|
42
|
+
data_designer/config/utils/type_helpers.py,sha256=XyVup24F4Bl7uNze_yUW9oD6EzFbfsJWKhpeMN2901A,4059
|
|
43
|
+
data_designer/config/utils/visualization.py,sha256=_0Mn-jva0Oz1tVTQH1mnWSARpqZ2kh1JSzJEuikyy9s,18491
|
|
44
|
+
data_designer/plugins/__init__.py,sha256=qe1alcTEtnMSMdzknjb57vvjqKgFE5cEHXxBj8tPWMI,275
|
|
45
|
+
data_designer/plugins/errors.py,sha256=d7FMed3ueQvZHwuhwyPLzF4E34bO1mdj3aBVEw6p34o,386
|
|
46
|
+
data_designer/plugins/plugin.py,sha256=TVyyOaQBWAt0FQwUmtihTZ9MDJD85HwggrQ3L9CviPQ,5367
|
|
47
|
+
data_designer/plugins/registry.py,sha256=Cnt33Q25o9bS2v2YDbV3QPM57VNrtIBKAb4ERQRE_dY,3053
|
|
48
|
+
data_designer_config-0.4.0.dist-info/METADATA,sha256=l06rdZe6t1jKhqvgkH0ZYSTXX-UUVsjQ-ZIfwD_mwvA,2283
|
|
49
|
+
data_designer_config-0.4.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
50
|
+
data_designer_config-0.4.0.dist-info/RECORD,,
|