data-designer 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. data_designer/_version.py +2 -2
  2. data_designer/cli/README.md +15 -1
  3. data_designer/cli/commands/download.py +56 -0
  4. data_designer/cli/commands/list.py +4 -18
  5. data_designer/cli/controllers/__init__.py +2 -1
  6. data_designer/cli/controllers/download_controller.py +217 -0
  7. data_designer/cli/controllers/model_controller.py +4 -3
  8. data_designer/cli/forms/field.py +65 -19
  9. data_designer/cli/forms/model_builder.py +251 -44
  10. data_designer/cli/main.py +11 -1
  11. data_designer/cli/repositories/persona_repository.py +88 -0
  12. data_designer/cli/services/__init__.py +2 -1
  13. data_designer/cli/services/download_service.py +97 -0
  14. data_designer/cli/ui.py +131 -0
  15. data_designer/cli/utils.py +34 -0
  16. data_designer/config/analysis/__init__.py +2 -0
  17. data_designer/config/analysis/column_profilers.py +75 -7
  18. data_designer/config/analysis/column_statistics.py +192 -48
  19. data_designer/config/analysis/dataset_profiler.py +23 -5
  20. data_designer/config/analysis/utils/reporting.py +3 -3
  21. data_designer/config/base.py +3 -3
  22. data_designer/config/column_configs.py +27 -6
  23. data_designer/config/column_types.py +24 -17
  24. data_designer/config/config_builder.py +34 -26
  25. data_designer/config/data_designer_config.py +7 -7
  26. data_designer/config/datastore.py +6 -6
  27. data_designer/config/default_model_settings.py +27 -34
  28. data_designer/config/exports.py +14 -1
  29. data_designer/config/models.py +155 -29
  30. data_designer/config/preview_results.py +5 -4
  31. data_designer/config/processors.py +109 -4
  32. data_designer/config/sampler_constraints.py +1 -2
  33. data_designer/config/sampler_params.py +31 -31
  34. data_designer/config/seed.py +1 -2
  35. data_designer/config/utils/code_lang.py +4 -5
  36. data_designer/config/utils/constants.py +31 -8
  37. data_designer/config/utils/io_helpers.py +5 -5
  38. data_designer/config/utils/misc.py +1 -4
  39. data_designer/config/utils/numerical_helpers.py +2 -2
  40. data_designer/config/utils/type_helpers.py +3 -3
  41. data_designer/config/utils/validation.py +39 -9
  42. data_designer/config/utils/visualization.py +62 -15
  43. data_designer/config/validator_params.py +4 -8
  44. data_designer/engine/analysis/column_profilers/base.py +0 -7
  45. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
  46. data_designer/engine/analysis/column_statistics.py +16 -16
  47. data_designer/engine/analysis/dataset_profiler.py +25 -4
  48. data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
  49. data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
  50. data_designer/engine/column_generators/generators/base.py +34 -0
  51. data_designer/engine/column_generators/generators/embedding.py +45 -0
  52. data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
  53. data_designer/engine/column_generators/registry.py +4 -2
  54. data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
  55. data_designer/engine/configurable_task.py +2 -2
  56. data_designer/engine/dataset_builders/artifact_storage.py +14 -5
  57. data_designer/engine/dataset_builders/column_wise_builder.py +12 -8
  58. data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
  59. data_designer/engine/models/facade.py +66 -9
  60. data_designer/engine/models/litellm_overrides.py +5 -6
  61. data_designer/engine/models/parsers/errors.py +2 -4
  62. data_designer/engine/models/parsers/parser.py +2 -3
  63. data_designer/engine/models/parsers/postprocessors.py +3 -4
  64. data_designer/engine/models/parsers/types.py +4 -4
  65. data_designer/engine/models/registry.py +20 -11
  66. data_designer/engine/models/usage.py +7 -9
  67. data_designer/engine/processing/ginja/ast.py +1 -2
  68. data_designer/engine/processing/processors/drop_columns.py +1 -1
  69. data_designer/engine/processing/processors/registry.py +3 -0
  70. data_designer/engine/processing/processors/schema_transform.py +53 -0
  71. data_designer/engine/processing/utils.py +40 -2
  72. data_designer/engine/registry/base.py +12 -12
  73. data_designer/engine/sampling_gen/constraints.py +1 -2
  74. data_designer/engine/sampling_gen/data_sources/base.py +14 -14
  75. data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
  76. data_designer/engine/sampling_gen/people_gen.py +3 -7
  77. data_designer/engine/validators/base.py +2 -2
  78. data_designer/interface/data_designer.py +12 -0
  79. data_designer/interface/results.py +36 -0
  80. data_designer/logging.py +2 -2
  81. data_designer/plugin_manager.py +3 -3
  82. data_designer/plugins/plugin.py +3 -3
  83. data_designer/plugins/registry.py +2 -2
  84. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/METADATA +9 -9
  85. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/RECORD +88 -81
  86. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
  87. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
  88. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -10,43 +10,41 @@ from data_designer.config.column_configs import (
10
10
  LLMStructuredColumnConfig,
11
11
  LLMTextColumnConfig,
12
12
  )
