nvidia-nat 1.3.0a20250929__py3-none-any.whl → 1.3.0a20250930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. nat/agent/base.py +1 -1
  2. nat/agent/rewoo_agent/agent.py +100 -108
  3. nat/agent/rewoo_agent/register.py +4 -1
  4. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +26 -18
  5. nat/builder/builder.py +1 -1
  6. nat/builder/context.py +2 -2
  7. nat/builder/front_end.py +1 -1
  8. nat/cli/cli_utils/config_override.py +1 -1
  9. nat/cli/commands/mcp/mcp.py +2 -2
  10. nat/cli/commands/start.py +1 -1
  11. nat/cli/type_registry.py +1 -1
  12. nat/control_flow/router_agent/register.py +1 -1
  13. nat/data_models/api_server.py +9 -9
  14. nat/data_models/authentication.py +3 -9
  15. nat/data_models/dataset_handler.py +1 -1
  16. nat/eval/evaluator/base_evaluator.py +1 -1
  17. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  18. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  19. nat/experimental/decorators/experimental_warning_decorator.py +1 -2
  20. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  21. nat/front_ends/console/authentication_flow_handler.py +82 -30
  22. nat/front_ends/console/console_front_end_plugin.py +1 -1
  23. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  24. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +188 -2
  25. nat/front_ends/fastapi/job_store.py +2 -2
  26. nat/front_ends/fastapi/message_handler.py +4 -4
  27. nat/front_ends/fastapi/message_validator.py +5 -5
  28. nat/front_ends/mcp/tool_converter.py +1 -1
  29. nat/llm/utils/thinking.py +1 -1
  30. nat/observability/exporter/base_exporter.py +1 -1
  31. nat/observability/exporter/span_exporter.py +1 -1
  32. nat/observability/exporter_manager.py +2 -2
  33. nat/observability/processor/batching_processor.py +1 -1
  34. nat/profiler/decorators/function_tracking.py +2 -2
  35. nat/profiler/parameter_optimization/parameter_selection.py +3 -4
  36. nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
  37. nat/retriever/milvus/retriever.py +1 -1
  38. nat/settings/global_settings.py +2 -2
  39. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  40. nat/tool/datetime_tools.py +1 -1
  41. nat/utils/data_models/schema_validator.py +1 -1
  42. nat/utils/exception_handlers/automatic_retries.py +1 -1
  43. nat/utils/io/yaml_tools.py +1 -1
  44. nat/utils/type_utils.py +1 -1
  45. {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/METADATA +2 -1
  46. {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/RECORD +51 -51
  47. {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/WHEEL +0 -0
  48. {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/entry_points.txt +0 -0
  49. {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  50. {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/licenses/LICENSE.md +0 -0
  51. {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/top_level.txt +0 -0
@@ -14,8 +14,8 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import typing
17
+ from datetime import UTC
17
18
  from datetime import datetime
18
- from datetime import timezone
19
19
  from enum import Enum
20
20
 
21
21
  import httpx
@@ -166,13 +166,7 @@ class BearerTokenCred(_CredBase):
166
166
 
167
167
 
168
168
  Credential = typing.Annotated[
169
- typing.Union[
170
- HeaderCred,
171
- QueryCred,
172
- CookieCred,
173
- BasicAuthCred,
174
- BearerTokenCred,
175
- ],
169
+ HeaderCred | QueryCred | CookieCred | BasicAuthCred | BearerTokenCred,
176
170
  Field(discriminator="kind"),
177
171
  ]
178
172
 
@@ -213,7 +207,7 @@ class AuthResult(BaseModel):
213
207
  """
214
208
  Checks if the authentication token has expired.
215
209
  """
216
- return bool(self.token_expires_at and datetime.now(timezone.utc) >= self.token_expires_at)
210
+ return bool(self.token_expires_at and datetime.now(UTC) >= self.token_expires_at)
217
211
 
218
212
  def as_requests_kwargs(self) -> dict[str, typing.Any]:
219
213
  """
@@ -80,7 +80,7 @@ class EvalDatasetJsonConfig(EvalDatasetBaseConfig, name="json"):
80
80
 
81
81
 
82
82
  def read_jsonl(file_path: FilePath):
83
- with open(file_path, 'r', encoding='utf-8') as f:
83
+ with open(file_path, encoding='utf-8') as f:
84
84
  data = [json.loads(line) for line in f]
85
85
  return pd.DataFrame(data)
86
86
 
@@ -71,7 +71,7 @@ class BaseEvaluator(ABC):
71
71
  TqdmPositionRegistry.release(tqdm_position)
72
72
 
73
73
  # Compute average if possible
74
- numeric_scores = [item.score for item in output_items if isinstance(item.score, (int, float))]
74
+ numeric_scores = [item.score for item in output_items if isinstance(item.score, int | float)]
75
75
  avg_score = round(sum(numeric_scores) / len(numeric_scores), 2) if numeric_scores else None
76
76
 
77
77
  return EvalOutput(average_score=avg_score, eval_output_items=output_items)
@@ -204,7 +204,7 @@ class SweBenchEvaluator:
204
204
  # if report file is not present, return empty EvalOutput
205
205
  avg_score = 0.0
206
206
  if report_file.exists():
207
- with open(report_file, "r", encoding="utf-8") as f:
207
+ with open(report_file, encoding="utf-8") as f:
208
208
  report = json.load(f)
209
209
  resolved_instances = report.get("resolved_instances", 0)
210
210
  total_instances = report.get("total_instances", 0)
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import logging
17
- from typing import Callable
17
+ from collections.abc import Callable
18
18
 
19
19
  from langchain.output_parsers import ResponseSchema
20
20
  from langchain.output_parsers import StructuredOutputParser
@@ -137,8 +137,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
137
137
  @functools.wraps(func)
138
138
  def sync_gen_wrapper(*args, **kwargs) -> Generator[Any, Any, Any]:
139
139
  issue_experimental_warning(function_name, feature_name, metadata)
140
- for item in func(*args, **kwargs):
141
- yield item # yield the original item
140
+ yield from func(*args, **kwargs) # yield the original item
142
141
 
143
142
  return sync_gen_wrapper
144
143
 
@@ -71,7 +71,7 @@ class LLMBasedOutputMergingSelector(StrategyBase):
71
71
  raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
72
72
  "This error can be resolved by installing nvidia-nat-langchain.")
73
73
 
74
- from typing import Callable
74
+ from collections.abc import Callable
75
75
 
76
76
  from pydantic import BaseModel
77
77
 
@@ -14,13 +14,16 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import asyncio
17
+ import logging
17
18
  import secrets
18
19
  import webbrowser
19
20
  from dataclasses import dataclass
20
21
  from dataclasses import field
21
22
 
22
23
  import click
24
+ import httpx
23
25
  import pkce
26
+ from authlib.common.errors import AuthlibBaseError as OAuthError
24
27
  from authlib.integrations.httpx_client import AsyncOAuth2Client
25
28
  from fastapi import FastAPI
26
29
  from fastapi import Request
@@ -32,6 +35,8 @@ from nat.data_models.authentication import AuthFlowType
32
35
  from nat.data_models.authentication import AuthProviderBaseConfig
33
36
  from nat.front_ends.fastapi.fastapi_front_end_controller import _FastApiFrontEndController
34
37
 
38
+ logger = logging.getLogger(__name__)
39
+
35
40
 
36
41
  # --------------------------------------------------------------------------- #
37
42
  # Helpers #
@@ -87,17 +92,53 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
87
92
  """
88
93
  Separated for easy overriding in tests (to inject ASGITransport).
89
94
  """
90
- client = AsyncOAuth2Client(
91
- client_id=cfg.client_id,
92
- client_secret=cfg.client_secret,
93
- redirect_uri=cfg.redirect_uri,
94
- scope=" ".join(cfg.scopes) if cfg.scopes else None,
95
- token_endpoint=cfg.token_url,
96
- token_endpoint_auth_method=cfg.token_endpoint_auth_method,
97
- code_challenge_method="S256" if cfg.use_pkce else None,
98
- )
99
- self._oauth_client = client
100
- return client
95
+ try:
96
+ client = AsyncOAuth2Client(
97
+ client_id=cfg.client_id,
98
+ client_secret=cfg.client_secret,
99
+ redirect_uri=cfg.redirect_uri,
100
+ scope=" ".join(cfg.scopes) if cfg.scopes else None,
101
+ token_endpoint=cfg.token_url,
102
+ token_endpoint_auth_method=cfg.token_endpoint_auth_method,
103
+ code_challenge_method="S256" if cfg.use_pkce else None,
104
+ )
105
+ self._oauth_client = client
106
+ return client
107
+ except (OAuthError, ValueError, TypeError) as e:
108
+ raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
109
+ except Exception as e:
110
+ raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
111
+
112
+ def _create_authorization_url(self,
113
+ client: AsyncOAuth2Client,
114
+ config: OAuth2AuthCodeFlowProviderConfig,
115
+ state: str,
116
+ verifier: str | None = None,
117
+ challenge: str | None = None) -> str:
118
+ """
119
+ Create OAuth authorization URL with proper error handling.
120
+
121
+ Args:
122
+ client: The OAuth2 client instance
123
+ config: OAuth2 configuration
124
+ state: OAuth state parameter
125
+ verifier: PKCE verifier (if using PKCE)
126
+ challenge: PKCE challenge (if using PKCE)
127
+
128
+ Returns:
129
+ The authorization URL
130
+ """
131
+ try:
132
+ auth_url, _ = client.create_authorization_url(
133
+ config.authorization_url,
134
+ state=state,
135
+ code_verifier=verifier if config.use_pkce else None,
136
+ code_challenge=challenge if config.use_pkce else None,
137
+ **(config.authorization_kwargs or {})
138
+ )
139
+ return auth_url
140
+ except (OAuthError, ValueError, TypeError) as e:
141
+ raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
101
142
 
102
143
  # --------------------------- HTTP Basic ------------------------------ #
103
144
  @staticmethod
@@ -131,13 +172,12 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
131
172
  flow_state.verifier = verifier
132
173
  flow_state.challenge = challenge
133
174
 
134
- auth_url, _ = client.create_authorization_url(
135
- cfg.authorization_url,
136
- state=state,
137
- code_verifier=flow_state.verifier if cfg.use_pkce else None,
138
- code_challenge=flow_state.challenge if cfg.use_pkce else None,
139
- **(cfg.authorization_kwargs or {})
140
- )
175
+ # Create authorization URL using helper function
176
+ auth_url = self._create_authorization_url(client=client,
177
+ config=cfg,
178
+ state=state,
179
+ verifier=flow_state.verifier,
180
+ challenge=flow_state.challenge)
141
181
 
142
182
  # Register flow + maybe spin up redirect handler
143
183
  async with self._server_lock:
@@ -149,14 +189,18 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
149
189
  self._flows[state] = flow_state
150
190
  self._active_flows += 1
151
191
 
152
- click.echo("Your browser has been opened for authentication.")
153
- webbrowser.open(auth_url)
192
+ try:
193
+ webbrowser.open(auth_url)
194
+ click.echo("Your browser has been opened for authentication.")
195
+ except Exception as e:
196
+ logger.error("Browser open failed: %s", e)
197
+ raise RuntimeError(f"Browser open failed: {e}") from e
154
198
 
155
199
  # Wait for the redirect to land
156
200
  try:
157
201
  token = await asyncio.wait_for(flow_state.future, timeout=300)
158
- except asyncio.TimeoutError:
159
- raise RuntimeError("Authentication timed out (5 min).")
202
+ except TimeoutError as exc:
203
+ raise RuntimeError("Authentication timed out (5 min).") from exc
160
204
  finally:
161
205
  async with self._server_lock:
162
206
  self._flows.pop(state, None)
@@ -175,9 +219,9 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
175
219
  # --------------- redirect server / in‑process app -------------------- #
176
220
  async def _build_redirect_app(self) -> FastAPI:
177
221
  """
178
- * If cfg.run_redirect_local_server == True → start a uvicorn server (old behaviour).
179
- * Else → only build the FastAPI app and save it to `self._redirect_app`
180
- for in‑process testing with ASGITransport.
222
+ * If cfg.run_redirect_local_server == True → start a local server.
223
+ * Else → only build the redirect app and save it to `self._redirect_app`
224
+ for in‑process testing.
181
225
  """
182
226
  app = FastAPI()
183
227
 
@@ -195,8 +239,16 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
195
239
  state=state,
196
240
  )
197
241
  flow_state.future.set_result(token)
198
- except Exception as exc: # noqa: BLE001
199
- flow_state.future.set_exception(exc)
242
+ except OAuthError as e:
243
+ flow_state.future.set_exception(
244
+ RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})"))
245
+ return "Authentication failed: Authorization server rejected the request. You may close this tab."
246
+ except httpx.HTTPError as e:
247
+ flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}"))
248
+ return "Authentication failed: Network error occurred. You may close this tab."
249
+ except Exception as e:
250
+ flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}"))
251
+ return "Authentication failed: An unexpected error occurred. You may close this tab."
200
252
  return "Authentication successful – you may close this tab."
201
253
 
202
254
  return app
@@ -213,7 +265,7 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
213
265
 
214
266
  asyncio.create_task(self._server_controller.start_server(host="localhost", port=8000))
215
267
 
216
- # Give uvicorn a moment to bind sockets before we return
268
+ # Give the server a moment to bind sockets before we return
217
269
  await asyncio.sleep(0.3)
218
270
  except Exception as exc: # noqa: BLE001
219
271
  raise RuntimeError(f"Failed to start redirect server: {exc}") from exc
@@ -227,7 +279,7 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
227
279
  @property
228
280
  def redirect_app(self) -> FastAPI | None:
229
281
  """
230
- In testmode (run_redirect_local_server=False) the in‑memory FastAPI
231
- app is exposed so you can mount it on `httpx.ASGITransport`.
282
+ In test mode (run_redirect_local_server=False) the in‑memory redirect
283
+ app is exposed for testing purposes.
232
284
  """
233
285
  return self._redirect_app
@@ -88,7 +88,7 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
88
88
  elif (self.front_end_config.input_file):
89
89
 
90
90
  # Run the workflow
91
- with open(self.front_end_config.input_file, "r", encoding="utf-8") as f:
91
+ with open(self.front_end_config.input_file, encoding="utf-8") as f:
92
92
 
93
93
  async with session_manager.workflow.run(f) as runner:
94
94
  runner_outputs = await runner.result(to_type=str)
@@ -22,6 +22,7 @@ from dataclasses import dataclass
22
22
  from dataclasses import field
23
23
 
24
24
  import pkce
25
+ from authlib.common.errors import AuthlibBaseError as OAuthError
25
26
  from authlib.integrations.httpx_client import AsyncOAuth2Client
26
27
 
27
28
  from nat.authentication.interfaces import FlowHandlerBase
@@ -61,14 +62,50 @@ class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
61
62
 
62
63
  raise NotImplementedError(f"Authentication method '{method}' is not supported by the websocket frontend.")
63
64
 
64
- def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig):
65
- return AsyncOAuth2Client(client_id=config.client_id,
66
- client_secret=config.client_secret,
67
- redirect_uri=config.redirect_uri,
68
- scope=" ".join(config.scopes) if config.scopes else None,
69
- token_endpoint=config.token_url,
70
- code_challenge_method='S256' if config.use_pkce else None,
71
- token_endpoint_auth_method=config.token_endpoint_auth_method)
65
+ def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig) -> AsyncOAuth2Client:
66
+ try:
67
+ return AsyncOAuth2Client(client_id=config.client_id,
68
+ client_secret=config.client_secret,
69
+ redirect_uri=config.redirect_uri,
70
+ scope=" ".join(config.scopes) if config.scopes else None,
71
+ token_endpoint=config.token_url,
72
+ code_challenge_method='S256' if config.use_pkce else None,
73
+ token_endpoint_auth_method=config.token_endpoint_auth_method)
74
+ except (OAuthError, ValueError, TypeError) as e:
75
+ raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
76
+ except Exception as e:
77
+ raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
78
+
79
+ def _create_authorization_url(self,
80
+ client: AsyncOAuth2Client,
81
+ config: OAuth2AuthCodeFlowProviderConfig,
82
+ state: str,
83
+ verifier: str = None,
84
+ challenge: str = None) -> str:
85
+ """
86
+ Create OAuth authorization URL with proper error handling.
87
+
88
+ Args:
89
+ client: The OAuth2 client instance
90
+ config: OAuth2 configuration
91
+ state: OAuth state parameter
92
+ verifier: PKCE verifier (if using PKCE)
93
+ challenge: PKCE challenge (if using PKCE)
94
+
95
+ Returns:
96
+ The authorization URL
97
+ """
98
+ try:
99
+ authorization_url, _ = client.create_authorization_url(
100
+ config.authorization_url,
101
+ state=state,
102
+ code_verifier=verifier if config.use_pkce else None,
103
+ code_challenge=challenge if config.use_pkce else None,
104
+ **(config.authorization_kwargs or {})
105
+ )
106
+ return authorization_url
107
+ except (OAuthError, ValueError, TypeError) as e:
108
+ raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
72
109
 
73
110
  async def _handle_oauth2_auth_code_flow(self, config: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext:
74
111
 
@@ -82,21 +119,19 @@ class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
82
119
  flow_state.verifier = verifier
83
120
  flow_state.challenge = challenge
84
121
 
85
- authorization_url, _ = flow_state.client.create_authorization_url(
86
- config.authorization_url,
87
- state=state,
88
- code_verifier=flow_state.verifier if config.use_pkce else None,
89
- code_challenge=flow_state.challenge if config.use_pkce else None,
90
- **(config.authorization_kwargs or {})
91
- )
122
+ authorization_url = self._create_authorization_url(client=flow_state.client,
123
+ config=config,
124
+ state=state,
125
+ verifier=flow_state.verifier,
126
+ challenge=flow_state.challenge)
92
127
 
93
128
  await self._add_flow_cb(state, flow_state)
94
129
  await self._web_socket_message_handler.create_websocket_message(_HumanPromptOAuthConsent(text=authorization_url)
95
130
  )
96
131
  try:
97
132
  token = await asyncio.wait_for(flow_state.future, timeout=300)
98
- except asyncio.TimeoutError:
99
- raise RuntimeError("Authentication flow timed out after 5 minutes.")
133
+ except TimeoutError as exc:
134
+ raise RuntimeError("Authentication flow timed out after 5 minutes.") from exc
100
135
  finally:
101
136
 
102
137
  await self._remove_flow_cb(state)
@@ -25,18 +25,21 @@ from collections.abc import Callable
25
25
  from contextlib import asynccontextmanager
26
26
  from pathlib import Path
27
27
 
28
+ import httpx
29
+ from authlib.common.errors import AuthlibBaseError as OAuthError
28
30
  from fastapi import Body
29
31
  from fastapi import FastAPI
32
+ from fastapi import HTTPException
30
33
  from fastapi import Request
31
34
  from fastapi import Response
32
35
  from fastapi import UploadFile
33
- from fastapi.exceptions import HTTPException
34
36
  from fastapi.middleware.cors import CORSMiddleware
35
37
  from fastapi.responses import StreamingResponse
36
38
  from pydantic import BaseModel
37
39
  from pydantic import Field
38
40
  from starlette.websockets import WebSocket
39
41
 
42
+ from nat.builder.function import Function
40
43
  from nat.builder.workflow_builder import WorkflowBuilder
41
44
  from nat.data_models.api_server import ChatRequest
42
45
  from nat.data_models.api_server import ChatResponse
@@ -241,6 +244,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
241
244
  await self.add_evaluate_route(app, SessionManager(await builder.build()))
242
245
  await self.add_static_files_route(app, builder)
243
246
  await self.add_authorization_route(app)
247
+ await self.add_mcp_client_tool_list_route(app, builder)
244
248
 
245
249
  for ep in self.front_end_config.endpoints:
246
250
 
@@ -1071,8 +1075,13 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
1071
1075
  code_verifier=verifier,
1072
1076
  state=state)
1073
1077
  flow_state.future.set_result(res)
1078
+ except OAuthError as e:
1079
+ flow_state.future.set_exception(
1080
+ RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})"))
1081
+ except httpx.HTTPError as e:
1082
+ flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}"))
1074
1083
  except Exception as e:
1075
- flow_state.future.set_exception(e)
1084
+ flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}"))
1076
1085
 
1077
1086
  return HTMLResponse(content=AUTH_REDIRECT_SUCCESS_HTML,
1078
1087
  status_code=200,
@@ -1088,6 +1097,183 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
1088
1097
  methods=["GET"],
1089
1098
  description="Handles the authorization code and state returned from the Authorization Code Grant Flow.")
1090
1099
 
1100
+ async def add_mcp_client_tool_list_route(self, app: FastAPI, builder: WorkflowBuilder):
1101
+ """Add the MCP client tool list endpoint to the FastAPI app."""
1102
+ from typing import Any
1103
+
1104
+ from pydantic import BaseModel
1105
+
1106
+ class MCPToolInfo(BaseModel):
1107
+ name: str
1108
+ description: str
1109
+ server: str
1110
+ available: bool
1111
+
1112
+ class MCPClientToolListResponse(BaseModel):
1113
+ mcp_clients: list[dict[str, Any]]
1114
+
1115
+ async def get_mcp_client_tool_list() -> MCPClientToolListResponse:
1116
+ """
1117
+ Get the list of MCP tools from all MCP clients in the workflow configuration.
1118
+ Checks session health and compares with workflow function group configuration.
1119
+ """
1120
+ mcp_clients_info = []
1121
+
1122
+ try:
1123
+ # Get all function groups from the builder
1124
+ function_groups = builder._function_groups
1125
+
1126
+ # Find MCP client function groups
1127
+ for group_name, configured_group in function_groups.items():
1128
+ if configured_group.config.type != "mcp_client":
1129
+ continue
1130
+
1131
+ from nat.plugins.mcp.client_impl import MCPClientConfig
1132
+
1133
+ config = configured_group.config
1134
+ assert isinstance(config, MCPClientConfig)
1135
+
1136
+ # Reuse the existing MCP client session stored on the function group instance
1137
+ group_instance = configured_group.instance
1138
+
1139
+ client = group_instance.mcp_client
1140
+ if client is None:
1141
+ raise RuntimeError(f"MCP client not found for group {group_name}")
1142
+
1143
+ try:
1144
+ session_healthy = False
1145
+ server_tools: dict[str, Any] = {}
1146
+
1147
+ try:
1148
+ server_tools = await client.get_tools()
1149
+ session_healthy = True
1150
+ except Exception as e:
1151
+ logger.exception(f"Failed to connect to MCP server {client.server_name}: {e}")
1152
+ session_healthy = False
1153
+
1154
+ # Get workflow function group configuration (configured client-side tools)
1155
+ configured_short_names: set[str] = set()
1156
+ configured_full_to_fn: dict[str, Function] = {}
1157
+ try:
1158
+ # Pass a no-op filter function to bypass any default filtering that might check
1159
+ # health status, preventing potential infinite recursion during health status checks.
1160
+ async def pass_through_filter(fn):
1161
+ return fn
1162
+
1163
+ accessible_functions = await group_instance.get_accessible_functions(
1164
+ filter_fn=pass_through_filter)
1165
+ configured_full_to_fn = accessible_functions
1166
+ configured_short_names = {name.split('.', 1)[1] for name in accessible_functions.keys()}
1167
+ except Exception as e:
1168
+ logger.exception(f"Failed to get accessible functions for group {group_name}: {e}")
1169
+
1170
+ # Build alias->original mapping and override configs from overrides
1171
+ alias_to_original: dict[str, str] = {}
1172
+ override_configs: dict[str, Any] = {}
1173
+ try:
1174
+ if config.tool_overrides is not None:
1175
+ for orig_name, override in config.tool_overrides.items():
1176
+ if override.alias is not None:
1177
+ alias_to_original[override.alias] = orig_name
1178
+ override_configs[override.alias] = override
1179
+ else:
1180
+ override_configs[orig_name] = override
1181
+ except Exception:
1182
+ pass
1183
+
1184
+ # Create tool info list (always return configured tools; mark availability)
1185
+ tools_info: list[dict[str, Any]] = []
1186
+ available_count = 0
1187
+ for wf_fn, fn_short in zip(configured_full_to_fn.values(), configured_short_names):
1188
+ orig_name = alias_to_original.get(fn_short, fn_short)
1189
+ available = session_healthy and (orig_name in server_tools)
1190
+ if available:
1191
+ available_count += 1
1192
+
1193
+ # Prefer tool override description, then workflow function description,
1194
+ # then server description
1195
+ description = ""
1196
+ if fn_short in override_configs and override_configs[fn_short].description:
1197
+ description = override_configs[fn_short].description
1198
+ elif wf_fn.description:
1199
+ description = wf_fn.description
1200
+ elif available and orig_name in server_tools:
1201
+ description = server_tools[orig_name].description or ""
1202
+
1203
+ tools_info.append(
1204
+ MCPToolInfo(name=fn_short,
1205
+ description=description or "",
1206
+ server=client.server_name,
1207
+ available=available).model_dump())
1208
+
1209
+ # Sort tools_info by name to maintain consistent ordering
1210
+ tools_info.sort(key=lambda x: x['name'])
1211
+
1212
+ mcp_clients_info.append({
1213
+ "function_group": group_name,
1214
+ "server": client.server_name,
1215
+ "transport": config.server.transport,
1216
+ "session_healthy": session_healthy,
1217
+ "tools": tools_info,
1218
+ "total_tools": len(configured_short_names),
1219
+ "available_tools": available_count
1220
+ })
1221
+
1222
+ except Exception as e:
1223
+ logger.error(f"Error processing MCP client {group_name}: {e}")
1224
+ mcp_clients_info.append({
1225
+ "function_group": group_name,
1226
+ "server": "unknown",
1227
+ "transport": config.server.transport if config.server else "unknown",
1228
+ "session_healthy": False,
1229
+ "error": str(e),
1230
+ "tools": [],
1231
+ "total_tools": 0,
1232
+ "workflow_tools": 0
1233
+ })
1234
+
1235
+ return MCPClientToolListResponse(mcp_clients=mcp_clients_info)
1236
+
1237
+ except Exception as e:
1238
+ logger.error(f"Error in MCP client tool list endpoint: {e}")
1239
+ raise HTTPException(status_code=500, detail=f"Failed to retrieve MCP client information: {str(e)}")
1240
+
1241
+ # Add the route to the FastAPI app
1242
+ app.add_api_route(
1243
+ path="/mcp/client/tool/list",
1244
+ endpoint=get_mcp_client_tool_list,
1245
+ methods=["GET"],
1246
+ response_model=MCPClientToolListResponse,
1247
+ description="Get list of MCP client tools with session health and workflow configuration comparison",
1248
+ responses={
1249
+ 200: {
1250
+ "description": "Successfully retrieved MCP client tool information",
1251
+ "content": {
1252
+ "application/json": {
1253
+ "example": {
1254
+ "mcp_clients": [{
1255
+ "function_group": "mcp_tools",
1256
+ "server": "streamable-http:http://localhost:9901/mcp",
1257
+ "transport": "streamable-http",
1258
+ "session_healthy": True,
1259
+ "tools": [{
1260
+ "name": "tool_a",
1261
+ "description": "Tool A description",
1262
+ "server": "streamable-http:http://localhost:9901/mcp",
1263
+ "available": True
1264
+ }],
1265
+ "total_tools": 1,
1266
+ "available_tools": 1
1267
+ }]
1268
+ }
1269
+ }
1270
+ }
1271
+ },
1272
+ 500: {
1273
+ "description": "Internal Server Error"
1274
+ }
1275
+ })
1276
+
1091
1277
  async def _add_flow(self, state: str, flow_state: FlowState):
1092
1278
  async with self._outstanding_flows_lock:
1093
1279
  self._outstanding_flows[state] = flow_state
@@ -370,7 +370,7 @@ class JobStore(DaskClientMixin):
370
370
  # Convert BaseModel to JSON string for storage
371
371
  output = output.model_dump_json(round_trip=True)
372
372
 
373
- if isinstance(output, (dict, list)):
373
+ if isinstance(output, dict | list):
374
374
  # Convert dict or list to JSON string for storage
375
375
  output = json.dumps(output)
376
376
 
@@ -555,7 +555,7 @@ class JobStore(DaskClientMixin):
555
555
  logger.exception("Failed to expire %s", job_id)
556
556
 
557
557
  await session.execute(
558
- (update(JobInfo).where(JobInfo.job_id.in_(successfully_expired)).values(is_expired=True)))
558
+ update(JobInfo).where(JobInfo.job_id.in_(successfully_expired)).values(is_expired=True))
559
559
 
560
560
 
561
561
  def get_db_engine(db_url: str | None = None, echo: bool = False, use_async: bool = True) -> "Engine | AsyncEngine":
@@ -105,10 +105,10 @@ class WebSocketMessageHandler:
105
105
  if (isinstance(validated_message, WebSocketUserMessage)):
106
106
  await self.process_workflow_request(validated_message)
107
107
 
108
- elif isinstance(validated_message,
109
- (WebSocketSystemResponseTokenMessage,
110
- WebSocketSystemIntermediateStepMessage,
111
- WebSocketSystemInteractionMessage)):
108
+ elif isinstance(
109
+ validated_message,
110
+ WebSocketSystemResponseTokenMessage | WebSocketSystemIntermediateStepMessage
111
+ | WebSocketSystemInteractionMessage):
112
112
  # These messages are already handled by self.create_websocket_message(data_model=value, …)
113
113
  # No further processing is needed here.
114
114
  pass