data-designer 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. data_designer/_version.py +2 -2
  2. data_designer/cli/README.md +15 -1
  3. data_designer/cli/commands/download.py +56 -0
  4. data_designer/cli/commands/list.py +4 -18
  5. data_designer/cli/controllers/__init__.py +2 -1
  6. data_designer/cli/controllers/download_controller.py +217 -0
  7. data_designer/cli/controllers/model_controller.py +4 -3
  8. data_designer/cli/forms/field.py +65 -19
  9. data_designer/cli/forms/model_builder.py +251 -44
  10. data_designer/cli/main.py +11 -1
  11. data_designer/cli/repositories/persona_repository.py +88 -0
  12. data_designer/cli/services/__init__.py +2 -1
  13. data_designer/cli/services/download_service.py +97 -0
  14. data_designer/cli/ui.py +131 -0
  15. data_designer/cli/utils.py +34 -0
  16. data_designer/config/analysis/__init__.py +2 -0
  17. data_designer/config/analysis/column_profilers.py +75 -7
  18. data_designer/config/analysis/column_statistics.py +192 -48
  19. data_designer/config/analysis/dataset_profiler.py +23 -5
  20. data_designer/config/analysis/utils/reporting.py +3 -3
  21. data_designer/config/base.py +3 -3
  22. data_designer/config/column_configs.py +27 -6
  23. data_designer/config/column_types.py +24 -17
  24. data_designer/config/config_builder.py +34 -26
  25. data_designer/config/data_designer_config.py +7 -7
  26. data_designer/config/datastore.py +6 -6
  27. data_designer/config/default_model_settings.py +27 -34
  28. data_designer/config/exports.py +14 -1
  29. data_designer/config/models.py +155 -29
  30. data_designer/config/preview_results.py +5 -4
  31. data_designer/config/processors.py +109 -4
  32. data_designer/config/sampler_constraints.py +1 -2
  33. data_designer/config/sampler_params.py +31 -31
  34. data_designer/config/seed.py +1 -2
  35. data_designer/config/utils/code_lang.py +4 -5
  36. data_designer/config/utils/constants.py +31 -8
  37. data_designer/config/utils/io_helpers.py +5 -5
  38. data_designer/config/utils/misc.py +1 -4
  39. data_designer/config/utils/numerical_helpers.py +2 -2
  40. data_designer/config/utils/type_helpers.py +3 -3
  41. data_designer/config/utils/validation.py +39 -9
  42. data_designer/config/utils/visualization.py +62 -15
  43. data_designer/config/validator_params.py +4 -8
  44. data_designer/engine/analysis/column_profilers/base.py +0 -7
  45. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
  46. data_designer/engine/analysis/column_statistics.py +16 -16
  47. data_designer/engine/analysis/dataset_profiler.py +25 -4
  48. data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
  49. data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
  50. data_designer/engine/column_generators/generators/base.py +34 -0
  51. data_designer/engine/column_generators/generators/embedding.py +45 -0
  52. data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
  53. data_designer/engine/column_generators/registry.py +4 -2
  54. data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
  55. data_designer/engine/configurable_task.py +2 -2
  56. data_designer/engine/dataset_builders/artifact_storage.py +14 -5
  57. data_designer/engine/dataset_builders/column_wise_builder.py +12 -8
  58. data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
  59. data_designer/engine/models/facade.py +66 -9
  60. data_designer/engine/models/litellm_overrides.py +5 -6
  61. data_designer/engine/models/parsers/errors.py +2 -4
  62. data_designer/engine/models/parsers/parser.py +2 -3
  63. data_designer/engine/models/parsers/postprocessors.py +3 -4
  64. data_designer/engine/models/parsers/types.py +4 -4
  65. data_designer/engine/models/registry.py +20 -11
  66. data_designer/engine/models/usage.py +7 -9
  67. data_designer/engine/processing/ginja/ast.py +1 -2
  68. data_designer/engine/processing/processors/drop_columns.py +1 -1
  69. data_designer/engine/processing/processors/registry.py +3 -0
  70. data_designer/engine/processing/processors/schema_transform.py +53 -0
  71. data_designer/engine/processing/utils.py +40 -2
  72. data_designer/engine/registry/base.py +12 -12
  73. data_designer/engine/sampling_gen/constraints.py +1 -2
  74. data_designer/engine/sampling_gen/data_sources/base.py +14 -14
  75. data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
  76. data_designer/engine/sampling_gen/people_gen.py +3 -7
  77. data_designer/engine/validators/base.py +2 -2
  78. data_designer/interface/data_designer.py +12 -0
  79. data_designer/interface/results.py +36 -0
  80. data_designer/logging.py +2 -2
  81. data_designer/plugin_manager.py +3 -3
  82. data_designer/plugins/plugin.py +3 -3
  83. data_designer/plugins/registry.py +2 -2
  84. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/METADATA +9 -9
  85. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/RECORD +88 -81
  86. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
  87. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
  88. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -3,7 +3,7 @@
