data-designer-engine 0.4.0rc2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/engine/analysis/column_profilers/base.py +1 -2
- data_designer/engine/analysis/dataset_profiler.py +1 -2
- data_designer/engine/column_generators/generators/base.py +1 -6
- data_designer/engine/column_generators/generators/custom.py +195 -0
- data_designer/engine/column_generators/generators/llm_completion.py +34 -4
- data_designer/engine/column_generators/registry.py +3 -0
- data_designer/engine/column_generators/utils/errors.py +3 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +1 -1
- data_designer/engine/dataset_builders/column_wise_builder.py +47 -10
- data_designer/engine/dataset_builders/multi_column_configs.py +2 -2
- data_designer/engine/dataset_builders/utils/progress_tracker.py +122 -0
- data_designer/engine/mcp/__init__.py +30 -0
- data_designer/engine/mcp/errors.py +22 -0
- data_designer/engine/mcp/facade.py +485 -0
- data_designer/engine/mcp/factory.py +46 -0
- data_designer/engine/mcp/io.py +487 -0
- data_designer/engine/mcp/registry.py +203 -0
- data_designer/engine/model_provider.py +68 -0
- data_designer/engine/models/facade.py +92 -30
- data_designer/engine/models/factory.py +18 -1
- data_designer/engine/models/utils.py +111 -21
- data_designer/engine/resources/resource_provider.py +72 -3
- data_designer/engine/testing/fixtures.py +233 -0
- data_designer/engine/testing/stubs.py +1 -2
- {data_designer_engine-0.4.0rc2.dist-info → data_designer_engine-0.5.0rc1.dist-info}/METADATA +3 -2
- {data_designer_engine-0.4.0rc2.dist-info → data_designer_engine-0.5.0rc1.dist-info}/RECORD +27 -19
- data_designer/engine/_version.py +0 -34
- {data_designer_engine-0.4.0rc2.dist-info → data_designer_engine-0.5.0rc1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,487 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Low-level MCP I/O operations with caching and session pooling.
|
|
5
|
+
|
|
6
|
+
This module provides stateless functions for MCP communication using an actor-style
|
|
7
|
+
service that owns all async state within a single background event loop. Public APIs
|
|
8
|
+
are synchronous wrappers that submit coroutines to the loop and wait for results.
|
|
9
|
+
|
|
10
|
+
Architecture:
|
|
11
|
+
All MCP I/O is funneled through a single dedicated asyncio event loop running
|
|
12
|
+
in a background daemon thread. This avoids the complexity of managing multiple
|
|
13
|
+
event loops and allows sessions to be reused across calls from any thread.
|
|
14
|
+
|
|
15
|
+
Worker Thread 1 ──┐
|
|
16
|
+
Worker Thread 2 ──┼──► MCP Event Loop Thread ──► MCP Servers
|
|
17
|
+
Worker Thread N ──┘ (all sessions live here)
|
|
18
|
+
|
|
19
|
+
Request Coalescing:
|
|
20
|
+
When multiple threads request tools from the same provider simultaneously,
|
|
21
|
+
only one request is made to the MCP server. Other callers wait for the
|
|
22
|
+
in-flight request to complete and share the result. This prevents N
|
|
23
|
+
concurrent workers from making N separate ListToolsRequest calls.
|
|
24
|
+
|
|
25
|
+
The caller (MCPFacade) is responsible for resolving any secret references in
|
|
26
|
+
provider api_key fields before passing providers to these functions.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
import asyncio
|
|
32
|
+
import atexit
|
|
33
|
+
import json
|
|
34
|
+
import logging
|
|
35
|
+
import threading
|
|
36
|
+
from collections.abc import Coroutine, Iterable
|
|
37
|
+
from typing import Any
|
|
38
|
+
|
|
39
|
+
from mcp import ClientSession, StdioServerParameters
|
|
40
|
+
from mcp.client.sse import sse_client
|
|
41
|
+
from mcp.client.stdio import stdio_client
|
|
42
|
+
|
|
43
|
+
from data_designer.config.mcp import LocalStdioMCPProvider, MCPProviderT
|
|
44
|
+
from data_designer.engine.mcp.errors import MCPToolError
|
|
45
|
+
from data_designer.engine.mcp.registry import MCPToolDefinition, MCPToolResult
|
|
46
|
+
|
|
47
|
+
logger = logging.getLogger(__name__)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _provider_cache_key(provider: MCPProviderT) -> str:
|
|
51
|
+
"""Create a stable cache key for a provider."""
|
|
52
|
+
data = provider.model_dump(mode="json")
|
|
53
|
+
return json.dumps(data, sort_keys=True, separators=(",", ":"), ensure_ascii=False)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class MCPIOService:
|
|
57
|
+
"""Actor-style MCP I/O service owning all async state."""
|
|
58
|
+
|
|
59
|
+
def __init__(self) -> None:
|
|
60
|
+
self._loop: asyncio.AbstractEventLoop | None = None
|
|
61
|
+
self._thread: threading.Thread | None = None
|
|
62
|
+
self._loop_lock = threading.Lock()
|
|
63
|
+
|
|
64
|
+
self._sessions: dict[str, ClientSession] = {}
|
|
65
|
+
self._session_contexts: dict[str, Any] = {}
|
|
66
|
+
self._session_inflight: dict[str, asyncio.Task[ClientSession]] = {}
|
|
67
|
+
|
|
68
|
+
self._tools_cache: dict[str, tuple[MCPToolDefinition, ...]] = {}
|
|
69
|
+
self._tools_cache_epoch: dict[str, int] = {}
|
|
70
|
+
self._inflight_tools: dict[str, asyncio.Task[tuple[MCPToolDefinition, ...]]] = {}
|
|
71
|
+
|
|
72
|
+
def list_tools(self, provider: MCPProviderT, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]:
|
|
73
|
+
"""List tools from an MCP provider (cached with request coalescing)."""
|
|
74
|
+
try:
|
|
75
|
+
return self._run_on_loop(self._list_tools_async(provider), timeout_sec)
|
|
76
|
+
except TimeoutError as exc:
|
|
77
|
+
timeout_label = f"{timeout_sec:.1f}" if timeout_sec is not None else "unknown"
|
|
78
|
+
raise MCPToolError(f"Timed out after {timeout_label}s while listing tools on {provider.name!r}.") from exc
|
|
79
|
+
|
|
80
|
+
def call_tools(
|
|
81
|
+
self,
|
|
82
|
+
calls: list[tuple[MCPProviderT, str, dict[str, Any]]],
|
|
83
|
+
*,
|
|
84
|
+
timeout_sec: float | None = None,
|
|
85
|
+
) -> list[MCPToolResult]:
|
|
86
|
+
"""Call multiple tools in parallel."""
|
|
87
|
+
if not calls:
|
|
88
|
+
return []
|
|
89
|
+
try:
|
|
90
|
+
return self._run_on_loop(self._call_tools_async(calls), timeout_sec)
|
|
91
|
+
except TimeoutError as exc:
|
|
92
|
+
timeout_label = f"{timeout_sec:.1f}" if timeout_sec is not None else "unknown"
|
|
93
|
+
raise MCPToolError(f"Timed out after {timeout_label}s while calling tools in parallel.") from exc
|
|
94
|
+
|
|
95
|
+
def clear_provider_caches(self, providers: list[MCPProviderT]) -> int:
|
|
96
|
+
"""Clear caches and session pool entries for specific providers."""
|
|
97
|
+
if not providers:
|
|
98
|
+
return 0
|
|
99
|
+
if self._loop is not None and self._loop.is_running():
|
|
100
|
+
try:
|
|
101
|
+
return self._run_on_loop(self._clear_provider_caches_async(providers), timeout_sec=5)
|
|
102
|
+
except Exception:
|
|
103
|
+
logger.debug("Failed to clear provider caches on MCP IO service.", exc_info=True)
|
|
104
|
+
return 0
|
|
105
|
+
return self._clear_provider_caches_sync(providers)
|
|
106
|
+
|
|
107
|
+
def clear_tools_cache(self) -> None:
|
|
108
|
+
"""Clear the list_tools cache (best effort)."""
|
|
109
|
+
if self._loop is not None and self._loop.is_running():
|
|
110
|
+
try:
|
|
111
|
+
self._run_on_loop(self._clear_tools_cache_async(), timeout_sec=5)
|
|
112
|
+
return
|
|
113
|
+
except Exception:
|
|
114
|
+
logger.debug("Failed to clear tools cache on MCP IO service.", exc_info=True)
|
|
115
|
+
return
|
|
116
|
+
self._clear_tools_cache_sync()
|
|
117
|
+
|
|
118
|
+
def get_cache_info(self) -> dict[str, Any]:
|
|
119
|
+
"""Get cache statistics for list_tools."""
|
|
120
|
+
if self._loop is not None and self._loop.is_running():
|
|
121
|
+
try:
|
|
122
|
+
return self._run_on_loop(self._get_cache_info_async(), timeout_sec=5)
|
|
123
|
+
except Exception:
|
|
124
|
+
logger.debug("Failed to read tools cache info on MCP IO service.", exc_info=True)
|
|
125
|
+
return {"currsize": len(self._tools_cache), "providers": list(self._tools_cache.keys())}
|
|
126
|
+
|
|
127
|
+
def clear_session_pool(self) -> None:
|
|
128
|
+
"""Clear all pooled MCP sessions (best effort)."""
|
|
129
|
+
if self._loop is not None and self._loop.is_running():
|
|
130
|
+
try:
|
|
131
|
+
self._run_on_loop(self._close_all_sessions_async(), timeout_sec=5)
|
|
132
|
+
return
|
|
133
|
+
except Exception:
|
|
134
|
+
logger.debug("Failed to clear session pool on MCP IO service.", exc_info=True)
|
|
135
|
+
# Fall through to sync cleanup
|
|
136
|
+
self._clear_session_pool_sync()
|
|
137
|
+
|
|
138
|
+
def get_session_pool_info(self) -> dict[str, Any]:
|
|
139
|
+
"""Get information about the session pool."""
|
|
140
|
+
if self._loop is not None and self._loop.is_running():
|
|
141
|
+
try:
|
|
142
|
+
return self._run_on_loop(self._get_session_pool_info_async(), timeout_sec=5)
|
|
143
|
+
except Exception:
|
|
144
|
+
logger.debug("Failed to read session pool info on MCP IO service.", exc_info=True)
|
|
145
|
+
return {"active_sessions": len(self._sessions), "provider_keys": list(self._sessions.keys())}
|
|
146
|
+
|
|
147
|
+
def shutdown(self) -> None:
|
|
148
|
+
"""Shutdown the MCP event loop and close all sessions."""
|
|
149
|
+
if self._loop is None:
|
|
150
|
+
self._reset_state()
|
|
151
|
+
return
|
|
152
|
+
try:
|
|
153
|
+
future = asyncio.run_coroutine_threadsafe(self._close_all_sessions_async(), self._loop)
|
|
154
|
+
try:
|
|
155
|
+
future.result(timeout=5)
|
|
156
|
+
except Exception:
|
|
157
|
+
pass
|
|
158
|
+
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
159
|
+
if self._thread is not None:
|
|
160
|
+
self._thread.join(timeout=5)
|
|
161
|
+
finally:
|
|
162
|
+
self._loop = None
|
|
163
|
+
self._thread = None
|
|
164
|
+
self._reset_state()
|
|
165
|
+
|
|
166
|
+
def _ensure_loop(self) -> asyncio.AbstractEventLoop:
|
|
167
|
+
with self._loop_lock:
|
|
168
|
+
if self._loop is None or not self._loop.is_running():
|
|
169
|
+
loop = asyncio.new_event_loop()
|
|
170
|
+
self._loop = loop
|
|
171
|
+
self._thread = threading.Thread(
|
|
172
|
+
target=self._run_loop,
|
|
173
|
+
args=(loop,),
|
|
174
|
+
daemon=True,
|
|
175
|
+
name="MCP-EventLoop",
|
|
176
|
+
)
|
|
177
|
+
self._thread.start()
|
|
178
|
+
logger.debug("Started MCP background event loop")
|
|
179
|
+
# Capture local reference to avoid race with concurrent shutdown()
|
|
180
|
+
loop = self._loop
|
|
181
|
+
return loop
|
|
182
|
+
|
|
183
|
+
@staticmethod
|
|
184
|
+
def _run_loop(loop: asyncio.AbstractEventLoop) -> None:
|
|
185
|
+
asyncio.set_event_loop(loop)
|
|
186
|
+
loop.run_forever()
|
|
187
|
+
|
|
188
|
+
def _run_on_loop(self, coro: Coroutine[Any, Any, Any], timeout_sec: float | None) -> Any:
|
|
189
|
+
loop = self._ensure_loop()
|
|
190
|
+
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
|
191
|
+
return future.result(timeout=timeout_sec)
|
|
192
|
+
|
|
193
|
+
async def _get_or_create_session(self, provider: MCPProviderT) -> ClientSession:
|
|
194
|
+
key = _provider_cache_key(provider)
|
|
195
|
+
session = self._sessions.get(key)
|
|
196
|
+
if session is not None:
|
|
197
|
+
return session
|
|
198
|
+
|
|
199
|
+
inflight = self._session_inflight.get(key)
|
|
200
|
+
if inflight is not None:
|
|
201
|
+
return await inflight
|
|
202
|
+
|
|
203
|
+
async def create_session() -> ClientSession:
|
|
204
|
+
ctx: Any | None = None
|
|
205
|
+
new_session: ClientSession | None = None
|
|
206
|
+
try:
|
|
207
|
+
if isinstance(provider, LocalStdioMCPProvider):
|
|
208
|
+
params = StdioServerParameters(
|
|
209
|
+
command=provider.command,
|
|
210
|
+
args=provider.args,
|
|
211
|
+
env=provider.env,
|
|
212
|
+
)
|
|
213
|
+
ctx = stdio_client(params)
|
|
214
|
+
else:
|
|
215
|
+
headers = _build_auth_headers(provider.api_key)
|
|
216
|
+
ctx = sse_client(provider.endpoint, headers=headers)
|
|
217
|
+
|
|
218
|
+
read, write = await ctx.__aenter__()
|
|
219
|
+
new_session = ClientSession(read, write)
|
|
220
|
+
await new_session.__aenter__()
|
|
221
|
+
await new_session.initialize()
|
|
222
|
+
|
|
223
|
+
self._sessions[key] = new_session
|
|
224
|
+
self._session_contexts[key] = ctx
|
|
225
|
+
logger.debug("Created pooled MCP session for provider %r", provider.name)
|
|
226
|
+
return new_session
|
|
227
|
+
except Exception:
|
|
228
|
+
if new_session is not None:
|
|
229
|
+
try:
|
|
230
|
+
await new_session.__aexit__(None, None, None)
|
|
231
|
+
except Exception:
|
|
232
|
+
pass
|
|
233
|
+
if ctx is not None:
|
|
234
|
+
try:
|
|
235
|
+
await ctx.__aexit__(None, None, None)
|
|
236
|
+
except Exception:
|
|
237
|
+
pass
|
|
238
|
+
raise
|
|
239
|
+
|
|
240
|
+
task = asyncio.create_task(create_session())
|
|
241
|
+
self._session_inflight[key] = task
|
|
242
|
+
try:
|
|
243
|
+
return await task
|
|
244
|
+
finally:
|
|
245
|
+
self._session_inflight.pop(key, None)
|
|
246
|
+
|
|
247
|
+
async def _list_tools_async(self, provider: MCPProviderT) -> tuple[MCPToolDefinition, ...]:
|
|
248
|
+
key = _provider_cache_key(provider)
|
|
249
|
+
cached = self._tools_cache.get(key)
|
|
250
|
+
if cached is not None:
|
|
251
|
+
return cached
|
|
252
|
+
|
|
253
|
+
inflight = self._inflight_tools.get(key)
|
|
254
|
+
if inflight is not None:
|
|
255
|
+
return await inflight
|
|
256
|
+
|
|
257
|
+
epoch = self._tools_cache_epoch.get(key, 0)
|
|
258
|
+
|
|
259
|
+
async def fetch_tools() -> tuple[MCPToolDefinition, ...]:
|
|
260
|
+
session = await self._get_or_create_session(provider)
|
|
261
|
+
result = await session.list_tools()
|
|
262
|
+
raw_tools = getattr(result, "tools", result)
|
|
263
|
+
if not isinstance(raw_tools, list):
|
|
264
|
+
raise MCPToolError("Unexpected response from MCP provider when listing tools.")
|
|
265
|
+
tools = tuple(_coerce_tool_definition(tool, MCPToolDefinition) for tool in raw_tools)
|
|
266
|
+
if self._tools_cache_epoch.get(key, 0) == epoch:
|
|
267
|
+
self._tools_cache[key] = tools
|
|
268
|
+
logger.debug("Cached tools for provider %r (%d tools)", provider.name, len(tools))
|
|
269
|
+
return tools
|
|
270
|
+
|
|
271
|
+
task = asyncio.create_task(fetch_tools())
|
|
272
|
+
self._inflight_tools[key] = task
|
|
273
|
+
try:
|
|
274
|
+
return await task
|
|
275
|
+
finally:
|
|
276
|
+
self._inflight_tools.pop(key, None)
|
|
277
|
+
|
|
278
|
+
async def _call_tool_async(
|
|
279
|
+
self,
|
|
280
|
+
provider: MCPProviderT,
|
|
281
|
+
tool_name: str,
|
|
282
|
+
arguments: dict[str, Any],
|
|
283
|
+
) -> MCPToolResult:
|
|
284
|
+
session = await self._get_or_create_session(provider)
|
|
285
|
+
result = await session.call_tool(tool_name, arguments)
|
|
286
|
+
|
|
287
|
+
content = _serialize_tool_result_content(result)
|
|
288
|
+
is_error = getattr(result, "isError", None)
|
|
289
|
+
if is_error is None:
|
|
290
|
+
is_error = getattr(result, "is_error", False)
|
|
291
|
+
|
|
292
|
+
return MCPToolResult(content=content, is_error=bool(is_error))
|
|
293
|
+
|
|
294
|
+
async def _call_tools_async(
|
|
295
|
+
self,
|
|
296
|
+
calls: list[tuple[MCPProviderT, str, dict[str, Any]]],
|
|
297
|
+
) -> list[MCPToolResult]:
|
|
298
|
+
return await asyncio.gather(*[self._call_tool_async(p, n, a) for p, n, a in calls])
|
|
299
|
+
|
|
300
|
+
async def _clear_provider_caches_async(self, providers: list[MCPProviderT]) -> int:
|
|
301
|
+
keys = [_provider_cache_key(provider) for provider in providers]
|
|
302
|
+
self._invalidate_tools_cache(keys)
|
|
303
|
+
|
|
304
|
+
cleared_count = 0
|
|
305
|
+
for key in keys:
|
|
306
|
+
session = self._sessions.pop(key, None)
|
|
307
|
+
ctx = self._session_contexts.pop(key, None)
|
|
308
|
+
if session is not None:
|
|
309
|
+
cleared_count += 1
|
|
310
|
+
try:
|
|
311
|
+
await session.__aexit__(None, None, None)
|
|
312
|
+
except Exception:
|
|
313
|
+
pass
|
|
314
|
+
if ctx is not None:
|
|
315
|
+
try:
|
|
316
|
+
await ctx.__aexit__(None, None, None)
|
|
317
|
+
except Exception:
|
|
318
|
+
pass
|
|
319
|
+
|
|
320
|
+
if cleared_count > 0:
|
|
321
|
+
logger.debug("Cleared %d provider cache entries", cleared_count)
|
|
322
|
+
return cleared_count
|
|
323
|
+
|
|
324
|
+
def _clear_provider_caches_sync(self, providers: list[MCPProviderT]) -> int:
|
|
325
|
+
keys = [_provider_cache_key(provider) for provider in providers]
|
|
326
|
+
self._invalidate_tools_cache(keys)
|
|
327
|
+
|
|
328
|
+
cleared_count = 0
|
|
329
|
+
for key in keys:
|
|
330
|
+
if key in self._sessions:
|
|
331
|
+
del self._sessions[key]
|
|
332
|
+
cleared_count += 1
|
|
333
|
+
if key in self._session_contexts:
|
|
334
|
+
del self._session_contexts[key]
|
|
335
|
+
|
|
336
|
+
if cleared_count > 0:
|
|
337
|
+
logger.debug("Cleared %d provider cache entries", cleared_count)
|
|
338
|
+
return cleared_count
|
|
339
|
+
|
|
340
|
+
async def _clear_tools_cache_async(self) -> None:
|
|
341
|
+
self._invalidate_tools_cache(self._all_tools_keys())
|
|
342
|
+
|
|
343
|
+
def _clear_tools_cache_sync(self) -> None:
|
|
344
|
+
self._invalidate_tools_cache(self._all_tools_keys())
|
|
345
|
+
|
|
346
|
+
async def _get_cache_info_async(self) -> dict[str, Any]:
|
|
347
|
+
return {"currsize": len(self._tools_cache), "providers": list(self._tools_cache.keys())}
|
|
348
|
+
|
|
349
|
+
async def _close_all_sessions_async(self) -> None:
|
|
350
|
+
for key in list(self._sessions.keys()):
|
|
351
|
+
session = self._sessions.pop(key, None)
|
|
352
|
+
ctx = self._session_contexts.pop(key, None)
|
|
353
|
+
if session is not None:
|
|
354
|
+
try:
|
|
355
|
+
await session.__aexit__(None, None, None)
|
|
356
|
+
except Exception:
|
|
357
|
+
pass
|
|
358
|
+
if ctx is not None:
|
|
359
|
+
try:
|
|
360
|
+
await ctx.__aexit__(None, None, None)
|
|
361
|
+
except Exception:
|
|
362
|
+
pass
|
|
363
|
+
|
|
364
|
+
for task in self._session_inflight.values():
|
|
365
|
+
task.cancel()
|
|
366
|
+
self._session_inflight.clear()
|
|
367
|
+
|
|
368
|
+
def _clear_session_pool_sync(self) -> None:
|
|
369
|
+
self._sessions.clear()
|
|
370
|
+
self._session_contexts.clear()
|
|
371
|
+
self._session_inflight.clear()
|
|
372
|
+
|
|
373
|
+
async def _get_session_pool_info_async(self) -> dict[str, Any]:
|
|
374
|
+
return {"active_sessions": len(self._sessions), "provider_keys": list(self._sessions.keys())}
|
|
375
|
+
|
|
376
|
+
def _invalidate_tools_cache(self, keys: Iterable[str]) -> None:
|
|
377
|
+
for key in keys:
|
|
378
|
+
self._tools_cache.pop(key, None)
|
|
379
|
+
self._tools_cache_epoch[key] = self._tools_cache_epoch.get(key, 0) + 1
|
|
380
|
+
|
|
381
|
+
def _all_tools_keys(self) -> set[str]:
|
|
382
|
+
return set(self._tools_cache) | set(self._inflight_tools) | set(self._tools_cache_epoch)
|
|
383
|
+
|
|
384
|
+
def _reset_state(self) -> None:
|
|
385
|
+
self._sessions.clear()
|
|
386
|
+
self._session_contexts.clear()
|
|
387
|
+
self._session_inflight.clear()
|
|
388
|
+
self._tools_cache.clear()
|
|
389
|
+
self._tools_cache_epoch.clear()
|
|
390
|
+
self._inflight_tools.clear()
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
_MCP_IO_SERVICE = MCPIOService()
|
|
394
|
+
atexit.register(_MCP_IO_SERVICE.shutdown)
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def list_tools(provider: MCPProviderT, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]:
|
|
398
|
+
"""List tools from an MCP provider (cached with request coalescing)."""
|
|
399
|
+
return _MCP_IO_SERVICE.list_tools(provider, timeout_sec=timeout_sec)
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def call_tools(
|
|
403
|
+
calls: list[tuple[MCPProviderT, str, dict[str, Any]]],
|
|
404
|
+
*,
|
|
405
|
+
timeout_sec: float | None = None,
|
|
406
|
+
) -> list[MCPToolResult]:
|
|
407
|
+
"""Call multiple tools in parallel."""
|
|
408
|
+
return _MCP_IO_SERVICE.call_tools(calls, timeout_sec=timeout_sec)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def clear_provider_caches(providers: list[MCPProviderT]) -> int:
|
|
412
|
+
"""Clear all caches for specific MCP providers."""
|
|
413
|
+
return _MCP_IO_SERVICE.clear_provider_caches(providers)
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def clear_tools_cache() -> None:
|
|
417
|
+
"""Clear the list_tools cache."""
|
|
418
|
+
_MCP_IO_SERVICE.clear_tools_cache()
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def get_cache_info() -> dict[str, Any]:
|
|
422
|
+
"""Get cache statistics for list_tools."""
|
|
423
|
+
return _MCP_IO_SERVICE.get_cache_info()
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def clear_session_pool() -> None:
|
|
427
|
+
"""Clear all pooled MCP sessions."""
|
|
428
|
+
_MCP_IO_SERVICE.clear_session_pool()
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def get_session_pool_info() -> dict[str, Any]:
|
|
432
|
+
"""Get information about the session pool."""
|
|
433
|
+
return _MCP_IO_SERVICE.get_session_pool_info()
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def _build_auth_headers(api_key: str | None) -> dict[str, Any] | None:
|
|
437
|
+
"""Build authentication headers for SSE client."""
|
|
438
|
+
if not api_key:
|
|
439
|
+
return None
|
|
440
|
+
return {"Authorization": f"Bearer {api_key}"}
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def _coerce_tool_definition(tool: Any, tool_definition_cls: type[MCPToolDefinition]) -> MCPToolDefinition:
|
|
444
|
+
"""Coerce a tool from various formats into MCPToolDefinition."""
|
|
445
|
+
if isinstance(tool, dict):
|
|
446
|
+
name = tool.get("name")
|
|
447
|
+
description = tool.get("description")
|
|
448
|
+
input_schema = tool.get("inputSchema") or tool.get("input_schema")
|
|
449
|
+
else:
|
|
450
|
+
name = getattr(tool, "name", None)
|
|
451
|
+
description = getattr(tool, "description", None)
|
|
452
|
+
input_schema = getattr(tool, "inputSchema", None) or getattr(tool, "input_schema", None)
|
|
453
|
+
|
|
454
|
+
if not name:
|
|
455
|
+
raise MCPToolError("Encountered MCP tool without a name.")
|
|
456
|
+
|
|
457
|
+
return tool_definition_cls(name=name, description=description, input_schema=input_schema)
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def _serialize_tool_result_content(result: Any) -> str:
|
|
461
|
+
"""Serialize tool result content to a string."""
|
|
462
|
+
content = getattr(result, "content", result)
|
|
463
|
+
if content is None:
|
|
464
|
+
return ""
|
|
465
|
+
if isinstance(content, str):
|
|
466
|
+
return content
|
|
467
|
+
if isinstance(content, dict):
|
|
468
|
+
return json.dumps(content)
|
|
469
|
+
if isinstance(content, list):
|
|
470
|
+
parts: list[str] = []
|
|
471
|
+
for item in content:
|
|
472
|
+
if isinstance(item, str):
|
|
473
|
+
parts.append(item)
|
|
474
|
+
continue
|
|
475
|
+
if isinstance(item, dict):
|
|
476
|
+
if item.get("type") == "text":
|
|
477
|
+
parts.append(str(item.get("text", "")))
|
|
478
|
+
else:
|
|
479
|
+
parts.append(json.dumps(item))
|
|
480
|
+
continue
|
|
481
|
+
text_value = getattr(item, "text", None)
|
|
482
|
+
if text_value is not None:
|
|
483
|
+
parts.append(str(text_value))
|
|
484
|
+
else:
|
|
485
|
+
parts.append(str(item))
|
|
486
|
+
return "\n".join(parts)
|
|
487
|
+
return str(content)
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
|
+
|
|
11
|
+
from data_designer.config.mcp import ToolConfig
|
|
12
|
+
from data_designer.engine.model_provider import MCPProviderRegistry
|
|
13
|
+
from data_designer.engine.secret_resolver import SecretResolver
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from data_designer.engine.mcp.facade import MCPFacade
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(frozen=True)
|
|
22
|
+
class MCPToolDefinition:
|
|
23
|
+
"""Definition of an MCP tool with its schema."""
|
|
24
|
+
|
|
25
|
+
name: str
|
|
26
|
+
description: str | None
|
|
27
|
+
input_schema: dict[str, Any] | None
|
|
28
|
+
|
|
29
|
+
def to_openai_tool_schema(self) -> dict[str, Any]:
|
|
30
|
+
"""Convert this tool definition to OpenAI function calling format.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
A dictionary in OpenAI's tool schema format with 'type' set to
|
|
34
|
+
'function' and nested 'function' containing name, description,
|
|
35
|
+
and parameters.
|
|
36
|
+
"""
|
|
37
|
+
schema = self.input_schema or {"type": "object", "properties": {}}
|
|
38
|
+
return {
|
|
39
|
+
"type": "function",
|
|
40
|
+
"function": {
|
|
41
|
+
"name": self.name,
|
|
42
|
+
"description": self.description or "",
|
|
43
|
+
"parameters": schema,
|
|
44
|
+
},
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(frozen=True)
|
|
49
|
+
class MCPToolResult:
|
|
50
|
+
"""Result from executing an MCP tool call."""
|
|
51
|
+
|
|
52
|
+
content: str
|
|
53
|
+
is_error: bool = False
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class MCPRegistry:
|
|
57
|
+
"""Registry for MCP tool configurations and facades.
|
|
58
|
+
|
|
59
|
+
MCPRegistry manages ToolConfig instances by tool_alias and lazily creates
|
|
60
|
+
MCPFacade instances when requested. This is a config-only registry - all
|
|
61
|
+
actual MCP operations are delegated to the MCPFacade and io module.
|
|
62
|
+
|
|
63
|
+
This mirrors the ModelRegistry pattern for consistency across the codebase.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
*,
|
|
69
|
+
secret_resolver: SecretResolver,
|
|
70
|
+
mcp_provider_registry: MCPProviderRegistry,
|
|
71
|
+
mcp_facade_factory: Callable[[ToolConfig, SecretResolver, MCPProviderRegistry], MCPFacade],
|
|
72
|
+
tool_configs: list[ToolConfig] | None = None,
|
|
73
|
+
) -> None:
|
|
74
|
+
"""Initialize the MCPRegistry.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
secret_resolver: Resolver for secrets referenced in provider configs.
|
|
78
|
+
mcp_provider_registry: Registry of MCP provider configurations.
|
|
79
|
+
mcp_facade_factory: Factory for creating MCPFacade instances.
|
|
80
|
+
tool_configs: Optional list of tool configurations to register.
|
|
81
|
+
"""
|
|
82
|
+
self._secret_resolver = secret_resolver
|
|
83
|
+
self._mcp_provider_registry = mcp_provider_registry
|
|
84
|
+
self._mcp_facade_factory = mcp_facade_factory
|
|
85
|
+
self._tool_configs: dict[str, ToolConfig] = {}
|
|
86
|
+
self._facades: dict[str, MCPFacade] = {}
|
|
87
|
+
self._validated_tool_aliases: set[str] = set()
|
|
88
|
+
|
|
89
|
+
self._set_tool_configs(tool_configs)
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def tool_configs(self) -> dict[str, ToolConfig]:
|
|
93
|
+
"""Get all registered tool configurations."""
|
|
94
|
+
return self._tool_configs
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def facades(self) -> dict[str, MCPFacade]:
|
|
98
|
+
"""Get all instantiated facades."""
|
|
99
|
+
return self._facades
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def mcp_provider_registry(self) -> MCPProviderRegistry:
|
|
103
|
+
"""Get the MCP provider registry."""
|
|
104
|
+
return self._mcp_provider_registry
|
|
105
|
+
|
|
106
|
+
def register_tool_configs(self, tool_configs: list[ToolConfig]) -> None:
|
|
107
|
+
"""Register tool configurations at runtime.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
tool_configs: List of tool configurations to register. If a configuration
|
|
111
|
+
with the same alias already exists, it will be overwritten.
|
|
112
|
+
"""
|
|
113
|
+
self._set_tool_configs(list(self._tool_configs.values()) + tool_configs)
|
|
114
|
+
|
|
115
|
+
def get_mcp(self, *, tool_alias: str) -> MCPFacade:
|
|
116
|
+
"""Get or lazily create an MCPFacade for the given tool alias.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
tool_alias: The alias of the tool configuration.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
An MCPFacade configured for the specified tool alias.
|
|
123
|
+
|
|
124
|
+
Raises:
|
|
125
|
+
ValueError: If no tool config with the given alias is found.
|
|
126
|
+
"""
|
|
127
|
+
if tool_alias not in self._tool_configs:
|
|
128
|
+
raise ValueError(f"No tool config with alias {tool_alias!r} found!")
|
|
129
|
+
|
|
130
|
+
if tool_alias not in self._facades:
|
|
131
|
+
self._facades[tool_alias] = self._create_facade(self._tool_configs[tool_alias])
|
|
132
|
+
|
|
133
|
+
return self._facades[tool_alias]
|
|
134
|
+
|
|
135
|
+
def get_tool_config(self, *, tool_alias: str) -> ToolConfig:
|
|
136
|
+
"""Get a tool configuration by alias.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
tool_alias: The alias of the tool configuration.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
The tool configuration.
|
|
143
|
+
|
|
144
|
+
Raises:
|
|
145
|
+
ValueError: If no tool config with the given alias is found.
|
|
146
|
+
"""
|
|
147
|
+
if tool_alias not in self._tool_configs:
|
|
148
|
+
raise ValueError(f"No tool config with alias {tool_alias!r} found!")
|
|
149
|
+
return self._tool_configs[tool_alias]
|
|
150
|
+
|
|
151
|
+
def _set_tool_configs(self, tool_configs: list[ToolConfig] | None) -> None:
|
|
152
|
+
"""Set tool configurations from a list."""
|
|
153
|
+
tool_configs = tool_configs or []
|
|
154
|
+
self._tool_configs = {tc.tool_alias: tc for tc in tool_configs}
|
|
155
|
+
|
|
156
|
+
def _create_facade(self, tool_config: ToolConfig) -> MCPFacade:
|
|
157
|
+
"""Create an MCPFacade for a tool configuration."""
|
|
158
|
+
return self._mcp_facade_factory(tool_config, self._secret_resolver, self._mcp_provider_registry)
|
|
159
|
+
|
|
160
|
+
def _validate_tool_config_providers(self, tool_config: ToolConfig) -> None:
|
|
161
|
+
available_providers = {provider.name for provider in self._mcp_provider_registry.providers}
|
|
162
|
+
missing_providers = [provider for provider in tool_config.providers if provider not in available_providers]
|
|
163
|
+
if missing_providers:
|
|
164
|
+
available_list = sorted(available_providers) if available_providers else ["(none configured)"]
|
|
165
|
+
raise ValueError(
|
|
166
|
+
f"ToolConfig '{tool_config.tool_alias}' references provider(s) {missing_providers!r} "
|
|
167
|
+
f"which are not registered. Available providers: {available_list}"
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def _validate_tool_alias(self, tool_alias: str) -> None:
|
|
171
|
+
if tool_alias not in self._tool_configs:
|
|
172
|
+
raise ValueError(f"No tool config with alias {tool_alias!r} found!")
|
|
173
|
+
tool_config = self._tool_configs[tool_alias]
|
|
174
|
+
self._validate_tool_config_providers(tool_config)
|
|
175
|
+
facade = self.get_mcp(tool_alias=tool_alias)
|
|
176
|
+
facade.get_tool_schemas()
|
|
177
|
+
self._validated_tool_aliases.add(tool_alias)
|
|
178
|
+
|
|
179
|
+
def run_health_check(self, tool_aliases: list[str]) -> None:
|
|
180
|
+
if not tool_aliases:
|
|
181
|
+
return
|
|
182
|
+
logger.info("🧰 Running health checks for MCP tools...")
|
|
183
|
+
for tool_alias in tool_aliases:
|
|
184
|
+
logger.info(f" |-- 👀 Checking tools for tool alias {tool_alias!r}...")
|
|
185
|
+
try:
|
|
186
|
+
self._validate_tool_alias(tool_alias)
|
|
187
|
+
logger.info(" |-- ✅ Passed!")
|
|
188
|
+
except Exception:
|
|
189
|
+
logger.error(" |-- ❌ Failed!")
|
|
190
|
+
raise
|
|
191
|
+
|
|
192
|
+
def validate_no_duplicate_tool_names(self) -> None:
|
|
193
|
+
"""Validate that no ToolConfig has duplicate tool names across its providers.
|
|
194
|
+
|
|
195
|
+
This method eagerly fetches tool schemas for all registered ToolConfigs,
|
|
196
|
+
which triggers duplicate tool name detection. This catches cases where
|
|
197
|
+
multiple providers in the same ToolConfig expose a tool with the same name.
|
|
198
|
+
|
|
199
|
+
Raises:
|
|
200
|
+
DuplicateToolNameError: If any ToolConfig has duplicate tool names across providers.
|
|
201
|
+
"""
|
|
202
|
+
for tool_alias in self._tool_configs:
|
|
203
|
+
self._validate_tool_alias(tool_alias)
|