data-designer 0.1.5__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. data_designer/_version.py +2 -2
  2. data_designer/cli/README.md +15 -1
  3. data_designer/cli/commands/download.py +56 -0
  4. data_designer/cli/commands/list.py +4 -18
  5. data_designer/cli/controllers/__init__.py +2 -1
  6. data_designer/cli/controllers/download_controller.py +217 -0
  7. data_designer/cli/controllers/model_controller.py +4 -3
  8. data_designer/cli/forms/field.py +65 -19
  9. data_designer/cli/forms/model_builder.py +251 -44
  10. data_designer/cli/main.py +11 -1
  11. data_designer/cli/repositories/persona_repository.py +88 -0
  12. data_designer/cli/services/__init__.py +2 -1
  13. data_designer/cli/services/download_service.py +97 -0
  14. data_designer/cli/ui.py +131 -0
  15. data_designer/cli/utils.py +34 -0
  16. data_designer/config/analysis/__init__.py +2 -0
  17. data_designer/config/analysis/column_profilers.py +75 -7
  18. data_designer/config/analysis/column_statistics.py +192 -48
  19. data_designer/config/analysis/dataset_profiler.py +23 -5
  20. data_designer/config/analysis/utils/reporting.py +3 -3
  21. data_designer/config/base.py +3 -3
  22. data_designer/config/column_configs.py +27 -6
  23. data_designer/config/column_types.py +24 -17
  24. data_designer/config/config_builder.py +34 -26
  25. data_designer/config/data_designer_config.py +7 -7
  26. data_designer/config/datastore.py +6 -6
  27. data_designer/config/default_model_settings.py +27 -34
  28. data_designer/config/exports.py +8 -0
  29. data_designer/config/models.py +155 -29
  30. data_designer/config/preview_results.py +6 -8
  31. data_designer/config/processors.py +63 -2
  32. data_designer/config/sampler_constraints.py +1 -2
  33. data_designer/config/sampler_params.py +31 -31
  34. data_designer/config/seed.py +1 -2
  35. data_designer/config/utils/code_lang.py +4 -5
  36. data_designer/config/utils/constants.py +31 -8
  37. data_designer/config/utils/io_helpers.py +5 -5
  38. data_designer/config/utils/misc.py +1 -4
  39. data_designer/config/utils/numerical_helpers.py +2 -2
  40. data_designer/config/utils/type_helpers.py +3 -3
  41. data_designer/config/utils/validation.py +7 -8
  42. data_designer/config/utils/visualization.py +32 -17
  43. data_designer/config/validator_params.py +4 -8
  44. data_designer/engine/analysis/column_profilers/base.py +0 -7
  45. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
  46. data_designer/engine/analysis/column_statistics.py +16 -16
  47. data_designer/engine/analysis/dataset_profiler.py +25 -4
  48. data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
  49. data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
  50. data_designer/engine/column_generators/generators/base.py +34 -0
  51. data_designer/engine/column_generators/generators/embedding.py +45 -0
  52. data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
  53. data_designer/engine/column_generators/registry.py +4 -2
  54. data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
  55. data_designer/engine/configurable_task.py +2 -2
  56. data_designer/engine/dataset_builders/artifact_storage.py +1 -2
  57. data_designer/engine/dataset_builders/column_wise_builder.py +11 -10
  58. data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
  59. data_designer/engine/models/facade.py +66 -9
  60. data_designer/engine/models/litellm_overrides.py +5 -6
  61. data_designer/engine/models/parsers/errors.py +2 -4
  62. data_designer/engine/models/parsers/parser.py +2 -3
  63. data_designer/engine/models/parsers/postprocessors.py +3 -4
  64. data_designer/engine/models/parsers/types.py +4 -4
  65. data_designer/engine/models/registry.py +20 -11
  66. data_designer/engine/models/usage.py +7 -9
  67. data_designer/engine/processing/ginja/ast.py +1 -2
  68. data_designer/engine/processing/utils.py +40 -2
  69. data_designer/engine/registry/base.py +12 -12
  70. data_designer/engine/sampling_gen/constraints.py +1 -2
  71. data_designer/engine/sampling_gen/data_sources/base.py +14 -14
  72. data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
  73. data_designer/engine/sampling_gen/people_gen.py +3 -7
  74. data_designer/engine/validators/base.py +2 -2
  75. data_designer/logging.py +2 -2
  76. data_designer/plugin_manager.py +3 -3
  77. data_designer/plugins/plugin.py +3 -3
  78. data_designer/plugins/registry.py +2 -2
  79. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/METADATA +1 -1
  80. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/RECORD +83 -77
  81. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
  82. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
  83. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -20,10 +20,8 @@ from data_designer.config.analysis.column_statistics import (
20
20
  )
