nvidia-nat 1.3.0a20250929__py3-none-any.whl → 1.3.0a20250930__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +1 -1
- nat/agent/rewoo_agent/agent.py +100 -108
- nat/agent/rewoo_agent/register.py +4 -1
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +26 -18
- nat/builder/builder.py +1 -1
- nat/builder/context.py +2 -2
- nat/builder/front_end.py +1 -1
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/commands/mcp/mcp.py +2 -2
- nat/cli/commands/start.py +1 -1
- nat/cli/type_registry.py +1 -1
- nat/control_flow/router_agent/register.py +1 -1
- nat/data_models/api_server.py +9 -9
- nat/data_models/authentication.py +3 -9
- nat/data_models/dataset_handler.py +1 -1
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/swe_bench_evaluator/evaluate.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
- nat/experimental/decorators/experimental_warning_decorator.py +1 -2
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +188 -2
- nat/front_ends/fastapi/job_store.py +2 -2
- nat/front_ends/fastapi/message_handler.py +4 -4
- nat/front_ends/fastapi/message_validator.py +5 -5
- nat/front_ends/mcp/tool_converter.py +1 -1
- nat/llm/utils/thinking.py +1 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +2 -2
- nat/observability/processor/batching_processor.py +1 -1
- nat/profiler/decorators/function_tracking.py +2 -2
- nat/profiler/parameter_optimization/parameter_selection.py +3 -4
- nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
- nat/retriever/milvus/retriever.py +1 -1
- nat/settings/global_settings.py +2 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/datetime_tools.py +1 -1
- nat/utils/data_models/schema_validator.py +1 -1
- nat/utils/exception_handlers/automatic_retries.py +1 -1
- nat/utils/io/yaml_tools.py +1 -1
- nat/utils/type_utils.py +1 -1
- {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/METADATA +2 -1
- {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/RECORD +51 -51
- {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/top_level.txt +0 -0
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import typing
|
|
17
|
+
from datetime import UTC
|
|
17
18
|
from datetime import datetime
|
|
18
|
-
from datetime import timezone
|
|
19
19
|
from enum import Enum
|
|
20
20
|
|
|
21
21
|
import httpx
|
|
@@ -166,13 +166,7 @@ class BearerTokenCred(_CredBase):
|
|
|
166
166
|
|
|
167
167
|
|
|
168
168
|
Credential = typing.Annotated[
|
|
169
|
-
|
|
170
|
-
HeaderCred,
|
|
171
|
-
QueryCred,
|
|
172
|
-
CookieCred,
|
|
173
|
-
BasicAuthCred,
|
|
174
|
-
BearerTokenCred,
|
|
175
|
-
],
|
|
169
|
+
HeaderCred | QueryCred | CookieCred | BasicAuthCred | BearerTokenCred,
|
|
176
170
|
Field(discriminator="kind"),
|
|
177
171
|
]
|
|
178
172
|
|
|
@@ -213,7 +207,7 @@ class AuthResult(BaseModel):
|
|
|
213
207
|
"""
|
|
214
208
|
Checks if the authentication token has expired.
|
|
215
209
|
"""
|
|
216
|
-
return bool(self.token_expires_at and datetime.now(
|
|
210
|
+
return bool(self.token_expires_at and datetime.now(UTC) >= self.token_expires_at)
|
|
217
211
|
|
|
218
212
|
def as_requests_kwargs(self) -> dict[str, typing.Any]:
|
|
219
213
|
"""
|
|
@@ -80,7 +80,7 @@ class EvalDatasetJsonConfig(EvalDatasetBaseConfig, name="json"):
|
|
|
80
80
|
|
|
81
81
|
|
|
82
82
|
def read_jsonl(file_path: FilePath):
|
|
83
|
-
with open(file_path,
|
|
83
|
+
with open(file_path, encoding='utf-8') as f:
|
|
84
84
|
data = [json.loads(line) for line in f]
|
|
85
85
|
return pd.DataFrame(data)
|
|
86
86
|
|
|
@@ -71,7 +71,7 @@ class BaseEvaluator(ABC):
|
|
|
71
71
|
TqdmPositionRegistry.release(tqdm_position)
|
|
72
72
|
|
|
73
73
|
# Compute average if possible
|
|
74
|
-
numeric_scores = [item.score for item in output_items if isinstance(item.score,
|
|
74
|
+
numeric_scores = [item.score for item in output_items if isinstance(item.score, int | float)]
|
|
75
75
|
avg_score = round(sum(numeric_scores) / len(numeric_scores), 2) if numeric_scores else None
|
|
76
76
|
|
|
77
77
|
return EvalOutput(average_score=avg_score, eval_output_items=output_items)
|
|
@@ -204,7 +204,7 @@ class SweBenchEvaluator:
|
|
|
204
204
|
# if report file is not present, return empty EvalOutput
|
|
205
205
|
avg_score = 0.0
|
|
206
206
|
if report_file.exists():
|
|
207
|
-
with open(report_file,
|
|
207
|
+
with open(report_file, encoding="utf-8") as f:
|
|
208
208
|
report = json.load(f)
|
|
209
209
|
resolved_instances = report.get("resolved_instances", 0)
|
|
210
210
|
total_instances = report.get("total_instances", 0)
|
|
@@ -137,8 +137,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
|
|
|
137
137
|
@functools.wraps(func)
|
|
138
138
|
def sync_gen_wrapper(*args, **kwargs) -> Generator[Any, Any, Any]:
|
|
139
139
|
issue_experimental_warning(function_name, feature_name, metadata)
|
|
140
|
-
|
|
141
|
-
yield item # yield the original item
|
|
140
|
+
yield from func(*args, **kwargs) # yield the original item
|
|
142
141
|
|
|
143
142
|
return sync_gen_wrapper
|
|
144
143
|
|
|
@@ -71,7 +71,7 @@ class LLMBasedOutputMergingSelector(StrategyBase):
|
|
|
71
71
|
raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
|
|
72
72
|
"This error can be resolved by installing nvidia-nat-langchain.")
|
|
73
73
|
|
|
74
|
-
from
|
|
74
|
+
from collections.abc import Callable
|
|
75
75
|
|
|
76
76
|
from pydantic import BaseModel
|
|
77
77
|
|
|
@@ -14,13 +14,16 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import asyncio
|
|
17
|
+
import logging
|
|
17
18
|
import secrets
|
|
18
19
|
import webbrowser
|
|
19
20
|
from dataclasses import dataclass
|
|
20
21
|
from dataclasses import field
|
|
21
22
|
|
|
22
23
|
import click
|
|
24
|
+
import httpx
|
|
23
25
|
import pkce
|
|
26
|
+
from authlib.common.errors import AuthlibBaseError as OAuthError
|
|
24
27
|
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
25
28
|
from fastapi import FastAPI
|
|
26
29
|
from fastapi import Request
|
|
@@ -32,6 +35,8 @@ from nat.data_models.authentication import AuthFlowType
|
|
|
32
35
|
from nat.data_models.authentication import AuthProviderBaseConfig
|
|
33
36
|
from nat.front_ends.fastapi.fastapi_front_end_controller import _FastApiFrontEndController
|
|
34
37
|
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
35
40
|
|
|
36
41
|
# --------------------------------------------------------------------------- #
|
|
37
42
|
# Helpers #
|
|
@@ -87,17 +92,53 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
87
92
|
"""
|
|
88
93
|
Separated for easy overriding in tests (to inject ASGITransport).
|
|
89
94
|
"""
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
95
|
+
try:
|
|
96
|
+
client = AsyncOAuth2Client(
|
|
97
|
+
client_id=cfg.client_id,
|
|
98
|
+
client_secret=cfg.client_secret,
|
|
99
|
+
redirect_uri=cfg.redirect_uri,
|
|
100
|
+
scope=" ".join(cfg.scopes) if cfg.scopes else None,
|
|
101
|
+
token_endpoint=cfg.token_url,
|
|
102
|
+
token_endpoint_auth_method=cfg.token_endpoint_auth_method,
|
|
103
|
+
code_challenge_method="S256" if cfg.use_pkce else None,
|
|
104
|
+
)
|
|
105
|
+
self._oauth_client = client
|
|
106
|
+
return client
|
|
107
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
108
|
+
raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
|
|
109
|
+
except Exception as e:
|
|
110
|
+
raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
|
|
111
|
+
|
|
112
|
+
def _create_authorization_url(self,
|
|
113
|
+
client: AsyncOAuth2Client,
|
|
114
|
+
config: OAuth2AuthCodeFlowProviderConfig,
|
|
115
|
+
state: str,
|
|
116
|
+
verifier: str | None = None,
|
|
117
|
+
challenge: str | None = None) -> str:
|
|
118
|
+
"""
|
|
119
|
+
Create OAuth authorization URL with proper error handling.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
client: The OAuth2 client instance
|
|
123
|
+
config: OAuth2 configuration
|
|
124
|
+
state: OAuth state parameter
|
|
125
|
+
verifier: PKCE verifier (if using PKCE)
|
|
126
|
+
challenge: PKCE challenge (if using PKCE)
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
The authorization URL
|
|
130
|
+
"""
|
|
131
|
+
try:
|
|
132
|
+
auth_url, _ = client.create_authorization_url(
|
|
133
|
+
config.authorization_url,
|
|
134
|
+
state=state,
|
|
135
|
+
code_verifier=verifier if config.use_pkce else None,
|
|
136
|
+
code_challenge=challenge if config.use_pkce else None,
|
|
137
|
+
**(config.authorization_kwargs or {})
|
|
138
|
+
)
|
|
139
|
+
return auth_url
|
|
140
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
141
|
+
raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
|
|
101
142
|
|
|
102
143
|
# --------------------------- HTTP Basic ------------------------------ #
|
|
103
144
|
@staticmethod
|
|
@@ -131,13 +172,12 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
131
172
|
flow_state.verifier = verifier
|
|
132
173
|
flow_state.challenge = challenge
|
|
133
174
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
)
|
|
175
|
+
# Create authorization URL using helper function
|
|
176
|
+
auth_url = self._create_authorization_url(client=client,
|
|
177
|
+
config=cfg,
|
|
178
|
+
state=state,
|
|
179
|
+
verifier=flow_state.verifier,
|
|
180
|
+
challenge=flow_state.challenge)
|
|
141
181
|
|
|
142
182
|
# Register flow + maybe spin up redirect handler
|
|
143
183
|
async with self._server_lock:
|
|
@@ -149,14 +189,18 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
149
189
|
self._flows[state] = flow_state
|
|
150
190
|
self._active_flows += 1
|
|
151
191
|
|
|
152
|
-
|
|
153
|
-
|
|
192
|
+
try:
|
|
193
|
+
webbrowser.open(auth_url)
|
|
194
|
+
click.echo("Your browser has been opened for authentication.")
|
|
195
|
+
except Exception as e:
|
|
196
|
+
logger.error("Browser open failed: %s", e)
|
|
197
|
+
raise RuntimeError(f"Browser open failed: {e}") from e
|
|
154
198
|
|
|
155
199
|
# Wait for the redirect to land
|
|
156
200
|
try:
|
|
157
201
|
token = await asyncio.wait_for(flow_state.future, timeout=300)
|
|
158
|
-
except
|
|
159
|
-
raise RuntimeError("Authentication timed out (5 min).")
|
|
202
|
+
except TimeoutError as exc:
|
|
203
|
+
raise RuntimeError("Authentication timed out (5 min).") from exc
|
|
160
204
|
finally:
|
|
161
205
|
async with self._server_lock:
|
|
162
206
|
self._flows.pop(state, None)
|
|
@@ -175,9 +219,9 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
175
219
|
# --------------- redirect server / in‑process app -------------------- #
|
|
176
220
|
async def _build_redirect_app(self) -> FastAPI:
|
|
177
221
|
"""
|
|
178
|
-
* If cfg.run_redirect_local_server == True → start a
|
|
179
|
-
* Else → only build the
|
|
180
|
-
for in‑process testing
|
|
222
|
+
* If cfg.run_redirect_local_server == True → start a local server.
|
|
223
|
+
* Else → only build the redirect app and save it to `self._redirect_app`
|
|
224
|
+
for in‑process testing.
|
|
181
225
|
"""
|
|
182
226
|
app = FastAPI()
|
|
183
227
|
|
|
@@ -195,8 +239,16 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
195
239
|
state=state,
|
|
196
240
|
)
|
|
197
241
|
flow_state.future.set_result(token)
|
|
198
|
-
except
|
|
199
|
-
flow_state.future.set_exception(
|
|
242
|
+
except OAuthError as e:
|
|
243
|
+
flow_state.future.set_exception(
|
|
244
|
+
RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})"))
|
|
245
|
+
return "Authentication failed: Authorization server rejected the request. You may close this tab."
|
|
246
|
+
except httpx.HTTPError as e:
|
|
247
|
+
flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}"))
|
|
248
|
+
return "Authentication failed: Network error occurred. You may close this tab."
|
|
249
|
+
except Exception as e:
|
|
250
|
+
flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}"))
|
|
251
|
+
return "Authentication failed: An unexpected error occurred. You may close this tab."
|
|
200
252
|
return "Authentication successful – you may close this tab."
|
|
201
253
|
|
|
202
254
|
return app
|
|
@@ -213,7 +265,7 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
213
265
|
|
|
214
266
|
asyncio.create_task(self._server_controller.start_server(host="localhost", port=8000))
|
|
215
267
|
|
|
216
|
-
# Give
|
|
268
|
+
# Give the server a moment to bind sockets before we return
|
|
217
269
|
await asyncio.sleep(0.3)
|
|
218
270
|
except Exception as exc: # noqa: BLE001
|
|
219
271
|
raise RuntimeError(f"Failed to start redirect server: {exc}") from exc
|
|
@@ -227,7 +279,7 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
227
279
|
@property
|
|
228
280
|
def redirect_app(self) -> FastAPI | None:
|
|
229
281
|
"""
|
|
230
|
-
In
|
|
231
|
-
app is exposed
|
|
282
|
+
In test mode (run_redirect_local_server=False) the in‑memory redirect
|
|
283
|
+
app is exposed for testing purposes.
|
|
232
284
|
"""
|
|
233
285
|
return self._redirect_app
|
|
@@ -88,7 +88,7 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
|
|
|
88
88
|
elif (self.front_end_config.input_file):
|
|
89
89
|
|
|
90
90
|
# Run the workflow
|
|
91
|
-
with open(self.front_end_config.input_file,
|
|
91
|
+
with open(self.front_end_config.input_file, encoding="utf-8") as f:
|
|
92
92
|
|
|
93
93
|
async with session_manager.workflow.run(f) as runner:
|
|
94
94
|
runner_outputs = await runner.result(to_type=str)
|
|
@@ -22,6 +22,7 @@ from dataclasses import dataclass
|
|
|
22
22
|
from dataclasses import field
|
|
23
23
|
|
|
24
24
|
import pkce
|
|
25
|
+
from authlib.common.errors import AuthlibBaseError as OAuthError
|
|
25
26
|
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
26
27
|
|
|
27
28
|
from nat.authentication.interfaces import FlowHandlerBase
|
|
@@ -61,14 +62,50 @@ class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
61
62
|
|
|
62
63
|
raise NotImplementedError(f"Authentication method '{method}' is not supported by the websocket frontend.")
|
|
63
64
|
|
|
64
|
-
def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig):
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
65
|
+
def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig) -> AsyncOAuth2Client:
|
|
66
|
+
try:
|
|
67
|
+
return AsyncOAuth2Client(client_id=config.client_id,
|
|
68
|
+
client_secret=config.client_secret,
|
|
69
|
+
redirect_uri=config.redirect_uri,
|
|
70
|
+
scope=" ".join(config.scopes) if config.scopes else None,
|
|
71
|
+
token_endpoint=config.token_url,
|
|
72
|
+
code_challenge_method='S256' if config.use_pkce else None,
|
|
73
|
+
token_endpoint_auth_method=config.token_endpoint_auth_method)
|
|
74
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
75
|
+
raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
|
|
76
|
+
except Exception as e:
|
|
77
|
+
raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
|
|
78
|
+
|
|
79
|
+
def _create_authorization_url(self,
|
|
80
|
+
client: AsyncOAuth2Client,
|
|
81
|
+
config: OAuth2AuthCodeFlowProviderConfig,
|
|
82
|
+
state: str,
|
|
83
|
+
verifier: str = None,
|
|
84
|
+
challenge: str = None) -> str:
|
|
85
|
+
"""
|
|
86
|
+
Create OAuth authorization URL with proper error handling.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
client: The OAuth2 client instance
|
|
90
|
+
config: OAuth2 configuration
|
|
91
|
+
state: OAuth state parameter
|
|
92
|
+
verifier: PKCE verifier (if using PKCE)
|
|
93
|
+
challenge: PKCE challenge (if using PKCE)
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
The authorization URL
|
|
97
|
+
"""
|
|
98
|
+
try:
|
|
99
|
+
authorization_url, _ = client.create_authorization_url(
|
|
100
|
+
config.authorization_url,
|
|
101
|
+
state=state,
|
|
102
|
+
code_verifier=verifier if config.use_pkce else None,
|
|
103
|
+
code_challenge=challenge if config.use_pkce else None,
|
|
104
|
+
**(config.authorization_kwargs or {})
|
|
105
|
+
)
|
|
106
|
+
return authorization_url
|
|
107
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
108
|
+
raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
|
|
72
109
|
|
|
73
110
|
async def _handle_oauth2_auth_code_flow(self, config: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext:
|
|
74
111
|
|
|
@@ -82,21 +119,19 @@ class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
82
119
|
flow_state.verifier = verifier
|
|
83
120
|
flow_state.challenge = challenge
|
|
84
121
|
|
|
85
|
-
authorization_url
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
**(config.authorization_kwargs or {})
|
|
91
|
-
)
|
|
122
|
+
authorization_url = self._create_authorization_url(client=flow_state.client,
|
|
123
|
+
config=config,
|
|
124
|
+
state=state,
|
|
125
|
+
verifier=flow_state.verifier,
|
|
126
|
+
challenge=flow_state.challenge)
|
|
92
127
|
|
|
93
128
|
await self._add_flow_cb(state, flow_state)
|
|
94
129
|
await self._web_socket_message_handler.create_websocket_message(_HumanPromptOAuthConsent(text=authorization_url)
|
|
95
130
|
)
|
|
96
131
|
try:
|
|
97
132
|
token = await asyncio.wait_for(flow_state.future, timeout=300)
|
|
98
|
-
except
|
|
99
|
-
raise RuntimeError("Authentication flow timed out after 5 minutes.")
|
|
133
|
+
except TimeoutError as exc:
|
|
134
|
+
raise RuntimeError("Authentication flow timed out after 5 minutes.") from exc
|
|
100
135
|
finally:
|
|
101
136
|
|
|
102
137
|
await self._remove_flow_cb(state)
|
|
@@ -25,18 +25,21 @@ from collections.abc import Callable
|
|
|
25
25
|
from contextlib import asynccontextmanager
|
|
26
26
|
from pathlib import Path
|
|
27
27
|
|
|
28
|
+
import httpx
|
|
29
|
+
from authlib.common.errors import AuthlibBaseError as OAuthError
|
|
28
30
|
from fastapi import Body
|
|
29
31
|
from fastapi import FastAPI
|
|
32
|
+
from fastapi import HTTPException
|
|
30
33
|
from fastapi import Request
|
|
31
34
|
from fastapi import Response
|
|
32
35
|
from fastapi import UploadFile
|
|
33
|
-
from fastapi.exceptions import HTTPException
|
|
34
36
|
from fastapi.middleware.cors import CORSMiddleware
|
|
35
37
|
from fastapi.responses import StreamingResponse
|
|
36
38
|
from pydantic import BaseModel
|
|
37
39
|
from pydantic import Field
|
|
38
40
|
from starlette.websockets import WebSocket
|
|
39
41
|
|
|
42
|
+
from nat.builder.function import Function
|
|
40
43
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
41
44
|
from nat.data_models.api_server import ChatRequest
|
|
42
45
|
from nat.data_models.api_server import ChatResponse
|
|
@@ -241,6 +244,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
241
244
|
await self.add_evaluate_route(app, SessionManager(await builder.build()))
|
|
242
245
|
await self.add_static_files_route(app, builder)
|
|
243
246
|
await self.add_authorization_route(app)
|
|
247
|
+
await self.add_mcp_client_tool_list_route(app, builder)
|
|
244
248
|
|
|
245
249
|
for ep in self.front_end_config.endpoints:
|
|
246
250
|
|
|
@@ -1071,8 +1075,13 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
1071
1075
|
code_verifier=verifier,
|
|
1072
1076
|
state=state)
|
|
1073
1077
|
flow_state.future.set_result(res)
|
|
1078
|
+
except OAuthError as e:
|
|
1079
|
+
flow_state.future.set_exception(
|
|
1080
|
+
RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})"))
|
|
1081
|
+
except httpx.HTTPError as e:
|
|
1082
|
+
flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}"))
|
|
1074
1083
|
except Exception as e:
|
|
1075
|
-
flow_state.future.set_exception(e)
|
|
1084
|
+
flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}"))
|
|
1076
1085
|
|
|
1077
1086
|
return HTMLResponse(content=AUTH_REDIRECT_SUCCESS_HTML,
|
|
1078
1087
|
status_code=200,
|
|
@@ -1088,6 +1097,183 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
1088
1097
|
methods=["GET"],
|
|
1089
1098
|
description="Handles the authorization code and state returned from the Authorization Code Grant Flow.")
|
|
1090
1099
|
|
|
1100
|
+
async def add_mcp_client_tool_list_route(self, app: FastAPI, builder: WorkflowBuilder):
|
|
1101
|
+
"""Add the MCP client tool list endpoint to the FastAPI app."""
|
|
1102
|
+
from typing import Any
|
|
1103
|
+
|
|
1104
|
+
from pydantic import BaseModel
|
|
1105
|
+
|
|
1106
|
+
class MCPToolInfo(BaseModel):
|
|
1107
|
+
name: str
|
|
1108
|
+
description: str
|
|
1109
|
+
server: str
|
|
1110
|
+
available: bool
|
|
1111
|
+
|
|
1112
|
+
class MCPClientToolListResponse(BaseModel):
|
|
1113
|
+
mcp_clients: list[dict[str, Any]]
|
|
1114
|
+
|
|
1115
|
+
async def get_mcp_client_tool_list() -> MCPClientToolListResponse:
|
|
1116
|
+
"""
|
|
1117
|
+
Get the list of MCP tools from all MCP clients in the workflow configuration.
|
|
1118
|
+
Checks session health and compares with workflow function group configuration.
|
|
1119
|
+
"""
|
|
1120
|
+
mcp_clients_info = []
|
|
1121
|
+
|
|
1122
|
+
try:
|
|
1123
|
+
# Get all function groups from the builder
|
|
1124
|
+
function_groups = builder._function_groups
|
|
1125
|
+
|
|
1126
|
+
# Find MCP client function groups
|
|
1127
|
+
for group_name, configured_group in function_groups.items():
|
|
1128
|
+
if configured_group.config.type != "mcp_client":
|
|
1129
|
+
continue
|
|
1130
|
+
|
|
1131
|
+
from nat.plugins.mcp.client_impl import MCPClientConfig
|
|
1132
|
+
|
|
1133
|
+
config = configured_group.config
|
|
1134
|
+
assert isinstance(config, MCPClientConfig)
|
|
1135
|
+
|
|
1136
|
+
# Reuse the existing MCP client session stored on the function group instance
|
|
1137
|
+
group_instance = configured_group.instance
|
|
1138
|
+
|
|
1139
|
+
client = group_instance.mcp_client
|
|
1140
|
+
if client is None:
|
|
1141
|
+
raise RuntimeError(f"MCP client not found for group {group_name}")
|
|
1142
|
+
|
|
1143
|
+
try:
|
|
1144
|
+
session_healthy = False
|
|
1145
|
+
server_tools: dict[str, Any] = {}
|
|
1146
|
+
|
|
1147
|
+
try:
|
|
1148
|
+
server_tools = await client.get_tools()
|
|
1149
|
+
session_healthy = True
|
|
1150
|
+
except Exception as e:
|
|
1151
|
+
logger.exception(f"Failed to connect to MCP server {client.server_name}: {e}")
|
|
1152
|
+
session_healthy = False
|
|
1153
|
+
|
|
1154
|
+
# Get workflow function group configuration (configured client-side tools)
|
|
1155
|
+
configured_short_names: set[str] = set()
|
|
1156
|
+
configured_full_to_fn: dict[str, Function] = {}
|
|
1157
|
+
try:
|
|
1158
|
+
# Pass a no-op filter function to bypass any default filtering that might check
|
|
1159
|
+
# health status, preventing potential infinite recursion during health status checks.
|
|
1160
|
+
async def pass_through_filter(fn):
|
|
1161
|
+
return fn
|
|
1162
|
+
|
|
1163
|
+
accessible_functions = await group_instance.get_accessible_functions(
|
|
1164
|
+
filter_fn=pass_through_filter)
|
|
1165
|
+
configured_full_to_fn = accessible_functions
|
|
1166
|
+
configured_short_names = {name.split('.', 1)[1] for name in accessible_functions.keys()}
|
|
1167
|
+
except Exception as e:
|
|
1168
|
+
logger.exception(f"Failed to get accessible functions for group {group_name}: {e}")
|
|
1169
|
+
|
|
1170
|
+
# Build alias->original mapping and override configs from overrides
|
|
1171
|
+
alias_to_original: dict[str, str] = {}
|
|
1172
|
+
override_configs: dict[str, Any] = {}
|
|
1173
|
+
try:
|
|
1174
|
+
if config.tool_overrides is not None:
|
|
1175
|
+
for orig_name, override in config.tool_overrides.items():
|
|
1176
|
+
if override.alias is not None:
|
|
1177
|
+
alias_to_original[override.alias] = orig_name
|
|
1178
|
+
override_configs[override.alias] = override
|
|
1179
|
+
else:
|
|
1180
|
+
override_configs[orig_name] = override
|
|
1181
|
+
except Exception:
|
|
1182
|
+
pass
|
|
1183
|
+
|
|
1184
|
+
# Create tool info list (always return configured tools; mark availability)
|
|
1185
|
+
tools_info: list[dict[str, Any]] = []
|
|
1186
|
+
available_count = 0
|
|
1187
|
+
for wf_fn, fn_short in zip(configured_full_to_fn.values(), configured_short_names):
|
|
1188
|
+
orig_name = alias_to_original.get(fn_short, fn_short)
|
|
1189
|
+
available = session_healthy and (orig_name in server_tools)
|
|
1190
|
+
if available:
|
|
1191
|
+
available_count += 1
|
|
1192
|
+
|
|
1193
|
+
# Prefer tool override description, then workflow function description,
|
|
1194
|
+
# then server description
|
|
1195
|
+
description = ""
|
|
1196
|
+
if fn_short in override_configs and override_configs[fn_short].description:
|
|
1197
|
+
description = override_configs[fn_short].description
|
|
1198
|
+
elif wf_fn.description:
|
|
1199
|
+
description = wf_fn.description
|
|
1200
|
+
elif available and orig_name in server_tools:
|
|
1201
|
+
description = server_tools[orig_name].description or ""
|
|
1202
|
+
|
|
1203
|
+
tools_info.append(
|
|
1204
|
+
MCPToolInfo(name=fn_short,
|
|
1205
|
+
description=description or "",
|
|
1206
|
+
server=client.server_name,
|
|
1207
|
+
available=available).model_dump())
|
|
1208
|
+
|
|
1209
|
+
# Sort tools_info by name to maintain consistent ordering
|
|
1210
|
+
tools_info.sort(key=lambda x: x['name'])
|
|
1211
|
+
|
|
1212
|
+
mcp_clients_info.append({
|
|
1213
|
+
"function_group": group_name,
|
|
1214
|
+
"server": client.server_name,
|
|
1215
|
+
"transport": config.server.transport,
|
|
1216
|
+
"session_healthy": session_healthy,
|
|
1217
|
+
"tools": tools_info,
|
|
1218
|
+
"total_tools": len(configured_short_names),
|
|
1219
|
+
"available_tools": available_count
|
|
1220
|
+
})
|
|
1221
|
+
|
|
1222
|
+
except Exception as e:
|
|
1223
|
+
logger.error(f"Error processing MCP client {group_name}: {e}")
|
|
1224
|
+
mcp_clients_info.append({
|
|
1225
|
+
"function_group": group_name,
|
|
1226
|
+
"server": "unknown",
|
|
1227
|
+
"transport": config.server.transport if config.server else "unknown",
|
|
1228
|
+
"session_healthy": False,
|
|
1229
|
+
"error": str(e),
|
|
1230
|
+
"tools": [],
|
|
1231
|
+
"total_tools": 0,
|
|
1232
|
+
"workflow_tools": 0
|
|
1233
|
+
})
|
|
1234
|
+
|
|
1235
|
+
return MCPClientToolListResponse(mcp_clients=mcp_clients_info)
|
|
1236
|
+
|
|
1237
|
+
except Exception as e:
|
|
1238
|
+
logger.error(f"Error in MCP client tool list endpoint: {e}")
|
|
1239
|
+
raise HTTPException(status_code=500, detail=f"Failed to retrieve MCP client information: {str(e)}")
|
|
1240
|
+
|
|
1241
|
+
# Add the route to the FastAPI app
|
|
1242
|
+
app.add_api_route(
|
|
1243
|
+
path="/mcp/client/tool/list",
|
|
1244
|
+
endpoint=get_mcp_client_tool_list,
|
|
1245
|
+
methods=["GET"],
|
|
1246
|
+
response_model=MCPClientToolListResponse,
|
|
1247
|
+
description="Get list of MCP client tools with session health and workflow configuration comparison",
|
|
1248
|
+
responses={
|
|
1249
|
+
200: {
|
|
1250
|
+
"description": "Successfully retrieved MCP client tool information",
|
|
1251
|
+
"content": {
|
|
1252
|
+
"application/json": {
|
|
1253
|
+
"example": {
|
|
1254
|
+
"mcp_clients": [{
|
|
1255
|
+
"function_group": "mcp_tools",
|
|
1256
|
+
"server": "streamable-http:http://localhost:9901/mcp",
|
|
1257
|
+
"transport": "streamable-http",
|
|
1258
|
+
"session_healthy": True,
|
|
1259
|
+
"tools": [{
|
|
1260
|
+
"name": "tool_a",
|
|
1261
|
+
"description": "Tool A description",
|
|
1262
|
+
"server": "streamable-http:http://localhost:9901/mcp",
|
|
1263
|
+
"available": True
|
|
1264
|
+
}],
|
|
1265
|
+
"total_tools": 1,
|
|
1266
|
+
"available_tools": 1
|
|
1267
|
+
}]
|
|
1268
|
+
}
|
|
1269
|
+
}
|
|
1270
|
+
}
|
|
1271
|
+
},
|
|
1272
|
+
500: {
|
|
1273
|
+
"description": "Internal Server Error"
|
|
1274
|
+
}
|
|
1275
|
+
})
|
|
1276
|
+
|
|
1091
1277
|
async def _add_flow(self, state: str, flow_state: FlowState):
|
|
1092
1278
|
async with self._outstanding_flows_lock:
|
|
1093
1279
|
self._outstanding_flows[state] = flow_state
|
|
@@ -370,7 +370,7 @@ class JobStore(DaskClientMixin):
|
|
|
370
370
|
# Convert BaseModel to JSON string for storage
|
|
371
371
|
output = output.model_dump_json(round_trip=True)
|
|
372
372
|
|
|
373
|
-
if isinstance(output,
|
|
373
|
+
if isinstance(output, dict | list):
|
|
374
374
|
# Convert dict or list to JSON string for storage
|
|
375
375
|
output = json.dumps(output)
|
|
376
376
|
|
|
@@ -555,7 +555,7 @@ class JobStore(DaskClientMixin):
|
|
|
555
555
|
logger.exception("Failed to expire %s", job_id)
|
|
556
556
|
|
|
557
557
|
await session.execute(
|
|
558
|
-
|
|
558
|
+
update(JobInfo).where(JobInfo.job_id.in_(successfully_expired)).values(is_expired=True))
|
|
559
559
|
|
|
560
560
|
|
|
561
561
|
def get_db_engine(db_url: str | None = None, echo: bool = False, use_async: bool = True) -> "Engine | AsyncEngine":
|
|
@@ -105,10 +105,10 @@ class WebSocketMessageHandler:
|
|
|
105
105
|
if (isinstance(validated_message, WebSocketUserMessage)):
|
|
106
106
|
await self.process_workflow_request(validated_message)
|
|
107
107
|
|
|
108
|
-
elif isinstance(
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
108
|
+
elif isinstance(
|
|
109
|
+
validated_message,
|
|
110
|
+
WebSocketSystemResponseTokenMessage | WebSocketSystemIntermediateStepMessage
|
|
111
|
+
| WebSocketSystemInteractionMessage):
|
|
112
112
|
# These messages are already handled by self.create_websocket_message(data_model=value, …)
|
|
113
113
|
# No further processing is needed here.
|
|
114
114
|
pass
|