nvidia-nat-mcp 1.4.0a20260117__py3-none-any.whl → 1.5.0a20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. nat/meta/pypi.md +1 -1
  2. nat/plugins/mcp/__init__.py +1 -1
  3. nat/plugins/mcp/auth/__init__.py +1 -1
  4. nat/plugins/mcp/auth/auth_flow_handler.py +1 -1
  5. nat/plugins/mcp/auth/auth_provider.py +1 -1
  6. nat/plugins/mcp/auth/auth_provider_config.py +1 -1
  7. nat/plugins/mcp/auth/register.py +1 -1
  8. nat/plugins/mcp/auth/service_account/__init__.py +1 -1
  9. nat/plugins/mcp/auth/service_account/provider.py +1 -1
  10. nat/plugins/mcp/auth/service_account/provider_config.py +1 -1
  11. nat/plugins/mcp/auth/service_account/token_client.py +1 -1
  12. nat/plugins/mcp/auth/token_storage.py +2 -2
  13. nat/plugins/mcp/{client/client_base.py → client_base.py} +1 -1
  14. nat/plugins/mcp/{client/client_config.py → client_config.py} +9 -24
  15. nat/plugins/mcp/{client/client_impl.py → client_impl.py} +51 -219
  16. nat/plugins/mcp/exception_handler.py +1 -1
  17. nat/plugins/mcp/exceptions.py +1 -1
  18. nat/plugins/mcp/register.py +4 -5
  19. nat/plugins/mcp/tool.py +138 -0
  20. nat/plugins/mcp/utils.py +1 -1
  21. {nvidia_nat_mcp-1.4.0a20260117.dist-info → nvidia_nat_mcp-1.5.0a20251222.dist-info}/METADATA +5 -5
  22. nvidia_nat_mcp-1.5.0a20251222.dist-info/RECORD +27 -0
  23. nvidia_nat_mcp-1.5.0a20251222.dist-info/entry_points.txt +3 -0
  24. nat/plugins/mcp/cli/__init__.py +0 -15
  25. nat/plugins/mcp/cli/commands.py +0 -1055
  26. nat/plugins/mcp/client/__init__.py +0 -15
  27. nat/plugins/mcp/server/__init__.py +0 -15
  28. nat/plugins/mcp/server/front_end_config.py +0 -109
  29. nat/plugins/mcp/server/front_end_plugin.py +0 -155
  30. nat/plugins/mcp/server/front_end_plugin_worker.py +0 -415
  31. nat/plugins/mcp/server/introspection_token_verifier.py +0 -72
  32. nat/plugins/mcp/server/memory_profiler.py +0 -320
  33. nat/plugins/mcp/server/register_frontend.py +0 -27
  34. nat/plugins/mcp/server/tool_converter.py +0 -290
  35. nvidia_nat_mcp-1.4.0a20260117.dist-info/RECORD +0 -37
  36. nvidia_nat_mcp-1.4.0a20260117.dist-info/entry_points.txt +0 -9
  37. {nvidia_nat_mcp-1.4.0a20260117.dist-info → nvidia_nat_mcp-1.5.0a20251222.dist-info}/WHEEL +0 -0
  38. {nvidia_nat_mcp-1.4.0a20260117.dist-info → nvidia_nat_mcp-1.5.0a20251222.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  39. {nvidia_nat_mcp-1.4.0a20260117.dist-info → nvidia_nat_mcp-1.5.0a20251222.dist-info}/licenses/LICENSE.md +0 -0
  40. {nvidia_nat_mcp-1.4.0a20260117.dist-info → nvidia_nat_mcp-1.5.0a20251222.dist-info}/top_level.txt +0 -0
nat/meta/pypi.md CHANGED
@@ -1,5 +1,5 @@
1
1
  <!--
2
- SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
3
  SPDX-License-Identifier: Apache-2.0
4
4
 
5
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -206,7 +206,7 @@ class ObjectStoreTokenStorage(TokenStorageBase):
206
206
 
207
207
  class InMemoryTokenStorage(TokenStorageBase):
208
208
  """
209
- In-memory token storage using the built-in object store provided by the NeMo Agent toolkit.
209
+ In-memory token storage using NeMo Agent toolkit's built-in object store.
210
210
 
211
211
  This implementation uses the in-memory object store for token persistence,
212
212
  which provides a secure default option that doesn't require external storage
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -80,9 +80,9 @@ class MCPServerConfig(BaseModel):
80
80
  return self
81
81
 
82
82
 
83
- class MCPClientBaseConfig(FunctionGroupBaseConfig):
83
+ class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
84
84
  """
85
- Base configuration shared by MCP client variants.
85
+ Configuration for connecting to an MCP server as a client and exposing selected tools.
86
86
  """
87
87
  server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)")
88
88
  tool_call_timeout: timedelta = Field(
@@ -114,19 +114,6 @@ class MCPClientBaseConfig(FunctionGroupBaseConfig):
114
114
  calculator_multiply:
115
115
  description: "Multiply two numbers" # alias defaults to original name
116
116
  """)
