data-designer 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/config/analysis/column_profilers.py +4 -4
- data_designer/config/analysis/column_statistics.py +5 -5
- data_designer/config/analysis/dataset_profiler.py +6 -6
- data_designer/config/analysis/utils/errors.py +1 -1
- data_designer/config/analysis/utils/reporting.py +5 -5
- data_designer/config/base.py +2 -2
- data_designer/config/column_configs.py +8 -8
- data_designer/config/column_types.py +9 -5
- data_designer/config/config_builder.py +32 -27
- data_designer/config/data_designer_config.py +7 -7
- data_designer/config/datastore.py +4 -4
- data_designer/config/default_model_settings.py +4 -4
- data_designer/config/errors.py +1 -1
- data_designer/config/exports.py +128 -0
- data_designer/config/interface.py +6 -6
- data_designer/config/models.py +109 -5
- data_designer/config/preview_results.py +3 -3
- data_designer/config/processors.py +2 -2
- data_designer/config/sampler_constraints.py +1 -1
- data_designer/config/sampler_params.py +2 -5
- data_designer/config/seed.py +3 -3
- data_designer/config/utils/constants.py +1 -1
- data_designer/config/utils/errors.py +1 -1
- data_designer/config/utils/info.py +8 -4
- data_designer/config/utils/io_helpers.py +5 -5
- data_designer/config/utils/misc.py +3 -3
- data_designer/config/utils/numerical_helpers.py +1 -1
- data_designer/config/utils/type_helpers.py +7 -3
- data_designer/config/utils/validation.py +5 -5
- data_designer/config/utils/visualization.py +10 -10
- data_designer/config/validator_params.py +2 -2
- data_designer/engine/analysis/column_profilers/base.py +1 -1
- data_designer/engine/analysis/dataset_profiler.py +1 -1
- data_designer/engine/analysis/utils/judge_score_processing.py +1 -1
- data_designer/engine/column_generators/generators/samplers.py +1 -1
- data_designer/engine/dataset_builders/artifact_storage.py +16 -2
- data_designer/engine/dataset_builders/column_wise_builder.py +3 -3
- data_designer/engine/dataset_builders/utils/concurrency.py +1 -1
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +1 -1
- data_designer/engine/errors.py +1 -1
- data_designer/engine/models/errors.py +1 -1
- data_designer/engine/models/facade.py +1 -1
- data_designer/engine/models/parsers/parser.py +2 -2
- data_designer/engine/models/recipes/response_recipes.py +1 -1
- data_designer/engine/processing/ginja/environment.py +1 -1
- data_designer/engine/processing/gsonschema/validators.py +1 -1
- data_designer/engine/resources/managed_dataset_repository.py +4 -4
- data_designer/engine/resources/managed_storage.py +1 -1
- data_designer/engine/sampling_gen/constraints.py +1 -1
- data_designer/engine/sampling_gen/data_sources/base.py +1 -1
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +31 -9
- data_designer/engine/sampling_gen/entities/email_address_utils.py +1 -1
- data_designer/engine/sampling_gen/entities/national_id_utils.py +1 -1
- data_designer/engine/sampling_gen/entities/person.py +1 -1
- data_designer/engine/sampling_gen/entities/phone_number.py +1 -1
- data_designer/engine/sampling_gen/people_gen.py +3 -3
- data_designer/engine/secret_resolver.py +1 -1
- data_designer/engine/validators/python.py +2 -2
- data_designer/essentials/__init__.py +20 -128
- data_designer/interface/data_designer.py +16 -20
- data_designer/logging.py +2 -2
- data_designer/plugin_manager.py +14 -26
- data_designer/plugins/registry.py +1 -1
- {data_designer-0.1.2.dist-info → data_designer-0.1.4.dist-info}/METADATA +2 -2
- {data_designer-0.1.2.dist-info → data_designer-0.1.4.dist-info}/RECORD +69 -68
- {data_designer-0.1.2.dist-info → data_designer-0.1.4.dist-info}/WHEEL +1 -1
- {data_designer-0.1.2.dist-info → data_designer-0.1.4.dist-info}/entry_points.txt +0 -0
- {data_designer-0.1.2.dist-info → data_designer-0.1.4.dist-info}/licenses/LICENSE +0 -0
data_designer/config/models.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
import logging
|
|
4
5
|
from abc import ABC, abstractmethod
|
|
5
6
|
from enum import Enum
|
|
6
|
-
import logging
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
from typing import Any, Generic, List, Optional, TypeVar, Union
|
|
9
9
|
|
|
@@ -11,29 +11,35 @@ import numpy as np
|
|
|
11
11
|
from pydantic import BaseModel, Field, model_validator
|
|
12
12
|
from typing_extensions import Self, TypeAlias
|
|
13
13
|
|
|
14
|
-
from .base import ConfigBase
|
|
15
|
-
from .errors import InvalidConfigError
|
|
16
|
-
from .utils.constants import (
|
|
14
|
+
from data_designer.config.base import ConfigBase
|
|
15
|
+
from data_designer.config.errors import InvalidConfigError
|
|
16
|
+
from data_designer.config.utils.constants import (
|
|
17
17
|
MAX_TEMPERATURE,
|
|
18
18
|
MAX_TOP_P,
|
|
19
19
|
MIN_TEMPERATURE,
|
|
20
20
|
MIN_TOP_P,
|
|
21
21
|
)
|
|
22
|
-
from .utils.io_helpers import smart_load_yaml
|
|
22
|
+
from data_designer.config.utils.io_helpers import smart_load_yaml
|
|
23
23
|
|
|
24
24
|
logger = logging.getLogger(__name__)
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class Modality(str, Enum):
|
|
28
|
+
"""Supported modality types for multimodal model data."""
|
|
29
|
+
|
|
28
30
|
IMAGE = "image"
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
class ModalityDataType(str, Enum):
|
|
34
|
+
"""Data type formats for multimodal data."""
|
|
35
|
+
|
|
32
36
|
URL = "url"
|
|
33
37
|
BASE64 = "base64"
|
|
34
38
|
|
|
35
39
|
|
|
36
40
|
class ImageFormat(str, Enum):
|
|
41
|
+
"""Supported image formats for image modality."""
|
|
42
|
+
|
|
37
43
|
PNG = "png"
|
|
38
44
|
JPG = "jpg"
|
|
39
45
|
JPEG = "jpeg"
|
|
@@ -42,6 +48,8 @@ class ImageFormat(str, Enum):
|
|
|
42
48
|
|
|
43
49
|
|
|
44
50
|
class DistributionType(str, Enum):
|
|
51
|
+
"""Types of distributions for sampling inference parameters."""
|
|
52
|
+
|
|
45
53
|
UNIFORM = "uniform"
|
|
46
54
|
MANUAL = "manual"
|
|
47
55
|
|
|
@@ -56,10 +64,27 @@ class ModalityContext(ABC, BaseModel):
|
|
|
56
64
|
|
|
57
65
|
|
|
58
66
|
class ImageContext(ModalityContext):
|
|
67
|
+
"""Configuration for providing image context to multimodal models.
|
|
68
|
+
|
|
69
|
+
Attributes:
|
|
70
|
+
modality: The modality type (always "image").
|
|
71
|
+
column_name: Name of the column containing image data.
|
|
72
|
+
data_type: Format of the image data ("url" or "base64").
|
|
73
|
+
image_format: Image format (required for base64 data).
|
|
74
|
+
"""
|
|
75
|
+
|
|
59
76
|
modality: Modality = Modality.IMAGE
|
|
60
77
|
image_format: Optional[ImageFormat] = None
|
|
61
78
|
|
|
62
79
|
def get_context(self, record: dict) -> dict[str, Any]:
|
|
80
|
+
"""Get the context for the image modality.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
record: The record containing the image data.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
The context for the image modality.
|
|
87
|
+
"""
|
|
63
88
|
context = dict(type="image_url")
|
|
64
89
|
context_value = record[self.column_name]
|
|
65
90
|
if self.data_type == ModalityDataType.URL:
|
|
@@ -90,6 +115,13 @@ class Distribution(ABC, ConfigBase, Generic[DistributionParamsT]):
|
|
|
90
115
|
|
|
91
116
|
|
|
92
117
|
class ManualDistributionParams(ConfigBase):
|
|
118
|
+
"""Parameters for manual distribution sampling.
|
|
119
|
+
|
|
120
|
+
Attributes:
|
|
121
|
+
values: List of possible values to sample from.
|
|
122
|
+
weights: Optional list of weights for each value. If not provided, all values have equal probability.
|
|
123
|
+
"""
|
|
124
|
+
|
|
93
125
|
values: List[float] = Field(min_length=1)
|
|
94
126
|
weights: Optional[List[float]] = None
|
|
95
127
|
|
|
@@ -107,14 +139,36 @@ class ManualDistributionParams(ConfigBase):
|
|
|
107
139
|
|
|
108
140
|
|
|
109
141
|
class ManualDistribution(Distribution[ManualDistributionParams]):
|
|
142
|
+
"""Manual (discrete) distribution for sampling inference parameters.
|
|
143
|
+
|
|
144
|
+
Samples from a discrete set of values with optional weights. Useful for testing
|
|
145
|
+
specific values or creating custom probability distributions for temperature or top_p.
|
|
146
|
+
|
|
147
|
+
Attributes:
|
|
148
|
+
distribution_type: Type of distribution ("manual").
|
|
149
|
+
params: Distribution parameters (values, weights).
|
|
150
|
+
"""
|
|
151
|
+
|
|
110
152
|
distribution_type: Optional[DistributionType] = "manual"
|
|
111
153
|
params: ManualDistributionParams
|
|
112
154
|
|
|
113
155
|
def sample(self) -> float:
|
|
156
|
+
"""Sample a value from the manual distribution.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
A float value sampled from the manual distribution.
|
|
160
|
+
"""
|
|
114
161
|
return float(np.random.choice(self.params.values, p=self.params.weights))
|
|
115
162
|
|
|
116
163
|
|
|
117
164
|
class UniformDistributionParams(ConfigBase):
|
|
165
|
+
"""Parameters for uniform distribution sampling.
|
|
166
|
+
|
|
167
|
+
Attributes:
|
|
168
|
+
low: Lower bound (inclusive).
|
|
169
|
+
high: Upper bound (exclusive).
|
|
170
|
+
"""
|
|
171
|
+
|
|
118
172
|
low: float
|
|
119
173
|
high: float
|
|
120
174
|
|
|
@@ -126,10 +180,25 @@ class UniformDistributionParams(ConfigBase):
|
|
|
126
180
|
|
|
127
181
|
|
|
128
182
|
class UniformDistribution(Distribution[UniformDistributionParams]):
|
|
183
|
+
"""Uniform distribution for sampling inference parameters.
|
|
184
|
+
|
|
185
|
+
Samples values uniformly between low and high bounds. Useful for exploring
|
|
186
|
+
a continuous range of values for temperature or top_p.
|
|
187
|
+
|
|
188
|
+
Attributes:
|
|
189
|
+
distribution_type: Type of distribution ("uniform").
|
|
190
|
+
params: Distribution parameters (low, high).
|
|
191
|
+
"""
|
|
192
|
+
|
|
129
193
|
distribution_type: Optional[DistributionType] = "uniform"
|
|
130
194
|
params: UniformDistributionParams
|
|
131
195
|
|
|
132
196
|
def sample(self) -> float:
|
|
197
|
+
"""Sample a value from the uniform distribution.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
A float value sampled from the uniform distribution.
|
|
201
|
+
"""
|
|
133
202
|
return float(np.random.uniform(low=self.params.low, high=self.params.high, size=1)[0])
|
|
134
203
|
|
|
135
204
|
|
|
@@ -137,6 +206,17 @@ DistributionT: TypeAlias = Union[UniformDistribution, ManualDistribution]
|
|
|
137
206
|
|
|
138
207
|
|
|
139
208
|
class InferenceParameters(ConfigBase):
|
|
209
|
+
"""Configuration for LLM inference parameters.
|
|
210
|
+
|
|
211
|
+
Attributes:
|
|
212
|
+
temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling.
|
|
213
|
+
top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
|
|
214
|
+
max_tokens: Maximum number of tokens (includes both input and output tokens).
|
|
215
|
+
max_parallel_requests: Maximum number of parallel requests to the model API.
|
|
216
|
+
timeout: Timeout in seconds for each request.
|
|
217
|
+
extra_body: Additional parameters to pass to the model API.
|
|
218
|
+
"""
|
|
219
|
+
|
|
140
220
|
temperature: Optional[Union[float, DistributionT]] = None
|
|
141
221
|
top_p: Optional[Union[float, DistributionT]] = None
|
|
142
222
|
max_tokens: Optional[int] = Field(default=None, ge=1)
|
|
@@ -146,6 +226,11 @@ class InferenceParameters(ConfigBase):
|
|
|
146
226
|
|
|
147
227
|
@property
|
|
148
228
|
def generate_kwargs(self) -> dict[str, Union[float, int]]:
|
|
229
|
+
"""Get the generate kwargs for the inference parameters.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
A dictionary of the generate kwargs.
|
|
233
|
+
"""
|
|
149
234
|
result = {}
|
|
150
235
|
if self.temperature is not None:
|
|
151
236
|
result["temperature"] = (
|
|
@@ -206,6 +291,15 @@ class InferenceParameters(ConfigBase):
|
|
|
206
291
|
|
|
207
292
|
|
|
208
293
|
class ModelConfig(ConfigBase):
|
|
294
|
+
"""Configuration for a model used for generation.
|
|
295
|
+
|
|
296
|
+
Attributes:
|
|
297
|
+
alias: User-defined alias to reference in column configurations.
|
|
298
|
+
model: Model identifier (e.g., from build.nvidia.com or other providers).
|
|
299
|
+
inference_parameters: Inference parameters for the model (temperature, top_p, max_tokens, etc.).
|
|
300
|
+
provider: Optional model provider name if using custom providers.
|
|
301
|
+
"""
|
|
302
|
+
|
|
209
303
|
alias: str
|
|
210
304
|
model: str
|
|
211
305
|
inference_parameters: InferenceParameters = Field(default_factory=InferenceParameters)
|
|
@@ -213,6 +307,16 @@ class ModelConfig(ConfigBase):
|
|
|
213
307
|
|
|
214
308
|
|
|
215
309
|
class ModelProvider(ConfigBase):
|
|
310
|
+
"""Configuration for a custom model provider.
|
|
311
|
+
|
|
312
|
+
Attributes:
|
|
313
|
+
name: Name of the model provider.
|
|
314
|
+
endpoint: API endpoint URL for the provider.
|
|
315
|
+
provider_type: Provider type (default: "openai"). Determines the API format to use.
|
|
316
|
+
api_key: Optional API key for authentication.
|
|
317
|
+
extra_body: Additional parameters to pass in API requests.
|
|
318
|
+
"""
|
|
319
|
+
|
|
216
320
|
name: str
|
|
217
321
|
endpoint: str
|
|
218
322
|
provider_type: str = "openai"
|
|
@@ -7,9 +7,9 @@ from typing import Optional
|
|
|
7
7
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
|
|
10
|
-
from .analysis.dataset_profiler import DatasetProfilerResults
|
|
11
|
-
from .config_builder import DataDesignerConfigBuilder
|
|
12
|
-
from .utils.visualization import WithRecordSamplerMixin
|
|
10
|
+
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
|
|
11
|
+
from data_designer.config.config_builder import DataDesignerConfigBuilder
|
|
12
|
+
from data_designer.config.utils.visualization import WithRecordSamplerMixin
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class PreviewResults(WithRecordSamplerMixin):
|
|
@@ -7,8 +7,8 @@ from typing import Literal
|
|
|
7
7
|
|
|
8
8
|
from pydantic import Field, field_validator
|
|
9
9
|
|
|
10
|
-
from .base import ConfigBase
|
|
11
|
-
from .dataset_builders import BuildStage
|
|
10
|
+
from data_designer.config.base import ConfigBase
|
|
11
|
+
from data_designer.config.dataset_builders import BuildStage
|
|
12
12
|
|
|
13
13
|
SUPPORTED_STAGES = [BuildStage.POST_BATCH]
|
|
14
14
|
|
|
@@ -8,8 +8,8 @@ import pandas as pd
|
|
|
8
8
|
from pydantic import Field, field_validator, model_validator
|
|
9
9
|
from typing_extensions import Self, TypeAlias
|
|
10
10
|
|
|
11
|
-
from .base import ConfigBase
|
|
12
|
-
from .utils.constants import (
|
|
11
|
+
from data_designer.config.base import ConfigBase
|
|
12
|
+
from data_designer.config.utils.constants import (
|
|
13
13
|
AVAILABLE_LOCALES,
|
|
14
14
|
DEFAULT_AGE_RANGE,
|
|
15
15
|
LOCALES_WITH_MANAGED_DATASETS,
|
|
@@ -430,9 +430,6 @@ class PersonSamplerParams(ConfigBase):
|
|
|
430
430
|
age_range: Two-element list [min_age, max_age] specifying the age range to sample from
|
|
431
431
|
(inclusive). Defaults to a standard age range. Both values must be between minimum and
|
|
432
432
|
maximum allowed ages.
|
|
433
|
-
state: Only supported for "en_US" locale. Filters to sample people from specified US state(s).
|
|
434
|
-
Must be provided as two-letter state abbreviations (e.g., "CA", "NY", "TX"). Can be a
|
|
435
|
-
single state or a list of states.
|
|
436
433
|
with_synthetic_personas: If True, appends additional synthetic persona columns including
|
|
437
434
|
personality traits, interests, and background descriptions. Only supported for certain
|
|
438
435
|
locales with managed datasets.
|
data_designer/config/seed.py
CHANGED
|
@@ -8,9 +8,9 @@ from typing import Optional, Union
|
|
|
8
8
|
from pydantic import Field, field_validator, model_validator
|
|
9
9
|
from typing_extensions import Self
|
|
10
10
|
|
|
11
|
-
from .base import ConfigBase
|
|
12
|
-
from .datastore import DatastoreSettings
|
|
13
|
-
from .utils.io_helpers import (
|
|
11
|
+
from data_designer.config.base import ConfigBase
|
|
12
|
+
from data_designer.config.datastore import DatastoreSettings
|
|
13
|
+
from data_designer.config.utils.io_helpers import (
|
|
14
14
|
VALID_DATASET_FILE_EXTENSIONS,
|
|
15
15
|
validate_dataset_file_path,
|
|
16
16
|
validate_path_contains_files_of_type,
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from
|
|
4
|
+
from data_designer.errors import DataDesignerError
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class UserJinjaTemplateSyntaxError(DataDesignerError): ...
|
|
@@ -5,10 +5,14 @@ from abc import ABC, abstractmethod
|
|
|
5
5
|
from enum import Enum
|
|
6
6
|
from typing import Literal, TypeVar
|
|
7
7
|
|
|
8
|
-
from
|
|
9
|
-
from
|
|
10
|
-
from .type_helpers import get_sampler_params
|
|
11
|
-
from .visualization import
|
|
8
|
+
from data_designer.config.models import ModelConfig, ModelProvider
|
|
9
|
+
from data_designer.config.sampler_params import SamplerType
|
|
10
|
+
from data_designer.config.utils.type_helpers import get_sampler_params
|
|
11
|
+
from data_designer.config.utils.visualization import (
|
|
12
|
+
display_model_configs_table,
|
|
13
|
+
display_model_providers_table,
|
|
14
|
+
display_sampler_table,
|
|
15
|
+
)
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
class InfoType(str, Enum):
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from datetime import date, datetime, timedelta
|
|
5
|
-
from decimal import Decimal
|
|
6
4
|
import json
|
|
7
5
|
import logging
|
|
8
|
-
from numbers import Number
|
|
9
6
|
import os
|
|
7
|
+
from datetime import date, datetime, timedelta
|
|
8
|
+
from decimal import Decimal
|
|
9
|
+
from numbers import Number
|
|
10
10
|
from pathlib import Path
|
|
11
11
|
from typing import Any, Union
|
|
12
12
|
|
|
@@ -14,7 +14,7 @@ import numpy as np
|
|
|
14
14
|
import pandas as pd
|
|
15
15
|
import yaml
|
|
16
16
|
|
|
17
|
-
from
|
|
17
|
+
from data_designer.config.errors import InvalidFileFormatError, InvalidFilePathError
|
|
18
18
|
|
|
19
19
|
logger = logging.getLogger(__name__)
|
|
20
20
|
|
|
@@ -44,7 +44,7 @@ def load_config_file(file_path: Path) -> dict:
|
|
|
44
44
|
InvalidFileFormatError: If YAML is malformed
|
|
45
45
|
InvalidConfigError: If file is empty
|
|
46
46
|
"""
|
|
47
|
-
from
|
|
47
|
+
from data_designer.config.errors import InvalidConfigError
|
|
48
48
|
|
|
49
49
|
if not file_path.exists():
|
|
50
50
|
raise InvalidFilePathError(f"Configuration file not found: {file_path}")
|
|
@@ -3,14 +3,14 @@
|
|
|
3
3
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
|
-
from contextlib import contextmanager
|
|
7
6
|
import json
|
|
7
|
+
from contextlib import contextmanager
|
|
8
8
|
from typing import Optional, Union
|
|
9
9
|
|
|
10
10
|
from jinja2 import TemplateSyntaxError, meta
|
|
11
11
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
|
12
12
|
|
|
13
|
-
from .errors import UserJinjaTemplateSyntaxError
|
|
13
|
+
from data_designer.config.utils.errors import UserJinjaTemplateSyntaxError
|
|
14
14
|
|
|
15
15
|
REPR_LIST_LENGTH_USE_JSON = 4
|
|
16
16
|
|
|
@@ -43,7 +43,7 @@ def assert_valid_jinja2_template(template: str) -> None:
|
|
|
43
43
|
def can_run_data_designer_locally() -> bool:
|
|
44
44
|
"""Returns True if Data Designer can be run locally, False otherwise."""
|
|
45
45
|
try:
|
|
46
|
-
from ... import engine # noqa: F401
|
|
46
|
+
from ... import engine # noqa: F401, TID252
|
|
47
47
|
except ImportError:
|
|
48
48
|
return False
|
|
49
49
|
return True
|
|
@@ -1,14 +1,18 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from enum import Enum
|
|
5
4
|
import inspect
|
|
5
|
+
from enum import Enum
|
|
6
6
|
from typing import Any, Literal, Type, get_args, get_origin
|
|
7
7
|
|
|
8
8
|
from pydantic import BaseModel
|
|
9
9
|
|
|
10
|
-
from
|
|
11
|
-
from .errors import
|
|
10
|
+
from data_designer.config import sampler_params
|
|
11
|
+
from data_designer.config.utils.errors import (
|
|
12
|
+
InvalidDiscriminatorFieldError,
|
|
13
|
+
InvalidEnumValueError,
|
|
14
|
+
InvalidTypeUnionError,
|
|
15
|
+
)
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
class StrEnum(str, Enum):
|
|
@@ -15,11 +15,11 @@ from rich.console import Console, Group
|
|
|
15
15
|
from rich.padding import Padding
|
|
16
16
|
from rich.panel import Panel
|
|
17
17
|
|
|
18
|
-
from
|
|
19
|
-
from
|
|
20
|
-
from
|
|
21
|
-
from .
|
|
22
|
-
from .
|
|
18
|
+
from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_llm_generated
|
|
19
|
+
from data_designer.config.processors import ProcessorConfig, ProcessorType
|
|
20
|
+
from data_designer.config.utils.constants import RICH_CONSOLE_THEME
|
|
21
|
+
from data_designer.config.utils.misc import can_run_data_designer_locally
|
|
22
|
+
from data_designer.config.validator_params import ValidatorType
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class ViolationType(str, Enum):
|
|
@@ -3,11 +3,11 @@
|
|
|
3
3
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
6
8
|
from collections import OrderedDict
|
|
7
9
|
from enum import Enum
|
|
8
10
|
from functools import cached_property
|
|
9
|
-
import json
|
|
10
|
-
import os
|
|
11
11
|
from typing import TYPE_CHECKING, Optional, Union
|
|
12
12
|
|
|
13
13
|
import numpy as np
|
|
@@ -21,16 +21,16 @@ from rich.syntax import Syntax
|
|
|
21
21
|
from rich.table import Table
|
|
22
22
|
from rich.text import Text
|
|
23
23
|
|
|
24
|
-
from
|
|
25
|
-
from
|
|
26
|
-
from
|
|
27
|
-
from
|
|
28
|
-
from .code_lang import code_lang_to_syntax_lexer
|
|
29
|
-
from .constants import NVIDIA_API_KEY_ENV_VAR_NAME, OPENAI_API_KEY_ENV_VAR_NAME
|
|
30
|
-
from .errors import DatasetSampleDisplayError
|
|
24
|
+
from data_designer.config.base import ConfigBase
|
|
25
|
+
from data_designer.config.column_types import DataDesignerColumnType
|
|
26
|
+
from data_designer.config.models import ModelConfig, ModelProvider
|
|
27
|
+
from data_designer.config.sampler_params import SamplerType
|
|
28
|
+
from data_designer.config.utils.code_lang import code_lang_to_syntax_lexer
|
|
29
|
+
from data_designer.config.utils.constants import NVIDIA_API_KEY_ENV_VAR_NAME, OPENAI_API_KEY_ENV_VAR_NAME
|
|
30
|
+
from data_designer.config.utils.errors import DatasetSampleDisplayError
|
|
31
31
|
|
|
32
32
|
if TYPE_CHECKING:
|
|
33
|
-
from
|
|
33
|
+
from data_designer.config.config_builder import DataDesignerConfigBuilder
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
console = Console()
|
|
@@ -7,8 +7,8 @@ from typing import Any, Optional, Union
|
|
|
7
7
|
from pydantic import Field, field_serializer, model_validator
|
|
8
8
|
from typing_extensions import Self, TypeAlias
|
|
9
9
|
|
|
10
|
-
from .base import ConfigBase
|
|
11
|
-
from .utils.code_lang import SQL_DIALECTS, CodeLang
|
|
10
|
+
from data_designer.config.base import ConfigBase
|
|
11
|
+
from data_designer.config.utils.code_lang import SQL_DIALECTS, CodeLang
|
|
12
12
|
|
|
13
13
|
SUPPORTED_CODE_LANGUAGES = {CodeLang.PYTHON, *SQL_DIALECTS}
|
|
14
14
|
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
import logging
|
|
4
5
|
from collections.abc import Sequence
|
|
5
6
|
from functools import cached_property
|
|
6
|
-
import logging
|
|
7
7
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
from pydantic import Field, field_validator
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from collections import defaultdict
|
|
5
4
|
import logging
|
|
5
|
+
from collections import defaultdict
|
|
6
6
|
from typing import Any, Optional, Union
|
|
7
7
|
|
|
8
8
|
import pandas as pd
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from functools import partial
|
|
5
4
|
import logging
|
|
6
5
|
import random
|
|
6
|
+
from functools import partial
|
|
7
7
|
from typing import Callable
|
|
8
8
|
|
|
9
9
|
import pandas as pd
|
|
@@ -3,8 +3,10 @@
|
|
|
3
3
|
|
|
4
4
|
import json
|
|
5
5
|
import logging
|
|
6
|
-
from pathlib import Path
|
|
7
6
|
import shutil
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from functools import cached_property
|
|
9
|
+
from pathlib import Path
|
|
8
10
|
from typing import Union
|
|
9
11
|
|
|
10
12
|
import pandas as pd
|
|
@@ -36,9 +38,21 @@ class ArtifactStorage(BaseModel):
|
|
|
36
38
|
def artifact_path_exists(self) -> bool:
|
|
37
39
|
return self.artifact_path.exists()
|
|
38
40
|
|
|
41
|
+
@cached_property
|
|
42
|
+
def resolved_dataset_name(self) -> str:
|
|
43
|
+
dataset_path = self.artifact_path / self.dataset_name
|
|
44
|
+
if dataset_path.exists() and len(list(dataset_path.iterdir())) > 0:
|
|
45
|
+
new_dataset_name = f"{self.dataset_name}_{datetime.now().strftime('%m-%d-%Y_%H%M%S')}"
|
|
46
|
+
logger.info(
|
|
47
|
+
f"📂 Dataset path {str(dataset_path)!r} already exists. Dataset from this session"
|
|
48
|
+
f"\n\t\t will be saved to {str(self.artifact_path / new_dataset_name)!r} instead."
|
|
49
|
+
)
|
|
50
|
+
return new_dataset_name
|
|
51
|
+
return self.dataset_name
|
|
52
|
+
|
|
39
53
|
@property
|
|
40
54
|
def base_dataset_path(self) -> Path:
|
|
41
|
-
return self.artifact_path / self.
|
|
55
|
+
return self.artifact_path / self.resolved_dataset_name
|
|
42
56
|
|
|
43
57
|
@property
|
|
44
58
|
def dropped_columns_dataset_path(self) -> Path:
|
|
@@ -4,8 +4,8 @@
|
|
|
4
4
|
import functools
|
|
5
5
|
import json
|
|
6
6
|
import logging
|
|
7
|
-
from pathlib import Path
|
|
8
7
|
import time
|
|
8
|
+
from pathlib import Path
|
|
9
9
|
from typing import Callable
|
|
10
10
|
|
|
11
11
|
import pandas as pd
|
|
@@ -88,8 +88,8 @@ class ColumnWiseDatasetBuilder:
|
|
|
88
88
|
start_time = time.perf_counter()
|
|
89
89
|
|
|
90
90
|
self.batch_manager.start(num_records=num_records, buffer_size=buffer_size)
|
|
91
|
-
for batch_idx in range(
|
|
92
|
-
logger.info(f"⏳ Processing batch {batch_idx} of {self.batch_manager.num_batches}")
|
|
91
|
+
for batch_idx in range(self.batch_manager.num_batches):
|
|
92
|
+
logger.info(f"⏳ Processing batch {batch_idx + 1} of {self.batch_manager.num_batches}")
|
|
93
93
|
self._run_batch(generators)
|
|
94
94
|
df_batch = self._run_processors(
|
|
95
95
|
stage=BuildStage.POST_BATCH,
|
|
@@ -3,10 +3,10 @@
|
|
|
3
3
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
|
-
from concurrent.futures import Future, ThreadPoolExecutor
|
|
7
6
|
import contextvars
|
|
8
7
|
import json
|
|
9
8
|
import logging
|
|
9
|
+
from concurrent.futures import Future, ThreadPoolExecutor
|
|
10
10
|
from threading import Lock, Semaphore
|
|
11
11
|
from typing import Any, Optional, Protocol
|
|
12
12
|
|
data_designer/engine/errors.py
CHANGED