nvidia-nat 1.3.0a20250928__py3-none-any.whl → 1.3.0a20250930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. nat/agent/base.py +1 -1
  2. nat/agent/rewoo_agent/agent.py +298 -118
  3. nat/agent/rewoo_agent/prompt.py +19 -22
  4. nat/agent/rewoo_agent/register.py +4 -1
  5. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +26 -18
  6. nat/builder/builder.py +1 -1
  7. nat/builder/context.py +2 -2
  8. nat/builder/front_end.py +1 -1
  9. nat/cli/cli_utils/config_override.py +1 -1
  10. nat/cli/commands/mcp/mcp.py +2 -2
  11. nat/cli/commands/start.py +1 -1
  12. nat/cli/type_registry.py +1 -1
  13. nat/control_flow/router_agent/register.py +1 -1
  14. nat/data_models/api_server.py +9 -9
  15. nat/data_models/authentication.py +3 -9
  16. nat/data_models/dataset_handler.py +1 -1
  17. nat/eval/evaluator/base_evaluator.py +1 -1
  18. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  19. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  20. nat/experimental/decorators/experimental_warning_decorator.py +1 -2
  21. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  22. nat/front_ends/console/authentication_flow_handler.py +82 -30
  23. nat/front_ends/console/console_front_end_plugin.py +1 -1
  24. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  25. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +188 -2
  26. nat/front_ends/fastapi/job_store.py +2 -2
  27. nat/front_ends/fastapi/message_handler.py +4 -4
  28. nat/front_ends/fastapi/message_validator.py +5 -5
  29. nat/front_ends/mcp/tool_converter.py +1 -1
  30. nat/llm/utils/thinking.py +1 -1
  31. nat/observability/exporter/base_exporter.py +1 -1
  32. nat/observability/exporter/span_exporter.py +1 -1
  33. nat/observability/exporter_manager.py +2 -2
  34. nat/observability/processor/batching_processor.py +1 -1
  35. nat/profiler/decorators/function_tracking.py +2 -2
  36. nat/profiler/parameter_optimization/parameter_selection.py +3 -4
  37. nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
  38. nat/retriever/milvus/retriever.py +1 -1
  39. nat/settings/global_settings.py +2 -2
  40. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  41. nat/tool/datetime_tools.py +1 -1
  42. nat/utils/data_models/schema_validator.py +1 -1
  43. nat/utils/exception_handlers/automatic_retries.py +1 -1
  44. nat/utils/io/yaml_tools.py +1 -1
  45. nat/utils/type_utils.py +1 -1
  46. {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/METADATA +2 -1
  47. {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/RECORD +52 -52
  48. {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/WHEEL +0 -0
  49. {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/entry_points.txt +0 -0
  50. {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  51. {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/licenses/LICENSE.md +0 -0
  52. {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/top_level.txt +0 -0
@@ -25,18 +25,21 @@ from collections.abc import Callable
25
25
  from contextlib import asynccontextmanager
26
26
  from pathlib import Path
27
27
 
28
+ import httpx
29
+ from authlib.common.errors import AuthlibBaseError as OAuthError
28
30
  from fastapi import Body
29
31
  from fastapi import FastAPI
32
+ from fastapi import HTTPException
30
33
  from fastapi import Request
31
34
  from fastapi import Response
32
35
  from fastapi import UploadFile
33
- from fastapi.exceptions import HTTPException
34
36
  from fastapi.middleware.cors import CORSMiddleware
35
37
  from fastapi.responses import StreamingResponse
36
38
  from pydantic import BaseModel
37
39
  from pydantic import Field
38
40
  from starlette.websockets import WebSocket
39
41
 
42
+ from nat.builder.function import Function
40
43
  from nat.builder.workflow_builder import WorkflowBuilder
41
44
  from nat.data_models.api_server import ChatRequest
42
45
  from nat.data_models.api_server import ChatResponse
@@ -241,6 +244,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
241
244
  await self.add_evaluate_route(app, SessionManager(await builder.build()))
242
245
  await self.add_static_files_route(app, builder)
243
246
  await self.add_authorization_route(app)
247
+ await self.add_mcp_client_tool_list_route(app, builder)
244
248
 
245
249
  for ep in self.front_end_config.endpoints:
246
250
 
@@ -1071,8 +1075,13 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
1071
1075
  code_verifier=verifier,
1072
1076
  state=state)
1073
1077
  flow_state.future.set_result(res)
1078
+ except OAuthError as e:
1079
+ flow_state.future.set_exception(
1080
+ RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})"))
1081
+ except httpx.HTTPError as e:
1082
+ flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}"))
1074
1083
  except Exception as e:
1075
- flow_state.future.set_exception(e)
1084
+ flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}"))
1076
1085
 
1077
1086
  return HTMLResponse(content=AUTH_REDIRECT_SUCCESS_HTML,
1078
1087
  status_code=200,
@@ -1088,6 +1097,183 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
1088
1097
  methods=["GET"],
1089
1098
  description="Handles the authorization code and state returned from the Authorization Code Grant Flow.")
1090
1099
 
1100
+ async def add_mcp_client_tool_list_route(self, app: FastAPI, builder: WorkflowBuilder):
1101
+ """Add the MCP client tool list endpoint to the FastAPI app."""
1102
+ from typing import Any
1103
+
1104
+ from pydantic import BaseModel
1105
+
1106
+ class MCPToolInfo(BaseModel):
1107
+ name: str
1108
+ description: str
1109
+ server: str
1110
+ available: bool
1111
+
1112
+ class MCPClientToolListResponse(BaseModel):
1113
+ mcp_clients: list[dict[str, Any]]
1114
+
1115
+ async def get_mcp_client_tool_list() -> MCPClientToolListResponse:
1116
+ """
1117
+ Get the list of MCP tools from all MCP clients in the workflow configuration.
1118
+ Checks session health and compares with workflow function group configuration.
1119
+ """
1120
+ mcp_clients_info = []
1121
+
1122
+ try:
1123
+ # Get all function groups from the builder
1124
+ function_groups = builder._function_groups
1125
+
1126
+ # Find MCP client function groups
1127
+ for group_name, configured_group in function_groups.items():
1128
+ if configured_group.config.type != "mcp_client":
1129
+ continue
1130
+
1131
+ from nat.plugins.mcp.client_impl import MCPClientConfig
1132
+
1133
+ config = configured_group.config
1134
+ assert isinstance(config, MCPClientConfig)
1135
+
1136
+ # Reuse the existing MCP client session stored on the function group instance
1137
+ group_instance = configured_group.instance
1138
+
1139
+ client = group_instance.mcp_client
1140
+ if client is None:
1141
+ raise RuntimeError(f"MCP client not found for group {group_name}")
1142
+
1143
+ try:
1144
+ session_healthy = False
1145
+ server_tools: dict[str, Any] = {}
1146
+
1147
+ try:
1148
+ server_tools = await client.get_tools()
1149
+ session_healthy = True
1150
+ except Exception as e:
1151
+ logger.exception(f"Failed to connect to MCP server {client.server_name}: {e}")
1152
+ session_healthy = False
1153
+
1154
+ # Get workflow function group configuration (configured client-side tools)
1155
+ configured_short_names: set[str] = set()
1156
+ configured_full_to_fn: dict[str, Function] = {}
1157
+ try:
1158
+ # Pass a no-op filter function to bypass any default filtering that might check
1159
+ # health status, preventing potential infinite recursion during health status checks.
1160
+ async def pass_through_filter(fn):
1161
+ return fn
1162
+
1163
+ accessible_functions = await group_instance.get_accessible_functions(
1164
+ filter_fn=pass_through_filter)
1165
+ configured_full_to_fn = accessible_functions
1166
+ configured_short_names = {name.split('.', 1)[1] for name in accessible_functions.keys()}
1167
+ except Exception as e:
1168
+ logger.exception(f"Failed to get accessible functions for group {group_name}: {e}")
1169
+
1170
+ # Build alias->original mapping and override configs from overrides
1171
+ alias_to_original: dict[str, str] = {}
1172
+ override_configs: dict[str, Any] = {}
1173
+ try:
1174
+ if config.tool_overrides is not None:
1175
+ for orig_name, override in config.tool_overrides.items():
1176
+ if override.alias is not None:
1177
+ alias_to_original[override.alias] = orig_name
1178
+ override_configs[override.alias] = override
1179
+ else:
1180
+ override_configs[orig_name] = override
1181
+ except Exception:
1182
+ pass
1183
+
1184
+ # Create tool info list (always return configured tools; mark availability)
1185
+ tools_info: list[dict[str, Any]] = []
1186
+ available_count = 0
1187
+ for wf_fn, fn_short in zip(configured_full_to_fn.values(), configured_short_names):
1188
+ orig_name = alias_to_original.get(fn_short, fn_short)
1189
+ available = session_healthy and (orig_name in server_tools)
1190
+ if available:
1191
+ available_count += 1
1192
+
1193
+ # Prefer tool override description, then workflow function description,
1194
+ # then server description
1195
+ description = ""
1196
+ if fn_short in override_configs and override_configs[fn_short].description:
1197
+ description = override_configs[fn_short].description
1198
+ elif wf_fn.description:
1199
+ description = wf_fn.description
1200
+ elif available and orig_name in server_tools:
1201
+ description = server_tools[orig_name].description or ""
1202
+
1203
+ tools_info.append(
1204
+ MCPToolInfo(name=fn_short,
1205
+ description=description or "",
1206
+ server=client.server_name,
1207
+ available=available).model_dump())
1208
+
1209
+ # Sort tools_info by name to maintain consistent ordering
1210
+ tools_info.sort(key=lambda x: x['name'])
1211
+
1212
+ mcp_clients_info.append({
1213
+ "function_group": group_name,
1214
+ "server": client.server_name,
1215
+ "transport": config.server.transport,
1216
+ "session_healthy": session_healthy,
1217
+ "tools": tools_info,
1218
+ "total_tools": len(configured_short_names),
1219
+ "available_tools": available_count
1220
+ })
1221
+
1222
+ except Exception as e:
1223
+ logger.error(f"Error processing MCP client {group_name}: {e}")
1224
+ mcp_clients_info.append({
1225
+ "function_group": group_name,
1226
+ "server": "unknown",
1227
+ "transport": config.server.transport if config.server else "unknown",
1228
+ "session_healthy": False,
1229
+ "error": str(e),
1230
+ "tools": [],
1231
+ "total_tools": 0,
1232
+ "workflow_tools": 0
1233
+ })
1234
+
1235
+ return MCPClientToolListResponse(mcp_clients=mcp_clients_info)
1236
+
1237
+ except Exception as e:
1238
+ logger.error(f"Error in MCP client tool list endpoint: {e}")
1239
+ raise HTTPException(status_code=500, detail=f"Failed to retrieve MCP client information: {str(e)}")
1240
+
1241
+ # Add the route to the FastAPI app
1242
+ app.add_api_route(
1243
+ path="/mcp/client/tool/list",
1244
+ endpoint=get_mcp_client_tool_list,
1245
+ methods=["GET"],
1246
+ response_model=MCPClientToolListResponse,
1247
+ description="Get list of MCP client tools with session health and workflow configuration comparison",
1248
+ responses={
1249
+ 200: {
1250
+ "description": "Successfully retrieved MCP client tool information",
1251
+ "content": {
1252
+ "application/json": {
1253
+ "example": {
1254
+ "mcp_clients": [{
1255
+ "function_group": "mcp_tools",
1256
+ "server": "streamable-http:http://localhost:9901/mcp",
1257
+ "transport": "streamable-http",
1258
+ "session_healthy": True,
1259
+ "tools": [{
1260
+ "name": "tool_a",
1261
+ "description": "Tool A description",
1262
+ "server": "streamable-http:http://localhost:9901/mcp",
1263
+ "available": True
1264
+ }],
1265
+ "total_tools": 1,
1266
+ "available_tools": 1
1267
+ }]
1268
+ }
1269
+ }
1270
+ }
1271
+ },
1272
+ 500: {
1273
+ "description": "Internal Server Error"
1274
+ }
1275
+ })
1276
+
1091
1277
  async def _add_flow(self, state: str, flow_state: FlowState):
1092
1278
  async with self._outstanding_flows_lock:
1093
1279
  self._outstanding_flows[state] = flow_state
@@ -370,7 +370,7 @@ class JobStore(DaskClientMixin):
370
370
  # Convert BaseModel to JSON string for storage
371
371
  output = output.model_dump_json(round_trip=True)
372
372
 
373
- if isinstance(output, (dict, list)):
373
+ if isinstance(output, dict | list):
374
374
  # Convert dict or list to JSON string for storage
375
375
  output = json.dumps(output)
376
376
 
@@ -555,7 +555,7 @@ class JobStore(DaskClientMixin):
555
555
  logger.exception("Failed to expire %s", job_id)
556
556
 
557
557
  await session.execute(
558
- (update(JobInfo).where(JobInfo.job_id.in_(successfully_expired)).values(is_expired=True)))
558
+ update(JobInfo).where(JobInfo.job_id.in_(successfully_expired)).values(is_expired=True))
559
559
 
560
560
 
561
561
  def get_db_engine(db_url: str | None = None, echo: bool = False, use_async: bool = True) -> "Engine | AsyncEngine":
@@ -105,10 +105,10 @@ class WebSocketMessageHandler:
105
105
  if (isinstance(validated_message, WebSocketUserMessage)):
106
106
  await self.process_workflow_request(validated_message)
107
107
 
108
- elif isinstance(validated_message,
109
- (WebSocketSystemResponseTokenMessage,
110
- WebSocketSystemIntermediateStepMessage,
111
- WebSocketSystemInteractionMessage)):
108
+ elif isinstance(
109
+ validated_message,
110
+ WebSocketSystemResponseTokenMessage | WebSocketSystemIntermediateStepMessage
111
+ | WebSocketSystemInteractionMessage):
112
112
  # These messages are already handled by self.create_websocket_message(data_model=value, …)
113
113
  # No further processing is needed here.
114
114
  pass
@@ -139,7 +139,7 @@ class MessageValidator:
139
139
  text_content: str = str(data_model.payload)
140
140
  validated_message_content = SystemResponseContent(text=text_content)
141
141
 
142
- elif (isinstance(data_model, (ChatResponse, ChatResponseChunk))):
142
+ elif (isinstance(data_model, ChatResponse | ChatResponseChunk)):
143
143
  validated_message_content = SystemResponseContent(text=data_model.choices[0].message.content)
144
144
 
145
145
  elif (isinstance(data_model, ResponseIntermediateStep)):
@@ -204,7 +204,7 @@ class MessageValidator:
204
204
 
205
205
  validated_message_type: str = ""
206
206
  try:
207
- if (isinstance(data_model, (ResponsePayloadOutput, ChatResponse, ChatResponseChunk))):
207
+ if (isinstance(data_model, ResponsePayloadOutput | ChatResponse | ChatResponseChunk)):
208
208
  validated_message_type = WebSocketMessageType.RESPONSE_MESSAGE
209
209
 
210
210
  elif (isinstance(data_model, ResponseIntermediateStep)):
@@ -241,7 +241,7 @@ class MessageValidator:
241
241
  content: SystemResponseContent
242
242
  | Error = SystemResponseContent(),
243
243
  status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
244
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
244
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
245
245
  ) -> WebSocketSystemResponseTokenMessage | None:
246
246
  """
247
247
  Creates a system response token message with default values.
@@ -280,7 +280,7 @@ class MessageValidator:
280
280
  conversation_id: str | None = None,
281
281
  content: SystemIntermediateStepContent = SystemIntermediateStepContent(name="default", payload="default"),
282
282
  status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
283
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
283
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
284
284
  ) -> WebSocketSystemIntermediateStepMessage | None:
285
285
  """
286
286
  Creates a system intermediate step message with default values.
@@ -320,7 +320,7 @@ class MessageValidator:
320
320
  conversation_id: str | None = None,
321
321
  content: HumanPrompt,
322
322
  status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
323
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
323
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
324
324
  ) -> WebSocketSystemInteractionMessage | None:
325
325
  """
326
326
  Creates a system interaction message with default values.
@@ -175,7 +175,7 @@ def create_function_wrapper(
175
175
  # Handle different result types for proper formatting
176
176
  if isinstance(result, str):
177
177
  return result
178
- if isinstance(result, (dict, list)):
178
+ if isinstance(result, dict | list):
179
179
  return json.dumps(result, default=str)
180
180
  return str(result)
181
181
  except Exception as e:
nat/llm/utils/thinking.py CHANGED
@@ -19,10 +19,10 @@ import logging
19
19
  import types
20
20
  from abc import abstractmethod
21
21
  from collections.abc import AsyncGenerator
22
+ from collections.abc import Callable
22
23
  from collections.abc import Iterable
23
24
  from dataclasses import dataclass
24
25
  from typing import Any
25
- from typing import Callable
26
26
  from typing import TypeVar
27
27
 
28
28
  ModelType = TypeVar("ModelType")
@@ -372,7 +372,7 @@ class BaseExporter(Exporter):
372
372
  try:
373
373
  # Wait for all tasks to complete with a timeout
374
374
  await asyncio.wait_for(asyncio.gather(*self._tasks, return_exceptions=True), timeout=timeout)
375
- except asyncio.TimeoutError:
375
+ except TimeoutError:
376
376
  logger.warning("%s: Some tasks did not complete within %s seconds", self.name, timeout)
377
377
  except Exception as e:
378
378
  logger.exception("%s: Error while waiting for tasks: %s", self.name, e)
@@ -252,7 +252,7 @@ class SpanExporter(ProcessingExporter[InputSpanT, OutputSpanT], SerializeMixin):
252
252
 
253
253
  end_metadata = event.payload.metadata or {}
254
254
 
255
- if not isinstance(end_metadata, (dict, TraceMetadata)):
255
+ if not isinstance(end_metadata, dict | TraceMetadata):
256
256
  logger.warning("Invalid metadata type for step %s", event.UUID)
257
257
  return
258
258
 
@@ -184,7 +184,7 @@ class ExporterManager:
184
184
  try:
185
185
  await asyncio.wait_for(asyncio.gather(*cleanup_tasks, return_exceptions=True),
186
186
  timeout=self._shutdown_timeout)
187
- except asyncio.TimeoutError:
187
+ except TimeoutError:
188
188
  logger.warning("Some isolated exporters did not clean up within timeout")
189
189
 
190
190
  self._active_isolated_exporters.clear()
@@ -301,7 +301,7 @@ class ExporterManager:
301
301
  try:
302
302
  task.cancel()
303
303
  await asyncio.wait_for(task, timeout=self._shutdown_timeout)
304
- except asyncio.TimeoutError:
304
+ except TimeoutError:
305
305
  logger.warning("Exporter '%s' task did not shut down in time and may be stuck.", name)
306
306
  stuck_tasks.append(name)
307
307
  except asyncio.CancelledError:
@@ -241,7 +241,7 @@ class BatchingProcessor(CallbackProcessor[T, list[T]], Generic[T]):
241
241
  try:
242
242
  await asyncio.wait_for(self._shutdown_complete_event.wait(), timeout=self._shutdown_timeout)
243
243
  logger.debug("Shutdown completion detected via event")
244
- except asyncio.TimeoutError:
244
+ except TimeoutError:
245
245
  logger.warning("Shutdown completion timeout exceeded (%s seconds)", self._shutdown_timeout)
246
246
  return
247
247
 
@@ -40,10 +40,10 @@ def _serialize_data(obj: Any) -> Any:
40
40
 
41
41
  if isinstance(obj, dict):
42
42
  return {str(k): _serialize_data(v) for k, v in obj.items()}
43
- if isinstance(obj, (list, tuple, set)):
43
+ if isinstance(obj, list | tuple | set):
44
44
  return [_serialize_data(item) for item in obj]
45
45
 
46
- if isinstance(obj, (str, int, float, bool, type(None))):
46
+ if isinstance(obj, str | int | float | bool | type(None)):
47
47
  return obj
48
48
 
49
49
  # Fallback
@@ -13,8 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Optional
17
- from typing import Sequence
16
+ from collections.abc import Sequence
18
17
 
19
18
  import numpy as np
20
19
  import optuna
@@ -41,8 +40,8 @@ def pick_trial(
41
40
  study: Study,
42
41
  mode: str = "harmonic",
43
42
  *,
44
- weights: Optional[Sequence[float]] = None,
45
- ref_point: Optional[Sequence[float]] = None,
43
+ weights: Sequence[float] | None = None,
44
+ ref_point: Sequence[float] | None = None,
46
45
  eps: float = 1e-12,
47
46
  ) -> optuna.trial.FrozenTrial:
48
47
  """