13
- from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP
14
- from data_designer.config.models import InferenceParameters, ModelConfig
15
13
  from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX
16
14
  from data_designer.engine.column_generators.generators.base import (
17
15
  ColumnGenerator,
18
16
  GenerationStrategy,
19
17
  GeneratorMetadata,
18
+ WithModelGeneration,
20
19
  )
21
20
  from data_designer.engine.column_generators.utils.prompt_renderer import (
22
21
  PromptType,
23
22
  RecordBasedPromptRenderer,
24
23
  create_response_recipe,
25
24
  )
26
- from data_designer.engine.models.facade import ModelFacade
27
25
  from data_designer.engine.models.recipes.base import ResponseRecipe
28
26
  from data_designer.engine.processing.utils import deserialize_json_values
29
27
  from data_designer.engine.resources.resource_provider import ResourceType
30
28
 
31
- DEFAULT_MAX_CONVERSATION_RESTARTS = 5
32
- DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0
29
+ logger = logging.getLogger(__name__)
33
30
 
34
31
 
35
- logger = logging.getLogger(__name__)
32
+ DEFAULT_MAX_CONVERSATION_RESTARTS = 5
33
+ DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0
36
34
 
37
35
 
38
- class WithLLMGeneration:
36
+ class WithChatCompletionGeneration(WithModelGeneration):
39
37
  @functools.cached_property
40
- def model(self) -> ModelFacade:
41
- return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias)
38
+ def response_recipe(self) -> ResponseRecipe:
39
+ return create_response_recipe(self.config, self.model_config)
42
40
 
43
- @functools.cached_property
44
- def model_config(self) -> ModelConfig:
45
- return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias)
41
+ @property
42
+ def max_conversation_correction_steps(self) -> int:
43
+ return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
46
44
 
47
- @functools.cached_property
48
- def inference_parameters(self) -> InferenceParameters:
49
- return self.model_config.inference_parameters
45
+ @property
46
+ def max_conversation_restarts(self) -> int:
47
+ return DEFAULT_MAX_CONVERSATION_RESTARTS
50
48
 
51
49
  @functools.cached_property
52
50
  def prompt_renderer(self) -> RecordBasedPromptRenderer:
@@ -59,18 +57,6 @@ class WithLLMGeneration:
59
57
  },
60
58
  )
61
59
 
62
- @functools.cached_property
63
- def response_recipe(self) -> ResponseRecipe:
64
- return create_response_recipe(self.config, self.model_config)
65
-
66
- @property
67
- def max_conversation_correction_steps(self) -> int:
68
- return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
69
-
70
- @property
71
- def max_conversation_restarts(self) -> int:
72
- return DEFAULT_MAX_CONVERSATION_RESTARTS
73
-
74
60
  def generate(self, data: dict) -> dict:
75
61
  deserialized_record = deserialize_json_values(data)
76
62
 
@@ -96,7 +82,6 @@ class WithLLMGeneration:
96
82
  max_correction_steps=self.max_conversation_correction_steps,
97
83
  max_conversation_restarts=self.max_conversation_restarts,
98
84
  purpose=f"running generation for column '{self.config.name}'",
99
- **self.inference_parameters.generate_kwargs,
100
85
  )
101
86
 
102
87
  data[self.config.name] = deserialize_json_values(self.response_recipe.serialize_output(response))
@@ -106,21 +91,8 @@ class WithLLMGeneration:
106
91
 
