data-designer 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/cli/README.md +15 -1
- data_designer/cli/commands/download.py +56 -0
- data_designer/cli/commands/list.py +4 -18
- data_designer/cli/controllers/__init__.py +2 -1
- data_designer/cli/controllers/download_controller.py +217 -0
- data_designer/cli/controllers/model_controller.py +4 -3
- data_designer/cli/forms/field.py +65 -19
- data_designer/cli/forms/model_builder.py +251 -44
- data_designer/cli/main.py +11 -1
- data_designer/cli/repositories/persona_repository.py +88 -0
- data_designer/cli/services/__init__.py +2 -1
- data_designer/cli/services/download_service.py +97 -0
- data_designer/cli/ui.py +131 -0
- data_designer/cli/utils.py +34 -0
- data_designer/config/analysis/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +75 -7
- data_designer/config/analysis/column_statistics.py +192 -48
- data_designer/config/analysis/dataset_profiler.py +23 -5
- data_designer/config/analysis/utils/reporting.py +3 -3
- data_designer/config/base.py +3 -3
- data_designer/config/column_configs.py +27 -6
- data_designer/config/column_types.py +24 -17
- data_designer/config/config_builder.py +34 -26
- data_designer/config/data_designer_config.py +7 -7
- data_designer/config/datastore.py +6 -6
- data_designer/config/default_model_settings.py +27 -34
- data_designer/config/exports.py +14 -1
- data_designer/config/models.py +155 -29
- data_designer/config/preview_results.py +5 -4
- data_designer/config/processors.py +109 -4
- data_designer/config/sampler_constraints.py +1 -2
- data_designer/config/sampler_params.py +31 -31
- data_designer/config/seed.py +1 -2
- data_designer/config/utils/code_lang.py +4 -5
- data_designer/config/utils/constants.py +31 -8
- data_designer/config/utils/io_helpers.py +5 -5
- data_designer/config/utils/misc.py +1 -4
- data_designer/config/utils/numerical_helpers.py +2 -2
- data_designer/config/utils/type_helpers.py +3 -3
- data_designer/config/utils/validation.py +39 -9
- data_designer/config/utils/visualization.py +62 -15
- data_designer/config/validator_params.py +4 -8
- data_designer/engine/analysis/column_profilers/base.py +0 -7
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
- data_designer/engine/analysis/column_statistics.py +16 -16
- data_designer/engine/analysis/dataset_profiler.py +25 -4
- data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
- data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
- data_designer/engine/column_generators/generators/base.py +34 -0
- data_designer/engine/column_generators/generators/embedding.py +45 -0
- data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
- data_designer/engine/column_generators/registry.py +4 -2
- data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
- data_designer/engine/configurable_task.py +2 -2
- data_designer/engine/dataset_builders/artifact_storage.py +14 -5
- data_designer/engine/dataset_builders/column_wise_builder.py +12 -8
- data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
- data_designer/engine/models/facade.py +66 -9
- data_designer/engine/models/litellm_overrides.py +5 -6
- data_designer/engine/models/parsers/errors.py +2 -4
- data_designer/engine/models/parsers/parser.py +2 -3
- data_designer/engine/models/parsers/postprocessors.py +3 -4
- data_designer/engine/models/parsers/types.py +4 -4
- data_designer/engine/models/registry.py +20 -11
- data_designer/engine/models/usage.py +7 -9
- data_designer/engine/processing/ginja/ast.py +1 -2
- data_designer/engine/processing/processors/drop_columns.py +1 -1
- data_designer/engine/processing/processors/registry.py +3 -0
- data_designer/engine/processing/processors/schema_transform.py +53 -0
- data_designer/engine/processing/utils.py +40 -2
- data_designer/engine/registry/base.py +12 -12
- data_designer/engine/sampling_gen/constraints.py +1 -2
- data_designer/engine/sampling_gen/data_sources/base.py +14 -14
- data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
- data_designer/engine/sampling_gen/people_gen.py +3 -7
- data_designer/engine/validators/base.py +2 -2
- data_designer/interface/data_designer.py +12 -0
- data_designer/interface/results.py +36 -0
- data_designer/logging.py +2 -2
- data_designer/plugin_manager.py +3 -3
- data_designer/plugins/plugin.py +3 -3
- data_designer/plugins/registry.py +2 -2
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/METADATA +9 -9
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/RECORD +88 -81
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
data_designer/config/exports.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
from data_designer.config.analysis.column_profilers import JudgeScoreProfilerConfig
|
|
5
5
|
from data_designer.config.column_configs import (
|
|
6
|
+
EmbeddingColumnConfig,
|
|
6
7
|
ExpressionColumnConfig,
|
|
7
8
|
LLMCodeColumnConfig,
|
|
8
9
|
LLMJudgeColumnConfig,
|
|
@@ -19,6 +20,9 @@ from data_designer.config.data_designer_config import DataDesignerConfig
|
|
|
19
20
|
from data_designer.config.dataset_builders import BuildStage
|
|
20
21
|
from data_designer.config.datastore import DatastoreSettings
|
|
21
22
|
from data_designer.config.models import (
|
|
23
|
+
ChatCompletionInferenceParams,
|
|
24
|
+
EmbeddingInferenceParams,
|
|
25
|
+
GenerationType,
|
|
22
26
|
ImageContext,
|
|
23
27
|
ImageFormat,
|
|
24
28
|
InferenceParameters,
|
|
@@ -32,7 +36,11 @@ from data_designer.config.models import (
|
|
|
32
36
|
UniformDistribution,
|
|
33
37
|
UniformDistributionParams,
|
|
34
38
|
)
|
|
35
|
-
from data_designer.config.processors import
|
|
39
|
+
from data_designer.config.processors import (
|
|
40
|
+
DropColumnsProcessorConfig,
|
|
41
|
+
ProcessorType,
|
|
42
|
+
SchemaTransformProcessorConfig,
|
|
43
|
+
)
|
|
36
44
|
from data_designer.config.sampler_constraints import ColumnInequalityConstraint, ScalarInequalityConstraint
|
|
37
45
|
from data_designer.config.sampler_params import (
|
|
38
46
|
BernoulliMixtureSamplerParams,
|
|
@@ -69,6 +77,7 @@ from data_designer.config.validator_params import (
|
|
|
69
77
|
|
|
70
78
|
def get_config_exports() -> list[str]:
|
|
71
79
|
return [
|
|
80
|
+
SchemaTransformProcessorConfig.__name__,
|
|
72
81
|
BernoulliMixtureSamplerParams.__name__,
|
|
73
82
|
BernoulliSamplerParams.__name__,
|
|
74
83
|
BinomialSamplerParams.__name__,
|
|
@@ -76,6 +85,7 @@ def get_config_exports() -> list[str]:
|
|
|
76
85
|
CodeLang.__name__,
|
|
77
86
|
CodeValidatorParams.__name__,
|
|
78
87
|
ColumnInequalityConstraint.__name__,
|
|
88
|
+
ChatCompletionInferenceParams.__name__,
|
|
79
89
|
DataDesignerColumnType.__name__,
|
|
80
90
|
DataDesignerConfig.__name__,
|
|
81
91
|
DataDesignerConfigBuilder.__name__,
|
|
@@ -84,8 +94,11 @@ def get_config_exports() -> list[str]:
|
|
|
84
94
|
DatastoreSettings.__name__,
|
|
85
95
|
DatetimeSamplerParams.__name__,
|
|
86
96
|
DropColumnsProcessorConfig.__name__,
|
|
97
|
+
EmbeddingColumnConfig.__name__,
|
|
98
|
+
EmbeddingInferenceParams.__name__,
|
|
87
99
|
ExpressionColumnConfig.__name__,
|
|
88
100
|
GaussianSamplerParams.__name__,
|
|
101
|
+
GenerationType.__name__,
|
|
89
102
|
IndexRange.__name__,
|
|
90
103
|
InfoType.__name__,
|
|
91
104
|
ImageContext.__name__,
|
data_designer/config/models.py
CHANGED
|
@@ -5,10 +5,10 @@ import logging
|
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from enum import Enum
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import Any, Generic,
|
|
8
|
+
from typing import Any, Generic, Literal, TypeVar
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
|
-
from pydantic import BaseModel, Field, model_validator
|
|
11
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
12
12
|
from typing_extensions import Self, TypeAlias
|
|
13
13
|
|
|
14
14
|
from data_designer.config.base import ConfigBase
|
|
@@ -74,7 +74,7 @@ class ImageContext(ModalityContext):
|
|
|
74
74
|
"""
|
|
75
75
|
|
|
76
76
|
modality: Modality = Modality.IMAGE
|
|
77
|
-
image_format:
|
|
77
|
+
image_format: ImageFormat | None = None
|
|
78
78
|
|
|
79
79
|
def get_context(self, record: dict) -> dict[str, Any]:
|
|
80
80
|
"""Get the context for the image modality.
|
|
@@ -122,8 +122,8 @@ class ManualDistributionParams(ConfigBase):
|
|
|
122
122
|
weights: Optional list of weights for each value. If not provided, all values have equal probability.
|
|
123
123
|
"""
|
|
124
124
|
|
|
125
|
-
values:
|
|
126
|
-
weights:
|
|
125
|
+
values: list[float] = Field(min_length=1)
|
|
126
|
+
weights: list[float] | None = None
|
|
127
127
|
|
|
128
128
|
@model_validator(mode="after")
|
|
129
129
|
def _normalize_weights(self) -> Self:
|
|
@@ -149,7 +149,7 @@ class ManualDistribution(Distribution[ManualDistributionParams]):
|
|
|
149
149
|
params: Distribution parameters (values, weights).
|
|
150
150
|
"""
|
|
151
151
|
|
|
152
|
-
distribution_type:
|
|
152
|
+
distribution_type: DistributionType | None = "manual"
|
|
153
153
|
params: ManualDistributionParams
|
|
154
154
|
|
|
155
155
|
def sample(self) -> float:
|
|
@@ -190,7 +190,7 @@ class UniformDistribution(Distribution[UniformDistributionParams]):
|
|
|
190
190
|
params: Distribution parameters (low, high).
|
|
191
191
|
"""
|
|
192
192
|
|
|
193
|
-
distribution_type:
|
|
193
|
+
distribution_type: DistributionType | None = "uniform"
|
|
194
194
|
params: UniformDistributionParams
|
|
195
195
|
|
|
196
196
|
def sample(self) -> float:
|
|
@@ -202,36 +202,93 @@ class UniformDistribution(Distribution[UniformDistributionParams]):
|
|
|
202
202
|
return float(np.random.uniform(low=self.params.low, high=self.params.high, size=1)[0])
|
|
203
203
|
|
|
204
204
|
|
|
205
|
-
DistributionT: TypeAlias =
|
|
205
|
+
DistributionT: TypeAlias = UniformDistribution | ManualDistribution
|
|
206
206
|
|
|
207
207
|
|
|
208
|
-
class
|
|
209
|
-
""
|
|
208
|
+
class GenerationType(str, Enum):
|
|
209
|
+
CHAT_COMPLETION = "chat-completion"
|
|
210
|
+
EMBEDDING = "embedding"
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class BaseInferenceParams(ConfigBase, ABC):
|
|
214
|
+
"""Base configuration for inference parameters.
|
|
210
215
|
|
|
211
216
|
Attributes:
|
|
212
|
-
|
|
213
|
-
top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
|
|
214
|
-
max_tokens: Maximum number of tokens (includes both input and output tokens).
|
|
217
|
+
generation_type: Type of generation (chat-completion or embedding). Acts as discriminator.
|
|
215
218
|
max_parallel_requests: Maximum number of parallel requests to the model API.
|
|
216
219
|
timeout: Timeout in seconds for each request.
|
|
217
220
|
extra_body: Additional parameters to pass to the model API.
|
|
218
221
|
"""
|
|
219
222
|
|
|
220
|
-
|
|
221
|
-
top_p: Optional[Union[float, DistributionT]] = None
|
|
222
|
-
max_tokens: Optional[int] = Field(default=None, ge=1)
|
|
223
|
+
generation_type: GenerationType
|
|
223
224
|
max_parallel_requests: int = Field(default=4, ge=1)
|
|
224
|
-
timeout:
|
|
225
|
-
extra_body:
|
|
225
|
+
timeout: int | None = Field(default=None, ge=1)
|
|
226
|
+
extra_body: dict[str, Any] | None = None
|
|
226
227
|
|
|
227
228
|
@property
|
|
228
|
-
def generate_kwargs(self) -> dict[str,
|
|
229
|
+
def generate_kwargs(self) -> dict[str, Any]:
|
|
229
230
|
"""Get the generate kwargs for the inference parameters.
|
|
230
231
|
|
|
231
232
|
Returns:
|
|
232
233
|
A dictionary of the generate kwargs.
|
|
233
234
|
"""
|
|
234
235
|
result = {}
|
|
236
|
+
if self.timeout is not None:
|
|
237
|
+
result["timeout"] = self.timeout
|
|
238
|
+
if self.extra_body is not None and self.extra_body != {}:
|
|
239
|
+
result["extra_body"] = self.extra_body
|
|
240
|
+
return result
|
|
241
|
+
|
|
242
|
+
def format_for_display(self) -> str:
|
|
243
|
+
"""Format inference parameters for display.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
Formatted string of inference parameters
|
|
247
|
+
"""
|
|
248
|
+
params_dict = self.model_dump(exclude_none=True, mode="json")
|
|
249
|
+
|
|
250
|
+
if not params_dict:
|
|
251
|
+
return "(none)"
|
|
252
|
+
|
|
253
|
+
parts = []
|
|
254
|
+
for key, value in params_dict.items():
|
|
255
|
+
formatted_value = self._format_value(key, value)
|
|
256
|
+
parts.append(f"{key}={formatted_value}")
|
|
257
|
+
return ", ".join(parts)
|
|
258
|
+
|
|
259
|
+
def _format_value(self, key: str, value: Any) -> str:
|
|
260
|
+
"""Format a single parameter value. Override in subclasses for custom formatting.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
key: Parameter name
|
|
264
|
+
value: Parameter value
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
Formatted string representation of the value
|
|
268
|
+
"""
|
|
269
|
+
if isinstance(value, float):
|
|
270
|
+
return f"{value:.2f}"
|
|
271
|
+
return str(value)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class ChatCompletionInferenceParams(BaseInferenceParams):
|
|
275
|
+
"""Configuration for LLM inference parameters.
|
|
276
|
+
|
|
277
|
+
Attributes:
|
|
278
|
+
generation_type: Type of generation, always "chat-completion" for this class.
|
|
279
|
+
temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling.
|
|
280
|
+
top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
|
|
281
|
+
max_tokens: Maximum number of tokens (includes both input and output tokens).
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
generation_type: Literal[GenerationType.CHAT_COMPLETION] = GenerationType.CHAT_COMPLETION
|
|
285
|
+
temperature: float | DistributionT | None = None
|
|
286
|
+
top_p: float | DistributionT | None = None
|
|
287
|
+
max_tokens: int | None = Field(default=None, ge=1)
|
|
288
|
+
|
|
289
|
+
@property
|
|
290
|
+
def generate_kwargs(self) -> dict[str, Any]:
|
|
291
|
+
result = super().generate_kwargs
|
|
235
292
|
if self.temperature is not None:
|
|
236
293
|
result["temperature"] = (
|
|
237
294
|
self.temperature.sample() if hasattr(self.temperature, "sample") else self.temperature
|
|
@@ -240,10 +297,6 @@ class InferenceParameters(ConfigBase):
|
|
|
240
297
|
result["top_p"] = self.top_p.sample() if hasattr(self.top_p, "sample") else self.top_p
|
|
241
298
|
if self.max_tokens is not None:
|
|
242
299
|
result["max_tokens"] = self.max_tokens
|
|
243
|
-
if self.timeout is not None:
|
|
244
|
-
result["timeout"] = self.timeout
|
|
245
|
-
if self.extra_body is not None and self.extra_body != {}:
|
|
246
|
-
result["extra_body"] = self.extra_body
|
|
247
300
|
return result
|
|
248
301
|
|
|
249
302
|
@model_validator(mode="after")
|
|
@@ -266,7 +319,7 @@ class InferenceParameters(ConfigBase):
|
|
|
266
319
|
|
|
267
320
|
def _run_validation(
|
|
268
321
|
self,
|
|
269
|
-
value:
|
|
322
|
+
value: float | DistributionT | None,
|
|
270
323
|
param_name: str,
|
|
271
324
|
min_value: float,
|
|
272
325
|
max_value: float,
|
|
@@ -289,6 +342,61 @@ class InferenceParameters(ConfigBase):
|
|
|
289
342
|
def _is_value_in_range(self, value: float, min_value: float, max_value: float) -> bool:
|
|
290
343
|
return min_value <= value <= max_value
|
|
291
344
|
|
|
345
|
+
def _format_value(self, key: str, value: Any) -> str:
|
|
346
|
+
"""Format chat completion parameter values, including distributions.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
key: Parameter name
|
|
350
|
+
value: Parameter value
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
Formatted string representation of the value
|
|
354
|
+
"""
|
|
355
|
+
if isinstance(value, dict) and "distribution_type" in value:
|
|
356
|
+
return "dist"
|
|
357
|
+
return super()._format_value(key, value)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
# Maintain backwards compatibility with a deprecation warning
|
|
361
|
+
class InferenceParameters(ChatCompletionInferenceParams):
|
|
362
|
+
"""
|
|
363
|
+
Deprecated: Use ChatCompletionInferenceParams instead.
|
|
364
|
+
This alias will be removed in a future version.
|
|
365
|
+
"""
|
|
366
|
+
|
|
367
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
368
|
+
logger.warning(
|
|
369
|
+
"InferenceParameters is deprecated and will be removed in a future version. "
|
|
370
|
+
"Use ChatCompletionInferenceParams instead."
|
|
371
|
+
)
|
|
372
|
+
super().__init__(*args, **kwargs)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
class EmbeddingInferenceParams(BaseInferenceParams):
|
|
376
|
+
"""Configuration for embedding generation parameters.
|
|
377
|
+
|
|
378
|
+
Attributes:
|
|
379
|
+
generation_type: Type of generation, always "embedding" for this class.
|
|
380
|
+
encoding_format: Format of the embedding encoding ("float" or "base64").
|
|
381
|
+
dimensions: Number of dimensions for the embedding.
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
generation_type: Literal[GenerationType.EMBEDDING] = GenerationType.EMBEDDING
|
|
385
|
+
encoding_format: Literal["float", "base64"] = "float"
|
|
386
|
+
dimensions: int | None = None
|
|
387
|
+
|
|
388
|
+
@property
|
|
389
|
+
def generate_kwargs(self) -> dict[str, float | int]:
|
|
390
|
+
result = super().generate_kwargs
|
|
391
|
+
if self.encoding_format is not None:
|
|
392
|
+
result["encoding_format"] = self.encoding_format
|
|
393
|
+
if self.dimensions is not None:
|
|
394
|
+
result["dimensions"] = self.dimensions
|
|
395
|
+
return result
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
InferenceParamsT: TypeAlias = ChatCompletionInferenceParams | EmbeddingInferenceParams | InferenceParameters
|
|
399
|
+
|
|
292
400
|
|
|
293
401
|
class ModelConfig(ConfigBase):
|
|
294
402
|
"""Configuration for a model used for generation.
|
|
@@ -297,13 +405,31 @@ class ModelConfig(ConfigBase):
|
|
|
297
405
|
alias: User-defined alias to reference in column configurations.
|
|
298
406
|
model: Model identifier (e.g., from build.nvidia.com or other providers).
|
|
299
407
|
inference_parameters: Inference parameters for the model (temperature, top_p, max_tokens, etc.).
|
|
408
|
+
The generation_type is determined by the type of inference_parameters.
|
|
300
409
|
provider: Optional model provider name if using custom providers.
|
|
301
410
|
"""
|
|
302
411
|
|
|
303
412
|
alias: str
|
|
304
413
|
model: str
|
|
305
|
-
inference_parameters:
|
|
306
|
-
provider:
|
|
414
|
+
inference_parameters: InferenceParamsT = Field(default_factory=ChatCompletionInferenceParams)
|
|
415
|
+
provider: str | None = None
|
|
416
|
+
|
|
417
|
+
@property
|
|
418
|
+
def generation_type(self) -> GenerationType:
|
|
419
|
+
"""Get the generation type from the inference parameters."""
|
|
420
|
+
return self.inference_parameters.generation_type
|
|
421
|
+
|
|
422
|
+
@field_validator("inference_parameters", mode="before")
|
|
423
|
+
@classmethod
|
|
424
|
+
def _convert_inference_parameters(cls, value: Any) -> Any:
|
|
425
|
+
"""Convert raw dict to appropriate inference parameters type based on field presence."""
|
|
426
|
+
if isinstance(value, dict):
|
|
427
|
+
# Infer type from presence of embedding-specific fields
|
|
428
|
+
if "encoding_format" in value or "dimensions" in value:
|
|
429
|
+
return EmbeddingInferenceParams(**value)
|
|
430
|
+
else:
|
|
431
|
+
return ChatCompletionInferenceParams(**value)
|
|
432
|
+
return value
|
|
307
433
|
|
|
308
434
|
|
|
309
435
|
class ModelProvider(ConfigBase):
|
|
@@ -320,11 +446,11 @@ class ModelProvider(ConfigBase):
|
|
|
320
446
|
name: str
|
|
321
447
|
endpoint: str
|
|
322
448
|
provider_type: str = "openai"
|
|
323
|
-
api_key:
|
|
324
|
-
extra_body:
|
|
449
|
+
api_key: str | None = None
|
|
450
|
+
extra_body: dict[str, Any] | None = None
|
|
325
451
|
|
|
326
452
|
|
|
327
|
-
def load_model_configs(model_configs:
|
|
453
|
+
def load_model_configs(model_configs: list[ModelConfig] | str | Path) -> list[ModelConfig]:
|
|
328
454
|
if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs):
|
|
329
455
|
return model_configs
|
|
330
456
|
json_config = smart_load_yaml(model_configs)
|
|
@@ -3,8 +3,6 @@
|
|
|
3
3
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
|
-
from typing import Optional
|
|
7
|
-
|
|
8
6
|
import pandas as pd
|
|
9
7
|
|
|
10
8
|
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
|
|
@@ -17,8 +15,9 @@ class PreviewResults(WithRecordSamplerMixin):
|
|
|
17
15
|
self,
|
|
18
16
|
*,
|
|
19
17
|
config_builder: DataDesignerConfigBuilder,
|
|
20
|
-
dataset:
|
|
21
|
-
analysis:
|
|
18
|
+
dataset: pd.DataFrame | None = None,
|
|
19
|
+
analysis: DatasetProfilerResults | None = None,
|
|
20
|
+
processor_artifacts: dict[str, list[str] | str] | None = None,
|
|
22
21
|
):
|
|
23
22
|
"""Creates a new instance with results from a Data Designer preview run.
|
|
24
23
|
|
|
@@ -26,7 +25,9 @@ class PreviewResults(WithRecordSamplerMixin):
|
|
|
26
25
|
config_builder: Data Designer configuration builder.
|
|
27
26
|
dataset: Dataset of the preview run.
|
|
28
27
|
analysis: Analysis of the preview run.
|
|
28
|
+
processor_artifacts: Artifacts generated by the processors.
|
|
29
29
|
"""
|
|
30
30
|
self.dataset: pd.DataFrame | None = dataset
|
|
31
31
|
self.analysis: DatasetProfilerResults | None = analysis
|
|
32
|
+
self.processor_artifacts: dict[str, list[str] | str] | None = processor_artifacts
|
|
32
33
|
self._config_builder = config_builder
|
|
@@ -1,26 +1,54 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
import json
|
|
4
5
|
from abc import ABC
|
|
5
6
|
from enum import Enum
|
|
6
|
-
from typing import Literal
|
|
7
|
+
from typing import Any, Literal
|
|
7
8
|
|
|
8
9
|
from pydantic import Field, field_validator
|
|
10
|
+
from typing_extensions import TypeAlias
|
|
9
11
|
|
|
10
12
|
from data_designer.config.base import ConfigBase
|
|
11
13
|
from data_designer.config.dataset_builders import BuildStage
|
|
14
|
+
from data_designer.config.errors import InvalidConfigError
|
|
12
15
|
|
|
13
16
|
SUPPORTED_STAGES = [BuildStage.POST_BATCH]
|
|
14
17
|
|
|
15
18
|
|
|
16
19
|
class ProcessorType(str, Enum):
|
|
20
|
+
"""Enumeration of available processor types.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
DROP_COLUMNS: Processor that removes specified columns from the output dataset.
|
|
24
|
+
SCHEMA_TRANSFORM: Processor that creates a new dataset with a transformed schema using Jinja2 templates.
|
|
25
|
+
"""
|
|
26
|
+
|
|
17
27
|
DROP_COLUMNS = "drop_columns"
|
|
28
|
+
SCHEMA_TRANSFORM = "schema_transform"
|
|
18
29
|
|
|
19
30
|
|
|
20
31
|
class ProcessorConfig(ConfigBase, ABC):
|
|
32
|
+
"""Abstract base class for all processor configuration types.
|
|
33
|
+
|
|
34
|
+
Processors are transformations that run before or after columns are generated.
|
|
35
|
+
They can modify, reshape, or augment the dataset before it's saved.
|
|
36
|
+
|
|
37
|
+
Attributes:
|
|
38
|
+
name: Unique name of the processor, used to identify the processor in results
|
|
39
|
+
and to name output artifacts on disk.
|
|
40
|
+
build_stage: The stage at which the processor runs. Currently only `POST_BATCH`
|
|
41
|
+
is supported, meaning processors run after each batch of columns is generated.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
name: str = Field(
|
|
45
|
+
description="The name of the processor, used to identify the processor in the results and to write the artifacts to disk.",
|
|
46
|
+
)
|
|
21
47
|
build_stage: BuildStage = Field(
|
|
22
|
-
|
|
48
|
+
default=BuildStage.POST_BATCH,
|
|
49
|
+
description=f"The stage at which the processor will run. Supported stages: {', '.join(SUPPORTED_STAGES)}",
|
|
23
50
|
)
|
|
51
|
+
processor_type: str
|
|
24
52
|
|
|
25
53
|
@field_validator("build_stage")
|
|
26
54
|
def validate_build_stage(cls, v: BuildStage) -> BuildStage:
|
|
@@ -31,11 +59,88 @@ class ProcessorConfig(ConfigBase, ABC):
|
|
|
31
59
|
return v
|
|
32
60
|
|
|
33
61
|
|
|
34
|
-
def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs) -> ProcessorConfig:
|
|
62
|
+
def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs: Any) -> ProcessorConfig:
|
|
63
|
+
"""Create a processor configuration from a processor type and keyword arguments.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
processor_type: The type of processor to create.
|
|
67
|
+
**kwargs: Additional keyword arguments passed to the processor constructor.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
A processor configuration object of the specified type.
|
|
71
|
+
"""
|
|
35
72
|
if processor_type == ProcessorType.DROP_COLUMNS:
|
|
36
73
|
return DropColumnsProcessorConfig(**kwargs)
|
|
74
|
+
elif processor_type == ProcessorType.SCHEMA_TRANSFORM:
|
|
75
|
+
return SchemaTransformProcessorConfig(**kwargs)
|
|
37
76
|
|
|
38
77
|
|
|
39
78
|
class DropColumnsProcessorConfig(ProcessorConfig):
|
|
40
|
-
|
|
79
|
+
"""Configuration for dropping columns from the output dataset.
|
|
80
|
+
|
|
81
|
+
This processor removes specified columns from the generated dataset. The dropped
|
|
82
|
+
columns are saved separately in a `dropped-columns` directory for reference.
|
|
83
|
+
When this processor is added via the config builder, the corresponding column
|
|
84
|
+
configs are automatically marked with `drop = True`.
|
|
85
|
+
|
|
86
|
+
Alternatively, you can set `drop = True` when configuring a column.
|
|
87
|
+
|
|
88
|
+
Attributes:
|
|
89
|
+
column_names: List of column names to remove from the output dataset.
|
|
90
|
+
processor_type: Discriminator field, always `ProcessorType.DROP_COLUMNS` for this configuration type.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
column_names: list[str] = Field(description="List of column names to drop from the output dataset.")
|
|
41
94
|
processor_type: Literal[ProcessorType.DROP_COLUMNS] = ProcessorType.DROP_COLUMNS
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class SchemaTransformProcessorConfig(ProcessorConfig):
|
|
98
|
+
"""Configuration for transforming the dataset schema using Jinja2 templates.
|
|
99
|
+
|
|
100
|
+
This processor creates a new dataset with a transformed schema. Each key in the
|
|
101
|
+
template becomes a column in the output, and values are Jinja2 templates that
|
|
102
|
+
can reference any column in the batch. The transformed dataset is written to
|
|
103
|
+
a `processors-outputs/{processor_name}/` directory alongside the main dataset.
|
|
104
|
+
|
|
105
|
+
Attributes:
|
|
106
|
+
template: Dictionary defining the output schema. Keys are new column names,
|
|
107
|
+
values are Jinja2 templates (strings, lists, or nested structures).
|
|
108
|
+
Must be JSON-serializable.
|
|
109
|
+
processor_type: Discriminator field, always `ProcessorType.SCHEMA_TRANSFORM` for this configuration type.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
template: dict[str, Any] = Field(
|
|
113
|
+
...,
|
|
114
|
+
description="""
|
|
115
|
+
Dictionary specifying columns and templates to use in the new dataset with transformed schema.
|
|
116
|
+
|
|
117
|
+
Each key is a new column name, and each value is an object containing Jinja2 templates - for instance, a string or a list of strings.
|
|
118
|
+
Values must be JSON-serializable.
|
|
119
|
+
|
|
120
|
+
Example:
|
|
121
|
+
|
|
122
|
+
```python
|
|
123
|
+
template = {
|
|
124
|
+
"list_of_strings": ["{{ col1 }}", "{{ col2 }}"],
|
|
125
|
+
"uppercase_string": "{{ col1 | upper }}",
|
|
126
|
+
"lowercase_string": "{{ col2 | lower }}",
|
|
127
|
+
}
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
The above templates will create an new dataset with three columns: "list_of_strings", "uppercase_string", and "lowercase_string".
|
|
131
|
+
References to columns "col1" and "col2" in the templates will be replaced with the actual values of the columns in the dataset.
|
|
132
|
+
""",
|
|
133
|
+
)
|
|
134
|
+
processor_type: Literal[ProcessorType.SCHEMA_TRANSFORM] = ProcessorType.SCHEMA_TRANSFORM
|
|
135
|
+
|
|
136
|
+
@field_validator("template")
|
|
137
|
+
def validate_template(cls, v: dict[str, Any]) -> dict[str, Any]:
|
|
138
|
+
try:
|
|
139
|
+
json.dumps(v)
|
|
140
|
+
except TypeError as e:
|
|
141
|
+
if "not JSON serializable" in str(e):
|
|
142
|
+
raise InvalidConfigError("Template must be JSON serializable")
|
|
143
|
+
return v
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
ProcessorConfigT: TypeAlias = DropColumnsProcessorConfig | SchemaTransformProcessorConfig
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
5
|
from enum import Enum
|
|
6
|
-
from typing import Union
|
|
7
6
|
|
|
8
7
|
from typing_extensions import TypeAlias
|
|
9
8
|
|
|
@@ -48,4 +47,4 @@ class ColumnInequalityConstraint(Constraint):
|
|
|
48
47
|
return ConstraintType.COLUMN_INEQUALITY
|
|
49
48
|
|
|
50
49
|
|
|
51
|
-
ColumnConstraintT: TypeAlias =
|
|
50
|
+
ColumnConstraintT: TypeAlias = ScalarInequalityConstraint | ColumnInequalityConstraint
|