data-designer 0.1.5__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. data_designer/_version.py +2 -2
  2. data_designer/cli/README.md +15 -1
  3. data_designer/cli/commands/download.py +56 -0
  4. data_designer/cli/commands/list.py +4 -18
  5. data_designer/cli/controllers/__init__.py +2 -1
  6. data_designer/cli/controllers/download_controller.py +217 -0
  7. data_designer/cli/controllers/model_controller.py +4 -3
  8. data_designer/cli/forms/field.py +65 -19
  9. data_designer/cli/forms/model_builder.py +251 -44
  10. data_designer/cli/main.py +11 -1
  11. data_designer/cli/repositories/persona_repository.py +88 -0
  12. data_designer/cli/services/__init__.py +2 -1
  13. data_designer/cli/services/download_service.py +97 -0
  14. data_designer/cli/ui.py +131 -0
  15. data_designer/cli/utils.py +34 -0
  16. data_designer/config/analysis/__init__.py +2 -0
  17. data_designer/config/analysis/column_profilers.py +75 -7
  18. data_designer/config/analysis/column_statistics.py +192 -48
  19. data_designer/config/analysis/dataset_profiler.py +23 -5
  20. data_designer/config/analysis/utils/reporting.py +3 -3
  21. data_designer/config/base.py +3 -3
  22. data_designer/config/column_configs.py +27 -6
  23. data_designer/config/column_types.py +24 -17
  24. data_designer/config/config_builder.py +36 -27
  25. data_designer/config/data_designer_config.py +7 -7
  26. data_designer/config/datastore.py +6 -6
  27. data_designer/config/default_model_settings.py +27 -34
  28. data_designer/config/exports.py +8 -0
  29. data_designer/config/models.py +155 -29
  30. data_designer/config/preview_results.py +6 -8
  31. data_designer/config/processors.py +63 -2
  32. data_designer/config/sampler_constraints.py +1 -2
  33. data_designer/config/sampler_params.py +50 -31
  34. data_designer/config/seed.py +1 -2
  35. data_designer/config/utils/code_lang.py +4 -5
  36. data_designer/config/utils/constants.py +31 -8
  37. data_designer/config/utils/io_helpers.py +5 -5
  38. data_designer/config/utils/misc.py +1 -4
  39. data_designer/config/utils/numerical_helpers.py +2 -2
  40. data_designer/config/utils/type_helpers.py +3 -3
  41. data_designer/config/utils/validation.py +7 -8
  42. data_designer/config/utils/visualization.py +32 -17
  43. data_designer/config/validator_params.py +4 -8
  44. data_designer/engine/analysis/column_profilers/base.py +0 -7
  45. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
  46. data_designer/engine/analysis/column_statistics.py +16 -16
  47. data_designer/engine/analysis/dataset_profiler.py +25 -4
  48. data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
  49. data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
  50. data_designer/engine/column_generators/generators/base.py +34 -0
  51. data_designer/engine/column_generators/generators/embedding.py +45 -0
  52. data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
  53. data_designer/engine/column_generators/registry.py +4 -2
  54. data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
  55. data_designer/engine/configurable_task.py +2 -2
  56. data_designer/engine/dataset_builders/artifact_storage.py +1 -2
  57. data_designer/engine/dataset_builders/column_wise_builder.py +58 -15
  58. data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
  59. data_designer/engine/models/facade.py +66 -9
  60. data_designer/engine/models/litellm_overrides.py +5 -6
  61. data_designer/engine/models/parsers/errors.py +2 -4
  62. data_designer/engine/models/parsers/parser.py +2 -3
  63. data_designer/engine/models/parsers/postprocessors.py +3 -4
  64. data_designer/engine/models/parsers/types.py +4 -4
  65. data_designer/engine/models/registry.py +47 -12
  66. data_designer/engine/models/telemetry.py +355 -0
  67. data_designer/engine/models/usage.py +7 -9
  68. data_designer/engine/processing/ginja/ast.py +1 -2
  69. data_designer/engine/processing/utils.py +40 -2
  70. data_designer/engine/registry/base.py +12 -12
  71. data_designer/engine/sampling_gen/constraints.py +1 -2
  72. data_designer/engine/sampling_gen/data_sources/base.py +14 -14
  73. data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
  74. data_designer/engine/sampling_gen/people_gen.py +3 -7
  75. data_designer/engine/validators/base.py +2 -2
  76. data_designer/logging.py +2 -2
  77. data_designer/plugin_manager.py +3 -3
  78. data_designer/plugins/plugin.py +3 -3
  79. data_designer/plugins/registry.py +2 -2
  80. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/METADATA +32 -1
  81. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/RECORD +84 -77
  82. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/WHEEL +0 -0
  83. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/entry_points.txt +0 -0
  84. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -5,7 +5,6 @@ from __future__ import annotations