107
92
  return data
108
93
 
109
- def log_pre_generation(self) -> None:
110
- emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type]
111
- logger.info(f"{emoji} Preparing {self.config.column_type} column generation")
112
- logger.info(f" |-- column name: {self.config.name!r}")
113
- logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}")
114
- if self.model_config.provider is None:
115
- logger.info(f" |-- default model provider: {self._get_provider_name()!r}")
116
-
117
- def _get_provider_name(self) -> str:
118
- model_alias = self.model_config.alias
119
- provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias)
120
- return provider.name
121
-
122
94
 
123
- class LLMTextCellGenerator(WithLLMGeneration, ColumnGenerator[LLMTextColumnConfig]):
95
+ class LLMTextCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMTextColumnConfig]):
124
96
  @staticmethod
125
97
  def metadata() -> GeneratorMetadata:
126
98
  return GeneratorMetadata(
@@ -131,7 +103,7 @@ class LLMTextCellGenerator(WithLLMGeneration, ColumnGenerator[LLMTextColumnConfi
131
103
  )
132
104
 
133
105
 
134
- class LLMCodeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMCodeColumnConfig]):
106
+ class LLMCodeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMCodeColumnConfig]):
135
107
  @staticmethod
136
108
  def metadata() -> GeneratorMetadata:
137
109
  return GeneratorMetadata(
@@ -142,7 +114,7 @@ class LLMCodeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMCodeColumnConfi
142
114
  )
143
115
 
144
116
 
145
- class LLMStructuredCellGenerator(WithLLMGeneration, ColumnGenerator[LLMStructuredColumnConfig]):
117
+ class LLMStructuredCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMStructuredColumnConfig]):
146
118
  @staticmethod
147
119
  def metadata() -> GeneratorMetadata:
148
120
  return GeneratorMetadata(
@@ -153,7 +125,7 @@ class LLMStructuredCellGenerator(WithLLMGeneration, ColumnGenerator[LLMStructure
153
125
  )
154
126
 
155
127
 
156
- class LLMJudgeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMJudgeColumnConfig]):
128
+ class LLMJudgeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMJudgeColumnConfig]):
157
129
  @staticmethod
158
130
  def metadata() -> GeneratorMetadata:
159
131
  return GeneratorMetadata(
@@ -163,10 +135,6 @@ class LLMJudgeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMJudgeColumnCon
163
135
  required_resources=[ResourceType.MODEL_REGISTRY],
164
136
  )
165
137
 
166
- @property
167
- def max_conversation_correction_steps(self) -> int:
168
- return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
169
-
170
138
  @property
171
139
  def max_conversation_restarts(self) -> int:
172
140
  return 2 * DEFAULT_MAX_CONVERSATION_RESTARTS
@@ -3,6 +3,7 @@
3
3
 
4
4
  from data_designer.config.base import ConfigBase
5
5
  from data_designer.config.column_configs import (
6
+ EmbeddingColumnConfig,
6
7
  ExpressionColumnConfig,
7
8
  LLMCodeColumnConfig,
8
9
  LLMJudgeColumnConfig,
@@ -12,8 +13,9 @@ from data_designer.config.column_configs import (
12
13
  )
13
14
  from data_designer.config.column_types import DataDesignerColumnType
14
15
  from data_designer.engine.column_generators.generators.base import ColumnGenerator
16
+ from data_designer.engine.column_generators.generators.embedding import EmbeddingCellGenerator
15
17
  from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator
16
- from data_designer.engine.column_generators.generators.llm_generators import (
18
+ from data_designer.engine.column_generators.generators.llm_completion import (
17
19
  LLMCodeCellGenerator,
18
20
  LLMJudgeCellGenerator,
19
21
  LLMStructuredCellGenerator,
@@ -40,11 +42,11 @@ def create_default_column_generator_registry(with_plugins: bool = True) -> Colum
40
42
  registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig)
41
43
  registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig)
42
44
  registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig)
45
+ registry.register(DataDesignerColumnType.EMBEDDING, EmbeddingCellGenerator, EmbeddingColumnConfig)
43
46
  registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig)
44
47
  registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig)
45
48
  registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig)
46
49
  registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig)
47
-
48
50
  if with_plugins:
49
51
  for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR):
