nvidia-nat-mcp 1.4.0a20251008__py3-none-any.whl → 1.4.0a20251010__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -44,10 +44,48 @@ class SessionData:
44
44
  ref_count: int = 0
45
45
  lock: asyncio.Lock = field(default_factory=asyncio.Lock)
46
46
 
47
+ # lifetime task to respect task boundaries
48
+ stop_event: asyncio.Event = field(default_factory=asyncio.Event)
49
+ lifetime_task: asyncio.Task | None = None
50
+
47
51
 
48
52
  class MCPFunctionGroup(FunctionGroup):
49
53
  """
50
- A specialized FunctionGroup for MCP clients that includes MCP-specific attributes with session management.
54
+ A specialized FunctionGroup for MCP clients that includes MCP-specific attributes
55
+ with session management.
56
+
57
+ Locking model (simple + safe; occasional 'temporarily unavailable' is acceptable).
58
+
59
+ RW semantics:
60
+ - Multiple readers may hold the reader lock concurrently.
61
+ - While any reader holds the lock, writers cannot proceed.
62
+ - While the writer holds the lock, no new readers can proceed.
63
+
64
+ Data:
65
+ - _sessions: dict[str, SessionData]; SessionData = {client, last_activity, ref_count, lock}.
66
+
67
+ Locks:
68
+ - _session_rwlock (aiorwlock.RWLock)
69
+ • Reader: very short sections — dict lookups, ref_count ++/--, touch last_activity.
70
+ • Writer: structural changes — create session entries, enforce limits, remove on cleanup.
71
+ - SessionData.lock (asyncio.Lock)
72
+ • Protects per-session ref_count only, taken only while holding RW *reader*.
73
+ • last_activity: written without session lock (timestamp races acceptable for cleanup heuristic).
74
+
75
+ Ordering & awaits:
76
+ - Always acquire RWLock (reader/writer) before SessionData.lock; never the reverse.
77
+ - Never await network I/O under the writer (client creation is the one intentional exception).
78
+ - Client close happens after releasing the writer.
79
+
80
+ Cleanup:
81
+ - Under writer: find inactive (ref_count == 0 and idle > max_age), pop from _sessions, stash clients.
82
+ - After writer: await client.__aexit__() for each stashed client.
83
+ - TOCTOU race: cleanup may read ref_count==0 then a usage increments it; accepted, yields None gracefully.
84
+
85
+ Invariants:
86
+ - ref_count > 0 prevents cleanup.
87
+ - Usage context increments ref_count before yielding and decrements on exit.
88
+ - If a session disappears between ensure/use, callers return "Tool temporarily unavailable".
51
89
  """
52
90
 
53
91
  def __init__(self, *args, **kwargs):
@@ -71,6 +109,9 @@ class MCPFunctionGroup(FunctionGroup):
71
109
  self._shared_auth_provider: AuthProviderBase | None = None
72
110
  self._client_config: MCPClientConfig | None = None
73
111
 
112
+ # Use random session id for testing only
113
+ self._use_random_session_id_for_testing: bool = False
114
+
74
115
  @property
75
116
  def mcp_client(self):
76
117
  """Get the MCP client instance."""
@@ -111,6 +152,11 @@ class MCPFunctionGroup(FunctionGroup):
111
152
  """Maximum allowed sessions."""
112
153
  return self._client_config.max_sessions if self._client_config else 100
113
154
 
155
+ def _get_random_session_id(self) -> str:
156
+ """Get a random session ID."""
157
+ import uuid
158
+ return str(uuid.uuid4())
159
+
114
160
  def _get_session_id_from_context(self) -> str | None:
115
161
  """Get the session ID from the current context."""
116
162
  try:
@@ -118,9 +164,15 @@ class MCPFunctionGroup(FunctionGroup):
118
164
 
119
165
  # Get session id from context, authentication is done per-websocket session for tool calls
