data-designer 0.1.5__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. data_designer/_version.py +2 -2
  2. data_designer/cli/README.md +15 -1
  3. data_designer/cli/commands/download.py +56 -0
  4. data_designer/cli/commands/list.py +4 -18
  5. data_designer/cli/controllers/__init__.py +2 -1
  6. data_designer/cli/controllers/download_controller.py +217 -0
  7. data_designer/cli/controllers/model_controller.py +4 -3
  8. data_designer/cli/forms/field.py +65 -19
  9. data_designer/cli/forms/model_builder.py +251 -44
  10. data_designer/cli/main.py +11 -1
  11. data_designer/cli/repositories/persona_repository.py +88 -0
  12. data_designer/cli/services/__init__.py +2 -1
  13. data_designer/cli/services/download_service.py +97 -0
  14. data_designer/cli/ui.py +131 -0
  15. data_designer/cli/utils.py +34 -0
  16. data_designer/config/analysis/__init__.py +2 -0
  17. data_designer/config/analysis/column_profilers.py +75 -7
  18. data_designer/config/analysis/column_statistics.py +192 -48
  19. data_designer/config/analysis/dataset_profiler.py +23 -5
  20. data_designer/config/analysis/utils/reporting.py +3 -3
  21. data_designer/config/base.py +3 -3
  22. data_designer/config/column_configs.py +27 -6
  23. data_designer/config/column_types.py +24 -17
  24. data_designer/config/config_builder.py +36 -27
  25. data_designer/config/data_designer_config.py +7 -7
  26. data_designer/config/datastore.py +6 -6
  27. data_designer/config/default_model_settings.py +27 -34
  28. data_designer/config/exports.py +8 -0
  29. data_designer/config/models.py +155 -29
  30. data_designer/config/preview_results.py +6 -8
  31. data_designer/config/processors.py +63 -2
  32. data_designer/config/sampler_constraints.py +1 -2
  33. data_designer/config/sampler_params.py +50 -31
  34. data_designer/config/seed.py +1 -2
  35. data_designer/config/utils/code_lang.py +4 -5
  36. data_designer/config/utils/constants.py +31 -8
  37. data_designer/config/utils/io_helpers.py +5 -5
  38. data_designer/config/utils/misc.py +1 -4
  39. data_designer/config/utils/numerical_helpers.py +2 -2
  40. data_designer/config/utils/type_helpers.py +3 -3
  41. data_designer/config/utils/validation.py +7 -8
  42. data_designer/config/utils/visualization.py +32 -17
  43. data_designer/config/validator_params.py +4 -8
  44. data_designer/engine/analysis/column_profilers/base.py +0 -7
  45. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
  46. data_designer/engine/analysis/column_statistics.py +16 -16
  47. data_designer/engine/analysis/dataset_profiler.py +25 -4
  48. data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
  49. data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
  50. data_designer/engine/column_generators/generators/base.py +34 -0
  51. data_designer/engine/column_generators/generators/embedding.py +45 -0
  52. data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
  53. data_designer/engine/column_generators/registry.py +4 -2
  54. data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
  55. data_designer/engine/configurable_task.py +2 -2
  56. data_designer/engine/dataset_builders/artifact_storage.py +1 -2
  57. data_designer/engine/dataset_builders/column_wise_builder.py +58 -15
  58. data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
  59. data_designer/engine/models/facade.py +66 -9
  60. data_designer/engine/models/litellm_overrides.py +5 -6
  61. data_designer/engine/models/parsers/errors.py +2 -4
  62. data_designer/engine/models/parsers/parser.py +2 -3
  63. data_designer/engine/models/parsers/postprocessors.py +3 -4
  64. data_designer/engine/models/parsers/types.py +4 -4
  65. data_designer/engine/models/registry.py +47 -12
  66. data_designer/engine/models/telemetry.py +355 -0
  67. data_designer/engine/models/usage.py +7 -9
  68. data_designer/engine/processing/ginja/ast.py +1 -2
  69. data_designer/engine/processing/utils.py +40 -2
  70. data_designer/engine/registry/base.py +12 -12
  71. data_designer/engine/sampling_gen/constraints.py +1 -2
  72. data_designer/engine/sampling_gen/data_sources/base.py +14 -14
  73. data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
  74. data_designer/engine/sampling_gen/people_gen.py +3 -7
  75. data_designer/engine/validators/base.py +2 -2
  76. data_designer/logging.py +2 -2
  77. data_designer/plugin_manager.py +3 -3
  78. data_designer/plugins/plugin.py +3 -3
  79. data_designer/plugins/registry.py +2 -2
  80. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/METADATA +32 -1
  81. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/RECORD +84 -77
  82. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/WHEEL +0 -0
  83. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/entry_points.txt +0 -0
  84. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -6,9 +6,15 @@ import logging