21
21
  from data_designer.config.column_configs import (
22
22
  LLMTextColumnConfig,
23
- SingleColumnConfig,
24
- ValidationColumnConfig,
25
23
  )
26
- from data_designer.engine.column_generators.generators.llm_generators import (
24
+ from data_designer.engine.column_generators.utils.prompt_renderer import (
27
25
  PromptType,
28
26
  RecordBasedPromptRenderer,
29
27
  create_response_recipe,
@@ -39,41 +37,54 @@ logger = logging.getLogger(__name__)
39
37
 
40
38
 
41
39
  def calculate_column_distribution(
42
- column_config: SingleColumnConfig, df: pd.DataFrame, distribution_type: ColumnDistributionType
40
+ column_name: str, df: pd.DataFrame, distribution_type: ColumnDistributionType
43
41
  ) -> dict[str, CategoricalDistribution | NumericalDistribution | MissingValue | None]:
44
42
  distribution_type = ColumnDistributionType(distribution_type)
45
43
  try:
46
44
  if distribution_type == ColumnDistributionType.CATEGORICAL:
47
45
  return {
48
46
  "distribution_type": ColumnDistributionType.CATEGORICAL,
49
- "distribution": CategoricalDistribution.from_series(df[column_config.name]),
47
+ "distribution": CategoricalDistribution.from_series(df[column_name]),
50
48
  }
51
49
 
52
50
  if distribution_type == ColumnDistributionType.NUMERICAL:
53
51
  return {
54
52
  "distribution_type": ColumnDistributionType.NUMERICAL,
55
- "distribution": NumericalDistribution.from_series(df[column_config.name]),
53
+ "distribution": NumericalDistribution.from_series(df[column_name]),
56
54
  }
57
55
  except Exception as e:
58
- logger.warning(f"{WARNING_PREFIX} failed to calculate column distribution for '{column_config.name}' {e}")
56
+ logger.warning(f"{WARNING_PREFIX} failed to calculate column distribution for '{column_name}' {e}")
59
57
  return {
60
58
  "distribution_type": ColumnDistributionType.UNKNOWN,
61
59
  "distribution": MissingValue.CALCULATION_FAILED,
62
60
  }
63
61
 
64
62
 
65
- def calculate_general_column_info(column_config: SingleColumnConfig, df: pd.DataFrame) -> dict[str, Any]:
63
+ def calculate_general_column_info(column_name: str, df: pd.DataFrame) -> dict[str, Any]:
66
64
  try:
67
- _df = pd.DataFrame(df[column_config.name].apply(ensure_hashable))
65
+ _df = pd.DataFrame(df[column_name].apply(ensure_hashable))
66
+
67
+ if has_pyarrow_backend(df):
68
+ pyarrow_dtype = str(df[column_name].dtype.pyarrow_dtype)
69
+ simple_dtype = convert_pyarrow_dtype_to_simple_dtype(df[column_name].dtype.pyarrow_dtype)
70
+ else:
71
+ # We do not log a warning at the column-level because it would be too noisy.
72
+ # However, there is a logged warning at the dataset-profiler level.
73
+ try:
74
+ simple_dtype = get_column_data_type_from_first_non_null_value(column_name, df)
75
+ except Exception:
76
+ simple_dtype = MissingValue.CALCULATION_FAILED
77
+ pyarrow_dtype = "n/a"
78
+
68
79
  return {
69
- "pyarrow_dtype": str(df[column_config.name].dtype.pyarrow_dtype),
70
- "simple_dtype": convert_pyarrow_dtype_to_simple_dtype(df[column_config.name].dtype.pyarrow_dtype),
71
- "num_records": len(_df[column_config.name]),
72
- "num_null": _df[column_config.name].isnull().sum(),
73
- "num_unique": _df[column_config.name].nunique(),
80
+ "pyarrow_dtype": pyarrow_dtype,
81
+ "simple_dtype": simple_dtype,
82
+ "num_records": len(_df[column_name]),
83
+ "num_null": _df[column_name].isnull().sum(),
84
+ "num_unique": _df[column_name].nunique(),
74
85
  }
75
86
  except Exception as e:
76
- logger.warning(f"{WARNING_PREFIX} failed to calculate general column info for '{column_config.name}': {e}")
87
+ logger.warning(f"{WARNING_PREFIX} failed to calculate general column info for '{column_name}': {e}")
77
88
  return {
78
89
  "pyarrow_dtype": MissingValue.CALCULATION_FAILED,
79
90
  "simple_dtype": MissingValue.CALCULATION_FAILED,
@@ -83,7 +94,7 @@ def calculate_general_column_info(column_config: SingleColumnConfig, df: pd.Data
83
94
  }
84
95
 
85
96
 
86
- def calculate_prompt_token_stats(
97
+ def calculate_input_token_stats(
87
98
  column_config: LLMTextColumnConfig, df: pd.DataFrame
88
99
  ) -> dict[str, float | MissingValue]:
89
100
  try:
@@ -100,22 +111,20 @@ def calculate_prompt_token_stats(
100
111
  concatenated_prompt = str(system_prompt + "\n\n" + prompt)
101
112
  num_tokens.append(len(TOKENIZER.encode(concatenated_prompt, disallowed_special=())))
102
113
  except Exception as e:
103
- logger.warning(
104
- f"{WARNING_PREFIX} failed to calculate prompt token stats for column {column_config.name!r}: {e}"
105
- )
114
+ logger.warning(f"{WARNING_PREFIX} failed to calculate input token stats for column {column_config.name!r}: {e}")
106
115
  return {
107
- "prompt_tokens_mean": MissingValue.CALCULATION_FAILED,
108
- "prompt_tokens_median": MissingValue.CALCULATION_FAILED,
109
- "prompt_tokens_stddev": MissingValue.CALCULATION_FAILED,
116
+ "input_tokens_mean": MissingValue.CALCULATION_FAILED,
117
+ "input_tokens_median": MissingValue.CALCULATION_FAILED,
118
+ "input_tokens_stddev": MissingValue.CALCULATION_FAILED,
110
119
  }
111
120
  return {
112
- "prompt_tokens_mean": np.mean(num_tokens),
113
- "prompt_tokens_median": np.median(num_tokens),
114
- "prompt_tokens_stddev": np.std(num_tokens),
121
+ "input_tokens_mean": np.mean(num_tokens),
122
+ "input_tokens_median": np.median(num_tokens),
123
+ "input_tokens_stddev": np.std(num_tokens),
115
124
  }
116
125
 
117
126
 
118
- def calculate_completion_token_stats(
127
+ def calculate_output_token_stats(
119
128
  column_config: LLMTextColumnConfig, df: pd.DataFrame
120
129
  ) -> dict[str, float | MissingValue]:
121
130
  try:
@@ -123,34 +132,32 @@ def calculate_completion_token_stats(
123
132
  lambda value: len(TOKENIZER.encode(str(value), disallowed_special=()))
124
133
  )
125
134
  return {
126
- "completion_tokens_mean": tokens_per_record.mean(),
127
- "completion_tokens_median": tokens_per_record.median(),
128
- "completion_tokens_stddev": tokens_per_record.std(),
135
+ "output_tokens_mean": tokens_per_record.mean(),
136
+ "output_tokens_median": tokens_per_record.median(),
137
+ "output_tokens_stddev": tokens_per_record.std(),
129
138
  }
130
139
  except Exception as e:
131
- logger.warning(
132
- f"{WARNING_PREFIX} failed to calculate completion token stats for column {column_config.name}: {e}"
133
- )
140
+ logger.warning(f"{WARNING_PREFIX} failed to calculate output token stats for column {column_config.name}: {e}")
134
141
  return {
135
- "completion_tokens_mean": MissingValue.CALCULATION_FAILED,
136
- "completion_tokens_median": MissingValue.CALCULATION_FAILED,
137
- "completion_tokens_stddev": MissingValue.CALCULATION_FAILED,
142
+ "output_tokens_mean": MissingValue.CALCULATION_FAILED,
143
+ "output_tokens_median": MissingValue.CALCULATION_FAILED,
144
+ "output_tokens_stddev": MissingValue.CALCULATION_FAILED,
138
145
  }
139
146
 
140
147
 
141
148
  def calculate_token_stats(column_config: LLMTextColumnConfig, df: pd.DataFrame) -> dict[str, float | MissingValue]:
142
149
  return {
143
- **calculate_prompt_token_stats(column_config, df),
144
- **calculate_completion_token_stats(column_config, df),
150
+ **calculate_input_token_stats(column_config, df),
151
+ **calculate_output_token_stats(column_config, df),
145
152
  }
146
153
 
147
154
 
148
- def calculate_validation_column_info(column_config: ValidationColumnConfig, df: pd.DataFrame) -> dict[str, Any]:
155
+ def calculate_validation_column_info(column_name: str, df: pd.DataFrame) -> dict[str, Any]:
149
156
  try:
150
- return {"num_valid_records": df[column_config.name].apply(lambda x: ensure_boolean(x["is_valid"])).sum()}
157
+ return {"num_valid_records": df[column_name].apply(lambda x: ensure_boolean(x["is_valid"])).sum()}
151
158
  except Exception as e:
152
159
  logger.warning(
153
- f"{WARNING_PREFIX} failed to calculate code validation column info for column {column_config.name}: {e}"
160
+ f"{WARNING_PREFIX} failed to calculate code validation column info for column {column_name}: {e}"
154
161
  )
155
162
  return {"num_valid_records": MissingValue.CALCULATION_FAILED}
156
163
 
@@ -160,22 +167,33 @@ def convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype: pa.DataType) -> str:
160
167
  return f"list[{convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype.value_type)}]"
161
168
  if isinstance(pyarrow_dtype, pa.StructType):
162
169
  return "dict"
163
- pyarrow_dtype_str = str(pyarrow_dtype)
164
- if "int" in pyarrow_dtype_str:
170
+ return convert_to_simple_dtype(str(pyarrow_dtype))
171
+
172
+
173
+ def convert_to_simple_dtype(dtype: str) -> str:
174
+ if "int" in dtype:
165
175
  return "int"
166
- if "double" in pyarrow_dtype_str:
176
+ if "double" in dtype:
167
177
  return "float"
168
- if "float" in pyarrow_dtype_str:
178
+ if "float" in dtype:
169
179
  return "float"
170
- if "string" in pyarrow_dtype_str:
180
+ if "str" in dtype:
171
181
  return "string"
172
- if "timestamp" in pyarrow_dtype_str:
182
+ if "timestamp" in dtype:
173
183
  return "timestamp"
174
- if "time" in pyarrow_dtype_str:
184
+ if "time" in dtype:
175
185
  return "time"
176
- if "date" in pyarrow_dtype_str:
186
+ if "date" in dtype:
177
187
  return "date"
178
- return pyarrow_dtype_str
188
+ return dtype
189
+
190
+
191
+ def get_column_data_type_from_first_non_null_value(column_name: str, df: pd.DataFrame) -> str:
192
+ df_no_nulls = df[column_name].dropna()
193
+ if len(df_no_nulls) == 0:
194
+ return MissingValue.CALCULATION_FAILED
195
+ dtype = type(df_no_nulls.iloc[0]).__name__
196
+ return convert_to_simple_dtype(dtype)
179
197
 
180
198
 
181
199
  def ensure_hashable(x: Any) -> str:
@@ -207,3 +225,7 @@ def ensure_boolean(v: bool | str | int | None) -> bool:
207
225
  if v is None:
208
226
  return False
209
227
  raise ValueError(f"Invalid boolean value: {v}")
228
+
229
+
230
+ def has_pyarrow_backend(df: pd.DataFrame) -> bool:
231
+ return all(isinstance(dtype, pd.ArrowDtype) for dtype in df.dtypes)
@@ -3,7 +3,7 @@
3
3
 
4
4
  import logging
5
5
  from collections import defaultdict
6
- from typing import Any, Optional, Union
6
+ from typing import Any
7
7
 
8
8
  import pandas as pd
9
9
 
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
21
21
 
22
22
  def extract_judge_score_distributions(
23
23
  column_config: LLMJudgeColumnConfig, df: pd.DataFrame
24
- ) -> Union[JudgeScoreDistributions, MissingValue]:
24
+ ) -> JudgeScoreDistributions | MissingValue:
25
25
  scores = defaultdict(list)
26
26
  reasoning = defaultdict(list)
27
27
 
@@ -32,7 +32,7 @@ def extract_judge_score_distributions(
32
32
 
33
33
  for score in column_config.scores:
34
34
  is_numerical = True
35
- name = score.name.lower()
35
+ name = score.name
36
36
  for results in df[column_config.name]:
37
37
  try:
38
38
  score = results[name].get("score", None)
@@ -79,10 +79,10 @@ def extract_judge_score_distributions(
79
79
 
80
80
 
81
81
  def sample_scores_and_reasoning(
82
- scores: list[Union[int, str]],
82
+ scores: list[int | str],
83
83
  reasoning: list[str],
84
84
  num_samples: int,
85
- random_seed: Optional[int] = None,
85
+ random_seed: int | None = None,
86
86
  ) -> list[JudgeScoreSample]:
87
87
  if len(scores) != len(reasoning):
88
88
  raise ValueError("scores and reasoning must have the same length")
@@ -1,13 +1,20 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ import functools
5
+ import logging
4
6
  from abc import ABC, abstractmethod
5
7
  from typing import overload
6
8
 
7
9
  import pandas as pd
8
10
 
11
+ from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP
12
+ from data_designer.config.models import BaseInferenceParams, ModelConfig
9
13
  from data_designer.config.utils.type_helpers import StrEnum
10
14
  from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT
15
+ from data_designer.engine.models.facade import ModelFacade
16
+
17
+ logger = logging.getLogger(__name__)
11
18
 
12
19
 
13
20
  class GenerationStrategy(StrEnum):
@@ -59,3 +66,30 @@ class FromScratchColumnGenerator(ColumnGenerator[TaskConfigT], ABC):
59
66
 
60
67
  @abstractmethod
61
68
  def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ...
69
+
70
+
71
+ class WithModelGeneration:
72
+ @functools.cached_property
73
+ def model(self) -> ModelFacade:
74
+ return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias)
75
+
76
+ @functools.cached_property
77
+ def model_config(self) -> ModelConfig:
78
+ return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias)
79
+
80
+ @functools.cached_property
81
+ def inference_parameters(self) -> BaseInferenceParams:
82
+ return self.model_config.inference_parameters
83
+
84
+ def log_pre_generation(self) -> None:
85
+ emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type]
86
+ logger.info(f"{emoji} Preparing {self.config.column_type} column generation")
87
+ logger.info(f" |-- column name: {self.config.name!r}")
88
+ logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}")
89
+ if self.model_config.provider is None:
90
+ logger.info(f" |-- default model provider: {self._get_provider_name()!r}")
91
+
92
+ def _get_provider_name(self) -> str:
93
+ model_alias = self.model_config.alias
94
+ provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias)
95
+ return provider.name
@@ -0,0 +1,45 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ from pydantic import BaseModel, computed_field
6
+
7
+ from data_designer.config.column_configs import EmbeddingColumnConfig
8
+ from data_designer.engine.column_generators.generators.base import (
9
+ ColumnGenerator,
10
+ GenerationStrategy,
11
+ GeneratorMetadata,
12
+ WithModelGeneration,
13
+ )
14
+ from data_designer.engine.processing.utils import deserialize_json_values, parse_list_string
15
+ from data_designer.engine.resources.resource_provider import ResourceType
16
+
17
+
18
+ class EmbeddingGenerationResult(BaseModel):
19
+ embeddings: list[list[float]]
20
+
21
+ @computed_field
22
+ def num_embeddings(self) -> int:
23
+ return len(self.embeddings)
24
+
25
+ @computed_field
26
+ def dimension(self) -> int:
27
+ return len(self.embeddings[0]) if len(self.embeddings) > 0 else 0
28
+
29
+
30
+ class EmbeddingCellGenerator(WithModelGeneration, ColumnGenerator[EmbeddingColumnConfig]):
31
+ @staticmethod
32
+ def metadata() -> GeneratorMetadata:
33
+ return GeneratorMetadata(
34
+ name="embedding_cell_generator",
35
+ description="Generate embeddings for a text column.",
36
+ generation_strategy=GenerationStrategy.CELL_BY_CELL,
37
+ required_resources=[ResourceType.MODEL_REGISTRY],
38
+ )
39
+
40
+ def generate(self, data: dict) -> dict:
41
+ deserialized_record = deserialize_json_values(data)
42
+ input_texts = parse_list_string(deserialized_record[self.config.target_column])
43
+ embeddings = self.model.generate_text_embeddings(input_texts=input_texts)
44
+ data[self.config.name] = EmbeddingGenerationResult(embeddings=embeddings).model_dump(mode="json")
45
+ return data
@@ -10,43 +10,41 @@ from data_designer.config.column_configs import (
10
10
  LLMStructuredColumnConfig,
11
11
  LLMTextColumnConfig,
12
12
  )
13
- from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP
14
- from data_designer.config.models import InferenceParameters, ModelConfig
15
13
  from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX
16
14
  from data_designer.engine.column_generators.generators.base import (
17
15
  ColumnGenerator,
18
16
  GenerationStrategy,
19
17
  GeneratorMetadata,
18
+ WithModelGeneration,
20
19
  )
21
20
  from data_designer.engine.column_generators.utils.prompt_renderer import (
22
21
  PromptType,
23
22
  RecordBasedPromptRenderer,
24
23
  create_response_recipe,
25
24
  )
26
- from data_designer.engine.models.facade import ModelFacade
27
25
  from data_designer.engine.models.recipes.base import ResponseRecipe
28
26
  from data_designer.engine.processing.utils import deserialize_json_values
29
27
  from data_designer.engine.resources.resource_provider import ResourceType
30
28
 
31
- DEFAULT_MAX_CONVERSATION_RESTARTS = 5
32
- DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0
29
+ logger = logging.getLogger(__name__)
33
30
 
34
31
 
35
- logger = logging.getLogger(__name__)
32
+ DEFAULT_MAX_CONVERSATION_RESTARTS = 5
33
+ DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0
36
34
 
37
35
 
38
- class WithLLMGeneration:
36
+ class WithChatCompletionGeneration(WithModelGeneration):
39
37
  @functools.cached_property
40
- def model(self) -> ModelFacade:
41
- return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias)
38
+ def response_recipe(self) -> ResponseRecipe:
39
+ return create_response_recipe(self.config, self.model_config)
42
40
 
