data-designer-engine 0.4.0rc3__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/engine/analysis/column_profilers/base.py +1 -2
- data_designer/engine/analysis/dataset_profiler.py +1 -2
- data_designer/engine/column_generators/generators/base.py +1 -6
- data_designer/engine/column_generators/generators/custom.py +195 -0
- data_designer/engine/column_generators/generators/llm_completion.py +32 -5
- data_designer/engine/column_generators/registry.py +3 -0
- data_designer/engine/column_generators/utils/errors.py +3 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +1 -1
- data_designer/engine/dataset_builders/column_wise_builder.py +23 -5
- data_designer/engine/dataset_builders/multi_column_configs.py +2 -2
- data_designer/engine/mcp/__init__.py +30 -0
- data_designer/engine/mcp/errors.py +22 -0
- data_designer/engine/mcp/facade.py +485 -0
- data_designer/engine/mcp/factory.py +46 -0
- data_designer/engine/mcp/io.py +487 -0
- data_designer/engine/mcp/registry.py +203 -0
- data_designer/engine/model_provider.py +68 -0
- data_designer/engine/models/facade.py +74 -9
- data_designer/engine/models/factory.py +18 -1
- data_designer/engine/models/utils.py +28 -1
- data_designer/engine/resources/resource_provider.py +72 -3
- data_designer/engine/testing/fixtures.py +233 -0
- data_designer/engine/testing/stubs.py +1 -2
- {data_designer_engine-0.4.0rc3.dist-info → data_designer_engine-0.5.0rc1.dist-info}/METADATA +3 -2
- {data_designer_engine-0.4.0rc3.dist-info → data_designer_engine-0.5.0rc1.dist-info}/RECORD +26 -19
- data_designer/engine/_version.py +0 -34
- {data_designer_engine-0.4.0rc3.dist-info → data_designer_engine-0.5.0rc1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,485 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import uuid
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from data_designer.config.mcp import MCPProviderT, ToolConfig
|
|
11
|
+
from data_designer.engine.mcp import io as mcp_io
|
|
12
|
+
from data_designer.engine.mcp.errors import DuplicateToolNameError, MCPConfigurationError, MCPToolError
|
|
13
|
+
from data_designer.engine.mcp.registry import MCPToolDefinition
|
|
14
|
+
from data_designer.engine.model_provider import MCPProviderRegistry
|
|
15
|
+
from data_designer.engine.models.utils import ChatMessage
|
|
16
|
+
from data_designer.engine.secret_resolver import SecretResolver
|
|
17
|
+
|
|
18
|
+
DEFAULT_TOOL_REFUSAL_MESSAGE = (
|
|
19
|
+
"Tool call refused: You have reached the maximum number of tool-calling turns. "
|
|
20
|
+
"Please provide your final response without requesting additional tool calls."
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MCPFacade:
|
|
25
|
+
"""Lightweight facade scoped to a specific ToolConfig.
|
|
26
|
+
|
|
27
|
+
MCPFacade provides a clean interface for MCP tool operations within the context
|
|
28
|
+
of a specific tool configuration. It handles tool call extraction, validation,
|
|
29
|
+
and execution using the mcp.io module for communication.
|
|
30
|
+
|
|
31
|
+
This mirrors the ModelFacade pattern where each facade is scoped to a specific
|
|
32
|
+
configuration while sharing underlying resources through caching in the io module.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
tool_config: ToolConfig,
|
|
38
|
+
secret_resolver: SecretResolver,
|
|
39
|
+
mcp_provider_registry: MCPProviderRegistry,
|
|
40
|
+
) -> None:
|
|
41
|
+
"""Initialize the MCPFacade.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
tool_config: The tool configuration this facade is scoped to.
|
|
45
|
+
secret_resolver: Resolver for secrets referenced in provider configs.
|
|
46
|
+
mcp_provider_registry: Registry of MCP provider configurations.
|
|
47
|
+
"""
|
|
48
|
+
self._tool_config = tool_config
|
|
49
|
+
self._secret_resolver = secret_resolver
|
|
50
|
+
self._mcp_provider_registry = mcp_provider_registry
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def tool_alias(self) -> str:
|
|
54
|
+
"""The alias for this tool configuration."""
|
|
55
|
+
return self._tool_config.tool_alias
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def providers(self) -> list[str]:
|
|
59
|
+
"""List of MCP provider names for this configuration."""
|
|
60
|
+
return self._tool_config.providers
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def max_tool_call_turns(self) -> int:
|
|
64
|
+
"""Maximum number of tool-calling turns permitted in a single generation.
|
|
65
|
+
|
|
66
|
+
A turn is one iteration where the LLM requests tool calls. With parallel
|
|
67
|
+
tool calling, a single turn may execute multiple tools simultaneously.
|
|
68
|
+
"""
|
|
69
|
+
return self._tool_config.max_tool_call_turns
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def allow_tools(self) -> list[str] | None:
|
|
73
|
+
"""Optional allowlist of permitted tool names."""
|
|
74
|
+
return self._tool_config.allow_tools
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def timeout_sec(self) -> float | None:
|
|
78
|
+
"""Timeout in seconds for MCP tool calls."""
|
|
79
|
+
return self._tool_config.timeout_sec
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def tool_call_count(completion_response: Any) -> int:
|
|
83
|
+
"""Count the number of tool calls in a completion response.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
completion_response: The LLM completion response (litellm.ModelResponse).
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Number of tool calls in the response (0 if none).
|
|
90
|
+
"""
|
|
91
|
+
message = completion_response.choices[0].message
|
|
92
|
+
tool_calls = getattr(message, "tool_calls", None)
|
|
93
|
+
if tool_calls is None:
|
|
94
|
+
return 0
|
|
95
|
+
return len(tool_calls)
|
|
96
|
+
|
|
97
|
+
@staticmethod
|
|
98
|
+
def has_tool_calls(completion_response: Any) -> bool:
|
|
99
|
+
"""Returns True if tool calls are present in the completion response."""
|
|
100
|
+
return MCPFacade.tool_call_count(completion_response) > 0
|
|
101
|
+
|
|
102
|
+
def _resolve_provider(self, provider: MCPProviderT) -> MCPProviderT:
|
|
103
|
+
"""Resolve secret references in an MCP provider's api_key.
|
|
104
|
+
|
|
105
|
+
Creates a copy of the provider with the api_key resolved from any secret
|
|
106
|
+
reference (e.g., "env:API_KEY") to its actual value.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
provider: The MCP provider config.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
A copy of the provider with resolved api_key, or the original provider
|
|
113
|
+
if no api_key is configured.
|
|
114
|
+
"""
|
|
115
|
+
api_key_ref = getattr(provider, "api_key", None)
|
|
116
|
+
if not api_key_ref:
|
|
117
|
+
return provider
|
|
118
|
+
resolved_key = self._secret_resolver.resolve(api_key_ref)
|
|
119
|
+
return provider.model_copy(update={"api_key": resolved_key})
|
|
120
|
+
|
|
121
|
+
def get_tool_schemas(self) -> list[dict[str, Any]]:
|
|
122
|
+
"""Get OpenAI-compatible tool schemas for this configuration.
|
|
123
|
+
|
|
124
|
+
Fetches tools from all providers in the configuration and applies
|
|
125
|
+
allow_tools filtering if specified. Uses cached results from mcp_io.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
List of tool schemas in OpenAI function calling format.
|
|
129
|
+
|
|
130
|
+
Raises:
|
|
131
|
+
MCPConfigurationError: If allowed tools are not found on any provider.
|
|
132
|
+
DuplicateToolNameError: If the same tool name appears in multiple providers.
|
|
133
|
+
"""
|
|
134
|
+
all_tools: list[MCPToolDefinition] = []
|
|
135
|
+
tool_to_providers: dict[str, list[str]] = {}
|
|
136
|
+
|
|
137
|
+
for provider_name in self._tool_config.providers:
|
|
138
|
+
provider = self._mcp_provider_registry.get_provider(provider_name)
|
|
139
|
+
resolved_provider = self._resolve_provider(provider)
|
|
140
|
+
tools = mcp_io.list_tools(
|
|
141
|
+
resolved_provider, timeout_sec=self._tool_config.timeout_sec
|
|
142
|
+
) # Cached in io module
|
|
143
|
+
for tool in tools:
|
|
144
|
+
tool_to_providers.setdefault(tool.name, []).append(provider_name)
|
|
145
|
+
all_tools.extend(tools)
|
|
146
|
+
|
|
147
|
+
# Check for duplicate tool names across providers
|
|
148
|
+
duplicates = {name: providers for name, providers in tool_to_providers.items() if len(providers) > 1}
|
|
149
|
+
if duplicates:
|
|
150
|
+
dup_details = [f"'{name}' (in: {', '.join(providers)})" for name, providers in sorted(duplicates.items())]
|
|
151
|
+
raise DuplicateToolNameError(
|
|
152
|
+
f"Duplicate tool names found across MCP providers: {'; '.join(dup_details)}. "
|
|
153
|
+
"Each tool name must be unique across all providers in a ToolConfig."
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
all_available_names = set(tool_to_providers.keys())
|
|
157
|
+
allowed_names = set(self._tool_config.allow_tools) if self._tool_config.allow_tools else None
|
|
158
|
+
if allowed_names is not None:
|
|
159
|
+
missing = allowed_names.difference(all_available_names)
|
|
160
|
+
if missing:
|
|
161
|
+
provider_list = ", ".join(repr(p) for p in self._tool_config.providers)
|
|
162
|
+
raise MCPConfigurationError(
|
|
163
|
+
f"Tool(s) {sorted(missing)!r} not found on any of the MCP providers: {provider_list}."
|
|
164
|
+
)
|
|
165
|
+
all_tools = [tool for tool in all_tools if tool.name in allowed_names]
|
|
166
|
+
|
|
167
|
+
return [tool.to_openai_tool_schema() for tool in all_tools]
|
|
168
|
+
|
|
169
|
+
def process_completion_response(
|
|
170
|
+
self,
|
|
171
|
+
completion_response: Any,
|
|
172
|
+
) -> list[ChatMessage]:
|
|
173
|
+
"""Process an LLM completion response and execute any tool calls.
|
|
174
|
+
|
|
175
|
+
This is the primary method for handling tool calls from an LLM response.
|
|
176
|
+
It extracts the response content, reasoning content, and all tool calls
|
|
177
|
+
from the completion response, executes each tool call (including parallel
|
|
178
|
+
tool calls), and returns the messages for continuing the conversation.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
completion_response: The completion response object from the LLM,
|
|
182
|
+
typically from `router.completion()`. Expected to have a
|
|
183
|
+
`choices[0].message` structure with optional `content`,
|
|
184
|
+
`reasoning_content`, and `tool_calls` attributes.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
A list of ChatMessages to append to the conversation history:
|
|
188
|
+
- If tool calls were present: [assistant_message_with_tool_calls, *tool_response_messages]
|
|
189
|
+
- If no tool calls: [assistant_message]
|
|
190
|
+
|
|
191
|
+
Raises:
|
|
192
|
+
MCPToolError: If a tool call is missing a name.
|
|
193
|
+
MCPToolError: If tool call arguments cannot be parsed as JSON.
|
|
194
|
+
MCPToolError: If tool call arguments are an unsupported type.
|
|
195
|
+
MCPToolError: If a requested tool is not in the allowed tools list.
|
|
196
|
+
MCPToolError: If tool execution fails or times out.
|
|
197
|
+
MCPConfigurationError: If a requested tool is not found on any configured provider.
|
|
198
|
+
"""
|
|
199
|
+
message = completion_response.choices[0].message
|
|
200
|
+
|
|
201
|
+
# Extract response content and reasoning content
|
|
202
|
+
response_content = message.content or ""
|
|
203
|
+
reasoning_content = getattr(message, "reasoning_content", None)
|
|
204
|
+
|
|
205
|
+
# Strip whitespace if reasoning is present (models often add extra newlines)
|
|
206
|
+
if reasoning_content:
|
|
207
|
+
response_content = response_content.strip()
|
|
208
|
+
reasoning_content = reasoning_content.strip()
|
|
209
|
+
|
|
210
|
+
# Extract and normalize tool calls
|
|
211
|
+
tool_calls = self._extract_tool_calls(message)
|
|
212
|
+
|
|
213
|
+
if not tool_calls:
|
|
214
|
+
# No tool calls - just return the assistant message
|
|
215
|
+
return [
|
|
216
|
+
ChatMessage.as_assistant(
|
|
217
|
+
content=response_content,
|
|
218
|
+
reasoning_content=reasoning_content or None,
|
|
219
|
+
)
|
|
220
|
+
]
|
|
221
|
+
|
|
222
|
+
# Has tool calls - execute and return all messages
|
|
223
|
+
assistant_message = self._build_assistant_tool_message(response_content, tool_calls, reasoning_content)
|
|
224
|
+
tool_messages = self._execute_tool_calls_internal(tool_calls)
|
|
225
|
+
|
|
226
|
+
return [assistant_message, *tool_messages]
|
|
227
|
+
|
|
228
|
+
def refuse_completion_response(
|
|
229
|
+
self,
|
|
230
|
+
completion_response: Any,
|
|
231
|
+
refusal_message: str | None = None,
|
|
232
|
+
) -> list[ChatMessage]:
|
|
233
|
+
"""Refuse tool calls without executing them.
|
|
234
|
+
|
|
235
|
+
Used when the tool call turn budget is exhausted. Returns messages
|
|
236
|
+
that include the assistant's tool call request but with refusal
|
|
237
|
+
responses instead of actual tool results. This allows the model
|
|
238
|
+
to gracefully degrade and provide a final response without tools.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
completion_response: The LLM completion response containing tool calls.
|
|
242
|
+
refusal_message: Optional custom refusal message. Defaults to a
|
|
243
|
+
standard message about tool budget exhaustion.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
A list of ChatMessages to append to the conversation history:
|
|
247
|
+
- If tool calls were present: [assistant_message_with_tool_calls, *refusal_messages]
|
|
248
|
+
- If no tool calls: [assistant_message]
|
|
249
|
+
"""
|
|
250
|
+
message = completion_response.choices[0].message
|
|
251
|
+
|
|
252
|
+
# Extract response content and reasoning content
|
|
253
|
+
response_content = message.content or ""
|
|
254
|
+
reasoning_content = getattr(message, "reasoning_content", None)
|
|
255
|
+
|
|
256
|
+
# Strip whitespace if reasoning is present (models often add extra newlines)
|
|
257
|
+
if reasoning_content:
|
|
258
|
+
response_content = response_content.strip()
|
|
259
|
+
reasoning_content = reasoning_content.strip()
|
|
260
|
+
|
|
261
|
+
# Extract and normalize tool calls
|
|
262
|
+
tool_calls = self._extract_tool_calls(message)
|
|
263
|
+
|
|
264
|
+
if not tool_calls:
|
|
265
|
+
# No tool calls to refuse - just return assistant message
|
|
266
|
+
return [
|
|
267
|
+
ChatMessage.as_assistant(
|
|
268
|
+
content=response_content,
|
|
269
|
+
reasoning_content=reasoning_content or None,
|
|
270
|
+
)
|
|
271
|
+
]
|
|
272
|
+
|
|
273
|
+
# Build assistant message with tool calls (same as normal)
|
|
274
|
+
assistant_message = self._build_assistant_tool_message(response_content, tool_calls, reasoning_content)
|
|
275
|
+
|
|
276
|
+
# Build refusal messages instead of executing tools
|
|
277
|
+
refusal = refusal_message or DEFAULT_TOOL_REFUSAL_MESSAGE
|
|
278
|
+
tool_messages = [ChatMessage.as_tool(content=refusal, tool_call_id=tc["id"]) for tc in tool_calls]
|
|
279
|
+
|
|
280
|
+
return [assistant_message, *tool_messages]
|
|
281
|
+
|
|
282
|
+
def _extract_tool_calls(self, message: Any) -> list[dict[str, Any]]:
|
|
283
|
+
"""Extract and normalize tool calls from an LLM response message.
|
|
284
|
+
|
|
285
|
+
Handles various LLM response formats (dict or object with attributes)
|
|
286
|
+
and normalizes them into a consistent dictionary format. Supports
|
|
287
|
+
parallel tool calling where the model returns multiple tool calls
|
|
288
|
+
in a single response.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
message: The LLM response message, either as a dictionary or an object
|
|
292
|
+
with a 'tool_calls' attribute.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
A list of normalized tool call dictionaries. Each dictionary contains:
|
|
296
|
+
- 'id': Unique identifier for the tool call (generated if not provided)
|
|
297
|
+
- 'name': The name of the tool to call
|
|
298
|
+
- 'arguments': Parsed arguments as a dictionary
|
|
299
|
+
- 'arguments_json': Arguments serialized as a JSON string
|
|
300
|
+
|
|
301
|
+
Returns an empty list if no tool calls are present in the message.
|
|
302
|
+
|
|
303
|
+
Raises:
|
|
304
|
+
MCPToolError: If a tool call is missing a name.
|
|
305
|
+
MCPToolError: If tool call arguments cannot be parsed as JSON.
|
|
306
|
+
MCPToolError: If tool call arguments are an unsupported type.
|
|
307
|
+
"""
|
|
308
|
+
raw_tool_calls = getattr(message, "tool_calls", None)
|
|
309
|
+
if raw_tool_calls is None and isinstance(message, dict):
|
|
310
|
+
raw_tool_calls = message.get("tool_calls")
|
|
311
|
+
if not raw_tool_calls:
|
|
312
|
+
return []
|
|
313
|
+
|
|
314
|
+
tool_calls: list[dict[str, Any]] = []
|
|
315
|
+
for raw_tool_call in raw_tool_calls:
|
|
316
|
+
tool_calls.append(self._normalize_tool_call(raw_tool_call))
|
|
317
|
+
return tool_calls
|
|
318
|
+
|
|
319
|
+
def _normalize_tool_call(self, raw_tool_call: Any) -> dict[str, Any]:
|
|
320
|
+
"""Normalize a tool call from various LLM response formats.
|
|
321
|
+
|
|
322
|
+
Handles both dictionary and object representations of tool calls,
|
|
323
|
+
supporting the OpenAI format (with nested 'function' key) and
|
|
324
|
+
flattened formats.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
raw_tool_call: A tool call in any supported format.
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
A normalized tool call dictionary with keys:
|
|
331
|
+
- 'id': Tool call identifier (UUID generated if not provided)
|
|
332
|
+
- 'name': The tool name
|
|
333
|
+
- 'arguments': Parsed arguments dictionary
|
|
334
|
+
- 'arguments_json': JSON string of arguments
|
|
335
|
+
|
|
336
|
+
Raises:
|
|
337
|
+
MCPToolError: If the tool call is missing a name or has invalid
|
|
338
|
+
arguments that cannot be parsed as JSON.
|
|
339
|
+
"""
|
|
340
|
+
if isinstance(raw_tool_call, dict):
|
|
341
|
+
tool_call_id = raw_tool_call.get("id")
|
|
342
|
+
function = raw_tool_call.get("function") or {}
|
|
343
|
+
name = function.get("name") or raw_tool_call.get("name")
|
|
344
|
+
arguments = function.get("arguments") or raw_tool_call.get("arguments")
|
|
345
|
+
else:
|
|
346
|
+
tool_call_id = getattr(raw_tool_call, "id", None)
|
|
347
|
+
function = getattr(raw_tool_call, "function", None)
|
|
348
|
+
name = getattr(function, "name", None) if function is not None else getattr(raw_tool_call, "name", None)
|
|
349
|
+
arguments = (
|
|
350
|
+
getattr(function, "arguments", None)
|
|
351
|
+
if function is not None
|
|
352
|
+
else getattr(raw_tool_call, "arguments", None)
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
if not name:
|
|
356
|
+
raise MCPToolError("MCP tool call is missing a tool name.")
|
|
357
|
+
|
|
358
|
+
arguments_payload: dict[str, Any]
|
|
359
|
+
if arguments is None or arguments == "":
|
|
360
|
+
arguments_payload = {}
|
|
361
|
+
elif isinstance(arguments, str):
|
|
362
|
+
try:
|
|
363
|
+
arguments_payload = json.loads(arguments)
|
|
364
|
+
except json.JSONDecodeError as exc:
|
|
365
|
+
raise MCPToolError(f"Invalid tool arguments for '{name}': {arguments}") from exc
|
|
366
|
+
elif isinstance(arguments, dict):
|
|
367
|
+
arguments_payload = arguments
|
|
368
|
+
else:
|
|
369
|
+
raise MCPToolError(f"Unsupported tool arguments type for '{name}': {type(arguments)!r}")
|
|
370
|
+
|
|
371
|
+
# Normalize arguments_json to ensure valid, canonical JSON
|
|
372
|
+
try:
|
|
373
|
+
arguments_json = json.dumps(arguments_payload)
|
|
374
|
+
except TypeError as exc:
|
|
375
|
+
raise MCPToolError(f"Non-serializable tool arguments for '{name}': {exc}") from exc
|
|
376
|
+
|
|
377
|
+
return {
|
|
378
|
+
"id": tool_call_id or uuid.uuid4().hex,
|
|
379
|
+
"name": name,
|
|
380
|
+
"arguments": arguments_payload,
|
|
381
|
+
"arguments_json": arguments_json,
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
def _build_assistant_tool_message(
|
|
385
|
+
self,
|
|
386
|
+
response: str | None,
|
|
387
|
+
tool_calls: list[dict[str, Any]],
|
|
388
|
+
reasoning_content: str | None = None,
|
|
389
|
+
) -> ChatMessage:
|
|
390
|
+
"""Build the assistant message containing tool call requests.
|
|
391
|
+
|
|
392
|
+
Constructs a message in the format expected by the LLM conversation
|
|
393
|
+
history, representing the assistant's request to call tools.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
response: The assistant's text response content. May be empty if
|
|
397
|
+
the assistant only requested tool calls without additional text.
|
|
398
|
+
tool_calls: List of normalized tool call dictionaries.
|
|
399
|
+
reasoning_content: Optional reasoning content from the assistant's
|
|
400
|
+
response. If provided, will be included under the 'reasoning_content' key.
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
A ChatMessage representing the assistant message with tool call requests.
|
|
404
|
+
"""
|
|
405
|
+
tool_calls_payload = [
|
|
406
|
+
{
|
|
407
|
+
"id": tool_call["id"],
|
|
408
|
+
"type": "function",
|
|
409
|
+
"function": {"name": tool_call["name"], "arguments": tool_call["arguments_json"]},
|
|
410
|
+
}
|
|
411
|
+
for tool_call in tool_calls
|
|
412
|
+
]
|
|
413
|
+
return ChatMessage.as_assistant(
|
|
414
|
+
content=response or "",
|
|
415
|
+
reasoning_content=reasoning_content or None,
|
|
416
|
+
tool_calls=tool_calls_payload,
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
def _execute_tool_calls_internal(
|
|
420
|
+
self,
|
|
421
|
+
tool_calls: list[dict[str, Any]],
|
|
422
|
+
) -> list[ChatMessage]:
|
|
423
|
+
"""Execute tool calls in parallel and return tool response messages.
|
|
424
|
+
|
|
425
|
+
Validates all tool calls, then executes them concurrently via the io module
|
|
426
|
+
using call_tools_parallel. This leverages parallel tool calling when the
|
|
427
|
+
model returns multiple tool calls in a single response.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
tool_calls: List of normalized tool call dictionaries to execute.
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
A list of tool response messages, one per tool call.
|
|
434
|
+
|
|
435
|
+
Raises:
|
|
436
|
+
MCPToolError: If a tool is not in the allowed tools list or if
|
|
437
|
+
the MCP provider returns an error.
|
|
438
|
+
"""
|
|
439
|
+
allowed_tools = set(self._tool_config.allow_tools) if self._tool_config.allow_tools else None
|
|
440
|
+
|
|
441
|
+
# Validate all tool calls and collect provider + args
|
|
442
|
+
calls_to_execute: list[tuple[MCPProviderT, str, dict[str, Any], str]] = []
|
|
443
|
+
for tool_call in tool_calls:
|
|
444
|
+
tool_name = tool_call["name"]
|
|
445
|
+
if allowed_tools is not None and tool_name not in allowed_tools:
|
|
446
|
+
providers_str = ", ".join(repr(p) for p in self._tool_config.providers)
|
|
447
|
+
raise MCPToolError(f"Tool {tool_name!r} is not permitted for providers: {providers_str}.")
|
|
448
|
+
|
|
449
|
+
resolved_provider = self._find_resolved_provider_for_tool(tool_name)
|
|
450
|
+
calls_to_execute.append((resolved_provider, tool_name, tool_call["arguments"], tool_call["id"]))
|
|
451
|
+
|
|
452
|
+
# Execute all calls in parallel
|
|
453
|
+
results = mcp_io.call_tools(
|
|
454
|
+
[(p, n, a) for p, n, a, _ in calls_to_execute],
|
|
455
|
+
timeout_sec=self._tool_config.timeout_sec,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
# Build response messages
|
|
459
|
+
return [
|
|
460
|
+
ChatMessage.as_tool(content=result.content, tool_call_id=call[3])
|
|
461
|
+
for result, call in zip(results, calls_to_execute)
|
|
462
|
+
]
|
|
463
|
+
|
|
464
|
+
def _find_resolved_provider_for_tool(self, tool_name: str) -> MCPProviderT:
|
|
465
|
+
"""Find the provider that has the given tool and return it with resolved api_key.
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
tool_name: The name of the tool to find.
|
|
469
|
+
|
|
470
|
+
Returns:
|
|
471
|
+
The provider object (with resolved api_key) that has the tool.
|
|
472
|
+
|
|
473
|
+
Raises:
|
|
474
|
+
MCPConfigurationError: If no provider has the tool.
|
|
475
|
+
"""
|
|
476
|
+
for provider_name in self._tool_config.providers:
|
|
477
|
+
provider = self._mcp_provider_registry.get_provider(provider_name)
|
|
478
|
+
resolved_provider = self._resolve_provider(provider)
|
|
479
|
+
tools = mcp_io.list_tools(
|
|
480
|
+
resolved_provider, timeout_sec=self._tool_config.timeout_sec
|
|
481
|
+
) # Cached in io module
|
|
482
|
+
if any(tool.name == tool_name for tool in tools):
|
|
483
|
+
return resolved_provider
|
|
484
|
+
|
|
485
|
+
raise MCPConfigurationError(f"Tool {tool_name!r} not found on any configured provider.")
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from data_designer.config.mcp import ToolConfig
|
|
7
|
+
from data_designer.engine.mcp.facade import MCPFacade
|
|
8
|
+
from data_designer.engine.mcp.registry import MCPRegistry
|
|
9
|
+
from data_designer.engine.model_provider import MCPProviderRegistry
|
|
10
|
+
from data_designer.engine.secret_resolver import SecretResolver
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def create_mcp_registry(
|
|
14
|
+
*,
|
|
15
|
+
tool_configs: list[ToolConfig] | None = None,
|
|
16
|
+
secret_resolver: SecretResolver,
|
|
17
|
+
mcp_provider_registry: MCPProviderRegistry,
|
|
18
|
+
) -> MCPRegistry:
|
|
19
|
+
"""Factory function for creating an MCPRegistry instance.
|
|
20
|
+
|
|
21
|
+
This factory function creates an MCPRegistry with a facade factory that
|
|
22
|
+
creates MCPFacade instances on demand. It follows the same pattern as
|
|
23
|
+
create_model_registry for consistency.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
tool_configs: Optional list of tool configurations to register.
|
|
27
|
+
secret_resolver: Resolver for secrets referenced in provider configs.
|
|
28
|
+
mcp_provider_registry: Registry of MCP provider configurations.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
A configured MCPRegistry instance.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def mcp_facade_factory(
|
|
35
|
+
tool_config: ToolConfig, secret_resolver: SecretResolver, provider_registry: MCPProviderRegistry
|
|
36
|
+
) -> MCPFacade:
|
|
37
|
+
return MCPFacade(
|
|
38
|
+
tool_config=tool_config, secret_resolver=secret_resolver, mcp_provider_registry=provider_registry
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
return MCPRegistry(
|
|
42
|
+
secret_resolver=secret_resolver,
|
|
43
|
+
mcp_provider_registry=mcp_provider_registry,
|
|
44
|
+
mcp_facade_factory=mcp_facade_factory,
|
|
45
|
+
tool_configs=tool_configs,
|
|
46
|
+
)
|