50
52
  registry.register(
@@ -2,7 +2,6 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from enum import Enum
5
- from typing import Type
6
5
 
7
6
  from pydantic import BaseModel, ConfigDict, Field, create_model
8
7
 
@@ -19,7 +18,7 @@ class BaseJudgeResponse(BaseModel):
19
18
  reasoning: str = Field(..., description="Reasoning for the assigned score.")
20
19
 
21
20
 
22
- def _stringify_scoring(options: dict, enum_type: Type[Enum]) -> str:
21
+ def _stringify_scoring(options: dict, enum_type: type[Enum]) -> str:
23
22
  """Convert score descriptions into a single text block."""
24
23
  list_block = "\n".join(
25
24
  [SCORING_FORMAT.format(score=score, description=description) for score, description in options.items()]
@@ -27,7 +26,7 @@ def _stringify_scoring(options: dict, enum_type: Type[Enum]) -> str:
27
26
  return SCORE_FIELD_DESCRIPTION_FORMAT.format(enum_name=enum_type.__name__, scoring=list_block)
28
27
 
29
28
 
30
- def create_judge_response_model(score: Score) -> Type[BaseJudgeResponse]:
29
+ def create_judge_response_model(score: Score) -> type[BaseJudgeResponse]:
31
30
  """Create a JudgeResponse data type."""
32
31
  enum_members = {}
33
32
  for option in score.options.keys():
@@ -46,12 +45,12 @@ def create_judge_response_model(score: Score) -> Type[BaseJudgeResponse]:
46
45
 
47
46
 
48
47
  def create_judge_structured_output_model(
49
- judge_responses: list[Type[BaseJudgeResponse]],
50
- ) -> Type[BaseModel]:
48
+ judge_responses: list[type[BaseJudgeResponse]],
49
+ ) -> type[BaseModel]:
51
50
  """Create a JudgeStructuredOutput class dynamically."""
52
51
  return create_model(
53
52
  "JudgeStructuredOutput",
54
53
  __doc__=f"Response schema for scores with the following names: {[response.__name__ for response in judge_responses]}.",
55
54
  __base__=BaseModel,
56
- **{response.__name__.lower(): (response, ...) for response in judge_responses},
55
+ **{response.__name__: (response, ...) for response in judge_responses},
57
56
  )
@@ -3,7 +3,7 @@
3
3
 
4
4
  from abc import ABC, abstractmethod
5
5
  from pathlib import Path
6
- from typing import Generic, Type, TypeVar, get_origin
6
+ from typing import Generic, TypeVar, get_origin
7
7
 
8
8
  import pandas as pd
9
9
 
@@ -30,7 +30,7 @@ class ConfigurableTask(ABC, Generic[TaskConfigT]):
30
30
  self._initialize()
31
31
 
32
32
  @classmethod
33
- def get_config_type(cls) -> Type[TaskConfigT]:
33
+ def get_config_type(cls) -> type[TaskConfigT]:
34
34
  for base in cls.__orig_bases__:
35
35
  if hasattr(base, "__args__") and len(base.__args__) == 1:
36
36
  arg = base.__args__[0]
@@ -7,7 +7,6 @@ import shutil
7
7
  from datetime import datetime
8
8
  from functools import cached_property
9
9
  from pathlib import Path
10
- from typing import Union
11
10
 
12
11
  import pandas as pd
13
12
  from pydantic import BaseModel, field_validator, model_validator
@@ -25,6 +24,7 @@ class BatchStage(StrEnum):
25
24
  PARTIAL_RESULT = "partial_results_path"
26
25
  FINAL_RESULT = "final_dataset_path"
27
26
  DROPPED_COLUMNS = "dropped_columns_dataset_path"
27
+ PROCESSORS_OUTPUTS = "processors_outputs_path"
28
28
 
29
29
 
30
30
  class ArtifactStorage(BaseModel):
@@ -33,6 +33,7 @@ class ArtifactStorage(BaseModel):
33
33
  final_dataset_folder_name: str = "parquet-files"
34
34
  partial_results_folder_name: str = "tmp-partial-parquet-files"
35
35
  dropped_columns_folder_name: str = "dropped-columns-parquet-files"
36
+ processors_outputs_folder_name: str = "processors-files"
36
37
 
37
38
  @property
38
39
  def artifact_path_exists(self) -> bool:
@@ -70,8 +71,12 @@ class ArtifactStorage(BaseModel):
70
71
  def partial_results_path(self) -> Path:
71
72
  return self.base_dataset_path / self.partial_results_folder_name
72
73
 
74
+ @property
75
+ def processors_outputs_path(self) -> Path:
76
+ return self.base_dataset_path / self.processors_outputs_folder_name
77
+
73
78
  @field_validator("artifact_path")
74
- def validate_artifact_path(cls, v: Union[Path, str]) -> Path:
79
+ def validate_artifact_path(cls, v: Path | str) -> Path:
75
80
  v = Path(v)
76
81
  if not v.is_dir():
77
82
  raise ArtifactStorageError("Artifact path must exist and be a directory")
@@ -84,6 +89,7 @@ class ArtifactStorage(BaseModel):
84
89
  self.final_dataset_folder_name,
85
90
  self.partial_results_folder_name,
86
91
  self.dropped_columns_folder_name,
92
+ self.processors_outputs_folder_name,
87
93
  ]