43
- @functools.cached_property
44
- def model_config(self) -> ModelConfig:
45
- return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias)
41
+ @property
42
+ def max_conversation_correction_steps(self) -> int:
43
+ return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
46
44
 
47
- @functools.cached_property
48
- def inference_parameters(self) -> InferenceParameters:
49
- return self.model_config.inference_parameters
45
+ @property
46
+ def max_conversation_restarts(self) -> int:
47
+ return DEFAULT_MAX_CONVERSATION_RESTARTS
50
48
 
51
49
  @functools.cached_property
52
50
  def prompt_renderer(self) -> RecordBasedPromptRenderer:
@@ -59,18 +57,6 @@ class WithLLMGeneration:
59
57
  },
60
58
  )
61
59
 
62
- @functools.cached_property
63
- def response_recipe(self) -> ResponseRecipe:
64
- return create_response_recipe(self.config, self.model_config)
65
-
66
- @property
67
- def max_conversation_correction_steps(self) -> int:
68
- return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
69
-
70
- @property
71
- def max_conversation_restarts(self) -> int:
72
- return DEFAULT_MAX_CONVERSATION_RESTARTS
73
-
74
60
  def generate(self, data: dict) -> dict:
75
61
  deserialized_record = deserialize_json_values(data)
76
62
 
