data-designer 0.1.5__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. data_designer/_version.py +2 -2
  2. data_designer/cli/README.md +15 -1
  3. data_designer/cli/commands/download.py +56 -0
  4. data_designer/cli/commands/list.py +4 -18
  5. data_designer/cli/controllers/__init__.py +2 -1
  6. data_designer/cli/controllers/download_controller.py +217 -0
  7. data_designer/cli/controllers/model_controller.py +4 -3
  8. data_designer/cli/forms/field.py +65 -19
  9. data_designer/cli/forms/model_builder.py +251 -44
  10. data_designer/cli/main.py +11 -1
  11. data_designer/cli/repositories/persona_repository.py +88 -0
  12. data_designer/cli/services/__init__.py +2 -1
  13. data_designer/cli/services/download_service.py +97 -0
  14. data_designer/cli/ui.py +131 -0
  15. data_designer/cli/utils.py +34 -0
  16. data_designer/config/analysis/__init__.py +2 -0
  17. data_designer/config/analysis/column_profilers.py +75 -7
  18. data_designer/config/analysis/column_statistics.py +192 -48
  19. data_designer/config/analysis/dataset_profiler.py +23 -5
  20. data_designer/config/analysis/utils/reporting.py +3 -3
  21. data_designer/config/base.py +3 -3
  22. data_designer/config/column_configs.py +27 -6
  23. data_designer/config/column_types.py +24 -17
  24. data_designer/config/config_builder.py +34 -26
  25. data_designer/config/data_designer_config.py +7 -7
  26. data_designer/config/datastore.py +6 -6
  27. data_designer/config/default_model_settings.py +27 -34
  28. data_designer/config/exports.py +8 -0
  29. data_designer/config/models.py +155 -29
  30. data_designer/config/preview_results.py +6 -8
  31. data_designer/config/processors.py +63 -2
  32. data_designer/config/sampler_constraints.py +1 -2
  33. data_designer/config/sampler_params.py +31 -31
  34. data_designer/config/seed.py +1 -2
  35. data_designer/config/utils/code_lang.py +4 -5
  36. data_designer/config/utils/constants.py +31 -8
  37. data_designer/config/utils/io_helpers.py +5 -5
  38. data_designer/config/utils/misc.py +1 -4
  39. data_designer/config/utils/numerical_helpers.py +2 -2
  40. data_designer/config/utils/type_helpers.py +3 -3
  41. data_designer/config/utils/validation.py +7 -8
  42. data_designer/config/utils/visualization.py +32 -17
  43. data_designer/config/validator_params.py +4 -8
  44. data_designer/engine/analysis/column_profilers/base.py +0 -7
  45. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
  46. data_designer/engine/analysis/column_statistics.py +16 -16
  47. data_designer/engine/analysis/dataset_profiler.py +25 -4
  48. data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
  49. data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
  50. data_designer/engine/column_generators/generators/base.py +34 -0
  51. data_designer/engine/column_generators/generators/embedding.py +45 -0
  52. data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
  53. data_designer/engine/column_generators/registry.py +4 -2
  54. data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
  55. data_designer/engine/configurable_task.py +2 -2
  56. data_designer/engine/dataset_builders/artifact_storage.py +1 -2
  57. data_designer/engine/dataset_builders/column_wise_builder.py +11 -10
  58. data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
  59. data_designer/engine/models/facade.py +66 -9
  60. data_designer/engine/models/litellm_overrides.py +5 -6
  61. data_designer/engine/models/parsers/errors.py +2 -4
  62. data_designer/engine/models/parsers/parser.py +2 -3
  63. data_designer/engine/models/parsers/postprocessors.py +3 -4
  64. data_designer/engine/models/parsers/types.py +4 -4
  65. data_designer/engine/models/registry.py +20 -11
  66. data_designer/engine/models/usage.py +7 -9
  67. data_designer/engine/processing/ginja/ast.py +1 -2
  68. data_designer/engine/processing/utils.py +40 -2
  69. data_designer/engine/registry/base.py +12 -12
  70. data_designer/engine/sampling_gen/constraints.py +1 -2
  71. data_designer/engine/sampling_gen/data_sources/base.py +14 -14
  72. data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
  73. data_designer/engine/sampling_gen/people_gen.py +3 -7
  74. data_designer/engine/validators/base.py +2 -2
  75. data_designer/logging.py +2 -2
  76. data_designer/plugin_manager.py +3 -3
  77. data_designer/plugins/plugin.py +3 -3
  78. data_designer/plugins/registry.py +2 -2
  79. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/METADATA +1 -1
  80. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/RECORD +83 -77
  81. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
  82. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
  83. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,7 @@
