data-designer-engine 0.4.0rc3__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/engine/analysis/column_profilers/base.py +1 -2
- data_designer/engine/analysis/dataset_profiler.py +1 -2
- data_designer/engine/column_generators/generators/base.py +1 -6
- data_designer/engine/column_generators/generators/custom.py +195 -0
- data_designer/engine/column_generators/generators/llm_completion.py +32 -5
- data_designer/engine/column_generators/registry.py +3 -0
- data_designer/engine/column_generators/utils/errors.py +3 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +1 -1
- data_designer/engine/dataset_builders/column_wise_builder.py +23 -5
- data_designer/engine/dataset_builders/multi_column_configs.py +2 -2
- data_designer/engine/mcp/__init__.py +30 -0
- data_designer/engine/mcp/errors.py +22 -0
- data_designer/engine/mcp/facade.py +485 -0
- data_designer/engine/mcp/factory.py +46 -0
- data_designer/engine/mcp/io.py +487 -0
- data_designer/engine/mcp/registry.py +203 -0
- data_designer/engine/model_provider.py +68 -0
- data_designer/engine/models/facade.py +74 -9
- data_designer/engine/models/factory.py +18 -1
- data_designer/engine/models/utils.py +28 -1
- data_designer/engine/resources/resource_provider.py +72 -3
- data_designer/engine/testing/fixtures.py +233 -0
- data_designer/engine/testing/stubs.py +1 -2
- {data_designer_engine-0.4.0rc3.dist-info → data_designer_engine-0.5.0rc1.dist-info}/METADATA +3 -2
- {data_designer_engine-0.4.0rc3.dist-info → data_designer_engine-0.5.0rc1.dist-info}/RECORD +26 -19
- data_designer/engine/_version.py +0 -34
- {data_designer_engine-0.4.0rc3.dist-info → data_designer_engine-0.5.0rc1.dist-info}/WHEEL +0 -0
|
@@ -10,8 +10,7 @@ from typing import TYPE_CHECKING
|
|
|
10
10
|
from pydantic import BaseModel, model_validator
|
|
11
11
|
from typing_extensions import Self
|
|
12
12
|
|
|
13
|
-
from data_designer.config.base import ConfigBase
|
|
14
|
-
from data_designer.config.column_configs import SingleColumnConfig
|
|
13
|
+
from data_designer.config.base import ConfigBase, SingleColumnConfig
|
|
15
14
|
from data_designer.config.column_types import DataDesignerColumnType
|
|
16
15
|
from data_designer.engine.configurable_task import ConfigurableTask, TaskConfigT
|
|
17
16
|
from data_designer.lazy_heavy_imports import pd
|
|
@@ -12,8 +12,7 @@ from pydantic import Field, field_validator
|
|
|
12
12
|
|
|
13
13
|
from data_designer.config.analysis.column_profilers import ColumnProfilerConfigT
|
|
14
14
|
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
|
|
15
|
-
from data_designer.config.base import ConfigBase
|
|
16
|
-
from data_designer.config.column_configs import SingleColumnConfig
|
|
15
|
+
from data_designer.config.base import ConfigBase, SingleColumnConfig
|
|
17
16
|
from data_designer.config.column_types import ColumnConfigT
|
|
18
17
|
from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame, ColumnProfiler
|
|
19
18
|
from data_designer.engine.analysis.column_statistics import get_column_statistics_calculator
|
|
@@ -6,9 +6,9 @@ from __future__ import annotations
|
|
|
6
6
|
import functools
|
|
7
7
|
import logging
|
|
8
8
|
from abc import ABC, abstractmethod
|
|
9
|
-
from enum import Enum
|
|
10
9
|
from typing import TYPE_CHECKING, overload
|
|
11
10
|
|
|
11
|
+
from data_designer.config.column_configs import GenerationStrategy
|
|
12
12
|
from data_designer.engine.configurable_task import ConfigurableTask, DataT, TaskConfigT
|
|
13
13
|
from data_designer.lazy_heavy_imports import pd
|
|
14
14
|
|
|
@@ -22,11 +22,6 @@ if TYPE_CHECKING:
|
|
|
22
22
|
logger = logging.getLogger(__name__)
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
class GenerationStrategy(str, Enum):
|
|
26
|
-
CELL_BY_CELL = "cell_by_cell"
|
|
27
|
-
FULL_COLUMN = "full_column"
|
|
28
|
-
|
|
29
|
-
|
|
30
25
|
class ColumnGenerator(ConfigurableTask[TaskConfigT], ABC):
|
|
31
26
|
@property
|
|
32
27
|
def can_generate_from_scratch(self) -> bool:
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Custom column generator using user-provided callable functions."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import inspect
|
|
9
|
+
import logging
|
|
10
|
+
from typing import TYPE_CHECKING, Any
|
|
11
|
+
|
|
12
|
+
from data_designer.config.column_configs import CustomColumnConfig, GenerationStrategy
|
|
13
|
+
from data_designer.engine.column_generators.generators.base import ColumnGenerator
|
|
14
|
+
from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError
|
|
15
|
+
from data_designer.lazy_heavy_imports import pd
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import pandas as pd
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CustomColumnGenerator(ColumnGenerator[CustomColumnConfig]):
|
|
24
|
+
"""Column generator that uses a user-provided callable function.
|
|
25
|
+
|
|
26
|
+
Supports two strategies based on config.strategy:
|
|
27
|
+
- cell_by_cell: Processes rows one at a time (dict -> dict), parallelized by framework.
|
|
28
|
+
- full_column: Processes entire batch (DataFrame -> DataFrame) for vectorized ops.
|
|
29
|
+
|
|
30
|
+
Supported function signatures (validated by parameter name):
|
|
31
|
+
- fn(row) -> dict # cell_by_cell, simple transform
|
|
32
|
+
- fn(row, generator_params) -> dict # cell_by_cell, with typed params
|
|
33
|
+
- fn(row, generator_params, models) -> dict # cell_by_cell, with LLM access
|
|
34
|
+
- fn(df) -> DataFrame # full_column, simple transform
|
|
35
|
+
- fn(df, generator_params) -> DataFrame # full_column, with typed params
|
|
36
|
+
- fn(df, generator_params, models) -> DataFrame # full_column, with LLM access
|
|
37
|
+
|
|
38
|
+
The models dict provides direct access to ModelFacade instances keyed by alias.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def get_generation_strategy(self) -> GenerationStrategy:
|
|
42
|
+
"""Return strategy based on config."""
|
|
43
|
+
return self.config.generation_strategy
|
|
44
|
+
|
|
45
|
+
def generate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame:
|
|
46
|
+
"""Generate column value(s) for a row (dict) or batch (DataFrame)."""
|
|
47
|
+
is_full_column = self.config.generation_strategy == GenerationStrategy.FULL_COLUMN
|
|
48
|
+
is_dataframe = not isinstance(data, dict)
|
|
49
|
+
|
|
50
|
+
# Validate data type matches strategy
|
|
51
|
+
if is_full_column and not is_dataframe:
|
|
52
|
+
raise CustomColumnGenerationError(
|
|
53
|
+
f"Custom generator {self.config.name!r} is configured for 'full_column' strategy "
|
|
54
|
+
"but received a dict. Expected a DataFrame."
|
|
55
|
+
)
|
|
56
|
+
if not is_full_column and is_dataframe:
|
|
57
|
+
raise CustomColumnGenerationError(
|
|
58
|
+
f"Custom generator {self.config.name!r} is configured for 'cell_by_cell' strategy "
|
|
59
|
+
"but received a DataFrame. Expected a dict."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
return self._generate(data, is_dataframe)
|
|
63
|
+
|
|
64
|
+
def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd.DataFrame:
|
|
65
|
+
"""Unified generation logic for both strategies."""
|
|
66
|
+
# Get columns/keys using unified accessor
|
|
67
|
+
get_keys = (lambda d: set(d.columns)) if is_dataframe else (lambda d: set(d.keys()))
|
|
68
|
+
expected_type = pd.DataFrame if is_dataframe else dict
|
|
69
|
+
type_name = "DataFrame" if is_dataframe else "dict"
|
|
70
|
+
|
|
71
|
+
# Check required columns
|
|
72
|
+
missing = set(self.config.required_columns) - get_keys(data)
|
|
73
|
+
if missing:
|
|
74
|
+
raise CustomColumnGenerationError(
|
|
75
|
+
f"Missing required columns for custom generator '{self.config.name}': {sorted(missing)}"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
keys_before = get_keys(data)
|
|
79
|
+
|
|
80
|
+
# Invoke generator
|
|
81
|
+
try:
|
|
82
|
+
result = self._invoke_generator_function(data)
|
|
83
|
+
except CustomColumnGenerationError:
|
|
84
|
+
raise
|
|
85
|
+
except Exception as e:
|
|
86
|
+
raise CustomColumnGenerationError(
|
|
87
|
+
f"Custom generator function failed for column '{self.config.name}': {e}"
|
|
88
|
+
) from e
|
|
89
|
+
|
|
90
|
+
# Validate return type
|
|
91
|
+
if not isinstance(result, expected_type):
|
|
92
|
+
raise CustomColumnGenerationError(
|
|
93
|
+
f"Custom generator for column '{self.config.name}' must return a {type_name}, "
|
|
94
|
+
f"got {type(result).__name__}"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
return self._validate_output(result, keys_before, is_dataframe)
|
|
98
|
+
|
|
99
|
+
def _validate_output(
|
|
100
|
+
self, result: dict | pd.DataFrame, keys_before: set[str], is_dataframe: bool
|
|
101
|
+
) -> dict | pd.DataFrame:
|
|
102
|
+
"""Validate output columns and remove undeclared ones."""
|
|
103
|
+
# Unified accessors
|
|
104
|
+
get_keys = (lambda d: set(d.columns)) if is_dataframe else (lambda d: set(d.keys()))
|
|
105
|
+
container_name = "DataFrame" if is_dataframe else "row"
|
|
106
|
+
|
|
107
|
+
expected_new = {self.config.name} | set(self.config.side_effect_columns)
|
|
108
|
+
result_keys = get_keys(result)
|
|
109
|
+
|
|
110
|
+
# Check primary column exists
|
|
111
|
+
if self.config.name not in result_keys:
|
|
112
|
+
raise CustomColumnGenerationError(
|
|
113
|
+
f"Custom generator for column '{self.config.name}' did not create the expected column. "
|
|
114
|
+
f"The generator_function must add a key named '{self.config.name}' to the {container_name}."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Check side effect columns exist
|
|
118
|
+
missing = set(self.config.side_effect_columns) - result_keys
|
|
119
|
+
if missing:
|
|
120
|
+
raise CustomColumnGenerationError(
|
|
121
|
+
f"Custom generator for column '{self.config.name}' did not create declared side_effect_columns: "
|
|
122
|
+
f"{sorted(missing)}. Declared side_effect_columns must be added to the {container_name}."
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Check no pre-existing columns removed
|
|
126
|
+
removed = keys_before - result_keys
|
|
127
|
+
if removed:
|
|
128
|
+
raise CustomColumnGenerationError(
|
|
129
|
+
f"Custom generator for column '{self.config.name}' removed pre-existing columns: "
|
|
130
|
+
f"{sorted(removed)}. The generator_function must not remove any existing columns."
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Remove undeclared columns with warning
|
|
134
|
+
undeclared = (result_keys - keys_before) - expected_new
|
|
135
|
+
if undeclared:
|
|
136
|
+
logger.warning(
|
|
137
|
+
f"⚠️ Custom generator for column '{self.config.name}' created undeclared columns: "
|
|
138
|
+
f"{sorted(undeclared)}. These columns will be removed. "
|
|
139
|
+
f"To keep additional columns, declare them in @custom_column_generator(side_effect_columns=[...])."
|
|
140
|
+
)
|
|
141
|
+
if is_dataframe:
|
|
142
|
+
result = result.drop(columns=list(undeclared))
|
|
143
|
+
else:
|
|
144
|
+
for key in undeclared:
|
|
145
|
+
del result[key]
|
|
146
|
+
|
|
147
|
+
return result
|
|
148
|
+
|
|
149
|
+
def _invoke_generator_function(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame:
|
|
150
|
+
"""Invoke the user's generate function with appropriate arguments based on signature."""
|
|
151
|
+
params = self._get_validated_params()
|
|
152
|
+
|
|
153
|
+
if len(params) == 1:
|
|
154
|
+
return self.config.generator_function(data)
|
|
155
|
+
elif len(params) == 2:
|
|
156
|
+
return self.config.generator_function(data, self.config.generator_params)
|
|
157
|
+
else:
|
|
158
|
+
models = self._build_models_dict()
|
|
159
|
+
return self.config.generator_function(data, self.config.generator_params, models)
|
|
160
|
+
|
|
161
|
+
def _build_models_dict(self) -> dict[str, Any]:
|
|
162
|
+
"""Build a dict of ModelFacade instances from model_aliases."""
|
|
163
|
+
return {
|
|
164
|
+
alias: self.resource_provider.model_registry.get_model(model_alias=alias)
|
|
165
|
+
for alias in self.config.model_aliases
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
def _get_validated_params(self) -> list[inspect.Parameter]:
|
|
169
|
+
"""Get positional params and validate first param matches generation strategy."""
|
|
170
|
+
params = [
|
|
171
|
+
p
|
|
172
|
+
for p in inspect.signature(self.config.generator_function).parameters.values()
|
|
173
|
+
if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
|
174
|
+
]
|
|
175
|
+
# Decorator validated param names; here we only check strategy match
|
|
176
|
+
is_full = self.config.generation_strategy == GenerationStrategy.FULL_COLUMN
|
|
177
|
+
expected = "df" if is_full else "row"
|
|
178
|
+
if params[0].name != expected:
|
|
179
|
+
raise CustomColumnGenerationError(
|
|
180
|
+
f"Generator '{self.config.name}': strategy is {'full_column' if is_full else 'cell_by_cell'}, "
|
|
181
|
+
f"first parameter must be '{expected}', got '{params[0].name}'."
|
|
182
|
+
)
|
|
183
|
+
return params
|
|
184
|
+
|
|
185
|
+
def log_pre_generation(self) -> None:
|
|
186
|
+
logger.info(f"{self.config.get_column_emoji()} Custom column config for column '{self.config.name}'")
|
|
187
|
+
logger.info(f" |-- generator_function: {self.config.generator_function.__name__!r}")
|
|
188
|
+
logger.info(f" |-- generation_strategy: {self.config.generation_strategy!r}")
|
|
189
|
+
logger.info(f" |-- required_columns: {self.config.required_columns}")
|
|
190
|
+
if self.config.side_effect_columns:
|
|
191
|
+
logger.info(f" |-- side_effect_columns: {self.config.side_effect_columns}")
|
|
192
|
+
if self.config.model_aliases:
|
|
193
|
+
logger.info(f" |-- model_aliases: {self.config.model_aliases}")
|
|
194
|
+
if self.config.generator_params:
|
|
195
|
+
logger.info(f" |-- generator_params: {self.config.generator_params}")
|
|
@@ -12,7 +12,8 @@ from data_designer.config.column_configs import (
|
|
|
12
12
|
LLMStructuredColumnConfig,
|
|
13
13
|
LLMTextColumnConfig,
|
|
14
14
|
)
|
|
15
|
-
from data_designer.config.utils.constants import TRACE_COLUMN_POSTFIX
|
|
15
|
+
from data_designer.config.utils.constants import REASONING_CONTENT_COLUMN_POSTFIX, TRACE_COLUMN_POSTFIX
|
|
16
|
+
from data_designer.config.utils.trace_type import TraceType
|
|
16
17
|
from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModel, GenerationStrategy
|
|
17
18
|
from data_designer.engine.column_generators.utils.prompt_renderer import (
|
|
18
19
|
PromptType,
|
|
@@ -79,6 +80,7 @@ class ColumnGeneratorWithModelChatCompletion(ColumnGeneratorWithModel[TaskConfig
|
|
|
79
80
|
),
|
|
80
81
|
parser=self.response_recipe.parse,
|
|
81
82
|
multi_modal_context=multi_modal_context,
|
|
83
|
+
tool_alias=self.config.tool_alias,
|
|
82
84
|
max_correction_steps=self.max_conversation_correction_steps,
|
|
83
85
|
max_conversation_restarts=self.max_conversation_restarts,
|
|
84
86
|
purpose=f"running generation for column '{self.config.name}'",
|
|
@@ -87,14 +89,39 @@ class ColumnGeneratorWithModelChatCompletion(ColumnGeneratorWithModel[TaskConfig
|
|
|
87
89
|
serialized_output = self.response_recipe.serialize_output(response)
|
|
88
90
|
data[self.config.name] = self._process_serialized_output(serialized_output)
|
|
89
91
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
if should_save_trace:
|
|
92
|
+
effective_trace_type = self.config.with_trace
|
|
93
|
+
|
|
94
|
+
if effective_trace_type == TraceType.ALL_MESSAGES:
|
|
94
95
|
data[self.config.name + TRACE_COLUMN_POSTFIX] = [message.to_dict() for message in trace]
|
|
96
|
+
elif effective_trace_type == TraceType.LAST_MESSAGE:
|
|
97
|
+
last_assistant = next((m for m in reversed(trace) if m.role == "assistant"), None)
|
|
98
|
+
data[self.config.name + TRACE_COLUMN_POSTFIX] = [last_assistant.to_dict()] if last_assistant else []
|
|
99
|
+
|
|
100
|
+
if self.config.extract_reasoning_content:
|
|
101
|
+
data[self.config.name + REASONING_CONTENT_COLUMN_POSTFIX] = self._extract_reasoning_content(trace)
|
|
95
102
|
|
|
96
103
|
return data
|
|
97
104
|
|
|
105
|
+
def _extract_reasoning_content(self, trace: list) -> str | None:
|
|
106
|
+
"""Extract reasoning_content from the final assistant message in the trace.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
trace: List of ChatMessage objects from the generation.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
The stripped reasoning_content from the final assistant message, or None if not present.
|
|
113
|
+
"""
|
|
114
|
+
reasoning_value: str | None = None
|
|
115
|
+
for message in reversed(trace):
|
|
116
|
+
if message.role == "assistant":
|
|
117
|
+
reasoning_value = message.reasoning_content
|
|
118
|
+
break
|
|
119
|
+
|
|
120
|
+
if reasoning_value is not None:
|
|
121
|
+
reasoning_value = reasoning_value.strip() or None
|
|
122
|
+
|
|
123
|
+
return reasoning_value
|
|
124
|
+
|
|
98
125
|
def _process_serialized_output(self, serialized_output: str) -> str | dict | list:
|
|
99
126
|
"""Process the serialized output from the model. Subclasses can override to customize deserialization."""
|
|
100
127
|
return serialized_output
|
|
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
from data_designer.config.base import ConfigBase
|
|
7
7
|
from data_designer.config.column_configs import (
|
|
8
|
+
CustomColumnConfig,
|
|
8
9
|
EmbeddingColumnConfig,
|
|
9
10
|
ExpressionColumnConfig,
|
|
10
11
|
LLMCodeColumnConfig,
|
|
@@ -15,6 +16,7 @@ from data_designer.config.column_configs import (
|
|
|
15
16
|
)
|
|
16
17
|
from data_designer.config.column_types import DataDesignerColumnType
|
|
17
18
|
from data_designer.engine.column_generators.generators.base import ColumnGenerator
|
|
19
|
+
from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator
|
|
18
20
|
from data_designer.engine.column_generators.generators.embedding import EmbeddingCellGenerator
|
|
19
21
|
from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator
|
|
20
22
|
from data_designer.engine.column_generators.generators.llm_completion import (
|
|
@@ -40,6 +42,7 @@ class ColumnGeneratorRegistry(TaskRegistry[DataDesignerColumnType, ColumnGenerat
|
|
|
40
42
|
|
|
41
43
|
def create_default_column_generator_registry(with_plugins: bool = True) -> ColumnGeneratorRegistry:
|
|
42
44
|
registry = ColumnGeneratorRegistry()
|
|
45
|
+
registry.register(DataDesignerColumnType.CUSTOM, CustomColumnGenerator, CustomColumnConfig)
|
|
43
46
|
registry.register(DataDesignerColumnType.LLM_TEXT, LLMTextCellGenerator, LLMTextColumnConfig)
|
|
44
47
|
registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig)
|
|
45
48
|
registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig)
|
|
@@ -6,7 +6,7 @@ from __future__ import annotations
|
|
|
6
6
|
import json
|
|
7
7
|
import logging
|
|
8
8
|
|
|
9
|
-
from data_designer.config.
|
|
9
|
+
from data_designer.config.base import SingleColumnConfig
|
|
10
10
|
from data_designer.config.column_types import DataDesignerColumnType
|
|
11
11
|
from data_designer.config.models import ModelConfig
|
|
12
12
|
from data_designer.config.utils.code_lang import CodeLang
|
|
@@ -12,6 +12,7 @@ import uuid
|
|
|
12
12
|
from pathlib import Path
|
|
13
13
|
from typing import TYPE_CHECKING, Callable
|
|
14
14
|
|
|
15
|
+
from data_designer.config.column_configs import CustomColumnConfig
|
|
15
16
|
from data_designer.config.column_types import ColumnConfigT
|
|
16
17
|
from data_designer.config.config_builder import BuilderConfig
|
|
17
18
|
from data_designer.config.data_designer_config import DataDesignerConfig
|
|
@@ -97,6 +98,7 @@ class ColumnWiseDatasetBuilder:
|
|
|
97
98
|
on_batch_complete: Callable[[Path], None] | None = None,
|
|
98
99
|
) -> Path:
|
|
99
100
|
self._run_model_health_check_if_needed()
|
|
101
|
+
self._run_mcp_tool_check_if_needed()
|
|
100
102
|
self._write_builder_config()
|
|
101
103
|
generators = self._initialize_generators()
|
|
102
104
|
start_time = time.perf_counter()
|
|
@@ -125,6 +127,7 @@ class ColumnWiseDatasetBuilder:
|
|
|
125
127
|
|
|
126
128
|
def build_preview(self, *, num_records: int) -> pd.DataFrame:
|
|
127
129
|
self._run_model_health_check_if_needed()
|
|
130
|
+
self._run_mcp_tool_check_if_needed()
|
|
128
131
|
|
|
129
132
|
generators = self._initialize_generators()
|
|
130
133
|
group_id = uuid.uuid4().hex
|
|
@@ -209,11 +212,26 @@ class ColumnWiseDatasetBuilder:
|
|
|
209
212
|
df = generator.generate(self.batch_manager.get_current_batch(as_dataframe=True))
|
|
210
213
|
self.batch_manager.update_records(df.to_dict(orient="records"))
|
|
211
214
|
|
|
212
|
-
def _run_model_health_check_if_needed(self) ->
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
215
|
+
def _run_model_health_check_if_needed(self) -> None:
|
|
216
|
+
model_aliases: set[str] = set()
|
|
217
|
+
for config in self.single_column_configs:
|
|
218
|
+
if column_type_is_model_generated(config.column_type):
|
|
219
|
+
model_aliases.add(config.model_alias)
|
|
220
|
+
if isinstance(config, CustomColumnConfig) and config.model_aliases:
|
|
221
|
+
model_aliases.update(config.model_aliases)
|
|
222
|
+
|
|
223
|
+
if model_aliases:
|
|
224
|
+
self._resource_provider.model_registry.run_health_check(list(model_aliases))
|
|
225
|
+
|
|
226
|
+
def _run_mcp_tool_check_if_needed(self) -> None:
|
|
227
|
+
tool_aliases = sorted(
|
|
228
|
+
{config.tool_alias for config in self.llm_generated_column_configs if getattr(config, "tool_alias", None)}
|
|
229
|
+
)
|
|
230
|
+
if not tool_aliases:
|
|
231
|
+
return
|
|
232
|
+
if self._resource_provider.mcp_registry is None:
|
|
233
|
+
raise DatasetGenerationError(f"Tool alias(es) {tool_aliases!r} specified but no MCPRegistry configured.")
|
|
234
|
+
self._resource_provider.mcp_registry.run_health_check(tool_aliases)
|
|
217
235
|
|
|
218
236
|
def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None:
|
|
219
237
|
if generator.get_generation_strategy() != GenerationStrategy.CELL_BY_CELL:
|
|
@@ -8,8 +8,8 @@ from typing import TypeAlias
|
|
|
8
8
|
|
|
9
9
|
from pydantic import Field, field_validator
|
|
10
10
|
|
|
11
|
-
from data_designer.config.base import ConfigBase
|
|
12
|
-
from data_designer.config.column_configs import SamplerColumnConfig, SeedDatasetColumnConfig
|
|
11
|
+
from data_designer.config.base import ConfigBase, SingleColumnConfig
|
|
12
|
+
from data_designer.config.column_configs import SamplerColumnConfig, SeedDatasetColumnConfig
|
|
13
13
|
from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType
|
|
14
14
|
from data_designer.config.sampler_constraints import ColumnConstraintT
|
|
15
15
|
from data_designer.config.seed import SeedConfig
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from data_designer.engine.mcp import io
|
|
7
|
+
from data_designer.engine.mcp.errors import (
|
|
8
|
+
DuplicateToolNameError,
|
|
9
|
+
MCPClientUnavailableError,
|
|
10
|
+
MCPConfigurationError,
|
|
11
|
+
MCPError,
|
|
12
|
+
MCPToolError,
|
|
13
|
+
)
|
|
14
|
+
from data_designer.engine.mcp.facade import MCPFacade
|
|
15
|
+
from data_designer.engine.mcp.factory import create_mcp_registry
|
|
16
|
+
from data_designer.engine.mcp.registry import MCPRegistry, MCPToolDefinition, MCPToolResult
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"DuplicateToolNameError",
|
|
20
|
+
"MCPClientUnavailableError",
|
|
21
|
+
"MCPConfigurationError",
|
|
22
|
+
"MCPError",
|
|
23
|
+
"MCPFacade",
|
|
24
|
+
"MCPRegistry",
|
|
25
|
+
"MCPToolDefinition",
|
|
26
|
+
"MCPToolError",
|
|
27
|
+
"MCPToolResult",
|
|
28
|
+
"create_mcp_registry",
|
|
29
|
+
"io",
|
|
30
|
+
]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from data_designer.errors import DataDesignerError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MCPError(DataDesignerError): ...
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MCPConfigurationError(MCPError): ...
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MCPClientUnavailableError(MCPError): ...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MCPToolError(MCPError): ...
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DuplicateToolNameError(MCPConfigurationError):
|
|
22
|
+
"""Raised when the same tool name exists in multiple MCP providers or tool configs."""
|