data-designer-engine 0.4.0rc3__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/engine/analysis/column_profilers/base.py +1 -2
- data_designer/engine/analysis/dataset_profiler.py +1 -2
- data_designer/engine/column_generators/generators/base.py +1 -6
- data_designer/engine/column_generators/generators/custom.py +195 -0
- data_designer/engine/column_generators/generators/llm_completion.py +32 -5
- data_designer/engine/column_generators/registry.py +3 -0
- data_designer/engine/column_generators/utils/errors.py +3 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +1 -1
- data_designer/engine/dataset_builders/column_wise_builder.py +23 -5
- data_designer/engine/dataset_builders/multi_column_configs.py +2 -2
- data_designer/engine/mcp/__init__.py +30 -0
- data_designer/engine/mcp/errors.py +22 -0
- data_designer/engine/mcp/facade.py +485 -0
- data_designer/engine/mcp/factory.py +46 -0
- data_designer/engine/mcp/io.py +487 -0
- data_designer/engine/mcp/registry.py +203 -0
- data_designer/engine/model_provider.py +68 -0
- data_designer/engine/models/facade.py +74 -9
- data_designer/engine/models/factory.py +18 -1
- data_designer/engine/models/utils.py +28 -1
- data_designer/engine/resources/resource_provider.py +72 -3
- data_designer/engine/testing/fixtures.py +233 -0
- data_designer/engine/testing/stubs.py +1 -2
- {data_designer_engine-0.4.0rc3.dist-info → data_designer_engine-0.5.0rc1.dist-info}/METADATA +3 -2
- {data_designer_engine-0.4.0rc3.dist-info → data_designer_engine-0.5.0rc1.dist-info}/RECORD +26 -19
- data_designer/engine/_version.py +0 -34
- {data_designer_engine-0.4.0rc3.dist-info → data_designer_engine-0.5.0rc1.dist-info}/WHEEL +0 -0
|
@@ -8,6 +8,7 @@ from functools import cached_property
|
|
|
8
8
|
from pydantic import BaseModel, field_validator, model_validator
|
|
9
9
|
from typing_extensions import Self
|
|
10
10
|
|
|
11
|
+
from data_designer.config.mcp import MCPProviderT
|
|
11
12
|
from data_designer.config.models import ModelProvider
|
|
12
13
|
from data_designer.engine.errors import NoModelProvidersError, UnknownProviderError
|
|
13
14
|
|
|
@@ -75,3 +76,70 @@ def resolve_model_provider_registry(
|
|
|
75
76
|
providers=model_providers,
|
|
76
77
|
default=default_provider_name or model_providers[0].name,
|
|
77
78
|
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class MCPProviderRegistry(BaseModel):
|
|
82
|
+
"""Registry for MCP providers.
|
|
83
|
+
|
|
84
|
+
Unlike ModelProviderRegistry, MCPProviderRegistry can be empty since MCP providers
|
|
85
|
+
are optional. Users only need to register MCP providers if they want to use MCP tools
|
|
86
|
+
for generation.
|
|
87
|
+
|
|
88
|
+
Attributes:
|
|
89
|
+
providers: List of MCP providers (both MCPProvider and LocalStdioMCPProvider).
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
providers: list[MCPProviderT] = []
|
|
93
|
+
|
|
94
|
+
@field_validator("providers", mode="after")
|
|
95
|
+
@classmethod
|
|
96
|
+
def validate_providers_have_unique_names(cls, v: list[MCPProviderT]) -> list[MCPProviderT]:
|
|
97
|
+
names = set()
|
|
98
|
+
dupes = set()
|
|
99
|
+
for provider in v:
|
|
100
|
+
if provider.name in names:
|
|
101
|
+
dupes.add(provider.name)
|
|
102
|
+
names.add(provider.name)
|
|
103
|
+
|
|
104
|
+
if len(dupes) > 0:
|
|
105
|
+
raise ValueError(f"MCP providers must have unique names, found duplicates: {dupes}")
|
|
106
|
+
return v
|
|
107
|
+
|
|
108
|
+
@cached_property
|
|
109
|
+
def _providers_dict(self) -> dict[str, MCPProviderT]:
|
|
110
|
+
return {p.name: p for p in self.providers}
|
|
111
|
+
|
|
112
|
+
def get_provider(self, name: str) -> MCPProviderT:
|
|
113
|
+
"""Get an MCP provider by name.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
name: The name of the MCP provider.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
The MCP provider with the given name.
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
UnknownProviderError: If no provider with the given name is registered.
|
|
123
|
+
"""
|
|
124
|
+
try:
|
|
125
|
+
return self._providers_dict[name]
|
|
126
|
+
except KeyError:
|
|
127
|
+
raise UnknownProviderError(f"No MCP provider named {name!r} registered")
|
|
128
|
+
|
|
129
|
+
def is_empty(self) -> bool:
|
|
130
|
+
"""Check if the registry has no providers."""
|
|
131
|
+
return len(self.providers) == 0
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def resolve_mcp_provider_registry(
|
|
135
|
+
mcp_providers: list[MCPProviderT] | None = None,
|
|
136
|
+
) -> MCPProviderRegistry:
|
|
137
|
+
"""Create an MCPProviderRegistry from a list of MCP providers.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
mcp_providers: Optional list of MCP providers. If None or empty, returns an empty registry.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
An MCPProviderRegistry containing the provided MCP providers.
|
|
144
|
+
"""
|
|
145
|
+
return MCPProviderRegistry(providers=mcp_providers or [])
|
|
@@ -9,6 +9,7 @@ from copy import deepcopy
|
|
|
9
9
|
from typing import TYPE_CHECKING, Any
|
|
10
10
|
|
|
11
11
|
from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
|
|
12
|
+
from data_designer.engine.mcp.errors import MCPConfigurationError
|
|
12
13
|
from data_designer.engine.model_provider import ModelProviderRegistry
|
|
13
14
|
from data_designer.engine.models.errors import (
|
|
14
15
|
GenerationValidationFailureError,
|
|
@@ -25,6 +26,15 @@ from data_designer.lazy_heavy_imports import litellm
|
|
|
25
26
|
if TYPE_CHECKING:
|
|
26
27
|
import litellm
|
|
27
28
|
|
|
29
|
+
from data_designer.engine.mcp.facade import MCPFacade
|
|
30
|
+
from data_designer.engine.mcp.registry import MCPRegistry
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _identity(x: Any) -> Any:
|
|
34
|
+
"""Identity function for default parser. Module-level for pickling compatibility."""
|
|
35
|
+
return x
|
|
36
|
+
|
|
37
|
+
|
|
28
38
|
logger = logging.getLogger(__name__)
|
|
29
39
|
|
|
30
40
|
|
|
@@ -34,10 +44,13 @@ class ModelFacade:
|
|
|
34
44
|
model_config: ModelConfig,
|
|
35
45
|
secret_resolver: SecretResolver,
|
|
36
46
|
model_provider_registry: ModelProviderRegistry,
|
|
37
|
-
|
|
47
|
+
*,
|
|
48
|
+
mcp_registry: MCPRegistry | None = None,
|
|
49
|
+
) -> None:
|
|
38
50
|
self._model_config = model_config
|
|
39
51
|
self._secret_resolver = secret_resolver
|
|
40
52
|
self._model_provider_registry = model_provider_registry
|
|
53
|
+
self._mcp_registry = mcp_registry
|
|
41
54
|
self._litellm_deployment = self._get_litellm_deployment(model_config)
|
|
42
55
|
self._router = CustomRouter([self._litellm_deployment], **LiteLLMRouterDefaultKwargs().model_dump())
|
|
43
56
|
self._usage_stats = ModelUsageStats()
|
|
@@ -104,6 +117,17 @@ class ModelFacade:
|
|
|
104
117
|
kwargs["extra_headers"] = self.model_provider.extra_headers
|
|
105
118
|
return kwargs
|
|
106
119
|
|
|
120
|
+
def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None:
|
|
121
|
+
if tool_alias is None:
|
|
122
|
+
return None
|
|
123
|
+
if self._mcp_registry is None:
|
|
124
|
+
raise MCPConfigurationError(f"Tool alias {tool_alias!r} specified but no MCPRegistry configured.")
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
return self._mcp_registry.get_mcp(tool_alias=tool_alias)
|
|
128
|
+
except ValueError as exc:
|
|
129
|
+
raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc
|
|
130
|
+
|
|
107
131
|
@catch_llm_exceptions
|
|
108
132
|
def generate_text_embeddings(
|
|
109
133
|
self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
|
|
@@ -142,9 +166,10 @@ class ModelFacade:
|
|
|
142
166
|
self,
|
|
143
167
|
prompt: str,
|
|
144
168
|
*,
|
|
145
|
-
parser: Callable[[str], Any],
|
|
169
|
+
parser: Callable[[str], Any] = _identity,
|
|
146
170
|
system_prompt: str | None = None,
|
|
147
171
|
multi_modal_context: list[dict[str, Any]] | None = None,
|
|
172
|
+
tool_alias: str | None = None,
|
|
148
173
|
max_correction_steps: int = 0,
|
|
149
174
|
max_conversation_restarts: int = 0,
|
|
150
175
|
skip_usage_tracking: bool = False,
|
|
@@ -170,7 +195,10 @@ class ModelFacade:
|
|
|
170
195
|
no system message is provided and the model should use its default system
|
|
171
196
|
prompt.
|
|
172
197
|
parser (func(str) -> Any): A function applied to the LLM response which processes
|
|
173
|
-
an LLM response into some output object.
|
|
198
|
+
an LLM response into some output object. Default: identity function.
|
|
199
|
+
tool_alias (str | None): Optional tool configuration alias. When provided,
|
|
200
|
+
the model may call permitted tools from the configured MCP providers.
|
|
201
|
+
The alias must reference a ToolConfig registered in the MCPRegistry.
|
|
174
202
|
max_correction_steps (int): Maximum number of correction rounds permitted
|
|
175
203
|
within a single conversation. Note, many rounds can lead to increasing
|
|
176
204
|
context size without necessarily improving performance -- small language
|
|
@@ -186,25 +214,61 @@ class ModelFacade:
|
|
|
186
214
|
Returns:
|
|
187
215
|
A tuple containing:
|
|
188
216
|
- The parsed output object from the parser.
|
|
189
|
-
- The full trace of ChatMessage entries in the conversation, including any
|
|
190
|
-
corrections and reasoning traces. Callers can decide whether to store this.
|
|
217
|
+
- The full trace of ChatMessage entries in the conversation, including any tool calls,
|
|
218
|
+
corrections, and reasoning traces. Callers can decide whether to store this.
|
|
191
219
|
|
|
192
220
|
Raises:
|
|
193
221
|
GenerationValidationFailureError: If the maximum number of retries or
|
|
194
222
|
correction steps are met and the last response failures on
|
|
195
223
|
generation validation.
|
|
224
|
+
MCPConfigurationError: If tool_alias is specified but no MCPRegistry is configured.
|
|
196
225
|
"""
|
|
197
226
|
output_obj = None
|
|
227
|
+
tool_schemas = None
|
|
228
|
+
tool_call_turns = 0
|
|
198
229
|
curr_num_correction_steps = 0
|
|
199
230
|
curr_num_restarts = 0
|
|
200
231
|
|
|
201
|
-
|
|
232
|
+
mcp_facade = self._get_mcp_facade(tool_alias)
|
|
233
|
+
|
|
234
|
+
# Checkpoint for restarts - updated after tool calls so we don't repeat them
|
|
235
|
+
restart_checkpoint = prompt_to_messages(
|
|
202
236
|
user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context
|
|
203
237
|
)
|
|
204
|
-
|
|
238
|
+
checkpoint_tool_call_turns = 0
|
|
239
|
+
messages: list[ChatMessage] = deepcopy(restart_checkpoint)
|
|
240
|
+
|
|
241
|
+
if mcp_facade is not None:
|
|
242
|
+
tool_schemas = mcp_facade.get_tool_schemas()
|
|
205
243
|
|
|
206
244
|
while True:
|
|
207
|
-
|
|
245
|
+
completion_kwargs = dict(kwargs)
|
|
246
|
+
if tool_schemas is not None:
|
|
247
|
+
completion_kwargs["tools"] = tool_schemas
|
|
248
|
+
|
|
249
|
+
completion_response = self.completion(
|
|
250
|
+
messages,
|
|
251
|
+
skip_usage_tracking=skip_usage_tracking,
|
|
252
|
+
**completion_kwargs,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Process any tool calls in the response (handles parallel tool calling)
|
|
256
|
+
if mcp_facade is not None and mcp_facade.has_tool_calls(completion_response):
|
|
257
|
+
tool_call_turns += 1
|
|
258
|
+
|
|
259
|
+
if tool_call_turns > mcp_facade.max_tool_call_turns:
|
|
260
|
+
# Gracefully refuse tool calls when budget is exhausted
|
|
261
|
+
messages.extend(mcp_facade.refuse_completion_response(completion_response))
|
|
262
|
+
else:
|
|
263
|
+
messages.extend(mcp_facade.process_completion_response(completion_response))
|
|
264
|
+
|
|
265
|
+
# Update checkpoint so restarts don't repeat tool calls
|
|
266
|
+
restart_checkpoint = deepcopy(messages)
|
|
267
|
+
checkpoint_tool_call_turns = tool_call_turns
|
|
268
|
+
|
|
269
|
+
continue # Back to top
|
|
270
|
+
|
|
271
|
+
# No tool calls remaining to process
|
|
208
272
|
response = completion_response.choices[0].message.content or ""
|
|
209
273
|
reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None)
|
|
210
274
|
messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None))
|
|
@@ -226,7 +290,8 @@ class ModelFacade:
|
|
|
226
290
|
elif curr_num_restarts < max_conversation_restarts:
|
|
227
291
|
curr_num_correction_steps = 0
|
|
228
292
|
curr_num_restarts += 1
|
|
229
|
-
messages = deepcopy(
|
|
293
|
+
messages = deepcopy(restart_checkpoint)
|
|
294
|
+
tool_call_turns = checkpoint_tool_call_turns
|
|
230
295
|
|
|
231
296
|
else:
|
|
232
297
|
raise GenerationValidationFailureError(
|
|
@@ -10,6 +10,7 @@ from data_designer.engine.model_provider import ModelProviderRegistry
|
|
|
10
10
|
from data_designer.engine.secret_resolver import SecretResolver
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
|
+
from data_designer.engine.mcp.registry import MCPRegistry
|
|
13
14
|
from data_designer.engine.models.registry import ModelRegistry
|
|
14
15
|
|
|
15
16
|
|
|
@@ -18,12 +19,23 @@ def create_model_registry(
|
|
|
18
19
|
model_configs: list[ModelConfig] | None = None,
|
|
19
20
|
secret_resolver: SecretResolver,
|
|
20
21
|
model_provider_registry: ModelProviderRegistry,
|
|
22
|
+
mcp_registry: MCPRegistry | None = None,
|
|
21
23
|
) -> ModelRegistry:
|
|
22
24
|
"""Factory function for creating a ModelRegistry instance.
|
|
23
25
|
|
|
24
26
|
Heavy dependencies (litellm, httpx) are deferred until this function is called.
|
|
25
27
|
This is a factory function pattern - imports inside factories are idiomatic Python
|
|
26
28
|
for lazy initialization.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model_configs: Optional list of model configurations to register.
|
|
32
|
+
secret_resolver: Resolver for secrets referenced in provider configs.
|
|
33
|
+
model_provider_registry: Registry of model provider configurations.
|
|
34
|
+
mcp_registry: Optional MCP registry for tool operations. When provided,
|
|
35
|
+
ModelFacades can look up MCPFacades by tool_alias for tool-enabled generation.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
A configured ModelRegistry instance.
|
|
27
39
|
"""
|
|
28
40
|
from data_designer.engine.models.facade import ModelFacade
|
|
29
41
|
from data_designer.engine.models.litellm_overrides import apply_litellm_patches
|
|
@@ -32,7 +44,12 @@ def create_model_registry(
|
|
|
32
44
|
apply_litellm_patches()
|
|
33
45
|
|
|
34
46
|
def model_facade_factory(model_config, secret_resolver, model_provider_registry):
|
|
35
|
-
return ModelFacade(
|
|
47
|
+
return ModelFacade(
|
|
48
|
+
model_config,
|
|
49
|
+
secret_resolver,
|
|
50
|
+
model_provider_registry,
|
|
51
|
+
mcp_registry=mcp_registry,
|
|
52
|
+
)
|
|
36
53
|
|
|
37
54
|
return ModelRegistry(
|
|
38
55
|
model_configs=model_configs,
|
|
@@ -36,11 +36,14 @@ class ChatMessage:
|
|
|
36
36
|
def to_dict(self) -> dict[str, Any]:
|
|
37
37
|
"""Convert the message to a dictionary format for API calls.
|
|
38
38
|
|
|
39
|
+
Content is normalized to a list of ChatML-style blocks to keep a
|
|
40
|
+
consistent schema across traces and API payloads.
|
|
41
|
+
|
|
39
42
|
Returns:
|
|
40
43
|
A dictionary containing the message fields. Only includes non-empty
|
|
41
44
|
optional fields to keep the output clean.
|
|
42
45
|
"""
|
|
43
|
-
result: dict[str, Any] = {"role": self.role, "content": self.content}
|
|
46
|
+
result: dict[str, Any] = {"role": self.role, "content": _normalize_content_blocks(self.content)}
|
|
44
47
|
if self.reasoning_content:
|
|
45
48
|
result["reasoning_content"] = self.reasoning_content
|
|
46
49
|
if self.tool_calls:
|
|
@@ -99,3 +102,27 @@ def prompt_to_messages(
|
|
|
99
102
|
if system_prompt:
|
|
100
103
|
return [ChatMessage.as_system(system_prompt), ChatMessage.as_user(user_content)]
|
|
101
104
|
return [ChatMessage.as_user(user_content)]
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _normalize_content_blocks(content: Any) -> list[dict[str, Any]]:
|
|
108
|
+
if isinstance(content, list):
|
|
109
|
+
return [_normalize_content_block(block) for block in content]
|
|
110
|
+
if content is None:
|
|
111
|
+
return []
|
|
112
|
+
return [_text_block(content)]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _normalize_content_block(block: Any) -> dict[str, Any]:
|
|
116
|
+
if isinstance(block, dict) and "type" in block:
|
|
117
|
+
return block
|
|
118
|
+
return _text_block(block)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _text_block(value: Any) -> dict[str, Any]:
|
|
122
|
+
if value is None:
|
|
123
|
+
text_value = ""
|
|
124
|
+
elif isinstance(value, str):
|
|
125
|
+
text_value = value
|
|
126
|
+
else:
|
|
127
|
+
text_value = str(value)
|
|
128
|
+
return {"type": "text", "text": text_value}
|
|
@@ -5,15 +5,21 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
from data_designer.config.base import ConfigBase
|
|
7
7
|
from data_designer.config.dataset_metadata import DatasetMetadata
|
|
8
|
+
from data_designer.config.mcp import MCPProviderT, ToolConfig
|
|
8
9
|
from data_designer.config.models import ModelConfig
|
|
9
10
|
from data_designer.config.run_config import RunConfig
|
|
10
11
|
from data_designer.config.seed_source import SeedSource
|
|
11
12
|
from data_designer.config.utils.type_helpers import StrEnum
|
|
12
13
|
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
|
|
13
|
-
from data_designer.engine.
|
|
14
|
+
from data_designer.engine.mcp.factory import create_mcp_registry
|
|
15
|
+
from data_designer.engine.mcp.registry import MCPRegistry
|
|
16
|
+
from data_designer.engine.model_provider import (
|
|
17
|
+
ModelProviderRegistry,
|
|
18
|
+
resolve_mcp_provider_registry,
|
|
19
|
+
)
|
|
14
20
|
from data_designer.engine.models.factory import create_model_registry
|
|
15
21
|
from data_designer.engine.models.registry import ModelRegistry
|
|
16
|
-
from data_designer.engine.resources.managed_storage import ManagedBlobStorage
|
|
22
|
+
from data_designer.engine.resources.managed_storage import ManagedBlobStorage
|
|
17
23
|
from data_designer.engine.resources.seed_reader import SeedReader, SeedReaderRegistry
|
|
18
24
|
from data_designer.engine.secret_resolver import SecretResolver
|
|
19
25
|
|
|
@@ -28,6 +34,7 @@ class ResourceProvider(ConfigBase):
|
|
|
28
34
|
artifact_storage: ArtifactStorage
|
|
29
35
|
blob_storage: ManagedBlobStorage | None = None
|
|
30
36
|
model_registry: ModelRegistry | None = None
|
|
37
|
+
mcp_registry: MCPRegistry | None = None
|
|
31
38
|
run_config: RunConfig = RunConfig()
|
|
32
39
|
seed_reader: SeedReader | None = None
|
|
33
40
|
|
|
@@ -43,6 +50,31 @@ class ResourceProvider(ConfigBase):
|
|
|
43
50
|
return DatasetMetadata(seed_column_names=seed_column_names)
|
|
44
51
|
|
|
45
52
|
|
|
53
|
+
def _validate_tool_configs_against_providers(
|
|
54
|
+
tool_configs: list[ToolConfig],
|
|
55
|
+
mcp_providers: list[MCPProviderT],
|
|
56
|
+
) -> None:
|
|
57
|
+
"""Validate that all providers referenced in tool configs exist.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
tool_configs: List of tool configurations to validate.
|
|
61
|
+
mcp_providers: List of available MCP provider configurations.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If a tool config references a provider that doesn't exist.
|
|
65
|
+
"""
|
|
66
|
+
available_providers = {p.name for p in mcp_providers}
|
|
67
|
+
|
|
68
|
+
for tc in tool_configs:
|
|
69
|
+
missing_providers = [p for p in tc.providers if p not in available_providers]
|
|
70
|
+
if missing_providers:
|
|
71
|
+
available_list = sorted(available_providers) if available_providers else ["(none configured)"]
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"ToolConfig '{tc.tool_alias}' references provider(s) {missing_providers!r} "
|
|
74
|
+
f"which are not registered. Available providers: {available_list}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
46
78
|
def create_resource_provider(
|
|
47
79
|
*,
|
|
48
80
|
artifact_storage: ArtifactStorage,
|
|
@@ -53,9 +85,31 @@ def create_resource_provider(
|
|
|
53
85
|
blob_storage: ManagedBlobStorage | None = None,
|
|
54
86
|
seed_dataset_source: SeedSource | None = None,
|
|
55
87
|
run_config: RunConfig | None = None,
|
|
88
|
+
mcp_providers: list[MCPProviderT] | None = None,
|
|
89
|
+
tool_configs: list[ToolConfig] | None = None,
|
|
56
90
|
) -> ResourceProvider:
|
|
57
91
|
"""Factory function for creating a ResourceProvider instance.
|
|
92
|
+
|
|
58
93
|
This function triggers lazy loading of heavy dependencies like litellm.
|
|
94
|
+
The creation order is:
|
|
95
|
+
1. MCPProviderRegistry (can be empty)
|
|
96
|
+
2. MCPRegistry with tool_configs
|
|
97
|
+
3. ModelRegistry with mcp_registry
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
artifact_storage: Storage for build artifacts.
|
|
101
|
+
model_configs: List of model configurations.
|
|
102
|
+
secret_resolver: Resolver for secrets.
|
|
103
|
+
model_provider_registry: Registry of model providers.
|
|
104
|
+
seed_reader_registry: Registry of seed readers.
|
|
105
|
+
blob_storage: Optional blob storage for large files.
|
|
106
|
+
seed_dataset_source: Optional source for seed datasets.
|
|
107
|
+
run_config: Optional runtime configuration.
|
|
108
|
+
mcp_providers: Optional list of MCP provider configurations.
|
|
109
|
+
tool_configs: Optional list of tool configurations.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
A configured ResourceProvider instance.
|
|
59
113
|
"""
|
|
60
114
|
seed_reader = None
|
|
61
115
|
if seed_dataset_source:
|
|
@@ -64,14 +118,29 @@ def create_resource_provider(
|
|
|
64
118
|
secret_resolver,
|
|
65
119
|
)
|
|
66
120
|
|
|
121
|
+
# Create MCPProviderRegistry first (can be empty)
|
|
122
|
+
mcp_provider_registry = resolve_mcp_provider_registry(mcp_providers)
|
|
123
|
+
|
|
124
|
+
# Create MCPRegistry with tool configs (only if tool_configs provided)
|
|
125
|
+
# Tool validation is performed during dataset builder health checks.
|
|
126
|
+
mcp_registry = None
|
|
127
|
+
if tool_configs:
|
|
128
|
+
mcp_registry = create_mcp_registry(
|
|
129
|
+
tool_configs=tool_configs,
|
|
130
|
+
secret_resolver=secret_resolver,
|
|
131
|
+
mcp_provider_registry=mcp_provider_registry,
|
|
132
|
+
)
|
|
133
|
+
|
|
67
134
|
return ResourceProvider(
|
|
68
135
|
artifact_storage=artifact_storage,
|
|
69
136
|
model_registry=create_model_registry(
|
|
70
137
|
model_configs=model_configs,
|
|
71
138
|
secret_resolver=secret_resolver,
|
|
72
139
|
model_provider_registry=model_provider_registry,
|
|
140
|
+
mcp_registry=mcp_registry,
|
|
73
141
|
),
|
|
74
|
-
blob_storage=blob_storage
|
|
142
|
+
blob_storage=blob_storage,
|
|
143
|
+
mcp_registry=mcp_registry,
|
|
75
144
|
seed_reader=seed_reader,
|
|
76
145
|
run_config=run_config or RunConfig(),
|
|
77
146
|
)
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Pytest fixtures for engine testing.
|
|
5
|
+
|
|
6
|
+
Located in src/ so it can be packaged and shared across subpackages via pytest_plugins.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import Any
|
|
12
|
+
from unittest.mock import MagicMock
|
|
13
|
+
|
|
14
|
+
import pytest
|
|
15
|
+
|
|
16
|
+
from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider, ToolConfig
|
|
17
|
+
from data_designer.engine.mcp.facade import MCPFacade
|
|
18
|
+
from data_designer.engine.model_provider import MCPProviderRegistry
|
|
19
|
+
from data_designer.engine.secret_resolver import SecretResolver
|
|
20
|
+
from data_designer.engine.testing.stubs import StubHuggingFaceSeedReader
|
|
21
|
+
|
|
22
|
+
# =============================================================================
|
|
23
|
+
# Fake LLM response classes (used by completion response fixtures)
|
|
24
|
+
# =============================================================================
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class _FakeMessage:
|
|
28
|
+
"""Fake message class for mocking LLM completion responses."""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
content: str | None,
|
|
33
|
+
tool_calls: list[dict] | None = None,
|
|
34
|
+
reasoning_content: str | None = None,
|
|
35
|
+
) -> None:
|
|
36
|
+
self.content = content
|
|
37
|
+
self.tool_calls = tool_calls
|
|
38
|
+
self.reasoning_content = reasoning_content
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class _FakeChoice:
|
|
42
|
+
"""Fake choice class for mocking LLM completion responses."""
|
|
43
|
+
|
|
44
|
+
def __init__(self, message: _FakeMessage) -> None:
|
|
45
|
+
self.message = message
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class _FakeResponse:
|
|
49
|
+
"""Fake response class for mocking LLM completion responses."""
|
|
50
|
+
|
|
51
|
+
def __init__(self, message: _FakeMessage) -> None:
|
|
52
|
+
self.choices = [_FakeChoice(message)]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# =============================================================================
|
|
56
|
+
# Seed reader fixtures
|
|
57
|
+
# =============================================================================
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pytest.fixture
|
|
61
|
+
def stub_seed_reader() -> StubHuggingFaceSeedReader:
|
|
62
|
+
"""Stub seed reader for testing seed dataset functionality."""
|
|
63
|
+
return StubHuggingFaceSeedReader()
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# =============================================================================
|
|
67
|
+
# MCP Provider fixtures
|
|
68
|
+
# =============================================================================
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@pytest.fixture
|
|
72
|
+
def stub_mcp_provider_registry() -> MCPProviderRegistry:
|
|
73
|
+
"""Create a stub MCP provider registry with test providers."""
|
|
74
|
+
return MCPProviderRegistry(
|
|
75
|
+
providers=[
|
|
76
|
+
LocalStdioMCPProvider(name="tools", command="python"),
|
|
77
|
+
LocalStdioMCPProvider(name="secondary", command="python"),
|
|
78
|
+
]
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@pytest.fixture
|
|
83
|
+
def stub_mcp_provider_registry_single() -> MCPProviderRegistry:
|
|
84
|
+
"""Create a stub MCP provider registry with a single provider."""
|
|
85
|
+
return MCPProviderRegistry(providers=[LocalStdioMCPProvider(name="tools", command="python")])
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@pytest.fixture
|
|
89
|
+
def stub_secret_resolver() -> MagicMock:
|
|
90
|
+
"""Create a stub secret resolver for testing."""
|
|
91
|
+
resolver = MagicMock(spec=SecretResolver)
|
|
92
|
+
resolver.resolve.side_effect = lambda x: x # Return the input as-is
|
|
93
|
+
return resolver
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@pytest.fixture
|
|
97
|
+
def stub_stdio_provider() -> LocalStdioMCPProvider:
|
|
98
|
+
"""Create a stub stdio MCP provider for testing."""
|
|
99
|
+
return LocalStdioMCPProvider(
|
|
100
|
+
name="test-stdio",
|
|
101
|
+
command="python",
|
|
102
|
+
args=["-m", "test_server"],
|
|
103
|
+
env={"TEST_VAR": "value"},
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@pytest.fixture
|
|
108
|
+
def stub_sse_provider() -> MCPProvider:
|
|
109
|
+
"""Create a stub SSE MCP provider for testing."""
|
|
110
|
+
return MCPProvider(
|
|
111
|
+
name="test-sse",
|
|
112
|
+
endpoint="http://localhost:8080/sse",
|
|
113
|
+
api_key="test-key",
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# =============================================================================
|
|
118
|
+
# Tool config fixtures
|
|
119
|
+
# =============================================================================
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@pytest.fixture
|
|
123
|
+
def stub_tool_config() -> ToolConfig:
|
|
124
|
+
"""Create a basic tool configuration for testing."""
|
|
125
|
+
return ToolConfig(
|
|
126
|
+
tool_alias="test-tools",
|
|
127
|
+
providers=["tools"],
|
|
128
|
+
max_tool_call_turns=3,
|
|
129
|
+
timeout_sec=30.0,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@pytest.fixture
|
|
134
|
+
def stub_tool_config_with_allow_list() -> ToolConfig:
|
|
135
|
+
"""Create a tool configuration with an allow list."""
|
|
136
|
+
return ToolConfig(
|
|
137
|
+
tool_alias="test-tools",
|
|
138
|
+
providers=["tools"],
|
|
139
|
+
allow_tools=["lookup", "search"],
|
|
140
|
+
max_tool_call_turns=3,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# =============================================================================
|
|
145
|
+
# Facade fixtures
|
|
146
|
+
# =============================================================================
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@pytest.fixture
|
|
150
|
+
def stub_mcp_facade(
|
|
151
|
+
stub_tool_config: ToolConfig, stub_secret_resolver: MagicMock, stub_mcp_provider_registry: MCPProviderRegistry
|
|
152
|
+
) -> MCPFacade:
|
|
153
|
+
"""Create a stub MCPFacade for testing."""
|
|
154
|
+
return MCPFacade(
|
|
155
|
+
tool_config=stub_tool_config,
|
|
156
|
+
secret_resolver=stub_secret_resolver,
|
|
157
|
+
mcp_provider_registry=stub_mcp_provider_registry,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@pytest.fixture
|
|
162
|
+
def stub_mcp_facade_factory() -> Any:
|
|
163
|
+
"""Create a stub MCP facade factory for testing."""
|
|
164
|
+
|
|
165
|
+
def factory(
|
|
166
|
+
tool_config: ToolConfig, secret_resolver: SecretResolver, provider_registry: MCPProviderRegistry
|
|
167
|
+
) -> MCPFacade:
|
|
168
|
+
return MCPFacade(
|
|
169
|
+
tool_config=tool_config, secret_resolver=secret_resolver, mcp_provider_registry=provider_registry
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return factory
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
# =============================================================================
|
|
176
|
+
# Completion response fixtures
|
|
177
|
+
# =============================================================================
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
@pytest.fixture
|
|
181
|
+
def mock_completion_response_no_tools() -> _FakeResponse:
|
|
182
|
+
"""Mock LLM response with no tool calls."""
|
|
183
|
+
return _FakeResponse(_FakeMessage(content="Hello, I can help with that."))
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@pytest.fixture
|
|
187
|
+
def mock_completion_response_single_tool() -> _FakeResponse:
|
|
188
|
+
"""Mock LLM response with single tool call."""
|
|
189
|
+
tool_call = {
|
|
190
|
+
"id": "call-1",
|
|
191
|
+
"type": "function",
|
|
192
|
+
"function": {"name": "lookup", "arguments": '{"query": "test"}'},
|
|
193
|
+
}
|
|
194
|
+
return _FakeResponse(_FakeMessage(content="Let me look that up.", tool_calls=[tool_call]))
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@pytest.fixture
|
|
198
|
+
def mock_completion_response_parallel_tools() -> _FakeResponse:
|
|
199
|
+
"""Mock LLM response with multiple parallel tool calls."""
|
|
200
|
+
tool_calls = [
|
|
201
|
+
{"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": '{"query": "first"}'}},
|
|
202
|
+
{"id": "call-2", "type": "function", "function": {"name": "search", "arguments": '{"term": "second"}'}},
|
|
203
|
+
{"id": "call-3", "type": "function", "function": {"name": "fetch", "arguments": '{"url": "example.com"}'}},
|
|
204
|
+
]
|
|
205
|
+
return _FakeResponse(_FakeMessage(content="Executing multiple tools.", tool_calls=tool_calls))
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
@pytest.fixture
|
|
209
|
+
def mock_completion_response_with_reasoning() -> _FakeResponse:
|
|
210
|
+
"""Mock LLM response with reasoning_content."""
|
|
211
|
+
return _FakeResponse(
|
|
212
|
+
_FakeMessage(
|
|
213
|
+
content=" Final answer with extra spaces. ",
|
|
214
|
+
reasoning_content=" Thinking about the problem... ",
|
|
215
|
+
)
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@pytest.fixture
|
|
220
|
+
def mock_completion_response_tool_with_reasoning() -> _FakeResponse:
|
|
221
|
+
"""Mock LLM response with tool calls and reasoning_content."""
|
|
222
|
+
tool_call = {
|
|
223
|
+
"id": "call-1",
|
|
224
|
+
"type": "function",
|
|
225
|
+
"function": {"name": "lookup", "arguments": '{"query": "test"}'},
|
|
226
|
+
}
|
|
227
|
+
return _FakeResponse(
|
|
228
|
+
_FakeMessage(
|
|
229
|
+
content=" Looking it up... ",
|
|
230
|
+
tool_calls=[tool_call],
|
|
231
|
+
reasoning_content=" I should use the lookup tool. ",
|
|
232
|
+
)
|
|
233
|
+
)
|