data-designer 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/cli/README.md +15 -1
- data_designer/cli/commands/download.py +56 -0
- data_designer/cli/commands/list.py +4 -18
- data_designer/cli/controllers/__init__.py +2 -1
- data_designer/cli/controllers/download_controller.py +217 -0
- data_designer/cli/controllers/model_controller.py +4 -3
- data_designer/cli/forms/field.py +65 -19
- data_designer/cli/forms/model_builder.py +251 -44
- data_designer/cli/main.py +11 -1
- data_designer/cli/repositories/persona_repository.py +88 -0
- data_designer/cli/services/__init__.py +2 -1
- data_designer/cli/services/download_service.py +97 -0
- data_designer/cli/ui.py +131 -0
- data_designer/cli/utils.py +34 -0
- data_designer/config/analysis/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +75 -7
- data_designer/config/analysis/column_statistics.py +192 -48
- data_designer/config/analysis/dataset_profiler.py +23 -5
- data_designer/config/analysis/utils/reporting.py +3 -3
- data_designer/config/base.py +3 -3
- data_designer/config/column_configs.py +27 -6
- data_designer/config/column_types.py +24 -17
- data_designer/config/config_builder.py +34 -26
- data_designer/config/data_designer_config.py +7 -7
- data_designer/config/datastore.py +6 -6
- data_designer/config/default_model_settings.py +27 -34
- data_designer/config/exports.py +14 -1
- data_designer/config/models.py +155 -29
- data_designer/config/preview_results.py +5 -4
- data_designer/config/processors.py +109 -4
- data_designer/config/sampler_constraints.py +1 -2
- data_designer/config/sampler_params.py +31 -31
- data_designer/config/seed.py +1 -2
- data_designer/config/utils/code_lang.py +4 -5
- data_designer/config/utils/constants.py +31 -8
- data_designer/config/utils/io_helpers.py +5 -5
- data_designer/config/utils/misc.py +1 -4
- data_designer/config/utils/numerical_helpers.py +2 -2
- data_designer/config/utils/type_helpers.py +3 -3
- data_designer/config/utils/validation.py +39 -9
- data_designer/config/utils/visualization.py +62 -15
- data_designer/config/validator_params.py +4 -8
- data_designer/engine/analysis/column_profilers/base.py +0 -7
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
- data_designer/engine/analysis/column_statistics.py +16 -16
- data_designer/engine/analysis/dataset_profiler.py +25 -4
- data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
- data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
- data_designer/engine/column_generators/generators/base.py +34 -0
- data_designer/engine/column_generators/generators/embedding.py +45 -0
- data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
- data_designer/engine/column_generators/registry.py +4 -2
- data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
- data_designer/engine/configurable_task.py +2 -2
- data_designer/engine/dataset_builders/artifact_storage.py +14 -5
- data_designer/engine/dataset_builders/column_wise_builder.py +12 -8
- data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
- data_designer/engine/models/facade.py +66 -9
- data_designer/engine/models/litellm_overrides.py +5 -6
- data_designer/engine/models/parsers/errors.py +2 -4
- data_designer/engine/models/parsers/parser.py +2 -3
- data_designer/engine/models/parsers/postprocessors.py +3 -4
- data_designer/engine/models/parsers/types.py +4 -4
- data_designer/engine/models/registry.py +20 -11
- data_designer/engine/models/usage.py +7 -9
- data_designer/engine/processing/ginja/ast.py +1 -2
- data_designer/engine/processing/processors/drop_columns.py +1 -1
- data_designer/engine/processing/processors/registry.py +3 -0
- data_designer/engine/processing/processors/schema_transform.py +53 -0
- data_designer/engine/processing/utils.py +40 -2
- data_designer/engine/registry/base.py +12 -12
- data_designer/engine/sampling_gen/constraints.py +1 -2
- data_designer/engine/sampling_gen/data_sources/base.py +14 -14
- data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
- data_designer/engine/sampling_gen/people_gen.py +3 -7
- data_designer/engine/validators/base.py +2 -2
- data_designer/interface/data_designer.py +12 -0
- data_designer/interface/results.py +36 -0
- data_designer/logging.py +2 -2
- data_designer/plugin_manager.py +3 -3
- data_designer/plugins/plugin.py +3 -3
- data_designer/plugins/registry.py +2 -2
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/METADATA +9 -9
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/RECORD +88 -81
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -10,43 +10,41 @@ from data_designer.config.column_configs import (
|
|
|
10
10
|
LLMStructuredColumnConfig,
|
|
11
11
|
LLMTextColumnConfig,
|
|
12
12
|
)
|
|
13
|
-
from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP
|
|
14
|
-
from data_designer.config.models import InferenceParameters, ModelConfig
|
|
15
13
|
from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX
|
|
16
14
|
from data_designer.engine.column_generators.generators.base import (
|
|
17
15
|
ColumnGenerator,
|
|
18
16
|
GenerationStrategy,
|
|
19
17
|
GeneratorMetadata,
|
|
18
|
+
WithModelGeneration,
|
|
20
19
|
)
|
|
21
20
|
from data_designer.engine.column_generators.utils.prompt_renderer import (
|
|
22
21
|
PromptType,
|
|
23
22
|
RecordBasedPromptRenderer,
|
|
24
23
|
create_response_recipe,
|
|
25
24
|
)
|
|
26
|
-
from data_designer.engine.models.facade import ModelFacade
|
|
27
25
|
from data_designer.engine.models.recipes.base import ResponseRecipe
|
|
28
26
|
from data_designer.engine.processing.utils import deserialize_json_values
|
|
29
27
|
from data_designer.engine.resources.resource_provider import ResourceType
|
|
30
28
|
|
|
31
|
-
|
|
32
|
-
DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
33
30
|
|
|
34
31
|
|
|
35
|
-
|
|
32
|
+
DEFAULT_MAX_CONVERSATION_RESTARTS = 5
|
|
33
|
+
DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0
|
|
36
34
|
|
|
37
35
|
|
|
38
|
-
class
|
|
36
|
+
class WithChatCompletionGeneration(WithModelGeneration):
|
|
39
37
|
@functools.cached_property
|
|
40
|
-
def
|
|
41
|
-
return
|
|
38
|
+
def response_recipe(self) -> ResponseRecipe:
|
|
39
|
+
return create_response_recipe(self.config, self.model_config)
|
|
42
40
|
|
|
43
|
-
@
|
|
44
|
-
def
|
|
45
|
-
return
|
|
41
|
+
@property
|
|
42
|
+
def max_conversation_correction_steps(self) -> int:
|
|
43
|
+
return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
|
|
46
44
|
|
|
47
|
-
@
|
|
48
|
-
def
|
|
49
|
-
return
|
|
45
|
+
@property
|
|
46
|
+
def max_conversation_restarts(self) -> int:
|
|
47
|
+
return DEFAULT_MAX_CONVERSATION_RESTARTS
|
|
50
48
|
|
|
51
49
|
@functools.cached_property
|
|
52
50
|
def prompt_renderer(self) -> RecordBasedPromptRenderer:
|
|
@@ -59,18 +57,6 @@ class WithLLMGeneration:
|
|
|
59
57
|
},
|
|
60
58
|
)
|
|
61
59
|
|
|
62
|
-
@functools.cached_property
|
|
63
|
-
def response_recipe(self) -> ResponseRecipe:
|
|
64
|
-
return create_response_recipe(self.config, self.model_config)
|
|
65
|
-
|
|
66
|
-
@property
|
|
67
|
-
def max_conversation_correction_steps(self) -> int:
|
|
68
|
-
return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
|
|
69
|
-
|
|
70
|
-
@property
|
|
71
|
-
def max_conversation_restarts(self) -> int:
|
|
72
|
-
return DEFAULT_MAX_CONVERSATION_RESTARTS
|
|
73
|
-
|
|
74
60
|
def generate(self, data: dict) -> dict:
|
|
75
61
|
deserialized_record = deserialize_json_values(data)
|
|
76
62
|
|
|
@@ -96,7 +82,6 @@ class WithLLMGeneration:
|
|
|
96
82
|
max_correction_steps=self.max_conversation_correction_steps,
|
|
97
83
|
max_conversation_restarts=self.max_conversation_restarts,
|
|
98
84
|
purpose=f"running generation for column '{self.config.name}'",
|
|
99
|
-
**self.inference_parameters.generate_kwargs,
|
|
100
85
|
)
|
|
101
86
|
|
|
102
87
|
data[self.config.name] = deserialize_json_values(self.response_recipe.serialize_output(response))
|
|
@@ -106,21 +91,8 @@ class WithLLMGeneration:
|
|
|
106
91
|
|
|
107
92
|
return data
|
|
108
93
|
|
|
109
|
-
def log_pre_generation(self) -> None:
|
|
110
|
-
emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type]
|
|
111
|
-
logger.info(f"{emoji} Preparing {self.config.column_type} column generation")
|
|
112
|
-
logger.info(f" |-- column name: {self.config.name!r}")
|
|
113
|
-
logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}")
|
|
114
|
-
if self.model_config.provider is None:
|
|
115
|
-
logger.info(f" |-- default model provider: {self._get_provider_name()!r}")
|
|
116
|
-
|
|
117
|
-
def _get_provider_name(self) -> str:
|
|
118
|
-
model_alias = self.model_config.alias
|
|
119
|
-
provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias)
|
|
120
|
-
return provider.name
|
|
121
|
-
|
|
122
94
|
|
|
123
|
-
class LLMTextCellGenerator(
|
|
95
|
+
class LLMTextCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMTextColumnConfig]):
|
|
124
96
|
@staticmethod
|
|
125
97
|
def metadata() -> GeneratorMetadata:
|
|
126
98
|
return GeneratorMetadata(
|
|
@@ -131,7 +103,7 @@ class LLMTextCellGenerator(WithLLMGeneration, ColumnGenerator[LLMTextColumnConfi
|
|
|
131
103
|
)
|
|
132
104
|
|
|
133
105
|
|
|
134
|
-
class LLMCodeCellGenerator(
|
|
106
|
+
class LLMCodeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMCodeColumnConfig]):
|
|
135
107
|
@staticmethod
|
|
136
108
|
def metadata() -> GeneratorMetadata:
|
|
137
109
|
return GeneratorMetadata(
|
|
@@ -142,7 +114,7 @@ class LLMCodeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMCodeColumnConfi
|
|
|
142
114
|
)
|
|
143
115
|
|
|
144
116
|
|
|
145
|
-
class LLMStructuredCellGenerator(
|
|
117
|
+
class LLMStructuredCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMStructuredColumnConfig]):
|
|
146
118
|
@staticmethod
|
|
147
119
|
def metadata() -> GeneratorMetadata:
|
|
148
120
|
return GeneratorMetadata(
|
|
@@ -153,7 +125,7 @@ class LLMStructuredCellGenerator(WithLLMGeneration, ColumnGenerator[LLMStructure
|
|
|
153
125
|
)
|
|
154
126
|
|
|
155
127
|
|
|
156
|
-
class LLMJudgeCellGenerator(
|
|
128
|
+
class LLMJudgeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMJudgeColumnConfig]):
|
|
157
129
|
@staticmethod
|
|
158
130
|
def metadata() -> GeneratorMetadata:
|
|
159
131
|
return GeneratorMetadata(
|
|
@@ -163,10 +135,6 @@ class LLMJudgeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMJudgeColumnCon
|
|
|
163
135
|
required_resources=[ResourceType.MODEL_REGISTRY],
|
|
164
136
|
)
|
|
165
137
|
|
|
166
|
-
@property
|
|
167
|
-
def max_conversation_correction_steps(self) -> int:
|
|
168
|
-
return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
|
|
169
|
-
|
|
170
138
|
@property
|
|
171
139
|
def max_conversation_restarts(self) -> int:
|
|
172
140
|
return 2 * DEFAULT_MAX_CONVERSATION_RESTARTS
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
from data_designer.config.base import ConfigBase
|
|
5
5
|
from data_designer.config.column_configs import (
|
|
6
|
+
EmbeddingColumnConfig,
|
|
6
7
|
ExpressionColumnConfig,
|
|
7
8
|
LLMCodeColumnConfig,
|
|
8
9
|
LLMJudgeColumnConfig,
|
|
@@ -12,8 +13,9 @@ from data_designer.config.column_configs import (
|
|
|
12
13
|
)
|
|
13
14
|
from data_designer.config.column_types import DataDesignerColumnType
|
|
14
15
|
from data_designer.engine.column_generators.generators.base import ColumnGenerator
|
|
16
|
+
from data_designer.engine.column_generators.generators.embedding import EmbeddingCellGenerator
|
|
15
17
|
from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator
|
|
16
|
-
from data_designer.engine.column_generators.generators.
|
|
18
|
+
from data_designer.engine.column_generators.generators.llm_completion import (
|
|
17
19
|
LLMCodeCellGenerator,
|
|
18
20
|
LLMJudgeCellGenerator,
|
|
19
21
|
LLMStructuredCellGenerator,
|
|
@@ -40,11 +42,11 @@ def create_default_column_generator_registry(with_plugins: bool = True) -> Colum
|
|
|
40
42
|
registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig)
|
|
41
43
|
registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig)
|
|
42
44
|
registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig)
|
|
45
|
+
registry.register(DataDesignerColumnType.EMBEDDING, EmbeddingCellGenerator, EmbeddingColumnConfig)
|
|
43
46
|
registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig)
|
|
44
47
|
registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig)
|
|
45
48
|
registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig)
|
|
46
49
|
registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig)
|
|
47
|
-
|
|
48
50
|
if with_plugins:
|
|
49
51
|
for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR):
|
|
50
52
|
registry.register(
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from enum import Enum
|
|
5
|
-
from typing import Type
|
|
6
5
|
|
|
7
6
|
from pydantic import BaseModel, ConfigDict, Field, create_model
|
|
8
7
|
|
|
@@ -19,7 +18,7 @@ class BaseJudgeResponse(BaseModel):
|
|
|
19
18
|
reasoning: str = Field(..., description="Reasoning for the assigned score.")
|
|
20
19
|
|
|
21
20
|
|
|
22
|
-
def _stringify_scoring(options: dict, enum_type:
|
|
21
|
+
def _stringify_scoring(options: dict, enum_type: type[Enum]) -> str:
|
|
23
22
|
"""Convert score descriptions into a single text block."""
|
|
24
23
|
list_block = "\n".join(
|
|
25
24
|
[SCORING_FORMAT.format(score=score, description=description) for score, description in options.items()]
|
|
@@ -27,7 +26,7 @@ def _stringify_scoring(options: dict, enum_type: Type[Enum]) -> str:
|
|
|
27
26
|
return SCORE_FIELD_DESCRIPTION_FORMAT.format(enum_name=enum_type.__name__, scoring=list_block)
|
|
28
27
|
|
|
29
28
|
|
|
30
|
-
def create_judge_response_model(score: Score) ->
|
|
29
|
+
def create_judge_response_model(score: Score) -> type[BaseJudgeResponse]:
|
|
31
30
|
"""Create a JudgeResponse data type."""
|
|
32
31
|
enum_members = {}
|
|
33
32
|
for option in score.options.keys():
|
|
@@ -46,12 +45,12 @@ def create_judge_response_model(score: Score) -> Type[BaseJudgeResponse]:
|
|
|
46
45
|
|
|
47
46
|
|
|
48
47
|
def create_judge_structured_output_model(
|
|
49
|
-
judge_responses: list[
|
|
50
|
-
) ->
|
|
48
|
+
judge_responses: list[type[BaseJudgeResponse]],
|
|
49
|
+
) -> type[BaseModel]:
|
|
51
50
|
"""Create a JudgeStructuredOutput class dynamically."""
|
|
52
51
|
return create_model(
|
|
53
52
|
"JudgeStructuredOutput",
|
|
54
53
|
__doc__=f"Response schema for scores with the following names: {[response.__name__ for response in judge_responses]}.",
|
|
55
54
|
__base__=BaseModel,
|
|
56
|
-
**{response.__name__
|
|
55
|
+
**{response.__name__: (response, ...) for response in judge_responses},
|
|
57
56
|
)
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Generic,
|
|
6
|
+
from typing import Generic, TypeVar, get_origin
|
|
7
7
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
|
|
@@ -30,7 +30,7 @@ class ConfigurableTask(ABC, Generic[TaskConfigT]):
|
|
|
30
30
|
self._initialize()
|
|
31
31
|
|
|
32
32
|
@classmethod
|
|
33
|
-
def get_config_type(cls) ->
|
|
33
|
+
def get_config_type(cls) -> type[TaskConfigT]:
|
|
34
34
|
for base in cls.__orig_bases__:
|
|
35
35
|
if hasattr(base, "__args__") and len(base.__args__) == 1:
|
|
36
36
|
arg = base.__args__[0]
|
|
@@ -7,7 +7,6 @@ import shutil
|
|
|
7
7
|
from datetime import datetime
|
|
8
8
|
from functools import cached_property
|
|
9
9
|
from pathlib import Path
|
|
10
|
-
from typing import Union
|
|
11
10
|
|
|
12
11
|
import pandas as pd
|
|
13
12
|
from pydantic import BaseModel, field_validator, model_validator
|
|
@@ -25,6 +24,7 @@ class BatchStage(StrEnum):
|
|
|
25
24
|
PARTIAL_RESULT = "partial_results_path"
|
|
26
25
|
FINAL_RESULT = "final_dataset_path"
|
|
27
26
|
DROPPED_COLUMNS = "dropped_columns_dataset_path"
|
|
27
|
+
PROCESSORS_OUTPUTS = "processors_outputs_path"
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class ArtifactStorage(BaseModel):
|
|
@@ -33,6 +33,7 @@ class ArtifactStorage(BaseModel):
|
|
|
33
33
|
final_dataset_folder_name: str = "parquet-files"
|
|
34
34
|
partial_results_folder_name: str = "tmp-partial-parquet-files"
|
|
35
35
|
dropped_columns_folder_name: str = "dropped-columns-parquet-files"
|
|
36
|
+
processors_outputs_folder_name: str = "processors-files"
|
|
36
37
|
|
|
37
38
|
@property
|
|
38
39
|
def artifact_path_exists(self) -> bool:
|
|
@@ -70,8 +71,12 @@ class ArtifactStorage(BaseModel):
|
|
|
70
71
|
def partial_results_path(self) -> Path:
|
|
71
72
|
return self.base_dataset_path / self.partial_results_folder_name
|
|
72
73
|
|
|
74
|
+
@property
|
|
75
|
+
def processors_outputs_path(self) -> Path:
|
|
76
|
+
return self.base_dataset_path / self.processors_outputs_folder_name
|
|
77
|
+
|
|
73
78
|
@field_validator("artifact_path")
|
|
74
|
-
def validate_artifact_path(cls, v:
|
|
79
|
+
def validate_artifact_path(cls, v: Path | str) -> Path:
|
|
75
80
|
v = Path(v)
|
|
76
81
|
if not v.is_dir():
|
|
77
82
|
raise ArtifactStorageError("Artifact path must exist and be a directory")
|
|
@@ -84,6 +89,7 @@ class ArtifactStorage(BaseModel):
|
|
|
84
89
|
self.final_dataset_folder_name,
|
|
85
90
|
self.partial_results_folder_name,
|
|
86
91
|
self.dropped_columns_folder_name,
|
|
92
|
+
self.processors_outputs_folder_name,
|
|
87
93
|
]
|
|
88
94
|
|
|
89
95
|
for name in folder_names:
|
|
@@ -169,9 +175,10 @@ class ArtifactStorage(BaseModel):
|
|
|
169
175
|
batch_number: int,
|
|
170
176
|
dataframe: pd.DataFrame,
|
|
171
177
|
batch_stage: BatchStage,
|
|
178
|
+
subfolder: str | None = None,
|
|
172
179
|
) -> Path:
|
|
173
180
|
file_path = self.create_batch_file_path(batch_number, batch_stage=batch_stage)
|
|
174
|
-
self.write_parquet_file(file_path.name, dataframe, batch_stage)
|
|
181
|
+
self.write_parquet_file(file_path.name, dataframe, batch_stage, subfolder=subfolder)
|
|
175
182
|
return file_path
|
|
176
183
|
|
|
177
184
|
def write_parquet_file(
|
|
@@ -179,9 +186,11 @@ class ArtifactStorage(BaseModel):
|
|
|
179
186
|
parquet_file_name: str,
|
|
180
187
|
dataframe: pd.DataFrame,
|
|
181
188
|
batch_stage: BatchStage,
|
|
189
|
+
subfolder: str | None = None,
|
|
182
190
|
) -> Path:
|
|
183
|
-
|
|
184
|
-
|
|
191
|
+
subfolder = subfolder or ""
|
|
192
|
+
self.mkdir_if_needed(self._get_stage_path(batch_stage) / subfolder)
|
|
193
|
+
file_path = self._get_stage_path(batch_stage) / subfolder / parquet_file_name
|
|
185
194
|
dataframe.to_parquet(file_path, index=False)
|
|
186
195
|
return file_path
|
|
187
196
|
|
|
@@ -10,15 +10,18 @@ from typing import Callable
|
|
|
10
10
|
|
|
11
11
|
import pandas as pd
|
|
12
12
|
|
|
13
|
-
from data_designer.config.column_types import ColumnConfigT,
|
|
13
|
+
from data_designer.config.column_types import ColumnConfigT, column_type_is_model_generated
|
|
14
14
|
from data_designer.config.dataset_builders import BuildStage
|
|
15
15
|
from data_designer.config.processors import (
|
|
16
16
|
DropColumnsProcessorConfig,
|
|
17
17
|
ProcessorConfig,
|
|
18
18
|
ProcessorType,
|
|
19
19
|
)
|
|
20
|
-
from data_designer.engine.column_generators.generators.base import
|
|
21
|
-
|
|
20
|
+
from data_designer.engine.column_generators.generators.base import (
|
|
21
|
+
ColumnGenerator,
|
|
22
|
+
GenerationStrategy,
|
|
23
|
+
WithModelGeneration,
|
|
24
|
+
)
|
|
22
25
|
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
|
|
23
26
|
from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
|
|
24
27
|
from data_designer.engine.dataset_builders.multi_column_configs import (
|
|
@@ -72,7 +75,7 @@ class ColumnWiseDatasetBuilder:
|
|
|
72
75
|
|
|
73
76
|
@functools.cached_property
|
|
74
77
|
def llm_generated_column_configs(self) -> list[ColumnConfigT]:
|
|
75
|
-
return [config for config in self.single_column_configs if
|
|
78
|
+
return [config for config in self.single_column_configs if column_type_is_model_generated(config.column_type)]
|
|
76
79
|
|
|
77
80
|
def build(
|
|
78
81
|
self,
|
|
@@ -169,7 +172,7 @@ class ColumnWiseDatasetBuilder:
|
|
|
169
172
|
|
|
170
173
|
def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
|
|
171
174
|
max_workers = MAX_CONCURRENCY_PER_NON_LLM_GENERATOR
|
|
172
|
-
if isinstance(generator,
|
|
175
|
+
if isinstance(generator, WithModelGeneration):
|
|
173
176
|
max_workers = generator.inference_parameters.max_parallel_requests
|
|
174
177
|
self._fan_out_with_threads(generator, max_workers=max_workers)
|
|
175
178
|
|
|
@@ -178,12 +181,12 @@ class ColumnWiseDatasetBuilder:
|
|
|
178
181
|
self.batch_manager.update_records(df.to_dict(orient="records"))
|
|
179
182
|
|
|
180
183
|
def _run_model_health_check_if_needed(self) -> bool:
|
|
181
|
-
if any(
|
|
184
|
+
if any(column_type_is_model_generated(config.column_type) for config in self.single_column_configs):
|
|
182
185
|
self._resource_provider.model_registry.run_health_check(
|
|
183
|
-
set(config.model_alias for config in self.llm_generated_column_configs)
|
|
186
|
+
list(set(config.model_alias for config in self.llm_generated_column_configs))
|
|
184
187
|
)
|
|
185
188
|
|
|
186
|
-
def _fan_out_with_threads(self, generator:
|
|
189
|
+
def _fan_out_with_threads(self, generator: WithModelGeneration, max_workers: int) -> None:
|
|
187
190
|
if generator.generation_strategy != GenerationStrategy.CELL_BY_CELL:
|
|
188
191
|
raise DatasetGenerationError(
|
|
189
192
|
f"Generator {generator.metadata().name} is not a {GenerationStrategy.CELL_BY_CELL} "
|
|
@@ -244,6 +247,7 @@ class ColumnWiseDatasetBuilder:
|
|
|
244
247
|
processors[BuildStage.POST_BATCH].append( # as post-batch by default
|
|
245
248
|
DropColumnsProcessor(
|
|
246
249
|
config=DropColumnsProcessorConfig(
|
|
250
|
+
name="default_drop_columns_processor",
|
|
247
251
|
column_names=columns_to_drop,
|
|
248
252
|
build_stage=BuildStage.POST_BATCH,
|
|
249
253
|
),
|
|
@@ -8,7 +8,7 @@ import json
|
|
|
8
8
|
import logging
|
|
9
9
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
10
10
|
from threading import Lock, Semaphore
|
|
11
|
-
from typing import Any,
|
|
11
|
+
from typing import Any, Protocol
|
|
12
12
|
|
|
13
13
|
from pydantic import BaseModel, Field
|
|
14
14
|
|
|
@@ -46,13 +46,13 @@ class ExecutorResults(BaseModel):
|
|
|
46
46
|
class CallbackWithContext(Protocol):
|
|
47
47
|
"""Executor callback functions must accept a context kw argument."""
|
|
48
48
|
|
|
49
|
-
def __call__(self, result: Any, *, context:
|
|
49
|
+
def __call__(self, result: Any, *, context: dict | None = None) -> Any: ...
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
class ErrorCallbackWithContext(Protocol):
|
|
53
53
|
"""Error callbacks take the Exception instance and context."""
|
|
54
54
|
|
|
55
|
-
def __call__(self, exc: Exception, *, context:
|
|
55
|
+
def __call__(self, exc: Exception, *, context: dict | None = None) -> Any: ...
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
class ConcurrentThreadExecutor:
|
|
@@ -92,8 +92,8 @@ class ConcurrentThreadExecutor:
|
|
|
92
92
|
*,
|
|
93
93
|
max_workers: int,
|
|
94
94
|
column_name: str,
|
|
95
|
-
result_callback:
|
|
96
|
-
error_callback:
|
|
95
|
+
result_callback: CallbackWithContext | None = None,
|
|
96
|
+
error_callback: ErrorCallbackWithContext | None = None,
|
|
97
97
|
shutdown_error_rate: float = 0.50,
|
|
98
98
|
shutdown_error_window: int = 10,
|
|
99
99
|
):
|
|
@@ -136,7 +136,7 @@ class ConcurrentThreadExecutor:
|
|
|
136
136
|
)
|
|
137
137
|
)
|
|
138
138
|
|
|
139
|
-
def submit(self, fn, *args, context:
|
|
139
|
+
def submit(self, fn, *args, context: dict | None = None, **kwargs) -> None:
|
|
140
140
|
if self._executor is None:
|
|
141
141
|
raise RuntimeError("Executor is not initialized, this class should be used as a context manager.")
|
|
142
142
|
|
|
@@ -9,9 +9,9 @@ from copy import deepcopy
|
|
|
9
9
|
from typing import Any
|
|
10
10
|
|
|
11
11
|
from litellm.types.router import DeploymentTypedDict, LiteLLM_Params
|
|
12
|
-
from litellm.types.utils import ModelResponse
|
|
12
|
+
from litellm.types.utils import EmbeddingResponse, ModelResponse
|
|
13
13
|
|
|
14
|
-
from data_designer.config.models import ModelConfig, ModelProvider
|
|
14
|
+
from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
|
|
15
15
|
from data_designer.engine.model_provider import ModelProviderRegistry
|
|
16
16
|
from data_designer.engine.models.errors import (
|
|
17
17
|
GenerationValidationFailureError,
|
|
@@ -49,6 +49,10 @@ class ModelFacade:
|
|
|
49
49
|
def model_provider(self) -> ModelProvider:
|
|
50
50
|
return self._model_provider_registry.get_provider(self._model_config.provider)
|
|
51
51
|
|
|
52
|
+
@property
|
|
53
|
+
def model_generation_type(self) -> GenerationType:
|
|
54
|
+
return self._model_config.generation_type
|
|
55
|
+
|
|
52
56
|
@property
|
|
53
57
|
def model_provider_name(self) -> str:
|
|
54
58
|
return self.model_provider.name
|
|
@@ -64,13 +68,12 @@ class ModelFacade:
|
|
|
64
68
|
def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs) -> ModelResponse:
|
|
65
69
|
logger.debug(
|
|
66
70
|
f"Prompting model {self.model_name!r}...",
|
|
67
|
-
extra={"model": self.model_name, "messages": messages
|
|
71
|
+
extra={"model": self.model_name, "messages": messages},
|
|
68
72
|
)
|
|
69
73
|
response = None
|
|
70
|
-
|
|
71
|
-
kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
|
|
74
|
+
kwargs = self.consolidate_kwargs(**kwargs)
|
|
72
75
|
try:
|
|
73
|
-
response = self._router.completion(self.model_name, messages, **kwargs)
|
|
76
|
+
response = self._router.completion(model=self.model_name, messages=messages, **kwargs)
|
|
74
77
|
logger.debug(
|
|
75
78
|
f"Received completion from model {self.model_name!r}",
|
|
76
79
|
extra={
|
|
@@ -84,9 +87,50 @@ class ModelFacade:
|
|
|
84
87
|
except Exception as e:
|
|
85
88
|
raise e
|
|
86
89
|
finally:
|
|
87
|
-
if not skip_usage_tracking:
|
|
90
|
+
if not skip_usage_tracking and response is not None:
|
|
88
91
|
self._track_usage(response)
|
|
89
92
|
|
|
93
|
+
def consolidate_kwargs(self, **kwargs) -> dict[str, Any]:
|
|
94
|
+
# Remove purpose from kwargs to avoid passing it to the model
|
|
95
|
+
kwargs.pop("purpose", None)
|
|
96
|
+
kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs}
|
|
97
|
+
if self.model_provider.extra_body:
|
|
98
|
+
kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
|
|
99
|
+
return kwargs
|
|
100
|
+
|
|
101
|
+
@catch_llm_exceptions
|
|
102
|
+
def generate_text_embeddings(
|
|
103
|
+
self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
|
|
104
|
+
) -> list[list[float]]:
|
|
105
|
+
logger.debug(
|
|
106
|
+
f"Generating embeddings with model {self.model_name!r}...",
|
|
107
|
+
extra={
|
|
108
|
+
"model": self.model_name,
|
|
109
|
+
"input_count": len(input_texts),
|
|
110
|
+
},
|
|
111
|
+
)
|
|
112
|
+
kwargs = self.consolidate_kwargs(**kwargs)
|
|
113
|
+
response = None
|
|
114
|
+
try:
|
|
115
|
+
response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs)
|
|
116
|
+
logger.debug(
|
|
117
|
+
f"Received embeddings from model {self.model_name!r}",
|
|
118
|
+
extra={
|
|
119
|
+
"model": self.model_name,
|
|
120
|
+
"embedding_count": len(response.data) if response.data else 0,
|
|
121
|
+
"usage": self._usage_stats.model_dump(),
|
|
122
|
+
},
|
|
123
|
+
)
|
|
124
|
+
if response.data and len(response.data) == len(input_texts):
|
|
125
|
+
return [data["embedding"] for data in response.data]
|
|
126
|
+
else:
|
|
127
|
+
raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}")
|
|
128
|
+
except Exception as e:
|
|
129
|
+
raise e
|
|
130
|
+
finally:
|
|
131
|
+
if not skip_usage_tracking and response is not None:
|
|
132
|
+
self._track_usage_from_embedding(response)
|
|
133
|
+
|
|
90
134
|
@catch_llm_exceptions
|
|
91
135
|
def generate(
|
|
92
136
|
self,
|
|
@@ -218,8 +262,21 @@ class ModelFacade:
|
|
|
218
262
|
):
|
|
219
263
|
self._usage_stats.extend(
|
|
220
264
|
token_usage=TokenUsageStats(
|
|
221
|
-
|
|
222
|
-
|
|
265
|
+
input_tokens=response.usage.prompt_tokens,
|
|
266
|
+
output_tokens=response.usage.completion_tokens,
|
|
267
|
+
),
|
|
268
|
+
request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
def _track_usage_from_embedding(self, response: EmbeddingResponse | None) -> None:
|
|
272
|
+
if response is None:
|
|
273
|
+
self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
|
|
274
|
+
return
|
|
275
|
+
if response.usage is not None and response.usage.prompt_tokens is not None:
|
|
276
|
+
self._usage_stats.extend(
|
|
277
|
+
token_usage=TokenUsageStats(
|
|
278
|
+
input_tokens=response.usage.prompt_tokens,
|
|
279
|
+
output_tokens=0,
|
|
223
280
|
),
|
|
224
281
|
request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
|
|
225
282
|
)
|
|
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import random
|
|
7
7
|
import threading
|
|
8
|
-
from typing import Optional, Union
|
|
9
8
|
|
|
10
9
|
import httpx
|
|
11
10
|
import litellm
|
|
@@ -90,7 +89,7 @@ class CustomRouter(Router):
|
|
|
90
89
|
self._initial_retry_after_s = initial_retry_after_s
|
|
91
90
|
self._jitter_pct = jitter_pct
|
|
92
91
|
|
|
93
|
-
def _extract_retry_delay_from_headers(self, e: Exception) ->
|
|
92
|
+
def _extract_retry_delay_from_headers(self, e: Exception) -> int | float | None:
|
|
94
93
|
"""
|
|
95
94
|
Most of this code logic was extracted directly from the parent
|
|
96
95
|
`Router`'s `_time_to_sleep_before_retry` function. Our override
|
|
@@ -99,7 +98,7 @@ class CustomRouter(Router):
|
|
|
99
98
|
return this info, we'll simply use that retry value returned here.
|
|
100
99
|
"""
|
|
101
100
|
|
|
102
|
-
response_headers:
|
|
101
|
+
response_headers: httpx.Headers | None = None
|
|
103
102
|
if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
|
|
104
103
|
response_headers = e.response.headers # type: ignore
|
|
105
104
|
if hasattr(e, "litellm_response_headers"):
|
|
@@ -119,9 +118,9 @@ class CustomRouter(Router):
|
|
|
119
118
|
e: Exception,
|
|
120
119
|
remaining_retries: int,
|
|
121
120
|
num_retries: int,
|
|
122
|
-
healthy_deployments:
|
|
123
|
-
all_deployments:
|
|
124
|
-
) ->
|
|
121
|
+
healthy_deployments: list | None = None,
|
|
122
|
+
all_deployments: list | None = None,
|
|
123
|
+
) -> int | float:
|
|
125
124
|
"""
|
|
126
125
|
Implements exponential backoff for retries.
|
|
127
126
|
|
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from typing import Optional
|
|
5
|
-
|
|
6
4
|
|
|
7
5
|
class ParserException(Exception):
|
|
8
6
|
"""Identifies errors resulting from generic parser errors.
|
|
@@ -12,7 +10,7 @@ class ParserException(Exception):
|
|
|
12
10
|
attempted to parse.
|
|
13
11
|
"""
|
|
14
12
|
|
|
15
|
-
source:
|
|
13
|
+
source: str | None
|
|
16
14
|
|
|
17
15
|
@staticmethod
|
|
18
16
|
def _log_format(source: str) -> str:
|
|
@@ -24,7 +22,7 @@ class ParserException(Exception):
|
|
|
24
22
|
# return f"<source>{source}</source>"
|
|
25
23
|
return ""
|
|
26
24
|
|
|
27
|
-
def __init__(self, msg:
|
|
25
|
+
def __init__(self, msg: str | None = None, source: str | None = None):
|
|
28
26
|
msg = "" if msg is None else msg.strip()
|
|
29
27
|
|
|
30
28
|
if source is not None:
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from functools import reduce
|
|
5
|
-
from typing import Optional
|
|
6
5
|
|
|
7
6
|
import marko
|
|
8
7
|
from lxml import etree
|
|
@@ -105,8 +104,8 @@ class LLMResponseParser:
|
|
|
105
104
|
|
|
106
105
|
def __init__(
|
|
107
106
|
self,
|
|
108
|
-
tag_parsers:
|
|
109
|
-
postprocessors:
|
|
107
|
+
tag_parsers: dict[str, TagParser] | None = None,
|
|
108
|
+
postprocessors: list[PostProcessor] | None = None,
|
|
110
109
|
):
|
|
111
110
|
"""
|
|
112
111
|
Initializes the LLMResponseParser with optional tag parsers and post-processors.
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from typing import Optional, Type
|
|
5
4
|
|
|
6
5
|
import json_repair
|
|
7
6
|
from pydantic import BaseModel, ValidationError
|
|
@@ -60,12 +59,12 @@ def deserialize_json_code(
|
|
|
60
59
|
|
|
61
60
|
|
|
62
61
|
class RealizePydanticTypes:
|
|
63
|
-
types: list[
|
|
62
|
+
types: list[type[BaseModel]]
|
|
64
63
|
|
|
65
|
-
def __init__(self, types: list[
|
|
64
|
+
def __init__(self, types: list[type[BaseModel]]):
|
|
66
65
|
self.types = types
|
|
67
66
|
|
|
68
|
-
def _fit_types(self, obj: dict) ->
|
|
67
|
+
def _fit_types(self, obj: dict) -> BaseModel | None:
|
|
69
68
|
final_obj = None
|
|
70
69
|
|
|
71
70
|
for t in self.types:
|