data-designer 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/config/analysis/column_profilers.py +4 -4
- data_designer/config/analysis/column_statistics.py +5 -5
- data_designer/config/analysis/dataset_profiler.py +6 -6
- data_designer/config/analysis/utils/errors.py +1 -1
- data_designer/config/analysis/utils/reporting.py +5 -5
- data_designer/config/base.py +2 -2
- data_designer/config/column_configs.py +8 -8
- data_designer/config/column_types.py +9 -5
- data_designer/config/config_builder.py +32 -27
- data_designer/config/data_designer_config.py +7 -7
- data_designer/config/datastore.py +4 -4
- data_designer/config/default_model_settings.py +4 -4
- data_designer/config/errors.py +1 -1
- data_designer/config/exports.py +133 -0
- data_designer/config/interface.py +6 -6
- data_designer/config/models.py +109 -5
- data_designer/config/preview_results.py +9 -6
- data_designer/config/processors.py +48 -4
- data_designer/config/sampler_constraints.py +1 -1
- data_designer/config/sampler_params.py +2 -2
- data_designer/config/seed.py +3 -3
- data_designer/config/utils/constants.py +1 -1
- data_designer/config/utils/errors.py +1 -1
- data_designer/config/utils/info.py +8 -4
- data_designer/config/utils/io_helpers.py +5 -5
- data_designer/config/utils/misc.py +3 -3
- data_designer/config/utils/numerical_helpers.py +1 -1
- data_designer/config/utils/type_helpers.py +7 -3
- data_designer/config/utils/validation.py +37 -6
- data_designer/config/utils/visualization.py +42 -10
- data_designer/config/validator_params.py +2 -2
- data_designer/engine/analysis/column_profilers/base.py +1 -1
- data_designer/engine/analysis/dataset_profiler.py +1 -1
- data_designer/engine/analysis/utils/judge_score_processing.py +1 -1
- data_designer/engine/column_generators/generators/samplers.py +1 -1
- data_designer/engine/dataset_builders/artifact_storage.py +16 -6
- data_designer/engine/dataset_builders/column_wise_builder.py +4 -1
- data_designer/engine/dataset_builders/utils/concurrency.py +1 -1
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +1 -1
- data_designer/engine/errors.py +1 -1
- data_designer/engine/models/errors.py +1 -1
- data_designer/engine/models/facade.py +1 -1
- data_designer/engine/models/parsers/parser.py +2 -2
- data_designer/engine/models/recipes/response_recipes.py +1 -1
- data_designer/engine/processing/ginja/environment.py +1 -1
- data_designer/engine/processing/gsonschema/validators.py +1 -1
- data_designer/engine/processing/processors/drop_columns.py +1 -1
- data_designer/engine/processing/processors/registry.py +3 -0
- data_designer/engine/processing/processors/schema_transform.py +53 -0
- data_designer/engine/resources/managed_dataset_repository.py +4 -4
- data_designer/engine/resources/managed_storage.py +1 -1
- data_designer/engine/sampling_gen/constraints.py +1 -1
- data_designer/engine/sampling_gen/data_sources/base.py +1 -1
- data_designer/engine/sampling_gen/entities/email_address_utils.py +1 -1
- data_designer/engine/sampling_gen/entities/national_id_utils.py +1 -1
- data_designer/engine/sampling_gen/entities/person.py +1 -1
- data_designer/engine/sampling_gen/entities/phone_number.py +1 -1
- data_designer/engine/sampling_gen/people_gen.py +3 -3
- data_designer/engine/secret_resolver.py +1 -1
- data_designer/engine/validators/python.py +2 -2
- data_designer/essentials/__init__.py +20 -128
- data_designer/interface/data_designer.py +23 -19
- data_designer/interface/results.py +36 -0
- data_designer/logging.py +2 -2
- data_designer/plugin_manager.py +14 -26
- data_designer/plugins/registry.py +1 -1
- {data_designer-0.1.3.dist-info → data_designer-0.1.5.dist-info}/METADATA +9 -9
- {data_designer-0.1.3.dist-info → data_designer-0.1.5.dist-info}/RECORD +72 -70
- {data_designer-0.1.3.dist-info → data_designer-0.1.5.dist-info}/WHEEL +0 -0
- {data_designer-0.1.3.dist-info → data_designer-0.1.5.dist-info}/entry_points.txt +0 -0
- {data_designer-0.1.3.dist-info → data_designer-0.1.5.dist-info}/licenses/LICENSE +0 -0
data_designer/config/models.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
import logging
|
|
4
5
|
from abc import ABC, abstractmethod
|
|
5
6
|
from enum import Enum
|
|
6
|
-
import logging
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
from typing import Any, Generic, List, Optional, TypeVar, Union
|
|
9
9
|
|
|
@@ -11,29 +11,35 @@ import numpy as np
|
|
|
11
11
|
from pydantic import BaseModel, Field, model_validator
|
|
12
12
|
from typing_extensions import Self, TypeAlias
|
|
13
13
|
|
|
14
|
-
from .base import ConfigBase
|
|
15
|
-
from .errors import InvalidConfigError
|
|
16
|
-
from .utils.constants import (
|
|
14
|
+
from data_designer.config.base import ConfigBase
|
|
15
|
+
from data_designer.config.errors import InvalidConfigError
|
|
16
|
+
from data_designer.config.utils.constants import (
|
|
17
17
|
MAX_TEMPERATURE,
|
|
18
18
|
MAX_TOP_P,
|
|
19
19
|
MIN_TEMPERATURE,
|
|
20
20
|
MIN_TOP_P,
|
|
21
21
|
)
|
|
22
|
-
from .utils.io_helpers import smart_load_yaml
|
|
22
|
+
from data_designer.config.utils.io_helpers import smart_load_yaml
|
|
23
23
|
|
|
24
24
|
logger = logging.getLogger(__name__)
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class Modality(str, Enum):
|
|
28
|
+
"""Supported modality types for multimodal model data."""
|
|
29
|
+
|
|
28
30
|
IMAGE = "image"
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
class ModalityDataType(str, Enum):
|
|
34
|
+
"""Data type formats for multimodal data."""
|
|
35
|
+
|
|
32
36
|
URL = "url"
|
|
33
37
|
BASE64 = "base64"
|
|
34
38
|
|
|
35
39
|
|
|
36
40
|
class ImageFormat(str, Enum):
|
|
41
|
+
"""Supported image formats for image modality."""
|
|
42
|
+
|
|
37
43
|
PNG = "png"
|
|
38
44
|
JPG = "jpg"
|
|
39
45
|
JPEG = "jpeg"
|
|
@@ -42,6 +48,8 @@ class ImageFormat(str, Enum):
|
|
|
42
48
|
|
|
43
49
|
|
|
44
50
|
class DistributionType(str, Enum):
|
|
51
|
+
"""Types of distributions for sampling inference parameters."""
|
|
52
|
+
|
|
45
53
|
UNIFORM = "uniform"
|
|
46
54
|
MANUAL = "manual"
|
|
47
55
|
|
|
@@ -56,10 +64,27 @@ class ModalityContext(ABC, BaseModel):
|
|
|
56
64
|
|
|
57
65
|
|
|
58
66
|
class ImageContext(ModalityContext):
|
|
67
|
+
"""Configuration for providing image context to multimodal models.
|
|
68
|
+
|
|
69
|
+
Attributes:
|
|
70
|
+
modality: The modality type (always "image").
|
|
71
|
+
column_name: Name of the column containing image data.
|
|
72
|
+
data_type: Format of the image data ("url" or "base64").
|
|
73
|
+
image_format: Image format (required for base64 data).
|
|
74
|
+
"""
|
|
75
|
+
|
|
59
76
|
modality: Modality = Modality.IMAGE
|
|
60
77
|
image_format: Optional[ImageFormat] = None
|
|
61
78
|
|
|
62
79
|
def get_context(self, record: dict) -> dict[str, Any]:
|
|
80
|
+
"""Get the context for the image modality.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
record: The record containing the image data.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
The context for the image modality.
|
|
87
|
+
"""
|
|
63
88
|
context = dict(type="image_url")
|
|
64
89
|
context_value = record[self.column_name]
|
|
65
90
|
if self.data_type == ModalityDataType.URL:
|
|
@@ -90,6 +115,13 @@ class Distribution(ABC, ConfigBase, Generic[DistributionParamsT]):
|
|
|
90
115
|
|
|
91
116
|
|
|
92
117
|
class ManualDistributionParams(ConfigBase):
|
|
118
|
+
"""Parameters for manual distribution sampling.
|
|
119
|
+
|
|
120
|
+
Attributes:
|
|
121
|
+
values: List of possible values to sample from.
|
|
122
|
+
weights: Optional list of weights for each value. If not provided, all values have equal probability.
|
|
123
|
+
"""
|
|
124
|
+
|
|
93
125
|
values: List[float] = Field(min_length=1)
|
|
94
126
|
weights: Optional[List[float]] = None
|
|
95
127
|
|
|
@@ -107,14 +139,36 @@ class ManualDistributionParams(ConfigBase):
|
|
|
107
139
|
|
|
108
140
|
|
|
109
141
|
class ManualDistribution(Distribution[ManualDistributionParams]):
|
|
142
|
+
"""Manual (discrete) distribution for sampling inference parameters.
|
|
143
|
+
|
|
144
|
+
Samples from a discrete set of values with optional weights. Useful for testing
|
|
145
|
+
specific values or creating custom probability distributions for temperature or top_p.
|
|
146
|
+
|
|
147
|
+
Attributes:
|
|
148
|
+
distribution_type: Type of distribution ("manual").
|
|
149
|
+
params: Distribution parameters (values, weights).
|
|
150
|
+
"""
|
|
151
|
+
|
|
110
152
|
distribution_type: Optional[DistributionType] = "manual"
|
|
111
153
|
params: ManualDistributionParams
|
|
112
154
|
|
|
113
155
|
def sample(self) -> float:
|
|
156
|
+
"""Sample a value from the manual distribution.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
A float value sampled from the manual distribution.
|
|
160
|
+
"""
|
|
114
161
|
return float(np.random.choice(self.params.values, p=self.params.weights))
|
|
115
162
|
|
|
116
163
|
|
|
117
164
|
class UniformDistributionParams(ConfigBase):
|
|
165
|
+
"""Parameters for uniform distribution sampling.
|
|
166
|
+
|
|
167
|
+
Attributes:
|
|
168
|
+
low: Lower bound (inclusive).
|
|
169
|
+
high: Upper bound (exclusive).
|
|
170
|
+
"""
|
|
171
|
+
|
|
118
172
|
low: float
|
|
119
173
|
high: float
|
|
120
174
|
|
|
@@ -126,10 +180,25 @@ class UniformDistributionParams(ConfigBase):
|
|
|
126
180
|
|
|
127
181
|
|
|
128
182
|
class UniformDistribution(Distribution[UniformDistributionParams]):
|
|
183
|
+
"""Uniform distribution for sampling inference parameters.
|
|
184
|
+
|
|
185
|
+
Samples values uniformly between low and high bounds. Useful for exploring
|
|
186
|
+
a continuous range of values for temperature or top_p.
|
|
187
|
+
|
|
188
|
+
Attributes:
|
|
189
|
+
distribution_type: Type of distribution ("uniform").
|
|
190
|
+
params: Distribution parameters (low, high).
|
|
191
|
+
"""
|
|
192
|
+
|
|
129
193
|
distribution_type: Optional[DistributionType] = "uniform"
|
|
130
194
|
params: UniformDistributionParams
|
|
131
195
|
|
|
132
196
|
def sample(self) -> float:
|
|
197
|
+
"""Sample a value from the uniform distribution.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
A float value sampled from the uniform distribution.
|
|
201
|
+
"""
|
|
133
202
|
return float(np.random.uniform(low=self.params.low, high=self.params.high, size=1)[0])
|
|
134
203
|
|
|
135
204
|
|
|
@@ -137,6 +206,17 @@ DistributionT: TypeAlias = Union[UniformDistribution, ManualDistribution]
|
|
|
137
206
|
|
|
138
207
|
|
|
139
208
|
class InferenceParameters(ConfigBase):
|
|
209
|
+
"""Configuration for LLM inference parameters.
|
|
210
|
+
|
|
211
|
+
Attributes:
|
|
212
|
+
temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling.
|
|
213
|
+
top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
|
|
214
|
+
max_tokens: Maximum number of tokens (includes both input and output tokens).
|
|
215
|
+
max_parallel_requests: Maximum number of parallel requests to the model API.
|
|
216
|
+
timeout: Timeout in seconds for each request.
|
|
217
|
+
extra_body: Additional parameters to pass to the model API.
|
|
218
|
+
"""
|
|
219
|
+
|
|
140
220
|
temperature: Optional[Union[float, DistributionT]] = None
|
|
141
221
|
top_p: Optional[Union[float, DistributionT]] = None
|
|
142
222
|
max_tokens: Optional[int] = Field(default=None, ge=1)
|
|
@@ -146,6 +226,11 @@ class InferenceParameters(ConfigBase):
|
|
|
146
226
|
|
|
147
227
|
@property
|
|
148
228
|
def generate_kwargs(self) -> dict[str, Union[float, int]]:
|
|
229
|
+
"""Get the generate kwargs for the inference parameters.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
A dictionary of the generate kwargs.
|
|
233
|
+
"""
|
|
149
234
|
result = {}
|
|
150
235
|
if self.temperature is not None:
|
|
151
236
|
result["temperature"] = (
|
|
@@ -206,6 +291,15 @@ class InferenceParameters(ConfigBase):
|
|
|
206
291
|
|
|
207
292
|
|
|
208
293
|
class ModelConfig(ConfigBase):
|
|
294
|
+
"""Configuration for a model used for generation.
|
|
295
|
+
|
|
296
|
+
Attributes:
|
|
297
|
+
alias: User-defined alias to reference in column configurations.
|
|
298
|
+
model: Model identifier (e.g., from build.nvidia.com or other providers).
|
|
299
|
+
inference_parameters: Inference parameters for the model (temperature, top_p, max_tokens, etc.).
|
|
300
|
+
provider: Optional model provider name if using custom providers.
|
|
301
|
+
"""
|
|
302
|
+
|
|
209
303
|
alias: str
|
|
210
304
|
model: str
|
|
211
305
|
inference_parameters: InferenceParameters = Field(default_factory=InferenceParameters)
|
|
@@ -213,6 +307,16 @@ class ModelConfig(ConfigBase):
|
|
|
213
307
|
|
|
214
308
|
|
|
215
309
|
class ModelProvider(ConfigBase):
|
|
310
|
+
"""Configuration for a custom model provider.
|
|
311
|
+
|
|
312
|
+
Attributes:
|
|
313
|
+
name: Name of the model provider.
|
|
314
|
+
endpoint: API endpoint URL for the provider.
|
|
315
|
+
provider_type: Provider type (default: "openai"). Determines the API format to use.
|
|
316
|
+
api_key: Optional API key for authentication.
|
|
317
|
+
extra_body: Additional parameters to pass in API requests.
|
|
318
|
+
"""
|
|
319
|
+
|
|
216
320
|
name: str
|
|
217
321
|
endpoint: str
|
|
218
322
|
provider_type: str = "openai"
|
|
@@ -3,13 +3,13 @@
|
|
|
3
3
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
|
-
from typing import Optional
|
|
6
|
+
from typing import Optional, Union
|
|
7
7
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
|
|
10
|
-
from .analysis.dataset_profiler import DatasetProfilerResults
|
|
11
|
-
from .config_builder import DataDesignerConfigBuilder
|
|
12
|
-
from .utils.visualization import WithRecordSamplerMixin
|
|
10
|
+
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
|
|
11
|
+
from data_designer.config.config_builder import DataDesignerConfigBuilder
|
|
12
|
+
from data_designer.config.utils.visualization import WithRecordSamplerMixin
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class PreviewResults(WithRecordSamplerMixin):
|
|
@@ -19,6 +19,7 @@ class PreviewResults(WithRecordSamplerMixin):
|
|
|
19
19
|
config_builder: DataDesignerConfigBuilder,
|
|
20
20
|
dataset: Optional[pd.DataFrame] = None,
|
|
21
21
|
analysis: Optional[DatasetProfilerResults] = None,
|
|
22
|
+
processor_artifacts: Optional[dict[str, Union[list[str], str]]] = None,
|
|
22
23
|
):
|
|
23
24
|
"""Creates a new instance with results from a Data Designer preview run.
|
|
24
25
|
|
|
@@ -26,7 +27,9 @@ class PreviewResults(WithRecordSamplerMixin):
|
|
|
26
27
|
config_builder: Data Designer configuration builder.
|
|
27
28
|
dataset: Dataset of the preview run.
|
|
28
29
|
analysis: Analysis of the preview run.
|
|
30
|
+
processor_artifacts: Artifacts generated by the processors.
|
|
29
31
|
"""
|
|
30
|
-
self.dataset: pd.DataFrame
|
|
31
|
-
self.analysis: DatasetProfilerResults
|
|
32
|
+
self.dataset: Optional[pd.DataFrame] = dataset
|
|
33
|
+
self.analysis: Optional[DatasetProfilerResults] = analysis
|
|
34
|
+
self.processor_artifacts: Optional[dict[str, Union[list[str], str]]] = processor_artifacts
|
|
32
35
|
self._config_builder = config_builder
|
|
@@ -1,25 +1,32 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
import json
|
|
4
5
|
from abc import ABC
|
|
5
6
|
from enum import Enum
|
|
6
|
-
from typing import Literal
|
|
7
|
+
from typing import Any, Literal
|
|
7
8
|
|
|
8
9
|
from pydantic import Field, field_validator
|
|
9
10
|
|
|
10
|
-
from .base import ConfigBase
|
|
11
|
-
from .dataset_builders import BuildStage
|
|
11
|
+
from data_designer.config.base import ConfigBase
|
|
12
|
+
from data_designer.config.dataset_builders import BuildStage
|
|
13
|
+
from data_designer.config.errors import InvalidConfigError
|
|
12
14
|
|
|
13
15
|
SUPPORTED_STAGES = [BuildStage.POST_BATCH]
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
class ProcessorType(str, Enum):
|
|
17
19
|
DROP_COLUMNS = "drop_columns"
|
|
20
|
+
SCHEMA_TRANSFORM = "schema_transform"
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
class ProcessorConfig(ConfigBase, ABC):
|
|
24
|
+
name: str = Field(
|
|
25
|
+
description="The name of the processor, used to identify the processor in the results and to write the artifacts to disk.",
|
|
26
|
+
)
|
|
21
27
|
build_stage: BuildStage = Field(
|
|
22
|
-
|
|
28
|
+
default=BuildStage.POST_BATCH,
|
|
29
|
+
description=f"The stage at which the processor will run. Supported stages: {', '.join(SUPPORTED_STAGES)}",
|
|
23
30
|
)
|
|
24
31
|
|
|
25
32
|
@field_validator("build_stage")
|
|
@@ -34,8 +41,45 @@ class ProcessorConfig(ConfigBase, ABC):
|
|
|
34
41
|
def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs) -> ProcessorConfig:
|
|
35
42
|
if processor_type == ProcessorType.DROP_COLUMNS:
|
|
36
43
|
return DropColumnsProcessorConfig(**kwargs)
|
|
44
|
+
elif processor_type == ProcessorType.SCHEMA_TRANSFORM:
|
|
45
|
+
return SchemaTransformProcessorConfig(**kwargs)
|
|
37
46
|
|
|
38
47
|
|
|
39
48
|
class DropColumnsProcessorConfig(ProcessorConfig):
|
|
40
49
|
column_names: list[str]
|
|
41
50
|
processor_type: Literal[ProcessorType.DROP_COLUMNS] = ProcessorType.DROP_COLUMNS
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class SchemaTransformProcessorConfig(ProcessorConfig):
|
|
54
|
+
template: dict[str, Any] = Field(
|
|
55
|
+
...,
|
|
56
|
+
description="""
|
|
57
|
+
Dictionary specifying columns and templates to use in the new dataset with transformed schema.
|
|
58
|
+
|
|
59
|
+
Each key is a new column name, and each value is an object containing Jinja2 templates - for instance, a string or a list of strings.
|
|
60
|
+
Values must be JSON-serializable.
|
|
61
|
+
|
|
62
|
+
Example:
|
|
63
|
+
|
|
64
|
+
```python
|
|
65
|
+
template = {
|
|
66
|
+
"list_of_strings": ["{{ col1 }}", "{{ col2 }}"],
|
|
67
|
+
"uppercase_string": "{{ col1 | upper }}",
|
|
68
|
+
"lowercase_string": "{{ col2 | lower }}",
|
|
69
|
+
}
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
The above templates will create an new dataset with three columns: "list_of_strings", "uppercase_string", and "lowercase_string".
|
|
73
|
+
References to columns "col1" and "col2" in the templates will be replaced with the actual values of the columns in the dataset.
|
|
74
|
+
""",
|
|
75
|
+
)
|
|
76
|
+
processor_type: Literal[ProcessorType.SCHEMA_TRANSFORM] = ProcessorType.SCHEMA_TRANSFORM
|
|
77
|
+
|
|
78
|
+
@field_validator("template")
|
|
79
|
+
def validate_template(cls, v: dict[str, Any]) -> dict[str, Any]:
|
|
80
|
+
try:
|
|
81
|
+
json.dumps(v)
|
|
82
|
+
except TypeError as e:
|
|
83
|
+
if "not JSON serializable" in str(e):
|
|
84
|
+
raise InvalidConfigError("Template must be JSON serializable")
|
|
85
|
+
return v
|
|
@@ -8,8 +8,8 @@ import pandas as pd
|
|
|
8
8
|
from pydantic import Field, field_validator, model_validator
|
|
9
9
|
from typing_extensions import Self, TypeAlias
|
|
10
10
|
|
|
11
|
-
from .base import ConfigBase
|
|
12
|
-
from .utils.constants import (
|
|
11
|
+
from data_designer.config.base import ConfigBase
|
|
12
|
+
from data_designer.config.utils.constants import (
|
|
13
13
|
AVAILABLE_LOCALES,
|
|
14
14
|
DEFAULT_AGE_RANGE,
|
|
15
15
|
LOCALES_WITH_MANAGED_DATASETS,
|
data_designer/config/seed.py
CHANGED
|
@@ -8,9 +8,9 @@ from typing import Optional, Union
|
|
|
8
8
|
from pydantic import Field, field_validator, model_validator
|
|
9
9
|
from typing_extensions import Self
|
|
10
10
|
|
|
11
|
-
from .base import ConfigBase
|
|
12
|
-
from .datastore import DatastoreSettings
|
|
13
|
-
from .utils.io_helpers import (
|
|
11
|
+
from data_designer.config.base import ConfigBase
|
|
12
|
+
from data_designer.config.datastore import DatastoreSettings
|
|
13
|
+
from data_designer.config.utils.io_helpers import (
|
|
14
14
|
VALID_DATASET_FILE_EXTENSIONS,
|
|
15
15
|
validate_dataset_file_path,
|
|
16
16
|
validate_path_contains_files_of_type,
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from
|
|
4
|
+
from data_designer.errors import DataDesignerError
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class UserJinjaTemplateSyntaxError(DataDesignerError): ...
|
|
@@ -5,10 +5,14 @@ from abc import ABC, abstractmethod
|
|
|
5
5
|
from enum import Enum
|
|
6
6
|
from typing import Literal, TypeVar
|
|
7
7
|
|
|
8
|
-
from
|
|
9
|
-
from
|
|
10
|
-
from .type_helpers import get_sampler_params
|
|
11
|
-
from .visualization import
|
|
8
|
+
from data_designer.config.models import ModelConfig, ModelProvider
|
|
9
|
+
from data_designer.config.sampler_params import SamplerType
|
|
10
|
+
from data_designer.config.utils.type_helpers import get_sampler_params
|
|
11
|
+
from data_designer.config.utils.visualization import (
|
|
12
|
+
display_model_configs_table,
|
|
13
|
+
display_model_providers_table,
|
|
14
|
+
display_sampler_table,
|
|
15
|
+
)
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
class InfoType(str, Enum):
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from datetime import date, datetime, timedelta
|
|
5
|
-
from decimal import Decimal
|
|
6
4
|
import json
|
|
7
5
|
import logging
|
|
8
|
-
from numbers import Number
|
|
9
6
|
import os
|
|
7
|
+
from datetime import date, datetime, timedelta
|
|
8
|
+
from decimal import Decimal
|
|
9
|
+
from numbers import Number
|
|
10
10
|
from pathlib import Path
|
|
11
11
|
from typing import Any, Union
|
|
12
12
|
|
|
@@ -14,7 +14,7 @@ import numpy as np
|
|
|
14
14
|
import pandas as pd
|
|
15
15
|
import yaml
|
|
16
16
|
|
|
17
|
-
from
|
|
17
|
+
from data_designer.config.errors import InvalidFileFormatError, InvalidFilePathError
|
|
18
18
|
|
|
19
19
|
logger = logging.getLogger(__name__)
|
|
20
20
|
|
|
@@ -44,7 +44,7 @@ def load_config_file(file_path: Path) -> dict:
|
|
|
44
44
|
InvalidFileFormatError: If YAML is malformed
|
|
45
45
|
InvalidConfigError: If file is empty
|
|
46
46
|
"""
|
|
47
|
-
from
|
|
47
|
+
from data_designer.config.errors import InvalidConfigError
|
|
48
48
|
|
|
49
49
|
if not file_path.exists():
|
|
50
50
|
raise InvalidFilePathError(f"Configuration file not found: {file_path}")
|
|
@@ -3,14 +3,14 @@
|
|
|
3
3
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
|
-
from contextlib import contextmanager
|
|
7
6
|
import json
|
|
7
|
+
from contextlib import contextmanager
|
|
8
8
|
from typing import Optional, Union
|
|
9
9
|
|
|
10
10
|
from jinja2 import TemplateSyntaxError, meta
|
|
11
11
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
|
12
12
|
|
|
13
|
-
from .errors import UserJinjaTemplateSyntaxError
|
|
13
|
+
from data_designer.config.utils.errors import UserJinjaTemplateSyntaxError
|
|
14
14
|
|
|
15
15
|
REPR_LIST_LENGTH_USE_JSON = 4
|
|
16
16
|
|
|
@@ -43,7 +43,7 @@ def assert_valid_jinja2_template(template: str) -> None:
|
|
|
43
43
|
def can_run_data_designer_locally() -> bool:
|
|
44
44
|
"""Returns True if Data Designer can be run locally, False otherwise."""
|
|
45
45
|
try:
|
|
46
|
-
from ... import engine # noqa: F401
|
|
46
|
+
from ... import engine # noqa: F401, TID252
|
|
47
47
|
except ImportError:
|
|
48
48
|
return False
|
|
49
49
|
return True
|
|
@@ -1,14 +1,18 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from enum import Enum
|
|
5
4
|
import inspect
|
|
5
|
+
from enum import Enum
|
|
6
6
|
from typing import Any, Literal, Type, get_args, get_origin
|
|
7
7
|
|
|
8
8
|
from pydantic import BaseModel
|
|
9
9
|
|
|
10
|
-
from
|
|
11
|
-
from .errors import
|
|
10
|
+
from data_designer.config import sampler_params
|
|
11
|
+
from data_designer.config.utils.errors import (
|
|
12
|
+
InvalidDiscriminatorFieldError,
|
|
13
|
+
InvalidEnumValueError,
|
|
14
|
+
InvalidTypeUnionError,
|
|
15
|
+
)
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
class StrEnum(str, Enum):
|
|
@@ -15,11 +15,14 @@ from rich.console import Console, Group
|
|
|
15
15
|
from rich.padding import Padding
|
|
16
16
|
from rich.panel import Panel
|
|
17
17
|
|
|
18
|
-
from
|
|
19
|
-
from
|
|
20
|
-
from
|
|
21
|
-
from .
|
|
22
|
-
|
|
18
|
+
from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_llm_generated
|
|
19
|
+
from data_designer.config.processors import ProcessorConfig, ProcessorType
|
|
20
|
+
from data_designer.config.utils.constants import RICH_CONSOLE_THEME
|
|
21
|
+
from data_designer.config.utils.misc import (
|
|
22
|
+
can_run_data_designer_locally,
|
|
23
|
+
get_prompt_template_keywords,
|
|
24
|
+
)
|
|
25
|
+
from data_designer.config.validator_params import ValidatorType
|
|
23
26
|
|
|
24
27
|
|
|
25
28
|
class ViolationType(str, Enum):
|
|
@@ -63,6 +66,7 @@ def validate_data_designer_config(
|
|
|
63
66
|
violations.extend(validate_expression_references(columns=columns, allowed_references=allowed_references))
|
|
64
67
|
violations.extend(validate_columns_not_all_dropped(columns=columns))
|
|
65
68
|
violations.extend(validate_drop_columns_processor(columns=columns, processor_configs=processor_configs))
|
|
69
|
+
violations.extend(validate_schema_transform_processor(columns=columns, processor_configs=processor_configs))
|
|
66
70
|
if not can_run_data_designer_locally():
|
|
67
71
|
violations.extend(validate_local_only_columns(columns=columns))
|
|
68
72
|
return violations
|
|
@@ -271,7 +275,7 @@ def validate_drop_columns_processor(
|
|
|
271
275
|
columns: list[ColumnConfigT],
|
|
272
276
|
processor_configs: list[ProcessorConfig],
|
|
273
277
|
) -> list[Violation]:
|
|
274
|
-
all_column_names =
|
|
278
|
+
all_column_names = {c.name for c in columns}
|
|
275
279
|
for processor_config in processor_configs:
|
|
276
280
|
if processor_config.processor_type == ProcessorType.DROP_COLUMNS:
|
|
277
281
|
invalid_columns = set(processor_config.column_names) - all_column_names
|
|
@@ -288,6 +292,33 @@ def validate_drop_columns_processor(
|
|
|
288
292
|
return []
|
|
289
293
|
|
|
290
294
|
|
|
295
|
+
def validate_schema_transform_processor(
|
|
296
|
+
columns: list[ColumnConfigT],
|
|
297
|
+
processor_configs: list[ProcessorConfig],
|
|
298
|
+
) -> list[Violation]:
|
|
299
|
+
violations = []
|
|
300
|
+
|
|
301
|
+
all_column_names = {c.name for c in columns}
|
|
302
|
+
for processor_config in processor_configs:
|
|
303
|
+
if processor_config.processor_type == ProcessorType.SCHEMA_TRANSFORM:
|
|
304
|
+
for col, template in processor_config.template.items():
|
|
305
|
+
template_keywords = get_prompt_template_keywords(template)
|
|
306
|
+
invalid_keywords = set(template_keywords) - all_column_names
|
|
307
|
+
if len(invalid_keywords) > 0:
|
|
308
|
+
invalid_keywords = ", ".join([f"'{k}'" for k in invalid_keywords])
|
|
309
|
+
message = f"Ancillary dataset processor attempts to reference columns {invalid_keywords} in the template for '{col}', but the columns are not defined in the dataset."
|
|
310
|
+
violations.append(
|
|
311
|
+
Violation(
|
|
312
|
+
column=None,
|
|
313
|
+
type=ViolationType.INVALID_REFERENCE,
|
|
314
|
+
message=message,
|
|
315
|
+
level=ViolationLevel.ERROR,
|
|
316
|
+
)
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
return violations
|
|
320
|
+
|
|
321
|
+
|
|
291
322
|
def validate_expression_references(
|
|
292
323
|
columns: list[ColumnConfigT],
|
|
293
324
|
allowed_references: list[str],
|