3
3
 
4
4
  from data_designer.config.analysis.column_profilers import JudgeScoreProfilerConfig
5
5
  from data_designer.config.column_configs import (
6
+ EmbeddingColumnConfig,
6
7
  ExpressionColumnConfig,
7
8
  LLMCodeColumnConfig,
8
9
  LLMJudgeColumnConfig,
@@ -19,6 +20,9 @@ from data_designer.config.data_designer_config import DataDesignerConfig
19
20
  from data_designer.config.dataset_builders import BuildStage
20
21
  from data_designer.config.datastore import DatastoreSettings
21
22
  from data_designer.config.models import (
23
+ ChatCompletionInferenceParams,
24
+ EmbeddingInferenceParams,
25
+ GenerationType,
22
26
  ImageContext,
23
27
  ImageFormat,
24
28
  InferenceParameters,
@@ -81,6 +85,7 @@ def get_config_exports() -> list[str]:
81
85
  CodeLang.__name__,
82
86
  CodeValidatorParams.__name__,
83
87
  ColumnInequalityConstraint.__name__,
88
+ ChatCompletionInferenceParams.__name__,
84
89
  DataDesignerColumnType.__name__,
85
90
  DataDesignerConfig.__name__,
86
91
  DataDesignerConfigBuilder.__name__,
@@ -89,8 +94,11 @@ def get_config_exports() -> list[str]:
89
94
  DatastoreSettings.__name__,
90
95
  DatetimeSamplerParams.__name__,
91
96
  DropColumnsProcessorConfig.__name__,
97
+ EmbeddingColumnConfig.__name__,
98
+ EmbeddingInferenceParams.__name__,
92
99
  ExpressionColumnConfig.__name__,
93
100
  GaussianSamplerParams.__name__,
101
+ GenerationType.__name__,
94
102
  IndexRange.__name__,
95
103
  InfoType.__name__,
96
104
  ImageContext.__name__,
@@ -5,10 +5,10 @@ import logging
5
5
  from abc import ABC, abstractmethod
6
6
  from enum import Enum
7
7
  from pathlib import Path
8
- from typing import Any, Generic, List, Optional, TypeVar, Union
8
+ from typing import Any, Generic, Literal, TypeVar
9
9
 
10
10
  import numpy as np
11
- from pydantic import BaseModel, Field, model_validator
11
+ from pydantic import BaseModel, Field, field_validator, model_validator
12
12
  from typing_extensions import Self, TypeAlias
13
13
 
14
14
  from data_designer.config.base import ConfigBase
@@ -74,7 +74,7 @@ class ImageContext(ModalityContext):
74
74
  """
75
75
 
76
76
  modality: Modality = Modality.IMAGE
77
- image_format: Optional[ImageFormat] = None
77
+ image_format: ImageFormat | None = None
78
78
 
79
79
  def get_context(self, record: dict) -> dict[str, Any]:
80
80
  """Get the context for the image modality.
@@ -122,8 +122,8 @@ class ManualDistributionParams(ConfigBase):
122
122
  weights: Optional list of weights for each value. If not provided, all values have equal probability.
123
123
  """
124
124
 
125
- values: List[float] = Field(min_length=1)
126
- weights: Optional[List[float]] = None
125
+ values: list[float] = Field(min_length=1)
126
+ weights: list[float] | None = None
127
127
 
128
128
  @model_validator(mode="after")
129
129
  def _normalize_weights(self) -> Self:
@@ -149,7 +149,7 @@ class ManualDistribution(Distribution[ManualDistributionParams]):
149
149
  params: Distribution parameters (values, weights).
150
150
  """
151
151
 
152
- distribution_type: Optional[DistributionType] = "manual"
152
+ distribution_type: DistributionType | None = "manual"
153
153
  params: ManualDistributionParams
154
154
 
155
155
  def sample(self) -> float:
@@ -190,7 +190,7 @@ class UniformDistribution(Distribution[UniformDistributionParams]):
190
190
  params: Distribution parameters (low, high).
191
191
  """
192
192
 
193
- distribution_type: Optional[DistributionType] = "uniform"
193
+ distribution_type: DistributionType | None = "uniform"
194
194
  params: UniformDistributionParams
195
195
 
196
196
  def sample(self) -> float:
@@ -202,36 +202,93 @@ class UniformDistribution(Distribution[UniformDistributionParams]):
202
202
  return float(np.random.uniform(low=self.params.low, high=self.params.high, size=1)[0])
203
203
 
204
204
 
205
- DistributionT: TypeAlias = Union[UniformDistribution, ManualDistribution]
205
+ DistributionT: TypeAlias = UniformDistribution | ManualDistribution
206
206
 
207
207
 
208
- class InferenceParameters(ConfigBase):
209
- """Configuration for LLM inference parameters.
208
+ class GenerationType(str, Enum):
209
+ CHAT_COMPLETION = "chat-completion"
210
+ EMBEDDING = "embedding"
211
+
212
+
213
+ class BaseInferenceParams(ConfigBase, ABC):
214
+ """Base configuration for inference parameters.
210
215
 
211
216
  Attributes:
212
- temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling.
213
- top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
214
- max_tokens: Maximum number of tokens (includes both input and output tokens).
217
+ generation_type: Type of generation (chat-completion or embedding). Acts as discriminator.
215
218
  max_parallel_requests: Maximum number of parallel requests to the model API.
216
219
  timeout: Timeout in seconds for each request.
217
220
  extra_body: Additional parameters to pass to the model API.
218
221
  """
219
222
 
220
- temperature: Optional[Union[float, DistributionT]] = None
221
- top_p: Optional[Union[float, DistributionT]] = None
222
- max_tokens: Optional[int] = Field(default=None, ge=1)
223
+ generation_type: GenerationType
223
224
  max_parallel_requests: int = Field(default=4, ge=1)
224
- timeout: Optional[int] = Field(default=None, ge=1)
225
- extra_body: Optional[dict[str, Any]] = None
225
+ timeout: int | None = Field(default=None, ge=1)
226
+ extra_body: dict[str, Any] | None = None
226
227
 
227
228
  @property
228
- def generate_kwargs(self) -> dict[str, Union[float, int]]:
229
+ def generate_kwargs(self) -> dict[str, Any]:
229
230
  """Get the generate kwargs for the inference parameters.
230
231
 
231
232
  Returns:
232
233
  A dictionary of the generate kwargs.
233
234
  """
234
235
  result = {}
236
+ if self.timeout is not None:
237
+ result["timeout"] = self.timeout
238
+ if self.extra_body is not None and self.extra_body != {}:
239
+ result["extra_body"] = self.extra_body
240
+ return result
241
+
242
+ def format_for_display(self) -> str:
243
+ """Format inference parameters for display.
244
+
245
+ Returns:
246
+ Formatted string of inference parameters
247
+ """
248
+ params_dict = self.model_dump(exclude_none=True, mode="json")
249
+
250
+ if not params_dict:
251
+ return "(none)"
252
+
253
+ parts = []
254
+ for key, value in params_dict.items():
255
+ formatted_value = self._format_value(key, value)
256
+ parts.append(f"{key}={formatted_value}")
257
+ return ", ".join(parts)
258
+
259
+ def _format_value(self, key: str, value: Any) -> str:
260
+ """Format a single parameter value. Override in subclasses for custom formatting.
261
+
262
+ Args:
263
+ key: Parameter name
264
+ value: Parameter value
265
+
266
+ Returns:
267
+ Formatted string representation of the value
268
+ """
269
+ if isinstance(value, float):
270
+ return f"{value:.2f}"
271
+ return str(value)
272
+
273
+
274
+ class ChatCompletionInferenceParams(BaseInferenceParams):
275
+ """Configuration for LLM inference parameters.
276
+
277
+ Attributes:
278
+ generation_type: Type of generation, always "chat-completion" for this class.
279
+ temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling.
280
+ top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
281
+ max_tokens: Maximum number of tokens (includes both input and output tokens).
282
+ """
283
+
284
+ generation_type: Literal[GenerationType.CHAT_COMPLETION] = GenerationType.CHAT_COMPLETION
285
+ temperature: float | DistributionT | None = None
286
+ top_p: float | DistributionT | None = None
287
+ max_tokens: int | None = Field(default=None, ge=1)
288
+
289
+ @property
290
+ def generate_kwargs(self) -> dict[str, Any]:
291
+ result = super().generate_kwargs
235
292
  if self.temperature is not None:
236
293
  result["temperature"] = (
237
294
  self.temperature.sample() if hasattr(self.temperature, "sample") else self.temperature
@@ -240,10 +297,6 @@ class InferenceParameters(ConfigBase):
240
297
  result["top_p"] = self.top_p.sample() if hasattr(self.top_p, "sample") else self.top_p
241
298
  if self.max_tokens is not None:
242
299
  result["max_tokens"] = self.max_tokens
243
- if self.timeout is not None:
244
- result["timeout"] = self.timeout
245
- if self.extra_body is not None and self.extra_body != {}:
246
- result["extra_body"] = self.extra_body
247
300
  return result
248
301
 
249
302
  @model_validator(mode="after")
@@ -266,7 +319,7 @@ class InferenceParameters(ConfigBase):
266
319
 
267
320
  def _run_validation(
268
321
  self,
269
- value: Union[float, DistributionT, None],
322
+ value: float | DistributionT | None,
270
323
  param_name: str,
271
324
  min_value: float,
272
325
  max_value: float,
@@ -289,6 +342,61 @@ class InferenceParameters(ConfigBase):
289
342
  def _is_value_in_range(self, value: float, min_value: float, max_value: float) -> bool:
290
343
  return min_value <= value <= max_value
291
344
 
345
+ def _format_value(self, key: str, value: Any) -> str:
346
+ """Format chat completion parameter values, including distributions.
347
+
348
+ Args:
349
+ key: Parameter name
350
+ value: Parameter value
351
+
352
+ Returns:
353
+ Formatted string representation of the value
354
+ """
355
+ if isinstance(value, dict) and "distribution_type" in value:
356
+ return "dist"
357
+ return super()._format_value(key, value)
358
+
359
+
360
+ # Maintain backwards compatibility with a deprecation warning
361
+ class InferenceParameters(ChatCompletionInferenceParams):
362
+ """
363
+ Deprecated: Use ChatCompletionInferenceParams instead.
364
+ This alias will be removed in a future version.
365
+ """
366
+
367
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
368
+ logger.warning(
369
+ "InferenceParameters is deprecated and will be removed in a future version. "
370
+ "Use ChatCompletionInferenceParams instead."
371
+ )
372
+ super().__init__(*args, **kwargs)
373
+
374
+
375
+ class EmbeddingInferenceParams(BaseInferenceParams):
376
+ """Configuration for embedding generation parameters.
377
+
378
+ Attributes:
379
+ generation_type: Type of generation, always "embedding" for this class.
380
+ encoding_format: Format of the embedding encoding ("float" or "base64").
381
+ dimensions: Number of dimensions for the embedding.
382
+ """
383
+
384
+ generation_type: Literal[GenerationType.EMBEDDING] = GenerationType.EMBEDDING
385
+ encoding_format: Literal["float", "base64"] = "float"
386
+ dimensions: int | None = None
387
+
388
+ @property
389
+ def generate_kwargs(self) -> dict[str, float | int]:
390
+ result = super().generate_kwargs
391
+ if self.encoding_format is not None:
392
+ result["encoding_format"] = self.encoding_format
393
+ if self.dimensions is not None:
394
+ result["dimensions"] = self.dimensions
395
+ return result
396
+
397
+
398
+ InferenceParamsT: TypeAlias = ChatCompletionInferenceParams | EmbeddingInferenceParams | InferenceParameters
399
+
292
400
 
293
401
  class ModelConfig(ConfigBase):
294
402
  """Configuration for a model used for generation.
@@ -297,13 +405,31 @@ class ModelConfig(ConfigBase):
297
405
  alias: User-defined alias to reference in column configurations.
298
406
  model: Model identifier (e.g., from build.nvidia.com or other providers).
299
407
  inference_parameters: Inference parameters for the model (temperature, top_p, max_tokens, etc.).
408
+ The generation_type is determined by the type of inference_parameters.
300
409
  provider: Optional model provider name if using custom providers.
301
410
  """
302
411
 
303
412
  alias: str
304
413
  model: str
305
- inference_parameters: InferenceParameters = Field(default_factory=InferenceParameters)
306
- provider: Optional[str] = None
414
+ inference_parameters: InferenceParamsT = Field(default_factory=ChatCompletionInferenceParams)
415
+ provider: str | None = None
416
+
417
+ @property
418
+ def generation_type(self) -> GenerationType:
419
+ """Get the generation type from the inference parameters."""
420
+ return self.inference_parameters.generation_type
421
+
422
+ @field_validator("inference_parameters", mode="before")
423
+ @classmethod
424
+ def _convert_inference_parameters(cls, value: Any) -> Any:
425
+ """Convert raw dict to appropriate inference parameters type based on field presence."""
426
+ if isinstance(value, dict):
427
+ # Infer type from presence of embedding-specific fields
428
+ if "encoding_format" in value or "dimensions" in value:
429
+ return EmbeddingInferenceParams(**value)
430
+ else:
431
+ return ChatCompletionInferenceParams(**value)
432
+ return value
307
433
 
308
434
 
309
435
  class ModelProvider(ConfigBase):
@@ -320,11 +446,11 @@ class ModelProvider(ConfigBase):
320
446
  name: str
321
447
  endpoint: str
322
448
  provider_type: str = "openai"
323
- api_key: Optional[str] = None
324
- extra_body: Optional[dict[str, Any]] = None
449
+ api_key: str | None = None
450
+ extra_body: dict[str, Any] | None = None
325
451
 
326
452
 
327
- def load_model_configs(model_configs: Union[list[ModelConfig], str, Path]) -> list[ModelConfig]:
453
+ def load_model_configs(model_configs: list[ModelConfig] | str | Path) -> list[ModelConfig]:
328
454
  if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs):
329
455
  return model_configs
330
456
  json_config = smart_load_yaml(model_configs)
@@ -3,8 +3,6 @@
3
3
 
4
4
  from __future__ import annotations
5
5
 
6
- from typing import Optional, Union
7
-
8
6
  import pandas as pd
9
7
 
10
8
  from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
@@ -17,9 +15,9 @@ class PreviewResults(WithRecordSamplerMixin):
17
15
  self,
18
16
  *,
19
17
  config_builder: DataDesignerConfigBuilder,
20
- dataset: Optional[pd.DataFrame] = None,
21
- analysis: Optional[DatasetProfilerResults] = None,
22
- processor_artifacts: Optional[dict[str, Union[list[str], str]]] = None,
18
+ dataset: pd.DataFrame | None = None,
19
+ analysis: DatasetProfilerResults | None = None,
20
+ processor_artifacts: dict[str, list[str] | str] | None = None,
23
21
  ):
24
22
  """Creates a new instance with results from a Data Designer preview run.
25
23
 
@@ -29,7 +27,7 @@ class PreviewResults(WithRecordSamplerMixin):
29
27
  analysis: Analysis of the preview run.
30
28
  processor_artifacts: Artifacts generated by the processors.
31
29
  """
32
- self.dataset: Optional[pd.DataFrame] = dataset
33
- self.analysis: Optional[DatasetProfilerResults] = analysis
34
- self.processor_artifacts: Optional[dict[str, Union[list[str], str]]] = processor_artifacts
30
+ self.dataset: pd.DataFrame | None = dataset
31
+ self.analysis: DatasetProfilerResults | None = analysis
32
+ self.processor_artifacts: dict[str, list[str] | str] | None = processor_artifacts
35
33
  self._config_builder = config_builder
@@ -7,6 +7,7 @@ from enum import Enum
7
7
  from typing import Any, Literal
8
8
 
9
9
  from pydantic import Field, field_validator
10
+ from typing_extensions import TypeAlias
10
11
 
11
12
  from data_designer.config.base import ConfigBase
12
13
  from data_designer.config.dataset_builders import BuildStage
@@ -16,11 +17,30 @@ SUPPORTED_STAGES = [BuildStage.POST_BATCH]
16
17
 
17
18
 
18
19
  class ProcessorType(str, Enum):
20
+ """Enumeration of available processor types.
21
+
22
+ Attributes:
23
+ DROP_COLUMNS: Processor that removes specified columns from the output dataset.
24
+ SCHEMA_TRANSFORM: Processor that creates a new dataset with a transformed schema using Jinja2 templates.
25
+ """
26
+
19
27
  DROP_COLUMNS = "drop_columns"
20
28
  SCHEMA_TRANSFORM = "schema_transform"
21
29
 
22
30
 
23
31
  class ProcessorConfig(ConfigBase, ABC):
32
+ """Abstract base class for all processor configuration types.
33
+
34
+ Processors are transformations that run before or after columns are generated.
35
+ They can modify, reshape, or augment the dataset before it's saved.
36
+
37
+ Attributes:
38
+ name: Unique name of the processor, used to identify the processor in results
39
+ and to name output artifacts on disk.
40
+ build_stage: The stage at which the processor runs. Currently only `POST_BATCH`
41
+ is supported, meaning processors run after each batch of columns is generated.
42
+ """
43
+
24
44
  name: str = Field(
25
45
  description="The name of the processor, used to identify the processor in the results and to write the artifacts to disk.",
26
46
  )
@@ -28,6 +48,7 @@ class ProcessorConfig(ConfigBase, ABC):
28
48
  default=BuildStage.POST_BATCH,
29
49
  description=f"The stage at which the processor will run. Supported stages: {', '.join(SUPPORTED_STAGES)}",
30
50
  )
51
+ processor_type: str
31
52
 
32
53
  @field_validator("build_stage")
33
54
  def validate_build_stage(cls, v: BuildStage) -> BuildStage:
@@ -38,7 +59,16 @@ class ProcessorConfig(ConfigBase, ABC):
38
59
  return v
39
60
 
40
61
 
41
- def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs) -> ProcessorConfig:
62
+ def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs: Any) -> ProcessorConfig:
63
+ """Create a processor configuration from a processor type and keyword arguments.
64
+
65
+ Args:
66
+ processor_type: The type of processor to create.
67
+ **kwargs: Additional keyword arguments passed to the processor constructor.
68
+
69
+ Returns:
70
+ A processor configuration object of the specified type.
71
+ """
42
72
  if processor_type == ProcessorType.DROP_COLUMNS:
43
73
  return DropColumnsProcessorConfig(**kwargs)
44
74
  elif processor_type == ProcessorType.SCHEMA_TRANSFORM:
@@ -46,11 +76,39 @@ def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs) ->
46
76
 
47
77
 
48
78
  class DropColumnsProcessorConfig(ProcessorConfig):
49
- column_names: list[str]
79
+ """Configuration for dropping columns from the output dataset.
80
+
81
+ This processor removes specified columns from the generated dataset. The dropped
82
+ columns are saved separately in a `dropped-columns` directory for reference.
83
+ When this processor is added via the config builder, the corresponding column
84
+ configs are automatically marked with `drop = True`.
85
+
86
+ Alternatively, you can set `drop = True` when configuring a column.
87
+
88
+ Attributes:
89
+ column_names: List of column names to remove from the output dataset.
90
+ processor_type: Discriminator field, always `ProcessorType.DROP_COLUMNS` for this configuration type.
91
+ """
92
+
93
+ column_names: list[str] = Field(description="List of column names to drop from the output dataset.")
50
94
  processor_type: Literal[ProcessorType.DROP_COLUMNS] = ProcessorType.DROP_COLUMNS
51
95
 
52
96
 
53
97
  class SchemaTransformProcessorConfig(ProcessorConfig):