3
3
 
4
4
  from functools import cached_property
5
5
  from pathlib import Path
6
- from typing import Annotated, Optional, Union
6
+ from typing import Annotated
7
7
 
8
8
  from pydantic import BaseModel, Field, field_validator
9
9
 
@@ -16,11 +16,26 @@ from data_designer.config.utils.numerical_helpers import prepare_number_for_repo
16
16
 
17
17
 
18
18
  class DatasetProfilerResults(BaseModel):
19
+ """Container for complete dataset profiling and analysis results.
20
+
21
+ Stores profiling results for a generated dataset, including statistics for all columns,
22
+ dataset-level metadata, and optional advanced profiler results. Provides methods for
23
+ computing derived metrics and generating formatted reports.
24
+
25
+ Attributes:
26
+ num_records: Actual number of records successfully generated in the dataset.
27
+ target_num_records: Target number of records that were requested to be generated.
28
+ column_statistics: List of statistics objects for all columns in the dataset. Each
29
+ column has statistics appropriate to its type. Must contain at least one column.
30
+ side_effect_column_names: Column names that were generated as side effects of other columns.
31
+ column_profiles: Column profiler results for specific columns when configured.
32
+ """
33
+
19
34
  num_records: int
20
35
  target_num_records: int
21
36
  column_statistics: list[Annotated[ColumnStatisticsT, Field(discriminator="column_type")]] = Field(..., min_length=1)
22
- side_effect_column_names: Optional[list[str]] = None
23
- column_profiles: Optional[list[ColumnProfilerResultsT]] = None
37
+ side_effect_column_names: list[str] | None = None
38
+ column_profiles: list[ColumnProfilerResultsT] | None = None
24
39
 
25
40
  @field_validator("num_records", "target_num_records", mode="before")
26
41
  def ensure_python_integers(cls, v: int) -> int:
@@ -28,10 +43,12 @@ class DatasetProfilerResults(BaseModel):
28
43
 
29
44
  @property
30
45
  def percent_complete(self) -> float:
46
+ """Returns the completion percentage of the dataset."""
31
47
  return 100 * self.num_records / (self.target_num_records + EPSILON)
32
48
 
33
49
  @cached_property
34
50
  def column_types(self) -> list[str]:
51
+ """Returns a sorted list of unique column types present in the dataset."""
35
52
  display_order = get_column_display_order()
36
53
  return sorted(
37
54
  list(set([c.column_type for c in self.column_statistics])),
@@ -39,12 +56,13 @@ class DatasetProfilerResults(BaseModel):
39
56
  )
40
57
 
41
58
  def get_column_statistics_by_type(self, column_type: DataDesignerColumnType) -> list[ColumnStatisticsT]:
59
+ """Filters column statistics to return only those of the specified type."""
42
60
  return [c for c in self.column_statistics if c.column_type == column_type]
43
61
 
44
62
  def to_report(
45
63
  self,
46
- save_path: Optional[Union[str, Path]] = None,
47
- include_sections: Optional[list[Union[ReportSection, DataDesignerColumnType]]] = None,
64
+ save_path: str | Path | None = None,
65
+ include_sections: list[ReportSection | DataDesignerColumnType] | None = None,
48
66
  ) -> None:
49
67
  """Generate and print an analysis report based on the dataset profiling results.
50
68
 
@@ -5,7 +5,7 @@ from __future__ import annotations
5
5
 
6
6
  from enum import Enum
7
7
  from pathlib import Path
8
- from typing import TYPE_CHECKING, Optional, Union
8
+ from typing import TYPE_CHECKING
9
9
 
10
10
  from rich.align import Align
11
11
  from rich.console import Console, Group
@@ -48,8 +48,8 @@ DEFAULT_INCLUDE_SECTIONS = [
48
48
 
49
49
  def generate_analysis_report(
50
50
  analysis: DatasetProfilerResults,
51
- save_path: Optional[Union[str, Path]] = None,
52
- include_sections: Optional[list[Union[ReportSection, DataDesignerColumnType]]] = None,
51
+ save_path: str | Path | None = None,
52
+ include_sections: list[ReportSection | DataDesignerColumnType] | None = None,
53
53
  ) -> None:
54
54
  """Generate an analysis report for dataset profiling results.
55
55
 
@@ -4,7 +4,7 @@
4
4
  from __future__ import annotations
5
5
 
6
6
  from pathlib import Path
7
- from typing import Any, Optional, Union
7
+ from typing import Any
8
8
 
9
9
  import yaml
10
10
  from pydantic import BaseModel, ConfigDict