88
94
 
89
95
  for name in folder_names:
@@ -169,9 +175,10 @@ class ArtifactStorage(BaseModel):
169
175
  batch_number: int,
170
176
  dataframe: pd.DataFrame,
171
177
  batch_stage: BatchStage,
178
+ subfolder: str | None = None,
172
179
  ) -> Path:
173
180
  file_path = self.create_batch_file_path(batch_number, batch_stage=batch_stage)
174
- self.write_parquet_file(file_path.name, dataframe, batch_stage)
181
+ self.write_parquet_file(file_path.name, dataframe, batch_stage, subfolder=subfolder)
175
182
  return file_path
176
183
 
177
184
  def write_parquet_file(
@@ -179,9 +186,11 @@ class ArtifactStorage(BaseModel):
179
186
  parquet_file_name: str,
180
187
  dataframe: pd.DataFrame,
181
188
  batch_stage: BatchStage,
189
+ subfolder: str | None = None,
182
190
  ) -> Path:
183
- self.mkdir_if_needed(self._get_stage_path(batch_stage))
184
- file_path = self._get_stage_path(batch_stage) / parquet_file_name
191
+ subfolder = subfolder or ""
192
+ self.mkdir_if_needed(self._get_stage_path(batch_stage) / subfolder)
193
+ file_path = self._get_stage_path(batch_stage) / subfolder / parquet_file_name
185
194
  dataframe.to_parquet(file_path, index=False)
186
195
  return file_path
187
196
 
@@ -10,15 +10,18 @@ from typing import Callable
10
10
 
11
11
  import pandas as pd
12
12
 
13
- from data_designer.config.column_types import ColumnConfigT, column_type_is_llm_generated
13
+ from data_designer.config.column_types import ColumnConfigT, column_type_is_model_generated
14
14
  from data_designer.config.dataset_builders import BuildStage
15
15
  from data_designer.config.processors import (
16
16
  DropColumnsProcessorConfig,
17
17
  ProcessorConfig,
18
18
  ProcessorType,
19
19
  )
20
- from data_designer.engine.column_generators.generators.base import ColumnGenerator, GenerationStrategy
21
- from data_designer.engine.column_generators.generators.llm_generators import WithLLMGeneration
20
+ from data_designer.engine.column_generators.generators.base import (
21
+ ColumnGenerator,
22
+ GenerationStrategy,
23
+ WithModelGeneration,
24
+ )
22
25
  from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
23
26
  from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