5
5
 
6
6
  import logging
7
7
  import random
8
- from typing import Union
9
8
 
10
9
  from data_designer.config.analysis.column_profilers import (
11
10
  JudgeScoreProfilerConfig,
@@ -69,7 +68,7 @@ class JudgeScoreProfiler(ColumnProfiler[JudgeScoreProfilerConfig]):
69
68
  )
70
69
 
71
70
  for score in column_config.scores:
72
- score_name = score.name.lower()
71
+ score_name = score.name
73
72
  logger.info(f"{random.choice(['👩‍⚖️', '👨‍⚖️'])} Summarizing LLM-as-judge score: '{score_name}'")
74
73
  score_sample = sample_scores_and_reasoning(
75
74
  scores=score_distributions.scores[score_name],
@@ -96,7 +95,7 @@ class JudgeScoreProfiler(ColumnProfiler[JudgeScoreProfilerConfig]):
96
95
  name: str,
97
96
  sample: list[JudgeScoreSample],
98
97
  histogram: CategoricalHistogramData,
99
- distribution: Union[CategoricalDistribution, NumericalDistribution, MissingValue],
98
+ distribution: CategoricalDistribution | NumericalDistribution | MissingValue,
100
99
  distribution_type: ColumnDistributionType,
101
100
  ) -> JudgeScoreSummary:
102
101
  if isinstance(distribution, MissingValue) or not sample:
@@ -4,7 +4,7 @@
4
4
  from __future__ import annotations
5
5
 
6
6
  import logging
7
- from typing import Any, Type, TypeAlias, Union
7
+ from typing import Any, TypeAlias
8
8
 
9
9
  import pandas as pd
10
10
  from pydantic import BaseModel
@@ -41,7 +41,7 @@ class GeneralColumnStatisticsCalculator(BaseModel):
41
41
  return self.column_config_with_df.df
42
42
 
43
43
  @property
44
- def column_statistics_type(self) -> Type[ColumnStatisticsT]:
44
+ def column_statistics_type(self) -> type[ColumnStatisticsT]:
45
45
  return DEFAULT_COLUMN_STATISTICS_MAP.get(self.column_config.column_type, GeneralColumnStatistics)
46
46
 
47
47
  def calculate(self) -> Self:
@@ -59,7 +59,7 @@ class GeneralColumnStatisticsCalculator(BaseModel):
59
59
  )
60
60
 
61
61
  def calculate_general_column_info(self) -> dict[str, Any]:
62
- return calculate_general_column_info(self.column_config, self.df)
62
+ return calculate_general_column_info(self.column_config.name, self.df)
63
63
 
64
64
  def __repr__(self) -> str:
65
65
  params = []