@@ -31,7 +31,7 @@ class ExportableConfigBase(ConfigBase):
31
31
  """
32
32
  return self.model_dump(mode="json")
33
33
 
34
- def to_yaml(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[int] = 2, **kwargs) -> Optional[str]:
34
+ def to_yaml(self, path: str | Path | None = None, *, indent: int | None = 2, **kwargs) -> str | None:
35
35
  """Convert the configuration to a YAML string or file.
36
36
 
37
37
  Args:
@@ -49,7 +49,7 @@ class ExportableConfigBase(ConfigBase):
49
49
  with open(path, "w") as f:
50
50
  f.write(yaml_str)
51
51
 
52
- def to_json(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[int] = 2, **kwargs) -> Optional[str]:
52
+ def to_json(self, path: str | Path | None = None, *, indent: int | None = 2, **kwargs) -> str | None:
53
53
  """Convert the configuration to a JSON string or file.
54
54
 
55
55
  Args:
@@ -2,7 +2,7 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from abc import ABC
5
- from typing import Annotated, Literal, Optional, Type, Union
5
+ from typing import Annotated, Literal
6
6
 
7
7
  from pydantic import BaseModel, Discriminator, Field, model_validator
8
8
  from typing_extensions import Self
@@ -91,7 +91,7 @@ class SamplerColumnConfig(SingleColumnConfig):
91
91
  sampler_type: SamplerType
92
92
  params: Annotated[SamplerParamsT, Discriminator("sampler_type")]
93
93
  conditional_params: dict[str, Annotated[SamplerParamsT, Discriminator("sampler_type")]] = {}
94
- convert_to: Optional[str] = None
94
+ convert_to: str | None = None
95
95
  column_type: Literal["sampler"] = "sampler"
96
96
 
97
97
  @model_validator(mode="before")
@@ -146,8 +146,8 @@ class LLMTextColumnConfig(SingleColumnConfig):
146
146
 
147
147
  prompt: str
148
148
  model_alias: str
149
- system_prompt: Optional[str] = None
150
- multi_modal_context: Optional[list[ImageContext]] = None
149
+ system_prompt: str | None = None
150
+ multi_modal_context: list[ImageContext] | None = None
151
151
  column_type: Literal["llm-text"] = "llm-text"
152
152
 
153
153
  @property
@@ -222,7 +222,7 @@ class LLMStructuredColumnConfig(LLMTextColumnConfig):
222
222
  column_type: Discriminator field, always "llm-structured" for this configuration type.
223
223
  """
224
224
 
225
- output_format: Union[dict, Type[BaseModel]]
225
+ output_format: dict | type[BaseModel]
226
226
  column_type: Literal["llm-structured"] = "llm-structured"
227
227
 
228
228
  @model_validator(mode="after")
@@ -255,7 +255,7 @@ class Score(ConfigBase):
255
255
 
256
256
  name: str = Field(..., description="A clear name for this score.")
257
257
  description: str = Field(..., description="An informative and detailed assessment guide for using this score.")
258
- options: dict[Union[int, str], str] = Field(..., description="Score options in the format of {score: description}.")
258
+ options: dict[int | str, str] = Field(..., description="Score options in the format of {score: description}.")
259
259
 
260
260
 
261
261
  class LLMJudgeColumnConfig(LLMTextColumnConfig):