98
+ """Configuration for transforming the dataset schema using Jinja2 templates.
99
+
100
+ This processor creates a new dataset with a transformed schema. Each key in the
101
+ template becomes a column in the output, and values are Jinja2 templates that
102
+ can reference any column in the batch. The transformed dataset is written to
103
+ a `processors-outputs/{processor_name}/` directory alongside the main dataset.
104
+
105
+ Attributes:
106
+ template: Dictionary defining the output schema. Keys are new column names,
107
+ values are Jinja2 templates (strings, lists, or nested structures).
108
+ Must be JSON-serializable.
109
+ processor_type: Discriminator field, always `ProcessorType.SCHEMA_TRANSFORM` for this configuration type.
110
+ """
111
+
54
112
  template: dict[str, Any] = Field(
55
113
  ...,
56
114
  description="""
@@ -83,3 +141,6 @@ class SchemaTransformProcessorConfig(ProcessorConfig):
83
141
  if "not JSON serializable" in str(e):
84
142
  raise InvalidConfigError("Template must be JSON serializable")
85
143
  return v
144
+
145
+
146
+ ProcessorConfigT: TypeAlias = DropColumnsProcessorConfig | SchemaTransformProcessorConfig
@@ -3,7 +3,6 @@
3
3
 
4
4
  from abc import ABC, abstractmethod
5
5
  from enum import Enum
6
- from typing import Union
7
6
 
8
7
  from typing_extensions import TypeAlias
9
8
 
@@ -48,4 +47,4 @@ class ColumnInequalityConstraint(Constraint):
48
47
  return ConstraintType.COLUMN_INEQUALITY
49
48
 
50
49
 
51
- ColumnConstraintT: TypeAlias = Union[ScalarInequalityConstraint, ColumnInequalityConstraint]
50
+ ColumnConstraintT: TypeAlias = ScalarInequalityConstraint | ColumnInequalityConstraint
@@ -2,7 +2,7 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from enum import Enum
5
- from typing import Literal, Optional, Union
5
+ from typing import Literal
6
6
 
7
7
  import pandas as pd
8
8
  from pydantic import Field, field_validator, model_validator
@@ -54,12 +54,12 @@ class CategorySamplerParams(ConfigBase):
54
54
  Larger weights result in higher sampling probability for the corresponding value.
55
55
  """
56
56
 
57
- values: list[Union[str, int, float]] = Field(
57
+ values: list[str | int | float] = Field(
58
58
  ...,
59
59
  min_length=1,
60
60
  description="List of possible categorical values that can be sampled from.",
61
61
  )
62
- weights: Optional[list[float]] = Field(
62
+ weights: list[float] | None = Field(
63
63
  default=None,
64
64
  description=(
65
65
  "List of unnormalized probability weights to assigned to each value, in order. "
@@ -134,7 +134,7 @@ class SubcategorySamplerParams(ConfigBase):
134
134
  """
135
135
 
136
136
  category: str = Field(..., description="Name of parent category to this subcategory.")
137
- values: dict[str, list[Union[str, int, float]]] = Field(
137
+ values: dict[str, list[str | int | float]] = Field(
138
138
  ...,
139
139
  description="Mapping from each value of parent category to a list of subcategory values.",
140
140
  )
@@ -214,7 +214,7 @@ class UUIDSamplerParams(ConfigBase):
214
214
  lowercase UUIDs.
215
215
  """
216
216
 
217
- prefix: Optional[str] = Field(default=None, description="String prepended to the front of the UUID.")
217
+ prefix: str | None = Field(default=None, description="String prepended to the front of the UUID.")
218
218
  short_form: bool = Field(
219
219
  default=False,
220
220
  description="If true, all UUIDs sampled will be truncated at 8 characters.",
@@ -259,7 +259,7 @@ class ScipySamplerParams(ConfigBase):
259
259
  ...,
260
260
  description="Parameters of the scipy.stats distribution given in `dist_name`.",
261
261
  )
262
- decimal_places: Optional[int] = Field(
262
+ decimal_places: int | None = Field(
263
263
  default=None, description="Number of decimal places to round the sampled values to."
264
264
  )
265
265
  sampler_type: Literal[SamplerType.SCIPY] = SamplerType.SCIPY
@@ -356,7 +356,7 @@ class GaussianSamplerParams(ConfigBase):
356
356
 
357
357
  mean: float = Field(..., description="Mean of the Gaussian distribution")
358
358
  stddev: float = Field(..., description="Standard deviation of the Gaussian distribution")
359
- decimal_places: Optional[int] = Field(
359
+ decimal_places: int | None = Field(
360
360
  default=None, description="Number of decimal places to round the sampled values to."
361
361
  )
362
362
  sampler_type: Literal[SamplerType.GAUSSIAN] = SamplerType.GAUSSIAN
@@ -398,7 +398,7 @@ class UniformSamplerParams(ConfigBase):
398
398
 
399
399
  low: float = Field(..., description="Lower bound of the uniform distribution, inclusive.")
400
400
  high: float = Field(..., description="Upper bound of the uniform distribution, inclusive.")
401
- decimal_places: Optional[int] = Field(
401
+ decimal_places: int | None = Field(
402
402
  default=None, description="Number of decimal places to round the sampled values to."
403
403
  )
404
404
  sampler_type: Literal[SamplerType.UNIFORM] = SamplerType.UNIFORM
@@ -421,8 +421,8 @@ class PersonSamplerParams(ConfigBase):
421
421
 
422
422
  Attributes:
423
423
  locale: Locale string determining the language and geographic region for synthetic people.
424
- Format: language_COUNTRY (e.g., "en_US", "en_GB", "fr_FR", "de_DE", "es_ES", "ja_JP").
425
- Defaults to "en_US".
424
+ Must be a locale supported by a managed Nemotron Personas dataset. The dataset must
425
+ be downloaded and available in the managed assets directory.
426
426
  sex: If specified, filters to only sample people of the specified sex. Options: "Male" or
427
427
  "Female". If None, samples both sexes.
428
428
  city: If specified, filters to only sample people from the specified city or cities. Can be
@@ -447,11 +447,11 @@ class PersonSamplerParams(ConfigBase):
447
447
  f"{', '.join(LOCALES_WITH_MANAGED_DATASETS)}."
448
448
  ),
449
449
  )
450
- sex: Optional[SexT] = Field(
450
+ sex: SexT | None = Field(
451
451
  default=None,
452
452
  description="If specified, then only synthetic people of the specified sex will be sampled.",
453
453
  )
454
- city: Optional[Union[str, list[str]]] = Field(
454
+ city: str | list[str] | None = Field(
455
455
  default=None,
456
456
  description="If specified, then only synthetic people from these cities will be sampled.",
457
457
  )
@@ -461,7 +461,7 @@ class PersonSamplerParams(ConfigBase):
461
461
  min_length=2,
462
462
  max_length=2,
463
463
  )
464
- select_field_values: Optional[dict[str, list[str]]] = Field(
464
+ select_field_values: dict[str, list[str]] | None = Field(
465
465
  default=None,
466
466
  description=(
467
467
  "Sample synthetic people with the specified field values. This is meant to be a flexible argument for "
@@ -529,11 +529,11 @@ class PersonFromFakerSamplerParams(ConfigBase):
529
529
  "that a synthetic person will be sampled from. E.g, en_US, en_GB, fr_FR, ..."
530
530
  ),
531
531
  )
532
- sex: Optional[SexT] = Field(
532
+ sex: SexT | None = Field(
533
533
  default=None,
534
534
  description="If specified, then only synthetic people of the specified sex will be sampled.",
535
535
  )
536
- city: Optional[Union[str, list[str]]] = Field(
536
+ city: str | list[str] | None = Field(
537
537
  default=None,
538
538
  description="If specified, then only synthetic people from these cities will be sampled.",
539
539
  )
@@ -585,22 +585,22 @@ class PersonFromFakerSamplerParams(ConfigBase):
585
585
  return value
586
586
 
587
587
 
588
- SamplerParamsT: TypeAlias = Union[
589
- SubcategorySamplerParams,
590
- CategorySamplerParams,
591
- DatetimeSamplerParams,
592
- PersonSamplerParams,
593
- PersonFromFakerSamplerParams,
594
- TimeDeltaSamplerParams,
595
- UUIDSamplerParams,
596
- BernoulliSamplerParams,
597
- BernoulliMixtureSamplerParams,
598
- BinomialSamplerParams,
599
- GaussianSamplerParams,
600
- PoissonSamplerParams,
601
- UniformSamplerParams,
602
- ScipySamplerParams,
603
- ]
588
+ SamplerParamsT: TypeAlias = (
589
+ SubcategorySamplerParams
590
+ | CategorySamplerParams
591
+ | DatetimeSamplerParams
592
+ | PersonSamplerParams
593
+ | PersonFromFakerSamplerParams
594
+ | TimeDeltaSamplerParams
595
+ | UUIDSamplerParams
596
+ | BernoulliSamplerParams
597
+ | BernoulliMixtureSamplerParams
598
+ | BinomialSamplerParams
599
+ | GaussianSamplerParams
600
+ | PoissonSamplerParams
601
+ | UniformSamplerParams
602
+ | ScipySamplerParams
603
+ )
604
604
 
605
605
 
606
606
  def is_numerical_sampler_type(sampler_type: SamplerType) -> bool:
@@ -3,7 +3,6 @@
3
3
 
4
4
  from abc import ABC
5
5
  from enum import Enum
6
- from typing import Optional, Union
7
6
 
8
7
  from pydantic import Field, field_validator, model_validator
9
8
  from typing_extensions import Self
@@ -112,7 +111,7 @@ class SeedConfig(ConfigBase):
112
111
 
113
112
  dataset: str
114
113
  sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED
115
- selection_strategy: Optional[Union[IndexRange, PartitionBlock]] = None
114
+ selection_strategy: IndexRange | PartitionBlock | None = None
116
115
 
117
116
 
118
117
  class SeedDatasetReference(ABC, ConfigBase):