120
166
  session_id = None
121
- cookies = getattr(_Ctx.get().metadata, "cookies", None)
122
- if cookies:
123
- session_id = cookies.get("nat-session")
167
+ # get session id from cookies if session_aware_tools is enabled
168
+ if self._client_config and self._client_config.session_aware_tools:
169
+ cookies = getattr(_Ctx.get().metadata, "cookies", None)
170
+ if cookies:
171
+ if self._use_random_session_id_for_testing:
172
+ # This path is for testing only and should not be used in production
173
+ session_id = self._get_random_session_id()
174
+ else:
175
+ session_id = cookies.get("nat-session")
124
176
 
125
177
  if not session_id:
126
178
  # use default user id if allowed
@@ -154,6 +206,8 @@ class MCPFunctionGroup(FunctionGroup):
154
206
  if max_age is None:
155
207
  max_age = self._client_config.session_idle_timeout if self._client_config else timedelta(hours=1)
156
208
 
209
+ to_close: list[tuple[str, SessionData]] = []
210
+
157
211
  async with self._session_rwlock.writer:
158
212
  current_time = datetime.now()
159
213
  inactive_sessions = []
@@ -171,8 +225,8 @@ class MCPFunctionGroup(FunctionGroup):
171
225
  logger.info("Cleaning up inactive session client: %s", truncate_session_id(session_id))
172
226
  session_data = self._sessions[session_id]
173
227
  # Close the client connection
174
- await session_data.client.__aexit__(None, None, None)
175
- logger.info("Cleaned up inactive session client: %s", truncate_session_id(session_id))
228
+ if session_data:
229
+ to_close.append((session_id, session_data))
176
230
  except Exception as e:
177
231
  logger.warning("Error cleaning up session client %s: %s", truncate_session_id(session_id), e)
178
232
  finally:
@@ -181,6 +235,26 @@ class MCPFunctionGroup(FunctionGroup):
181
235
  logger.info("Cleaned up session tracking for: %s", truncate_session_id(session_id))
182
236
  logger.info(" Total sessions: %d", len(self._sessions))
183
237
 
238
+ # Close sessions outside the writer lock to avoid deadlock
239
+ for session_id, sdata in to_close:
240
+ try:
241
+ if sdata.stop_event and sdata.lifetime_task:
242
+ if not sdata.lifetime_task.done():
243
+ # Instead of directly exiting the task, set the stop event
244
+ # and wait for the task to exit. This ensures the cancel scope
245
+ # is entered and exited in the same task.
246
+ sdata.stop_event.set()
247
+ await sdata.lifetime_task # __aexit__ runs in that task
248
+ else:
249
+ logger.debug("Session client %s lifetime task already done", truncate_session_id(session_id))
250
+ else:
251
+ # add fallback to ensure we clean up the client
252
+ logger.warning("Session client %s lifetime task not found, cleaning up client",
253
+ truncate_session_id(session_id))
254
+ await sdata.client.__aexit__(None, None, None)
255
+ except Exception as e:
256
+ logger.warning("Error cleaning up session client %s: %s", truncate_session_id(session_id), e)
257
+
184
258
  async def _get_session_client(self, session_id: str) -> MCPBaseClient:
185
259
  """Get the appropriate MCP client for the session."""
186
260
  # Throttled cleanup on access
@@ -220,15 +294,19 @@ class MCPFunctionGroup(FunctionGroup):
220
294
  logger.warning("Session limit reached (%d), rejecting new session: %s",
221
295
  self._client_config.max_sessions,
222
296
  truncate_session_id(session_id))
223
- raise RuntimeError(f"Service temporarily unavailable: Maximum concurrent sessions "
224
- f"({self._client_config.max_sessions}) exceeded. Please try again later.")
297
+ raise RuntimeError(f"Tool unavailable: Maximum concurrent sessions "
298
+ f"({self._client_config.max_sessions}) exceeded.")
225
299
 