@@ -96,7 +82,6 @@ class WithLLMGeneration:
96
82
  max_correction_steps=self.max_conversation_correction_steps,
97
83
  max_conversation_restarts=self.max_conversation_restarts,
98
84
  purpose=f"running generation for column '{self.config.name}'",
99
- **self.inference_parameters.generate_kwargs,
100
85
  )
101
86
 
102
87
  data[self.config.name] = deserialize_json_values(self.response_recipe.serialize_output(response))
@@ -106,21 +91,8 @@ class WithLLMGeneration:
106
91
 
107
92
  return data
108
93
 
109
- def log_pre_generation(self) -> None:
110
- emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type]
111
- logger.info(f"{emoji} Preparing {self.config.column_type} column generation")
112
- logger.info(f" |-- column name: {self.config.name!r}")
113
- logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}")
114
- if self.model_config.provider is None:
115
- logger.info(f" |-- default model provider: {self._get_provider_name()!r}")
116
-
117
- def _get_provider_name(self) -> str:
118
- model_alias = self.model_config.alias
119
- provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias)
120
- return provider.name
121
-
122
94
 
123
- class LLMTextCellGenerator(WithLLMGeneration, ColumnGenerator[LLMTextColumnConfig]):
95
+ class LLMTextCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMTextColumnConfig]):
124
96
  @staticmethod