@@ -377,3 +377,24 @@ class SeedDatasetColumnConfig(SingleColumnConfig):
377
377
  """
378
378
 
379
379
  column_type: Literal["seed-dataset"] = "seed-dataset"
380
+
381
+
382
+ class EmbeddingColumnConfig(SingleColumnConfig):
383
+ """Configuration for embedding generation columns.
384
+
385
+ Embedding columns generate embeddings for text input using a specified model.
386
+
387
+ Attributes:
388
+ target_column: The column to generate embeddings for. The column could be a single text string or a list of text strings in stringified JSON format.
389
+ If it is a list of text strings in stringified JSON format, the embeddings will be generated for each text string.
390
+ model_alias: The model to use for embedding generation.
391
+ column_type: Discriminator field, always "embedding" for this configuration type.
392
+ """
393
+
394
+ target_column: str
395
+ model_alias: str
396
+ column_type: Literal["embedding"] = "embedding"
397
+
398
+ @property
399
+ def required_columns(self) -> list[str]:
400
+ return [self.target_column]
@@ -1,11 +1,11 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
- from typing import Union
5
4
 
6
5
  from typing_extensions import TypeAlias
7
6
 
8
7
  from data_designer.config.column_configs import (
8
+ EmbeddingColumnConfig,
9
9
  ExpressionColumnConfig,
10
10
  LLMCodeColumnConfig,
11
11
  LLMJudgeColumnConfig,
@@ -26,16 +26,17 @@ from data_designer.plugin_manager import PluginManager
26
26
 
27
27
  plugin_manager = PluginManager()
28
28
 
29
- ColumnConfigT: TypeAlias = Union[
30
- ExpressionColumnConfig,
31
- LLMCodeColumnConfig,
32
- LLMJudgeColumnConfig,
33
- LLMStructuredColumnConfig,
34
- LLMTextColumnConfig,
35
- SamplerColumnConfig,
36
- SeedDatasetColumnConfig,
37
- ValidationColumnConfig,
38
- ]
29
+ ColumnConfigT: TypeAlias = (
30
+ ExpressionColumnConfig
31
+ | LLMCodeColumnConfig
32
+ | LLMJudgeColumnConfig
33
+ | LLMStructuredColumnConfig
34
+ | LLMTextColumnConfig
35
+ | SamplerColumnConfig
36
+ | SeedDatasetColumnConfig
37
+ | ValidationColumnConfig
38
+ | EmbeddingColumnConfig
39
+ )
39
40
  ColumnConfigT = plugin_manager.inject_into_column_config_type_union(ColumnConfigT)
40
41
 
41
42
  DataDesignerColumnType = create_str_enum_from_discriminated_type_union(
@@ -54,13 +55,14 @@ COLUMN_TYPE_EMOJI_MAP = {
54
55
  DataDesignerColumnType.SEED_DATASET: "🌱",
55
56
  DataDesignerColumnType.SAMPLER: "🎲",
56
57
  DataDesignerColumnType.VALIDATION: "🔍",
58
+ DataDesignerColumnType.EMBEDDING: "🧬",
57
59
  }
58
60
  COLUMN_TYPE_EMOJI_MAP.update(
59
61
  {DataDesignerColumnType(p.name): p.emoji for p in plugin_manager.get_column_generator_plugins()}
60
62
  )
61
63
 
62
64
 
63
- def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumnType]) -> bool:
65
+ def column_type_used_in_execution_dag(column_type: str | DataDesignerColumnType) -> bool:
64
66
  """Return True if the column type is used in the workflow execution DAG."""
65
67
  column_type = resolve_string_enum(column_type, DataDesignerColumnType)
66
68
  dag_column_types = {
@@ -70,27 +72,29 @@ def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumn
70
72
  DataDesignerColumnType.LLM_STRUCTURED,
71
73
  DataDesignerColumnType.LLM_TEXT,
72
74
  DataDesignerColumnType.VALIDATION,
75
+ DataDesignerColumnType.EMBEDDING,
73
76
  }
74
77
  dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType))
75
78
  return column_type in dag_column_types
76
79
 
77
80
 
78
- def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType]) -> bool:
79
- """Return True if the column type is an LLM-generated column."""
81
+ def column_type_is_model_generated(column_type: str | DataDesignerColumnType) -> bool:
82
+ """Return True if the column type is a model-generated column."""
80
83
  column_type = resolve_string_enum(column_type, DataDesignerColumnType)
81
- llm_generated_column_types = {
84
+ model_generated_column_types = {
82
85
  DataDesignerColumnType.LLM_TEXT,
83
86
  DataDesignerColumnType.LLM_CODE,
84
87
  DataDesignerColumnType.LLM_STRUCTURED,
85
88
  DataDesignerColumnType.LLM_JUDGE,
89
+ DataDesignerColumnType.EMBEDDING,
86
90
  }
87
- llm_generated_column_types.update(
91
+ model_generated_column_types.update(
88
92
  plugin_manager.get_plugin_column_types(
89
93
  DataDesignerColumnType,
90
94
  required_resources=["model_registry"],
91
95
  )
92
96
  )
93
- return column_type in llm_generated_column_types
97
+ return column_type in model_generated_column_types
94
98
 
95
99
 
96
100
  def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT:
@@ -121,6 +125,8 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType
121
125
  return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs))
122
126
  if column_type == DataDesignerColumnType.SEED_DATASET:
123
127
  return SeedDatasetColumnConfig(name=name, **kwargs)
128
+ if column_type == DataDesignerColumnType.EMBEDDING:
129
+ return EmbeddingColumnConfig(name=name, **kwargs)
124
130
  if plugin := plugin_manager.get_column_generator_plugin_if_exists(column_type.value):
125
131
  return plugin.config_cls(name=name, **kwargs)
126
132
  raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover
@@ -135,6 +141,7 @@ def get_column_display_order() -> list[DataDesignerColumnType]:
135
141
  DataDesignerColumnType.LLM_CODE,
136
142
  DataDesignerColumnType.LLM_STRUCTURED,
137
143
  DataDesignerColumnType.LLM_JUDGE,
144
+ DataDesignerColumnType.EMBEDDING,
138
145
  DataDesignerColumnType.VALIDATION,
139
146
  DataDesignerColumnType.EXPRESSION,
140
147
  ]
@@ -6,7 +6,6 @@ from __future__ import annotations
6
6
  import json
7
7
  import logging
8
8
  from pathlib import Path
9
- from typing import Optional, Union
10
9
 
11
10
  from pygments import highlight
12
11
  from pygments.formatters import HtmlFormatter
@@ -19,7 +18,7 @@ from data_designer.config.column_configs import SeedDatasetColumnConfig
19
18
  from data_designer.config.column_types import (
20
19
  ColumnConfigT,
21
20
  DataDesignerColumnType,
22
- column_type_is_llm_generated,
21
+ column_type_is_model_generated,
23
22
  get_column_config_from_kwargs,
24
23
  get_column_display_order,
25
24
  )
@@ -29,7 +28,7 @@ from data_designer.config.datastore import DatastoreSettings, fetch_seed_dataset
29
28
  from data_designer.config.default_model_settings import get_default_model_configs
30
29
  from data_designer.config.errors import BuilderConfigurationError, InvalidColumnTypeError, InvalidConfigError
31
30
  from data_designer.config.models import ModelConfig, load_model_configs
32
- from data_designer.config.processors import ProcessorConfig, ProcessorType, get_processor_config_from_kwargs
31
+ from data_designer.config.processors import ProcessorConfigT, ProcessorType, get_processor_config_from_kwargs
33
32
  from data_designer.config.sampler_constraints import (
34
33
  ColumnConstraintT,
35
34
  ColumnInequalityConstraint,
@@ -69,7 +68,7 @@ class BuilderConfig(ExportableConfigBase):
69
68
  """