24
27
  from data_designer.engine.dataset_builders.multi_column_configs import (
@@ -72,7 +75,7 @@ class ColumnWiseDatasetBuilder:
72
75
 
73
76
  @functools.cached_property
74
77
  def llm_generated_column_configs(self) -> list[ColumnConfigT]:
75
- return [config for config in self.single_column_configs if column_type_is_llm_generated(config.column_type)]
78
+ return [config for config in self.single_column_configs if column_type_is_model_generated(config.column_type)]
76
79
 
77
80
  def build(
78
81
  self,
@@ -169,7 +172,7 @@ class ColumnWiseDatasetBuilder:
169
172
 
170
173
  def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
171
174
  max_workers = MAX_CONCURRENCY_PER_NON_LLM_GENERATOR
172
- if isinstance(generator, WithLLMGeneration):
175
+ if isinstance(generator, WithModelGeneration):
173
176
  max_workers = generator.inference_parameters.max_parallel_requests
174
177
  self._fan_out_with_threads(generator, max_workers=max_workers)
175
178
 
@@ -178,12 +181,12 @@ class ColumnWiseDatasetBuilder:
178
181
  self.batch_manager.update_records(df.to_dict(orient="records"))
179
182
 
180
183
  def _run_model_health_check_if_needed(self) -> bool:
181
- if any(column_type_is_llm_generated(config.column_type) for config in self.single_column_configs):
184
+ if any(column_type_is_model_generated(config.column_type) for config in self.single_column_configs):
182
185
  self._resource_provider.model_registry.run_health_check(
183
- set(config.model_alias for config in self.llm_generated_column_configs)
186
+ list(set(config.model_alias for config in self.llm_generated_column_configs))
184
187
  )
185
188
 
186
- def _fan_out_with_threads(self, generator: WithLLMGeneration, max_workers: int) -> None:
189
+ def _fan_out_with_threads(self, generator: WithModelGeneration, max_workers: int) -> None:
187
190
  if generator.generation_strategy != GenerationStrategy.CELL_BY_CELL:
188
191
  raise DatasetGenerationError(
189
192
  f"Generator {generator.metadata().name} is not a {GenerationStrategy.CELL_BY_CELL} "
@@ -244,6 +247,7 @@ class ColumnWiseDatasetBuilder:
244
247
  processors[BuildStage.POST_BATCH].append( # as post-batch by default
245
248
  DropColumnsProcessor(
246
249
  config=DropColumnsProcessorConfig(
250
+ name="default_drop_columns_processor",
247
251
  column_names=columns_to_drop,
248
252
  build_stage=BuildStage.POST_BATCH,
249
253
  ),
@@ -8,7 +8,7 @@ import json
8
8
  import logging
9
9
  from concurrent.futures import Future, ThreadPoolExecutor
10
10
  from threading import Lock, Semaphore
11
- from typing import Any, Optional, Protocol
11
+ from typing import Any, Protocol
12
12
 
13
13
  from pydantic import BaseModel, Field
14
14
 
@@ -46,13 +46,13 @@ class ExecutorResults(BaseModel):
46
46
  class CallbackWithContext(Protocol):
47
47
  """Executor callback functions must accept a context kw argument."""
48
48
 
49
- def __call__(self, result: Any, *, context: Optional[dict] = None) -> Any: ...
49
+ def __call__(self, result: Any, *, context: dict | None = None) -> Any: ...
50
50
 
51
51
 
52
52
  class ErrorCallbackWithContext(Protocol):
53
53
  """Error callbacks take the Exception instance and context."""
54
54
 
55
- def __call__(self, exc: Exception, *, context: Optional[dict] = None) -> Any: ...
55
+ def __call__(self, exc: Exception, *, context: dict | None = None) -> Any: ...
56
56
 
57
57
 
58
58
  class ConcurrentThreadExecutor:
@@ -92,8 +92,8 @@ class ConcurrentThreadExecutor:
92
92
  *,
93
93
  max_workers: int,
94
94
  column_name: str,
95
- result_callback: Optional[CallbackWithContext] = None,
96
- error_callback: Optional[ErrorCallbackWithContext] = None,
95
+ result_callback: CallbackWithContext | None = None,
96
+ error_callback: ErrorCallbackWithContext | None = None,
97
97
  shutdown_error_rate: float = 0.50,
98
98
  shutdown_error_window: int = 10,
99
99
  ):
@@ -136,7 +136,7 @@ class ConcurrentThreadExecutor:
136
136
  )
137
137
  )
138
138
 
139
- def submit(self, fn, *args, context: Optional[dict] = None, **kwargs) -> None:
139
+ def submit(self, fn, *args, context: dict | None = None, **kwargs) -> None:
140
140
  if self._executor is None:
141
141
  raise RuntimeError("Executor is not initialized, this class should be used as a context manager.")
142
142
 
@@ -9,9 +9,9 @@ from copy import deepcopy
9
9
  from typing import Any
10
10
 
11
11
  from litellm.types.router import DeploymentTypedDict, LiteLLM_Params
12
- from litellm.types.utils import ModelResponse
12
+ from litellm.types.utils import EmbeddingResponse, ModelResponse
13
13
 
14
- from data_designer.config.models import ModelConfig, ModelProvider
14
+ from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
15
15
  from data_designer.engine.model_provider import ModelProviderRegistry
16
16
  from data_designer.engine.models.errors import (
17
17
  GenerationValidationFailureError,
@@ -49,6 +49,10 @@ class ModelFacade:
49
49
  def model_provider(self) -> ModelProvider:
50
50
  return self._model_provider_registry.get_provider(self._model_config.provider)
51
51
 
52
+ @property
53
+ def model_generation_type(self) -> GenerationType:
54
+ return self._model_config.generation_type
55
+
52
56
  @property
53
57
  def model_provider_name(self) -> str:
54
58
  return self.model_provider.name
@@ -64,13 +68,12 @@ class ModelFacade:
64
68
  def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs) -> ModelResponse:
65
69
  logger.debug(
66
70
  f"Prompting model {self.model_name!r}...",
67
- extra={"model": self.model_name, "messages": messages, "sensitive": True},
71
+ extra={"model": self.model_name, "messages": messages},
68
72
  )
69
73
  response = None
70
- if self.model_provider.extra_body:
71
- kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
74
+ kwargs = self.consolidate_kwargs(**kwargs)
72
75
  try:
73
- response = self._router.completion(self.model_name, messages, **kwargs)
76
+ response = self._router.completion(model=self.model_name, messages=messages, **kwargs)
74
77
  logger.debug(
75
78
  f"Received completion from model {self.model_name!r}",
76
79
  extra={
@@ -84,9 +87,50 @@ class ModelFacade:
84
87
  except Exception as e:
85
88
  raise e
86
89
  finally:
87
- if not skip_usage_tracking:
90
+ if not skip_usage_tracking and response is not None:
88
91
  self._track_usage(response)
89
92
 
93
+ def consolidate_kwargs(self, **kwargs) -> dict[str, Any]:
94
+ # Remove purpose from kwargs to avoid passing it to the model
95
+ kwargs.pop("purpose", None)
96
+ kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs}
97
+ if self.model_provider.extra_body:
98
+ kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
99
+ return kwargs
100
+
101
+ @catch_llm_exceptions
102
+ def generate_text_embeddings(
103
+ self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
104
+ ) -> list[list[float]]:
105
+ logger.debug(
106
+ f"Generating embeddings with model {self.model_name!r}...",
107
+ extra={
108
+ "model": self.model_name,
109
+ "input_count": len(input_texts),
110
+ },
111
+ )
112
+ kwargs = self.consolidate_kwargs(**kwargs)
113
+ response = None
114
+ try:
115
+ response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs)
116
+ logger.debug(
117
+ f"Received embeddings from model {self.model_name!r}",
118
+ extra={
119
+ "model": self.model_name,
120
+ "embedding_count": len(response.data) if response.data else 0,
121
+ "usage": self._usage_stats.model_dump(),
122
+ },
123
+ )
124
+ if response.data and len(response.data) == len(input_texts):
125
+ return [data["embedding"] for data in response.data]
126
+ else:
127
+ raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}")
128
+ except Exception as e:
129
+ raise e
130
+ finally:
131
+ if not skip_usage_tracking and response is not None:
132
+ self._track_usage_from_embedding(response)
133
+
90
134
  @catch_llm_exceptions
91
135
  def generate(
92
136
  self,
@@ -218,8 +262,21 @@ class ModelFacade:
218
262
  ):
219
263
  self._usage_stats.extend(
220
264
  token_usage=TokenUsageStats(
221
- prompt_tokens=response.usage.prompt_tokens,
222
- completion_tokens=response.usage.completion_tokens,
265
+ input_tokens=response.usage.prompt_tokens,
266
+ output_tokens=response.usage.completion_tokens,
267
+ ),
268
+ request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
269
+ )
270
+
271
+ def _track_usage_from_embedding(self, response: EmbeddingResponse | None) -> None:
272
+ if response is None:
273
+ self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
274
+ return
275
+ if response.usage is not None and response.usage.prompt_tokens is not None:
276
+ self._usage_stats.extend(
277
+ token_usage=TokenUsageStats(
278
+ input_tokens=response.usage.prompt_tokens,
279
+ output_tokens=0,
223
280
  ),
224
281
  request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
225
282
  )
@@ -5,7 +5,6 @@ from __future__ import annotations
5
5
 
6
6
  import random
7
7
  import threading
8
- from typing import Optional, Union
9
8
 
10
9
  import httpx
11
10
  import litellm
@@ -90,7 +89,7 @@ class CustomRouter(Router):
90
89
  self._initial_retry_after_s = initial_retry_after_s
91
90
  self._jitter_pct = jitter_pct
92
91
 
93
- def _extract_retry_delay_from_headers(self, e: Exception) -> Optional[Union[int, float]]:
92
+ def _extract_retry_delay_from_headers(self, e: Exception) -> int | float | None:
94
93
  """
95
94
  Most of this code logic was extracted directly from the parent
96
95
  `Router`'s `_time_to_sleep_before_retry` function. Our override
@@ -99,7 +98,7 @@ class CustomRouter(Router):
99
98
  return this info, we'll simply use that retry value returned here.
100
99
  """
101
100
 
102
- response_headers: Optional[httpx.Headers] = None
101
+ response_headers: httpx.Headers | None = None
103
102
  if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
104
103
  response_headers = e.response.headers # type: ignore
105
104
  if hasattr(e, "litellm_response_headers"):
@@ -119,9 +118,9 @@ class CustomRouter(Router):
119
118
  e: Exception,
120
119
  remaining_retries: int,
121
120
  num_retries: int,
122
- healthy_deployments: Optional[list] = None,
123
- all_deployments: Optional[list] = None,
124
- ) -> Union[int, float]:
121
+ healthy_deployments: list | None = None,
122
+ all_deployments: list | None = None,
123
+ ) -> int | float:
125
124
  """
126
125
  Implements exponential backoff for retries.
127
126
 
@@ -1,8 +1,6 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
- from typing import Optional
5
-
6
4
 
7
5
  class ParserException(Exception):
8
6
  """Identifies errors resulting from generic parser errors.