125
97
  def metadata() -> GeneratorMetadata:
126
98
  return GeneratorMetadata(
@@ -131,7 +103,7 @@ class LLMTextCellGenerator(WithLLMGeneration, ColumnGenerator[LLMTextColumnConfi
131
103
  )
132
104
 
133
105
 
134
- class LLMCodeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMCodeColumnConfig]):
106
+ class LLMCodeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMCodeColumnConfig]):
135
107
  @staticmethod
136
108
  def metadata() -> GeneratorMetadata:
137
109
  return GeneratorMetadata(
@@ -142,7 +114,7 @@ class LLMCodeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMCodeColumnConfi
142
114
  )
143
115
 
144
116
 
145
- class LLMStructuredCellGenerator(WithLLMGeneration, ColumnGenerator[LLMStructuredColumnConfig]):
117
+ class LLMStructuredCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMStructuredColumnConfig]):
146
118
  @staticmethod
147
119
  def metadata() -> GeneratorMetadata:
148
120
  return GeneratorMetadata(
@@ -153,7 +125,7 @@ class LLMStructuredCellGenerator(WithLLMGeneration, ColumnGenerator[LLMStructure
153
125
  )
154
126
 
155
127
 
156
- class LLMJudgeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMJudgeColumnConfig]):
128
+ class LLMJudgeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMJudgeColumnConfig]):
157
129
  @staticmethod
