data-designer 0.1.5__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/cli/README.md +15 -1
- data_designer/cli/commands/download.py +56 -0
- data_designer/cli/commands/list.py +4 -18
- data_designer/cli/controllers/__init__.py +2 -1
- data_designer/cli/controllers/download_controller.py +217 -0
- data_designer/cli/controllers/model_controller.py +4 -3
- data_designer/cli/forms/field.py +65 -19
- data_designer/cli/forms/model_builder.py +251 -44
- data_designer/cli/main.py +11 -1
- data_designer/cli/repositories/persona_repository.py +88 -0
- data_designer/cli/services/__init__.py +2 -1
- data_designer/cli/services/download_service.py +97 -0
- data_designer/cli/ui.py +131 -0
- data_designer/cli/utils.py +34 -0
- data_designer/config/analysis/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +75 -7
- data_designer/config/analysis/column_statistics.py +192 -48
- data_designer/config/analysis/dataset_profiler.py +23 -5
- data_designer/config/analysis/utils/reporting.py +3 -3
- data_designer/config/base.py +3 -3
- data_designer/config/column_configs.py +27 -6
- data_designer/config/column_types.py +24 -17
- data_designer/config/config_builder.py +34 -26
- data_designer/config/data_designer_config.py +7 -7
- data_designer/config/datastore.py +6 -6
- data_designer/config/default_model_settings.py +27 -34
- data_designer/config/exports.py +8 -0
- data_designer/config/models.py +155 -29
- data_designer/config/preview_results.py +6 -8
- data_designer/config/processors.py +63 -2
- data_designer/config/sampler_constraints.py +1 -2
- data_designer/config/sampler_params.py +31 -31
- data_designer/config/seed.py +1 -2
- data_designer/config/utils/code_lang.py +4 -5
- data_designer/config/utils/constants.py +31 -8
- data_designer/config/utils/io_helpers.py +5 -5
- data_designer/config/utils/misc.py +1 -4
- data_designer/config/utils/numerical_helpers.py +2 -2
- data_designer/config/utils/type_helpers.py +3 -3
- data_designer/config/utils/validation.py +7 -8
- data_designer/config/utils/visualization.py +32 -17
- data_designer/config/validator_params.py +4 -8
- data_designer/engine/analysis/column_profilers/base.py +0 -7
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
- data_designer/engine/analysis/column_statistics.py +16 -16
- data_designer/engine/analysis/dataset_profiler.py +25 -4
- data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
- data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
- data_designer/engine/column_generators/generators/base.py +34 -0
- data_designer/engine/column_generators/generators/embedding.py +45 -0
- data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
- data_designer/engine/column_generators/registry.py +4 -2
- data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
- data_designer/engine/configurable_task.py +2 -2
- data_designer/engine/dataset_builders/artifact_storage.py +1 -2
- data_designer/engine/dataset_builders/column_wise_builder.py +11 -10
- data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
- data_designer/engine/models/facade.py +66 -9
- data_designer/engine/models/litellm_overrides.py +5 -6
- data_designer/engine/models/parsers/errors.py +2 -4
- data_designer/engine/models/parsers/parser.py +2 -3
- data_designer/engine/models/parsers/postprocessors.py +3 -4
- data_designer/engine/models/parsers/types.py +4 -4
- data_designer/engine/models/registry.py +20 -11
- data_designer/engine/models/usage.py +7 -9
- data_designer/engine/processing/ginja/ast.py +1 -2
- data_designer/engine/processing/utils.py +40 -2
- data_designer/engine/registry/base.py +12 -12
- data_designer/engine/sampling_gen/constraints.py +1 -2
- data_designer/engine/sampling_gen/data_sources/base.py +14 -14
- data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
- data_designer/engine/sampling_gen/people_gen.py +3 -7
- data_designer/engine/validators/base.py +2 -2
- data_designer/logging.py +2 -2
- data_designer/plugin_manager.py +3 -3
- data_designer/plugins/plugin.py +3 -3
- data_designer/plugins/registry.py +2 -2
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/METADATA +1 -1
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/RECORD +83 -77
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -20,10 +20,8 @@ from data_designer.config.analysis.column_statistics import (
|
|
|
20
20
|
)
|
|
21
21
|
from data_designer.config.column_configs import (
|
|
22
22
|
LLMTextColumnConfig,
|
|
23
|
-
SingleColumnConfig,
|
|
24
|
-
ValidationColumnConfig,
|
|
25
23
|
)
|
|
26
|
-
from data_designer.engine.column_generators.
|
|
24
|
+
from data_designer.engine.column_generators.utils.prompt_renderer import (
|
|
27
25
|
PromptType,
|
|
28
26
|
RecordBasedPromptRenderer,
|
|
29
27
|
create_response_recipe,
|
|
@@ -39,41 +37,54 @@ logger = logging.getLogger(__name__)
|
|
|
39
37
|
|
|
40
38
|
|
|
41
39
|
def calculate_column_distribution(
|
|
42
|
-
|
|
40
|
+
column_name: str, df: pd.DataFrame, distribution_type: ColumnDistributionType
|
|
43
41
|
) -> dict[str, CategoricalDistribution | NumericalDistribution | MissingValue | None]:
|
|
44
42
|
distribution_type = ColumnDistributionType(distribution_type)
|
|
45
43
|
try:
|
|
46
44
|
if distribution_type == ColumnDistributionType.CATEGORICAL:
|
|
47
45
|
return {
|
|
48
46
|
"distribution_type": ColumnDistributionType.CATEGORICAL,
|
|
49
|
-
"distribution": CategoricalDistribution.from_series(df[
|
|
47
|
+
"distribution": CategoricalDistribution.from_series(df[column_name]),
|
|
50
48
|
}
|
|
51
49
|
|
|
52
50
|
if distribution_type == ColumnDistributionType.NUMERICAL:
|
|
53
51
|
return {
|
|
54
52
|
"distribution_type": ColumnDistributionType.NUMERICAL,
|
|
55
|
-
"distribution": NumericalDistribution.from_series(df[
|
|
53
|
+
"distribution": NumericalDistribution.from_series(df[column_name]),
|
|
56
54
|
}
|
|
57
55
|
except Exception as e:
|
|
58
|
-
logger.warning(f"{WARNING_PREFIX} failed to calculate column distribution for '{
|
|
56
|
+
logger.warning(f"{WARNING_PREFIX} failed to calculate column distribution for '{column_name}' {e}")
|
|
59
57
|
return {
|
|
60
58
|
"distribution_type": ColumnDistributionType.UNKNOWN,
|
|
61
59
|
"distribution": MissingValue.CALCULATION_FAILED,
|
|
62
60
|
}
|
|
63
61
|
|
|
64
62
|
|
|
65
|
-
def calculate_general_column_info(
|
|
63
|
+
def calculate_general_column_info(column_name: str, df: pd.DataFrame) -> dict[str, Any]:
|
|
66
64
|
try:
|
|
67
|
-
_df = pd.DataFrame(df[
|
|
65
|
+
_df = pd.DataFrame(df[column_name].apply(ensure_hashable))
|
|
66
|
+
|
|
67
|
+
if has_pyarrow_backend(df):
|
|
68
|
+
pyarrow_dtype = str(df[column_name].dtype.pyarrow_dtype)
|
|
69
|
+
simple_dtype = convert_pyarrow_dtype_to_simple_dtype(df[column_name].dtype.pyarrow_dtype)
|
|
70
|
+
else:
|
|
71
|
+
# We do not log a warning at the column-level because it would be too noisy.
|
|
72
|
+
# However, there is a logged warning at the dataset-profiler level.
|
|
73
|
+
try:
|
|
74
|
+
simple_dtype = get_column_data_type_from_first_non_null_value(column_name, df)
|
|
75
|
+
except Exception:
|
|
76
|
+
simple_dtype = MissingValue.CALCULATION_FAILED
|
|
77
|
+
pyarrow_dtype = "n/a"
|
|
78
|
+
|
|
68
79
|
return {
|
|
69
|
-
"pyarrow_dtype":
|
|
70
|
-
"simple_dtype":
|
|
71
|
-
"num_records": len(_df[
|
|
72
|
-
"num_null": _df[
|
|
73
|
-
"num_unique": _df[
|
|
80
|
+
"pyarrow_dtype": pyarrow_dtype,
|
|
81
|
+
"simple_dtype": simple_dtype,
|
|
82
|
+
"num_records": len(_df[column_name]),
|
|
83
|
+
"num_null": _df[column_name].isnull().sum(),
|
|
84
|
+
"num_unique": _df[column_name].nunique(),
|
|
74
85
|
}
|
|
75
86
|
except Exception as e:
|
|
76
|
-
logger.warning(f"{WARNING_PREFIX} failed to calculate general column info for '{
|
|
87
|
+
logger.warning(f"{WARNING_PREFIX} failed to calculate general column info for '{column_name}': {e}")
|
|
77
88
|
return {
|
|
78
89
|
"pyarrow_dtype": MissingValue.CALCULATION_FAILED,
|
|
79
90
|
"simple_dtype": MissingValue.CALCULATION_FAILED,
|
|
@@ -83,7 +94,7 @@ def calculate_general_column_info(column_config: SingleColumnConfig, df: pd.Data
|
|
|
83
94
|
}
|
|
84
95
|
|
|
85
96
|
|
|
86
|
-
def
|
|
97
|
+
def calculate_input_token_stats(
|
|
87
98
|
column_config: LLMTextColumnConfig, df: pd.DataFrame
|
|
88
99
|
) -> dict[str, float | MissingValue]:
|
|
89
100
|
try:
|
|
@@ -100,22 +111,20 @@ def calculate_prompt_token_stats(
|
|
|
100
111
|
concatenated_prompt = str(system_prompt + "\n\n" + prompt)
|
|
101
112
|
num_tokens.append(len(TOKENIZER.encode(concatenated_prompt, disallowed_special=())))
|
|
102
113
|
except Exception as e:
|
|
103
|
-
logger.warning(
|
|
104
|
-
f"{WARNING_PREFIX} failed to calculate prompt token stats for column {column_config.name!r}: {e}"
|
|
105
|
-
)
|
|
114
|
+
logger.warning(f"{WARNING_PREFIX} failed to calculate input token stats for column {column_config.name!r}: {e}")
|
|
106
115
|
return {
|
|
107
|
-
"
|
|
108
|
-
"
|
|
109
|
-
"
|
|
116
|
+
"input_tokens_mean": MissingValue.CALCULATION_FAILED,
|
|
117
|
+
"input_tokens_median": MissingValue.CALCULATION_FAILED,
|
|
118
|
+
"input_tokens_stddev": MissingValue.CALCULATION_FAILED,
|
|
110
119
|
}
|
|
111
120
|
return {
|
|
112
|
-
"
|
|
113
|
-
"
|
|
114
|
-
"
|
|
121
|
+
"input_tokens_mean": np.mean(num_tokens),
|
|
122
|
+
"input_tokens_median": np.median(num_tokens),
|
|
123
|
+
"input_tokens_stddev": np.std(num_tokens),
|
|
115
124
|
}
|
|
116
125
|
|
|
117
126
|
|
|
118
|
-
def
|
|
127
|
+
def calculate_output_token_stats(
|
|
119
128
|
column_config: LLMTextColumnConfig, df: pd.DataFrame
|
|
120
129
|
) -> dict[str, float | MissingValue]:
|
|
121
130
|
try:
|
|
@@ -123,34 +132,32 @@ def calculate_completion_token_stats(
|
|
|
123
132
|
lambda value: len(TOKENIZER.encode(str(value), disallowed_special=()))
|
|
124
133
|
)
|
|
125
134
|
return {
|
|
126
|
-
"
|
|
127
|
-
"
|
|
128
|
-
"
|
|
135
|
+
"output_tokens_mean": tokens_per_record.mean(),
|
|
136
|
+
"output_tokens_median": tokens_per_record.median(),
|
|
137
|
+
"output_tokens_stddev": tokens_per_record.std(),
|
|
129
138
|
}
|
|
130
139
|
except Exception as e:
|
|
131
|
-
logger.warning(
|
|
132
|
-
f"{WARNING_PREFIX} failed to calculate completion token stats for column {column_config.name}: {e}"
|
|
133
|
-
)
|
|
140
|
+
logger.warning(f"{WARNING_PREFIX} failed to calculate output token stats for column {column_config.name}: {e}")
|
|
134
141
|
return {
|
|
135
|
-
"
|
|
136
|
-
"
|
|
137
|
-
"
|
|
142
|
+
"output_tokens_mean": MissingValue.CALCULATION_FAILED,
|
|
143
|
+
"output_tokens_median": MissingValue.CALCULATION_FAILED,
|
|
144
|
+
"output_tokens_stddev": MissingValue.CALCULATION_FAILED,
|
|
138
145
|
}
|
|
139
146
|
|
|
140
147
|
|
|
141
148
|
def calculate_token_stats(column_config: LLMTextColumnConfig, df: pd.DataFrame) -> dict[str, float | MissingValue]:
|
|
142
149
|
return {
|
|
143
|
-
**
|
|
144
|
-
**
|
|
150
|
+
**calculate_input_token_stats(column_config, df),
|
|
151
|
+
**calculate_output_token_stats(column_config, df),
|
|
145
152
|
}
|
|
146
153
|
|
|
147
154
|
|
|
148
|
-
def calculate_validation_column_info(
|
|
155
|
+
def calculate_validation_column_info(column_name: str, df: pd.DataFrame) -> dict[str, Any]:
|
|
149
156
|
try:
|
|
150
|
-
return {"num_valid_records": df[
|
|
157
|
+
return {"num_valid_records": df[column_name].apply(lambda x: ensure_boolean(x["is_valid"])).sum()}
|
|
151
158
|
except Exception as e:
|
|
152
159
|
logger.warning(
|
|
153
|
-
f"{WARNING_PREFIX} failed to calculate code validation column info for column {
|
|
160
|
+
f"{WARNING_PREFIX} failed to calculate code validation column info for column {column_name}: {e}"
|
|
154
161
|
)
|
|
155
162
|
return {"num_valid_records": MissingValue.CALCULATION_FAILED}
|
|
156
163
|
|
|
@@ -160,22 +167,33 @@ def convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype: pa.DataType) -> str:
|
|
|
160
167
|
return f"list[{convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype.value_type)}]"
|
|
161
168
|
if isinstance(pyarrow_dtype, pa.StructType):
|
|
162
169
|
return "dict"
|
|
163
|
-
|
|
164
|
-
|
|
170
|
+
return convert_to_simple_dtype(str(pyarrow_dtype))
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def convert_to_simple_dtype(dtype: str) -> str:
|
|
174
|
+
if "int" in dtype:
|
|
165
175
|
return "int"
|
|
166
|
-
if "double" in
|
|
176
|
+
if "double" in dtype:
|
|
167
177
|
return "float"
|
|
168
|
-
if "float" in
|
|
178
|
+
if "float" in dtype:
|
|
169
179
|
return "float"
|
|
170
|
-
if "
|
|
180
|
+
if "str" in dtype:
|
|
171
181
|
return "string"
|
|
172
|
-
if "timestamp" in
|
|
182
|
+
if "timestamp" in dtype:
|
|
173
183
|
return "timestamp"
|
|
174
|
-
if "time" in
|
|
184
|
+
if "time" in dtype:
|
|
175
185
|
return "time"
|
|
176
|
-
if "date" in
|
|
186
|
+
if "date" in dtype:
|
|
177
187
|
return "date"
|
|
178
|
-
return
|
|
188
|
+
return dtype
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def get_column_data_type_from_first_non_null_value(column_name: str, df: pd.DataFrame) -> str:
|
|
192
|
+
df_no_nulls = df[column_name].dropna()
|
|
193
|
+
if len(df_no_nulls) == 0:
|
|
194
|
+
return MissingValue.CALCULATION_FAILED
|
|
195
|
+
dtype = type(df_no_nulls.iloc[0]).__name__
|
|
196
|
+
return convert_to_simple_dtype(dtype)
|
|
179
197
|
|
|
180
198
|
|
|
181
199
|
def ensure_hashable(x: Any) -> str:
|
|
@@ -207,3 +225,7 @@ def ensure_boolean(v: bool | str | int | None) -> bool:
|
|
|
207
225
|
if v is None:
|
|
208
226
|
return False
|
|
209
227
|
raise ValueError(f"Invalid boolean value: {v}")
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def has_pyarrow_backend(df: pd.DataFrame) -> bool:
|
|
231
|
+
return all(isinstance(dtype, pd.ArrowDtype) for dtype in df.dtypes)
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
from collections import defaultdict
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
|
|
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
|
|
21
21
|
|
|
22
22
|
def extract_judge_score_distributions(
|
|
23
23
|
column_config: LLMJudgeColumnConfig, df: pd.DataFrame
|
|
24
|
-
) ->
|
|
24
|
+
) -> JudgeScoreDistributions | MissingValue:
|
|
25
25
|
scores = defaultdict(list)
|
|
26
26
|
reasoning = defaultdict(list)
|
|
27
27
|
|
|
@@ -32,7 +32,7 @@ def extract_judge_score_distributions(
|
|
|
32
32
|
|
|
33
33
|
for score in column_config.scores:
|
|
34
34
|
is_numerical = True
|
|
35
|
-
name = score.name
|
|
35
|
+
name = score.name
|
|
36
36
|
for results in df[column_config.name]:
|
|
37
37
|
try:
|
|
38
38
|
score = results[name].get("score", None)
|
|
@@ -79,10 +79,10 @@ def extract_judge_score_distributions(
|
|
|
79
79
|
|
|
80
80
|
|
|
81
81
|
def sample_scores_and_reasoning(
|
|
82
|
-
scores: list[
|
|
82
|
+
scores: list[int | str],
|
|
83
83
|
reasoning: list[str],
|
|
84
84
|
num_samples: int,
|
|
85
|
-
random_seed:
|
|
85
|
+
random_seed: int | None = None,
|
|
86
86
|
) -> list[JudgeScoreSample]:
|
|
87
87
|
if len(scores) != len(reasoning):
|
|
88
88
|
raise ValueError("scores and reasoning must have the same length")
|
|
@@ -1,13 +1,20 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
import functools
|
|
5
|
+
import logging
|
|
4
6
|
from abc import ABC, abstractmethod
|
|
5
7
|
from typing import overload
|
|
6
8
|
|
|
7
9
|
import pandas as pd
|
|
8
10
|
|
|
11
|
+
from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP
|
|
12
|
+
from data_designer.config.models import BaseInferenceParams, ModelConfig
|
|
9
13
|
from data_designer.config.utils.type_helpers import StrEnum
|
|
10
14
|
from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT
|
|
15
|
+
from data_designer.engine.models.facade import ModelFacade
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
11
18
|
|
|
12
19
|
|
|
13
20
|
class GenerationStrategy(StrEnum):
|
|
@@ -59,3 +66,30 @@ class FromScratchColumnGenerator(ColumnGenerator[TaskConfigT], ABC):
|
|
|
59
66
|
|
|
60
67
|
@abstractmethod
|
|
61
68
|
def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ...
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class WithModelGeneration:
|
|
72
|
+
@functools.cached_property
|
|
73
|
+
def model(self) -> ModelFacade:
|
|
74
|
+
return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias)
|
|
75
|
+
|
|
76
|
+
@functools.cached_property
|
|
77
|
+
def model_config(self) -> ModelConfig:
|
|
78
|
+
return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias)
|
|
79
|
+
|
|
80
|
+
@functools.cached_property
|
|
81
|
+
def inference_parameters(self) -> BaseInferenceParams:
|
|
82
|
+
return self.model_config.inference_parameters
|
|
83
|
+
|
|
84
|
+
def log_pre_generation(self) -> None:
|
|
85
|
+
emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type]
|
|
86
|
+
logger.info(f"{emoji} Preparing {self.config.column_type} column generation")
|
|
87
|
+
logger.info(f" |-- column name: {self.config.name!r}")
|
|
88
|
+
logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}")
|
|
89
|
+
if self.model_config.provider is None:
|
|
90
|
+
logger.info(f" |-- default model provider: {self._get_provider_name()!r}")
|
|
91
|
+
|
|
92
|
+
def _get_provider_name(self) -> str:
|
|
93
|
+
model_alias = self.model_config.alias
|
|
94
|
+
provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias)
|
|
95
|
+
return provider.name
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, computed_field
|
|
6
|
+
|
|
7
|
+
from data_designer.config.column_configs import EmbeddingColumnConfig
|
|
8
|
+
from data_designer.engine.column_generators.generators.base import (
|
|
9
|
+
ColumnGenerator,
|
|
10
|
+
GenerationStrategy,
|
|
11
|
+
GeneratorMetadata,
|
|
12
|
+
WithModelGeneration,
|
|
13
|
+
)
|
|
14
|
+
from data_designer.engine.processing.utils import deserialize_json_values, parse_list_string
|
|
15
|
+
from data_designer.engine.resources.resource_provider import ResourceType
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class EmbeddingGenerationResult(BaseModel):
|
|
19
|
+
embeddings: list[list[float]]
|
|
20
|
+
|
|
21
|
+
@computed_field
|
|
22
|
+
def num_embeddings(self) -> int:
|
|
23
|
+
return len(self.embeddings)
|
|
24
|
+
|
|
25
|
+
@computed_field
|
|
26
|
+
def dimension(self) -> int:
|
|
27
|
+
return len(self.embeddings[0]) if len(self.embeddings) > 0 else 0
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class EmbeddingCellGenerator(WithModelGeneration, ColumnGenerator[EmbeddingColumnConfig]):
|
|
31
|
+
@staticmethod
|
|
32
|
+
def metadata() -> GeneratorMetadata:
|
|
33
|
+
return GeneratorMetadata(
|
|
34
|
+
name="embedding_cell_generator",
|
|
35
|
+
description="Generate embeddings for a text column.",
|
|
36
|
+
generation_strategy=GenerationStrategy.CELL_BY_CELL,
|
|
37
|
+
required_resources=[ResourceType.MODEL_REGISTRY],
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def generate(self, data: dict) -> dict:
|
|
41
|
+
deserialized_record = deserialize_json_values(data)
|
|
42
|
+
input_texts = parse_list_string(deserialized_record[self.config.target_column])
|
|
43
|
+
embeddings = self.model.generate_text_embeddings(input_texts=input_texts)
|
|
44
|
+
data[self.config.name] = EmbeddingGenerationResult(embeddings=embeddings).model_dump(mode="json")
|
|
45
|
+
return data
|
|
@@ -10,43 +10,41 @@ from data_designer.config.column_configs import (
|
|
|
10
10
|
LLMStructuredColumnConfig,
|
|
11
11
|
LLMTextColumnConfig,
|
|
12
12
|
)
|
|
13
|
-
from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP
|
|
14
|
-
from data_designer.config.models import InferenceParameters, ModelConfig
|
|
15
13
|
from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX
|
|
16
14
|
from data_designer.engine.column_generators.generators.base import (
|
|
17
15
|
ColumnGenerator,
|
|
18
16
|
GenerationStrategy,
|
|
19
17
|
GeneratorMetadata,
|
|
18
|
+
WithModelGeneration,
|
|
20
19
|
)
|
|
21
20
|
from data_designer.engine.column_generators.utils.prompt_renderer import (
|
|
22
21
|
PromptType,
|
|
23
22
|
RecordBasedPromptRenderer,
|
|
24
23
|
create_response_recipe,
|
|
25
24
|
)
|
|
26
|
-
from data_designer.engine.models.facade import ModelFacade
|
|
27
25
|
from data_designer.engine.models.recipes.base import ResponseRecipe
|
|
28
26
|
from data_designer.engine.processing.utils import deserialize_json_values
|
|
29
27
|
from data_designer.engine.resources.resource_provider import ResourceType
|
|
30
28
|
|
|
31
|
-
|
|
32
|
-
DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
33
30
|
|
|
34
31
|
|
|
35
|
-
|
|
32
|
+
DEFAULT_MAX_CONVERSATION_RESTARTS = 5
|
|
33
|
+
DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0
|
|
36
34
|
|
|
37
35
|
|
|
38
|
-
class
|
|
36
|
+
class WithChatCompletionGeneration(WithModelGeneration):
|
|
39
37
|
@functools.cached_property
|
|
40
|
-
def
|
|
41
|
-
return
|
|
38
|
+
def response_recipe(self) -> ResponseRecipe:
|
|
39
|
+
return create_response_recipe(self.config, self.model_config)
|
|
42
40
|
|
|
43
|
-
@
|
|
44
|
-
def
|
|
45
|
-
return
|
|
41
|
+
@property
|
|
42
|
+
def max_conversation_correction_steps(self) -> int:
|
|
43
|
+
return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
|
|
46
44
|
|
|
47
|
-
@
|
|
48
|
-
def
|
|
49
|
-
return
|
|
45
|
+
@property
|
|
46
|
+
def max_conversation_restarts(self) -> int:
|
|
47
|
+
return DEFAULT_MAX_CONVERSATION_RESTARTS
|
|
50
48
|
|
|
51
49
|
@functools.cached_property
|
|
52
50
|
def prompt_renderer(self) -> RecordBasedPromptRenderer:
|
|
@@ -59,18 +57,6 @@ class WithLLMGeneration:
|
|
|
59
57
|
},
|
|
60
58
|
)
|
|
61
59
|
|
|
62
|
-
@functools.cached_property
|
|
63
|
-
def response_recipe(self) -> ResponseRecipe:
|
|
64
|
-
return create_response_recipe(self.config, self.model_config)
|
|
65
|
-
|
|
66
|
-
@property
|
|
67
|
-
def max_conversation_correction_steps(self) -> int:
|
|
68
|
-
return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
|
|
69
|
-
|
|
70
|
-
@property
|
|
71
|
-
def max_conversation_restarts(self) -> int:
|
|
72
|
-
return DEFAULT_MAX_CONVERSATION_RESTARTS
|
|
73
|
-
|
|
74
60
|
def generate(self, data: dict) -> dict:
|
|
75
61
|
deserialized_record = deserialize_json_values(data)
|
|
76
62
|
|
|
@@ -96,7 +82,6 @@ class WithLLMGeneration:
|
|
|
96
82
|
max_correction_steps=self.max_conversation_correction_steps,
|
|
97
83
|
max_conversation_restarts=self.max_conversation_restarts,
|
|
98
84
|
purpose=f"running generation for column '{self.config.name}'",
|
|
99
|
-
**self.inference_parameters.generate_kwargs,
|
|
100
85
|
)
|
|
101
86
|
|
|
102
87
|
data[self.config.name] = deserialize_json_values(self.response_recipe.serialize_output(response))
|
|
@@ -106,21 +91,8 @@ class WithLLMGeneration:
|
|
|
106
91
|
|
|
107
92
|
return data
|
|
108
93
|
|
|
109
|
-
def log_pre_generation(self) -> None:
|
|
110
|
-
emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type]
|
|
111
|
-
logger.info(f"{emoji} Preparing {self.config.column_type} column generation")
|
|
112
|
-
logger.info(f" |-- column name: {self.config.name!r}")
|
|
113
|
-
logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}")
|
|
114
|
-
if self.model_config.provider is None:
|
|
115
|
-
logger.info(f" |-- default model provider: {self._get_provider_name()!r}")
|
|
116
|
-
|
|
117
|
-
def _get_provider_name(self) -> str:
|
|
118
|
-
model_alias = self.model_config.alias
|
|
119
|
-
provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias)
|
|
120
|
-
return provider.name
|
|
121
|
-
|
|
122
94
|
|
|
123
|
-
class LLMTextCellGenerator(
|
|
95
|
+
class LLMTextCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMTextColumnConfig]):
|
|
124
96
|
@staticmethod
|
|
125
97
|
def metadata() -> GeneratorMetadata:
|
|
126
98
|
return GeneratorMetadata(
|
|
@@ -131,7 +103,7 @@ class LLMTextCellGenerator(WithLLMGeneration, ColumnGenerator[LLMTextColumnConfi
|
|
|
131
103
|
)
|
|
132
104
|
|
|
133
105
|
|
|
134
|
-
class LLMCodeCellGenerator(
|
|
106
|
+
class LLMCodeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMCodeColumnConfig]):
|
|
135
107
|
@staticmethod
|
|
136
108
|
def metadata() -> GeneratorMetadata:
|
|
137
109
|
return GeneratorMetadata(
|
|
@@ -142,7 +114,7 @@ class LLMCodeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMCodeColumnConfi
|
|
|
142
114
|
)
|
|
143
115
|
|
|
144
116
|
|
|
145
|
-
class LLMStructuredCellGenerator(
|
|
117
|
+
class LLMStructuredCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMStructuredColumnConfig]):
|
|
146
118
|
@staticmethod
|
|
147
119
|
def metadata() -> GeneratorMetadata:
|
|
148
120
|
return GeneratorMetadata(
|
|
@@ -153,7 +125,7 @@ class LLMStructuredCellGenerator(WithLLMGeneration, ColumnGenerator[LLMStructure
|
|
|
153
125
|
)
|
|
154
126
|
|
|
155
127
|
|
|
156
|
-
class LLMJudgeCellGenerator(
|
|
128
|
+
class LLMJudgeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMJudgeColumnConfig]):
|
|
157
129
|
@staticmethod
|
|
158
130
|
def metadata() -> GeneratorMetadata:
|
|
159
131
|
return GeneratorMetadata(
|
|
@@ -163,10 +135,6 @@ class LLMJudgeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMJudgeColumnCon
|
|
|
163
135
|
required_resources=[ResourceType.MODEL_REGISTRY],
|
|
164
136
|
)
|
|
165
137
|
|
|
166
|
-
@property
|
|
167
|
-
def max_conversation_correction_steps(self) -> int:
|
|
168
|
-
return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
|
|
169
|
-
|
|
170
138
|
@property
|
|
171
139
|
def max_conversation_restarts(self) -> int:
|
|
172
140
|
return 2 * DEFAULT_MAX_CONVERSATION_RESTARTS
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
from data_designer.config.base import ConfigBase
|
|
5
5
|
from data_designer.config.column_configs import (
|
|
6
|
+
EmbeddingColumnConfig,
|
|
6
7
|
ExpressionColumnConfig,
|
|
7
8
|
LLMCodeColumnConfig,
|
|
8
9
|
LLMJudgeColumnConfig,
|
|
@@ -12,8 +13,9 @@ from data_designer.config.column_configs import (
|
|
|
12
13
|
)
|
|
13
14
|
from data_designer.config.column_types import DataDesignerColumnType
|
|
14
15
|
from data_designer.engine.column_generators.generators.base import ColumnGenerator
|
|
16
|
+
from data_designer.engine.column_generators.generators.embedding import EmbeddingCellGenerator
|
|
15
17
|
from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator
|
|
16
|
-
from data_designer.engine.column_generators.generators.
|
|
18
|
+
from data_designer.engine.column_generators.generators.llm_completion import (
|
|
17
19
|
LLMCodeCellGenerator,
|
|
18
20
|
LLMJudgeCellGenerator,
|
|
19
21
|
LLMStructuredCellGenerator,
|
|
@@ -40,11 +42,11 @@ def create_default_column_generator_registry(with_plugins: bool = True) -> Colum
|
|
|
40
42
|
registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig)
|
|
41
43
|
registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig)
|
|
42
44
|
registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig)
|
|
45
|
+
registry.register(DataDesignerColumnType.EMBEDDING, EmbeddingCellGenerator, EmbeddingColumnConfig)
|
|
43
46
|
registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig)
|
|
44
47
|
registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig)
|
|
45
48
|
registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig)
|
|
46
49
|
registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig)
|
|
47
|
-
|
|
48
50
|
if with_plugins:
|
|
49
51
|
for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR):
|
|
50
52
|
registry.register(
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from enum import Enum
|
|
5
|
-
from typing import Type
|
|
6
5
|
|
|
7
6
|
from pydantic import BaseModel, ConfigDict, Field, create_model
|
|
8
7
|
|
|
@@ -19,7 +18,7 @@ class BaseJudgeResponse(BaseModel):
|
|
|
19
18
|
reasoning: str = Field(..., description="Reasoning for the assigned score.")
|
|
20
19
|
|
|
21
20
|
|
|
22
|
-
def _stringify_scoring(options: dict, enum_type:
|
|
21
|
+
def _stringify_scoring(options: dict, enum_type: type[Enum]) -> str:
|
|
23
22
|
"""Convert score descriptions into a single text block."""
|
|
24
23
|
list_block = "\n".join(
|
|
25
24
|
[SCORING_FORMAT.format(score=score, description=description) for score, description in options.items()]
|
|
@@ -27,7 +26,7 @@ def _stringify_scoring(options: dict, enum_type: Type[Enum]) -> str:
|
|
|
27
26
|
return SCORE_FIELD_DESCRIPTION_FORMAT.format(enum_name=enum_type.__name__, scoring=list_block)
|
|
28
27
|
|
|
29
28
|
|
|
30
|
-
def create_judge_response_model(score: Score) ->
|
|
29
|
+
def create_judge_response_model(score: Score) -> type[BaseJudgeResponse]:
|
|
31
30
|
"""Create a JudgeResponse data type."""
|
|
32
31
|
enum_members = {}
|
|
33
32
|
for option in score.options.keys():
|
|
@@ -46,12 +45,12 @@ def create_judge_response_model(score: Score) -> Type[BaseJudgeResponse]:
|
|
|
46
45
|
|
|
47
46
|
|
|
48
47
|
def create_judge_structured_output_model(
|
|
49
|
-
judge_responses: list[
|
|
50
|
-
) ->
|
|
48
|
+
judge_responses: list[type[BaseJudgeResponse]],
|
|
49
|
+
) -> type[BaseModel]:
|
|
51
50
|
"""Create a JudgeStructuredOutput class dynamically."""
|
|
52
51
|
return create_model(
|
|
53
52
|
"JudgeStructuredOutput",
|
|
54
53
|
__doc__=f"Response schema for scores with the following names: {[response.__name__ for response in judge_responses]}.",
|
|
55
54
|
__base__=BaseModel,
|
|
56
|
-
**{response.__name__
|
|
55
|
+
**{response.__name__: (response, ...) for response in judge_responses},
|
|
57
56
|
)
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Generic,
|
|
6
|
+
from typing import Generic, TypeVar, get_origin
|
|
7
7
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
|
|
@@ -30,7 +30,7 @@ class ConfigurableTask(ABC, Generic[TaskConfigT]):
|
|
|
30
30
|
self._initialize()
|
|
31
31
|
|
|
32
32
|
@classmethod
|
|
33
|
-
def get_config_type(cls) ->
|
|
33
|
+
def get_config_type(cls) -> type[TaskConfigT]:
|
|
34
34
|
for base in cls.__orig_bases__:
|
|
35
35
|
if hasattr(base, "__args__") and len(base.__args__) == 1:
|
|
36
36
|
arg = base.__args__[0]
|
|
@@ -7,7 +7,6 @@ import shutil
|
|
|
7
7
|
from datetime import datetime
|
|
8
8
|
from functools import cached_property
|
|
9
9
|
from pathlib import Path
|
|
10
|
-
from typing import Union
|
|
11
10
|
|
|
12
11
|
import pandas as pd
|
|
13
12
|
from pydantic import BaseModel, field_validator, model_validator
|
|
@@ -77,7 +76,7 @@ class ArtifactStorage(BaseModel):
|
|
|
77
76
|
return self.base_dataset_path / self.processors_outputs_folder_name
|
|
78
77
|
|
|
79
78
|
@field_validator("artifact_path")
|
|
80
|
-
def validate_artifact_path(cls, v:
|
|
79
|
+
def validate_artifact_path(cls, v: Path | str) -> Path:
|
|
81
80
|
v = Path(v)
|
|
82
81
|
if not v.is_dir():
|
|
83
82
|
raise ArtifactStorageError("Artifact path must exist and be a directory")
|
|
@@ -10,15 +10,18 @@ from typing import Callable
|
|
|
10
10
|
|
|
11
11
|
import pandas as pd
|
|
12
12
|
|
|
13
|
-
from data_designer.config.column_types import ColumnConfigT,
|
|
13
|
+
from data_designer.config.column_types import ColumnConfigT, column_type_is_model_generated
|
|
14
14
|
from data_designer.config.dataset_builders import BuildStage
|
|
15
15
|
from data_designer.config.processors import (
|
|
16
16
|
DropColumnsProcessorConfig,
|
|
17
17
|
ProcessorConfig,
|
|
18
18
|
ProcessorType,
|
|
19
19
|
)
|
|
20
|
-
from data_designer.engine.column_generators.generators.base import
|
|
21
|
-
|
|
20
|
+
from data_designer.engine.column_generators.generators.base import (
|
|
21
|
+
ColumnGenerator,
|
|
22
|
+
GenerationStrategy,
|
|
23
|
+
WithModelGeneration,
|
|
24
|
+
)
|
|
22
25
|
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
|
|
23
26
|
from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
|
|
24
27
|
from data_designer.engine.dataset_builders.multi_column_configs import (
|
|
@@ -72,7 +75,7 @@ class ColumnWiseDatasetBuilder:
|
|
|
72
75
|
|
|
73
76
|
@functools.cached_property
|
|
74
77
|
def llm_generated_column_configs(self) -> list[ColumnConfigT]:
|
|
75
|
-
return [config for config in self.single_column_configs if
|
|
78
|
+
return [config for config in self.single_column_configs if column_type_is_model_generated(config.column_type)]
|
|
76
79
|
|
|
77
80
|
def build(
|
|
78
81
|
self,
|
|
@@ -169,10 +172,8 @@ class ColumnWiseDatasetBuilder:
|
|
|
169
172
|
|
|
170
173
|
def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
|
|
171
174
|
max_workers = MAX_CONCURRENCY_PER_NON_LLM_GENERATOR
|
|
172
|
-
if isinstance(generator,
|
|
175
|
+
if isinstance(generator, WithModelGeneration):
|
|
173
176
|
max_workers = generator.inference_parameters.max_parallel_requests
|
|
174
|
-
elif hasattr(generator.config, "max_parallel_requests"):
|
|
175
|
-
max_workers = generator.config.max_parallel_requests
|
|
176
177
|
self._fan_out_with_threads(generator, max_workers=max_workers)
|
|
177
178
|
|
|
178
179
|
def _run_full_column_generator(self, generator: ColumnGenerator) -> None:
|
|
@@ -180,12 +181,12 @@ class ColumnWiseDatasetBuilder:
|
|
|
180
181
|
self.batch_manager.update_records(df.to_dict(orient="records"))
|
|
181
182
|
|
|
182
183
|
def _run_model_health_check_if_needed(self) -> bool:
|
|
183
|
-
if any(
|
|
184
|
+
if any(column_type_is_model_generated(config.column_type) for config in self.single_column_configs):
|
|
184
185
|
self._resource_provider.model_registry.run_health_check(
|
|
185
|
-
set(config.model_alias for config in self.llm_generated_column_configs)
|
|
186
|
+
list(set(config.model_alias for config in self.llm_generated_column_configs))
|
|
186
187
|
)
|
|
187
188
|
|
|
188
|
-
def _fan_out_with_threads(self, generator:
|
|
189
|
+
def _fan_out_with_threads(self, generator: WithModelGeneration, max_workers: int) -> None:
|
|
189
190
|
if generator.generation_strategy != GenerationStrategy.CELL_BY_CELL:
|
|
190
191
|
raise DatasetGenerationError(
|
|
191
192
|
f"Generator {generator.metadata().name} is not a {GenerationStrategy.CELL_BY_CELL} "
|