6
6
  import os
7
7
  from functools import lru_cache
8
8
  from pathlib import Path
9
- from typing import Any, Literal, Optional
10
-
11
- from data_designer.config.models import InferenceParameters, ModelConfig, ModelProvider
9
+ from typing import Any, Literal
10
+
11
+ from data_designer.config.models import (
12
+ ChatCompletionInferenceParams,
13
+ EmbeddingInferenceParams,
14
+ InferenceParamsT,
15
+ ModelConfig,
16
+ ModelProvider,
17
+ )
12
18
  from data_designer.config.utils.constants import (
13
19
  MANAGED_ASSETS_PATH,
14
20
  MODEL_CONFIGS_FILE_PATH,
@@ -21,46 +27,32 @@ from data_designer.config.utils.io_helpers import load_config_file, save_config_
21
27
  logger = logging.getLogger(__name__)
22
28
 
23
29
 
24
- def get_default_text_alias_inference_parameters() -> InferenceParameters:
25
- return InferenceParameters(
26
- temperature=0.85,
27
- top_p=0.95,
28
- )
29
-
30
-
31
- def get_default_reasoning_alias_inference_parameters() -> InferenceParameters:
32
- return InferenceParameters(
33
- temperature=0.35,
34
- top_p=0.95,
35
- )
36
-
37
-
38
- def get_default_vision_alias_inference_parameters() -> InferenceParameters:
39
- return InferenceParameters(
40
- temperature=0.85,
41
- top_p=0.95,
42
- )
43
-
44
-
45
- def get_default_inference_parameters(model_alias: Literal["text", "reasoning", "vision"]) -> InferenceParameters:
30
+ def get_default_inference_parameters(
31
+ model_alias: Literal["text", "reasoning", "vision", "embedding"],
32
+ inference_parameters: dict[str, Any],
33
+ ) -> InferenceParamsT:
46
34
  if model_alias == "reasoning":
47
- return get_default_reasoning_alias_inference_parameters()
35
+ return ChatCompletionInferenceParams(**inference_parameters)
48
36
  elif model_alias == "vision":
49
- return get_default_vision_alias_inference_parameters()
37
+ return ChatCompletionInferenceParams(**inference_parameters)
38
+ elif model_alias == "embedding":
39
+ return EmbeddingInferenceParams(**inference_parameters)
50
40
  else:
51
- return get_default_text_alias_inference_parameters()
41
+ return ChatCompletionInferenceParams(**inference_parameters)
52
42
 
53
43
 
54
44
  def get_builtin_model_configs() -> list[ModelConfig]:
55
45
  model_configs = []
56
46
  for provider, model_alias_map in PREDEFINED_PROVIDERS_MODEL_MAP.items():
57
- for model_alias, model_id in model_alias_map.items():
47
+ for model_alias, settings in model_alias_map.items():
58
48
  model_configs.append(
59
49
  ModelConfig(
60
50
  alias=f"{provider}-{model_alias}",
61
- model=model_id,
51
+ model=settings["model"],
62
52
  provider=provider,
63
- inference_parameters=get_default_inference_parameters(model_alias),
53
+ inference_parameters=get_default_inference_parameters(
54
+ model_alias, settings["inference_parameters"]
55
+ ),
64
56
  )
65
57
  )
66
58
  return model_configs
@@ -93,7 +85,7 @@ def get_default_providers() -> list[ModelProvider]:
93
85
  return []
94
86
 
95
87
 
96
- def get_default_provider_name() -> Optional[str]:
88
+ def get_default_provider_name() -> str | None:
97
89
  return _get_default_providers_file_content(MODEL_PROVIDERS_FILE_PATH).get("default")
98
90
 
99
91
 
@@ -103,7 +95,8 @@ def resolve_seed_default_model_settings() -> None:
103
95
  f"🍾 Default model configs were not found, so writing the following to {str(MODEL_CONFIGS_FILE_PATH)!r}"
104
96
  )