70
69
 
71
70
  data_designer: DataDesignerConfig
72
- datastore_settings: Optional[DatastoreSettings]
71
+ datastore_settings: DatastoreSettings | None
73
72
 
74
73
 
75
74
  class DataDesignerConfigBuilder:
@@ -79,7 +78,7 @@ class DataDesignerConfigBuilder:
79
78
  """
80
79
 
81
80
  @classmethod
82
- def from_config(cls, config: Union[dict, str, Path, BuilderConfig]) -> Self:
81
+ def from_config(cls, config: dict | str | Path | BuilderConfig) -> Self:
83
82
  """Create a DataDesignerConfigBuilder from an existing configuration.
84
83
 
85
84
  Args:
@@ -130,7 +129,7 @@ class DataDesignerConfigBuilder:
130
129
 
131
130
  return builder
132
131
 
133
- def __init__(self, model_configs: Optional[Union[list[ModelConfig], str, Path]] = None):
132
+ def __init__(self, model_configs: list[ModelConfig] | str | Path | None = None):
134
133
  """Initialize a new DataDesignerConfigBuilder instance.
135
134
 
136
135
  Args:
@@ -141,11 +140,11 @@ class DataDesignerConfigBuilder:
141
140
  """
142
141
  self._column_configs = {}
143
142
  self._model_configs = _load_model_configs(model_configs)
144
- self._processor_configs: list[ProcessorConfig] = []
145
- self._seed_config: Optional[SeedConfig] = None
143
+ self._processor_configs: list[ProcessorConfigT] = []
144
+ self._seed_config: SeedConfig | None = None
146
145
  self._constraints: list[ColumnConstraintT] = []
147
146
  self._profilers: list[ColumnProfilerConfigT] = []
148
- self._datastore_settings: Optional[DatastoreSettings] = None
147
+ self._datastore_settings: DatastoreSettings | None = None
149
148
 
150
149
  @property
151
150
  def model_configs(self) -> list[ModelConfig]:
@@ -206,10 +205,10 @@ class DataDesignerConfigBuilder:
206
205
 
207
206
  def add_column(
208
207
  self,
209
- column_config: Optional[ColumnConfigT] = None,
208
+ column_config: ColumnConfigT | None = None,
210
209
  *,
211
- name: Optional[str] = None,
212
- column_type: Optional[DataDesignerColumnType] = None,
210
+ name: str | None = None,
211
+ column_type: DataDesignerColumnType | None = None,
213
212
  **kwargs,
214
213
  ) -> Self:
215
214
  """Add a Data Designer column configuration to the current Data Designer configuration.
@@ -246,9 +245,9 @@ class DataDesignerConfigBuilder:
246
245
 
247
246
  def add_constraint(
248
247
  self,
249
- constraint: Optional[ColumnConstraintT] = None,
248
+ constraint: ColumnConstraintT | None = None,
250
249
  *,
251
- constraint_type: Optional[ConstraintType] = None,
250
+ constraint_type: ConstraintType | None = None,
252
251
  **kwargs,
253
252
  ) -> Self:
254
253
  """Add a constraint to the current Data Designer configuration.
@@ -298,9 +297,9 @@ class DataDesignerConfigBuilder:
298
297
 