@@ -93,7 +93,7 @@ class SamplerColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
93
93
  return (
94
94
  {
95
95
  "sampler_type": SamplerType(self.column_config.sampler_type),
96
- **calculate_column_distribution(self.column_config, self.df, dist_type),
96
+ **calculate_column_distribution(self.column_config.name, self.df, dist_type),
97
97
  }
98
98
  if make_dist
99
99
  else {
@@ -109,23 +109,23 @@ class SeedDatasetColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
109
109
 
110
110
  class ValidationColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
111
111
  def calculate_validation_column_info(self) -> dict[str, Any]:
112
- return calculate_validation_column_info(self.column_config, self.df)
112
+ return calculate_validation_column_info(self.column_config.name, self.df)
113
113
 
114
114
 
115
115
  class ExpressionColumnStatisticsCalculator(GeneralColumnStatisticsCalculator): ...
116
116
 
117
117
 
118
- ColumnStatisticsCalculatorT: TypeAlias = Union[
119
- ExpressionColumnStatisticsCalculator,
120
- ValidationColumnStatisticsCalculator,
121
- GeneralColumnStatisticsCalculator,
122
- LLMCodeColumnStatisticsCalculator,
123
- LLMJudgedColumnStatisticsCalculator,
124
- LLMStructuredColumnStatisticsCalculator,
125
- LLMTextColumnStatisticsCalculator,
126
- SamplerColumnStatisticsCalculator,
127
- SeedDatasetColumnStatisticsCalculator,
128
- ]
118
+ ColumnStatisticsCalculatorT: TypeAlias = (
119
+ ExpressionColumnStatisticsCalculator
120
+ | ValidationColumnStatisticsCalculator
121
+ | GeneralColumnStatisticsCalculator
122
+ | LLMCodeColumnStatisticsCalculator
123
+ | LLMJudgedColumnStatisticsCalculator
124
+ | LLMStructuredColumnStatisticsCalculator
125
+ | LLMTextColumnStatisticsCalculator
126
+ | SamplerColumnStatisticsCalculator
127
+ | SeedDatasetColumnStatisticsCalculator
128
+ )
129
129
  DEFAULT_COLUMN_STATISTICS_CALCULATOR_MAP = {
130
130
  DataDesignerColumnType.EXPRESSION: ExpressionColumnStatisticsCalculator,
131
131
  DataDesignerColumnType.VALIDATION: ValidationColumnStatisticsCalculator,
@@ -6,6 +6,7 @@ from collections.abc import Sequence
6
6
  from functools import cached_property
7
7
 
8
8
  import pandas as pd
9
+ import pyarrow as pa
9
10
  from pydantic import Field, field_validator
10
11
 
11
12
  from data_designer.config.analysis.column_profilers import ColumnProfilerConfigT
@@ -19,10 +20,8 @@ from data_designer.config.column_types import (
19
20
  from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame, ColumnProfiler
20
21
  from data_designer.engine.analysis.column_statistics import get_column_statistics_calculator
21
22
  from data_designer.engine.analysis.errors import DatasetProfilerConfigurationError
22
- from data_designer.engine.dataset_builders.multi_column_configs import (
23
- DatasetBuilderColumnConfigT,
24
- MultiColumnConfig,
25
- )
23
+ from data_designer.engine.analysis.utils.column_statistics_calculations import has_pyarrow_backend
24
+ from data_designer.engine.dataset_builders.multi_column_configs import DatasetBuilderColumnConfigT, MultiColumnConfig
26
25
  from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
27
26
  from data_designer.engine.resources.resource_provider import ResourceProvider
28
27
 
@@ -68,6 +67,7 @@ class DataDesignerDatasetProfiler:
68
67
  logger.info("📐 Measuring dataset column statistics:")
69
68
 
70
69
  self._validate_schema_consistency(list(dataset.columns))
70
+ dataset = self._convert_to_pyarrow_backend_if_needed(dataset)
71
71
 
72
72
  column_statistics = []
73
73
  for c in self.config.column_configs:
@@ -100,6 +100,27 @@ class DataDesignerDatasetProfiler:
100
100
  column_profiles=column_profiles if column_profiles else None,
101
101
  )
102
102
 
103
+ def _convert_to_pyarrow_backend_if_needed(self, dataset: pd.DataFrame) -> pd.DataFrame:
104
+ if not has_pyarrow_backend(dataset):
105
+ try:
106
+ dataset = pa.Table.from_pandas(dataset).to_pandas(types_mapper=pd.ArrowDtype)
107
+ except Exception as e:
108
+ # For ArrowTypeError, the second arg contains the more informative message
109
+ if isinstance(e, pa.lib.ArrowTypeError) and len(e.args) > 1:
110
+ error_msg = str(e.args[1])
111
+ else:
112
+ error_msg = str(e)
113
+ for col in dataset.columns:
114
+ # Make sure column names are clear in the error message
115
+ error_msg = error_msg.replace(col, f"'{col}'")
116
+ logger.warning("⚠️ Unable to convert the dataset to a PyArrow backend")
117
+ logger.warning(f" |-- Conversion Error Message: {error_msg}")
118
+ logger.warning(" |-- This is often due to at least one column having mixed data types")
119
+ logger.warning(
120
+ " |-- Note: Reported data types will be inferred from the first non-null value of each column"
121
+ )
122
+ return dataset
123
+
103
124
  def _create_column_profiler(self, profiler_config: ColumnProfilerConfigT) -> ColumnProfiler:
104
125
  return self.registry.column_profilers.get_for_config_type(type(profiler_config))(
105
126
  config=profiler_config, resource_provider=self.resource_provider
@@ -20,10 +20,8 @@ from data_designer.config.analysis.column_statistics import (
20
20
  )
21
21
  from data_designer.config.column_configs import (
22
22
  LLMTextColumnConfig,
23
- SingleColumnConfig,
24
- ValidationColumnConfig,
25
23
  )
26
- from data_designer.engine.column_generators.generators.llm_generators import (
24
+ from data_designer.engine.column_generators.utils.prompt_renderer import (
27
25
  PromptType,
28
26
  RecordBasedPromptRenderer,
29
27
  create_response_recipe,
@@ -39,41 +37,54 @@ logger = logging.getLogger(__name__)
39
37
 
40
38
 
41
39
  def calculate_column_distribution(
42
- column_config: SingleColumnConfig, df: pd.DataFrame, distribution_type: ColumnDistributionType
40
+ column_name: str, df: pd.DataFrame, distribution_type: ColumnDistributionType
43
41
  ) -> dict[str, CategoricalDistribution | NumericalDistribution | MissingValue | None]:
44
42
  distribution_type = ColumnDistributionType(distribution_type)
45
43
  try:
46
44
  if distribution_type == ColumnDistributionType.CATEGORICAL:
47
45
  return {
48
46
  "distribution_type": ColumnDistributionType.CATEGORICAL,
49
- "distribution": CategoricalDistribution.from_series(df[column_config.name]),
47
+ "distribution": CategoricalDistribution.from_series(df[column_name]),
50
48
  }
51
49
 
52
50
  if distribution_type == ColumnDistributionType.NUMERICAL:
53
51
  return {
54
52
  "distribution_type": ColumnDistributionType.NUMERICAL,
55
- "distribution": NumericalDistribution.from_series(df[column_config.name]),
53
+ "distribution": NumericalDistribution.from_series(df[column_name]),
56
54
  }
57
55
  except Exception as e:
58
- logger.warning(f"{WARNING_PREFIX} failed to calculate column distribution for '{column_config.name}' {e}")
56
+ logger.warning(f"{WARNING_PREFIX} failed to calculate column distribution for '{column_name}' {e}")
59
57
  return {
60
58
  "distribution_type": ColumnDistributionType.UNKNOWN,
61
59
  "distribution": MissingValue.CALCULATION_FAILED,
62
60
  }
63
61
 
64
62
 
65
- def calculate_general_column_info(column_config: SingleColumnConfig, df: pd.DataFrame) -> dict[str, Any]:
63
+ def calculate_general_column_info(column_name: str, df: pd.DataFrame) -> dict[str, Any]:
66
64
  try:
67
- _df = pd.DataFrame(df[column_config.name].apply(ensure_hashable))
65
+ _df = pd.DataFrame(df[column_name].apply(ensure_hashable))
66
+
67
+ if has_pyarrow_backend(df):
68
+ pyarrow_dtype = str(df[column_name].dtype.pyarrow_dtype)
69
+ simple_dtype = convert_pyarrow_dtype_to_simple_dtype(df[column_name].dtype.pyarrow_dtype)
70
+ else:
71
+ # We do not log a warning at the column-level because it would be too noisy.
72
+ # However, there is a logged warning at the dataset-profiler level.
73
+ try:
74
+ simple_dtype = get_column_data_type_from_first_non_null_value(column_name, df)
75
+ except Exception:
76
+ simple_dtype = MissingValue.CALCULATION_FAILED
77
+ pyarrow_dtype = "n/a"
78
+
68
79
  return {
69
- "pyarrow_dtype": str(df[column_config.name].dtype.pyarrow_dtype),
70
- "simple_dtype": convert_pyarrow_dtype_to_simple_dtype(df[column_config.name].dtype.pyarrow_dtype),
71
- "num_records": len(_df[column_config.name]),
72
- "num_null": _df[column_config.name].isnull().sum(),
73
- "num_unique": _df[column_config.name].nunique(),
80
+ "pyarrow_dtype": pyarrow_dtype,
81
+ "simple_dtype": simple_dtype,
82
+ "num_records": len(_df[column_name]),
83
+ "num_null": _df[column_name].isnull().sum(),
84
+ "num_unique": _df[column_name].nunique(),
74
85
  }
75
86
  except Exception as e:
76
- logger.warning(f"{WARNING_PREFIX} failed to calculate general column info for '{column_config.name}': {e}")
87
+ logger.warning(f"{WARNING_PREFIX} failed to calculate general column info for '{column_name}': {e}")
77
88
  return {
78
89
  "pyarrow_dtype": MissingValue.CALCULATION_FAILED,
79
90
  "simple_dtype": MissingValue.CALCULATION_FAILED,
@@ -83,7 +94,7 @@ def calculate_general_column_info(column_config: SingleColumnConfig, df: pd.Data
83
94
  }
84
95
 
85
96
 
86
- def calculate_prompt_token_stats(
97
+ def calculate_input_token_stats(
87
98
  column_config: LLMTextColumnConfig, df: pd.DataFrame
88
99
  ) -> dict[str, float | MissingValue]:
89
100
  try:
@@ -100,22 +111,20 @@ def calculate_prompt_token_stats(
100
111
  concatenated_prompt = str(system_prompt + "\n\n" + prompt)
101
112
  num_tokens.append(len(TOKENIZER.encode(concatenated_prompt, disallowed_special=())))
102
113
  except Exception as e:
103
- logger.warning(
104
- f"{WARNING_PREFIX} failed to calculate prompt token stats for column {column_config.name!r}: {e}"
105
- )
114
+ logger.warning(f"{WARNING_PREFIX} failed to calculate input token stats for column {column_config.name!r}: {e}")
106
115
  return {
107
- "prompt_tokens_mean": MissingValue.CALCULATION_FAILED,
108
- "prompt_tokens_median": MissingValue.CALCULATION_FAILED,
109
- "prompt_tokens_stddev": MissingValue.CALCULATION_FAILED,
116
+ "input_tokens_mean": MissingValue.CALCULATION_FAILED,
117
+ "input_tokens_median": MissingValue.CALCULATION_FAILED,
118
+ "input_tokens_stddev": MissingValue.CALCULATION_FAILED,
110
119
  }
111
120
  return {
112
- "prompt_tokens_mean": np.mean(num_tokens),
113
- "prompt_tokens_median": np.median(num_tokens),
114
- "prompt_tokens_stddev": np.std(num_tokens),
121
+ "input_tokens_mean": np.mean(num_tokens),
122
+ "input_tokens_median": np.median(num_tokens),
123
+ "input_tokens_stddev": np.std(num_tokens),
115
124
  }
116
125
 
117
126
 
118
- def calculate_completion_token_stats(
127
+ def calculate_output_token_stats(
119
128
  column_config: LLMTextColumnConfig, df: pd.DataFrame
120
129
  ) -> dict[str, float | MissingValue]:
121
130
  try:
@@ -123,34 +132,32 @@ def calculate_completion_token_stats(
123
132
  lambda value: len(TOKENIZER.encode(str(value), disallowed_special=()))
124
133
  )
125
134
  return {
126
- "completion_tokens_mean": tokens_per_record.mean(),
127
- "completion_tokens_median": tokens_per_record.median(),
128
- "completion_tokens_stddev": tokens_per_record.std(),
135
+ "output_tokens_mean": tokens_per_record.mean(),
136
+ "output_tokens_median": tokens_per_record.median(),
137
+ "output_tokens_stddev": tokens_per_record.std(),
129
138
  }
130
139
  except Exception as e:
131
- logger.warning(
132
- f"{WARNING_PREFIX} failed to calculate completion token stats for column {column_config.name}: {e}"
133
- )
140
+ logger.warning(f"{WARNING_PREFIX} failed to calculate output token stats for column {column_config.name}: {e}")
134
141
  return {
135
- "completion_tokens_mean": MissingValue.CALCULATION_FAILED,
136
- "completion_tokens_median": MissingValue.CALCULATION_FAILED,
137
- "completion_tokens_stddev": MissingValue.CALCULATION_FAILED,
142
+ "output_tokens_mean": MissingValue.CALCULATION_FAILED,
143
+ "output_tokens_median": MissingValue.CALCULATION_FAILED,
144
+ "output_tokens_stddev": MissingValue.CALCULATION_FAILED,
138
145
  }
139
146
 
140
147
 
141
148
  def calculate_token_stats(column_config: LLMTextColumnConfig, df: pd.DataFrame) -> dict[str, float | MissingValue]:
142
149
  return {
143
- **calculate_prompt_token_stats(column_config, df),
144
- **calculate_completion_token_stats(column_config, df),
150
+ **calculate_input_token_stats(column_config, df),
151
+ **calculate_output_token_stats(column_config, df),
145
152
  }
146
153
 
147
154
 
148
- def calculate_validation_column_info(column_config: ValidationColumnConfig, df: pd.DataFrame) -> dict[str, Any]:
155
+ def calculate_validation_column_info(column_name: str, df: pd.DataFrame) -> dict[str, Any]:
149
156
  try:
150
- return {"num_valid_records": df[column_config.name].apply(lambda x: ensure_boolean(x["is_valid"])).sum()}
157
+ return {"num_valid_records": df[column_name].apply(lambda x: ensure_boolean(x["is_valid"])).sum()}
151
158
  except Exception as e:
152
159
  logger.warning(
153
- f"{WARNING_PREFIX} failed to calculate code validation column info for column {column_config.name}: {e}"
160
+ f"{WARNING_PREFIX} failed to calculate code validation column info for column {column_name}: {e}"
154
161
  )
155
162
  return {"num_valid_records": MissingValue.CALCULATION_FAILED}
156
163
 
@@ -160,22 +167,33 @@ def convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype: pa.DataType) -> str:
160
167
  return f"list[{convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype.value_type)}]"
161
168
  if isinstance(pyarrow_dtype, pa.StructType):
162
169
  return "dict"
163
- pyarrow_dtype_str = str(pyarrow_dtype)
164
- if "int" in pyarrow_dtype_str:
170
+ return convert_to_simple_dtype(str(pyarrow_dtype))
171
+
172
+
173
+ def convert_to_simple_dtype(dtype: str) -> str:
174
+ if "int" in dtype:
165
175
  return "int"
166
- if "double" in pyarrow_dtype_str:
176
+ if "double" in dtype:
167
177
  return "float"
168
- if "float" in pyarrow_dtype_str:
178
+ if "float" in dtype:
169
179
  return "float"
170
- if "string" in pyarrow_dtype_str:
180
+ if "str" in dtype:
171
181
  return "string"
172
- if "timestamp" in pyarrow_dtype_str:
182
+ if "timestamp" in dtype:
173
183
  return "timestamp"
174
- if "time" in pyarrow_dtype_str:
184
+ if "time" in dtype:
175
185
  return "time"
176
- if "date" in pyarrow_dtype_str:
186
+ if "date" in dtype:
177
187
  return "date"
178
- return pyarrow_dtype_str
188
+ return dtype
189
+
190
+
191
+ def get_column_data_type_from_first_non_null_value(column_name: str, df: pd.DataFrame) -> str:
192
+ df_no_nulls = df[column_name].dropna()
193
+ if len(df_no_nulls) == 0:
194
+ return MissingValue.CALCULATION_FAILED
195
+ dtype = type(df_no_nulls.iloc[0]).__name__
196
+ return convert_to_simple_dtype(dtype)
179
197
 
180
198
 
181
199
  def ensure_hashable(x: Any) -> str:
@@ -207,3 +225,7 @@ def ensure_boolean(v: bool | str | int | None) -> bool:
207
225
  if v is None:
208
226
  return False
209
227
  raise ValueError(f"Invalid boolean value: {v}")
228
+
229
+
230
+ def has_pyarrow_backend(df: pd.DataFrame) -> bool:
231
+ return all(isinstance(dtype, pd.ArrowDtype) for dtype in df.dtypes)
@@ -3,7 +3,7 @@
3
3
 
4
4
  import logging
5
5
  from collections import defaultdict
6
- from typing import Any, Optional, Union
6
+ from typing import Any
7
7
 
8
8
  import pandas as pd
9
9
 
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
21
21
 
22
22
  def extract_judge_score_distributions(
23
23
  column_config: LLMJudgeColumnConfig, df: pd.DataFrame
24
- ) -> Union[JudgeScoreDistributions, MissingValue]:
24
+ ) -> JudgeScoreDistributions | MissingValue:
25
25
  scores = defaultdict(list)
26
26
  reasoning = defaultdict(list)
27
27
 
@@ -32,7 +32,7 @@ def extract_judge_score_distributions(
32
32
 
33
33
  for score in column_config.scores:
34
34
  is_numerical = True
35
- name = score.name.lower()
35
+ name = score.name
36
36
  for results in df[column_config.name]:
37
37
  try:
38
38
  score = results[name].get("score", None)
@@ -79,10 +79,10 @@ def extract_judge_score_distributions(
79
79
 
80
80
 
81
81
  def sample_scores_and_reasoning(
82
- scores: list[Union[int, str]],
82
+ scores: list[int | str],
83
83
  reasoning: list[str],
84
84
  num_samples: int,
85
- random_seed: Optional[int] = None,
85
+ random_seed: int | None = None,
86
86
  ) -> list[JudgeScoreSample]:
87
87
  if len(scores) != len(reasoning):
88
88
  raise ValueError("scores and reasoning must have the same length")
@@ -1,13 +1,20 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ import functools
5
+ import logging
4
6
  from abc import ABC, abstractmethod
5
7
  from typing import overload
6
8
 
7
9
  import pandas as pd
8
10
 
11
+ from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP
12
+ from data_designer.config.models import BaseInferenceParams, ModelConfig
9
13
  from data_designer.config.utils.type_helpers import StrEnum
10
14
  from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT
15
+ from data_designer.engine.models.facade import ModelFacade
16
+
17
+ logger = logging.getLogger(__name__)
11
18
 
12
19
 
13
20
  class GenerationStrategy(StrEnum):
@@ -59,3 +66,30 @@ class FromScratchColumnGenerator(ColumnGenerator[TaskConfigT], ABC):
59
66
 
60
67
  @abstractmethod
61
68
  def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ...
69
+
70
+
71
+ class WithModelGeneration:
72
+ @functools.cached_property
73
+ def model(self) -> ModelFacade:
74
+ return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias)
75
+
76
+ @functools.cached_property
77
+ def model_config(self) -> ModelConfig:
78
+ return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias)
79
+
80
+ @functools.cached_property
81
+ def inference_parameters(self) -> BaseInferenceParams:
82
+ return self.model_config.inference_parameters
83
+
84
+ def log_pre_generation(self) -> None:
85
+ emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type]
86
+ logger.info(f"{emoji} Preparing {self.config.column_type} column generation")
87
+ logger.info(f" |-- column name: {self.config.name!r}")
88
+ logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}")
89
+ if self.model_config.provider is None:
90
+ logger.info(f" |-- default model provider: {self._get_provider_name()!r}")
91
+
92
+ def _get_provider_name(self) -> str:
93
+ model_alias = self.model_config.alias
94
+ provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias)
95
+ return provider.name
@@ -0,0 +1,45 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ from pydantic import BaseModel, computed_field
6
+
7
+ from data_designer.config.column_configs import EmbeddingColumnConfig
8
+ from data_designer.engine.column_generators.generators.base import (
9
+ ColumnGenerator,
10
+ GenerationStrategy,
11
+ GeneratorMetadata,
12
+ WithModelGeneration,
13
+ )
14
+ from data_designer.engine.processing.utils import deserialize_json_values, parse_list_string
15
+ from data_designer.engine.resources.resource_provider import ResourceType
16
+
17
+
18
+ class EmbeddingGenerationResult(BaseModel):
19
+ embeddings: list[list[float]]
20
+
21
+ @computed_field
22
+ def num_embeddings(self) -> int:
23
+ return len(self.embeddings)
24
+
25
+ @computed_field
26
+ def dimension(self) -> int:
27
+ return len(self.embeddings[0]) if len(self.embeddings) > 0 else 0
28
+
29
+
30
+ class EmbeddingCellGenerator(WithModelGeneration, ColumnGenerator[EmbeddingColumnConfig]):
31
+ @staticmethod
32
+ def metadata() -> GeneratorMetadata:
33
+ return GeneratorMetadata(
34
+ name="embedding_cell_generator",
35
+ description="Generate embeddings for a text column.",
36
+ generation_strategy=GenerationStrategy.CELL_BY_CELL,
37
+ required_resources=[ResourceType.MODEL_REGISTRY],
38
+ )
39
+
40
+ def generate(self, data: dict) -> dict:
41
+ deserialized_record = deserialize_json_values(data)
42
+ input_texts = parse_list_string(deserialized_record[self.config.target_column])
43
+ embeddings = self.model.generate_text_embeddings(input_texts=input_texts)
44
+ data[self.config.name] = EmbeddingGenerationResult(embeddings=embeddings).model_dump(mode="json")
45
+ return data
@@ -10,43 +10,41 @@ from data_designer.config.column_configs import (
10
10
  LLMStructuredColumnConfig,
11
11
  LLMTextColumnConfig,
12
12
  )
13
- from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP
14
- from data_designer.config.models import InferenceParameters, ModelConfig
15
13
  from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX
16
14
  from data_designer.engine.column_generators.generators.base import (
17
15
  ColumnGenerator,
18
16
  GenerationStrategy,
19
17
  GeneratorMetadata,
18
+ WithModelGeneration,
20
19
  )
21
20
  from data_designer.engine.column_generators.utils.prompt_renderer import (
22
21
  PromptType,
23
22
  RecordBasedPromptRenderer,
24
23
  create_response_recipe,
25
24
  )
26
- from data_designer.engine.models.facade import ModelFacade
27
25
  from data_designer.engine.models.recipes.base import ResponseRecipe
28
26
  from data_designer.engine.processing.utils import deserialize_json_values
29
27
  from data_designer.engine.resources.resource_provider import ResourceType
30
28
 
31
- DEFAULT_MAX_CONVERSATION_RESTARTS = 5
32
- DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0
29
+ logger = logging.getLogger(__name__)
33
30
 
34
31
 
35
- logger = logging.getLogger(__name__)
32
+ DEFAULT_MAX_CONVERSATION_RESTARTS = 5
33
+ DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0
36
34
 
37
35
 
38
- class WithLLMGeneration:
36
+ class WithChatCompletionGeneration(WithModelGeneration):
39
37
  @functools.cached_property
40
- def model(self) -> ModelFacade:
41
- return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias)
38
+ def response_recipe(self) -> ResponseRecipe:
39
+ return create_response_recipe(self.config, self.model_config)
42
40
 
43
- @functools.cached_property
44
- def model_config(self) -> ModelConfig:
45
- return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias)
41
+ @property
42
+ def max_conversation_correction_steps(self) -> int:
43
+ return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
46
44
 
47
- @functools.cached_property
48
- def inference_parameters(self) -> InferenceParameters:
49
- return self.model_config.inference_parameters
45
+ @property
46
+ def max_conversation_restarts(self) -> int:
47
+ return DEFAULT_MAX_CONVERSATION_RESTARTS
50
48
 
51
49
  @functools.cached_property
52
50
  def prompt_renderer(self) -> RecordBasedPromptRenderer:
@@ -59,18 +57,6 @@ class WithLLMGeneration:
59
57
  },
60
58
  )
61
59
 
62
- @functools.cached_property
63
- def response_recipe(self) -> ResponseRecipe:
64
- return create_response_recipe(self.config, self.model_config)
65
-
66
- @property
67
- def max_conversation_correction_steps(self) -> int:
68
- return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
69
-
70
- @property
71
- def max_conversation_restarts(self) -> int:
72
- return DEFAULT_MAX_CONVERSATION_RESTARTS
73
-
74
60
  def generate(self, data: dict) -> dict:
75
61
  deserialized_record = deserialize_json_values(data)
76
62
 
@@ -96,7 +82,6 @@ class WithLLMGeneration:
96
82
  max_correction_steps=self.max_conversation_correction_steps,
97
83
  max_conversation_restarts=self.max_conversation_restarts,
98
84
  purpose=f"running generation for column '{self.config.name}'",
99
- **self.inference_parameters.generate_kwargs,
100
85
  )
