nvidia-nat 1.3.0a20250928__py3-none-any.whl → 1.3.0a20250930__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +1 -1
- nat/agent/rewoo_agent/agent.py +298 -118
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +4 -1
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +26 -18
- nat/builder/builder.py +1 -1
- nat/builder/context.py +2 -2
- nat/builder/front_end.py +1 -1
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/commands/mcp/mcp.py +2 -2
- nat/cli/commands/start.py +1 -1
- nat/cli/type_registry.py +1 -1
- nat/control_flow/router_agent/register.py +1 -1
- nat/data_models/api_server.py +9 -9
- nat/data_models/authentication.py +3 -9
- nat/data_models/dataset_handler.py +1 -1
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/swe_bench_evaluator/evaluate.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
- nat/experimental/decorators/experimental_warning_decorator.py +1 -2
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +188 -2
- nat/front_ends/fastapi/job_store.py +2 -2
- nat/front_ends/fastapi/message_handler.py +4 -4
- nat/front_ends/fastapi/message_validator.py +5 -5
- nat/front_ends/mcp/tool_converter.py +1 -1
- nat/llm/utils/thinking.py +1 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +2 -2
- nat/observability/processor/batching_processor.py +1 -1
- nat/profiler/decorators/function_tracking.py +2 -2
- nat/profiler/parameter_optimization/parameter_selection.py +3 -4
- nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
- nat/retriever/milvus/retriever.py +1 -1
- nat/settings/global_settings.py +2 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/datetime_tools.py +1 -1
- nat/utils/data_models/schema_validator.py +1 -1
- nat/utils/exception_handlers/automatic_retries.py +1 -1
- nat/utils/io/yaml_tools.py +1 -1
- nat/utils/type_utils.py +1 -1
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/METADATA +2 -1
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/RECORD +52 -52
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/top_level.txt +0 -0
|
@@ -25,18 +25,21 @@ from collections.abc import Callable
|
|
|
25
25
|
from contextlib import asynccontextmanager
|
|
26
26
|
from pathlib import Path
|
|
27
27
|
|
|
28
|
+
import httpx
|
|
29
|
+
from authlib.common.errors import AuthlibBaseError as OAuthError
|
|
28
30
|
from fastapi import Body
|
|
29
31
|
from fastapi import FastAPI
|
|
32
|
+
from fastapi import HTTPException
|
|
30
33
|
from fastapi import Request
|
|
31
34
|
from fastapi import Response
|
|
32
35
|
from fastapi import UploadFile
|
|
33
|
-
from fastapi.exceptions import HTTPException
|
|
34
36
|
from fastapi.middleware.cors import CORSMiddleware
|
|
35
37
|
from fastapi.responses import StreamingResponse
|
|
36
38
|
from pydantic import BaseModel
|
|
37
39
|
from pydantic import Field
|
|
38
40
|
from starlette.websockets import WebSocket
|
|
39
41
|
|
|
42
|
+
from nat.builder.function import Function
|
|
40
43
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
41
44
|
from nat.data_models.api_server import ChatRequest
|
|
42
45
|
from nat.data_models.api_server import ChatResponse
|
|
@@ -241,6 +244,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
241
244
|
await self.add_evaluate_route(app, SessionManager(await builder.build()))
|
|
242
245
|
await self.add_static_files_route(app, builder)
|
|
243
246
|
await self.add_authorization_route(app)
|
|
247
|
+
await self.add_mcp_client_tool_list_route(app, builder)
|
|
244
248
|
|
|
245
249
|
for ep in self.front_end_config.endpoints:
|
|
246
250
|
|
|
@@ -1071,8 +1075,13 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
1071
1075
|
code_verifier=verifier,
|
|
1072
1076
|
state=state)
|
|
1073
1077
|
flow_state.future.set_result(res)
|
|
1078
|
+
except OAuthError as e:
|
|
1079
|
+
flow_state.future.set_exception(
|
|
1080
|
+
RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})"))
|
|
1081
|
+
except httpx.HTTPError as e:
|
|
1082
|
+
flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}"))
|
|
1074
1083
|
except Exception as e:
|
|
1075
|
-
flow_state.future.set_exception(e)
|
|
1084
|
+
flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}"))
|
|
1076
1085
|
|
|
1077
1086
|
return HTMLResponse(content=AUTH_REDIRECT_SUCCESS_HTML,
|
|
1078
1087
|
status_code=200,
|
|
@@ -1088,6 +1097,183 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
1088
1097
|
methods=["GET"],
|
|
1089
1098
|
description="Handles the authorization code and state returned from the Authorization Code Grant Flow.")
|
|
1090
1099
|
|
|
1100
|
+
async def add_mcp_client_tool_list_route(self, app: FastAPI, builder: WorkflowBuilder):
|
|
1101
|
+
"""Add the MCP client tool list endpoint to the FastAPI app."""
|
|
1102
|
+
from typing import Any
|
|
1103
|
+
|
|
1104
|
+
from pydantic import BaseModel
|
|
1105
|
+
|
|
1106
|
+
class MCPToolInfo(BaseModel):
|
|
1107
|
+
name: str
|
|
1108
|
+
description: str
|
|
1109
|
+
server: str
|
|
1110
|
+
available: bool
|
|
1111
|
+
|
|
1112
|
+
class MCPClientToolListResponse(BaseModel):
|
|
1113
|
+
mcp_clients: list[dict[str, Any]]
|
|
1114
|
+
|
|
1115
|
+
async def get_mcp_client_tool_list() -> MCPClientToolListResponse:
|
|
1116
|
+
"""
|
|
1117
|
+
Get the list of MCP tools from all MCP clients in the workflow configuration.
|
|
1118
|
+
Checks session health and compares with workflow function group configuration.
|
|
1119
|
+
"""
|
|
1120
|
+
mcp_clients_info = []
|
|
1121
|
+
|
|
1122
|
+
try:
|
|
1123
|
+
# Get all function groups from the builder
|
|
1124
|
+
function_groups = builder._function_groups
|
|
1125
|
+
|
|
1126
|
+
# Find MCP client function groups
|
|
1127
|
+
for group_name, configured_group in function_groups.items():
|
|
1128
|
+
if configured_group.config.type != "mcp_client":
|
|
1129
|
+
continue
|
|
1130
|
+
|
|
1131
|
+
from nat.plugins.mcp.client_impl import MCPClientConfig
|
|
1132
|
+
|
|
1133
|
+
config = configured_group.config
|
|
1134
|
+
assert isinstance(config, MCPClientConfig)
|
|
1135
|
+
|
|
1136
|
+
# Reuse the existing MCP client session stored on the function group instance
|
|
1137
|
+
group_instance = configured_group.instance
|
|
1138
|
+
|
|
1139
|
+
client = group_instance.mcp_client
|
|
1140
|
+
if client is None:
|
|
1141
|
+
raise RuntimeError(f"MCP client not found for group {group_name}")
|
|
1142
|
+
|
|
1143
|
+
try:
|
|
1144
|
+
session_healthy = False
|
|
1145
|
+
server_tools: dict[str, Any] = {}
|
|
1146
|
+
|
|
1147
|
+
try:
|
|
1148
|
+
server_tools = await client.get_tools()
|
|
1149
|
+
session_healthy = True
|
|
1150
|
+
except Exception as e:
|
|
1151
|
+
logger.exception(f"Failed to connect to MCP server {client.server_name}: {e}")
|
|
1152
|
+
session_healthy = False
|
|
1153
|
+
|
|
1154
|
+
# Get workflow function group configuration (configured client-side tools)
|
|
1155
|
+
configured_short_names: set[str] = set()
|
|
1156
|
+
configured_full_to_fn: dict[str, Function] = {}
|
|
1157
|
+
try:
|
|
1158
|
+
# Pass a no-op filter function to bypass any default filtering that might check
|
|
1159
|
+
# health status, preventing potential infinite recursion during health status checks.
|
|
1160
|
+
async def pass_through_filter(fn):
|
|
1161
|
+
return fn
|
|
1162
|
+
|
|
1163
|
+
accessible_functions = await group_instance.get_accessible_functions(
|
|
1164
|
+
filter_fn=pass_through_filter)
|
|
1165
|
+
configured_full_to_fn = accessible_functions
|
|
1166
|
+
configured_short_names = {name.split('.', 1)[1] for name in accessible_functions.keys()}
|
|
1167
|
+
except Exception as e:
|
|
1168
|
+
logger.exception(f"Failed to get accessible functions for group {group_name}: {e}")
|
|
1169
|
+
|
|
1170
|
+
# Build alias->original mapping and override configs from overrides
|
|
1171
|
+
alias_to_original: dict[str, str] = {}
|
|
1172
|
+
override_configs: dict[str, Any] = {}
|
|
1173
|
+
try:
|
|
1174
|
+
if config.tool_overrides is not None:
|
|
1175
|
+
for orig_name, override in config.tool_overrides.items():
|
|
1176
|
+
if override.alias is not None:
|
|
1177
|
+
alias_to_original[override.alias] = orig_name
|
|
1178
|
+
override_configs[override.alias] = override
|
|
1179
|
+
else:
|
|
1180
|
+
override_configs[orig_name] = override
|
|
1181
|
+
except Exception:
|
|
1182
|
+
pass
|
|
1183
|
+
|
|
1184
|
+
# Create tool info list (always return configured tools; mark availability)
|
|
1185
|
+
tools_info: list[dict[str, Any]] = []
|
|
1186
|
+
available_count = 0
|
|
1187
|
+
for wf_fn, fn_short in zip(configured_full_to_fn.values(), configured_short_names):
|
|
1188
|
+
orig_name = alias_to_original.get(fn_short, fn_short)
|
|
1189
|
+
available = session_healthy and (orig_name in server_tools)
|
|
1190
|
+
if available:
|
|
1191
|
+
available_count += 1
|
|
1192
|
+
|
|
1193
|
+
# Prefer tool override description, then workflow function description,
|
|
1194
|
+
# then server description
|
|
1195
|
+
description = ""
|
|
1196
|
+
if fn_short in override_configs and override_configs[fn_short].description:
|
|
1197
|
+
description = override_configs[fn_short].description
|
|
1198
|
+
elif wf_fn.description:
|
|
1199
|
+
description = wf_fn.description
|
|
1200
|
+
elif available and orig_name in server_tools:
|
|
1201
|
+
description = server_tools[orig_name].description or ""
|
|
1202
|
+
|
|
1203
|
+
tools_info.append(
|
|
1204
|
+
MCPToolInfo(name=fn_short,
|
|
1205
|
+
description=description or "",
|
|
1206
|
+
server=client.server_name,
|
|
1207
|
+
available=available).model_dump())
|
|
1208
|
+
|
|
1209
|
+
# Sort tools_info by name to maintain consistent ordering
|
|
1210
|
+
tools_info.sort(key=lambda x: x['name'])
|
|
1211
|
+
|
|
1212
|
+
mcp_clients_info.append({
|
|
1213
|
+
"function_group": group_name,
|
|
1214
|
+
"server": client.server_name,
|
|
1215
|
+
"transport": config.server.transport,
|
|
1216
|
+
"session_healthy": session_healthy,
|
|
1217
|
+
"tools": tools_info,
|
|
1218
|
+
"total_tools": len(configured_short_names),
|
|
1219
|
+
"available_tools": available_count
|
|
1220
|
+
})
|
|
1221
|
+
|
|
1222
|
+
except Exception as e:
|
|
1223
|
+
logger.error(f"Error processing MCP client {group_name}: {e}")
|
|
1224
|
+
mcp_clients_info.append({
|
|
1225
|
+
"function_group": group_name,
|
|
1226
|
+
"server": "unknown",
|
|
1227
|
+
"transport": config.server.transport if config.server else "unknown",
|
|
1228
|
+
"session_healthy": False,
|
|
1229
|
+
"error": str(e),
|
|
1230
|
+
"tools": [],
|
|
1231
|
+
"total_tools": 0,
|
|
1232
|
+
"workflow_tools": 0
|
|
1233
|
+
})
|
|
1234
|
+
|
|
1235
|
+
return MCPClientToolListResponse(mcp_clients=mcp_clients_info)
|
|
1236
|
+
|
|
1237
|
+
except Exception as e:
|
|
1238
|
+
logger.error(f"Error in MCP client tool list endpoint: {e}")
|
|
1239
|
+
raise HTTPException(status_code=500, detail=f"Failed to retrieve MCP client information: {str(e)}")
|
|
1240
|
+
|
|
1241
|
+
# Add the route to the FastAPI app
|
|
1242
|
+
app.add_api_route(
|
|
1243
|
+
path="/mcp/client/tool/list",
|
|
1244
|
+
endpoint=get_mcp_client_tool_list,
|
|
1245
|
+
methods=["GET"],
|
|
1246
|
+
response_model=MCPClientToolListResponse,
|
|
1247
|
+
description="Get list of MCP client tools with session health and workflow configuration comparison",
|
|
1248
|
+
responses={
|
|
1249
|
+
200: {
|
|
1250
|
+
"description": "Successfully retrieved MCP client tool information",
|
|
1251
|
+
"content": {
|
|
1252
|
+
"application/json": {
|
|
1253
|
+
"example": {
|
|
1254
|
+
"mcp_clients": [{
|
|
1255
|
+
"function_group": "mcp_tools",
|
|
1256
|
+
"server": "streamable-http:http://localhost:9901/mcp",
|
|
1257
|
+
"transport": "streamable-http",
|
|
1258
|
+
"session_healthy": True,
|
|
1259
|
+
"tools": [{
|
|
1260
|
+
"name": "tool_a",
|
|
1261
|
+
"description": "Tool A description",
|
|
1262
|
+
"server": "streamable-http:http://localhost:9901/mcp",
|
|
1263
|
+
"available": True
|
|
1264
|
+
}],
|
|
1265
|
+
"total_tools": 1,
|
|
1266
|
+
"available_tools": 1
|
|
1267
|
+
}]
|
|
1268
|
+
}
|
|
1269
|
+
}
|
|
1270
|
+
}
|
|
1271
|
+
},
|
|
1272
|
+
500: {
|
|
1273
|
+
"description": "Internal Server Error"
|
|
1274
|
+
}
|
|
1275
|
+
})
|
|
1276
|
+
|
|
1091
1277
|
async def _add_flow(self, state: str, flow_state: FlowState):
|
|
1092
1278
|
async with self._outstanding_flows_lock:
|
|
1093
1279
|
self._outstanding_flows[state] = flow_state
|
|
@@ -370,7 +370,7 @@ class JobStore(DaskClientMixin):
|
|
|
370
370
|
# Convert BaseModel to JSON string for storage
|
|
371
371
|
output = output.model_dump_json(round_trip=True)
|
|
372
372
|
|
|
373
|
-
if isinstance(output,
|
|
373
|
+
if isinstance(output, dict | list):
|
|
374
374
|
# Convert dict or list to JSON string for storage
|
|
375
375
|
output = json.dumps(output)
|
|
376
376
|
|
|
@@ -555,7 +555,7 @@ class JobStore(DaskClientMixin):
|
|
|
555
555
|
logger.exception("Failed to expire %s", job_id)
|
|
556
556
|
|
|
557
557
|
await session.execute(
|
|
558
|
-
|
|
558
|
+
update(JobInfo).where(JobInfo.job_id.in_(successfully_expired)).values(is_expired=True))
|
|
559
559
|
|
|
560
560
|
|
|
561
561
|
def get_db_engine(db_url: str | None = None, echo: bool = False, use_async: bool = True) -> "Engine | AsyncEngine":
|
|
@@ -105,10 +105,10 @@ class WebSocketMessageHandler:
|
|
|
105
105
|
if (isinstance(validated_message, WebSocketUserMessage)):
|
|
106
106
|
await self.process_workflow_request(validated_message)
|
|
107
107
|
|
|
108
|
-
elif isinstance(
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
108
|
+
elif isinstance(
|
|
109
|
+
validated_message,
|
|
110
|
+
WebSocketSystemResponseTokenMessage | WebSocketSystemIntermediateStepMessage
|
|
111
|
+
| WebSocketSystemInteractionMessage):
|
|
112
112
|
# These messages are already handled by self.create_websocket_message(data_model=value, …)
|
|
113
113
|
# No further processing is needed here.
|
|
114
114
|
pass
|
|
@@ -139,7 +139,7 @@ class MessageValidator:
|
|
|
139
139
|
text_content: str = str(data_model.payload)
|
|
140
140
|
validated_message_content = SystemResponseContent(text=text_content)
|
|
141
141
|
|
|
142
|
-
elif (isinstance(data_model,
|
|
142
|
+
elif (isinstance(data_model, ChatResponse | ChatResponseChunk)):
|
|
143
143
|
validated_message_content = SystemResponseContent(text=data_model.choices[0].message.content)
|
|
144
144
|
|
|
145
145
|
elif (isinstance(data_model, ResponseIntermediateStep)):
|
|
@@ -204,7 +204,7 @@ class MessageValidator:
|
|
|
204
204
|
|
|
205
205
|
validated_message_type: str = ""
|
|
206
206
|
try:
|
|
207
|
-
if (isinstance(data_model,
|
|
207
|
+
if (isinstance(data_model, ResponsePayloadOutput | ChatResponse | ChatResponseChunk)):
|
|
208
208
|
validated_message_type = WebSocketMessageType.RESPONSE_MESSAGE
|
|
209
209
|
|
|
210
210
|
elif (isinstance(data_model, ResponseIntermediateStep)):
|
|
@@ -241,7 +241,7 @@ class MessageValidator:
|
|
|
241
241
|
content: SystemResponseContent
|
|
242
242
|
| Error = SystemResponseContent(),
|
|
243
243
|
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
244
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
244
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
245
245
|
) -> WebSocketSystemResponseTokenMessage | None:
|
|
246
246
|
"""
|
|
247
247
|
Creates a system response token message with default values.
|
|
@@ -280,7 +280,7 @@ class MessageValidator:
|
|
|
280
280
|
conversation_id: str | None = None,
|
|
281
281
|
content: SystemIntermediateStepContent = SystemIntermediateStepContent(name="default", payload="default"),
|
|
282
282
|
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
283
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
283
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
284
284
|
) -> WebSocketSystemIntermediateStepMessage | None:
|
|
285
285
|
"""
|
|
286
286
|
Creates a system intermediate step message with default values.
|
|
@@ -320,7 +320,7 @@ class MessageValidator:
|
|
|
320
320
|
conversation_id: str | None = None,
|
|
321
321
|
content: HumanPrompt,
|
|
322
322
|
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
323
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
323
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
324
324
|
) -> WebSocketSystemInteractionMessage | None:
|
|
325
325
|
"""
|
|
326
326
|
Creates a system interaction message with default values.
|
|
@@ -175,7 +175,7 @@ def create_function_wrapper(
|
|
|
175
175
|
# Handle different result types for proper formatting
|
|
176
176
|
if isinstance(result, str):
|
|
177
177
|
return result
|
|
178
|
-
if isinstance(result,
|
|
178
|
+
if isinstance(result, dict | list):
|
|
179
179
|
return json.dumps(result, default=str)
|
|
180
180
|
return str(result)
|
|
181
181
|
except Exception as e:
|
nat/llm/utils/thinking.py
CHANGED
|
@@ -19,10 +19,10 @@ import logging
|
|
|
19
19
|
import types
|
|
20
20
|
from abc import abstractmethod
|
|
21
21
|
from collections.abc import AsyncGenerator
|
|
22
|
+
from collections.abc import Callable
|
|
22
23
|
from collections.abc import Iterable
|
|
23
24
|
from dataclasses import dataclass
|
|
24
25
|
from typing import Any
|
|
25
|
-
from typing import Callable
|
|
26
26
|
from typing import TypeVar
|
|
27
27
|
|
|
28
28
|
ModelType = TypeVar("ModelType")
|
|
@@ -372,7 +372,7 @@ class BaseExporter(Exporter):
|
|
|
372
372
|
try:
|
|
373
373
|
# Wait for all tasks to complete with a timeout
|
|
374
374
|
await asyncio.wait_for(asyncio.gather(*self._tasks, return_exceptions=True), timeout=timeout)
|
|
375
|
-
except
|
|
375
|
+
except TimeoutError:
|
|
376
376
|
logger.warning("%s: Some tasks did not complete within %s seconds", self.name, timeout)
|
|
377
377
|
except Exception as e:
|
|
378
378
|
logger.exception("%s: Error while waiting for tasks: %s", self.name, e)
|
|
@@ -252,7 +252,7 @@ class SpanExporter(ProcessingExporter[InputSpanT, OutputSpanT], SerializeMixin):
|
|
|
252
252
|
|
|
253
253
|
end_metadata = event.payload.metadata or {}
|
|
254
254
|
|
|
255
|
-
if not isinstance(end_metadata,
|
|
255
|
+
if not isinstance(end_metadata, dict | TraceMetadata):
|
|
256
256
|
logger.warning("Invalid metadata type for step %s", event.UUID)
|
|
257
257
|
return
|
|
258
258
|
|
|
@@ -184,7 +184,7 @@ class ExporterManager:
|
|
|
184
184
|
try:
|
|
185
185
|
await asyncio.wait_for(asyncio.gather(*cleanup_tasks, return_exceptions=True),
|
|
186
186
|
timeout=self._shutdown_timeout)
|
|
187
|
-
except
|
|
187
|
+
except TimeoutError:
|
|
188
188
|
logger.warning("Some isolated exporters did not clean up within timeout")
|
|
189
189
|
|
|
190
190
|
self._active_isolated_exporters.clear()
|
|
@@ -301,7 +301,7 @@ class ExporterManager:
|
|
|
301
301
|
try:
|
|
302
302
|
task.cancel()
|
|
303
303
|
await asyncio.wait_for(task, timeout=self._shutdown_timeout)
|
|
304
|
-
except
|
|
304
|
+
except TimeoutError:
|
|
305
305
|
logger.warning("Exporter '%s' task did not shut down in time and may be stuck.", name)
|
|
306
306
|
stuck_tasks.append(name)
|
|
307
307
|
except asyncio.CancelledError:
|
|
@@ -241,7 +241,7 @@ class BatchingProcessor(CallbackProcessor[T, list[T]], Generic[T]):
|
|
|
241
241
|
try:
|
|
242
242
|
await asyncio.wait_for(self._shutdown_complete_event.wait(), timeout=self._shutdown_timeout)
|
|
243
243
|
logger.debug("Shutdown completion detected via event")
|
|
244
|
-
except
|
|
244
|
+
except TimeoutError:
|
|
245
245
|
logger.warning("Shutdown completion timeout exceeded (%s seconds)", self._shutdown_timeout)
|
|
246
246
|
return
|
|
247
247
|
|
|
@@ -40,10 +40,10 @@ def _serialize_data(obj: Any) -> Any:
|
|
|
40
40
|
|
|
41
41
|
if isinstance(obj, dict):
|
|
42
42
|
return {str(k): _serialize_data(v) for k, v in obj.items()}
|
|
43
|
-
if isinstance(obj,
|
|
43
|
+
if isinstance(obj, list | tuple | set):
|
|
44
44
|
return [_serialize_data(item) for item in obj]
|
|
45
45
|
|
|
46
|
-
if isinstance(obj,
|
|
46
|
+
if isinstance(obj, str | int | float | bool | type(None)):
|
|
47
47
|
return obj
|
|
48
48
|
|
|
49
49
|
# Fallback
|
|
@@ -13,8 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
from
|
|
17
|
-
from typing import Sequence
|
|
16
|
+
from collections.abc import Sequence
|
|
18
17
|
|
|
19
18
|
import numpy as np
|
|
20
19
|
import optuna
|
|
@@ -41,8 +40,8 @@ def pick_trial(
|
|
|
41
40
|
study: Study,
|
|
42
41
|
mode: str = "harmonic",
|
|
43
42
|
*,
|
|
44
|
-
weights:
|
|
45
|
-
ref_point:
|
|
43
|
+
weights: Sequence[float] | None = None,
|
|
44
|
+
ref_point: Sequence[float] | None = None,
|
|
46
45
|
eps: float = 1e-12,
|
|
47
46
|
) -> optuna.trial.FrozenTrial:
|
|
48
47
|
"""
|
|
@@ -324,7 +324,7 @@ def create_pareto_visualization(data_source: optuna.Study | Path | pd.DataFrame,
|
|
|
324
324
|
if hasattr(data_source, 'trials_dataframe'):
|
|
325
325
|
# Optuna study object
|
|
326
326
|
trials_df, pareto_trials_df = load_trials_from_study(data_source)
|
|
327
|
-
elif isinstance(data_source,
|
|
327
|
+
elif isinstance(data_source, str | Path):
|
|
328
328
|
# CSV file path
|
|
329
329
|
trials_df, pareto_trials_df = load_trials_from_csv(Path(data_source), metric_names, directions)
|
|
330
330
|
elif isinstance(data_source, pd.DataFrame):
|
|
@@ -214,7 +214,7 @@ def _wrap_milvus_results(res: list[Hit], content_field: str):
|
|
|
214
214
|
|
|
215
215
|
|
|
216
216
|
def _wrap_milvus_single_results(res: Hit | dict, content_field: str) -> Document:
|
|
217
|
-
if not isinstance(res,
|
|
217
|
+
if not isinstance(res, Hit | dict):
|
|
218
218
|
raise ValueError(f"Milvus search returned object of type {type(res)}. Expected 'Hit' or 'dict'.")
|
|
219
219
|
|
|
220
220
|
if isinstance(res, Hit):
|
nat/settings/global_settings.py
CHANGED
|
@@ -124,7 +124,7 @@ class Settings(HashableBaseModel):
|
|
|
124
124
|
if (short_names[key.local_name] == 1):
|
|
125
125
|
type_list.append((key.local_name, key.config_type))
|
|
126
126
|
|
|
127
|
-
return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
127
|
+
return typing.Union[*tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
128
128
|
|
|
129
129
|
RegistryHandlerAnnotation = dict[
|
|
130
130
|
str,
|
|
@@ -169,7 +169,7 @@ class Settings(HashableBaseModel):
|
|
|
169
169
|
if (not os.path.exists(configuration_file)):
|
|
170
170
|
loaded_config = {}
|
|
171
171
|
else:
|
|
172
|
-
with open(file_path,
|
|
172
|
+
with open(file_path, encoding="utf-8") as f:
|
|
173
173
|
try:
|
|
174
174
|
loaded_config = json.load(f)
|
|
175
175
|
except Exception as e:
|
|
@@ -62,7 +62,7 @@ class CodeExecutionResponse(Response):
|
|
|
62
62
|
super().__init__(status=status_code, mimetype="application/json", response=result.model_dump_json())
|
|
63
63
|
|
|
64
64
|
@classmethod
|
|
65
|
-
def with_error(cls, status_code: int, error_message: str) ->
|
|
65
|
+
def with_error(cls, status_code: int, error_message: str) -> CodeExecutionResponse:
|
|
66
66
|
return cls(status_code,
|
|
67
67
|
CodeExecutionResult(process_status=CodeExecutionStatus.ERROR, stdout="", stderr=error_message))
|
|
68
68
|
|
nat/tool/datetime_tools.py
CHANGED
|
@@ -72,7 +72,7 @@ async def current_datetime(_config: CurrentTimeToolConfig, _builder: Builder):
|
|
|
72
72
|
timezone_obj = _get_timezone_obj(headers)
|
|
73
73
|
|
|
74
74
|
now = datetime.datetime.now(timezone_obj)
|
|
75
|
-
now_machine_readable = now.strftime(
|
|
75
|
+
now_machine_readable = now.strftime("%Y-%m-%d %H:%M:%S %z")
|
|
76
76
|
|
|
77
77
|
# Returns the current time in machine readable format with timezone offset.
|
|
78
78
|
return f"The current time of day is {now_machine_readable}"
|
|
@@ -310,7 +310,7 @@ def patch_with_retry(
|
|
|
310
310
|
descriptor = inspect.getattr_static(cls, name)
|
|
311
311
|
|
|
312
312
|
# Skip dunders, privates and all descriptors we must not wrap
|
|
313
|
-
if (name.startswith("_") or isinstance(descriptor,
|
|
313
|
+
if (name.startswith("_") or isinstance(descriptor, property | staticmethod | classmethod)):
|
|
314
314
|
continue
|
|
315
315
|
|
|
316
316
|
original = descriptor.__func__ if isinstance(descriptor, types.MethodType) else descriptor
|
nat/utils/io/yaml_tools.py
CHANGED
nat/utils/type_utils.py
CHANGED
|
@@ -250,7 +250,7 @@ class DecomposedType:
|
|
|
250
250
|
remaining_args = tuple(arg for arg in self.args if arg is not types.NoneType)
|
|
251
251
|
|
|
252
252
|
if (len(remaining_args) > 1):
|
|
253
|
-
return DecomposedType(typing.Union[remaining_args])
|
|
253
|
+
return DecomposedType(typing.Union[*remaining_args])
|
|
254
254
|
if (len(remaining_args) == 1):
|
|
255
255
|
return DecomposedType(remaining_args[0])
|
|
256
256
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nvidia-nat
|
|
3
|
-
Version: 1.3.
|
|
3
|
+
Version: 1.3.0a20250930
|
|
4
4
|
Summary: NVIDIA NeMo Agent toolkit
|
|
5
5
|
Author: NVIDIA Corporation
|
|
6
6
|
Maintainer: NVIDIA Corporation
|
|
@@ -232,6 +232,7 @@ Requires-Dist: numpy~=2.3; python_version >= "3.12"
|
|
|
232
232
|
Requires-Dist: openinference-semantic-conventions~=0.1.14
|
|
233
233
|
Requires-Dist: openpyxl~=3.1
|
|
234
234
|
Requires-Dist: optuna~=4.4.0
|
|
235
|
+
Requires-Dist: pip>=24.3.1
|
|
235
236
|
Requires-Dist: pkce==1.0.3
|
|
236
237
|
Requires-Dist: pkginfo~=1.12
|
|
237
238
|
Requires-Dist: platformdirs~=4.3
|