nvidia-nat 1.3.0a20250929__py3-none-any.whl → 1.3.0a20251001__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. nat/agent/base.py +1 -1
  2. nat/agent/rewoo_agent/agent.py +100 -108
  3. nat/agent/rewoo_agent/register.py +4 -1
  4. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +26 -18
  5. nat/builder/builder.py +1 -1
  6. nat/builder/context.py +2 -2
  7. nat/builder/front_end.py +1 -1
  8. nat/cli/cli_utils/config_override.py +1 -1
  9. nat/cli/commands/mcp/mcp.py +2 -2
  10. nat/cli/commands/start.py +1 -1
  11. nat/cli/type_registry.py +1 -1
  12. nat/control_flow/router_agent/register.py +1 -1
  13. nat/data_models/api_server.py +9 -9
  14. nat/data_models/authentication.py +3 -9
  15. nat/data_models/dataset_handler.py +1 -1
  16. nat/eval/evaluator/base_evaluator.py +1 -1
  17. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  18. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  19. nat/experimental/decorators/experimental_warning_decorator.py +1 -2
  20. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  21. nat/front_ends/console/authentication_flow_handler.py +82 -30
  22. nat/front_ends/console/console_front_end_plugin.py +1 -1
  23. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  24. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +188 -2
  25. nat/front_ends/fastapi/job_store.py +2 -2
  26. nat/front_ends/fastapi/message_handler.py +4 -4
  27. nat/front_ends/fastapi/message_validator.py +5 -5
  28. nat/front_ends/mcp/tool_converter.py +1 -1
  29. nat/llm/utils/thinking.py +1 -1
  30. nat/observability/exporter/base_exporter.py +1 -1
  31. nat/observability/exporter/span_exporter.py +1 -1
  32. nat/observability/exporter_manager.py +2 -2
  33. nat/observability/processor/batching_processor.py +1 -1
  34. nat/profiler/decorators/function_tracking.py +2 -2
  35. nat/profiler/parameter_optimization/parameter_selection.py +3 -4
  36. nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
  37. nat/retriever/milvus/retriever.py +1 -1
  38. nat/settings/global_settings.py +2 -2
  39. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  40. nat/tool/datetime_tools.py +1 -1
  41. nat/utils/data_models/schema_validator.py +1 -1
  42. nat/utils/exception_handlers/automatic_retries.py +1 -1
  43. nat/utils/io/yaml_tools.py +1 -1
  44. nat/utils/type_utils.py +1 -1
  45. {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20251001.dist-info}/METADATA +2 -1
  46. {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20251001.dist-info}/RECORD +51 -51
  47. {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20251001.dist-info}/WHEEL +0 -0
  48. {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20251001.dist-info}/entry_points.txt +0 -0
  49. {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20251001.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  50. {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20251001.dist-info}/licenses/LICENSE.md +0 -0
  51. {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20251001.dist-info}/top_level.txt +0 -0
nat/agent/base.py CHANGED
@@ -192,7 +192,7 @@ class BaseAgent(ABC):
192
192
  await asyncio.sleep(sleep_time)
193
193
 
194
194
  # All retries exhausted, return error message
195
- error_content = "Tool call failed after all retry attempts. Last error: %s" % str(last_exception)
195
+ error_content = f"Tool call failed after all retry attempts. Last error: {str(last_exception)}"
196
196
  logger.error("%s %s", AGENT_LOG_PREFIX, error_content, exc_info=True)
197
197
  return ToolMessage(name=tool.name, tool_call_id=tool.name, content=error_content, status="error")
198
198
 
@@ -18,9 +18,7 @@ import json
18
18
  import logging
19
19
  import re
20
20
  from json import JSONDecodeError
21
- from typing import Dict
22
- from typing import List
23
- from typing import Tuple
21
+ from typing import Any
24
22
 
25
23
  from langchain_core.callbacks.base import AsyncCallbackHandler
26
24
  from langchain_core.language_models import BaseChatModel
@@ -47,6 +45,17 @@ from nat.agent.base import BaseAgent
47
45
  logger = logging.getLogger(__name__)
48
46
 
49
47
 
48
+ class ReWOOEvidence(BaseModel):
49
+ placeholder: str
50
+ tool: str
51
+ tool_input: Any
52
+
53
+
54
+ class ReWOOPlanStep(BaseModel):
55
+ plan: str
56
+ evidence: ReWOOEvidence
57
+
58
+
50
59
  class ReWOOGraphState(BaseModel):
51
60
  """State schema for the ReWOO Agent Graph"""
52
61
  messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReWOO Agent
@@ -56,8 +65,8 @@ class ReWOOGraphState(BaseModel):
56
65
  steps: AIMessage = Field(
57
66
  default_factory=lambda: AIMessage(content="")) # the steps to solve the task, parsed from the plan
58
67
  # New fields for parallel execution support
59
- evidence_map: Dict[str, Dict] = Field(default_factory=dict) # mapping from placeholders to step info
60
- execution_levels: List[List[str]] = Field(default_factory=list) # levels for parallel execution
68
+ evidence_map: dict[str, ReWOOPlanStep] = Field(default_factory=dict) # mapping from placeholders to step info
69
+ execution_levels: list[list[str]] = Field(default_factory=list) # levels for parallel execution
61
70
  current_level: int = Field(default=0) # current execution level
62
71
  intermediate_results: dict[str, ToolMessage] = Field(default_factory=dict) # the intermediate results of each step
63
72
  result: AIMessage = Field(
@@ -91,18 +100,15 @@ class ReWOOAgentGraph(BaseAgent):
91
100
  logger.debug(
92
101
  "%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
93
102
  AGENT_LOG_PREFIX)
94
- tool_names = ",".join([tool.name for tool in tools[:-1]]) + ',' + tools[-1].name # prevent trailing ","
95
- if not use_tool_schema:
96
- tool_names_and_descriptions = "\n".join(
97
- [f"{tool.name}: {tool.description}"
98
- for tool in tools[:-1]]) + "\n" + f"{tools[-1].name}: {tools[-1].description}" # prevent trailing "\n"
99
- else:
100
- logger.debug("%s Adding the tools' input schema to the tools' description", AGENT_LOG_PREFIX)
101
- tool_names_and_descriptions = "\n".join([
102
- f"{tool.name}: {tool.description}. {INPUT_SCHEMA_MESSAGE.format(schema=tool.input_schema.model_fields)}"
103
- for tool in tools[:-1]
104
- ]) + "\n" + (f"{tools[-1].name}: {tools[-1].description}. "
105
- f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
103
+
104
+ def describe_tool(tool: BaseTool) -> str:
105
+ description = f"{tool.name}: {tool.description}"
106
+ if use_tool_schema:
107
+ description += f". {INPUT_SCHEMA_MESSAGE.format(schema=tool.input_schema.model_fields)}"
108
+ return description
109
+
110
+ tool_names = ",".join(tool.name for tool in tools)
111
+ tool_names_and_descriptions = "\n".join(describe_tool(tool) for tool in tools)
106
112
 
107
113
  self.planner_prompt = planner_prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
108
114
  self.solver_prompt = solver_prompt
@@ -123,9 +129,12 @@ class ReWOOAgentGraph(BaseAgent):
123
129
  def _get_current_level_status(state: ReWOOGraphState) -> tuple[int, bool]:
124
130
  """
125
131
  Get the current execution level and whether it's complete.
126
- :param state: The ReWOO graph state.
127
- :return: Tuple of (current_level, is_complete). Level -1 means all execution is complete.
128
- :rtype: tuple[int, bool]
132
+
133
+ Args:
134
+ state: The ReWOO graph state.
135
+
136
+ Returns:
137
+ tuple of (current_level, is_complete). Level -1 means all execution is complete.
129
138
  """
130
139
  if not state.execution_levels:
131
140
  return -1, True
@@ -143,63 +152,43 @@ class ReWOOAgentGraph(BaseAgent):
143
152
  return current_level, level_complete
144
153
 
145
154
  @staticmethod
146
- def _parse_planner_output(planner_output: str) -> AIMessage:
155
+ def _parse_planner_output(planner_output: str) -> list[ReWOOPlanStep]:
147
156
 
148
157
  try:
149
- steps = json.loads(planner_output)
150
- except json.JSONDecodeError as ex:
158
+ return [ReWOOPlanStep(**step) for step in json.loads(planner_output)]
159
+ except Exception as ex:
151
160
  raise ValueError(f"The output of planner is invalid JSON format: {planner_output}") from ex
152
161
 
153
- return AIMessage(content=steps)
154
-
155
162
  @staticmethod
156
- def _parse_planner_dependencies(steps: List[Dict]) -> Tuple[Dict[str, Dict], List[List[str]]]:
163
+ def _parse_planner_dependencies(steps: list[ReWOOPlanStep]) -> tuple[dict[str, ReWOOPlanStep], list[list[str]]]:
157
164
  """
158
165
  Parse planner steps to identify dependencies and create execution levels for parallel processing.
159
166
  This creates a dependency map and identifies which evidence placeholders can be executed in parallel.
160
167
 
161
- :param steps: List of plan steps from the planner.
162
- :type steps: List[Dict]
163
- :return: A mapping from evidence placeholders to step info and execution levels for parallel processing.
164
- :rtype: Tuple[Dict[str, Dict], List[List[str]]]
165
- """
166
- evidences = {}
167
- dependence = {}
168
+ Args:
169
+ steps: list of plan steps from the planner.
168
170
 
171
+ Returns:
172
+ A mapping from evidence placeholders to step info and execution levels for parallel processing.
173
+ """
169
174
  # First pass: collect all evidence placeholders and their info
170
- for step in steps:
171
- if "evidence" not in step:
172
- continue
173
-
174
- evidence_info = step["evidence"]
175
- placeholder = evidence_info.get("placeholder", "")
176
-
177
- if placeholder:
178
- # Store the complete step info for this evidence
179
- evidences[placeholder] = {"plan": step.get("plan", ""), "evidence": evidence_info}
175
+ evidences: dict[str, ReWOOPlanStep] = {
176
+ step.evidence.placeholder: step
177
+ for step in steps if step.evidence and step.evidence.placeholder
178
+ }
180
179
 
181
180
  # Second pass: find dependencies now that we have all placeholders
182
- for step in steps:
183
- if "evidence" not in step:
184
- continue
185
-
186
- evidence_info = step["evidence"]
187
- placeholder = evidence_info.get("placeholder", "")
188
- tool_input = evidence_info.get("tool_input", "")
189
-
190
- if placeholder:
191
- # Find dependencies by looking for other placeholders in tool_input
192
- dependence[placeholder] = []
193
-
194
- # Convert tool_input to string to search for placeholders
195
- tool_input_str = str(tool_input)
196
- for var in re.findall(r"#E\d+", tool_input_str):
197
- if var in evidences and var != placeholder:
198
- dependence[placeholder].append(var)
181
+ dependencies = {
182
+ step.evidence.placeholder: [
183
+ var for var in re.findall(r"#E\d+", str(step.evidence.tool_input))
184
+ if var in evidences and var != step.evidence.placeholder
185
+ ]
186
+ for step in steps if step.evidence and step.evidence.placeholder
187
+ }
199
188
 
200
189
  # Create execution levels using topological sort
201
- levels = []
202
- remaining = dict(dependence)
190
+ levels: list[list[str]] = []
191
+ remaining = dict(dependencies)
203
192
 
204
193
  while remaining:
205
194
  # Find items with no dependencies (can be executed in parallel)
@@ -215,10 +204,8 @@ class ReWOOAgentGraph(BaseAgent):
215
204
  remaining.pop(placeholder)
216
205
 
217
206
  # Remove completed items from other dependencies
218
- # for placeholder in remaining.items():
219
- # remaining[placeholder] = [dep for dep in remaining[placeholder] if dep not in ready]
220
207
  for ph, deps in list(remaining.items()):
221
- remaining[ph] = [dep for dep in deps if dep not in ready]
208
+ remaining[ph] = list(set(deps) - set(ready))
222
209
  return evidences, levels
223
210
 
224
211
  @staticmethod
@@ -239,6 +226,7 @@ class ReWOOAgentGraph(BaseAgent):
239
226
 
240
227
  else:
241
228
  assert False, f"Unexpected type for tool_input: {type(tool_input)}"
229
+
242
230
  return tool_input
243
231
 
244
232
  @staticmethod
@@ -293,7 +281,7 @@ class ReWOOAgentGraph(BaseAgent):
293
281
  steps = self._parse_planner_output(str(plan.content))
294
282
 
295
283
  # Parse dependencies and create execution levels for parallel processing
296
- evidence_map, execution_levels = self._parse_planner_dependencies(steps.content)
284
+ evidence_map, execution_levels = self._parse_planner_dependencies(steps)
297
285
 
298
286
  if self.detailed_logs:
299
287
  agent_response_log_message = AGENT_CALL_LOG_MESSAGE % (task, str(plan.content))
@@ -302,10 +290,9 @@ class ReWOOAgentGraph(BaseAgent):
302
290
 
303
291
  return {
304
292
  "plan": plan,
305
- "steps": steps,
306
293
  "evidence_map": evidence_map,
307
294
  "execution_levels": execution_levels,
308
- "current_level": 0
295
+ "current_level": 0,
309
296
  }
310
297
 
311
298
  except Exception as ex:
@@ -339,7 +326,7 @@ class ReWOOAgentGraph(BaseAgent):
339
326
  current_level_placeholders = state.execution_levels[current_level]
340
327
 
341
328
  # Filter to only placeholders not yet completed
342
- pending_placeholders = [p for p in current_level_placeholders if p not in state.intermediate_results]
329
+ pending_placeholders = list(set(current_level_placeholders) - set(state.intermediate_results.keys()))
343
330
 
344
331
  if not pending_placeholders:
345
332
  # All placeholders in this level are done, move to next level
@@ -365,10 +352,8 @@ class ReWOOAgentGraph(BaseAgent):
365
352
  # Process results and update intermediate_results
366
353
  updated_intermediate_results = dict(state.intermediate_results)
367
354
 
368
- for i, result in enumerate(results):
369
- placeholder = pending_placeholders[i]
370
-
371
- if isinstance(result, Exception):
355
+ for placeholder, result in zip(pending_placeholders, results):
356
+ if isinstance(result, BaseException):
372
357
  logger.error("%s Tool execution failed for %s: %s", AGENT_LOG_PREFIX, placeholder, result)
373
358
  # Create error tool message
374
359
  error_message = f"Tool execution failed: {str(result)}"
@@ -398,29 +383,32 @@ class ReWOOAgentGraph(BaseAgent):
398
383
 
399
384
  async def _execute_single_tool(self,
400
385
  placeholder: str,
401
- step_info: Dict,
402
- intermediate_results: Dict[str, ToolMessage]) -> ToolMessage:
386
+ step_info: ReWOOPlanStep,
387
+ intermediate_results: dict[str, ToolMessage]) -> ToolMessage:
403
388
  """
404
389
  Execute a single tool with proper placeholder replacement.
405
390
 
406
- :param placeholder: The evidence placeholder (e.g., "#E1").
407
- :param step_info: Step information containing tool and tool_input.
408
- :param intermediate_results: Current intermediate results for placeholder replacement.
409
- :return: ToolMessage with the tool execution result.
391
+ Args:
392
+ placeholder: The evidence placeholder (e.g., "#E1").
393
+ step_info: Step information containing tool and tool_input.
394
+ intermediate_results: Current intermediate results for placeholder replacement.
395
+
396
+ Returns:
397
+ ToolMessage with the tool execution result.
410
398
  """
411
- evidence_info = step_info["evidence"]
412
- tool_name = evidence_info.get("tool", "")
413
- tool_input = evidence_info.get("tool_input", "")
399
+ evidence_info = step_info.evidence
400
+ tool_name = evidence_info.tool
401
+ tool_input = evidence_info.tool_input
414
402
 
415
403
  # Replace placeholders in tool input with previous results
416
- for _placeholder, _tool_output in intermediate_results.items():
417
- _tool_output_content = _tool_output.content
404
+ for ph_key, tool_output in intermediate_results.items():
405
+ tool_output_content = tool_output.content
418
406
  # If the content is a list, get the first element which should be a dict
419
- if isinstance(_tool_output_content, list):
420
- _tool_output_content = _tool_output_content[0]
421
- assert isinstance(_tool_output_content, dict)
407
+ if isinstance(tool_output_content, list):
408
+ tool_output_content = tool_output_content[0]
409
+ assert isinstance(tool_output_content, dict)
422
410
 
423
- tool_input = self._replace_placeholder(_placeholder, tool_input, _tool_output_content)
411
+ tool_input = self._replace_placeholder(ph_key, tool_input, tool_output_content)
424
412
 
425
413
  # Get the requested tool
426
414
  requested_tool = self._get_tool(tool_name)
@@ -442,10 +430,11 @@ class ReWOOAgentGraph(BaseAgent):
442
430
 
443
431
  # Parse and execute the tool
444
432
  tool_input_parsed = self._parse_tool_input(tool_input)
445
- tool_response = await self._call_tool(requested_tool,
446
- tool_input_parsed,
447
- RunnableConfig(callbacks=self.callbacks),
448
- max_retries=self.tool_call_max_retries)
433
+ tool_response = await self._call_tool(
434
+ requested_tool,
435
+ tool_input_parsed,
436
+ RunnableConfig(callbacks=self.callbacks), # type: ignore
437
+ max_retries=self.tool_call_max_retries)
449
438
 
450
439
  if self.detailed_logs:
451
440
  self._log_tool_response(requested_tool.name, tool_input_parsed, str(tool_response))
@@ -459,20 +448,20 @@ class ReWOOAgentGraph(BaseAgent):
459
448
  plan = ""
460
449
  # Add the tool outputs of each step to the plan using evidence_map
461
450
  for placeholder, step_info in state.evidence_map.items():
462
- evidence_info = step_info["evidence"]
463
- original_tool_input = evidence_info.get("tool_input", "")
464
- tool_name = evidence_info.get("tool", "")
451
+ evidence_info = step_info.evidence
452
+ original_tool_input = evidence_info.tool_input
453
+ tool_name = evidence_info.tool
465
454
 
466
455
  # Replace placeholders in tool input with actual results
467
456
  final_tool_input = original_tool_input
468
- for _placeholder, _tool_output in state.intermediate_results.items():
469
- _tool_output_content = _tool_output.content
457
+ for ph_key, tool_output in state.intermediate_results.items():
458
+ tool_output_content = tool_output.content
470
459
  # If the content is a list, get the first element which should be a dict
471
- if isinstance(_tool_output_content, list):
472
- _tool_output_content = _tool_output_content[0]
473
- assert isinstance(_tool_output_content, dict)
460
+ if isinstance(tool_output_content, list):
461
+ tool_output_content = tool_output_content[0]
462
+ assert isinstance(tool_output_content, dict)
474
463
 
475
- final_tool_input = self._replace_placeholder(_placeholder, final_tool_input, _tool_output_content)
464
+ final_tool_input = self._replace_placeholder(ph_key, final_tool_input, tool_output_content)
476
465
 
477
466
  # Get the final result for this placeholder
478
467
  final_result = ""
@@ -482,14 +471,15 @@ class ReWOOAgentGraph(BaseAgent):
482
471
  result_content = result_content[0]
483
472
  if isinstance(result_content, dict):
484
473
  final_result = str(result_content)
485
- else:
486
- final_result = str(result_content)
487
474
  else:
488
475
  final_result = str(result_content)
489
476
 
490
- step_plan = step_info.get("plan", "")
491
- plan += f"Plan: {step_plan}\n{placeholder} = \
492
- {tool_name}[{final_tool_input}]\nResult: {final_result}\n\n"
477
+ step_plan = step_info.plan
478
+ plan += '\n'.join([
479
+ f"Plan: {step_plan}",
480
+ f"{placeholder} = {tool_name}[{final_tool_input}",
481
+ f"Result: {final_result}\n\n"
482
+ ])
493
483
 
494
484
  task = str(state.task.content)
495
485
  solver_prompt = self.solver_prompt.partial(plan=plan)
@@ -547,8 +537,10 @@ class ReWOOAgentGraph(BaseAgent):
547
537
  graph.add_node("solver", self.solver_node)
548
538
 
549
539
  graph.add_edge("planner", "executor")
550
- conditional_edge_possible_outputs = {AgentDecision.TOOL: "executor", AgentDecision.END: "solver"}
551
- graph.add_conditional_edges("executor", self.conditional_edge, conditional_edge_possible_outputs)
540
+ graph.add_conditional_edges("executor",
541
+ self.conditional_edge, {
542
+ AgentDecision.TOOL: "executor", AgentDecision.END: "solver"
543
+ })
552
544
 
553
545
  graph.set_entry_point("planner")
554
546
  graph.set_finish_point("solver")
@@ -71,8 +71,8 @@ class ReWOOAgentWorkflowConfig(AgentBaseConfig, name="rewoo_agent"):
71
71
 
72
72
  @register_function(config_type=ReWOOAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
73
73
  async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builder):
74
- from langchain.schema import BaseMessage
75
74
  from langchain_core.messages import trim_messages
75
+ from langchain_core.messages.base import BaseMessage
76
76
  from langchain_core.messages.human import HumanMessage
77
77
  from langchain_core.prompts import ChatPromptTemplate
78
78
  from langgraph.graph.state import CompiledStateGraph
@@ -154,6 +154,9 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
154
154
  # get and return the output from the state
155
155
  state = ReWOOGraphState(**state)
156
156
  output_message = state.result.content
157
+ # Ensure output_message is a string
158
+ if isinstance(output_message, list | dict):
159
+ output_message = str(output_message)
157
160
  return ChatResponse.from_string(output_message)
158
161
 
159
162
  except Exception as ex:
@@ -13,10 +13,12 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import logging
17
+ from collections.abc import Callable
18
+ from datetime import UTC
16
19
  from datetime import datetime
17
- from datetime import timezone
18
- from typing import Callable
19
20
 
21
+ import httpx
20
22
  from authlib.integrations.httpx_client import OAuth2Client as AuthlibOAuth2Client
21
23
  from pydantic import SecretStr
22
24
 
@@ -28,6 +30,8 @@ from nat.data_models.authentication import AuthFlowType
28
30
  from nat.data_models.authentication import AuthResult
29
31
  from nat.data_models.authentication import BearerTokenCred
30
32
 
33
+ logger = logging.getLogger(__name__)
34
+
31
35
 
32
36
  class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConfig]):
33
37
 
@@ -41,26 +45,30 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
41
45
  if not isinstance(refresh_token, str):
42
46
  return None
43
47
 
44
- with AuthlibOAuth2Client(
45
- client_id=self.config.client_id,
46
- client_secret=self.config.client_secret,
47
- ) as client:
48
- try:
48
+ try:
49
+ with AuthlibOAuth2Client(
50
+ client_id=self.config.client_id,
51
+ client_secret=self.config.client_secret,
52
+ ) as client:
49
53
  new_token_data = client.refresh_token(self.config.token_url, refresh_token=refresh_token)
50
- except Exception:
51
- # On any failure, we'll fall back to the full auth flow.
52
- return None
53
54
 
54
- expires_at_ts = new_token_data.get("expires_at")
55
- new_expires_at = datetime.fromtimestamp(expires_at_ts, tz=timezone.utc) if expires_at_ts else None
55
+ expires_at_ts = new_token_data.get("expires_at")
56
+ new_expires_at = datetime.fromtimestamp(expires_at_ts, tz=UTC) if expires_at_ts else None
56
57
 
57
- new_auth_result = AuthResult(
58
- credentials=[BearerTokenCred(token=SecretStr(new_token_data["access_token"]))],
59
- token_expires_at=new_expires_at,
60
- raw=new_token_data,
61
- )
58
+ new_auth_result = AuthResult(
59
+ credentials=[BearerTokenCred(token=SecretStr(new_token_data["access_token"]))],
60
+ token_expires_at=new_expires_at,
61
+ raw=new_token_data,
62
+ )
62
63
 
63
- self._authenticated_tokens[user_id] = new_auth_result
64
+ self._authenticated_tokens[user_id] = new_auth_result
65
+ except httpx.HTTPStatusError:
66
+ return None
67
+ except httpx.RequestError:
68
+ return None
69
+ except Exception:
70
+ # On any other failure, we'll fall back to the full auth flow.
71
+ return None
64
72
 
65
73
  return new_auth_result
66
74
 
nat/builder/builder.py CHANGED
@@ -56,7 +56,7 @@ if typing.TYPE_CHECKING:
56
56
  from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
57
57
 
58
58
 
59
- class UserManagerHolder():
59
+ class UserManagerHolder:
60
60
 
61
61
  def __init__(self, context: Context) -> None:
62
62
  self._context = context
nat/builder/context.py CHANGED
@@ -40,12 +40,12 @@ from nat.utils.reactive.subject import Subject
40
40
  class Singleton(type):
41
41
 
42
42
  def __init__(cls, name, bases, dict):
43
- super(Singleton, cls).__init__(name, bases, dict)
43
+ super().__init__(name, bases, dict)
44
44
  cls.instance = None
45
45
 
46
46
  def __call__(cls, *args, **kw):
47
47
  if cls.instance is None:
48
- cls.instance = super(Singleton, cls).__call__(*args, **kw)
48
+ cls.instance = super().__call__(*args, **kw)
49
49
  return cls.instance
50
50
 
51
51
 
nat/builder/front_end.py CHANGED
@@ -37,7 +37,7 @@ class FrontEndBase(typing.Generic[FrontEndConfigT], ABC):
37
37
 
38
38
  super().__init__()
39
39
 
40
- self._full_config: "Config" = full_config
40
+ self._full_config: Config = full_config
41
41
  self._front_end_config: FrontEndConfigT = typing.cast(FrontEndConfigT, full_config.general.front_end)
42
42
 
43
43
  @property
@@ -84,7 +84,7 @@ class LayeredConfig:
84
84
  if lower_value not in ['true', 'false']:
85
85
  raise ValueError(f"Boolean value must be 'true' or 'false', got '{value}'")
86
86
  value = lower_value == 'true'
87
- elif isinstance(original_value, (int, float)):
87
+ elif isinstance(original_value, int | float):
88
88
  value = type(original_value)(value)
89
89
  elif isinstance(original_value, list):
90
90
  value = [v.strip() for v in value.split(',')]
@@ -297,7 +297,7 @@ async def list_tools_via_function_group(
297
297
  if fn is not None:
298
298
  tools.append(to_tool_entry(full, fn))
299
299
  else:
300
- for full, fn in fns.items():
300
+ for full, fn in (await fns).items():
301
301
  tools.append(to_tool_entry(full, fn))
302
302
 
303
303
  return tools
@@ -443,7 +443,7 @@ async def ping_mcp_server(url: str,
443
443
  # Apply timeout to the entire ping operation
444
444
  return await asyncio.wait_for(_ping_operation(), timeout=timeout)
445
445
 
446
- except asyncio.TimeoutError:
446
+ except TimeoutError:
447
447
  return MCPPingResult(url=url,
448
448
  status="unhealthy",
449
449
  response_time_ms=None,
nat/cli/commands/start.py CHANGED
@@ -111,7 +111,7 @@ class StartCommandGroup(click.Group):
111
111
  elif (issubclass(decomposed_type.root, Path)):
112
112
  param_type = click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path)
113
113
 
114
- elif (issubclass(decomposed_type.root, (list, tuple, set))):
114
+ elif (issubclass(decomposed_type.root, list | tuple | set)):
115
115
  if (len(decomposed_type.args) == 1):
116
116
  inner = DecomposedType(decomposed_type.args[0])
117
117
  # Support containers of Literal values -> multiple Choice
nat/cli/type_registry.py CHANGED
@@ -992,7 +992,7 @@ class TypeRegistry:
992
992
  if (short_names[key.local_name] == 1):
993
993
  type_list.append((key.local_name, key.config_type))
994
994
 
995
- return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
995
+ return typing.Union[*tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
996
996
 
997
997
  def compute_annotation(self, cls: type[TypedBaseModelT]):
998
998
 
@@ -81,7 +81,7 @@ async def router_agent_workflow(config: RouterAgentWorkflowConfig, builder: Buil
81
81
  logger.exception("%s Router Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
82
82
  if config.verbose:
83
83
  return str(ex)
84
- return "Router agent failed with exception: %s" % ex
84
+ return f"Router agent failed with exception: {ex}"
85
85
 
86
86
  try:
87
87
  yield FunctionInfo.from_fn(_response_fn, description=config.description)
@@ -273,7 +273,7 @@ class ChatResponse(ResponseBaseModelOutput):
273
273
  if model is None:
274
274
  model = ""
275
275
  if created is None:
276
- created = datetime.datetime.now(datetime.timezone.utc)
276
+ created = datetime.datetime.now(datetime.UTC)
277
277
 
278
278
  return ChatResponse(id=id_,
279
279
  object=object_,
@@ -317,7 +317,7 @@ class ChatResponseChunk(ResponseBaseModelOutput):
317
317
  if id_ is None:
318
318
  id_ = str(uuid.uuid4())
319
319
  if created is None:
320
- created = datetime.datetime.now(datetime.timezone.utc)
320
+ created = datetime.datetime.now(datetime.UTC)
321
321
  if model is None:
322
322
  model = ""
323
323
  if object_ is None:
@@ -343,7 +343,7 @@ class ChatResponseChunk(ResponseBaseModelOutput):
343
343
  if id_ is None:
344
344
  id_ = str(uuid.uuid4())
345
345
  if created is None:
346
- created = datetime.datetime.now(datetime.timezone.utc)
346
+ created = datetime.datetime.now(datetime.UTC)
347
347
  if model is None:
348
348
  model = ""
349
349
 
@@ -485,7 +485,7 @@ class WebSocketUserMessage(BaseModel):
485
485
  security: Security = Security()
486
486
  error: Error = Error()
487
487
  schema_version: str = "1.0.0"
488
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
488
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
489
489
 
490
490
 
491
491
  class WebSocketUserInteractionResponseMessage(BaseModel):
@@ -501,7 +501,7 @@ class WebSocketUserInteractionResponseMessage(BaseModel):
501
501
  security: Security = Security()
502
502
  error: Error = Error()
503
503
  schema_version: str = "1.0.0"
504
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
504
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
505
505
 
506
506
 
507
507
  class SystemIntermediateStepContent(BaseModel):
@@ -527,7 +527,7 @@ class WebSocketSystemIntermediateStepMessage(BaseModel):
527
527
  conversation_id: str | None = None
528
528
  content: SystemIntermediateStepContent
529
529
  status: WebSocketMessageStatus
530
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
530
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
531
531
 
532
532
 
533
533
  class SystemResponseContent(BaseModel):
@@ -551,7 +551,7 @@ class WebSocketSystemResponseTokenMessage(BaseModel):
551
551
  conversation_id: str | None = None
552
552
  content: SystemResponseContent | Error | GenerateResponse
553
553
  status: WebSocketMessageStatus
554
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
554
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
555
555
 
556
556
  @field_validator("content")
557
557
  @classmethod
@@ -560,7 +560,7 @@ class WebSocketSystemResponseTokenMessage(BaseModel):
560
560
  raise ValueError(f"Field: content must be 'Error' when type is {WebSocketMessageType.ERROR_MESSAGE}")
561
561
 
562
562
  if info.data.get("type") == WebSocketMessageType.RESPONSE_MESSAGE and not isinstance(
563
- value, (SystemResponseContent, GenerateResponse)):
563
+ value, SystemResponseContent | GenerateResponse):
564
564
  raise ValueError(
565
565
  f"Field: content must be 'SystemResponseContent' when type is {WebSocketMessageType.RESPONSE_MESSAGE}")
566
566
  return value
@@ -582,7 +582,7 @@ class WebSocketSystemInteractionMessage(BaseModel):
582
582
  conversation_id: str | None = None
583
583
  content: HumanPrompt
584
584
  status: WebSocketMessageStatus
585
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
585
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
586
586
 
587
587
 
588
588
  # ======== GenerateResponse Converters ========