101
86
 
102
87
  data[self.config.name] = deserialize_json_values(self.response_recipe.serialize_output(response))
@@ -106,21 +91,8 @@ class WithLLMGeneration:
106
91
 
107
92
  return data
108
93
 
109
- def log_pre_generation(self) -> None:
110
- emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type]
111
- logger.info(f"{emoji} Preparing {self.config.column_type} column generation")
112
- logger.info(f" |-- column name: {self.config.name!r}")
113
- logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}")
114
- if self.model_config.provider is None:
115
- logger.info(f" |-- default model provider: {self._get_provider_name()!r}")
116
-
117
- def _get_provider_name(self) -> str:
118
- model_alias = self.model_config.alias
119
- provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias)
120
- return provider.name
121
-
122
94
 
123
- class LLMTextCellGenerator(WithLLMGeneration, ColumnGenerator[LLMTextColumnConfig]):
95
+ class LLMTextCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMTextColumnConfig]):
124
96
  @staticmethod
125
97
  def metadata() -> GeneratorMetadata:
126
98
  return GeneratorMetadata(
@@ -131,7 +103,7 @@ class LLMTextCellGenerator(WithLLMGeneration, ColumnGenerator[LLMTextColumnConfi
131
103
  )
132
104
 
133
105
 
134
- class LLMCodeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMCodeColumnConfig]):
106
+ class LLMCodeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMCodeColumnConfig]):
135
107
  @staticmethod
