data-designer 0.1.5__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/cli/README.md +15 -1
- data_designer/cli/commands/download.py +56 -0
- data_designer/cli/commands/list.py +4 -18
- data_designer/cli/controllers/__init__.py +2 -1
- data_designer/cli/controllers/download_controller.py +217 -0
- data_designer/cli/controllers/model_controller.py +4 -3
- data_designer/cli/forms/field.py +65 -19
- data_designer/cli/forms/model_builder.py +251 -44
- data_designer/cli/main.py +11 -1
- data_designer/cli/repositories/persona_repository.py +88 -0
- data_designer/cli/services/__init__.py +2 -1
- data_designer/cli/services/download_service.py +97 -0
- data_designer/cli/ui.py +131 -0
- data_designer/cli/utils.py +34 -0
- data_designer/config/analysis/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +75 -7
- data_designer/config/analysis/column_statistics.py +192 -48
- data_designer/config/analysis/dataset_profiler.py +23 -5
- data_designer/config/analysis/utils/reporting.py +3 -3
- data_designer/config/base.py +3 -3
- data_designer/config/column_configs.py +27 -6
- data_designer/config/column_types.py +24 -17
- data_designer/config/config_builder.py +34 -26
- data_designer/config/data_designer_config.py +7 -7
- data_designer/config/datastore.py +6 -6
- data_designer/config/default_model_settings.py +27 -34
- data_designer/config/exports.py +8 -0
- data_designer/config/models.py +155 -29
- data_designer/config/preview_results.py +6 -8
- data_designer/config/processors.py +63 -2
- data_designer/config/sampler_constraints.py +1 -2
- data_designer/config/sampler_params.py +31 -31
- data_designer/config/seed.py +1 -2
- data_designer/config/utils/code_lang.py +4 -5
- data_designer/config/utils/constants.py +31 -8
- data_designer/config/utils/io_helpers.py +5 -5
- data_designer/config/utils/misc.py +1 -4
- data_designer/config/utils/numerical_helpers.py +2 -2
- data_designer/config/utils/type_helpers.py +3 -3
- data_designer/config/utils/validation.py +7 -8
- data_designer/config/utils/visualization.py +32 -17
- data_designer/config/validator_params.py +4 -8
- data_designer/engine/analysis/column_profilers/base.py +0 -7
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
- data_designer/engine/analysis/column_statistics.py +16 -16
- data_designer/engine/analysis/dataset_profiler.py +25 -4
- data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
- data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
- data_designer/engine/column_generators/generators/base.py +34 -0
- data_designer/engine/column_generators/generators/embedding.py +45 -0
- data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
- data_designer/engine/column_generators/registry.py +4 -2
- data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
- data_designer/engine/configurable_task.py +2 -2
- data_designer/engine/dataset_builders/artifact_storage.py +1 -2
- data_designer/engine/dataset_builders/column_wise_builder.py +11 -10
- data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
- data_designer/engine/models/facade.py +66 -9
- data_designer/engine/models/litellm_overrides.py +5 -6
- data_designer/engine/models/parsers/errors.py +2 -4
- data_designer/engine/models/parsers/parser.py +2 -3
- data_designer/engine/models/parsers/postprocessors.py +3 -4
- data_designer/engine/models/parsers/types.py +4 -4
- data_designer/engine/models/registry.py +20 -11
- data_designer/engine/models/usage.py +7 -9
- data_designer/engine/processing/ginja/ast.py +1 -2
- data_designer/engine/processing/utils.py +40 -2
- data_designer/engine/registry/base.py +12 -12
- data_designer/engine/sampling_gen/constraints.py +1 -2
- data_designer/engine/sampling_gen/data_sources/base.py +14 -14
- data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
- data_designer/engine/sampling_gen/people_gen.py +3 -7
- data_designer/engine/validators/base.py +2 -2
- data_designer/logging.py +2 -2
- data_designer/plugin_manager.py +3 -3
- data_designer/plugins/plugin.py +3 -3
- data_designer/plugins/registry.py +2 -2
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/METADATA +1 -1
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/RECORD +83 -77
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
data_designer/config/exports.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
from data_designer.config.analysis.column_profilers import JudgeScoreProfilerConfig
|
|
5
5
|
from data_designer.config.column_configs import (
|
|
6
|
+
EmbeddingColumnConfig,
|
|
6
7
|
ExpressionColumnConfig,
|
|
7
8
|
LLMCodeColumnConfig,
|
|
8
9
|
LLMJudgeColumnConfig,
|
|
@@ -19,6 +20,9 @@ from data_designer.config.data_designer_config import DataDesignerConfig
|
|
|
19
20
|
from data_designer.config.dataset_builders import BuildStage
|
|
20
21
|
from data_designer.config.datastore import DatastoreSettings
|
|
21
22
|
from data_designer.config.models import (
|
|
23
|
+
ChatCompletionInferenceParams,
|
|
24
|
+
EmbeddingInferenceParams,
|
|
25
|
+
GenerationType,
|
|
22
26
|
ImageContext,
|
|
23
27
|
ImageFormat,
|
|
24
28
|
InferenceParameters,
|
|
@@ -81,6 +85,7 @@ def get_config_exports() -> list[str]:
|
|
|
81
85
|
CodeLang.__name__,
|
|
82
86
|
CodeValidatorParams.__name__,
|
|
83
87
|
ColumnInequalityConstraint.__name__,
|
|
88
|
+
ChatCompletionInferenceParams.__name__,
|
|
84
89
|
DataDesignerColumnType.__name__,
|
|
85
90
|
DataDesignerConfig.__name__,
|
|
86
91
|
DataDesignerConfigBuilder.__name__,
|
|
@@ -89,8 +94,11 @@ def get_config_exports() -> list[str]:
|
|
|
89
94
|
DatastoreSettings.__name__,
|
|
90
95
|
DatetimeSamplerParams.__name__,
|
|
91
96
|
DropColumnsProcessorConfig.__name__,
|
|
97
|
+
EmbeddingColumnConfig.__name__,
|
|
98
|
+
EmbeddingInferenceParams.__name__,
|
|
92
99
|
ExpressionColumnConfig.__name__,
|
|
93
100
|
GaussianSamplerParams.__name__,
|
|
101
|
+
GenerationType.__name__,
|
|
94
102
|
IndexRange.__name__,
|
|
95
103
|
InfoType.__name__,
|
|
96
104
|
ImageContext.__name__,
|
data_designer/config/models.py
CHANGED
|
@@ -5,10 +5,10 @@ import logging
|
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from enum import Enum
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import Any, Generic,
|
|
8
|
+
from typing import Any, Generic, Literal, TypeVar
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
|
-
from pydantic import BaseModel, Field, model_validator
|
|
11
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
12
12
|
from typing_extensions import Self, TypeAlias
|
|
13
13
|
|
|
14
14
|
from data_designer.config.base import ConfigBase
|
|
@@ -74,7 +74,7 @@ class ImageContext(ModalityContext):
|
|
|
74
74
|
"""
|
|
75
75
|
|
|
76
76
|
modality: Modality = Modality.IMAGE
|
|
77
|
-
image_format:
|
|
77
|
+
image_format: ImageFormat | None = None
|
|
78
78
|
|
|
79
79
|
def get_context(self, record: dict) -> dict[str, Any]:
|
|
80
80
|
"""Get the context for the image modality.
|
|
@@ -122,8 +122,8 @@ class ManualDistributionParams(ConfigBase):
|
|
|
122
122
|
weights: Optional list of weights for each value. If not provided, all values have equal probability.
|
|
123
123
|
"""
|
|
124
124
|
|
|
125
|
-
values:
|
|
126
|
-
weights:
|
|
125
|
+
values: list[float] = Field(min_length=1)
|
|
126
|
+
weights: list[float] | None = None
|
|
127
127
|
|
|
128
128
|
@model_validator(mode="after")
|
|
129
129
|
def _normalize_weights(self) -> Self:
|
|
@@ -149,7 +149,7 @@ class ManualDistribution(Distribution[ManualDistributionParams]):
|
|
|
149
149
|
params: Distribution parameters (values, weights).
|
|
150
150
|
"""
|
|
151
151
|
|
|
152
|
-
distribution_type:
|
|
152
|
+
distribution_type: DistributionType | None = "manual"
|
|
153
153
|
params: ManualDistributionParams
|
|
154
154
|
|
|
155
155
|
def sample(self) -> float:
|
|
@@ -190,7 +190,7 @@ class UniformDistribution(Distribution[UniformDistributionParams]):
|
|
|
190
190
|
params: Distribution parameters (low, high).
|
|
191
191
|
"""
|
|
192
192
|
|
|
193
|
-
distribution_type:
|
|
193
|
+
distribution_type: DistributionType | None = "uniform"
|
|
194
194
|
params: UniformDistributionParams
|
|
195
195
|
|
|
196
196
|
def sample(self) -> float:
|
|
@@ -202,36 +202,93 @@ class UniformDistribution(Distribution[UniformDistributionParams]):
|
|
|
202
202
|
return float(np.random.uniform(low=self.params.low, high=self.params.high, size=1)[0])
|
|
203
203
|
|
|
204
204
|
|
|
205
|
-
DistributionT: TypeAlias =
|
|
205
|
+
DistributionT: TypeAlias = UniformDistribution | ManualDistribution
|
|
206
206
|
|
|
207
207
|
|
|
208
|
-
class
|
|
209
|
-
""
|
|
208
|
+
class GenerationType(str, Enum):
|
|
209
|
+
CHAT_COMPLETION = "chat-completion"
|
|
210
|
+
EMBEDDING = "embedding"
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class BaseInferenceParams(ConfigBase, ABC):
|
|
214
|
+
"""Base configuration for inference parameters.
|
|
210
215
|
|
|
211
216
|
Attributes:
|
|
212
|
-
|
|
213
|
-
top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
|
|
214
|
-
max_tokens: Maximum number of tokens (includes both input and output tokens).
|
|
217
|
+
generation_type: Type of generation (chat-completion or embedding). Acts as discriminator.
|
|
215
218
|
max_parallel_requests: Maximum number of parallel requests to the model API.
|
|
216
219
|
timeout: Timeout in seconds for each request.
|
|
217
220
|
extra_body: Additional parameters to pass to the model API.
|
|
218
221
|
"""
|
|
219
222
|
|
|
220
|
-
|
|
221
|
-
top_p: Optional[Union[float, DistributionT]] = None
|
|
222
|
-
max_tokens: Optional[int] = Field(default=None, ge=1)
|
|
223
|
+
generation_type: GenerationType
|
|
223
224
|
max_parallel_requests: int = Field(default=4, ge=1)
|
|
224
|
-
timeout:
|
|
225
|
-
extra_body:
|
|
225
|
+
timeout: int | None = Field(default=None, ge=1)
|
|
226
|
+
extra_body: dict[str, Any] | None = None
|
|
226
227
|
|
|
227
228
|
@property
|
|
228
|
-
def generate_kwargs(self) -> dict[str,
|
|
229
|
+
def generate_kwargs(self) -> dict[str, Any]:
|
|
229
230
|
"""Get the generate kwargs for the inference parameters.
|
|
230
231
|
|
|
231
232
|
Returns:
|
|
232
233
|
A dictionary of the generate kwargs.
|
|
233
234
|
"""
|
|
234
235
|
result = {}
|
|
236
|
+
if self.timeout is not None:
|
|
237
|
+
result["timeout"] = self.timeout
|
|
238
|
+
if self.extra_body is not None and self.extra_body != {}:
|
|
239
|
+
result["extra_body"] = self.extra_body
|
|
240
|
+
return result
|
|
241
|
+
|
|
242
|
+
def format_for_display(self) -> str:
|
|
243
|
+
"""Format inference parameters for display.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
Formatted string of inference parameters
|
|
247
|
+
"""
|
|
248
|
+
params_dict = self.model_dump(exclude_none=True, mode="json")
|
|
249
|
+
|
|
250
|
+
if not params_dict:
|
|
251
|
+
return "(none)"
|
|
252
|
+
|
|
253
|
+
parts = []
|
|
254
|
+
for key, value in params_dict.items():
|
|
255
|
+
formatted_value = self._format_value(key, value)
|
|
256
|
+
parts.append(f"{key}={formatted_value}")
|
|
257
|
+
return ", ".join(parts)
|
|
258
|
+
|
|
259
|
+
def _format_value(self, key: str, value: Any) -> str:
|
|
260
|
+
"""Format a single parameter value. Override in subclasses for custom formatting.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
key: Parameter name
|
|
264
|
+
value: Parameter value
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
Formatted string representation of the value
|
|
268
|
+
"""
|
|
269
|
+
if isinstance(value, float):
|
|
270
|
+
return f"{value:.2f}"
|
|
271
|
+
return str(value)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class ChatCompletionInferenceParams(BaseInferenceParams):
|
|
275
|
+
"""Configuration for LLM inference parameters.
|
|
276
|
+
|
|
277
|
+
Attributes:
|
|
278
|
+
generation_type: Type of generation, always "chat-completion" for this class.
|
|
279
|
+
temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling.
|
|
280
|
+
top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
|
|
281
|
+
max_tokens: Maximum number of tokens (includes both input and output tokens).
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
generation_type: Literal[GenerationType.CHAT_COMPLETION] = GenerationType.CHAT_COMPLETION
|
|
285
|
+
temperature: float | DistributionT | None = None
|
|
286
|
+
top_p: float | DistributionT | None = None
|
|
287
|
+
max_tokens: int | None = Field(default=None, ge=1)
|
|
288
|
+
|
|
289
|
+
@property
|
|
290
|
+
def generate_kwargs(self) -> dict[str, Any]:
|
|
291
|
+
result = super().generate_kwargs
|
|
235
292
|
if self.temperature is not None:
|
|
236
293
|
result["temperature"] = (
|
|
237
294
|
self.temperature.sample() if hasattr(self.temperature, "sample") else self.temperature
|
|
@@ -240,10 +297,6 @@ class InferenceParameters(ConfigBase):
|
|
|
240
297
|
result["top_p"] = self.top_p.sample() if hasattr(self.top_p, "sample") else self.top_p
|
|
241
298
|
if self.max_tokens is not None:
|
|
242
299
|
result["max_tokens"] = self.max_tokens
|
|
243
|
-
if self.timeout is not None:
|
|
244
|
-
result["timeout"] = self.timeout
|
|
245
|
-
if self.extra_body is not None and self.extra_body != {}:
|
|
246
|
-
result["extra_body"] = self.extra_body
|
|
247
300
|
return result
|
|
248
301
|
|
|
249
302
|
@model_validator(mode="after")
|
|
@@ -266,7 +319,7 @@ class InferenceParameters(ConfigBase):
|
|
|
266
319
|
|
|
267
320
|
def _run_validation(
|
|
268
321
|
self,
|
|
269
|
-
value:
|
|
322
|
+
value: float | DistributionT | None,
|
|
270
323
|
param_name: str,
|
|
271
324
|
min_value: float,
|
|
272
325
|
max_value: float,
|
|
@@ -289,6 +342,61 @@ class InferenceParameters(ConfigBase):
|
|
|
289
342
|
def _is_value_in_range(self, value: float, min_value: float, max_value: float) -> bool:
|
|
290
343
|
return min_value <= value <= max_value
|
|
291
344
|
|
|
345
|
+
def _format_value(self, key: str, value: Any) -> str:
|
|
346
|
+
"""Format chat completion parameter values, including distributions.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
key: Parameter name
|
|
350
|
+
value: Parameter value
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
Formatted string representation of the value
|
|
354
|
+
"""
|
|
355
|
+
if isinstance(value, dict) and "distribution_type" in value:
|
|
356
|
+
return "dist"
|
|
357
|
+
return super()._format_value(key, value)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
# Maintain backwards compatibility with a deprecation warning
|
|
361
|
+
class InferenceParameters(ChatCompletionInferenceParams):
|
|
362
|
+
"""
|
|
363
|
+
Deprecated: Use ChatCompletionInferenceParams instead.
|
|
364
|
+
This alias will be removed in a future version.
|
|
365
|
+
"""
|
|
366
|
+
|
|
367
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
368
|
+
logger.warning(
|
|
369
|
+
"InferenceParameters is deprecated and will be removed in a future version. "
|
|
370
|
+
"Use ChatCompletionInferenceParams instead."
|
|
371
|
+
)
|
|
372
|
+
super().__init__(*args, **kwargs)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
class EmbeddingInferenceParams(BaseInferenceParams):
|
|
376
|
+
"""Configuration for embedding generation parameters.
|
|
377
|
+
|
|
378
|
+
Attributes:
|
|
379
|
+
generation_type: Type of generation, always "embedding" for this class.
|
|
380
|
+
encoding_format: Format of the embedding encoding ("float" or "base64").
|
|
381
|
+
dimensions: Number of dimensions for the embedding.
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
generation_type: Literal[GenerationType.EMBEDDING] = GenerationType.EMBEDDING
|
|
385
|
+
encoding_format: Literal["float", "base64"] = "float"
|
|
386
|
+
dimensions: int | None = None
|
|
387
|
+
|
|
388
|
+
@property
|
|
389
|
+
def generate_kwargs(self) -> dict[str, float | int]:
|
|
390
|
+
result = super().generate_kwargs
|
|
391
|
+
if self.encoding_format is not None:
|
|
392
|
+
result["encoding_format"] = self.encoding_format
|
|
393
|
+
if self.dimensions is not None:
|
|
394
|
+
result["dimensions"] = self.dimensions
|
|
395
|
+
return result
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
InferenceParamsT: TypeAlias = ChatCompletionInferenceParams | EmbeddingInferenceParams | InferenceParameters
|
|
399
|
+
|
|
292
400
|
|
|
293
401
|
class ModelConfig(ConfigBase):
|
|
294
402
|
"""Configuration for a model used for generation.
|
|
@@ -297,13 +405,31 @@ class ModelConfig(ConfigBase):
|
|
|
297
405
|
alias: User-defined alias to reference in column configurations.
|
|
298
406
|
model: Model identifier (e.g., from build.nvidia.com or other providers).
|
|
299
407
|
inference_parameters: Inference parameters for the model (temperature, top_p, max_tokens, etc.).
|
|
408
|
+
The generation_type is determined by the type of inference_parameters.
|
|
300
409
|
provider: Optional model provider name if using custom providers.
|
|
301
410
|
"""
|
|
302
411
|
|
|
303
412
|
alias: str
|
|
304
413
|
model: str
|
|
305
|
-
inference_parameters:
|
|
306
|
-
provider:
|
|
414
|
+
inference_parameters: InferenceParamsT = Field(default_factory=ChatCompletionInferenceParams)
|
|
415
|
+
provider: str | None = None
|
|
416
|
+
|
|
417
|
+
@property
|
|
418
|
+
def generation_type(self) -> GenerationType:
|
|
419
|
+
"""Get the generation type from the inference parameters."""
|
|
420
|
+
return self.inference_parameters.generation_type
|
|
421
|
+
|
|
422
|
+
@field_validator("inference_parameters", mode="before")
|
|
423
|
+
@classmethod
|
|
424
|
+
def _convert_inference_parameters(cls, value: Any) -> Any:
|
|
425
|
+
"""Convert raw dict to appropriate inference parameters type based on field presence."""
|
|
426
|
+
if isinstance(value, dict):
|
|
427
|
+
# Infer type from presence of embedding-specific fields
|
|
428
|
+
if "encoding_format" in value or "dimensions" in value:
|
|
429
|
+
return EmbeddingInferenceParams(**value)
|
|
430
|
+
else:
|
|
431
|
+
return ChatCompletionInferenceParams(**value)
|
|
432
|
+
return value
|
|
307
433
|
|
|
308
434
|
|
|
309
435
|
class ModelProvider(ConfigBase):
|
|
@@ -320,11 +446,11 @@ class ModelProvider(ConfigBase):
|
|
|
320
446
|
name: str
|
|
321
447
|
endpoint: str
|
|
322
448
|
provider_type: str = "openai"
|
|
323
|
-
api_key:
|
|
324
|
-
extra_body:
|
|
449
|
+
api_key: str | None = None
|
|
450
|
+
extra_body: dict[str, Any] | None = None
|
|
325
451
|
|
|
326
452
|
|
|
327
|
-
def load_model_configs(model_configs:
|
|
453
|
+
def load_model_configs(model_configs: list[ModelConfig] | str | Path) -> list[ModelConfig]:
|
|
328
454
|
if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs):
|
|
329
455
|
return model_configs
|
|
330
456
|
json_config = smart_load_yaml(model_configs)
|
|
@@ -3,8 +3,6 @@
|
|
|
3
3
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
|
-
from typing import Optional, Union
|
|
7
|
-
|
|
8
6
|
import pandas as pd
|
|
9
7
|
|
|
10
8
|
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
|
|
@@ -17,9 +15,9 @@ class PreviewResults(WithRecordSamplerMixin):
|
|
|
17
15
|
self,
|
|
18
16
|
*,
|
|
19
17
|
config_builder: DataDesignerConfigBuilder,
|
|
20
|
-
dataset:
|
|
21
|
-
analysis:
|
|
22
|
-
processor_artifacts:
|
|
18
|
+
dataset: pd.DataFrame | None = None,
|
|
19
|
+
analysis: DatasetProfilerResults | None = None,
|
|
20
|
+
processor_artifacts: dict[str, list[str] | str] | None = None,
|
|
23
21
|
):
|
|
24
22
|
"""Creates a new instance with results from a Data Designer preview run.
|
|
25
23
|
|
|
@@ -29,7 +27,7 @@ class PreviewResults(WithRecordSamplerMixin):
|
|
|
29
27
|
analysis: Analysis of the preview run.
|
|
30
28
|
processor_artifacts: Artifacts generated by the processors.
|
|
31
29
|
"""
|
|
32
|
-
self.dataset:
|
|
33
|
-
self.analysis:
|
|
34
|
-
self.processor_artifacts:
|
|
30
|
+
self.dataset: pd.DataFrame | None = dataset
|
|
31
|
+
self.analysis: DatasetProfilerResults | None = analysis
|
|
32
|
+
self.processor_artifacts: dict[str, list[str] | str] | None = processor_artifacts
|
|
35
33
|
self._config_builder = config_builder
|
|
@@ -7,6 +7,7 @@ from enum import Enum
|
|
|
7
7
|
from typing import Any, Literal
|
|
8
8
|
|
|
9
9
|
from pydantic import Field, field_validator
|
|
10
|
+
from typing_extensions import TypeAlias
|
|
10
11
|
|
|
11
12
|
from data_designer.config.base import ConfigBase
|
|
12
13
|
from data_designer.config.dataset_builders import BuildStage
|
|
@@ -16,11 +17,30 @@ SUPPORTED_STAGES = [BuildStage.POST_BATCH]
|
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
class ProcessorType(str, Enum):
|
|
20
|
+
"""Enumeration of available processor types.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
DROP_COLUMNS: Processor that removes specified columns from the output dataset.
|
|
24
|
+
SCHEMA_TRANSFORM: Processor that creates a new dataset with a transformed schema using Jinja2 templates.
|
|
25
|
+
"""
|
|
26
|
+
|
|
19
27
|
DROP_COLUMNS = "drop_columns"
|
|
20
28
|
SCHEMA_TRANSFORM = "schema_transform"
|
|
21
29
|
|
|
22
30
|
|
|
23
31
|
class ProcessorConfig(ConfigBase, ABC):
|
|
32
|
+
"""Abstract base class for all processor configuration types.
|
|
33
|
+
|
|
34
|
+
Processors are transformations that run before or after columns are generated.
|
|
35
|
+
They can modify, reshape, or augment the dataset before it's saved.
|
|
36
|
+
|
|
37
|
+
Attributes:
|
|
38
|
+
name: Unique name of the processor, used to identify the processor in results
|
|
39
|
+
and to name output artifacts on disk.
|
|
40
|
+
build_stage: The stage at which the processor runs. Currently only `POST_BATCH`
|
|
41
|
+
is supported, meaning processors run after each batch of columns is generated.
|
|
42
|
+
"""
|
|
43
|
+
|
|
24
44
|
name: str = Field(
|
|
25
45
|
description="The name of the processor, used to identify the processor in the results and to write the artifacts to disk.",
|
|
26
46
|
)
|
|
@@ -28,6 +48,7 @@ class ProcessorConfig(ConfigBase, ABC):
|
|
|
28
48
|
default=BuildStage.POST_BATCH,
|
|
29
49
|
description=f"The stage at which the processor will run. Supported stages: {', '.join(SUPPORTED_STAGES)}",
|
|
30
50
|
)
|
|
51
|
+
processor_type: str
|
|
31
52
|
|
|
32
53
|
@field_validator("build_stage")
|
|
33
54
|
def validate_build_stage(cls, v: BuildStage) -> BuildStage:
|
|
@@ -38,7 +59,16 @@ class ProcessorConfig(ConfigBase, ABC):
|
|
|
38
59
|
return v
|
|
39
60
|
|
|
40
61
|
|
|
41
|
-
def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs) -> ProcessorConfig:
|
|
62
|
+
def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs: Any) -> ProcessorConfig:
|
|
63
|
+
"""Create a processor configuration from a processor type and keyword arguments.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
processor_type: The type of processor to create.
|
|
67
|
+
**kwargs: Additional keyword arguments passed to the processor constructor.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
A processor configuration object of the specified type.
|
|
71
|
+
"""
|
|
42
72
|
if processor_type == ProcessorType.DROP_COLUMNS:
|
|
43
73
|
return DropColumnsProcessorConfig(**kwargs)
|
|
44
74
|
elif processor_type == ProcessorType.SCHEMA_TRANSFORM:
|
|
@@ -46,11 +76,39 @@ def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs) ->
|
|
|
46
76
|
|
|
47
77
|
|
|
48
78
|
class DropColumnsProcessorConfig(ProcessorConfig):
|
|
49
|
-
|
|
79
|
+
"""Configuration for dropping columns from the output dataset.
|
|
80
|
+
|
|
81
|
+
This processor removes specified columns from the generated dataset. The dropped
|
|
82
|
+
columns are saved separately in a `dropped-columns` directory for reference.
|
|
83
|
+
When this processor is added via the config builder, the corresponding column
|
|
84
|
+
configs are automatically marked with `drop = True`.
|
|
85
|
+
|
|
86
|
+
Alternatively, you can set `drop = True` when configuring a column.
|
|
87
|
+
|
|
88
|
+
Attributes:
|
|
89
|
+
column_names: List of column names to remove from the output dataset.
|
|
90
|
+
processor_type: Discriminator field, always `ProcessorType.DROP_COLUMNS` for this configuration type.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
column_names: list[str] = Field(description="List of column names to drop from the output dataset.")
|
|
50
94
|
processor_type: Literal[ProcessorType.DROP_COLUMNS] = ProcessorType.DROP_COLUMNS
|
|
51
95
|
|
|
52
96
|
|
|
53
97
|
class SchemaTransformProcessorConfig(ProcessorConfig):
|
|
98
|
+
"""Configuration for transforming the dataset schema using Jinja2 templates.
|
|
99
|
+
|
|
100
|
+
This processor creates a new dataset with a transformed schema. Each key in the
|
|
101
|
+
template becomes a column in the output, and values are Jinja2 templates that
|
|
102
|
+
can reference any column in the batch. The transformed dataset is written to
|
|
103
|
+
a `processors-outputs/{processor_name}/` directory alongside the main dataset.
|
|
104
|
+
|
|
105
|
+
Attributes:
|
|
106
|
+
template: Dictionary defining the output schema. Keys are new column names,
|
|
107
|
+
values are Jinja2 templates (strings, lists, or nested structures).
|
|
108
|
+
Must be JSON-serializable.
|
|
109
|
+
processor_type: Discriminator field, always `ProcessorType.SCHEMA_TRANSFORM` for this configuration type.
|
|
110
|
+
"""
|
|
111
|
+
|
|
54
112
|
template: dict[str, Any] = Field(
|
|
55
113
|
...,
|
|
56
114
|
description="""
|
|
@@ -83,3 +141,6 @@ class SchemaTransformProcessorConfig(ProcessorConfig):
|
|
|
83
141
|
if "not JSON serializable" in str(e):
|
|
84
142
|
raise InvalidConfigError("Template must be JSON serializable")
|
|
85
143
|
return v
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
ProcessorConfigT: TypeAlias = DropColumnsProcessorConfig | SchemaTransformProcessorConfig
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
5
|
from enum import Enum
|
|
6
|
-
from typing import Union
|
|
7
6
|
|
|
8
7
|
from typing_extensions import TypeAlias
|
|
9
8
|
|
|
@@ -48,4 +47,4 @@ class ColumnInequalityConstraint(Constraint):
|
|
|
48
47
|
return ConstraintType.COLUMN_INEQUALITY
|
|
49
48
|
|
|
50
49
|
|
|
51
|
-
ColumnConstraintT: TypeAlias =
|
|
50
|
+
ColumnConstraintT: TypeAlias = ScalarInequalityConstraint | ColumnInequalityConstraint
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from enum import Enum
|
|
5
|
-
from typing import Literal
|
|
5
|
+
from typing import Literal
|
|
6
6
|
|
|
7
7
|
import pandas as pd
|
|
8
8
|
from pydantic import Field, field_validator, model_validator
|
|
@@ -54,12 +54,12 @@ class CategorySamplerParams(ConfigBase):
|
|
|
54
54
|
Larger weights result in higher sampling probability for the corresponding value.
|
|
55
55
|
"""
|
|
56
56
|
|
|
57
|
-
values: list[
|
|
57
|
+
values: list[str | int | float] = Field(
|
|
58
58
|
...,
|
|
59
59
|
min_length=1,
|
|
60
60
|
description="List of possible categorical values that can be sampled from.",
|
|
61
61
|
)
|
|
62
|
-
weights:
|
|
62
|
+
weights: list[float] | None = Field(
|
|
63
63
|
default=None,
|
|
64
64
|
description=(
|
|
65
65
|
"List of unnormalized probability weights to assigned to each value, in order. "
|
|
@@ -134,7 +134,7 @@ class SubcategorySamplerParams(ConfigBase):
|
|
|
134
134
|
"""
|
|
135
135
|
|
|
136
136
|
category: str = Field(..., description="Name of parent category to this subcategory.")
|
|
137
|
-
values: dict[str, list[
|
|
137
|
+
values: dict[str, list[str | int | float]] = Field(
|
|
138
138
|
...,
|
|
139
139
|
description="Mapping from each value of parent category to a list of subcategory values.",
|
|
140
140
|
)
|
|
@@ -214,7 +214,7 @@ class UUIDSamplerParams(ConfigBase):
|
|
|
214
214
|
lowercase UUIDs.
|
|
215
215
|
"""
|
|
216
216
|
|
|
217
|
-
prefix:
|
|
217
|
+
prefix: str | None = Field(default=None, description="String prepended to the front of the UUID.")
|
|
218
218
|
short_form: bool = Field(
|
|
219
219
|
default=False,
|
|
220
220
|
description="If true, all UUIDs sampled will be truncated at 8 characters.",
|
|
@@ -259,7 +259,7 @@ class ScipySamplerParams(ConfigBase):
|
|
|
259
259
|
...,
|
|
260
260
|
description="Parameters of the scipy.stats distribution given in `dist_name`.",
|
|
261
261
|
)
|
|
262
|
-
decimal_places:
|
|
262
|
+
decimal_places: int | None = Field(
|
|
263
263
|
default=None, description="Number of decimal places to round the sampled values to."
|
|
264
264
|
)
|
|
265
265
|
sampler_type: Literal[SamplerType.SCIPY] = SamplerType.SCIPY
|
|
@@ -356,7 +356,7 @@ class GaussianSamplerParams(ConfigBase):
|
|
|
356
356
|
|
|
357
357
|
mean: float = Field(..., description="Mean of the Gaussian distribution")
|
|
358
358
|
stddev: float = Field(..., description="Standard deviation of the Gaussian distribution")
|
|
359
|
-
decimal_places:
|
|
359
|
+
decimal_places: int | None = Field(
|
|
360
360
|
default=None, description="Number of decimal places to round the sampled values to."
|
|
361
361
|
)
|
|
362
362
|
sampler_type: Literal[SamplerType.GAUSSIAN] = SamplerType.GAUSSIAN
|
|
@@ -398,7 +398,7 @@ class UniformSamplerParams(ConfigBase):
|
|
|
398
398
|
|
|
399
399
|
low: float = Field(..., description="Lower bound of the uniform distribution, inclusive.")
|
|
400
400
|
high: float = Field(..., description="Upper bound of the uniform distribution, inclusive.")
|
|
401
|
-
decimal_places:
|
|
401
|
+
decimal_places: int | None = Field(
|
|
402
402
|
default=None, description="Number of decimal places to round the sampled values to."
|
|
403
403
|
)
|
|
404
404
|
sampler_type: Literal[SamplerType.UNIFORM] = SamplerType.UNIFORM
|
|
@@ -421,8 +421,8 @@ class PersonSamplerParams(ConfigBase):
|
|
|
421
421
|
|
|
422
422
|
Attributes:
|
|
423
423
|
locale: Locale string determining the language and geographic region for synthetic people.
|
|
424
|
-
|
|
425
|
-
|
|
424
|
+
Must be a locale supported by a managed Nemotron Personas dataset. The dataset must
|
|
425
|
+
be downloaded and available in the managed assets directory.
|
|
426
426
|
sex: If specified, filters to only sample people of the specified sex. Options: "Male" or
|
|
427
427
|
"Female". If None, samples both sexes.
|
|
428
428
|
city: If specified, filters to only sample people from the specified city or cities. Can be
|
|
@@ -447,11 +447,11 @@ class PersonSamplerParams(ConfigBase):
|
|
|
447
447
|
f"{', '.join(LOCALES_WITH_MANAGED_DATASETS)}."
|
|
448
448
|
),
|
|
449
449
|
)
|
|
450
|
-
sex:
|
|
450
|
+
sex: SexT | None = Field(
|
|
451
451
|
default=None,
|
|
452
452
|
description="If specified, then only synthetic people of the specified sex will be sampled.",
|
|
453
453
|
)
|
|
454
|
-
city:
|
|
454
|
+
city: str | list[str] | None = Field(
|
|
455
455
|
default=None,
|
|
456
456
|
description="If specified, then only synthetic people from these cities will be sampled.",
|
|
457
457
|
)
|
|
@@ -461,7 +461,7 @@ class PersonSamplerParams(ConfigBase):
|
|
|
461
461
|
min_length=2,
|
|
462
462
|
max_length=2,
|
|
463
463
|
)
|
|
464
|
-
select_field_values:
|
|
464
|
+
select_field_values: dict[str, list[str]] | None = Field(
|
|
465
465
|
default=None,
|
|
466
466
|
description=(
|
|
467
467
|
"Sample synthetic people with the specified field values. This is meant to be a flexible argument for "
|
|
@@ -529,11 +529,11 @@ class PersonFromFakerSamplerParams(ConfigBase):
|
|
|
529
529
|
"that a synthetic person will be sampled from. E.g, en_US, en_GB, fr_FR, ..."
|
|
530
530
|
),
|
|
531
531
|
)
|
|
532
|
-
sex:
|
|
532
|
+
sex: SexT | None = Field(
|
|
533
533
|
default=None,
|
|
534
534
|
description="If specified, then only synthetic people of the specified sex will be sampled.",
|
|
535
535
|
)
|
|
536
|
-
city:
|
|
536
|
+
city: str | list[str] | None = Field(
|
|
537
537
|
default=None,
|
|
538
538
|
description="If specified, then only synthetic people from these cities will be sampled.",
|
|
539
539
|
)
|
|
@@ -585,22 +585,22 @@ class PersonFromFakerSamplerParams(ConfigBase):
|
|
|
585
585
|
return value
|
|
586
586
|
|
|
587
587
|
|
|
588
|
-
SamplerParamsT: TypeAlias =
|
|
589
|
-
SubcategorySamplerParams
|
|
590
|
-
CategorySamplerParams
|
|
591
|
-
DatetimeSamplerParams
|
|
592
|
-
PersonSamplerParams
|
|
593
|
-
PersonFromFakerSamplerParams
|
|
594
|
-
TimeDeltaSamplerParams
|
|
595
|
-
UUIDSamplerParams
|
|
596
|
-
BernoulliSamplerParams
|
|
597
|
-
BernoulliMixtureSamplerParams
|
|
598
|
-
BinomialSamplerParams
|
|
599
|
-
GaussianSamplerParams
|
|
600
|
-
PoissonSamplerParams
|
|
601
|
-
UniformSamplerParams
|
|
602
|
-
ScipySamplerParams
|
|
603
|
-
|
|
588
|
+
SamplerParamsT: TypeAlias = (
|
|
589
|
+
SubcategorySamplerParams
|
|
590
|
+
| CategorySamplerParams
|
|
591
|
+
| DatetimeSamplerParams
|
|
592
|
+
| PersonSamplerParams
|
|
593
|
+
| PersonFromFakerSamplerParams
|
|
594
|
+
| TimeDeltaSamplerParams
|
|
595
|
+
| UUIDSamplerParams
|
|
596
|
+
| BernoulliSamplerParams
|
|
597
|
+
| BernoulliMixtureSamplerParams
|
|
598
|
+
| BinomialSamplerParams
|
|
599
|
+
| GaussianSamplerParams
|
|
600
|
+
| PoissonSamplerParams
|
|
601
|
+
| UniformSamplerParams
|
|
602
|
+
| ScipySamplerParams
|
|
603
|
+
)
|
|
604
604
|
|
|
605
605
|
|
|
606
606
|
def is_numerical_sampler_type(sampler_type: SamplerType) -> bool:
|
data_designer/config/seed.py
CHANGED
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
|
|
4
4
|
from abc import ABC
|
|
5
5
|
from enum import Enum
|
|
6
|
-
from typing import Optional, Union
|
|
7
6
|
|
|
8
7
|
from pydantic import Field, field_validator, model_validator
|
|
9
8
|
from typing_extensions import Self
|
|
@@ -112,7 +111,7 @@ class SeedConfig(ConfigBase):
|
|
|
112
111
|
|
|
113
112
|
dataset: str
|
|
114
113
|
sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED
|
|
115
|
-
selection_strategy:
|
|
114
|
+
selection_strategy: IndexRange | PartitionBlock | None = None
|
|
116
115
|
|
|
117
116
|
|
|
118
117
|
class SeedDatasetReference(ABC, ConfigBase):
|