105
97
  save_config_file(
106
- MODEL_CONFIGS_FILE_PATH, {"model_configs": [mc.model_dump() for mc in get_builtin_model_configs()]}
98
+ MODEL_CONFIGS_FILE_PATH,
99
+ {"model_configs": [mc.model_dump(mode="json") for mc in get_builtin_model_configs()]},
107
100
  )
108
101
 
109
102
  if not MODEL_PROVIDERS_FILE_PATH.exists():
@@ -111,7 +104,7 @@ def resolve_seed_default_model_settings() -> None:
111
104
  f"🪄 Default model providers were not found, so writing the following to {str(MODEL_PROVIDERS_FILE_PATH)!r}"
112
105
  )
113
106
  save_config_file(
114
- MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump() for p in get_builtin_model_providers()]}
107
+ MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump(mode="json") for p in get_builtin_model_providers()]}
115
108
  )
116
109
 
117
110
  if not MANAGED_ASSETS_PATH.exists():
@@ -3,6 +3,7 @@
3
3
 
4
4
  from data_designer.config.analysis.column_profilers import JudgeScoreProfilerConfig
5
5
  from data_designer.config.column_configs import (
6
+ EmbeddingColumnConfig,
6
7
  ExpressionColumnConfig,
7
8
  LLMCodeColumnConfig,
8
9
  LLMJudgeColumnConfig,
@@ -19,6 +20,9 @@ from data_designer.config.data_designer_config import DataDesignerConfig
19
20
  from data_designer.config.dataset_builders import BuildStage
20
21
  from data_designer.config.datastore import DatastoreSettings
21
22
  from data_designer.config.models import (
23
+ ChatCompletionInferenceParams,
24
+ EmbeddingInferenceParams,
25
+ GenerationType,
22
26
  ImageContext,
23
27
  ImageFormat,
24
28
  InferenceParameters,
@@ -81,6 +85,7 @@ def get_config_exports() -> list[str]:
81
85
  CodeLang.__name__,
82
86
  CodeValidatorParams.__name__,
83
87
  ColumnInequalityConstraint.__name__,
88
+ ChatCompletionInferenceParams.__name__,
84
89
  DataDesignerColumnType.__name__,
85
90
  DataDesignerConfig.__name__,
86
91
  DataDesignerConfigBuilder.__name__,
@@ -89,8 +94,11 @@ def get_config_exports() -> list[str]:
89
94
  DatastoreSettings.__name__,
90
95
  DatetimeSamplerParams.__name__,
91
96
  DropColumnsProcessorConfig.__name__,
97
+ EmbeddingColumnConfig.__name__,
98
+ EmbeddingInferenceParams.__name__,
92
99
  ExpressionColumnConfig.__name__,
93
100
  GaussianSamplerParams.__name__,
101
+ GenerationType.__name__,
94
102
  IndexRange.__name__,
95
103
  InfoType.__name__,
96
104
  ImageContext.__name__,
@@ -5,10 +5,10 @@ import logging
5
5
  from abc import ABC, abstractmethod
6
6
  from enum import Enum
7
7
  from pathlib import Path
8
- from typing import Any, Generic, List, Optional, TypeVar, Union
8
+ from typing import Any, Generic, Literal, TypeVar
9
9
 
10
10
  import numpy as np
11
- from pydantic import BaseModel, Field, model_validator
11
+ from pydantic import BaseModel, Field, field_validator, model_validator
12
12
  from typing_extensions import Self, TypeAlias
13
13
 
14
14
  from data_designer.config.base import ConfigBase
@@ -74,7 +74,7 @@ class ImageContext(ModalityContext):
74
74
  """
75
75
 
76
76
  modality: Modality = Modality.IMAGE
77
- image_format: Optional[ImageFormat] = None
77
+ image_format: ImageFormat | None = None
78
78
 
79
79
  def get_context(self, record: dict) -> dict[str, Any]:
80
80
  """Get the context for the image modality.
@@ -122,8 +122,8 @@ class ManualDistributionParams(ConfigBase):
122
122
  weights: Optional list of weights for each value. If not provided, all values have equal probability.
123
123
  """
124
124
 
125
- values: List[float] = Field(min_length=1)
126
- weights: Optional[List[float]] = None
125
+ values: list[float] = Field(min_length=1)
126
+ weights: list[float] | None = None
127
127
 
128
128
  @model_validator(mode="after")
129
129
  def _normalize_weights(self) -> Self:
@@ -149,7 +149,7 @@ class ManualDistribution(Distribution[ManualDistributionParams]):
149
149
  params: Distribution parameters (values, weights).
150
150
  """
151
151
 
152
- distribution_type: Optional[DistributionType] = "manual"
152
+ distribution_type: DistributionType | None = "manual"
153
153
  params: ManualDistributionParams
154
154
 
155
155
  def sample(self) -> float:
@@ -190,7 +190,7 @@ class UniformDistribution(Distribution[UniformDistributionParams]):
190
190
  params: Distribution parameters (low, high).
191
191
  """
192
192
 
193
- distribution_type: Optional[DistributionType] = "uniform"
193
+ distribution_type: DistributionType | None = "uniform"
194
194
  params: UniformDistributionParams
195
195
 
196
196
  def sample(self) -> float:
@@ -202,36 +202,93 @@ class UniformDistribution(Distribution[UniformDistributionParams]):
202
202
  return float(np.random.uniform(low=self.params.low, high=self.params.high, size=1)[0])
203
203
 
204
204
 
205
- DistributionT: TypeAlias = Union[UniformDistribution, ManualDistribution]
205
+ DistributionT: TypeAlias = UniformDistribution | ManualDistribution
206
206
 
207
207
 
208
- class InferenceParameters(ConfigBase):
209
- """Configuration for LLM inference parameters.
208
+ class GenerationType(str, Enum):
209
+ CHAT_COMPLETION = "chat-completion"
210
+ EMBEDDING = "embedding"
211
+
212
+
213
+ class BaseInferenceParams(ConfigBase, ABC):
214
+ """Base configuration for inference parameters.
210
215
 
211
216
  Attributes:
212
- temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling.
213
- top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
214
- max_tokens: Maximum number of tokens (includes both input and output tokens).
217
+ generation_type: Type of generation (chat-completion or embedding). Acts as discriminator.
215
218
  max_parallel_requests: Maximum number of parallel requests to the model API.
216
219
  timeout: Timeout in seconds for each request.
217
220
  extra_body: Additional parameters to pass to the model API.
218
221
  """
219
222
 
220
- temperature: Optional[Union[float, DistributionT]] = None
221
- top_p: Optional[Union[float, DistributionT]] = None
222
- max_tokens: Optional[int] = Field(default=None, ge=1)
223
+ generation_type: GenerationType
223
224
  max_parallel_requests: int = Field(default=4, ge=1)
224
- timeout: Optional[int] = Field(default=None, ge=1)
225
- extra_body: Optional[dict[str, Any]] = None
225
+ timeout: int | None = Field(default=None, ge=1)
226
+ extra_body: dict[str, Any] | None = None
226
227
 
227
228
  @property
228
- def generate_kwargs(self) -> dict[str, Union[float, int]]:
229
+ def generate_kwargs(self) -> dict[str, Any]:
229
230
  """Get the generate kwargs for the inference parameters.
230
231
 
231
232
  Returns:
232
233
  A dictionary of the generate kwargs.
233
234
  """
234
235
  result = {}
236
+ if self.timeout is not None:
237
+ result["timeout"] = self.timeout
238
+ if self.extra_body is not None and self.extra_body != {}:
239
+ result["extra_body"] = self.extra_body
240
+ return result
241
+
242
+ def format_for_display(self) -> str:
243
+ """Format inference parameters for display.
244
+
245
+ Returns:
246
+ Formatted string of inference parameters
247
+ """
248
+ params_dict = self.model_dump(exclude_none=True, mode="json")
249
+
250
+ if not params_dict:
251
+ return "(none)"
252
+
253
+ parts = []
254
+ for key, value in params_dict.items():
255
+ formatted_value = self._format_value(key, value)
256
+ parts.append(f"{key}={formatted_value}")
257
+ return ", ".join(parts)
258
+
259
+ def _format_value(self, key: str, value: Any) -> str:
260
+ """Format a single parameter value. Override in subclasses for custom formatting.
261
+
262
+ Args:
263
+ key: Parameter name
264
+ value: Parameter value
265
+
266
+ Returns:
267
+ Formatted string representation of the value
268
+ """
269
+ if isinstance(value, float):
270
+ return f"{value:.2f}"
271
+ return str(value)
272
+
273
+
274
+ class ChatCompletionInferenceParams(BaseInferenceParams):
275
+ """Configuration for LLM inference parameters.
276
+
277
+ Attributes:
278
+ generation_type: Type of generation, always "chat-completion" for this class.
279
+ temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling.
280
+ top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
281
+ max_tokens: Maximum number of tokens (includes both input and output tokens).
282
+ """
283
+
284
+ generation_type: Literal[GenerationType.CHAT_COMPLETION] = GenerationType.CHAT_COMPLETION
285
+ temperature: float | DistributionT | None = None
286
+ top_p: float | DistributionT | None = None
287
+ max_tokens: int | None = Field(default=None, ge=1)
288
+
289
+ @property
290
+ def generate_kwargs(self) -> dict[str, Any]:
291
+ result = super().generate_kwargs
235
292
  if self.temperature is not None:
236
293
  result["temperature"] = (
237
294
  self.temperature.sample() if hasattr(self.temperature, "sample") else self.temperature
@@ -240,10 +297,6 @@ class InferenceParameters(ConfigBase):
240
297
  result["top_p"] = self.top_p.sample() if hasattr(self.top_p, "sample") else self.top_p
241
298
  if self.max_tokens is not None:
242
299
  result["max_tokens"] = self.max_tokens
243
- if self.timeout is not None:
244
- result["timeout"] = self.timeout
245
- if self.extra_body is not None and self.extra_body != {}:
246
- result["extra_body"] = self.extra_body
247
300
  return result
248
301
 
249
302
  @model_validator(mode="after")
@@ -266,7 +319,7 @@ class InferenceParameters(ConfigBase):
266
319
 
267
320
  def _run_validation(
268
321
  self,
269
- value: Union[float, DistributionT, None],
322
+ value: float | DistributionT | None,
270
323
  param_name: str,
271
324
  min_value: float,
272
325
  max_value: float,
@@ -289,6 +342,61 @@ class InferenceParameters(ConfigBase):
289
342
  def _is_value_in_range(self, value: float, min_value: float, max_value: float) -> bool:
290
343
  return min_value <= value <= max_value
291
344
 
345
+ def _format_value(self, key: str, value: Any) -> str:
346
+ """Format chat completion parameter values, including distributions.
347
+
348
+ Args:
349
+ key: Parameter name
350
+ value: Parameter value
351
+
352
+ Returns:
353
+ Formatted string representation of the value
354
+ """
355
+ if isinstance(value, dict) and "distribution_type" in value:
356
+ return "dist"
357
+ return super()._format_value(key, value)
358
+
359
+
360
+ # Maintain backwards compatibility with a deprecation warning
361
+ class InferenceParameters(ChatCompletionInferenceParams):
362
+ """
363
+ Deprecated: Use ChatCompletionInferenceParams instead.
364
+ This alias will be removed in a future version.
365
+ """
366
+
367
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
368
+ logger.warning(
369
+ "InferenceParameters is deprecated and will be removed in a future version. "
370
+ "Use ChatCompletionInferenceParams instead."
371
+ )
372
+ super().__init__(*args, **kwargs)
373
+
374
+
375
+ class EmbeddingInferenceParams(BaseInferenceParams):
376
+ """Configuration for embedding generation parameters.
377
+
378
+ Attributes:
379
+ generation_type: Type of generation, always "embedding" for this class.
380
+ encoding_format: Format of the embedding encoding ("float" or "base64").
381
+ dimensions: Number of dimensions for the embedding.
382
+ """
383
+
384
+ generation_type: Literal[GenerationType.EMBEDDING] = GenerationType.EMBEDDING
385
+ encoding_format: Literal["float", "base64"] = "float"
386
+ dimensions: int | None = None
387
+
388
+ @property
389
+ def generate_kwargs(self) -> dict[str, float | int]:
390
+ result = super().generate_kwargs
391
+ if self.encoding_format is not None:
392
+ result["encoding_format"] = self.encoding_format
393
+ if self.dimensions is not None:
394
+ result["dimensions"] = self.dimensions
395
+ return result
396
+
397
+
398
+ InferenceParamsT: TypeAlias = ChatCompletionInferenceParams | EmbeddingInferenceParams | InferenceParameters
399
+
292
400
 
293
401
  class ModelConfig(ConfigBase):
294
402
  """Configuration for a model used for generation.
@@ -297,13 +405,31 @@ class ModelConfig(ConfigBase):
297
405
  alias: User-defined alias to reference in column configurations.
298
406
  model: Model identifier (e.g., from build.nvidia.com or other providers).
299
407
  inference_parameters: Inference parameters for the model (temperature, top_p, max_tokens, etc.).
408
+ The generation_type is determined by the type of inference_parameters.
300
409
  provider: Optional model provider name if using custom providers.
301
410
  """
302
411
 
303
412
  alias: str
304
413
  model: str
305
- inference_parameters: InferenceParameters = Field(default_factory=InferenceParameters)
306
- provider: Optional[str] = None
414
+ inference_parameters: InferenceParamsT = Field(default_factory=ChatCompletionInferenceParams)
415
+ provider: str | None = None
416
+
417
+ @property
418
+ def generation_type(self) -> GenerationType:
419
+ """Get the generation type from the inference parameters."""
420
+ return self.inference_parameters.generation_type
421
+
422
+ @field_validator("inference_parameters", mode="before")
423
+ @classmethod
424
+ def _convert_inference_parameters(cls, value: Any) -> Any:
425
+ """Convert raw dict to appropriate inference parameters type based on field presence."""
426
+ if isinstance(value, dict):
427
+ # Infer type from presence of embedding-specific fields
428
+ if "encoding_format" in value or "dimensions" in value:
429
+ return EmbeddingInferenceParams(**value)
430
+ else:
431
+ return ChatCompletionInferenceParams(**value)
432
+ return value
307
433
 
308
434
 
309
435
  class ModelProvider(ConfigBase):
@@ -320,11 +446,11 @@ class ModelProvider(ConfigBase):
320
446
  name: str
321
447
  endpoint: str
322
448
  provider_type: str = "openai"
323
- api_key: Optional[str] = None
324
- extra_body: Optional[dict[str, Any]] = None
449
+ api_key: str | None = None
450
+ extra_body: dict[str, Any] | None = None
325
451
 
326
452
 
327
- def load_model_configs(model_configs: Union[list[ModelConfig], str, Path]) -> list[ModelConfig]:
453
+ def load_model_configs(model_configs: list[ModelConfig] | str | Path) -> list[ModelConfig]:
328
454
  if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs):
329
455
  return model_configs
330
456
  json_config = smart_load_yaml(model_configs)
@@ -3,8 +3,6 @@
3
3
 
4
4
  from __future__ import annotations
5
5
 
6
- from typing import Optional, Union
7
-
8
6
  import pandas as pd
9
7
 
10
8
  from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
@@ -17,9 +15,9 @@ class PreviewResults(WithRecordSamplerMixin):
17
15
  self,
18
16
  *,
19
17
  config_builder: DataDesignerConfigBuilder,
20
- dataset: Optional[pd.DataFrame] = None,
21
- analysis: Optional[DatasetProfilerResults] = None,
22
- processor_artifacts: Optional[dict[str, Union[list[str], str]]] = None,
18
+ dataset: pd.DataFrame | None = None,
19
+ analysis: DatasetProfilerResults | None = None,
20
+ processor_artifacts: dict[str, list[str] | str] | None = None,
23
21
  ):
24
22
  """Creates a new instance with results from a Data Designer preview run.
25
23
 
@@ -29,7 +27,7 @@ class PreviewResults(WithRecordSamplerMixin):
29
27
  analysis: Analysis of the preview run.
30
28
  processor_artifacts: Artifacts generated by the processors.
31
29
  """
32
- self.dataset: Optional[pd.DataFrame] = dataset
33
- self.analysis: Optional[DatasetProfilerResults] = analysis
34
- self.processor_artifacts: Optional[dict[str, Union[list[str], str]]] = processor_artifacts
30
+ self.dataset: pd.DataFrame | None = dataset
31
+ self.analysis: DatasetProfilerResults | None = analysis
32
+ self.processor_artifacts: dict[str, list[str] | str] | None = processor_artifacts
35
33
  self._config_builder = config_builder
@@ -7,6 +7,7 @@ from enum import Enum
7
7
  from typing import Any, Literal
8
8
 
9
9
  from pydantic import Field, field_validator
10
+ from typing_extensions import TypeAlias
10
11
 
11
12
  from data_designer.config.base import ConfigBase
12
13
  from data_designer.config.dataset_builders import BuildStage
@@ -16,11 +17,30 @@ SUPPORTED_STAGES = [BuildStage.POST_BATCH]
16
17
 
17
18
 
18
19
  class ProcessorType(str, Enum):
20
+ """Enumeration of available processor types.
21
+
22
+ Attributes:
23
+ DROP_COLUMNS: Processor that removes specified columns from the output dataset.
24
+ SCHEMA_TRANSFORM: Processor that creates a new dataset with a transformed schema using Jinja2 templates.
25
+ """
26
+
19
27
  DROP_COLUMNS = "drop_columns"
20
28
  SCHEMA_TRANSFORM = "schema_transform"
21
29
 
22
30
 
23
31
  class ProcessorConfig(ConfigBase, ABC):
32
+ """Abstract base class for all processor configuration types.
33
+
34
+ Processors are transformations that run before or after columns are generated.
35
+ They can modify, reshape, or augment the dataset before it's saved.
36
+
37
+ Attributes:
38
+ name: Unique name of the processor, used to identify the processor in results
39
+ and to name output artifacts on disk.
40
+ build_stage: The stage at which the processor runs. Currently only `POST_BATCH`
41
+ is supported, meaning processors run after each batch of columns is generated.
42
+ """
43
+
24
44
  name: str = Field(
25
45
  description="The name of the processor, used to identify the processor in the results and to write the artifacts to disk.",
26
46
  )
@@ -28,6 +48,7 @@ class ProcessorConfig(ConfigBase, ABC):
28
48
  default=BuildStage.POST_BATCH,
29
49
  description=f"The stage at which the processor will run. Supported stages: {', '.join(SUPPORTED_STAGES)}",
30
50
  )
51
+ processor_type: str
31
52
 
32
53
  @field_validator("build_stage")
33
54
  def validate_build_stage(cls, v: BuildStage) -> BuildStage:
@@ -38,7 +59,16 @@ class ProcessorConfig(ConfigBase, ABC):
38
59
  return v
39
60
 
40
61
 
41
- def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs) -> ProcessorConfig:
62
+ def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs: Any) -> ProcessorConfig:
63
+ """Create a processor configuration from a processor type and keyword arguments.
64
+
65
+ Args:
66
+ processor_type: The type of processor to create.
67
+ **kwargs: Additional keyword arguments passed to the processor constructor.
68
+
69
+ Returns:
70
+ A processor configuration object of the specified type.
71
+ """
42
72
  if processor_type == ProcessorType.DROP_COLUMNS:
43
73
  return DropColumnsProcessorConfig(**kwargs)
44
74
  elif processor_type == ProcessorType.SCHEMA_TRANSFORM:
@@ -46,11 +76,39 @@ def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs) ->
46
76
 
47
77
 
48
78
  class DropColumnsProcessorConfig(ProcessorConfig):
49
- column_names: list[str]
79
+ """Configuration for dropping columns from the output dataset.
80
+
81
+ This processor removes specified columns from the generated dataset. The dropped
82
+ columns are saved separately in a `dropped-columns` directory for reference.
83
+ When this processor is added via the config builder, the corresponding column
84
+ configs are automatically marked with `drop = True`.
85
+
86
+ Alternatively, you can set `drop = True` when configuring a column.
87
+
88
+ Attributes:
89
+ column_names: List of column names to remove from the output dataset.
90
+ processor_type: Discriminator field, always `ProcessorType.DROP_COLUMNS` for this configuration type.
91
+ """
92
+
93
+ column_names: list[str] = Field(description="List of column names to drop from the output dataset.")
50
94
  processor_type: Literal[ProcessorType.DROP_COLUMNS] = ProcessorType.DROP_COLUMNS
51
95
 
52
96
 
53
97
  class SchemaTransformProcessorConfig(ProcessorConfig):