136
108
  def metadata() -> GeneratorMetadata:
137
109
  return GeneratorMetadata(
@@ -142,7 +114,7 @@ class LLMCodeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMCodeColumnConfi
142
114
  )
143
115
 
144
116
 
145
- class LLMStructuredCellGenerator(WithLLMGeneration, ColumnGenerator[LLMStructuredColumnConfig]):
117
+ class LLMStructuredCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMStructuredColumnConfig]):
146
118
  @staticmethod
147
119
  def metadata() -> GeneratorMetadata:
148
120
  return GeneratorMetadata(
@@ -153,7 +125,7 @@ class LLMStructuredCellGenerator(WithLLMGeneration, ColumnGenerator[LLMStructure
153
125
  )
154
126
 
155
127
 
156
- class LLMJudgeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMJudgeColumnConfig]):
128
+ class LLMJudgeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMJudgeColumnConfig]):
157
129
  @staticmethod
158
130
  def metadata() -> GeneratorMetadata:
159
131
  return GeneratorMetadata(
@@ -163,10 +135,6 @@ class LLMJudgeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMJudgeColumnCon
163
135
  required_resources=[ResourceType.MODEL_REGISTRY],
164
136
  )
165
137
 
166
- @property
167
- def max_conversation_correction_steps(self) -> int:
168
- return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
169
-
170
138
  @property
171
139
  def max_conversation_restarts(self) -> int:
172
140
  return 2 * DEFAULT_MAX_CONVERSATION_RESTARTS