299
298
  def add_processor(
300
299
  self,
301
- processor_config: Optional[ProcessorConfig] = None,
300
+ processor_config: ProcessorConfigT | None = None,
302
301
  *,
303
- processor_type: Optional[ProcessorType] = None,
302
+ processor_type: ProcessorType | None = None,
304
303
  **kwargs,
305
304
  ) -> Self:
306
305
  """Add a processor to the current Data Designer configuration.
@@ -447,12 +446,21 @@ class DataDesignerConfigBuilder:
447
446
  return [c for c in self._constraints if c.target_column == target_column]
448
447
 
449
448
  def get_llm_gen_columns(self) -> list[ColumnConfigT]:
450
- """Get all LLM-generated column configurations.
449
+ """Get all model-generated column configurations.
451
450
 
452
451
  Returns:
453
- A list of column configurations that use LLM generation.
452
+ A list of column configurations that use model generation.
454
453
  """
455
- return [c for c in self._column_configs.values() if column_type_is_llm_generated(c.column_type)]
454
+ logger.warning("get_llm_gen_columns is deprecated. Use get_model_gen_columns instead.")
455
+ return self.get_model_gen_columns()
456
+
457
+ def get_model_gen_columns(self) -> list[ColumnConfigT]:
458
+ """Get all model-generated column configurations.
459
+
460
+ Returns:
461
+ A list of column configurations that use model generation.
462
+ """
463
+ return [c for c in self._column_configs.values() if column_type_is_model_generated(c.column_type)]
456
464
 
457
465
  def get_columns_of_type(self, column_type: DataDesignerColumnType) -> list[ColumnConfigT]:
458
466
  """Get all column configurations of the specified type.
@@ -478,7 +486,7 @@ class DataDesignerConfigBuilder:
478
486
  column_type = resolve_string_enum(column_type, DataDesignerColumnType)
479
487
  return [c for c in self._column_configs.values() if c.column_type != column_type]
480
488
 
481
- def get_processor_configs(self) -> dict[BuildStage, list[ProcessorConfig]]:
489
+ def get_processor_configs(self) -> dict[BuildStage, list[ProcessorConfigT]]:
482
490
  """Get processor configuration objects.
483
491
 
484
492
  Returns:
@@ -486,7 +494,7 @@ class DataDesignerConfigBuilder:
486
494
  """
487
495
  return self._processor_configs
488
496
 
489
- def get_seed_config(self) -> Optional[SeedConfig]:
497
+ def get_seed_config(self) -> SeedConfig | None:
490
498
  """Get the seed config for the current Data Designer configuration.
491
499
 
492
500
  Returns:
@@ -494,7 +502,7 @@ class DataDesignerConfigBuilder:
494
502
  """
495
503
  return self._seed_config
496
504
 
497
- def get_seed_datastore_settings(self) -> Optional[DatastoreSettings]:
505
+ def get_seed_datastore_settings(self) -> DatastoreSettings | None:
498
506
  """Get most recent datastore settings for the current Data Designer configuration.
499
507
 
500
508
  Returns:
@@ -513,7 +521,7 @@ class DataDesignerConfigBuilder:
513
521
  """
514
522
  return len(self.get_columns_of_type(column_type))
515
523
 
516
- def set_seed_datastore_settings(self, datastore_settings: Optional[DatastoreSettings]) -> Self:
524
+ def set_seed_datastore_settings(self, datastore_settings: DatastoreSettings | None) -> Self:
517
525
  """Set the datastore settings for the seed dataset.
518
526
 
519
527
  Args:
@@ -554,7 +562,7 @@ class DataDesignerConfigBuilder:
554
562
  dataset_reference: SeedDatasetReference,
555
563
  *,
556
564
  sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED,
557
- selection_strategy: Optional[Union[IndexRange, PartitionBlock]] = None,
565
+ selection_strategy: IndexRange | PartitionBlock | None = None,
558
566
  ) -> Self:
559
567
  """Add a seed dataset to the current Data Designer configuration.
560
568
 
@@ -582,7 +590,7 @@ class DataDesignerConfigBuilder:
582
590
  self._column_configs[column_name] = SeedDatasetColumnConfig(name=column_name)
583
591
  return self
584
592
 
585
- def write_config(self, path: Union[str, Path], indent: Optional[int] = 2, **kwargs) -> None:
593
+ def write_config(self, path: str | Path, indent: int | None = 2, **kwargs) -> None:
586
594
  """Write the current configuration to a file.
587
595
 
588
596
  Args:
@@ -653,7 +661,7 @@ class DataDesignerConfigBuilder:
653
661
  return REPR_HTML_TEMPLATE.format(css=css, highlighted_html=highlighted_html)
654
662
 
655
663
 
