data-designer 0.3.8rc2__py3-none-any.whl → 0.4.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/cli/commands/__init__.py +1 -1
- data_designer/interface/__init__.py +21 -1
- data_designer/{_version.py → interface/_version.py} +2 -2
- data_designer/interface/data_designer.py +1 -7
- {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0rc1.dist-info}/METADATA +10 -42
- data_designer-0.4.0rc1.dist-info/RECORD +39 -0
- data_designer/__init__.py +0 -17
- data_designer/config/__init__.py +0 -2
- data_designer/config/analysis/__init__.py +0 -2
- data_designer/config/analysis/column_profilers.py +0 -159
- data_designer/config/analysis/column_statistics.py +0 -421
- data_designer/config/analysis/dataset_profiler.py +0 -84
- data_designer/config/analysis/utils/errors.py +0 -10
- data_designer/config/analysis/utils/reporting.py +0 -192
- data_designer/config/base.py +0 -69
- data_designer/config/column_configs.py +0 -470
- data_designer/config/column_types.py +0 -141
- data_designer/config/config_builder.py +0 -595
- data_designer/config/data_designer_config.py +0 -40
- data_designer/config/dataset_builders.py +0 -13
- data_designer/config/dataset_metadata.py +0 -18
- data_designer/config/default_model_settings.py +0 -129
- data_designer/config/errors.py +0 -24
- data_designer/config/exports.py +0 -145
- data_designer/config/interface.py +0 -55
- data_designer/config/models.py +0 -455
- data_designer/config/preview_results.py +0 -41
- data_designer/config/processors.py +0 -148
- data_designer/config/run_config.py +0 -51
- data_designer/config/sampler_constraints.py +0 -52
- data_designer/config/sampler_params.py +0 -639
- data_designer/config/seed.py +0 -116
- data_designer/config/seed_source.py +0 -84
- data_designer/config/seed_source_types.py +0 -19
- data_designer/config/utils/code_lang.py +0 -82
- data_designer/config/utils/constants.py +0 -363
- data_designer/config/utils/errors.py +0 -21
- data_designer/config/utils/info.py +0 -94
- data_designer/config/utils/io_helpers.py +0 -258
- data_designer/config/utils/misc.py +0 -78
- data_designer/config/utils/numerical_helpers.py +0 -30
- data_designer/config/utils/type_helpers.py +0 -106
- data_designer/config/utils/visualization.py +0 -482
- data_designer/config/validator_params.py +0 -94
- data_designer/engine/__init__.py +0 -2
- data_designer/engine/analysis/column_profilers/base.py +0 -49
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +0 -153
- data_designer/engine/analysis/column_profilers/registry.py +0 -22
- data_designer/engine/analysis/column_statistics.py +0 -145
- data_designer/engine/analysis/dataset_profiler.py +0 -149
- data_designer/engine/analysis/errors.py +0 -9
- data_designer/engine/analysis/utils/column_statistics_calculations.py +0 -234
- data_designer/engine/analysis/utils/judge_score_processing.py +0 -132
- data_designer/engine/column_generators/__init__.py +0 -2
- data_designer/engine/column_generators/generators/__init__.py +0 -2
- data_designer/engine/column_generators/generators/base.py +0 -122
- data_designer/engine/column_generators/generators/embedding.py +0 -35
- data_designer/engine/column_generators/generators/expression.py +0 -55
- data_designer/engine/column_generators/generators/llm_completion.py +0 -113
- data_designer/engine/column_generators/generators/samplers.py +0 -69
- data_designer/engine/column_generators/generators/seed_dataset.py +0 -144
- data_designer/engine/column_generators/generators/validation.py +0 -140
- data_designer/engine/column_generators/registry.py +0 -60
- data_designer/engine/column_generators/utils/errors.py +0 -15
- data_designer/engine/column_generators/utils/generator_classification.py +0 -43
- data_designer/engine/column_generators/utils/judge_score_factory.py +0 -58
- data_designer/engine/column_generators/utils/prompt_renderer.py +0 -100
- data_designer/engine/compiler.py +0 -97
- data_designer/engine/configurable_task.py +0 -71
- data_designer/engine/dataset_builders/artifact_storage.py +0 -283
- data_designer/engine/dataset_builders/column_wise_builder.py +0 -335
- data_designer/engine/dataset_builders/errors.py +0 -15
- data_designer/engine/dataset_builders/multi_column_configs.py +0 -46
- data_designer/engine/dataset_builders/utils/__init__.py +0 -2
- data_designer/engine/dataset_builders/utils/concurrency.py +0 -212
- data_designer/engine/dataset_builders/utils/config_compiler.py +0 -62
- data_designer/engine/dataset_builders/utils/dag.py +0 -62
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +0 -200
- data_designer/engine/dataset_builders/utils/errors.py +0 -15
- data_designer/engine/errors.py +0 -51
- data_designer/engine/model_provider.py +0 -77
- data_designer/engine/models/__init__.py +0 -2
- data_designer/engine/models/errors.py +0 -300
- data_designer/engine/models/facade.py +0 -287
- data_designer/engine/models/factory.py +0 -42
- data_designer/engine/models/litellm_overrides.py +0 -179
- data_designer/engine/models/parsers/__init__.py +0 -2
- data_designer/engine/models/parsers/errors.py +0 -34
- data_designer/engine/models/parsers/parser.py +0 -235
- data_designer/engine/models/parsers/postprocessors.py +0 -93
- data_designer/engine/models/parsers/tag_parsers.py +0 -62
- data_designer/engine/models/parsers/types.py +0 -84
- data_designer/engine/models/recipes/base.py +0 -81
- data_designer/engine/models/recipes/response_recipes.py +0 -293
- data_designer/engine/models/registry.py +0 -146
- data_designer/engine/models/telemetry.py +0 -359
- data_designer/engine/models/usage.py +0 -73
- data_designer/engine/models/utils.py +0 -38
- data_designer/engine/processing/ginja/__init__.py +0 -2
- data_designer/engine/processing/ginja/ast.py +0 -65
- data_designer/engine/processing/ginja/environment.py +0 -463
- data_designer/engine/processing/ginja/exceptions.py +0 -56
- data_designer/engine/processing/ginja/record.py +0 -32
- data_designer/engine/processing/gsonschema/__init__.py +0 -2
- data_designer/engine/processing/gsonschema/exceptions.py +0 -15
- data_designer/engine/processing/gsonschema/schema_transformers.py +0 -83
- data_designer/engine/processing/gsonschema/types.py +0 -10
- data_designer/engine/processing/gsonschema/validators.py +0 -202
- data_designer/engine/processing/processors/base.py +0 -13
- data_designer/engine/processing/processors/drop_columns.py +0 -42
- data_designer/engine/processing/processors/registry.py +0 -25
- data_designer/engine/processing/processors/schema_transform.py +0 -49
- data_designer/engine/processing/utils.py +0 -169
- data_designer/engine/registry/base.py +0 -99
- data_designer/engine/registry/data_designer_registry.py +0 -39
- data_designer/engine/registry/errors.py +0 -12
- data_designer/engine/resources/managed_dataset_generator.py +0 -39
- data_designer/engine/resources/managed_dataset_repository.py +0 -197
- data_designer/engine/resources/managed_storage.py +0 -65
- data_designer/engine/resources/resource_provider.py +0 -77
- data_designer/engine/resources/seed_reader.py +0 -154
- data_designer/engine/sampling_gen/column.py +0 -91
- data_designer/engine/sampling_gen/constraints.py +0 -100
- data_designer/engine/sampling_gen/data_sources/base.py +0 -217
- data_designer/engine/sampling_gen/data_sources/errors.py +0 -12
- data_designer/engine/sampling_gen/data_sources/sources.py +0 -347
- data_designer/engine/sampling_gen/entities/__init__.py +0 -2
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +0 -86
- data_designer/engine/sampling_gen/entities/email_address_utils.py +0 -171
- data_designer/engine/sampling_gen/entities/errors.py +0 -10
- data_designer/engine/sampling_gen/entities/national_id_utils.py +0 -102
- data_designer/engine/sampling_gen/entities/person.py +0 -144
- data_designer/engine/sampling_gen/entities/phone_number.py +0 -128
- data_designer/engine/sampling_gen/errors.py +0 -26
- data_designer/engine/sampling_gen/generator.py +0 -122
- data_designer/engine/sampling_gen/jinja_utils.py +0 -64
- data_designer/engine/sampling_gen/people_gen.py +0 -199
- data_designer/engine/sampling_gen/person_constants.py +0 -56
- data_designer/engine/sampling_gen/schema.py +0 -147
- data_designer/engine/sampling_gen/schema_builder.py +0 -61
- data_designer/engine/sampling_gen/utils.py +0 -46
- data_designer/engine/secret_resolver.py +0 -82
- data_designer/engine/validation.py +0 -367
- data_designer/engine/validators/__init__.py +0 -19
- data_designer/engine/validators/base.py +0 -38
- data_designer/engine/validators/local_callable.py +0 -39
- data_designer/engine/validators/python.py +0 -254
- data_designer/engine/validators/remote.py +0 -89
- data_designer/engine/validators/sql.py +0 -65
- data_designer/errors.py +0 -7
- data_designer/essentials/__init__.py +0 -33
- data_designer/lazy_heavy_imports.py +0 -54
- data_designer/logging.py +0 -163
- data_designer/plugin_manager.py +0 -78
- data_designer/plugins/__init__.py +0 -8
- data_designer/plugins/errors.py +0 -15
- data_designer/plugins/plugin.py +0 -141
- data_designer/plugins/registry.py +0 -88
- data_designer/plugins/testing/__init__.py +0 -10
- data_designer/plugins/testing/stubs.py +0 -116
- data_designer/plugins/testing/utils.py +0 -20
- data_designer-0.3.8rc2.dist-info/RECORD +0 -196
- data_designer-0.3.8rc2.dist-info/licenses/LICENSE +0 -201
- {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0rc1.dist-info}/WHEEL +0 -0
- {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0rc1.dist-info}/entry_points.txt +0 -0
|
@@ -1,113 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
import functools
|
|
7
|
-
import logging
|
|
8
|
-
|
|
9
|
-
from data_designer.config.column_configs import (
|
|
10
|
-
LLMCodeColumnConfig,
|
|
11
|
-
LLMJudgeColumnConfig,
|
|
12
|
-
LLMStructuredColumnConfig,
|
|
13
|
-
LLMTextColumnConfig,
|
|
14
|
-
)
|
|
15
|
-
from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX
|
|
16
|
-
from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModel, GenerationStrategy
|
|
17
|
-
from data_designer.engine.column_generators.utils.prompt_renderer import (
|
|
18
|
-
PromptType,
|
|
19
|
-
RecordBasedPromptRenderer,
|
|
20
|
-
create_response_recipe,
|
|
21
|
-
)
|
|
22
|
-
from data_designer.engine.configurable_task import TaskConfigT
|
|
23
|
-
from data_designer.engine.models.recipes.base import ResponseRecipe
|
|
24
|
-
from data_designer.engine.processing.utils import deserialize_json_values
|
|
25
|
-
|
|
26
|
-
logger = logging.getLogger(__name__)
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class ColumnGeneratorWithModelChatCompletion(ColumnGeneratorWithModel[TaskConfigT]):
|
|
30
|
-
@staticmethod
|
|
31
|
-
def get_generation_strategy() -> GenerationStrategy:
|
|
32
|
-
return GenerationStrategy.CELL_BY_CELL
|
|
33
|
-
|
|
34
|
-
@functools.cached_property
|
|
35
|
-
def response_recipe(self) -> ResponseRecipe:
|
|
36
|
-
return create_response_recipe(self.config, self.model_config)
|
|
37
|
-
|
|
38
|
-
@property
|
|
39
|
-
def max_conversation_correction_steps(self) -> int:
|
|
40
|
-
return self.resource_provider.run_config.max_conversation_correction_steps
|
|
41
|
-
|
|
42
|
-
@property
|
|
43
|
-
def max_conversation_restarts(self) -> int:
|
|
44
|
-
return self.resource_provider.run_config.max_conversation_restarts
|
|
45
|
-
|
|
46
|
-
@functools.cached_property
|
|
47
|
-
def prompt_renderer(self) -> RecordBasedPromptRenderer:
|
|
48
|
-
return RecordBasedPromptRenderer(
|
|
49
|
-
response_recipe=self.response_recipe,
|
|
50
|
-
error_message_context={
|
|
51
|
-
"column_name": self.config.name,
|
|
52
|
-
"column_type": self.config.column_type,
|
|
53
|
-
"model_alias": self.config.model_alias,
|
|
54
|
-
},
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
def generate(self, data: dict) -> dict:
|
|
58
|
-
# Deserialize input data from previous columns so Jinja2 templates can access nested fields
|
|
59
|
-
# Example: If prev column stored '{"key": "value"}', templates can use {{ prev_column.key }}
|
|
60
|
-
# Note: This creates a new dict and doesn't mutate the original `data` argument
|
|
61
|
-
deserialized_record = deserialize_json_values(data)
|
|
62
|
-
|
|
63
|
-
multi_modal_context = None
|
|
64
|
-
if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0:
|
|
65
|
-
multi_modal_context = [
|
|
66
|
-
context.get_context(deserialized_record) for context in self.config.multi_modal_context
|
|
67
|
-
]
|
|
68
|
-
|
|
69
|
-
response, reasoning_trace = self.model.generate(
|
|
70
|
-
prompt=self.prompt_renderer.render(
|
|
71
|
-
record=deserialized_record,
|
|
72
|
-
prompt_template=self.config.prompt,
|
|
73
|
-
prompt_type=PromptType.USER_PROMPT,
|
|
74
|
-
),
|
|
75
|
-
system_prompt=self.prompt_renderer.render(
|
|
76
|
-
record=deserialized_record,
|
|
77
|
-
prompt_template=self.config.system_prompt,
|
|
78
|
-
prompt_type=PromptType.SYSTEM_PROMPT,
|
|
79
|
-
),
|
|
80
|
-
parser=self.response_recipe.parse,
|
|
81
|
-
multi_modal_context=multi_modal_context,
|
|
82
|
-
max_correction_steps=self.max_conversation_correction_steps,
|
|
83
|
-
max_conversation_restarts=self.max_conversation_restarts,
|
|
84
|
-
purpose=f"running generation for column '{self.config.name}'",
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
serialized_output = self.response_recipe.serialize_output(response)
|
|
88
|
-
data[self.config.name] = self._process_serialized_output(serialized_output)
|
|
89
|
-
|
|
90
|
-
if reasoning_trace:
|
|
91
|
-
data[self.config.name + REASONING_TRACE_COLUMN_POSTFIX] = reasoning_trace
|
|
92
|
-
|
|
93
|
-
return data
|
|
94
|
-
|
|
95
|
-
def _process_serialized_output(self, serialized_output: str) -> str | dict | list:
|
|
96
|
-
"""Process the serialized output from the model. Subclasses can override to customize deserialization."""
|
|
97
|
-
return serialized_output
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
class LLMTextCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMTextColumnConfig]): ...
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
class LLMCodeCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMCodeColumnConfig]): ...
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
class LLMStructuredCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMStructuredColumnConfig]):
|
|
107
|
-
def _process_serialized_output(self, serialized_output: str) -> dict | list:
|
|
108
|
-
return deserialize_json_values(serialized_output)
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
class LLMJudgeCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMJudgeColumnConfig]):
|
|
112
|
-
def _process_serialized_output(self, serialized_output: str) -> dict | list:
|
|
113
|
-
return deserialize_json_values(serialized_output)
|
|
@@ -1,69 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
import logging
|
|
7
|
-
import random
|
|
8
|
-
from functools import partial
|
|
9
|
-
from typing import TYPE_CHECKING, Callable
|
|
10
|
-
|
|
11
|
-
from data_designer.config.utils.constants import LOCALES_WITH_MANAGED_DATASETS
|
|
12
|
-
from data_designer.engine.column_generators.generators.base import FromScratchColumnGenerator, GenerationStrategy
|
|
13
|
-
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
|
|
14
|
-
from data_designer.engine.processing.utils import concat_datasets
|
|
15
|
-
from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
|
|
16
|
-
from data_designer.engine.sampling_gen.data_sources.sources import SamplerType
|
|
17
|
-
from data_designer.engine.sampling_gen.entities.person import load_person_data_sampler
|
|
18
|
-
from data_designer.engine.sampling_gen.generator import DatasetGenerator as SamplingDatasetGenerator
|
|
19
|
-
from data_designer.lazy_heavy_imports import pd
|
|
20
|
-
|
|
21
|
-
if TYPE_CHECKING:
|
|
22
|
-
import pandas as pd
|
|
23
|
-
|
|
24
|
-
logger = logging.getLogger(__name__)
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class SamplerColumnGenerator(FromScratchColumnGenerator[SamplerMultiColumnConfig]):
|
|
28
|
-
@staticmethod
|
|
29
|
-
def get_generation_strategy() -> GenerationStrategy:
|
|
30
|
-
return GenerationStrategy.FULL_COLUMN
|
|
31
|
-
|
|
32
|
-
def generate(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
33
|
-
df_samplers = self.generate_from_scratch(len(data))
|
|
34
|
-
return concat_datasets([data, df_samplers])
|
|
35
|
-
|
|
36
|
-
def generate_from_scratch(self, num_records: int) -> pd.DataFrame:
|
|
37
|
-
sampling_generator = self._prepare_for_generation(num_records)
|
|
38
|
-
return sampling_generator.generate(num_records)
|
|
39
|
-
|
|
40
|
-
@property
|
|
41
|
-
def _needs_person_generator(self) -> bool:
|
|
42
|
-
columns = [c for c in self.config.columns if c.sampler_type == SamplerType.PERSON]
|
|
43
|
-
return any(c.params.locale in LOCALES_WITH_MANAGED_DATASETS for c in columns)
|
|
44
|
-
|
|
45
|
-
@property
|
|
46
|
-
def _person_generator_loader(self) -> Callable[[bool], ManagedDatasetGenerator]:
|
|
47
|
-
return partial(load_person_data_sampler, blob_storage=self.resource_provider.blob_storage)
|
|
48
|
-
|
|
49
|
-
def _create_sampling_dataset_generator(self) -> SamplingDatasetGenerator:
|
|
50
|
-
return SamplingDatasetGenerator(
|
|
51
|
-
sampler_columns=self.config,
|
|
52
|
-
person_generator_loader=(self._person_generator_loader if self._needs_person_generator else None),
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
def _log_person_generation_if_needed(self) -> None:
|
|
56
|
-
if self._needs_person_generator:
|
|
57
|
-
columns = [c for c in self.config.columns if c.sampler_type == SamplerType.PERSON]
|
|
58
|
-
emoji = random.choice(["🧑🎨", "🙋♂️", "🙋♀️", "🧑🚀", "👩🎤", "👨🍳", "👩🔬", "👨💻", "👩💼"])
|
|
59
|
-
log_msg = f"🎲 {emoji} Initializing person generation"
|
|
60
|
-
if any(c.params.with_synthetic_personas for c in columns):
|
|
61
|
-
log_msg += " ⚡️ with synthetic personas ⚡️"
|
|
62
|
-
logger.info(log_msg)
|
|
63
|
-
|
|
64
|
-
def _prepare_for_generation(self, num_records: int) -> SamplingDatasetGenerator:
|
|
65
|
-
logger.info(
|
|
66
|
-
f"🎲 Preparing samplers to generate {num_records} records across {len(self.config.columns)} columns"
|
|
67
|
-
)
|
|
68
|
-
self._log_person_generation_if_needed()
|
|
69
|
-
return self._create_sampling_dataset_generator()
|
|
@@ -1,144 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
import functools
|
|
7
|
-
import logging
|
|
8
|
-
from typing import TYPE_CHECKING
|
|
9
|
-
|
|
10
|
-
from data_designer.config.seed import IndexRange, PartitionBlock, SamplingStrategy
|
|
11
|
-
from data_designer.engine.column_generators.generators.base import FromScratchColumnGenerator, GenerationStrategy
|
|
12
|
-
from data_designer.engine.column_generators.utils.errors import SeedDatasetError
|
|
13
|
-
from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig
|
|
14
|
-
from data_designer.engine.processing.utils import concat_datasets
|
|
15
|
-
from data_designer.lazy_heavy_imports import duckdb, pd
|
|
16
|
-
|
|
17
|
-
if TYPE_CHECKING:
|
|
18
|
-
import duckdb
|
|
19
|
-
import pandas as pd
|
|
20
|
-
|
|
21
|
-
MAX_ZERO_RECORD_RESPONSE_FACTOR = 2
|
|
22
|
-
|
|
23
|
-
logger = logging.getLogger(__name__)
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class SeedDatasetColumnGenerator(FromScratchColumnGenerator[SeedDatasetMultiColumnConfig]):
|
|
27
|
-
@staticmethod
|
|
28
|
-
def get_generation_strategy() -> GenerationStrategy:
|
|
29
|
-
return GenerationStrategy.FULL_COLUMN
|
|
30
|
-
|
|
31
|
-
@property
|
|
32
|
-
def num_records_sampled(self) -> int:
|
|
33
|
-
return self._num_records_sampled
|
|
34
|
-
|
|
35
|
-
@functools.cached_property
|
|
36
|
-
def duckdb_conn(self) -> duckdb.DuckDBPyConnection:
|
|
37
|
-
return self.resource_provider.seed_reader.create_duckdb_connection()
|
|
38
|
-
|
|
39
|
-
def generate(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
40
|
-
return concat_datasets([self.generate_from_scratch(len(data)), data])
|
|
41
|
-
|
|
42
|
-
def generate_from_scratch(self, num_records: int) -> pd.DataFrame:
|
|
43
|
-
if num_records <= 0:
|
|
44
|
-
raise ValueError("🛑 `num_records` must be positive.")
|
|
45
|
-
|
|
46
|
-
if self._batch_reader is None:
|
|
47
|
-
self._reset_batch_reader(num_records)
|
|
48
|
-
|
|
49
|
-
return self._sample_records(num_records)
|
|
50
|
-
|
|
51
|
-
def _initialize(self) -> None:
|
|
52
|
-
self._num_records_sampled = 0
|
|
53
|
-
self._batch_reader = None
|
|
54
|
-
self._df_remaining = None
|
|
55
|
-
self._dataset_uri = self.resource_provider.seed_reader.get_dataset_uri()
|
|
56
|
-
self._seed_dataset_size = self.duckdb_conn.execute(f"SELECT COUNT(*) FROM '{self._dataset_uri}'").fetchone()[0]
|
|
57
|
-
self._index_range = self._resolve_index_range()
|
|
58
|
-
|
|
59
|
-
def _validate_selection_strategy(self) -> None:
|
|
60
|
-
err_msg = None
|
|
61
|
-
if self.config.selection_strategy is not None:
|
|
62
|
-
if (
|
|
63
|
-
isinstance(self.config.selection_strategy, IndexRange)
|
|
64
|
-
and self.config.selection_strategy.end >= self._seed_dataset_size
|
|
65
|
-
):
|
|
66
|
-
err_msg = f"Selection strategy 'end' index {self.config.selection_strategy.end} is out of bounds for dataset size {self._seed_dataset_size}"
|
|
67
|
-
elif (
|
|
68
|
-
isinstance(self.config.selection_strategy, PartitionBlock)
|
|
69
|
-
and self.config.selection_strategy.num_partitions > self._seed_dataset_size
|
|
70
|
-
):
|
|
71
|
-
err_msg = f"Selection strategy 'num_partitions' {self.config.selection_strategy.num_partitions} is out of bounds for dataset size {self._seed_dataset_size}"
|
|
72
|
-
if err_msg is not None:
|
|
73
|
-
raise SeedDatasetError(err_msg)
|
|
74
|
-
|
|
75
|
-
def _resolve_index_range(self) -> IndexRange | None:
|
|
76
|
-
self._validate_selection_strategy()
|
|
77
|
-
index_range = None
|
|
78
|
-
if self.config.selection_strategy is not None:
|
|
79
|
-
if isinstance(self.config.selection_strategy, IndexRange):
|
|
80
|
-
index_range = self.config.selection_strategy
|
|
81
|
-
elif isinstance(self.config.selection_strategy, PartitionBlock):
|
|
82
|
-
index_range = self.config.selection_strategy.to_index_range(self._seed_dataset_size)
|
|
83
|
-
return index_range
|
|
84
|
-
|
|
85
|
-
def _reset_batch_reader(self, num_records: int) -> None:
|
|
86
|
-
shuffle = self.config.sampling_strategy == SamplingStrategy.SHUFFLE
|
|
87
|
-
shuffle_query = " ORDER BY RANDOM()" if shuffle else ""
|
|
88
|
-
|
|
89
|
-
if self._index_range is not None:
|
|
90
|
-
# Use LIMIT and OFFSET for efficient index range filtering
|
|
91
|
-
# IndexRange uses 0-based indexing [start, end] inclusive
|
|
92
|
-
# OFFSET skips the first 'start' rows (0-based)
|
|
93
|
-
# LIMIT takes 'end - start + 1' rows to include both start and end (inclusive)
|
|
94
|
-
offset_value = self._index_range.start
|
|
95
|
-
limit_value = self._index_range.end - self._index_range.start + 1
|
|
96
|
-
read_query = f"""
|
|
97
|
-
SELECT * FROM '{self._dataset_uri}'
|
|
98
|
-
LIMIT {limit_value} OFFSET {offset_value}
|
|
99
|
-
"""
|
|
100
|
-
|
|
101
|
-
read_query = f"SELECT * FROM ({read_query}){shuffle_query}"
|
|
102
|
-
else:
|
|
103
|
-
read_query = f"SELECT * FROM '{self._dataset_uri}'{shuffle_query}"
|
|
104
|
-
self._batch_reader = self.duckdb_conn.query(read_query).record_batch(batch_size=num_records)
|
|
105
|
-
|
|
106
|
-
def _sample_records(self, num_records: int) -> pd.DataFrame:
|
|
107
|
-
logger.info(f"🌱 Sampling {num_records} records from seed dataset")
|
|
108
|
-
logger.info(f" |-- seed dataset size: {self._seed_dataset_size} records")
|
|
109
|
-
logger.info(f" |-- sampling strategy: {self.config.sampling_strategy}")
|
|
110
|
-
if self._index_range is not None:
|
|
111
|
-
if isinstance(self.config.selection_strategy, IndexRange):
|
|
112
|
-
logger.info(f" |-- selection: rows [{self._index_range.start} to {self._index_range.end}] inclusive")
|
|
113
|
-
else:
|
|
114
|
-
logger.info(
|
|
115
|
-
f" |-- selection: partition {self.config.selection_strategy.index + 1} of {self.config.selection_strategy.num_partitions}"
|
|
116
|
-
)
|
|
117
|
-
logger.info(f" |-- seed dataset size after selection: {self._index_range.size} records")
|
|
118
|
-
df_batch = pd.DataFrame()
|
|
119
|
-
df_sample = pd.DataFrame() if self._df_remaining is None else self._df_remaining
|
|
120
|
-
num_zero_record_responses = 0
|
|
121
|
-
|
|
122
|
-
while len(df_sample) < num_records:
|
|
123
|
-
try:
|
|
124
|
-
df_batch = self._batch_reader.read_next_batch().to_pandas()
|
|
125
|
-
df_sample = pd.concat([df_sample, df_batch], ignore_index=True)
|
|
126
|
-
except StopIteration:
|
|
127
|
-
self._reset_batch_reader(num_records)
|
|
128
|
-
|
|
129
|
-
if len(df_batch) == 0:
|
|
130
|
-
num_zero_record_responses += 1
|
|
131
|
-
if num_zero_record_responses > MAX_ZERO_RECORD_RESPONSE_FACTOR * num_records:
|
|
132
|
-
raise RuntimeError(
|
|
133
|
-
"🛑 Something went wrong while reading from the datastore. "
|
|
134
|
-
"Please check your connection and try again. "
|
|
135
|
-
"If the issue persists, please contact support."
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
self._df_remaining = None
|
|
139
|
-
if len(df_sample) > num_records:
|
|
140
|
-
self._df_remaining = df_sample.iloc[num_records:].reset_index(drop=True)
|
|
141
|
-
df_sample = df_sample.iloc[:num_records]
|
|
142
|
-
self._num_records_sampled += len(df_sample)
|
|
143
|
-
|
|
144
|
-
return df_sample
|
|
@@ -1,140 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
import logging
|
|
7
|
-
from typing import TYPE_CHECKING
|
|
8
|
-
|
|
9
|
-
from data_designer.config.column_configs import ValidationColumnConfig
|
|
10
|
-
from data_designer.config.errors import InvalidConfigError
|
|
11
|
-
from data_designer.config.utils.code_lang import SQL_DIALECTS, CodeLang
|
|
12
|
-
from data_designer.config.validator_params import ValidatorParamsT, ValidatorType
|
|
13
|
-
from data_designer.engine.column_generators.generators.base import ColumnGeneratorFullColumn
|
|
14
|
-
from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor
|
|
15
|
-
from data_designer.engine.errors import DataDesignerRuntimeError
|
|
16
|
-
from data_designer.engine.validators import (
|
|
17
|
-
BaseValidator,
|
|
18
|
-
LocalCallableValidator,
|
|
19
|
-
PythonValidator,
|
|
20
|
-
RemoteValidator,
|
|
21
|
-
SQLValidator,
|
|
22
|
-
ValidationResult,
|
|
23
|
-
)
|
|
24
|
-
from data_designer.lazy_heavy_imports import pd
|
|
25
|
-
|
|
26
|
-
if TYPE_CHECKING:
|
|
27
|
-
import pandas as pd
|
|
28
|
-
|
|
29
|
-
logger = logging.getLogger(__name__)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def get_validator_from_params(validator_type: ValidatorType, validator_params: ValidatorParamsT) -> BaseValidator:
|
|
33
|
-
if validator_type == ValidatorType.CODE:
|
|
34
|
-
if validator_params.code_lang == CodeLang.PYTHON:
|
|
35
|
-
return PythonValidator(validator_params)
|
|
36
|
-
elif validator_params.code_lang in SQL_DIALECTS:
|
|
37
|
-
return SQLValidator(validator_params)
|
|
38
|
-
elif validator_type == ValidatorType.REMOTE:
|
|
39
|
-
return RemoteValidator(validator_params)
|
|
40
|
-
else:
|
|
41
|
-
return LocalCallableValidator(validator_params)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
class ValidationColumnGenerator(ColumnGeneratorFullColumn[ValidationColumnConfig]):
|
|
45
|
-
def generate(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
46
|
-
logger.info(f"🔍 Validating column {self.config.name!r} with {len(data)} records")
|
|
47
|
-
logger.info(f" |-- target columns: {self.config.target_columns}")
|
|
48
|
-
logger.info(f" |-- validator type: {self.config.validator_type}")
|
|
49
|
-
logger.info(f" |-- validator params: {self.config.validator_params}")
|
|
50
|
-
logger.info(f" |-- batch size: {self.config.batch_size}")
|
|
51
|
-
|
|
52
|
-
validator = get_validator_from_params(self.config.validator_type, self.config.validator_params)
|
|
53
|
-
|
|
54
|
-
# Check if the target columns are present in the dataset
|
|
55
|
-
missing_columns = set(self.config.target_columns) - set(data.columns)
|
|
56
|
-
if missing_columns:
|
|
57
|
-
raise InvalidConfigError(
|
|
58
|
-
f"Target columns {missing_columns} defined in validation column {self.config.name!r} are missing in dataset"
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
# Check whether to pass single columns or multiple columns to the validator
|
|
62
|
-
validate_columns_separately = False
|
|
63
|
-
if self.config.validator_type == ValidatorType.CODE and len(self.config.target_columns) > 1:
|
|
64
|
-
# Code validator expects single column input, so we validate each column separately
|
|
65
|
-
validate_columns_separately = True
|
|
66
|
-
|
|
67
|
-
columns_to_validate = [[col] for col in self.config.target_columns]
|
|
68
|
-
else:
|
|
69
|
-
columns_to_validate = [self.config.target_columns]
|
|
70
|
-
|
|
71
|
-
outputs_as_dicts = None
|
|
72
|
-
for cols in columns_to_validate:
|
|
73
|
-
# Filter the dataset to only include the target columns, and convert to a list of dictionaries
|
|
74
|
-
records = data[cols].to_dict(orient="records")
|
|
75
|
-
|
|
76
|
-
batched_records = [
|
|
77
|
-
records[batch_start : batch_start + self.config.batch_size]
|
|
78
|
-
for batch_start in range(0, len(records), self.config.batch_size)
|
|
79
|
-
]
|
|
80
|
-
|
|
81
|
-
# Run validation in parallel or sequentially, depending on the validator type and parameters
|
|
82
|
-
if (
|
|
83
|
-
self.config.validator_type == ValidatorType.REMOTE
|
|
84
|
-
and self.config.validator_params.max_parallel_requests > 1
|
|
85
|
-
):
|
|
86
|
-
concatenated_outputs = self._validate_in_parallel(validator, batched_records)
|
|
87
|
-
else:
|
|
88
|
-
concatenated_outputs = []
|
|
89
|
-
for batch in batched_records:
|
|
90
|
-
concatenated_outputs.extend(self._validate_batch(validator, batch))
|
|
91
|
-
|
|
92
|
-
if validate_columns_separately:
|
|
93
|
-
if outputs_as_dicts is None:
|
|
94
|
-
outputs_as_dicts = [{cols[0]: output.model_dump(mode="json")} for output in concatenated_outputs]
|
|
95
|
-
else:
|
|
96
|
-
for dict_output in outputs_as_dicts:
|
|
97
|
-
dict_output[cols[0]] = concatenated_outputs[0].model_dump(mode="json")
|
|
98
|
-
else:
|
|
99
|
-
outputs_as_dicts = [output.model_dump(mode="json") for output in concatenated_outputs]
|
|
100
|
-
|
|
101
|
-
validation_results = pd.DataFrame({self.config.name: outputs_as_dicts})
|
|
102
|
-
return pd.concat([data, validation_results], axis=1)
|
|
103
|
-
|
|
104
|
-
def _validate_in_parallel(self, validator: BaseValidator, batched_records: list[list[dict]]) -> pd.DataFrame:
|
|
105
|
-
"""Run validation in parallel."""
|
|
106
|
-
|
|
107
|
-
outputs = [None] * len(batched_records)
|
|
108
|
-
|
|
109
|
-
def result_callback(result: ValidationResult, context: dict):
|
|
110
|
-
outputs[context["index"]] = result
|
|
111
|
-
|
|
112
|
-
def error_callback(error: Exception, context: dict):
|
|
113
|
-
outputs[context["index"]] = ValidationResult.empty(size=len(batched_records[context["index"]]))
|
|
114
|
-
|
|
115
|
-
settings = self.resource_provider.run_config
|
|
116
|
-
with ConcurrentThreadExecutor(
|
|
117
|
-
max_workers=self.config.validator_params.max_parallel_requests,
|
|
118
|
-
column_name=self.config.name,
|
|
119
|
-
result_callback=result_callback,
|
|
120
|
-
error_callback=error_callback,
|
|
121
|
-
shutdown_error_rate=settings.shutdown_error_rate,
|
|
122
|
-
shutdown_error_window=settings.shutdown_error_window,
|
|
123
|
-
disable_early_shutdown=settings.disable_early_shutdown,
|
|
124
|
-
) as executor:
|
|
125
|
-
for i, batch in enumerate(batched_records):
|
|
126
|
-
executor.submit(lambda batch: self._validate_batch(validator, batch), batch, context={"index": i})
|
|
127
|
-
|
|
128
|
-
if any(output is None for output in outputs):
|
|
129
|
-
raise DataDesignerRuntimeError("Validation task failed due to an unexpected error in parallel execution")
|
|
130
|
-
|
|
131
|
-
# Concatenate the outputs and convert to a DataFrame
|
|
132
|
-
return sum([output.data for output in outputs], [])
|
|
133
|
-
|
|
134
|
-
def _validate_batch(self, validator: BaseValidator, batch: list[dict]) -> ValidationResult:
|
|
135
|
-
try:
|
|
136
|
-
return validator.run_validation(batch)
|
|
137
|
-
except Exception as e:
|
|
138
|
-
error_to_display = str(e).replace("\n", "\n ") # add spaces to improve readability
|
|
139
|
-
logger.error(f"Batch could not be validated:\n {error_to_display}")
|
|
140
|
-
raise e
|
|
@@ -1,60 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
from data_designer.config.base import ConfigBase
|
|
7
|
-
from data_designer.config.column_configs import (
|
|
8
|
-
EmbeddingColumnConfig,
|
|
9
|
-
ExpressionColumnConfig,
|
|
10
|
-
LLMCodeColumnConfig,
|
|
11
|
-
LLMJudgeColumnConfig,
|
|
12
|
-
LLMStructuredColumnConfig,
|
|
13
|
-
LLMTextColumnConfig,
|
|
14
|
-
ValidationColumnConfig,
|
|
15
|
-
)
|
|
16
|
-
from data_designer.config.column_types import DataDesignerColumnType
|
|
17
|
-
from data_designer.engine.column_generators.generators.base import ColumnGenerator
|
|
18
|
-
from data_designer.engine.column_generators.generators.embedding import EmbeddingCellGenerator
|
|
19
|
-
from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator
|
|
20
|
-
from data_designer.engine.column_generators.generators.llm_completion import (
|
|
21
|
-
LLMCodeCellGenerator,
|
|
22
|
-
LLMJudgeCellGenerator,
|
|
23
|
-
LLMStructuredCellGenerator,
|
|
24
|
-
LLMTextCellGenerator,
|
|
25
|
-
)
|
|
26
|
-
from data_designer.engine.column_generators.generators.samplers import SamplerColumnGenerator
|
|
27
|
-
from data_designer.engine.column_generators.generators.seed_dataset import SeedDatasetColumnGenerator
|
|
28
|
-
from data_designer.engine.column_generators.generators.validation import ValidationColumnGenerator
|
|
29
|
-
from data_designer.engine.dataset_builders.multi_column_configs import (
|
|
30
|
-
SamplerMultiColumnConfig,
|
|
31
|
-
SeedDatasetMultiColumnConfig,
|
|
32
|
-
)
|
|
33
|
-
from data_designer.engine.registry.base import TaskRegistry
|
|
34
|
-
from data_designer.plugins.plugin import PluginType
|
|
35
|
-
from data_designer.plugins.registry import PluginRegistry
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class ColumnGeneratorRegistry(TaskRegistry[DataDesignerColumnType, ColumnGenerator, ConfigBase]): ...
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def create_default_column_generator_registry(with_plugins: bool = True) -> ColumnGeneratorRegistry:
|
|
42
|
-
registry = ColumnGeneratorRegistry()
|
|
43
|
-
registry.register(DataDesignerColumnType.LLM_TEXT, LLMTextCellGenerator, LLMTextColumnConfig)
|
|
44
|
-
registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig)
|
|
45
|
-
registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig)
|
|
46
|
-
registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig)
|
|
47
|
-
registry.register(DataDesignerColumnType.EMBEDDING, EmbeddingCellGenerator, EmbeddingColumnConfig)
|
|
48
|
-
registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig)
|
|
49
|
-
registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig)
|
|
50
|
-
registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig)
|
|
51
|
-
registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig)
|
|
52
|
-
if with_plugins:
|
|
53
|
-
for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR):
|
|
54
|
-
registry.register(
|
|
55
|
-
DataDesignerColumnType(plugin.name),
|
|
56
|
-
plugin.impl_cls,
|
|
57
|
-
plugin.config_cls,
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
return registry
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
from data_designer.engine.errors import DataDesignerError
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class PromptTemplateRenderError(DataDesignerError): ...
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class ExpressionTemplateRenderError(DataDesignerError): ...
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class SeedDatasetError(DataDesignerError): ...
|
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
from data_designer.config.column_types import DataDesignerColumnType
|
|
7
|
-
from data_designer.config.utils.type_helpers import resolve_string_enum
|
|
8
|
-
from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModelRegistry
|
|
9
|
-
from data_designer.plugin_manager import PluginManager
|
|
10
|
-
|
|
11
|
-
plugin_manager = PluginManager()
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def column_type_used_in_execution_dag(column_type: str | DataDesignerColumnType) -> bool:
|
|
15
|
-
"""Return True if the column type is used in the workflow execution DAG."""
|
|
16
|
-
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
|
|
17
|
-
dag_column_types = {
|
|
18
|
-
DataDesignerColumnType.EXPRESSION,
|
|
19
|
-
DataDesignerColumnType.LLM_CODE,
|
|
20
|
-
DataDesignerColumnType.LLM_JUDGE,
|
|
21
|
-
DataDesignerColumnType.LLM_STRUCTURED,
|
|
22
|
-
DataDesignerColumnType.LLM_TEXT,
|
|
23
|
-
DataDesignerColumnType.VALIDATION,
|
|
24
|
-
DataDesignerColumnType.EMBEDDING,
|
|
25
|
-
}
|
|
26
|
-
dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType))
|
|
27
|
-
return column_type in dag_column_types
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def column_type_is_model_generated(column_type: str | DataDesignerColumnType) -> bool:
|
|
31
|
-
"""Return True if the column type is a model-generated column."""
|
|
32
|
-
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
|
|
33
|
-
model_generated_column_types = {
|
|
34
|
-
DataDesignerColumnType.LLM_TEXT,
|
|
35
|
-
DataDesignerColumnType.LLM_CODE,
|
|
36
|
-
DataDesignerColumnType.LLM_STRUCTURED,
|
|
37
|
-
DataDesignerColumnType.LLM_JUDGE,
|
|
38
|
-
DataDesignerColumnType.EMBEDDING,
|
|
39
|
-
}
|
|
40
|
-
for plugin in plugin_manager.get_column_generator_plugins():
|
|
41
|
-
if issubclass(plugin.impl_cls, ColumnGeneratorWithModelRegistry):
|
|
42
|
-
model_generated_column_types.add(plugin.name)
|
|
43
|
-
return column_type in model_generated_column_types
|
|
@@ -1,58 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
from enum import Enum
|
|
7
|
-
|
|
8
|
-
from pydantic import BaseModel, ConfigDict, Field, create_model
|
|
9
|
-
|
|
10
|
-
from data_designer.config.column_configs import Score
|
|
11
|
-
|
|
12
|
-
SCORING_FORMAT = "* {score}: {description}"
|
|
13
|
-
SCORE_FIELD_DESCRIPTION_FORMAT = "Score Descriptions for {enum_name}:\n{scoring}"
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class BaseJudgeResponse(BaseModel):
|
|
17
|
-
"""Base model for all rubrics."""
|
|
18
|
-
|
|
19
|
-
model_config = ConfigDict(use_enum_values=True)
|
|
20
|
-
reasoning: str = Field(..., description="Reasoning for the assigned score.")
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def _stringify_scoring(options: dict, enum_type: type[Enum]) -> str:
|
|
24
|
-
"""Convert score descriptions into a single text block."""
|
|
25
|
-
list_block = "\n".join(
|
|
26
|
-
[SCORING_FORMAT.format(score=score, description=description) for score, description in options.items()]
|
|
27
|
-
)
|
|
28
|
-
return SCORE_FIELD_DESCRIPTION_FORMAT.format(enum_name=enum_type.__name__, scoring=list_block)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def create_judge_response_model(score: Score) -> type[BaseJudgeResponse]:
|
|
32
|
-
"""Create a JudgeResponse data type."""
|
|
33
|
-
enum_members = {}
|
|
34
|
-
for option in score.options.keys():
|
|
35
|
-
member_name = f"VALUE_{option}"
|
|
36
|
-
enum_members[member_name] = option
|
|
37
|
-
|
|
38
|
-
DynamicScaleEnum = Enum(f"{score.name}Enum", enum_members)
|
|
39
|
-
options = _stringify_scoring(score.options, enum_type=DynamicScaleEnum)
|
|
40
|
-
|
|
41
|
-
return create_model(
|
|
42
|
-
score.name,
|
|
43
|
-
__doc__=score.description if score.description else None,
|
|
44
|
-
__base__=BaseJudgeResponse,
|
|
45
|
-
score=(DynamicScaleEnum, Field(..., description=options)),
|
|
46
|
-
)
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def create_judge_structured_output_model(
|
|
50
|
-
judge_responses: list[type[BaseJudgeResponse]],
|
|
51
|
-
) -> type[BaseModel]:
|
|
52
|
-
"""Create a JudgeStructuredOutput class dynamically."""
|
|
53
|
-
return create_model(
|
|
54
|
-
"JudgeStructuredOutput",
|
|
55
|
-
__doc__=f"Response schema for scores with the following names: {[response.__name__ for response in judge_responses]}.",
|
|
56
|
-
__base__=BaseModel,
|
|
57
|
-
**{response.__name__: (response, ...) for response in judge_responses},
|
|
58
|
-
)
|