@@ -324,7 +324,7 @@ def create_pareto_visualization(data_source: optuna.Study | Path | pd.DataFrame,
324
324
  if hasattr(data_source, 'trials_dataframe'):
325
325
  # Optuna study object
326
326
  trials_df, pareto_trials_df = load_trials_from_study(data_source)
327
- elif isinstance(data_source, (str, Path)):
327
+ elif isinstance(data_source, str | Path):
328
328
  # CSV file path
329
329
  trials_df, pareto_trials_df = load_trials_from_csv(Path(data_source), metric_names, directions)
330
330
  elif isinstance(data_source, pd.DataFrame):
@@ -214,7 +214,7 @@ def _wrap_milvus_results(res: list[Hit], content_field: str):
214
214
 
215
215
 
216
216
  def _wrap_milvus_single_results(res: Hit | dict, content_field: str) -> Document:
217
- if not isinstance(res, (Hit, dict)):
217
+ if not isinstance(res, Hit | dict):
218
218
  raise ValueError(f"Milvus search returned object of type {type(res)}. Expected 'Hit' or 'dict'.")
219
219
 
220
220
  if isinstance(res, Hit):
@@ -124,7 +124,7 @@ class Settings(HashableBaseModel):
124
124
  if (short_names[key.local_name] == 1):
125
125
  type_list.append((key.local_name, key.config_type))
126
126
 
127
- return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
127
+ return typing.Union[*tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
128
128
 
129
129
  RegistryHandlerAnnotation = dict[
130
130
  str,
@@ -169,7 +169,7 @@ class Settings(HashableBaseModel):
169
169
  if (not os.path.exists(configuration_file)):
170
170
  loaded_config = {}
171
171
  else:
172
- with open(file_path, mode="r", encoding="utf-8") as f:
172
+ with open(file_path, encoding="utf-8") as f:
173
173
  try:
174
174
  loaded_config = json.load(f)
175
175
  except Exception as e:
@@ -62,7 +62,7 @@ class CodeExecutionResponse(Response):
62
62
  super().__init__(status=status_code, mimetype="application/json", response=result.model_dump_json())
63
63
 
64
64
  @classmethod
65
- def with_error(cls, status_code: int, error_message: str) -> 'CodeExecutionResponse':
65
+ def with_error(cls, status_code: int, error_message: str) -> CodeExecutionResponse:
66
66
  return cls(status_code,
67
67
  CodeExecutionResult(process_status=CodeExecutionStatus.ERROR, stdout="", stderr=error_message))
68
68
 
@@ -72,7 +72,7 @@ async def current_datetime(_config: CurrentTimeToolConfig, _builder: Builder):
72
72
  timezone_obj = _get_timezone_obj(headers)
73
73
 
74
74
  now = datetime.datetime.now(timezone_obj)
75
- now_machine_readable = now.strftime(("%Y-%m-%d %H:%M:%S %z"))
75
+ now_machine_readable = now.strftime("%Y-%m-%d %H:%M:%S %z")
76
76
 
77
77
  # Returns the current time in machine readable format with timezone offset.
78
78
  return f"The current time of day is {now_machine_readable}"
@@ -52,7 +52,7 @@ def validate_yaml(ctx, param, value):
52
52
  if value is None:
53
53
  return None
54
54
 
55
- with open(value, 'r', encoding="utf-8") as f:
55
+ with open(value, encoding="utf-8") as f:
56
56
  yaml.safe_load(f)
57
57
 
58
58
  return value
@@ -310,7 +310,7 @@ def patch_with_retry(
310
310
  descriptor = inspect.getattr_static(cls, name)
311
311
 
312
312
  # Skip dunders, privates and all descriptors we must not wrap
313
- if (name.startswith("_") or isinstance(descriptor, (property, staticmethod, classmethod))):
313
+ if (name.startswith("_") or isinstance(descriptor, property | staticmethod | classmethod)):
314
314
  continue
315
315
 
316
316
  original = descriptor.__func__ if isinstance(descriptor, types.MethodType) else descriptor
@@ -57,7 +57,7 @@ def yaml_load(config_path: StrPath) -> dict:
57
57
  """
58
58
 
59
59
  # Read YAML file
60
- with open(config_path, "r", encoding="utf-8") as stream:
60
+ with open(config_path, encoding="utf-8") as stream:
61
61
  config_str = stream.read()
62
62
 
63
63
  return yaml_loads(config_str)
nat/utils/type_utils.py CHANGED
@@ -250,7 +250,7 @@ class DecomposedType:
250
250
  remaining_args = tuple(arg for arg in self.args if arg is not types.NoneType)
251
251
 
252
252
  if (len(remaining_args) > 1):
253
- return DecomposedType(typing.Union[remaining_args])
253
+ return DecomposedType(typing.Union[*remaining_args])
254
254
  if (len(remaining_args) == 1):
255
255
  return DecomposedType(remaining_args[0])
256
256
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-nat
3
- Version: 1.3.0a20250928
3
+ Version: 1.3.0a20250930
4
4
  Summary: NVIDIA NeMo Agent toolkit
5
5
  Author: NVIDIA Corporation
6
6
  Maintainer: NVIDIA Corporation
@@ -232,6 +232,7 @@ Requires-Dist: numpy~=2.3; python_version >= "3.12"
232
232
  Requires-Dist: openinference-semantic-conventions~=0.1.14
233
233
  Requires-Dist: openpyxl~=3.1
234
234
  Requires-Dist: optuna~=4.4.0
235
+ Requires-Dist: pip>=24.3.1
235
236
  Requires-Dist: pkce==1.0.3
236
237
  Requires-Dist: pkginfo~=1.12
237
238
  Requires-Dist: platformdirs~=4.3