nvidia-nat 1.3.0a20250929__py3-none-any.whl → 1.3.0a20251001__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +1 -1
- nat/agent/rewoo_agent/agent.py +100 -108
- nat/agent/rewoo_agent/register.py +4 -1
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +26 -18
- nat/builder/builder.py +1 -1
- nat/builder/context.py +2 -2
- nat/builder/front_end.py +1 -1
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/commands/mcp/mcp.py +2 -2
- nat/cli/commands/start.py +1 -1
- nat/cli/type_registry.py +1 -1
- nat/control_flow/router_agent/register.py +1 -1
- nat/data_models/api_server.py +9 -9
- nat/data_models/authentication.py +3 -9
- nat/data_models/dataset_handler.py +1 -1
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/swe_bench_evaluator/evaluate.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
- nat/experimental/decorators/experimental_warning_decorator.py +1 -2
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +188 -2
- nat/front_ends/fastapi/job_store.py +2 -2
- nat/front_ends/fastapi/message_handler.py +4 -4
- nat/front_ends/fastapi/message_validator.py +5 -5
- nat/front_ends/mcp/tool_converter.py +1 -1
- nat/llm/utils/thinking.py +1 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +2 -2
- nat/observability/processor/batching_processor.py +1 -1
- nat/profiler/decorators/function_tracking.py +2 -2
- nat/profiler/parameter_optimization/parameter_selection.py +3 -4
- nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
- nat/retriever/milvus/retriever.py +1 -1
- nat/settings/global_settings.py +2 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/datetime_tools.py +1 -1
- nat/utils/data_models/schema_validator.py +1 -1
- nat/utils/exception_handlers/automatic_retries.py +1 -1
- nat/utils/io/yaml_tools.py +1 -1
- nat/utils/type_utils.py +1 -1
- {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20251001.dist-info}/METADATA +2 -1
- {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20251001.dist-info}/RECORD +51 -51
- {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20251001.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20251001.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20251001.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20251001.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250929.dist-info → nvidia_nat-1.3.0a20251001.dist-info}/top_level.txt +0 -0
nat/agent/base.py
CHANGED
|
@@ -192,7 +192,7 @@ class BaseAgent(ABC):
|
|
|
192
192
|
await asyncio.sleep(sleep_time)
|
|
193
193
|
|
|
194
194
|
# All retries exhausted, return error message
|
|
195
|
-
error_content = "Tool call failed after all retry attempts. Last error:
|
|
195
|
+
error_content = f"Tool call failed after all retry attempts. Last error: {str(last_exception)}"
|
|
196
196
|
logger.error("%s %s", AGENT_LOG_PREFIX, error_content, exc_info=True)
|
|
197
197
|
return ToolMessage(name=tool.name, tool_call_id=tool.name, content=error_content, status="error")
|
|
198
198
|
|
nat/agent/rewoo_agent/agent.py
CHANGED
|
@@ -18,9 +18,7 @@ import json
|
|
|
18
18
|
import logging
|
|
19
19
|
import re
|
|
20
20
|
from json import JSONDecodeError
|
|
21
|
-
from typing import
|
|
22
|
-
from typing import List
|
|
23
|
-
from typing import Tuple
|
|
21
|
+
from typing import Any
|
|
24
22
|
|
|
25
23
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
|
26
24
|
from langchain_core.language_models import BaseChatModel
|
|
@@ -47,6 +45,17 @@ from nat.agent.base import BaseAgent
|
|
|
47
45
|
logger = logging.getLogger(__name__)
|
|
48
46
|
|
|
49
47
|
|
|
48
|
+
class ReWOOEvidence(BaseModel):
|
|
49
|
+
placeholder: str
|
|
50
|
+
tool: str
|
|
51
|
+
tool_input: Any
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ReWOOPlanStep(BaseModel):
|
|
55
|
+
plan: str
|
|
56
|
+
evidence: ReWOOEvidence
|
|
57
|
+
|
|
58
|
+
|
|
50
59
|
class ReWOOGraphState(BaseModel):
|
|
51
60
|
"""State schema for the ReWOO Agent Graph"""
|
|
52
61
|
messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReWOO Agent
|
|
@@ -56,8 +65,8 @@ class ReWOOGraphState(BaseModel):
|
|
|
56
65
|
steps: AIMessage = Field(
|
|
57
66
|
default_factory=lambda: AIMessage(content="")) # the steps to solve the task, parsed from the plan
|
|
58
67
|
# New fields for parallel execution support
|
|
59
|
-
evidence_map:
|
|
60
|
-
execution_levels:
|
|
68
|
+
evidence_map: dict[str, ReWOOPlanStep] = Field(default_factory=dict) # mapping from placeholders to step info
|
|
69
|
+
execution_levels: list[list[str]] = Field(default_factory=list) # levels for parallel execution
|
|
61
70
|
current_level: int = Field(default=0) # current execution level
|
|
62
71
|
intermediate_results: dict[str, ToolMessage] = Field(default_factory=dict) # the intermediate results of each step
|
|
63
72
|
result: AIMessage = Field(
|
|
@@ -91,18 +100,15 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
91
100
|
logger.debug(
|
|
92
101
|
"%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
|
|
93
102
|
AGENT_LOG_PREFIX)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
for tool in tools[:-1]
|
|
104
|
-
]) + "\n" + (f"{tools[-1].name}: {tools[-1].description}. "
|
|
105
|
-
f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
|
|
103
|
+
|
|
104
|
+
def describe_tool(tool: BaseTool) -> str:
|
|
105
|
+
description = f"{tool.name}: {tool.description}"
|
|
106
|
+
if use_tool_schema:
|
|
107
|
+
description += f". {INPUT_SCHEMA_MESSAGE.format(schema=tool.input_schema.model_fields)}"
|
|
108
|
+
return description
|
|
109
|
+
|
|
110
|
+
tool_names = ",".join(tool.name for tool in tools)
|
|
111
|
+
tool_names_and_descriptions = "\n".join(describe_tool(tool) for tool in tools)
|
|
106
112
|
|
|
107
113
|
self.planner_prompt = planner_prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
|
|
108
114
|
self.solver_prompt = solver_prompt
|
|
@@ -123,9 +129,12 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
123
129
|
def _get_current_level_status(state: ReWOOGraphState) -> tuple[int, bool]:
|
|
124
130
|
"""
|
|
125
131
|
Get the current execution level and whether it's complete.
|
|
126
|
-
|
|
127
|
-
:
|
|
128
|
-
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
state: The ReWOO graph state.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
tuple of (current_level, is_complete). Level -1 means all execution is complete.
|
|
129
138
|
"""
|
|
130
139
|
if not state.execution_levels:
|
|
131
140
|
return -1, True
|
|
@@ -143,63 +152,43 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
143
152
|
return current_level, level_complete
|
|
144
153
|
|
|
145
154
|
@staticmethod
|
|
146
|
-
def _parse_planner_output(planner_output: str) ->
|
|
155
|
+
def _parse_planner_output(planner_output: str) -> list[ReWOOPlanStep]:
|
|
147
156
|
|
|
148
157
|
try:
|
|
149
|
-
|
|
150
|
-
except
|
|
158
|
+
return [ReWOOPlanStep(**step) for step in json.loads(planner_output)]
|
|
159
|
+
except Exception as ex:
|
|
151
160
|
raise ValueError(f"The output of planner is invalid JSON format: {planner_output}") from ex
|
|
152
161
|
|
|
153
|
-
return AIMessage(content=steps)
|
|
154
|
-
|
|
155
162
|
@staticmethod
|
|
156
|
-
def _parse_planner_dependencies(steps:
|
|
163
|
+
def _parse_planner_dependencies(steps: list[ReWOOPlanStep]) -> tuple[dict[str, ReWOOPlanStep], list[list[str]]]:
|
|
157
164
|
"""
|
|
158
165
|
Parse planner steps to identify dependencies and create execution levels for parallel processing.
|
|
159
166
|
This creates a dependency map and identifies which evidence placeholders can be executed in parallel.
|
|
160
167
|
|
|
161
|
-
:
|
|
162
|
-
|
|
163
|
-
:return: A mapping from evidence placeholders to step info and execution levels for parallel processing.
|
|
164
|
-
:rtype: Tuple[Dict[str, Dict], List[List[str]]]
|
|
165
|
-
"""
|
|
166
|
-
evidences = {}
|
|
167
|
-
dependence = {}
|
|
168
|
+
Args:
|
|
169
|
+
steps: list of plan steps from the planner.
|
|
168
170
|
|
|
171
|
+
Returns:
|
|
172
|
+
A mapping from evidence placeholders to step info and execution levels for parallel processing.
|
|
173
|
+
"""
|
|
169
174
|
# First pass: collect all evidence placeholders and their info
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
evidence_info = step["evidence"]
|
|
175
|
-
placeholder = evidence_info.get("placeholder", "")
|
|
176
|
-
|
|
177
|
-
if placeholder:
|
|
178
|
-
# Store the complete step info for this evidence
|
|
179
|
-
evidences[placeholder] = {"plan": step.get("plan", ""), "evidence": evidence_info}
|
|
175
|
+
evidences: dict[str, ReWOOPlanStep] = {
|
|
176
|
+
step.evidence.placeholder: step
|
|
177
|
+
for step in steps if step.evidence and step.evidence.placeholder
|
|
178
|
+
}
|
|
180
179
|
|
|
181
180
|
# Second pass: find dependencies now that we have all placeholders
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
if placeholder:
|
|
191
|
-
# Find dependencies by looking for other placeholders in tool_input
|
|
192
|
-
dependence[placeholder] = []
|
|
193
|
-
|
|
194
|
-
# Convert tool_input to string to search for placeholders
|
|
195
|
-
tool_input_str = str(tool_input)
|
|
196
|
-
for var in re.findall(r"#E\d+", tool_input_str):
|
|
197
|
-
if var in evidences and var != placeholder:
|
|
198
|
-
dependence[placeholder].append(var)
|
|
181
|
+
dependencies = {
|
|
182
|
+
step.evidence.placeholder: [
|
|
183
|
+
var for var in re.findall(r"#E\d+", str(step.evidence.tool_input))
|
|
184
|
+
if var in evidences and var != step.evidence.placeholder
|
|
185
|
+
]
|
|
186
|
+
for step in steps if step.evidence and step.evidence.placeholder
|
|
187
|
+
}
|
|
199
188
|
|
|
200
189
|
# Create execution levels using topological sort
|
|
201
|
-
levels = []
|
|
202
|
-
remaining = dict(
|
|
190
|
+
levels: list[list[str]] = []
|
|
191
|
+
remaining = dict(dependencies)
|
|
203
192
|
|
|
204
193
|
while remaining:
|
|
205
194
|
# Find items with no dependencies (can be executed in parallel)
|
|
@@ -215,10 +204,8 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
215
204
|
remaining.pop(placeholder)
|
|
216
205
|
|
|
217
206
|
# Remove completed items from other dependencies
|
|
218
|
-
# for placeholder in remaining.items():
|
|
219
|
-
# remaining[placeholder] = [dep for dep in remaining[placeholder] if dep not in ready]
|
|
220
207
|
for ph, deps in list(remaining.items()):
|
|
221
|
-
remaining[ph] =
|
|
208
|
+
remaining[ph] = list(set(deps) - set(ready))
|
|
222
209
|
return evidences, levels
|
|
223
210
|
|
|
224
211
|
@staticmethod
|
|
@@ -239,6 +226,7 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
239
226
|
|
|
240
227
|
else:
|
|
241
228
|
assert False, f"Unexpected type for tool_input: {type(tool_input)}"
|
|
229
|
+
|
|
242
230
|
return tool_input
|
|
243
231
|
|
|
244
232
|
@staticmethod
|
|
@@ -293,7 +281,7 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
293
281
|
steps = self._parse_planner_output(str(plan.content))
|
|
294
282
|
|
|
295
283
|
# Parse dependencies and create execution levels for parallel processing
|
|
296
|
-
evidence_map, execution_levels = self._parse_planner_dependencies(steps
|
|
284
|
+
evidence_map, execution_levels = self._parse_planner_dependencies(steps)
|
|
297
285
|
|
|
298
286
|
if self.detailed_logs:
|
|
299
287
|
agent_response_log_message = AGENT_CALL_LOG_MESSAGE % (task, str(plan.content))
|
|
@@ -302,10 +290,9 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
302
290
|
|
|
303
291
|
return {
|
|
304
292
|
"plan": plan,
|
|
305
|
-
"steps": steps,
|
|
306
293
|
"evidence_map": evidence_map,
|
|
307
294
|
"execution_levels": execution_levels,
|
|
308
|
-
"current_level": 0
|
|
295
|
+
"current_level": 0,
|
|
309
296
|
}
|
|
310
297
|
|
|
311
298
|
except Exception as ex:
|
|
@@ -339,7 +326,7 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
339
326
|
current_level_placeholders = state.execution_levels[current_level]
|
|
340
327
|
|
|
341
328
|
# Filter to only placeholders not yet completed
|
|
342
|
-
pending_placeholders =
|
|
329
|
+
pending_placeholders = list(set(current_level_placeholders) - set(state.intermediate_results.keys()))
|
|
343
330
|
|
|
344
331
|
if not pending_placeholders:
|
|
345
332
|
# All placeholders in this level are done, move to next level
|
|
@@ -365,10 +352,8 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
365
352
|
# Process results and update intermediate_results
|
|
366
353
|
updated_intermediate_results = dict(state.intermediate_results)
|
|
367
354
|
|
|
368
|
-
for
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
if isinstance(result, Exception):
|
|
355
|
+
for placeholder, result in zip(pending_placeholders, results):
|
|
356
|
+
if isinstance(result, BaseException):
|
|
372
357
|
logger.error("%s Tool execution failed for %s: %s", AGENT_LOG_PREFIX, placeholder, result)
|
|
373
358
|
# Create error tool message
|
|
374
359
|
error_message = f"Tool execution failed: {str(result)}"
|
|
@@ -398,29 +383,32 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
398
383
|
|
|
399
384
|
async def _execute_single_tool(self,
|
|
400
385
|
placeholder: str,
|
|
401
|
-
step_info:
|
|
402
|
-
intermediate_results:
|
|
386
|
+
step_info: ReWOOPlanStep,
|
|
387
|
+
intermediate_results: dict[str, ToolMessage]) -> ToolMessage:
|
|
403
388
|
"""
|
|
404
389
|
Execute a single tool with proper placeholder replacement.
|
|
405
390
|
|
|
406
|
-
:
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
391
|
+
Args:
|
|
392
|
+
placeholder: The evidence placeholder (e.g., "#E1").
|
|
393
|
+
step_info: Step information containing tool and tool_input.
|
|
394
|
+
intermediate_results: Current intermediate results for placeholder replacement.
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
ToolMessage with the tool execution result.
|
|
410
398
|
"""
|
|
411
|
-
evidence_info = step_info
|
|
412
|
-
tool_name = evidence_info.
|
|
413
|
-
tool_input = evidence_info.
|
|
399
|
+
evidence_info = step_info.evidence
|
|
400
|
+
tool_name = evidence_info.tool
|
|
401
|
+
tool_input = evidence_info.tool_input
|
|
414
402
|
|
|
415
403
|
# Replace placeholders in tool input with previous results
|
|
416
|
-
for
|
|
417
|
-
|
|
404
|
+
for ph_key, tool_output in intermediate_results.items():
|
|
405
|
+
tool_output_content = tool_output.content
|
|
418
406
|
# If the content is a list, get the first element which should be a dict
|
|
419
|
-
if isinstance(
|
|
420
|
-
|
|
421
|
-
assert isinstance(
|
|
407
|
+
if isinstance(tool_output_content, list):
|
|
408
|
+
tool_output_content = tool_output_content[0]
|
|
409
|
+
assert isinstance(tool_output_content, dict)
|
|
422
410
|
|
|
423
|
-
tool_input = self._replace_placeholder(
|
|
411
|
+
tool_input = self._replace_placeholder(ph_key, tool_input, tool_output_content)
|
|
424
412
|
|
|
425
413
|
# Get the requested tool
|
|
426
414
|
requested_tool = self._get_tool(tool_name)
|
|
@@ -442,10 +430,11 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
442
430
|
|
|
443
431
|
# Parse and execute the tool
|
|
444
432
|
tool_input_parsed = self._parse_tool_input(tool_input)
|
|
445
|
-
tool_response = await self._call_tool(
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
433
|
+
tool_response = await self._call_tool(
|
|
434
|
+
requested_tool,
|
|
435
|
+
tool_input_parsed,
|
|
436
|
+
RunnableConfig(callbacks=self.callbacks), # type: ignore
|
|
437
|
+
max_retries=self.tool_call_max_retries)
|
|
449
438
|
|
|
450
439
|
if self.detailed_logs:
|
|
451
440
|
self._log_tool_response(requested_tool.name, tool_input_parsed, str(tool_response))
|
|
@@ -459,20 +448,20 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
459
448
|
plan = ""
|
|
460
449
|
# Add the tool outputs of each step to the plan using evidence_map
|
|
461
450
|
for placeholder, step_info in state.evidence_map.items():
|
|
462
|
-
evidence_info = step_info
|
|
463
|
-
original_tool_input = evidence_info.
|
|
464
|
-
tool_name = evidence_info.
|
|
451
|
+
evidence_info = step_info.evidence
|
|
452
|
+
original_tool_input = evidence_info.tool_input
|
|
453
|
+
tool_name = evidence_info.tool
|
|
465
454
|
|
|
466
455
|
# Replace placeholders in tool input with actual results
|
|
467
456
|
final_tool_input = original_tool_input
|
|
468
|
-
for
|
|
469
|
-
|
|
457
|
+
for ph_key, tool_output in state.intermediate_results.items():
|
|
458
|
+
tool_output_content = tool_output.content
|
|
470
459
|
# If the content is a list, get the first element which should be a dict
|
|
471
|
-
if isinstance(
|
|
472
|
-
|
|
473
|
-
assert isinstance(
|
|
460
|
+
if isinstance(tool_output_content, list):
|
|
461
|
+
tool_output_content = tool_output_content[0]
|
|
462
|
+
assert isinstance(tool_output_content, dict)
|
|
474
463
|
|
|
475
|
-
final_tool_input = self._replace_placeholder(
|
|
464
|
+
final_tool_input = self._replace_placeholder(ph_key, final_tool_input, tool_output_content)
|
|
476
465
|
|
|
477
466
|
# Get the final result for this placeholder
|
|
478
467
|
final_result = ""
|
|
@@ -482,14 +471,15 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
482
471
|
result_content = result_content[0]
|
|
483
472
|
if isinstance(result_content, dict):
|
|
484
473
|
final_result = str(result_content)
|
|
485
|
-
else:
|
|
486
|
-
final_result = str(result_content)
|
|
487
474
|
else:
|
|
488
475
|
final_result = str(result_content)
|
|
489
476
|
|
|
490
|
-
step_plan = step_info.
|
|
491
|
-
plan +=
|
|
492
|
-
|
|
477
|
+
step_plan = step_info.plan
|
|
478
|
+
plan += '\n'.join([
|
|
479
|
+
f"Plan: {step_plan}",
|
|
480
|
+
f"{placeholder} = {tool_name}[{final_tool_input}",
|
|
481
|
+
f"Result: {final_result}\n\n"
|
|
482
|
+
])
|
|
493
483
|
|
|
494
484
|
task = str(state.task.content)
|
|
495
485
|
solver_prompt = self.solver_prompt.partial(plan=plan)
|
|
@@ -547,8 +537,10 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
547
537
|
graph.add_node("solver", self.solver_node)
|
|
548
538
|
|
|
549
539
|
graph.add_edge("planner", "executor")
|
|
550
|
-
|
|
551
|
-
|
|
540
|
+
graph.add_conditional_edges("executor",
|
|
541
|
+
self.conditional_edge, {
|
|
542
|
+
AgentDecision.TOOL: "executor", AgentDecision.END: "solver"
|
|
543
|
+
})
|
|
552
544
|
|
|
553
545
|
graph.set_entry_point("planner")
|
|
554
546
|
graph.set_finish_point("solver")
|
|
@@ -71,8 +71,8 @@ class ReWOOAgentWorkflowConfig(AgentBaseConfig, name="rewoo_agent"):
|
|
|
71
71
|
|
|
72
72
|
@register_function(config_type=ReWOOAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
73
73
|
async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builder):
|
|
74
|
-
from langchain.schema import BaseMessage
|
|
75
74
|
from langchain_core.messages import trim_messages
|
|
75
|
+
from langchain_core.messages.base import BaseMessage
|
|
76
76
|
from langchain_core.messages.human import HumanMessage
|
|
77
77
|
from langchain_core.prompts import ChatPromptTemplate
|
|
78
78
|
from langgraph.graph.state import CompiledStateGraph
|
|
@@ -154,6 +154,9 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
154
154
|
# get and return the output from the state
|
|
155
155
|
state = ReWOOGraphState(**state)
|
|
156
156
|
output_message = state.result.content
|
|
157
|
+
# Ensure output_message is a string
|
|
158
|
+
if isinstance(output_message, list | dict):
|
|
159
|
+
output_message = str(output_message)
|
|
157
160
|
return ChatResponse.from_string(output_message)
|
|
158
161
|
|
|
159
162
|
except Exception as ex:
|
|
@@ -13,10 +13,12 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import logging
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from datetime import UTC
|
|
16
19
|
from datetime import datetime
|
|
17
|
-
from datetime import timezone
|
|
18
|
-
from typing import Callable
|
|
19
20
|
|
|
21
|
+
import httpx
|
|
20
22
|
from authlib.integrations.httpx_client import OAuth2Client as AuthlibOAuth2Client
|
|
21
23
|
from pydantic import SecretStr
|
|
22
24
|
|
|
@@ -28,6 +30,8 @@ from nat.data_models.authentication import AuthFlowType
|
|
|
28
30
|
from nat.data_models.authentication import AuthResult
|
|
29
31
|
from nat.data_models.authentication import BearerTokenCred
|
|
30
32
|
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
31
35
|
|
|
32
36
|
class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConfig]):
|
|
33
37
|
|
|
@@ -41,26 +45,30 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
|
|
|
41
45
|
if not isinstance(refresh_token, str):
|
|
42
46
|
return None
|
|
43
47
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
48
|
+
try:
|
|
49
|
+
with AuthlibOAuth2Client(
|
|
50
|
+
client_id=self.config.client_id,
|
|
51
|
+
client_secret=self.config.client_secret,
|
|
52
|
+
) as client:
|
|
49
53
|
new_token_data = client.refresh_token(self.config.token_url, refresh_token=refresh_token)
|
|
50
|
-
except Exception:
|
|
51
|
-
# On any failure, we'll fall back to the full auth flow.
|
|
52
|
-
return None
|
|
53
54
|
|
|
54
|
-
|
|
55
|
-
|
|
55
|
+
expires_at_ts = new_token_data.get("expires_at")
|
|
56
|
+
new_expires_at = datetime.fromtimestamp(expires_at_ts, tz=UTC) if expires_at_ts else None
|
|
56
57
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
58
|
+
new_auth_result = AuthResult(
|
|
59
|
+
credentials=[BearerTokenCred(token=SecretStr(new_token_data["access_token"]))],
|
|
60
|
+
token_expires_at=new_expires_at,
|
|
61
|
+
raw=new_token_data,
|
|
62
|
+
)
|
|
62
63
|
|
|
63
|
-
|
|
64
|
+
self._authenticated_tokens[user_id] = new_auth_result
|
|
65
|
+
except httpx.HTTPStatusError:
|
|
66
|
+
return None
|
|
67
|
+
except httpx.RequestError:
|
|
68
|
+
return None
|
|
69
|
+
except Exception:
|
|
70
|
+
# On any other failure, we'll fall back to the full auth flow.
|
|
71
|
+
return None
|
|
64
72
|
|
|
65
73
|
return new_auth_result
|
|
66
74
|
|
nat/builder/builder.py
CHANGED
nat/builder/context.py
CHANGED
|
@@ -40,12 +40,12 @@ from nat.utils.reactive.subject import Subject
|
|
|
40
40
|
class Singleton(type):
|
|
41
41
|
|
|
42
42
|
def __init__(cls, name, bases, dict):
|
|
43
|
-
super(
|
|
43
|
+
super().__init__(name, bases, dict)
|
|
44
44
|
cls.instance = None
|
|
45
45
|
|
|
46
46
|
def __call__(cls, *args, **kw):
|
|
47
47
|
if cls.instance is None:
|
|
48
|
-
cls.instance = super(
|
|
48
|
+
cls.instance = super().__call__(*args, **kw)
|
|
49
49
|
return cls.instance
|
|
50
50
|
|
|
51
51
|
|
nat/builder/front_end.py
CHANGED
|
@@ -37,7 +37,7 @@ class FrontEndBase(typing.Generic[FrontEndConfigT], ABC):
|
|
|
37
37
|
|
|
38
38
|
super().__init__()
|
|
39
39
|
|
|
40
|
-
self._full_config:
|
|
40
|
+
self._full_config: Config = full_config
|
|
41
41
|
self._front_end_config: FrontEndConfigT = typing.cast(FrontEndConfigT, full_config.general.front_end)
|
|
42
42
|
|
|
43
43
|
@property
|
|
@@ -84,7 +84,7 @@ class LayeredConfig:
|
|
|
84
84
|
if lower_value not in ['true', 'false']:
|
|
85
85
|
raise ValueError(f"Boolean value must be 'true' or 'false', got '{value}'")
|
|
86
86
|
value = lower_value == 'true'
|
|
87
|
-
elif isinstance(original_value,
|
|
87
|
+
elif isinstance(original_value, int | float):
|
|
88
88
|
value = type(original_value)(value)
|
|
89
89
|
elif isinstance(original_value, list):
|
|
90
90
|
value = [v.strip() for v in value.split(',')]
|
nat/cli/commands/mcp/mcp.py
CHANGED
|
@@ -297,7 +297,7 @@ async def list_tools_via_function_group(
|
|
|
297
297
|
if fn is not None:
|
|
298
298
|
tools.append(to_tool_entry(full, fn))
|
|
299
299
|
else:
|
|
300
|
-
for full, fn in fns.items():
|
|
300
|
+
for full, fn in (await fns).items():
|
|
301
301
|
tools.append(to_tool_entry(full, fn))
|
|
302
302
|
|
|
303
303
|
return tools
|
|
@@ -443,7 +443,7 @@ async def ping_mcp_server(url: str,
|
|
|
443
443
|
# Apply timeout to the entire ping operation
|
|
444
444
|
return await asyncio.wait_for(_ping_operation(), timeout=timeout)
|
|
445
445
|
|
|
446
|
-
except
|
|
446
|
+
except TimeoutError:
|
|
447
447
|
return MCPPingResult(url=url,
|
|
448
448
|
status="unhealthy",
|
|
449
449
|
response_time_ms=None,
|
nat/cli/commands/start.py
CHANGED
|
@@ -111,7 +111,7 @@ class StartCommandGroup(click.Group):
|
|
|
111
111
|
elif (issubclass(decomposed_type.root, Path)):
|
|
112
112
|
param_type = click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path)
|
|
113
113
|
|
|
114
|
-
elif (issubclass(decomposed_type.root,
|
|
114
|
+
elif (issubclass(decomposed_type.root, list | tuple | set)):
|
|
115
115
|
if (len(decomposed_type.args) == 1):
|
|
116
116
|
inner = DecomposedType(decomposed_type.args[0])
|
|
117
117
|
# Support containers of Literal values -> multiple Choice
|
nat/cli/type_registry.py
CHANGED
|
@@ -992,7 +992,7 @@ class TypeRegistry:
|
|
|
992
992
|
if (short_names[key.local_name] == 1):
|
|
993
993
|
type_list.append((key.local_name, key.config_type))
|
|
994
994
|
|
|
995
|
-
return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
995
|
+
return typing.Union[*tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
996
996
|
|
|
997
997
|
def compute_annotation(self, cls: type[TypedBaseModelT]):
|
|
998
998
|
|
|
@@ -81,7 +81,7 @@ async def router_agent_workflow(config: RouterAgentWorkflowConfig, builder: Buil
|
|
|
81
81
|
logger.exception("%s Router Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
|
|
82
82
|
if config.verbose:
|
|
83
83
|
return str(ex)
|
|
84
|
-
return "Router agent failed with exception:
|
|
84
|
+
return f"Router agent failed with exception: {ex}"
|
|
85
85
|
|
|
86
86
|
try:
|
|
87
87
|
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
nat/data_models/api_server.py
CHANGED
|
@@ -273,7 +273,7 @@ class ChatResponse(ResponseBaseModelOutput):
|
|
|
273
273
|
if model is None:
|
|
274
274
|
model = ""
|
|
275
275
|
if created is None:
|
|
276
|
-
created = datetime.datetime.now(datetime.
|
|
276
|
+
created = datetime.datetime.now(datetime.UTC)
|
|
277
277
|
|
|
278
278
|
return ChatResponse(id=id_,
|
|
279
279
|
object=object_,
|
|
@@ -317,7 +317,7 @@ class ChatResponseChunk(ResponseBaseModelOutput):
|
|
|
317
317
|
if id_ is None:
|
|
318
318
|
id_ = str(uuid.uuid4())
|
|
319
319
|
if created is None:
|
|
320
|
-
created = datetime.datetime.now(datetime.
|
|
320
|
+
created = datetime.datetime.now(datetime.UTC)
|
|
321
321
|
if model is None:
|
|
322
322
|
model = ""
|
|
323
323
|
if object_ is None:
|
|
@@ -343,7 +343,7 @@ class ChatResponseChunk(ResponseBaseModelOutput):
|
|
|
343
343
|
if id_ is None:
|
|
344
344
|
id_ = str(uuid.uuid4())
|
|
345
345
|
if created is None:
|
|
346
|
-
created = datetime.datetime.now(datetime.
|
|
346
|
+
created = datetime.datetime.now(datetime.UTC)
|
|
347
347
|
if model is None:
|
|
348
348
|
model = ""
|
|
349
349
|
|
|
@@ -485,7 +485,7 @@ class WebSocketUserMessage(BaseModel):
|
|
|
485
485
|
security: Security = Security()
|
|
486
486
|
error: Error = Error()
|
|
487
487
|
schema_version: str = "1.0.0"
|
|
488
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
488
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
489
489
|
|
|
490
490
|
|
|
491
491
|
class WebSocketUserInteractionResponseMessage(BaseModel):
|
|
@@ -501,7 +501,7 @@ class WebSocketUserInteractionResponseMessage(BaseModel):
|
|
|
501
501
|
security: Security = Security()
|
|
502
502
|
error: Error = Error()
|
|
503
503
|
schema_version: str = "1.0.0"
|
|
504
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
504
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
505
505
|
|
|
506
506
|
|
|
507
507
|
class SystemIntermediateStepContent(BaseModel):
|
|
@@ -527,7 +527,7 @@ class WebSocketSystemIntermediateStepMessage(BaseModel):
|
|
|
527
527
|
conversation_id: str | None = None
|
|
528
528
|
content: SystemIntermediateStepContent
|
|
529
529
|
status: WebSocketMessageStatus
|
|
530
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
530
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
531
531
|
|
|
532
532
|
|
|
533
533
|
class SystemResponseContent(BaseModel):
|
|
@@ -551,7 +551,7 @@ class WebSocketSystemResponseTokenMessage(BaseModel):
|
|
|
551
551
|
conversation_id: str | None = None
|
|
552
552
|
content: SystemResponseContent | Error | GenerateResponse
|
|
553
553
|
status: WebSocketMessageStatus
|
|
554
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
554
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
555
555
|
|
|
556
556
|
@field_validator("content")
|
|
557
557
|
@classmethod
|
|
@@ -560,7 +560,7 @@ class WebSocketSystemResponseTokenMessage(BaseModel):
|
|
|
560
560
|
raise ValueError(f"Field: content must be 'Error' when type is {WebSocketMessageType.ERROR_MESSAGE}")
|
|
561
561
|
|
|
562
562
|
if info.data.get("type") == WebSocketMessageType.RESPONSE_MESSAGE and not isinstance(
|
|
563
|
-
value,
|
|
563
|
+
value, SystemResponseContent | GenerateResponse):
|
|
564
564
|
raise ValueError(
|
|
565
565
|
f"Field: content must be 'SystemResponseContent' when type is {WebSocketMessageType.RESPONSE_MESSAGE}")
|
|
566
566
|
return value
|
|
@@ -582,7 +582,7 @@ class WebSocketSystemInteractionMessage(BaseModel):
|
|
|
582
582
|
conversation_id: str | None = None
|
|
583
583
|
content: HumanPrompt
|
|
584
584
|
status: WebSocketMessageStatus
|
|
585
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
585
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
586
586
|
|
|
587
587
|
|
|
588
588
|
# ======== GenerateResponse Converters ========
|