158
130
  def metadata() -> GeneratorMetadata:
159
131
  return GeneratorMetadata(
@@ -163,10 +135,6 @@ class LLMJudgeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMJudgeColumnCon
163
135
  required_resources=[ResourceType.MODEL_REGISTRY],
164
136
  )
165
137
 
166
- @property
167
- def max_conversation_correction_steps(self) -> int:
168
- return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
169
-
170
138
  @property
171
139
  def max_conversation_restarts(self) -> int:
172
140
  return 2 * DEFAULT_MAX_CONVERSATION_RESTARTS
@@ -3,6 +3,7 @@
3
3
 
4
4
  from data_designer.config.base import ConfigBase
5
5
  from data_designer.config.column_configs import (
6
+ EmbeddingColumnConfig,
6
7
  ExpressionColumnConfig,
7
8
  LLMCodeColumnConfig,
8
9
  LLMJudgeColumnConfig,
@@ -12,8 +13,9 @@ from data_designer.config.column_configs import (
12
13
  )
13
14
  from data_designer.config.column_types import DataDesignerColumnType
14
15
  from data_designer.engine.column_generators.generators.base import ColumnGenerator
16
+ from data_designer.engine.column_generators.generators.embedding import EmbeddingCellGenerator
15
17
  from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator
16
- from data_designer.engine.column_generators.generators.llm_generators import (
18
+ from data_designer.engine.column_generators.generators.llm_completion import (
17
19
  LLMCodeCellGenerator,
18
20
  LLMJudgeCellGenerator,
19
21
  LLMStructuredCellGenerator,
@@ -40,11 +42,11 @@ def create_default_column_generator_registry(with_plugins: bool = True) -> Colum
40
42
  registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig)
41
43
  registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig)
42
44
  registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig)