226
300
  # Create session client lazily
227
301
  logger.info("Creating new MCP client for session: %s", truncate_session_id(session_id))
228
- session_client = await self._create_session_client(session_id)
229
-
230
- # Create session data with all components
231
- session_data = SessionData(client=session_client, last_activity=datetime.now(), ref_count=0)
302
+ session_client, stop_event, lifetime_task = await self._create_session_client(session_id)
303
+ session_data = SessionData(
304
+ client=session_client,
305
+ last_activity=datetime.now(),
306
+ ref_count=0,
307
+ stop_event=stop_event,
308
+ lifetime_task=lifetime_task,
309
+ )
232
310
 
233
311
  # Cache the session data
234
312
  self._sessions[session_id] = session_data
@@ -241,23 +319,32 @@ class MCPFunctionGroup(FunctionGroup):
241
319
  # Ensure session exists - create it if it doesn't
242
320
  if session_id not in self._sessions:
243
321
  # Create session client first
244
- await self._get_session_client(session_id)
245
- # Session should now exist in _sessions
322
+ await self._get_session_client(session_id) # START read phase: bump ref_count under reader + session lock
246
323
 
247
- # Get session data (session must exist at this point)
248
- session_data = self._sessions[session_id]
249
-
250
- # Thread-safe reference counting using per-session lock
251
- async with session_data.lock:
252
- session_data.ref_count += 1
324
+ async with self._session_rwlock.reader:
325
+ sdata = self._sessions.get(session_id)
326
+ if not sdata:
327
+ # this can happen if the session is cleaned up between the check and the lock
328
+ # this is rare and we can just return that the tool is temporarily unavailable
329
+ yield None
330
+ return
331
+ async with sdata.lock:
332
+ sdata.ref_count += 1
333
+ client = sdata.client # capture
334
+ # END read phase (release reader before long await)
253
335
 
254
336
  try:
255
- yield
337
+ yield client
256
338
  finally:
257
- async with session_data.lock:
258
- session_data.ref_count -= 1
259
-
260
- async def _create_session_client(self, session_id: str) -> MCPBaseClient:
339
+ # Brief read phase to decrement ref_count and touch activity
340
+ async with self._session_rwlock.reader:
341
+ sdata = self._sessions.get(session_id)
342
+ if sdata:
343
+ async with sdata.lock:
344
+ sdata.ref_count -= 1
345
+ sdata.last_activity = datetime.now()
346
+
347
+ async def _create_session_client(self, session_id: str) -> tuple[MCPBaseClient, asyncio.Event, asyncio.Task]:
261
348
  """Create a new MCP client instance for the session."""
262
349
  from nat.plugins.mcp.client_base import MCPStreamableHTTPClient
263
350
 
@@ -280,11 +367,50 @@ class MCPFunctionGroup(FunctionGroup):
280
367
  # per-user sessions are only supported for streamable-http transport
281
368
  raise ValueError(f"Unsupported transport: {config.server.transport}")
282
369
 
283
- # Initialize the client
284
- await client.__aenter__()
370
+ ready = asyncio.Event()
371
+ stop_event = asyncio.Event()
372
+
373
+ async def _lifetime():
374
+ """
375
+ Create a lifetime task to respect task boundaries and ensure the
376
+ cancel scope is entered and exited in the same task.
377
+ """
378
+ try:
379
+ async with client:
380
+ ready.set()
381
+ await stop_event.wait()
382
+ except Exception:
383
+ ready.set() # Ensure we don't hang the waiter
384
+ raise
385
+
386
+ task = asyncio.create_task(_lifetime(), name=f"mcp-session-{truncate_session_id(session_id)}")
387
+
388
+ # Wait for initialization with timeout to prevent infinite hangs
389
+ timeout = config.tool_call_timeout.total_seconds() if config else 300
390
+ try:
391
+ await asyncio.wait_for(ready.wait(), timeout=timeout)
392
+ except TimeoutError:
393
+ task.cancel()
394
+ try:
395
+ await task
396
+ except asyncio.CancelledError:
397
+ pass
398
+ logger.error("Session client initialization timed out after %ds for %s",
399
+ timeout,
400
+ truncate_session_id(session_id))
401
+ raise RuntimeError(f"Session client initialization timed out after {timeout}s")
402
+
403
+ # Check if initialization failed before ready was set
404
+ if task.done():
405
+ try:
406
+ await task # Re-raise exception if the task failed
407
+ except Exception as e:
408
+ logger.error("Failed to initialize session client for %s: %s", truncate_session_id(session_id), e)
409
+ raise RuntimeError(f"Failed to initialize session client: {e}") from e
285
410
 
