data-designer-engine 0.4.0rc2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. data_designer/engine/analysis/column_profilers/base.py +1 -2
  2. data_designer/engine/analysis/dataset_profiler.py +1 -2
  3. data_designer/engine/column_generators/generators/base.py +1 -6
  4. data_designer/engine/column_generators/generators/custom.py +195 -0
  5. data_designer/engine/column_generators/generators/llm_completion.py +34 -4
  6. data_designer/engine/column_generators/registry.py +3 -0
  7. data_designer/engine/column_generators/utils/errors.py +3 -0
  8. data_designer/engine/column_generators/utils/prompt_renderer.py +1 -1
  9. data_designer/engine/dataset_builders/column_wise_builder.py +47 -10
  10. data_designer/engine/dataset_builders/multi_column_configs.py +2 -2
  11. data_designer/engine/dataset_builders/utils/progress_tracker.py +122 -0
  12. data_designer/engine/mcp/__init__.py +30 -0
  13. data_designer/engine/mcp/errors.py +22 -0
  14. data_designer/engine/mcp/facade.py +485 -0
  15. data_designer/engine/mcp/factory.py +46 -0
  16. data_designer/engine/mcp/io.py +487 -0
  17. data_designer/engine/mcp/registry.py +203 -0
  18. data_designer/engine/model_provider.py +68 -0
  19. data_designer/engine/models/facade.py +92 -30
  20. data_designer/engine/models/factory.py +18 -1
  21. data_designer/engine/models/utils.py +111 -21
  22. data_designer/engine/resources/resource_provider.py +72 -3
  23. data_designer/engine/testing/fixtures.py +233 -0
  24. data_designer/engine/testing/stubs.py +1 -2
  25. {data_designer_engine-0.4.0rc2.dist-info → data_designer_engine-0.5.0rc1.dist-info}/METADATA +3 -2
  26. {data_designer_engine-0.4.0rc2.dist-info → data_designer_engine-0.5.0rc1.dist-info}/RECORD +27 -19
  27. data_designer/engine/_version.py +0 -34
  28. {data_designer_engine-0.4.0rc2.dist-info → data_designer_engine-0.5.0rc1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,487 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Low-level MCP I/O operations with caching and session pooling.
5
+
6
+ This module provides stateless functions for MCP communication using an actor-style
7
+ service that owns all async state within a single background event loop. Public APIs
8
+ are synchronous wrappers that submit coroutines to the loop and wait for results.
9
+
10
+ Architecture:
11
+ All MCP I/O is funneled through a single dedicated asyncio event loop running
12
+ in a background daemon thread. This avoids the complexity of managing multiple
13
+ event loops and allows sessions to be reused across calls from any thread.
14
+
15
+ Worker Thread 1 ──┐
16
+ Worker Thread 2 ──┼──► MCP Event Loop Thread ──► MCP Servers
17
+ Worker Thread N ──┘ (all sessions live here)
18
+
19
+ Request Coalescing:
20
+ When multiple threads request tools from the same provider simultaneously,
21
+ only one request is made to the MCP server. Other callers wait for the
22
+ in-flight request to complete and share the result. This prevents N
23
+ concurrent workers from making N separate ListToolsRequest calls.
24
+
25
+ The caller (MCPFacade) is responsible for resolving any secret references in
26
+ provider api_key fields before passing providers to these functions.
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import asyncio
32
+ import atexit
33
+ import json
34
+ import logging
35
+ import threading
36
+ from collections.abc import Coroutine, Iterable
37
+ from typing import Any
38
+
39
+ from mcp import ClientSession, StdioServerParameters
40
+ from mcp.client.sse import sse_client
41
+ from mcp.client.stdio import stdio_client
42
+
43
+ from data_designer.config.mcp import LocalStdioMCPProvider, MCPProviderT
44
+ from data_designer.engine.mcp.errors import MCPToolError
45
+ from data_designer.engine.mcp.registry import MCPToolDefinition, MCPToolResult
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ def _provider_cache_key(provider: MCPProviderT) -> str:
51
+ """Create a stable cache key for a provider."""
52
+ data = provider.model_dump(mode="json")
53
+ return json.dumps(data, sort_keys=True, separators=(",", ":"), ensure_ascii=False)
54
+
55
+
56
+ class MCPIOService:
57
+ """Actor-style MCP I/O service owning all async state."""
58
+
59
+ def __init__(self) -> None:
60
+ self._loop: asyncio.AbstractEventLoop | None = None
61
+ self._thread: threading.Thread | None = None
62
+ self._loop_lock = threading.Lock()
63
+
64
+ self._sessions: dict[str, ClientSession] = {}
65
+ self._session_contexts: dict[str, Any] = {}
66
+ self._session_inflight: dict[str, asyncio.Task[ClientSession]] = {}
67
+
68
+ self._tools_cache: dict[str, tuple[MCPToolDefinition, ...]] = {}
69
+ self._tools_cache_epoch: dict[str, int] = {}
70
+ self._inflight_tools: dict[str, asyncio.Task[tuple[MCPToolDefinition, ...]]] = {}
71
+
72
+ def list_tools(self, provider: MCPProviderT, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]:
73
+ """List tools from an MCP provider (cached with request coalescing)."""
74
+ try:
75
+ return self._run_on_loop(self._list_tools_async(provider), timeout_sec)
76
+ except TimeoutError as exc:
77
+ timeout_label = f"{timeout_sec:.1f}" if timeout_sec is not None else "unknown"
78
+ raise MCPToolError(f"Timed out after {timeout_label}s while listing tools on {provider.name!r}.") from exc
79
+
80
+ def call_tools(
81
+ self,
82
+ calls: list[tuple[MCPProviderT, str, dict[str, Any]]],
83
+ *,
84
+ timeout_sec: float | None = None,
85
+ ) -> list[MCPToolResult]:
86
+ """Call multiple tools in parallel."""
87
+ if not calls:
88
+ return []
89
+ try:
90
+ return self._run_on_loop(self._call_tools_async(calls), timeout_sec)
91
+ except TimeoutError as exc:
92
+ timeout_label = f"{timeout_sec:.1f}" if timeout_sec is not None else "unknown"
93
+ raise MCPToolError(f"Timed out after {timeout_label}s while calling tools in parallel.") from exc
94
+
95
+ def clear_provider_caches(self, providers: list[MCPProviderT]) -> int:
96
+ """Clear caches and session pool entries for specific providers."""
97
+ if not providers:
98
+ return 0
99
+ if self._loop is not None and self._loop.is_running():
100
+ try:
101
+ return self._run_on_loop(self._clear_provider_caches_async(providers), timeout_sec=5)
102
+ except Exception:
103
+ logger.debug("Failed to clear provider caches on MCP IO service.", exc_info=True)
104
+ return 0
105
+ return self._clear_provider_caches_sync(providers)
106
+
107
+ def clear_tools_cache(self) -> None:
108
+ """Clear the list_tools cache (best effort)."""
109
+ if self._loop is not None and self._loop.is_running():
110
+ try:
111
+ self._run_on_loop(self._clear_tools_cache_async(), timeout_sec=5)
112
+ return
113
+ except Exception:
114
+ logger.debug("Failed to clear tools cache on MCP IO service.", exc_info=True)
115
+ return
116
+ self._clear_tools_cache_sync()
117
+
118
+ def get_cache_info(self) -> dict[str, Any]:
119
+ """Get cache statistics for list_tools."""
120
+ if self._loop is not None and self._loop.is_running():
121
+ try:
122
+ return self._run_on_loop(self._get_cache_info_async(), timeout_sec=5)
123
+ except Exception:
124
+ logger.debug("Failed to read tools cache info on MCP IO service.", exc_info=True)
125
+ return {"currsize": len(self._tools_cache), "providers": list(self._tools_cache.keys())}
126
+
127
+ def clear_session_pool(self) -> None:
128
+ """Clear all pooled MCP sessions (best effort)."""
129
+ if self._loop is not None and self._loop.is_running():
130
+ try:
131
+ self._run_on_loop(self._close_all_sessions_async(), timeout_sec=5)
132
+ return
133
+ except Exception:
134
+ logger.debug("Failed to clear session pool on MCP IO service.", exc_info=True)
135
+ # Fall through to sync cleanup
136
+ self._clear_session_pool_sync()
137
+
138
+ def get_session_pool_info(self) -> dict[str, Any]:
139
+ """Get information about the session pool."""
140
+ if self._loop is not None and self._loop.is_running():
141
+ try:
142
+ return self._run_on_loop(self._get_session_pool_info_async(), timeout_sec=5)
143
+ except Exception:
144
+ logger.debug("Failed to read session pool info on MCP IO service.", exc_info=True)
145
+ return {"active_sessions": len(self._sessions), "provider_keys": list(self._sessions.keys())}
146
+
147
+ def shutdown(self) -> None:
148
+ """Shutdown the MCP event loop and close all sessions."""
149
+ if self._loop is None:
150
+ self._reset_state()
151
+ return
152
+ try:
153
+ future = asyncio.run_coroutine_threadsafe(self._close_all_sessions_async(), self._loop)
154
+ try:
155
+ future.result(timeout=5)
156
+ except Exception:
157
+ pass
158
+ self._loop.call_soon_threadsafe(self._loop.stop)
159
+ if self._thread is not None:
160
+ self._thread.join(timeout=5)
161
+ finally:
162
+ self._loop = None
163
+ self._thread = None
164
+ self._reset_state()
165
+
166
+ def _ensure_loop(self) -> asyncio.AbstractEventLoop:
167
+ with self._loop_lock:
168
+ if self._loop is None or not self._loop.is_running():
169
+ loop = asyncio.new_event_loop()
170
+ self._loop = loop
171
+ self._thread = threading.Thread(
172
+ target=self._run_loop,
173
+ args=(loop,),
174
+ daemon=True,
175
+ name="MCP-EventLoop",
176
+ )
177
+ self._thread.start()
178
+ logger.debug("Started MCP background event loop")
179
+ # Capture local reference to avoid race with concurrent shutdown()
180
+ loop = self._loop
181
+ return loop
182
+
183
+ @staticmethod
184
+ def _run_loop(loop: asyncio.AbstractEventLoop) -> None:
185
+ asyncio.set_event_loop(loop)
186
+ loop.run_forever()
187
+
188
+ def _run_on_loop(self, coro: Coroutine[Any, Any, Any], timeout_sec: float | None) -> Any:
189
+ loop = self._ensure_loop()
190
+ future = asyncio.run_coroutine_threadsafe(coro, loop)
191
+ return future.result(timeout=timeout_sec)
192
+
193
+ async def _get_or_create_session(self, provider: MCPProviderT) -> ClientSession:
194
+ key = _provider_cache_key(provider)
195
+ session = self._sessions.get(key)
196
+ if session is not None:
197
+ return session
198
+
199
+ inflight = self._session_inflight.get(key)
200
+ if inflight is not None:
201
+ return await inflight
202
+
203
+ async def create_session() -> ClientSession:
204
+ ctx: Any | None = None
205
+ new_session: ClientSession | None = None
206
+ try:
207
+ if isinstance(provider, LocalStdioMCPProvider):
208
+ params = StdioServerParameters(
209
+ command=provider.command,
210
+ args=provider.args,
211
+ env=provider.env,
212
+ )
213
+ ctx = stdio_client(params)
214
+ else:
215
+ headers = _build_auth_headers(provider.api_key)
216
+ ctx = sse_client(provider.endpoint, headers=headers)
217
+
218
+ read, write = await ctx.__aenter__()
219
+ new_session = ClientSession(read, write)
220
+ await new_session.__aenter__()
221
+ await new_session.initialize()
222
+
223
+ self._sessions[key] = new_session
224
+ self._session_contexts[key] = ctx
225
+ logger.debug("Created pooled MCP session for provider %r", provider.name)
226
+ return new_session
227
+ except Exception:
228
+ if new_session is not None:
229
+ try:
230
+ await new_session.__aexit__(None, None, None)
231
+ except Exception:
232
+ pass
233
+ if ctx is not None:
234
+ try:
235
+ await ctx.__aexit__(None, None, None)
236
+ except Exception:
237
+ pass
238
+ raise
239
+
240
+ task = asyncio.create_task(create_session())
241
+ self._session_inflight[key] = task
242
+ try:
243
+ return await task
244
+ finally:
245
+ self._session_inflight.pop(key, None)
246
+
247
+ async def _list_tools_async(self, provider: MCPProviderT) -> tuple[MCPToolDefinition, ...]:
248
+ key = _provider_cache_key(provider)
249
+ cached = self._tools_cache.get(key)
250
+ if cached is not None:
251
+ return cached
252
+
253
+ inflight = self._inflight_tools.get(key)
254
+ if inflight is not None:
255
+ return await inflight
256
+
257
+ epoch = self._tools_cache_epoch.get(key, 0)
258
+
259
+ async def fetch_tools() -> tuple[MCPToolDefinition, ...]:
260
+ session = await self._get_or_create_session(provider)
261
+ result = await session.list_tools()
262
+ raw_tools = getattr(result, "tools", result)
263
+ if not isinstance(raw_tools, list):
264
+ raise MCPToolError("Unexpected response from MCP provider when listing tools.")
265
+ tools = tuple(_coerce_tool_definition(tool, MCPToolDefinition) for tool in raw_tools)
266
+ if self._tools_cache_epoch.get(key, 0) == epoch:
267
+ self._tools_cache[key] = tools
268
+ logger.debug("Cached tools for provider %r (%d tools)", provider.name, len(tools))
269
+ return tools
270
+
271
+ task = asyncio.create_task(fetch_tools())
272
+ self._inflight_tools[key] = task
273
+ try:
274
+ return await task
275
+ finally:
276
+ self._inflight_tools.pop(key, None)
277
+
278
+ async def _call_tool_async(
279
+ self,
280
+ provider: MCPProviderT,
281
+ tool_name: str,
282
+ arguments: dict[str, Any],
283
+ ) -> MCPToolResult:
284
+ session = await self._get_or_create_session(provider)
285
+ result = await session.call_tool(tool_name, arguments)
286
+
287
+ content = _serialize_tool_result_content(result)
288
+ is_error = getattr(result, "isError", None)
289
+ if is_error is None:
290
+ is_error = getattr(result, "is_error", False)
291
+
292
+ return MCPToolResult(content=content, is_error=bool(is_error))
293
+
294
+ async def _call_tools_async(
295
+ self,
296
+ calls: list[tuple[MCPProviderT, str, dict[str, Any]]],
297
+ ) -> list[MCPToolResult]:
298
+ return await asyncio.gather(*[self._call_tool_async(p, n, a) for p, n, a in calls])
299
+
300
+ async def _clear_provider_caches_async(self, providers: list[MCPProviderT]) -> int:
301
+ keys = [_provider_cache_key(provider) for provider in providers]
302
+ self._invalidate_tools_cache(keys)
303
+
304
+ cleared_count = 0
305
+ for key in keys:
306
+ session = self._sessions.pop(key, None)
307
+ ctx = self._session_contexts.pop(key, None)
308
+ if session is not None:
309
+ cleared_count += 1
310
+ try:
311
+ await session.__aexit__(None, None, None)
312
+ except Exception:
313
+ pass
314
+ if ctx is not None:
315
+ try:
316
+ await ctx.__aexit__(None, None, None)
317
+ except Exception:
318
+ pass
319
+
320
+ if cleared_count > 0:
321
+ logger.debug("Cleared %d provider cache entries", cleared_count)
322
+ return cleared_count
323
+
324
+ def _clear_provider_caches_sync(self, providers: list[MCPProviderT]) -> int:
325
+ keys = [_provider_cache_key(provider) for provider in providers]
326
+ self._invalidate_tools_cache(keys)
327
+
328
+ cleared_count = 0
329
+ for key in keys:
330
+ if key in self._sessions:
331
+ del self._sessions[key]
332
+ cleared_count += 1
333
+ if key in self._session_contexts:
334
+ del self._session_contexts[key]
335
+
336
+ if cleared_count > 0:
337
+ logger.debug("Cleared %d provider cache entries", cleared_count)
338
+ return cleared_count
339
+
340
+ async def _clear_tools_cache_async(self) -> None:
341
+ self._invalidate_tools_cache(self._all_tools_keys())
342
+
343
+ def _clear_tools_cache_sync(self) -> None:
344
+ self._invalidate_tools_cache(self._all_tools_keys())
345
+
346
+ async def _get_cache_info_async(self) -> dict[str, Any]:
347
+ return {"currsize": len(self._tools_cache), "providers": list(self._tools_cache.keys())}
348
+
349
+ async def _close_all_sessions_async(self) -> None:
350
+ for key in list(self._sessions.keys()):
351
+ session = self._sessions.pop(key, None)
352
+ ctx = self._session_contexts.pop(key, None)
353
+ if session is not None:
354
+ try:
355
+ await session.__aexit__(None, None, None)
356
+ except Exception:
357
+ pass
358
+ if ctx is not None:
359
+ try:
360
+ await ctx.__aexit__(None, None, None)
361
+ except Exception:
362
+ pass
363
+
364
+ for task in self._session_inflight.values():
365
+ task.cancel()
366
+ self._session_inflight.clear()
367
+
368
+ def _clear_session_pool_sync(self) -> None:
369
+ self._sessions.clear()
370
+ self._session_contexts.clear()
371
+ self._session_inflight.clear()
372
+
373
+ async def _get_session_pool_info_async(self) -> dict[str, Any]:
374
+ return {"active_sessions": len(self._sessions), "provider_keys": list(self._sessions.keys())}
375
+
376
+ def _invalidate_tools_cache(self, keys: Iterable[str]) -> None:
377
+ for key in keys:
378
+ self._tools_cache.pop(key, None)
379
+ self._tools_cache_epoch[key] = self._tools_cache_epoch.get(key, 0) + 1
380
+
381
+ def _all_tools_keys(self) -> set[str]:
382
+ return set(self._tools_cache) | set(self._inflight_tools) | set(self._tools_cache_epoch)
383
+
384
+ def _reset_state(self) -> None:
385
+ self._sessions.clear()
386
+ self._session_contexts.clear()
387
+ self._session_inflight.clear()
388
+ self._tools_cache.clear()
389
+ self._tools_cache_epoch.clear()
390
+ self._inflight_tools.clear()
391
+
392
+
393
+ _MCP_IO_SERVICE = MCPIOService()
394
+ atexit.register(_MCP_IO_SERVICE.shutdown)
395
+
396
+
397
+ def list_tools(provider: MCPProviderT, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]:
398
+ """List tools from an MCP provider (cached with request coalescing)."""
399
+ return _MCP_IO_SERVICE.list_tools(provider, timeout_sec=timeout_sec)
400
+
401
+
402
+ def call_tools(
403
+ calls: list[tuple[MCPProviderT, str, dict[str, Any]]],
404
+ *,
405
+ timeout_sec: float | None = None,
406
+ ) -> list[MCPToolResult]:
407
+ """Call multiple tools in parallel."""
408
+ return _MCP_IO_SERVICE.call_tools(calls, timeout_sec=timeout_sec)
409
+
410
+
411
+ def clear_provider_caches(providers: list[MCPProviderT]) -> int:
412
+ """Clear all caches for specific MCP providers."""
413
+ return _MCP_IO_SERVICE.clear_provider_caches(providers)
414
+
415
+
416
+ def clear_tools_cache() -> None:
417
+ """Clear the list_tools cache."""
418
+ _MCP_IO_SERVICE.clear_tools_cache()
419
+
420
+
421
+ def get_cache_info() -> dict[str, Any]:
422
+ """Get cache statistics for list_tools."""
423
+ return _MCP_IO_SERVICE.get_cache_info()
424
+
425
+
426
+ def clear_session_pool() -> None:
427
+ """Clear all pooled MCP sessions."""
428
+ _MCP_IO_SERVICE.clear_session_pool()
429
+
430
+
431
+ def get_session_pool_info() -> dict[str, Any]:
432
+ """Get information about the session pool."""
433
+ return _MCP_IO_SERVICE.get_session_pool_info()
434
+
435
+
436
+ def _build_auth_headers(api_key: str | None) -> dict[str, Any] | None:
437
+ """Build authentication headers for SSE client."""
438
+ if not api_key:
439
+ return None
440
+ return {"Authorization": f"Bearer {api_key}"}
441
+
442
+
443
+ def _coerce_tool_definition(tool: Any, tool_definition_cls: type[MCPToolDefinition]) -> MCPToolDefinition:
444
+ """Coerce a tool from various formats into MCPToolDefinition."""
445
+ if isinstance(tool, dict):
446
+ name = tool.get("name")
447
+ description = tool.get("description")
448
+ input_schema = tool.get("inputSchema") or tool.get("input_schema")
449
+ else:
450
+ name = getattr(tool, "name", None)
451
+ description = getattr(tool, "description", None)
452
+ input_schema = getattr(tool, "inputSchema", None) or getattr(tool, "input_schema", None)
453
+
454
+ if not name:
455
+ raise MCPToolError("Encountered MCP tool without a name.")
456
+
457
+ return tool_definition_cls(name=name, description=description, input_schema=input_schema)
458
+
459
+
460
+ def _serialize_tool_result_content(result: Any) -> str:
461
+ """Serialize tool result content to a string."""
462
+ content = getattr(result, "content", result)
463
+ if content is None:
464
+ return ""
465
+ if isinstance(content, str):
466
+ return content
467
+ if isinstance(content, dict):
468
+ return json.dumps(content)
469
+ if isinstance(content, list):
470
+ parts: list[str] = []
471
+ for item in content:
472
+ if isinstance(item, str):
473
+ parts.append(item)
474
+ continue
475
+ if isinstance(item, dict):
476
+ if item.get("type") == "text":
477
+ parts.append(str(item.get("text", "")))
478
+ else:
479
+ parts.append(json.dumps(item))
480
+ continue
481
+ text_value = getattr(item, "text", None)
482
+ if text_value is not None:
483
+ parts.append(str(text_value))
484
+ else:
485
+ parts.append(str(item))
486
+ return "\n".join(parts)
487
+ return str(content)
@@ -0,0 +1,203 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import logging
7
+ from collections.abc import Callable
8
+ from dataclasses import dataclass
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ from data_designer.config.mcp import ToolConfig
12
+ from data_designer.engine.model_provider import MCPProviderRegistry
13
+ from data_designer.engine.secret_resolver import SecretResolver
14
+
15
+ if TYPE_CHECKING:
16
+ from data_designer.engine.mcp.facade import MCPFacade
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class MCPToolDefinition:
23
+ """Definition of an MCP tool with its schema."""
24
+
25
+ name: str
26
+ description: str | None
27
+ input_schema: dict[str, Any] | None
28
+
29
+ def to_openai_tool_schema(self) -> dict[str, Any]:
30
+ """Convert this tool definition to OpenAI function calling format.
31
+
32
+ Returns:
33
+ A dictionary in OpenAI's tool schema format with 'type' set to
34
+ 'function' and nested 'function' containing name, description,
35
+ and parameters.
36
+ """
37
+ schema = self.input_schema or {"type": "object", "properties": {}}
38
+ return {
39
+ "type": "function",
40
+ "function": {
41
+ "name": self.name,
42
+ "description": self.description or "",
43
+ "parameters": schema,
44
+ },
45
+ }
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class MCPToolResult:
50
+ """Result from executing an MCP tool call."""
51
+
52
+ content: str
53
+ is_error: bool = False
54
+
55
+
56
+ class MCPRegistry:
57
+ """Registry for MCP tool configurations and facades.
58
+
59
+ MCPRegistry manages ToolConfig instances by tool_alias and lazily creates
60
+ MCPFacade instances when requested. This is a config-only registry - all
61
+ actual MCP operations are delegated to the MCPFacade and io module.
62
+
63
+ This mirrors the ModelRegistry pattern for consistency across the codebase.
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ *,
69
+ secret_resolver: SecretResolver,
70
+ mcp_provider_registry: MCPProviderRegistry,
71
+ mcp_facade_factory: Callable[[ToolConfig, SecretResolver, MCPProviderRegistry], MCPFacade],
72
+ tool_configs: list[ToolConfig] | None = None,
73
+ ) -> None:
74
+ """Initialize the MCPRegistry.
75
+
76
+ Args:
77
+ secret_resolver: Resolver for secrets referenced in provider configs.
78
+ mcp_provider_registry: Registry of MCP provider configurations.
79
+ mcp_facade_factory: Factory for creating MCPFacade instances.
80
+ tool_configs: Optional list of tool configurations to register.
81
+ """
82
+ self._secret_resolver = secret_resolver
83
+ self._mcp_provider_registry = mcp_provider_registry
84
+ self._mcp_facade_factory = mcp_facade_factory
85
+ self._tool_configs: dict[str, ToolConfig] = {}
86
+ self._facades: dict[str, MCPFacade] = {}
87
+ self._validated_tool_aliases: set[str] = set()
88
+
89
+ self._set_tool_configs(tool_configs)
90
+
91
+ @property
92
+ def tool_configs(self) -> dict[str, ToolConfig]:
93
+ """Get all registered tool configurations."""
94
+ return self._tool_configs
95
+
96
+ @property
97
+ def facades(self) -> dict[str, MCPFacade]:
98
+ """Get all instantiated facades."""
99
+ return self._facades
100
+
101
+ @property
102
+ def mcp_provider_registry(self) -> MCPProviderRegistry:
103
+ """Get the MCP provider registry."""
104
+ return self._mcp_provider_registry
105
+
106
+ def register_tool_configs(self, tool_configs: list[ToolConfig]) -> None:
107
+ """Register tool configurations at runtime.
108
+
109
+ Args:
110
+ tool_configs: List of tool configurations to register. If a configuration
111
+ with the same alias already exists, it will be overwritten.
112
+ """
113
+ self._set_tool_configs(list(self._tool_configs.values()) + tool_configs)
114
+
115
+ def get_mcp(self, *, tool_alias: str) -> MCPFacade:
116
+ """Get or lazily create an MCPFacade for the given tool alias.
117
+
118
+ Args:
119
+ tool_alias: The alias of the tool configuration.
120
+
121
+ Returns:
122
+ An MCPFacade configured for the specified tool alias.
123
+
124
+ Raises:
125
+ ValueError: If no tool config with the given alias is found.
126
+ """
127
+ if tool_alias not in self._tool_configs:
128
+ raise ValueError(f"No tool config with alias {tool_alias!r} found!")
129
+
130
+ if tool_alias not in self._facades:
131
+ self._facades[tool_alias] = self._create_facade(self._tool_configs[tool_alias])
132
+
133
+ return self._facades[tool_alias]
134
+
135
+ def get_tool_config(self, *, tool_alias: str) -> ToolConfig:
136
+ """Get a tool configuration by alias.
137
+
138
+ Args:
139
+ tool_alias: The alias of the tool configuration.
140
+
141
+ Returns:
142
+ The tool configuration.
143
+
144
+ Raises:
145
+ ValueError: If no tool config with the given alias is found.
146
+ """
147
+ if tool_alias not in self._tool_configs:
148
+ raise ValueError(f"No tool config with alias {tool_alias!r} found!")
149
+ return self._tool_configs[tool_alias]
150
+
151
+ def _set_tool_configs(self, tool_configs: list[ToolConfig] | None) -> None:
152
+ """Set tool configurations from a list."""
153
+ tool_configs = tool_configs or []
154
+ self._tool_configs = {tc.tool_alias: tc for tc in tool_configs}
155
+
156
+ def _create_facade(self, tool_config: ToolConfig) -> MCPFacade:
157
+ """Create an MCPFacade for a tool configuration."""
158
+ return self._mcp_facade_factory(tool_config, self._secret_resolver, self._mcp_provider_registry)
159
+
160
+ def _validate_tool_config_providers(self, tool_config: ToolConfig) -> None:
161
+ available_providers = {provider.name for provider in self._mcp_provider_registry.providers}
162
+ missing_providers = [provider for provider in tool_config.providers if provider not in available_providers]
163
+ if missing_providers:
164
+ available_list = sorted(available_providers) if available_providers else ["(none configured)"]
165
+ raise ValueError(
166
+ f"ToolConfig '{tool_config.tool_alias}' references provider(s) {missing_providers!r} "
167
+ f"which are not registered. Available providers: {available_list}"
168
+ )
169
+
170
+ def _validate_tool_alias(self, tool_alias: str) -> None:
171
+ if tool_alias not in self._tool_configs:
172
+ raise ValueError(f"No tool config with alias {tool_alias!r} found!")
173
+ tool_config = self._tool_configs[tool_alias]
174
+ self._validate_tool_config_providers(tool_config)
175
+ facade = self.get_mcp(tool_alias=tool_alias)
176
+ facade.get_tool_schemas()
177
+ self._validated_tool_aliases.add(tool_alias)
178
+
179
+ def run_health_check(self, tool_aliases: list[str]) -> None:
180
+ if not tool_aliases:
181
+ return
182
+ logger.info("🧰 Running health checks for MCP tools...")
183
+ for tool_alias in tool_aliases:
184
+ logger.info(f" |-- 👀 Checking tools for tool alias {tool_alias!r}...")
185
+ try:
186
+ self._validate_tool_alias(tool_alias)
187
+ logger.info(" |-- ✅ Passed!")
188
+ except Exception:
189
+ logger.error(" |-- ❌ Failed!")
190
+ raise
191
+
192
+ def validate_no_duplicate_tool_names(self) -> None:
193
+ """Validate that no ToolConfig has duplicate tool names across its providers.
194
+
195
+ This method eagerly fetches tool schemas for all registered ToolConfigs,
196
+ which triggers duplicate tool name detection. This catches cases where
197
+ multiple providers in the same ToolConfig expose a tool with the same name.
198
+
199
+ Raises:
200
+ DuplicateToolNameError: If any ToolConfig has duplicate tool names across providers.
201
+ """
202
+ for tool_alias in self._tool_configs:
203
+ self._validate_tool_alias(tool_alias)