98
+ """Configuration for transforming the dataset schema using Jinja2 templates.
99
+
100
+ This processor creates a new dataset with a transformed schema. Each key in the
101
+ template becomes a column in the output, and values are Jinja2 templates that
102
+ can reference any column in the batch. The transformed dataset is written to
103
+ a `processors-outputs/{processor_name}/` directory alongside the main dataset.
104
+
105
+ Attributes:
106
+ template: Dictionary defining the output schema. Keys are new column names,
107
+ values are Jinja2 templates (strings, lists, or nested structures).
108
+ Must be JSON-serializable.
109
+ processor_type: Discriminator field, always `ProcessorType.SCHEMA_TRANSFORM` for this configuration type.
110
+ """
111
+
54
112
  template: dict[str, Any] = Field(
55
113
  ...,
56
114
  description="""
@@ -83,3 +141,6 @@ class SchemaTransformProcessorConfig(ProcessorConfig):
83
141
  if "not JSON serializable" in str(e):
84
142
  raise InvalidConfigError("Template must be JSON serializable")
85
143
  return v
144
+
145
+
146
+ ProcessorConfigT: TypeAlias = DropColumnsProcessorConfig | SchemaTransformProcessorConfig
@@ -3,7 +3,6 @@
3
3
 
4
4
  from abc import ABC, abstractmethod
5
5
  from enum import Enum
6
- from typing import Union
7
6
 
8
7
  from typing_extensions import TypeAlias
9
8
 
@@ -48,4 +47,4 @@ class ColumnInequalityConstraint(Constraint):
48
47
  return ConstraintType.COLUMN_INEQUALITY
49
48
 
50
49
 
51
- ColumnConstraintT: TypeAlias = Union[ScalarInequalityConstraint, ColumnInequalityConstraint]
50
+ ColumnConstraintT: TypeAlias = ScalarInequalityConstraint | ColumnInequalityConstraint