45
+ registry.register(DataDesignerColumnType.EMBEDDING, EmbeddingCellGenerator, EmbeddingColumnConfig)
43
46
  registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig)
44
47
  registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig)
45
48
  registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig)
46
49
  registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig)
47
-
48
50
  if with_plugins:
49
51
  for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR):
50
52
  registry.register(
@@ -2,7 +2,6 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from enum import Enum
5
- from typing import Type
6
5
 
7
6
  from pydantic import BaseModel, ConfigDict, Field, create_model
8
7
 
@@ -19,7 +18,7 @@ class BaseJudgeResponse(BaseModel):
19
18
  reasoning: str = Field(..., description="Reasoning for the assigned score.")
20
19
 
21
20
 
22
- def _stringify_scoring(options: dict, enum_type: Type[Enum]) -> str:
21
+ def _stringify_scoring(options: dict, enum_type: type[Enum]) -> str:
23
22
  """Convert score descriptions into a single text block."""
24
23
  list_block = "\n".join(
25
24
  [SCORING_FORMAT.format(score=score, description=description) for score, description in options.items()]
@@ -27,7 +26,7 @@ def _stringify_scoring(options: dict, enum_type: Type[Enum]) -> str:
27
26
  return SCORE_FIELD_DESCRIPTION_FORMAT.format(enum_name=enum_type.__name__, scoring=list_block)
28
27
 
29
28
 
30
- def create_judge_response_model(score: Score) -> Type[BaseJudgeResponse]:
29
+ def create_judge_response_model(score: Score) -> type[BaseJudgeResponse]:
31
30
  """Create a JudgeResponse data type."""
32
31
  enum_members = {}
33
32
  for option in score.options.keys():
@@ -46,12 +45,12 @@ def create_judge_response_model(score: Score) -> Type[BaseJudgeResponse]:
46
45
 
47
46
 
48
47
  def create_judge_structured_output_model(
49
- judge_responses: list[Type[BaseJudgeResponse]],
50
- ) -> Type[BaseModel]:
48
+ judge_responses: list[type[BaseJudgeResponse]],
49
+ ) -> type[BaseModel]:
51
50
  """Create a JudgeStructuredOutput class dynamically."""
52
51
  return create_model(
53
52
  "JudgeStructuredOutput",
54
53
  __doc__=f"Response schema for scores with the following names: {[response.__name__ for response in judge_responses]}.",
55
54
  __base__=BaseModel,
56
- **{response.__name__.lower(): (response, ...) for response in judge_responses},
55
+ **{response.__name__: (response, ...) for response in judge_responses},
57
56
  )
@@ -3,7 +3,7 @@
3
3
 
4
4
  from abc import ABC, abstractmethod
5
5
  from pathlib import Path
6
- from typing import Generic, Type, TypeVar, get_origin
6
+ from typing import Generic, TypeVar, get_origin
7
7
 
8
8
  import pandas as pd
9
9
 
@@ -30,7 +30,7 @@ class ConfigurableTask(ABC, Generic[TaskConfigT]):
30
30
  self._initialize()
31
31
 
32
32
  @classmethod
33
- def get_config_type(cls) -> Type[TaskConfigT]:
33
+ def get_config_type(cls) -> type[TaskConfigT]:
34
34
  for base in cls.__orig_bases__:
35
35
  if hasattr(base, "__args__") and len(base.__args__) == 1:
36
36
  arg = base.__args__[0]
@@ -7,7 +7,6 @@ import shutil
7
7
  from datetime import datetime
8
8
  from functools import cached_property
9
9
  from pathlib import Path
10
- from typing import Union
11
10
 
12
11
  import pandas as pd
13
12
  from pydantic import BaseModel, field_validator, model_validator
@@ -77,7 +76,7 @@ class ArtifactStorage(BaseModel):
77
76
  return self.base_dataset_path / self.processors_outputs_folder_name
78
77
 
79
78
  @field_validator("artifact_path")
80
- def validate_artifact_path(cls, v: Union[Path, str]) -> Path:
79
+ def validate_artifact_path(cls, v: Path | str) -> Path:
81
80
  v = Path(v)
82
81
  if not v.is_dir():
83
82
  raise ArtifactStorageError("Artifact path must exist and be a directory")
@@ -10,15 +10,18 @@ from typing import Callable
10
10
 
11
11
  import pandas as pd
12
12
 
13
- from data_designer.config.column_types import ColumnConfigT, column_type_is_llm_generated
13
+ from data_designer.config.column_types import ColumnConfigT, column_type_is_model_generated
14
14
  from data_designer.config.dataset_builders import BuildStage
15
15
  from data_designer.config.processors import (
16
16
  DropColumnsProcessorConfig,
17
17
  ProcessorConfig,
18
18
  ProcessorType,
19
19
  )
20
- from data_designer.engine.column_generators.generators.base import ColumnGenerator, GenerationStrategy
21
- from data_designer.engine.column_generators.generators.llm_generators import WithLLMGeneration
20
+ from data_designer.engine.column_generators.generators.base import (
21
+ ColumnGenerator,
22
+ GenerationStrategy,
23
+ WithModelGeneration,
24
+ )
22
25
  from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
23
26
  from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
24
27
  from data_designer.engine.dataset_builders.multi_column_configs import (
@@ -72,7 +75,7 @@ class ColumnWiseDatasetBuilder:
72
75
 
73
76
  @functools.cached_property
74
77
  def llm_generated_column_configs(self) -> list[ColumnConfigT]:
75
- return [config for config in self.single_column_configs if column_type_is_llm_generated(config.column_type)]
78
+ return [config for config in self.single_column_configs if column_type_is_model_generated(config.column_type)]
76
79
 
77
80
  def build(
78
81
  self,
@@ -169,10 +172,8 @@ class ColumnWiseDatasetBuilder:
169
172
 
170
173
  def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
171
174
  max_workers = MAX_CONCURRENCY_PER_NON_LLM_GENERATOR
172
- if isinstance(generator, WithLLMGeneration):
175
+ if isinstance(generator, WithModelGeneration):
173
176
  max_workers = generator.inference_parameters.max_parallel_requests
174
- elif hasattr(generator.config, "max_parallel_requests"):
175
- max_workers = generator.config.max_parallel_requests
176
177
  self._fan_out_with_threads(generator, max_workers=max_workers)
177
178
 
178
179
  def _run_full_column_generator(self, generator: ColumnGenerator) -> None:
@@ -180,12 +181,12 @@ class ColumnWiseDatasetBuilder:
180
181
  self.batch_manager.update_records(df.to_dict(orient="records"))
181
182
 
182
183
  def _run_model_health_check_if_needed(self) -> bool:
183
- if any(column_type_is_llm_generated(config.column_type) for config in self.single_column_configs):
184
+ if any(column_type_is_model_generated(config.column_type) for config in self.single_column_configs):
184
185
  self._resource_provider.model_registry.run_health_check(
185
- set(config.model_alias for config in self.llm_generated_column_configs)
186
+ list(set(config.model_alias for config in self.llm_generated_column_configs))
186
187
  )
187
188
 
188
- def _fan_out_with_threads(self, generator: WithLLMGeneration, max_workers: int) -> None:
189
+ def _fan_out_with_threads(self, generator: WithModelGeneration, max_workers: int) -> None:
189
190
  if generator.generation_strategy != GenerationStrategy.CELL_BY_CELL:
190
191
  raise DatasetGenerationError(
191
192
  f"Generator {generator.metadata().name} is not a {GenerationStrategy.CELL_BY_CELL} "