656
- def _load_model_configs(model_configs: Optional[Union[list[ModelConfig], str, Path]] = None) -> list[ModelConfig]:
664
+ def _load_model_configs(model_configs: list[ModelConfig] | str | Path | None = None) -> list[ModelConfig]:
657
665
  """Resolves the provided model_configs, which may be a string or Path to a model configuration file.
658
666
  If None or empty, returns default model configurations if possible, otherwise raises an error.
659
667
  """
@@ -3,7 +3,7 @@
3
3
 
4
4
  from __future__ import annotations
5
5
 
6
- from typing import Annotated, Optional
6
+ from typing import Annotated
7
7
 
8
8
  from pydantic import Field
9
9
 
@@ -11,7 +11,7 @@ from data_designer.config.analysis.column_profilers import ColumnProfilerConfigT
11
11
  from data_designer.config.base import ExportableConfigBase
12
12
  from data_designer.config.column_types import ColumnConfigT
13
13
  from data_designer.config.models import ModelConfig
14
- from data_designer.config.processors import ProcessorConfig
14
+ from data_designer.config.processors import ProcessorConfigT
15
15
  from data_designer.config.sampler_constraints import ColumnConstraintT
16
16
  from data_designer.config.seed import SeedConfig
17
17
 
@@ -33,8 +33,8 @@ class DataDesignerConfig(ExportableConfigBase):
33
33
  """
34
34
 
35
35
  columns: list[Annotated[ColumnConfigT, Field(discriminator="column_type")]] = Field(min_length=1)
36
- model_configs: Optional[list[ModelConfig]] = None
37
- seed_config: Optional[SeedConfig] = None
38
- constraints: Optional[list[ColumnConstraintT]] = None
39
- profilers: Optional[list[ColumnProfilerConfigT]] = None
40
- processors: Optional[list[ProcessorConfig]] = None
36
+ model_configs: list[ModelConfig] | None = None
37
+ seed_config: SeedConfig | None = None
38
+ constraints: list[ColumnConstraintT] | None = None
39
+ profilers: list[ColumnProfilerConfigT] | None = None
40
+ processors: list[Annotated[ProcessorConfigT, Field(discriminator="processor_type")]] | None = None
@@ -5,7 +5,7 @@ from __future__ import annotations
5
5
 
6
6
  import logging
7
7
  from pathlib import Path
8
- from typing import TYPE_CHECKING, Optional, Union
8
+ from typing import TYPE_CHECKING
9
9
 
10
10
  import pandas as pd
11
11
  import pyarrow.parquet as pq
@@ -28,10 +28,10 @@ class DatastoreSettings(BaseModel):
28
28
  ...,
29
29
  description="Datastore endpoint. Use 'https://huggingface.co' for the Hugging Face Hub.",
30
30
  )
31
- token: Optional[str] = Field(default=None, description="If needed, token to use for authentication.")
31
+ token: str | None = Field(default=None, description="If needed, token to use for authentication.")
32
32
 
33
33
 
34
- def get_file_column_names(file_reference: Union[str, Path, HfFileSystem], file_type: str) -> list[str]:
34
+ def get_file_column_names(file_reference: str | Path | HfFileSystem, file_type: str) -> list[str]:
35
35
  """Get column names from a dataset file.
36
36
 
37
37
  Args:
@@ -80,7 +80,7 @@ def fetch_seed_dataset_column_names(seed_dataset_reference: SeedDatasetReference
80
80
  def fetch_seed_dataset_column_names_from_datastore(
81
81
  repo_id: str,
82
82
  filename: str,
83
- datastore_settings: Optional[Union[DatastoreSettings, dict]] = None,
83
+ datastore_settings: DatastoreSettings | dict | None = None,
84
84
  ) -> list[str]:
85
85
  file_type = filename.split(".")[-1]
86
86
  if f".{file_type}" not in VALID_DATASET_FILE_EXTENSIONS:
@@ -115,7 +115,7 @@ def resolve_datastore_settings(datastore_settings: DatastoreSettings | dict | No
115
115
 
116
116
 
117
117
  def upload_to_hf_hub(
118
- dataset_path: Union[str, Path],
118
+ dataset_path: str | Path,
119
119
  filename: str,
120
120
  repo_id: str,
121
121
  datastore_settings: DatastoreSettings,
@@ -171,7 +171,7 @@ def _extract_single_file_path_from_glob_pattern_if_present(
171
171
  return matching_files[0]
172
172
 
173
173
 
174
- def _validate_dataset_path(dataset_path: Union[str, Path], allow_glob_pattern: bool = False) -> Path:
174
+ def _validate_dataset_path(dataset_path: str | Path, allow_glob_pattern: bool = False) -> Path:
175
175
  if allow_glob_pattern and "*" in str(dataset_path):
176
176
  parts = str(dataset_path).split("*.")
177
177
  file_path = parts[0]
@@ -6,9 +6,15 @@ import logging
6
6
  import os
7
7
  from functools import lru_cache
8
8
  from pathlib import Path
9
- from typing import Any, Literal, Optional
10
-
11
- from data_designer.config.models import InferenceParameters, ModelConfig, ModelProvider
9
+ from typing import Any, Literal
10
+
11
+ from data_designer.config.models import (
12
+ ChatCompletionInferenceParams,
13
+ EmbeddingInferenceParams,
14
+ InferenceParamsT,
15
+ ModelConfig,
16
+ ModelProvider,
17
+ )
12
18
  from data_designer.config.utils.constants import (
13
19
  MANAGED_ASSETS_PATH,
14
20
  MODEL_CONFIGS_FILE_PATH,
@@ -21,46 +27,32 @@ from data_designer.config.utils.io_helpers import load_config_file, save_config_
21
27
  logger = logging.getLogger(__name__)
22
28
 
23
29
 
24
- def get_default_text_alias_inference_parameters() -> InferenceParameters:
25
- return InferenceParameters(
26
- temperature=0.85,
27
- top_p=0.95,
28
- )
29
-
30
-
31
- def get_default_reasoning_alias_inference_parameters() -> InferenceParameters:
32
- return InferenceParameters(
33
- temperature=0.35,
34
- top_p=0.95,
35
- )
36
-
37
-
38
- def get_default_vision_alias_inference_parameters() -> InferenceParameters:
39
- return InferenceParameters(
40
- temperature=0.85,
41
- top_p=0.95,
42
- )
43
-
44
-
45
- def get_default_inference_parameters(model_alias: Literal["text", "reasoning", "vision"]) -> InferenceParameters:
30
+ def get_default_inference_parameters(
31
+ model_alias: Literal["text", "reasoning", "vision", "embedding"],
32
+ inference_parameters: dict[str, Any],
33
+ ) -> InferenceParamsT:
46
34
  if model_alias == "reasoning":
47
- return get_default_reasoning_alias_inference_parameters()
35
+ return ChatCompletionInferenceParams(**inference_parameters)
48
36
  elif model_alias == "vision":
49
- return get_default_vision_alias_inference_parameters()
37
+ return ChatCompletionInferenceParams(**inference_parameters)
38
+ elif model_alias == "embedding":
39
+ return EmbeddingInferenceParams(**inference_parameters)
50
40
  else:
51
- return get_default_text_alias_inference_parameters()
41
+ return ChatCompletionInferenceParams(**inference_parameters)
52
42
 
53
43
 
54
44
  def get_builtin_model_configs() -> list[ModelConfig]:
55
45
  model_configs = []
56
46
  for provider, model_alias_map in PREDEFINED_PROVIDERS_MODEL_MAP.items():
57
- for model_alias, model_id in model_alias_map.items():
47
+ for model_alias, settings in model_alias_map.items():
58
48
  model_configs.append(
59
49
  ModelConfig(
60
50
  alias=f"{provider}-{model_alias}",
61
- model=model_id,
51
+ model=settings["model"],
62
52
  provider=provider,
63
- inference_parameters=get_default_inference_parameters(model_alias),
53
+ inference_parameters=get_default_inference_parameters(
54
+ model_alias, settings["inference_parameters"]
55
+ ),
64
56
  )
65
57
  )
66
58
  return model_configs
@@ -93,7 +85,7 @@ def get_default_providers() -> list[ModelProvider]:
93
85
  return []
94
86
 
95
87
 
96
- def get_default_provider_name() -> Optional[str]:
88
+ def get_default_provider_name() -> str | None:
97
89
  return _get_default_providers_file_content(MODEL_PROVIDERS_FILE_PATH).get("default")
98
90
 
99
91
 
@@ -103,7 +95,8 @@ def resolve_seed_default_model_settings() -> None:
103
95
  f"🍾 Default model configs were not found, so writing the following to {str(MODEL_CONFIGS_FILE_PATH)!r}"
104
96
  )
105
97
  save_config_file(
106
- MODEL_CONFIGS_FILE_PATH, {"model_configs": [mc.model_dump() for mc in get_builtin_model_configs()]}
98
+ MODEL_CONFIGS_FILE_PATH,
99
+ {"model_configs": [mc.model_dump(mode="json") for mc in get_builtin_model_configs()]},
107
100
  )
108
101
 
109
102
  if not MODEL_PROVIDERS_FILE_PATH.exists():
@@ -111,7 +104,7 @@ def resolve_seed_default_model_settings() -> None:
111
104
  f"🪄 Default model providers were not found, so writing the following to {str(MODEL_PROVIDERS_FILE_PATH)!r}"
112
105
  )
113
106
  save_config_file(
114
- MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump() for p in get_builtin_model_providers()]}
107
+ MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump(mode="json") for p in get_builtin_model_providers()]}
115
108
  )
116
109
 
117
110
  if not MANAGED_ASSETS_PATH.exists():