@@ -12,7 +10,7 @@ class ParserException(Exception):
12
10
  attempted to parse.
13
11
  """
14
12
 
15
- source: Optional[str]
13
+ source: str | None
16
14
 
17
15
  @staticmethod
18
16
  def _log_format(source: str) -> str:
@@ -24,7 +22,7 @@ class ParserException(Exception):
24
22
  # return f"<source>{source}</source>"
25
23
  return ""
26
24
 
27
- def __init__(self, msg: Optional[str] = None, source: Optional[str] = None):
25
+ def __init__(self, msg: str | None = None, source: str | None = None):
28
26
  msg = "" if msg is None else msg.strip()
29
27
 
30
28
  if source is not None:
@@ -2,7 +2,6 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from functools import reduce
5
- from typing import Optional
6
5
 
7
6
  import marko
8
7
  from lxml import etree
@@ -105,8 +104,8 @@ class LLMResponseParser:
105
104
 
106
105
  def __init__(
107
106
  self,
108
- tag_parsers: Optional[dict[str, TagParser]] = None,
109
- postprocessors: Optional[list[PostProcessor]] = None,
107
+ tag_parsers: dict[str, TagParser] | None = None,
108
+ postprocessors: list[PostProcessor] | None = None,
110
109
  ):
111
110
  """
112
111
  Initializes the LLMResponseParser with optional tag parsers and post-processors.
@@ -1,7 +1,6 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
- from typing import Optional, Type
5
4
 
6
5
  import json_repair
7
6
  from pydantic import BaseModel, ValidationError
@@ -60,12 +59,12 @@ def deserialize_json_code(
60
59
 
61
60
 
62
61
  class RealizePydanticTypes:
63
- types: list[Type[BaseModel]]
62
+ types: list[type[BaseModel]]
64
63
 
65
- def __init__(self, types: list[Type[BaseModel]]):
64
+ def __init__(self, types: list[type[BaseModel]]):
66
65
  self.types = types
67
66
 
68
- def _fit_types(self, obj: dict) -> Optional[BaseModel]:
67
+ def _fit_types(self, obj: dict) -> BaseModel | None:
69
68
  final_obj = None
70
69
 
71
70
  for t in self.types: