nvidia-nat 1.2.1rc1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +27 -18
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +81 -50
- nat/agent/react_agent/register.py +59 -40
- nat/agent/reasoning_agent/reasoning_agent.py +17 -15
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +327 -149
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +64 -46
- nat/agent/tool_calling_agent/agent.py +152 -29
- nat/agent/tool_calling_agent/register.py +61 -38
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +10 -6
- nat/builder/context.py +70 -18
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/intermediate_step_manager.py +6 -2
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +327 -79
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +5 -2
- nat/cli/commands/workflow/templates/register.py.j2 +2 -3
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +105 -19
- nat/cli/entrypoint.py +17 -11
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +79 -10
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +196 -67
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +42 -18
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/span.py +41 -3
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/azure_openai_embedder.py +46 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +2 -3
- nat/embedder/register.py +1 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +9 -6
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +19 -7
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +455 -282
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +74 -50
- nat/front_ends/fastapi/message_validator.py +20 -21
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +47 -3
- nat/front_ends/mcp/mcp_front_end_plugin.py +48 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +120 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +57 -0
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +5 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +35 -15
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +22 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +14 -7
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +164 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +395 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +105 -8
- nat/runtime/session.py +69 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +4 -4
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +12 -3
- nat/utils/type_utils.py +9 -5
- nvidia_nat-1.3.0.dist-info/METADATA +195 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/RECORD +244 -200
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- nvidia_nat-1.2.1rc1.dist-info/METADATA +0 -365
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -22,26 +22,29 @@ from nat.builder.builder import Builder
|
|
|
22
22
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
23
23
|
from nat.builder.function_info import FunctionInfo
|
|
24
24
|
from nat.cli.register_workflow import register_function
|
|
25
|
+
from nat.data_models.agent import AgentBaseConfig
|
|
25
26
|
from nat.data_models.api_server import ChatRequest
|
|
27
|
+
from nat.data_models.api_server import ChatRequestOrMessage
|
|
26
28
|
from nat.data_models.api_server import ChatResponse
|
|
29
|
+
from nat.data_models.api_server import Usage
|
|
30
|
+
from nat.data_models.component_ref import FunctionGroupRef
|
|
27
31
|
from nat.data_models.component_ref import FunctionRef
|
|
28
|
-
from nat.data_models.
|
|
29
|
-
from nat.data_models.
|
|
32
|
+
from nat.data_models.optimizable import OptimizableField
|
|
33
|
+
from nat.data_models.optimizable import OptimizableMixin
|
|
34
|
+
from nat.data_models.optimizable import SearchSpace
|
|
30
35
|
from nat.utils.type_converter import GlobalTypeConverter
|
|
31
36
|
|
|
32
37
|
logger = logging.getLogger(__name__)
|
|
33
38
|
|
|
34
39
|
|
|
35
|
-
class ReActAgentWorkflowConfig(
|
|
40
|
+
class ReActAgentWorkflowConfig(AgentBaseConfig, OptimizableMixin, name="react_agent"):
|
|
36
41
|
"""
|
|
37
42
|
Defines a NAT function that uses a ReAct Agent performs reasoning inbetween tool calls, and utilizes the
|
|
38
43
|
tool names and descriptions to select the optimal tool.
|
|
39
44
|
"""
|
|
40
|
-
|
|
41
|
-
tool_names: list[FunctionRef] = Field(
|
|
42
|
-
|
|
43
|
-
llm_name: LLMRef = Field(description="The LLM model to use with the react agent.")
|
|
44
|
-
verbose: bool = Field(default=False, description="Set the verbosity of the react agent's logging.")
|
|
45
|
+
description: str = Field(default="ReAct Agent Workflow", description="The description of this functions use.")
|
|
46
|
+
tool_names: list[FunctionRef | FunctionGroupRef] = Field(
|
|
47
|
+
default_factory=list, description="The list of tools to provide to the react agent.")
|
|
45
48
|
retry_agent_response_parsing_errors: bool = Field(
|
|
46
49
|
default=True,
|
|
47
50
|
validation_alias=AliasChoices("retry_agent_response_parsing_errors", "retry_parsing_errors"),
|
|
@@ -60,23 +63,29 @@ class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
|
|
|
60
63
|
description="Whether to pass tool call errors to agent. If False, failed tool calls will raise an exception.")
|
|
61
64
|
include_tool_input_schema_in_tool_description: bool = Field(
|
|
62
65
|
default=True, description="Specify inclusion of tool input schemas in the prompt.")
|
|
63
|
-
|
|
66
|
+
normalize_tool_input_quotes: bool = Field(
|
|
67
|
+
default=True,
|
|
68
|
+
description="Whether to replace single quotes with double quotes in the tool input. "
|
|
69
|
+
"This is useful for tools that expect structured json input.")
|
|
64
70
|
system_prompt: str | None = Field(
|
|
65
71
|
default=None,
|
|
66
72
|
description="Provides the SYSTEM_PROMPT to use with the agent") # defaults to SYSTEM_PROMPT in prompt.py
|
|
67
73
|
max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
74
|
+
additional_instructions: str | None = OptimizableField(
|
|
75
|
+
default=None,
|
|
76
|
+
description="Additional instructions to provide to the agent in addition to the base prompt.",
|
|
77
|
+
space=SearchSpace(
|
|
78
|
+
is_prompt=True,
|
|
79
|
+
prompt="No additional instructions.",
|
|
80
|
+
prompt_purpose="Additional instructions to provide to the agent in addition to the base prompt.",
|
|
81
|
+
))
|
|
73
82
|
|
|
74
83
|
|
|
75
84
|
@register_function(config_type=ReActAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
76
85
|
async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builder):
|
|
77
86
|
from langchain.schema import BaseMessage
|
|
78
87
|
from langchain_core.messages import trim_messages
|
|
79
|
-
from langgraph.graph.
|
|
88
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
80
89
|
|
|
81
90
|
from nat.agent.base import AGENT_LOG_PREFIX
|
|
82
91
|
from nat.agent.react_agent.agent import ReActAgentGraph
|
|
@@ -89,26 +98,41 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
89
98
|
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
90
99
|
# the agent can run any installed tool, simply install the tool and add it to the config file
|
|
91
100
|
# the sample tool provided can easily be copied or changed
|
|
92
|
-
tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
101
|
+
tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
93
102
|
if not tools:
|
|
94
103
|
raise ValueError(f"No tools specified for ReAct Agent '{config.llm_name}'")
|
|
95
104
|
# configure callbacks, for sending intermediate steps
|
|
96
105
|
# construct the ReAct Agent Graph from the configured llm, prompt, and tools
|
|
97
|
-
graph:
|
|
106
|
+
graph: CompiledStateGraph = await ReActAgentGraph(
|
|
98
107
|
llm=llm,
|
|
99
108
|
prompt=prompt,
|
|
100
109
|
tools=tools,
|
|
101
110
|
use_tool_schema=config.include_tool_input_schema_in_tool_description,
|
|
102
111
|
detailed_logs=config.verbose,
|
|
112
|
+
log_response_max_chars=config.log_response_max_chars,
|
|
103
113
|
retry_agent_response_parsing_errors=config.retry_agent_response_parsing_errors,
|
|
104
114
|
parse_agent_response_max_retries=config.parse_agent_response_max_retries,
|
|
105
115
|
tool_call_max_retries=config.tool_call_max_retries,
|
|
106
|
-
pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent
|
|
116
|
+
pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent,
|
|
117
|
+
normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()
|
|
118
|
+
|
|
119
|
+
async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
|
|
120
|
+
"""
|
|
121
|
+
Main workflow entry function for the ReAct Agent.
|
|
122
|
+
|
|
123
|
+
This function invokes the ReAct Agent Graph and returns the response.
|
|
107
124
|
|
|
108
|
-
|
|
125
|
+
Args:
|
|
126
|
+
chat_request_or_message (ChatRequestOrMessage): The input message to process
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
ChatResponse | str: The response from the agent or error message
|
|
130
|
+
"""
|
|
109
131
|
try:
|
|
132
|
+
message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
|
|
133
|
+
|
|
110
134
|
# initialize the starting state with the user query
|
|
111
|
-
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in
|
|
135
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
|
|
112
136
|
max_tokens=config.max_history,
|
|
113
137
|
strategy="last",
|
|
114
138
|
token_counter=len,
|
|
@@ -125,25 +149,20 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
125
149
|
|
|
126
150
|
# get and return the output from the state
|
|
127
151
|
state = ReActGraphState(**state)
|
|
128
|
-
output_message = state.messages[-1]
|
|
129
|
-
|
|
130
|
-
|
|
152
|
+
output_message = state.messages[-1]
|
|
153
|
+
content = str(output_message.content)
|
|
154
|
+
|
|
155
|
+
# Create usage statistics for the response
|
|
156
|
+
prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
|
|
157
|
+
completion_tokens = len(content.split()) if content else 0
|
|
158
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
159
|
+
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
|
|
160
|
+
response = ChatResponse.from_string(content, usage=usage)
|
|
161
|
+
if chat_request_or_message.is_string:
|
|
162
|
+
return GlobalTypeConverter.get().convert(response, to_type=str)
|
|
163
|
+
return response
|
|
131
164
|
except Exception as ex:
|
|
132
|
-
logger.
|
|
133
|
-
|
|
134
|
-
if config.verbose:
|
|
135
|
-
return ChatResponse.from_string(str(ex))
|
|
136
|
-
return ChatResponse.from_string("I seem to be having a problem.")
|
|
137
|
-
|
|
138
|
-
if (config.use_openai_api):
|
|
139
|
-
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
140
|
-
else:
|
|
141
|
-
|
|
142
|
-
async def _str_api_fn(input_message: str) -> str:
|
|
143
|
-
oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
|
|
144
|
-
|
|
145
|
-
oai_output = await _response_fn(oai_input)
|
|
146
|
-
|
|
147
|
-
return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
|
|
165
|
+
logger.error("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, str(ex))
|
|
166
|
+
raise
|
|
148
167
|
|
|
149
|
-
|
|
168
|
+
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
@@ -23,25 +23,22 @@ from nat.builder.builder import Builder
|
|
|
23
23
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
24
24
|
from nat.builder.function_info import FunctionInfo
|
|
25
25
|
from nat.cli.register_workflow import register_function
|
|
26
|
+
from nat.data_models.agent import AgentBaseConfig
|
|
26
27
|
from nat.data_models.api_server import ChatRequest
|
|
27
28
|
from nat.data_models.component_ref import FunctionRef
|
|
28
|
-
from nat.data_models.component_ref import LLMRef
|
|
29
|
-
from nat.data_models.function import FunctionBaseConfig
|
|
30
29
|
|
|
31
30
|
logger = logging.getLogger(__name__)
|
|
32
31
|
|
|
33
32
|
|
|
34
|
-
class ReasoningFunctionConfig(
|
|
33
|
+
class ReasoningFunctionConfig(AgentBaseConfig, name="reasoning_agent"):
|
|
35
34
|
"""
|
|
36
35
|
Defines a NAT function that performs reasoning on the input data.
|
|
37
36
|
Output is passed to the next function in the workflow.
|
|
38
37
|
|
|
39
38
|
Designed to be used with an InterceptingFunction.
|
|
40
39
|
"""
|
|
41
|
-
|
|
42
|
-
llm_name: LLMRef = Field(description="The name of the LLM to use for reasoning.")
|
|
40
|
+
description: str = Field(default="Reasoning Agent", description="The description of this function's use.")
|
|
43
41
|
augmented_fn: FunctionRef = Field(description="The name of the function to reason on.")
|
|
44
|
-
verbose: bool = Field(default=False, description="Whether to log detailed information.")
|
|
45
42
|
reasoning_prompt_template: str = Field(
|
|
46
43
|
default=("You are an expert reasoning model task with creating a detailed execution plan"
|
|
47
44
|
" for a system that has the following description:\n\n"
|
|
@@ -102,7 +99,7 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
|
|
|
102
99
|
llm: BaseChatModel = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
103
100
|
|
|
104
101
|
# Get the augmented function's description
|
|
105
|
-
augmented_function = builder.get_function(config.augmented_fn)
|
|
102
|
+
augmented_function = await builder.get_function(config.augmented_fn)
|
|
106
103
|
|
|
107
104
|
# For now, we rely on runtime checking for type conversion
|
|
108
105
|
|
|
@@ -113,11 +110,16 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
|
|
|
113
110
|
f"function without a description.")
|
|
114
111
|
|
|
115
112
|
# Get the function dependencies of the augmented function
|
|
116
|
-
|
|
113
|
+
function_dependencies = builder.get_function_dependencies(config.augmented_fn)
|
|
114
|
+
function_used_tools = set()
|
|
115
|
+
function_used_tools.update(function_dependencies.functions)
|
|
116
|
+
for function_group in function_dependencies.function_groups:
|
|
117
|
+
function_used_tools.update(builder.get_function_group_dependencies(function_group).functions)
|
|
118
|
+
|
|
117
119
|
tool_names_with_desc: list[tuple[str, str]] = []
|
|
118
120
|
|
|
119
121
|
for tool in function_used_tools:
|
|
120
|
-
tool_impl = builder.get_function(tool)
|
|
122
|
+
tool_impl = await builder.get_function(tool)
|
|
121
123
|
tool_names_with_desc.append((tool, tool_impl.description if hasattr(tool_impl, "description") else ""))
|
|
122
124
|
|
|
123
125
|
# Draft the reasoning prompt for the augmented function
|
|
@@ -155,12 +157,12 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
|
|
|
155
157
|
prompt = prompt.to_string()
|
|
156
158
|
|
|
157
159
|
# Get the reasoning output from the LLM
|
|
158
|
-
reasoning_output =
|
|
160
|
+
reasoning_output = []
|
|
159
161
|
|
|
160
162
|
async for chunk in llm.astream(prompt):
|
|
161
|
-
reasoning_output
|
|
163
|
+
reasoning_output.append(chunk.content)
|
|
162
164
|
|
|
163
|
-
reasoning_output = remove_r1_think_tags(reasoning_output)
|
|
165
|
+
reasoning_output = remove_r1_think_tags("".join(reasoning_output))
|
|
164
166
|
|
|
165
167
|
output = await downstream_template.ainvoke(input={
|
|
166
168
|
"input_text": input_text, "reasoning_output": reasoning_output
|
|
@@ -198,12 +200,12 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
|
|
|
198
200
|
prompt = prompt.to_string()
|
|
199
201
|
|
|
200
202
|
# Get the reasoning output from the LLM
|
|
201
|
-
reasoning_output =
|
|
203
|
+
reasoning_output = []
|
|
202
204
|
|
|
203
205
|
async for chunk in llm.astream(prompt):
|
|
204
|
-
reasoning_output
|
|
206
|
+
reasoning_output.append(chunk.content)
|
|
205
207
|
|
|
206
|
-
reasoning_output = remove_r1_think_tags(reasoning_output)
|
|
208
|
+
reasoning_output = remove_r1_think_tags("".join(reasoning_output))
|
|
207
209
|
|
|
208
210
|
output = await downstream_template.ainvoke(input={
|
|
209
211
|
"input_text": input_text, "reasoning_output": reasoning_output
|
nat/agent/register.py
CHANGED
|
@@ -13,10 +13,10 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
# pylint: disable=unused-import
|
|
17
16
|
# flake8: noqa
|
|
18
17
|
|
|
19
18
|
# Import any workflows which need to be automatically registered here
|
|
19
|
+
from .prompt_optimizer import register as prompt_optimizer
|
|
20
20
|
from .react_agent import register as react_agent
|
|
21
21
|
from .reasoning_agent import reasoning_agent
|
|
22
22
|
from .rewoo_agent import register as rewoo_agent
|