286
411
  logger.info("Created session client for session: %s", truncate_session_id(session_id))
287
- return client
412
+ # NOTE: caller will place client into SessionData and attach stop_event/task
413
+ return client, stop_event, task
288
414
 
289
415
 
290
416
  def mcp_session_tool_function(tool, function_group: MCPFunctionGroup):
@@ -316,8 +442,9 @@ def mcp_session_tool_function(tool, function_group: MCPFunctionGroup):
316
442
  session_tool = await client.get_tool(tool.name)
317
443
  else:
318
444
  # Use session usage context to prevent cleanup during tool execution
319
- async with function_group._session_usage_context(session_id):
320
- client = await function_group._get_session_client(session_id)
445
+ async with function_group._session_usage_context(session_id) as client:
446
+ if client is None:
447
+ return "Tool temporarily unavailable. Try again."
321
448
  session_tool = await client.get_tool(tool.name)
322
449
 
323
450
  # Preserve original calling convention
@@ -428,11 +555,7 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
428
555
  description = override.description if override and override.description else tool.description
429
556
 
430
557
  # Create the tool function according to configuration
431
- if config.session_aware_tools:
432
- tool_fn = mcp_session_tool_function(tool, group)
433
- else:
434
- from nat.plugins.mcp.tool import mcp_tool_function
435
- tool_fn = mcp_tool_function(tool)
558
+ tool_fn = mcp_session_tool_function(tool, group)
436
559
 
437
560
  # Normalize optional typing for linter/type-checker compatibility
438
561
  single_fn = tool_fn.single_fn
@@ -1,7 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-nat-mcp
3
- Version: 1.4.0a20251008
3
+ Version: 1.4.0a20251010
4
4
  Summary: Subpackage for MCP client integration in NeMo Agent toolkit
5
+ Author: NVIDIA Corporation
6
+ Maintainer: NVIDIA Corporation
7
+ License-Expression: Apache-2.0
8
+ Project-URL: documentation, https://docs.nvidia.com/nemo/agent-toolkit/latest/
9
+ Project-URL: source, https://github.com/NVIDIA/NeMo-Agent-Toolkit
5
10
  Keywords: ai,rag,agents,mcp
6
11
  Classifier: Programming Language :: Python
7
12
  Classifier: Programming Language :: Python :: 3.11
@@ -9,9 +14,12 @@ Classifier: Programming Language :: Python :: 3.12
9
14
  Classifier: Programming Language :: Python :: 3.13
10
15
  Requires-Python: <3.14,>=3.11
11
16
  Description-Content-Type: text/markdown
12
- Requires-Dist: nvidia-nat==v1.4.0a20251008
17
+ License-File: LICENSE.md
18
+ License-File: LICENSE-3rd-party.txt
19
+ Requires-Dist: nvidia-nat==v1.4.0a20251010
13
20
  Requires-Dist: aiorwlock~=1.5
14
21
  Requires-Dist: mcp~=1.14
22
+ Dynamic: license-file
15
23
 
16
24
  <!--
17
25
  SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