117
-
118
- @model_validator(mode="after")
119
- def _validate_reconnect_backoff(self) -> "MCPClientBaseConfig":
120
- """Validate reconnect backoff values."""
121
- if self.reconnect_max_backoff < self.reconnect_initial_backoff:
122
- raise ValueError("reconnect_max_backoff must be greater than or equal to reconnect_initial_backoff")
123
- return self
124
-
125
-
126
- class MCPClientConfig(MCPClientBaseConfig, name="mcp_client"):
127
- """
128
- Configuration for connecting to an MCP server as a client and exposing selected tools.
129
- """
130
117
  session_aware_tools: bool = Field(default=True,
131
118
  description="Session-aware tools are created if True. Defaults to True.")
132
119
  max_sessions: int = Field(default=100,
@@ -136,11 +123,9 @@ class MCPClientConfig(MCPClientBaseConfig, name="mcp_client"):
136
123
  default=timedelta(hours=1),
137
124
  description="Time after which inactive sessions are cleaned up. Defaults to 1 hour.")
138
125
 
139
-
140
- class PerUserMCPClientConfig(MCPClientBaseConfig, name="per_user_mcp_client"):
141
- """
142
- MCP Client configuration for per-user workflows that are registered with @register_per_user_function,
143
-
144
- and each user gets their own MCP client instance.
145
- """
146
- pass
126
+ @model_validator(mode="after")
127
+ def _validate_reconnect_backoff(self) -> "MCPClientConfig":
128
+ """Validate reconnect backoff values."""
129
+ if self.reconnect_max_backoff < self.reconnect_initial_backoff:
130
+ raise ValueError("reconnect_max_backoff must be greater than or equal to reconnect_initial_backoff")
131
+ return self
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -26,77 +26,16 @@ from pydantic import BaseModel
26
26
 
27
27
  from nat.authentication.interfaces import AuthProviderBase
28
28
  from nat.builder.builder import Builder
29
- from nat.builder.context import Context
30
29
  from nat.builder.function import FunctionGroup
31
30
  from nat.cli.register_workflow import register_function_group
32
- from nat.cli.register_workflow import register_per_user_function_group
33
- from nat.plugins.mcp.client.client_base import MCPBaseClient
34
- from nat.plugins.mcp.client.client_config import MCPClientConfig
35
- from nat.plugins.mcp.client.client_config import MCPToolOverrideConfig
36
- from nat.plugins.mcp.client.client_config import PerUserMCPClientConfig
31
+ from nat.plugins.mcp.client_base import MCPBaseClient
32
+ from nat.plugins.mcp.client_config import MCPClientConfig
33
+ from nat.plugins.mcp.client_config import MCPToolOverrideConfig
37
34
  from nat.plugins.mcp.utils import truncate_session_id
38
35
 
39
36
  logger = logging.getLogger(__name__)
40
37
 
41
38
 
42
- class PerUserMCPFunctionGroup(FunctionGroup):
43
- """
44
- A specialized FunctionGroup for per-user MCP clients.
45
- """
46
-
47
- def __init__(self, *args, **kwargs):
48
- super().__init__(*args, **kwargs)
49
-
50
- self.mcp_client: MCPBaseClient | None = None # Will be set to the actual MCP client instance
51
- self.mcp_client_server_name: str | None = None
52
- self.mcp_client_transport: str | None = None
53
- self.user_id: str | None = None
54
-
55
-
56
- def mcp_per_user_tool_function(tool, client: MCPBaseClient):
57
- """
58
- Create a per-user NAT function for an MCP tool.
59
-
60
- Args:
61
- tool: The MCP tool to create a function for
62
- client: The MCP client to use for the function
63
-
64
- Returns:
65
- The NAT function
66
- """
67
- from nat.builder.function import FunctionInfo
68
-
69
- def _convert_from_str(input_str: str) -> tool.input_schema:
70
- return tool.input_schema.model_validate_json(input_str)
71
-
72
- async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
73
- try:
74
- mcp_tool = await client.get_tool(tool.name)
75
-
76
- if tool_input:
77
- args = tool_input.model_dump(exclude_none=True, mode='json')
78
- return await mcp_tool.acall(args)
79
-
80
- # kwargs arrives with all optional fields set to None because NAT's framework
81
- # converts the input dict to a Pydantic model (filling in all Field(default=None)),
82
- # then dumps it back to a dict. We need to strip out these None values because
83
- # many MCP servers (e.g., Kaggle) reject requests with excessive null fields.
84
- # We re-validate here (yes, redundant) to leverage Pydantic's exclude_none with
85
- # mode='json' for recursive None removal in nested models.
86
- # Reference: function_info.py:_convert_input_pydantic
87
- validated_input = mcp_tool.input_schema.model_validate(kwargs)
88
- args = validated_input.model_dump(exclude_none=True, mode='json')
89
- return await mcp_tool.acall(args)
90
- except Exception as e:
91
- logger.warning("Error calling tool %s", tool.name, exc_info=True)
92
- return str(e)
93
-
94
- return FunctionInfo.create(single_fn=_response_fn,
95
- description=tool.description,
96
- input_schema=tool.input_schema,
97
- converters=[_convert_from_str])
98
-
99
-
100
39
  @dataclass
101
40
  class SessionData:
102
41
  """Container for all session-related data."""
@@ -152,9 +91,9 @@ class MCPFunctionGroup(FunctionGroup):
152
91
  def __init__(self, *args, **kwargs):
153
92
  super().__init__(*args, **kwargs)
154
93
  # MCP client attributes with proper typing
155
- self.mcp_client: MCPBaseClient | None = None # Will be set to the actual MCP client instance
156
- self.mcp_client_server_name: str | None = None
157
- self.mcp_client_transport: str | None = None
94
+ self._mcp_client = None # Will be set to the actual MCP client instance
95
+ self._mcp_client_server_name: str | None = None
96
+ self._mcp_client_transport: str | None = None
158
97
 
159
98
  # Session management - consolidated data structure
160
99
  self._sessions: dict[str, SessionData] = {}
@@ -177,6 +116,36 @@ class MCPFunctionGroup(FunctionGroup):
177
116
  # Use random session id for testing only
178
117
  self._use_random_session_id_for_testing: bool = False
179
118
 
119
+ @property
120
+ def mcp_client(self):
121
+ """Get the MCP client instance."""
122
+ return self._mcp_client
123
+
124
+ @mcp_client.setter
125
+ def mcp_client(self, client):
126
+ """Set the MCP client instance."""
127
+ self._mcp_client = client
128
+
129
+ @property
130
+ def mcp_client_server_name(self) -> str | None:
131
+ """Get the MCP client server name."""
132
+ return self._mcp_client_server_name
133
+
134
+ @mcp_client_server_name.setter
135
+ def mcp_client_server_name(self, server_name: str | None):
136
+ """Set the MCP client server name."""
137
+ self._mcp_client_server_name = server_name
138
+
139
+ @property
140
+ def mcp_client_transport(self) -> str | None:
141
+ """Get the MCP client transport type."""
142
+ return self._mcp_client_transport
143
+
144
+ @mcp_client_transport.setter
145
+ def mcp_client_transport(self, transport: str | None):
146
+ """Set the MCP client transport type."""
147
+ self._mcp_client_transport = transport
148
+
180
149
  @property
181
150
  def session_count(self) -> int:
182
151
  """Current number of active sessions."""
@@ -289,7 +258,7 @@ class MCPFunctionGroup(FunctionGroup):
289
258
  except Exception as e:
290
259
  logger.warning("Error cleaning up session client %s: %s", truncate_session_id(session_id), e)
291
260
 
292
- async def _get_session_client(self, session_id: str) -> MCPBaseClient | None:
261
+ async def _get_session_client(self, session_id: str) -> MCPBaseClient:
293
262
  """Get the appropriate MCP client for the session."""
294
263
  # Throttled cleanup on access
295
264
  now = datetime.now()
@@ -379,7 +348,7 @@ class MCPFunctionGroup(FunctionGroup):
379
348
 
380
349
  async def _create_session_client(self, session_id: str) -> tuple[MCPBaseClient, asyncio.Event, asyncio.Task]:
381
350
  """Create a new MCP client instance for the session."""
382
- from nat.plugins.mcp.client.client_base import MCPStreamableHTTPClient
351
+ from nat.plugins.mcp.client_base import MCPStreamableHTTPClient
383
352
 
384
353
  config = self._client_config
385
354
  if not config:
@@ -471,13 +440,9 @@ def mcp_session_tool_function(tool, function_group: MCPFunctionGroup):
471
440
  if (not function_group._shared_auth_provider or session_id == function_group._default_user_id):
472
441
  # Use base client directly for default user
473
442
  client = function_group.mcp_client
474
- if client is None:
475
- return "Tool temporarily unavailable. Try again."
476
443
  session_tool = await client.get_tool(tool.name)
477
444
  else:
478
445
  # Use session usage context to prevent cleanup during tool execution
479
- if session_id is None:
480
- return "Tool temporarily unavailable. Try again."
481
446
  async with function_group._session_usage_context(session_id) as client:
482
447
  if client is None:
483
448
  return "Tool temporarily unavailable. Try again."
@@ -519,9 +484,9 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
519
484
  Returns:
520
485
  The function group
521
486
  """
522
- from nat.plugins.mcp.client.client_base import MCPSSEClient
523
- from nat.plugins.mcp.client.client_base import MCPStdioClient
524
- from nat.plugins.mcp.client.client_base import MCPStreamableHTTPClient
487
+ from nat.plugins.mcp.client_base import MCPSSEClient
488
+ from nat.plugins.mcp.client_base import MCPStdioClient
489
+ from nat.plugins.mcp.client_base import MCPStreamableHTTPClient
525
490
 
526
491
  # Resolve auth provider if specified
527
492
  auth_provider = None
@@ -609,16 +574,23 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
609
574
  # Create the tool function according to configuration
610
575
  tool_fn = mcp_session_tool_function(tool, group)
611
576
 
577
+ # Normalize optional typing for linter/type-checker compatibility
578
+ single_fn = tool_fn.single_fn
579
+ if single_fn is None:
580
+ # Should not happen because FunctionInfo always sets a single_fn
581
+ logger.warning("Skipping tool %s because single_fn is None", function_name)
582
+ continue
583
+
612
584
  input_schema = tool_fn.input_schema
613
585
  # Convert NoneType sentinel to None for FunctionGroup.add_function signature
614
- if input_schema is type(None):
586
+ if input_schema is type(None): # noqa: E721
615
587
  input_schema = None
616
588
 
617
589
  # Add to group
618
590
  logger.info("Adding tool %s to group", function_name)
619
591
  group.add_function(name=function_name,
620
592
  description=description,
621
- fn=tool_fn.single_fn,
593
+ fn=single_fn,
622
594
  input_schema=input_schema,
623
595
  converters=tool_fn.converters)
624
596
 
@@ -640,143 +612,3 @@ def mcp_apply_tool_alias_and_description(
640
612
  return {}
641
613
 
642
614
  return {name: override for name, override in tool_overrides.items() if name in all_tools}
643
-
644
-
645
- @register_per_user_function_group(config_type=PerUserMCPClientConfig)
646
- async def per_user_mcp_client_function_group(config: PerUserMCPClientConfig, _builder: Builder):
647
- """
648
- Connect to an MCP server and expose tools as a function group for per-user workflows.
649
-
650
- Args:
651
- config: The configuration for the MCP client
652
- _builder: The builder
653
- Returns:
654
- The function group
655
- """
656
- from nat.plugins.mcp.client.client_base import MCPSSEClient
657
- from nat.plugins.mcp.client.client_base import MCPStdioClient
658
- from nat.plugins.mcp.client.client_base import MCPStreamableHTTPClient
659
-
660
- # Resolve auth provider if specified
661
- auth_provider = None
662
- if config.server.auth_provider:
663
- auth_provider = await _builder.get_auth_provider(config.server.auth_provider)
664
-
665
- user_id = Context.get().user_id
666
-
667
- # Build the appropriate client
668
- if config.server.transport == "stdio":
669
- if not config.server.command:
670
- raise ValueError("command is required for stdio transport")
671
- client = MCPStdioClient(config.server.command,
672
- config.server.args,
673
- config.server.env,
674
- tool_call_timeout=config.tool_call_timeout,
675
- auth_flow_timeout=config.auth_flow_timeout,
676
- reconnect_enabled=config.reconnect_enabled,
677
- reconnect_max_attempts=config.reconnect_max_attempts,
678
- reconnect_initial_backoff=config.reconnect_initial_backoff,
679
- reconnect_max_backoff=config.reconnect_max_backoff)
680
- elif config.server.transport == "sse":
681
- client = MCPSSEClient(str(config.server.url),
682
- tool_call_timeout=config.tool_call_timeout,
683
- auth_flow_timeout=config.auth_flow_timeout,
684
- reconnect_enabled=config.reconnect_enabled,
685
- reconnect_max_attempts=config.reconnect_max_attempts,
686
- reconnect_initial_backoff=config.reconnect_initial_backoff,
687
- reconnect_max_backoff=config.reconnect_max_backoff)
688
- elif config.server.transport == "streamable-http":
689
- client = MCPStreamableHTTPClient(str(config.server.url),
690
- auth_provider=auth_provider,
691
- user_id=user_id,
692
- tool_call_timeout=config.tool_call_timeout,
693
- auth_flow_timeout=config.auth_flow_timeout,
694
- reconnect_enabled=config.reconnect_enabled,
695
- reconnect_max_attempts=config.reconnect_max_attempts,
696
- reconnect_initial_backoff=config.reconnect_initial_backoff,
697
- reconnect_max_backoff=config.reconnect_max_backoff)
698
- else:
699
- raise ValueError(f"Unsupported transport: {config.server.transport}")
700
-
701
- logger.info("Per-user MCP client configured for server: %s (user: %s)", client.server_name, user_id)
702
-
703
- group = PerUserMCPFunctionGroup(config=config)
704
-
705
- # Use a lifetime task to ensure the client context is entered and exited in the same task.
706
- # This avoids anyio's "Attempted to exit cancel scope in a different task" error.
707
- ready = asyncio.Event()
708
- stop_event = asyncio.Event()
709
-
710
- async def _lifetime():
711
- """Lifetime task that owns the client's async context."""
712
- try:
713
- async with client:
714
- ready.set()
715
- await stop_event.wait()
716
- except Exception:
717
- ready.set() # Ensure we don't hang the waiter
718
- raise
719
-
720
- lifetime_task = asyncio.create_task(_lifetime(), name=f"mcp-per-user-{user_id}")
721
-
722
- # Wait for client initialization
723
- timeout = config.tool_call_timeout.total_seconds()
724
- try:
725
- await asyncio.wait_for(ready.wait(), timeout=timeout)
726
- except TimeoutError:
727
- lifetime_task.cancel()
728
- try:
729
- await lifetime_task
730
- except asyncio.CancelledError:
731
- pass
732
- raise RuntimeError(f"Per-user MCP client initialization timed out after {timeout}s")
733
-
734
- # Check if initialization failed
735
- if lifetime_task.done():
736
- try:
737
- await lifetime_task
738
- except Exception as e:
739
- raise RuntimeError(f"Failed to initialize per-user MCP client: {e}") from e
740
-
741
- try:
742
- # Expose the live MCP client on the function group instance so other components (e.g., HTTP endpoints)
743
- # can reuse the already-established session instead of creating a new client per request.
744
- group.mcp_client = client
745
- group.mcp_client_server_name = client.server_name
746
- group.mcp_client_transport = client.transport
747
- group.user_id = user_id
748
-
749
- all_tools = await client.get_tools()
750
- tool_overrides = mcp_apply_tool_alias_and_description(all_tools, config.tool_overrides)
751
-
752
- # Add each tool as a function to the group
753
- for tool_name, tool in all_tools.items():
754
- # Get override if it exists
755
- override = tool_overrides.get(tool_name)
756
-
757
- # Use override values or defaults
758
- function_name = override.alias if override and override.alias else tool_name
759
- description = override.description if override and override.description else tool.description
760
-
761
- # Create the tool function according to configuration
762
- tool_fn = mcp_per_user_tool_function(tool, client)
763
-
764
- input_schema = tool_fn.input_schema
765
- # Convert NoneType sentinel to None for FunctionGroup.add_function signature
766
- if input_schema is type(None):
767
- input_schema = None
768
-
769
- # Add to group
770
- logger.info("Adding tool %s to group", function_name)
771
- group.add_function(name=function_name,
772
- description=description,
773
- fn=tool_fn.single_fn,
774
- input_schema=input_schema,
775
- converters=tool_fn.converters)
776
-
777
- yield group
778
- finally:
779
- # Signal the lifetime task to exit and wait for clean shutdown
780
- stop_event.set()
781
- if not lifetime_task.done():
782
- await lifetime_task
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,8 +16,7 @@
16
16
  # flake8: noqa
17
17
  # isort:skip_file
18
18
 
19
- # Register client components
20
- from .client import client_impl
19
+ # Import any providers which need to be automatically registered here
21
20
 
22
- # Register server/frontend components
23
- from .server import register_frontend
21
+ from . import client_impl
22
+ from . import tool