@@ -2,7 +2,7 @@ nat/meta/pypi.md,sha256=EYyJTCCEOWzuuz-uNaYJ_WBk55Jiig87wcUr9E4g0yw,1484
2
2
  nat/plugins/mcp/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
3
3
  nat/plugins/mcp/client_base.py,sha256=nos9NTQ2NlU9vd0PFb1n_q9AZtPKQ5OZ6EU74Ydo1C0,26533
4
4
  nat/plugins/mcp/client_config.py,sha256=l9tVUHe8WdFPJ9rXDg8dZkQi1dvHGYwoqQ8Glqg2LGs,6783
5
- nat/plugins/mcp/client_impl.py,sha256=uw_iCOwkwbkHYGRW0XSis3wL3jsmf1RDOO6epVy5UPY,21372
5
+ nat/plugins/mcp/client_impl.py,sha256=ekyw-hy5AqBJNkBvVH9Dl4s3wTtadLOcGxOu80Qbv0E,27370
6
6
  nat/plugins/mcp/exception_handler.py,sha256=4JVdZDJL4LyumZEcMIEBK2LYC6djuSMzqUhQDZZ6dUo,7648
7
7
  nat/plugins/mcp/exceptions.py,sha256=EGVOnYlui8xufm8dhJyPL1SUqBLnCGOTvRoeyNcmcWE,5980
8
8
  nat/plugins/mcp/register.py,sha256=HOT2Wl2isGuyFc7BUTi58-BbjI5-EtZMZo7stsv5pN4,831
@@ -14,8 +14,10 @@ nat/plugins/mcp/auth/auth_provider.py,sha256=BgH66DlZgzhLDLO4cBERpHvNAmli5fMo_SC
14
14
  nat/plugins/mcp/auth/auth_provider_config.py,sha256=b1AaXzOuAkygKXAWSxMKWg8wfW8k33tmUUq6Dk5Mmwk,4038
15
15
  nat/plugins/mcp/auth/register.py,sha256=L2x69NjJPS4s6CCE5myzWVrWn3e_ttHyojmGXvBipMg,1228
16
16
  nat/plugins/mcp/auth/token_storage.py,sha256=aS13ZvEJXcYzkZ0GSbrSor4i5bpjD5BkXHQw1iywC9k,9240
17
- nvidia_nat_mcp-1.4.0a20251008.dist-info/METADATA,sha256=A9ksFEOx30iLCUQw3m2s8KhitBt5hAyKL-oL46NeWn8,2013
18
- nvidia_nat_mcp-1.4.0a20251008.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
19
- nvidia_nat_mcp-1.4.0a20251008.dist-info/entry_points.txt,sha256=rYvUp4i-klBr3bVNh7zYOPXret704vTjvCk1qd7FooI,97
20
- nvidia_nat_mcp-1.4.0a20251008.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
21
- nvidia_nat_mcp-1.4.0a20251008.dist-info/RECORD,,
17
+ nvidia_nat_mcp-1.4.0a20251010.dist-info/licenses/LICENSE-3rd-party.txt,sha256=fOk5jMmCX9YoKWyYzTtfgl-SUy477audFC5hNY4oP7Q,284609
18
+ nvidia_nat_mcp-1.4.0a20251010.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
19
+ nvidia_nat_mcp-1.4.0a20251010.dist-info/METADATA,sha256=SiuoE4J0XIqm-ApDSMuAUh6CYhP5lMcD5DqtdLOafnA,2330
20
+ nvidia_nat_mcp-1.4.0a20251010.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
21
+ nvidia_nat_mcp-1.4.0a20251010.dist-info/entry_points.txt,sha256=rYvUp4i-klBr3bVNh7zYOPXret704vTjvCk1qd7FooI,97
22
+ nvidia_nat_mcp-1.4.0a20251010.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
23
+ nvidia_nat_mcp-1.4.0a20251010.dist-info/RECORD,,