nvidia-nat 1.3.0rc1__py3-none-any.whl → 1.3.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/prompt_optimizer/register.py +2 -2
- nat/agent/react_agent/register.py +20 -21
- nat/agent/rewoo_agent/register.py +18 -20
- nat/agent/tool_calling_agent/register.py +7 -3
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +31 -18
- nat/builder/component_utils.py +1 -1
- nat/builder/context.py +22 -6
- nat/builder/function.py +3 -2
- nat/builder/workflow_builder.py +46 -3
- nat/cli/commands/mcp/mcp.py +6 -6
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -12
- nat/cli/commands/workflow/templates/register.py.j2 +2 -2
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +54 -10
- nat/cli/entrypoint.py +9 -1
- nat/cli/main.py +3 -0
- nat/data_models/api_server.py +143 -66
- nat/data_models/config.py +1 -1
- nat/data_models/span.py +41 -3
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +2 -2
- nat/front_ends/console/console_front_end_plugin.py +11 -2
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +5 -35
- nat/front_ends/fastapi/message_validator.py +3 -1
- nat/observability/exporter/span_exporter.py +34 -14
- nat/observability/register.py +16 -0
- nat/profiler/decorators/framework_wrapper.py +1 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
- nat/runtime/runner.py +103 -6
- nat/runtime/session.py +27 -1
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +4 -4
- nat/utils/decorators.py +210 -0
- nat/utils/type_converter.py +8 -0
- nvidia_nat-1.3.0rc3.dist-info/METADATA +195 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/RECORD +46 -45
- nvidia_nat-1.3.0rc1.dist-info/METADATA +0 -391
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/top_level.txt +0 -0
|
@@ -51,7 +51,7 @@ async def prompt_optimizer_function(config: PromptOptimizerConfig, builder: Buil
|
|
|
51
51
|
from .prompt import mutator_prompt
|
|
52
52
|
except ImportError as exc:
|
|
53
53
|
raise ImportError("langchain-core is not installed. Please install it to use MultiLLMPlanner.\n"
|
|
54
|
-
"This error can be resolve by installing nvidia-nat[langchain]") from exc
|
|
54
|
+
"This error can be resolve by installing \"nvidia-nat[langchain]\".") from exc
|
|
55
55
|
|
|
56
56
|
llm = await builder.get_llm(config.optimizer_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
57
57
|
|
|
@@ -111,7 +111,7 @@ async def prompt_recombiner_function(config: PromptRecombinerConfig, builder: Bu
|
|
|
111
111
|
from langchain_core.prompts import PromptTemplate
|
|
112
112
|
except ImportError as exc:
|
|
113
113
|
raise ImportError("langchain-core is not installed. Please install it to use MultiLLMPlanner.\n"
|
|
114
|
-
"This error can be resolve by installing nvidia-nat[langchain].") from exc
|
|
114
|
+
"This error can be resolve by installing \"nvidia-nat[langchain]\".") from exc
|
|
115
115
|
|
|
116
116
|
llm = await builder.get_llm(config.optimizer_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
117
117
|
|
|
@@ -24,7 +24,9 @@ from nat.builder.function_info import FunctionInfo
|
|
|
24
24
|
from nat.cli.register_workflow import register_function
|
|
25
25
|
from nat.data_models.agent import AgentBaseConfig
|
|
26
26
|
from nat.data_models.api_server import ChatRequest
|
|
27
|
+
from nat.data_models.api_server import ChatRequestOrMessage
|
|
27
28
|
from nat.data_models.api_server import ChatResponse
|
|
29
|
+
from nat.data_models.api_server import Usage
|
|
28
30
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
29
31
|
from nat.data_models.component_ref import FunctionRef
|
|
30
32
|
from nat.data_models.optimizable import OptimizableField
|
|
@@ -69,9 +71,6 @@ class ReActAgentWorkflowConfig(AgentBaseConfig, OptimizableMixin, name="react_ag
|
|
|
69
71
|
default=None,
|
|
70
72
|
description="Provides the SYSTEM_PROMPT to use with the agent") # defaults to SYSTEM_PROMPT in prompt.py
|
|
71
73
|
max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
|
|
72
|
-
use_openai_api: bool = Field(default=False,
|
|
73
|
-
description=("Use OpenAI API for the input/output types to the function. "
|
|
74
|
-
"If False, strings will be used."))
|
|
75
74
|
additional_instructions: str | None = OptimizableField(
|
|
76
75
|
default=None,
|
|
77
76
|
description="Additional instructions to provide to the agent in addition to the base prompt.",
|
|
@@ -117,21 +116,23 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
117
116
|
pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent,
|
|
118
117
|
normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()
|
|
119
118
|
|
|
120
|
-
async def _response_fn(
|
|
119
|
+
async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
|
|
121
120
|
"""
|
|
122
121
|
Main workflow entry function for the ReAct Agent.
|
|
123
122
|
|
|
124
123
|
This function invokes the ReAct Agent Graph and returns the response.
|
|
125
124
|
|
|
126
125
|
Args:
|
|
127
|
-
|
|
126
|
+
chat_request_or_message (ChatRequestOrMessage): The input message to process
|
|
128
127
|
|
|
129
128
|
Returns:
|
|
130
|
-
ChatResponse: The response from the agent or error message
|
|
129
|
+
ChatResponse | str: The response from the agent or error message
|
|
131
130
|
"""
|
|
132
131
|
try:
|
|
132
|
+
message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
|
|
133
|
+
|
|
133
134
|
# initialize the starting state with the user query
|
|
134
|
-
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in
|
|
135
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
|
|
135
136
|
max_tokens=config.max_history,
|
|
136
137
|
strategy="last",
|
|
137
138
|
token_counter=len,
|
|
@@ -149,21 +150,19 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
149
150
|
# get and return the output from the state
|
|
150
151
|
state = ReActGraphState(**state)
|
|
151
152
|
output_message = state.messages[-1]
|
|
152
|
-
|
|
153
|
-
|
|
153
|
+
content = str(output_message.content)
|
|
154
|
+
|
|
155
|
+
# Create usage statistics for the response
|
|
156
|
+
prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
|
|
157
|
+
completion_tokens = len(content.split()) if content else 0
|
|
158
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
159
|
+
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
|
|
160
|
+
response = ChatResponse.from_string(content, usage=usage)
|
|
161
|
+
if chat_request_or_message.is_string:
|
|
162
|
+
return GlobalTypeConverter.get().convert(response, to_type=str)
|
|
163
|
+
return response
|
|
154
164
|
except Exception as ex:
|
|
155
165
|
logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, str(ex))
|
|
156
166
|
raise RuntimeError
|
|
157
167
|
|
|
158
|
-
|
|
159
|
-
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
160
|
-
else:
|
|
161
|
-
|
|
162
|
-
async def _str_api_fn(input_message: str) -> str:
|
|
163
|
-
oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
|
|
164
|
-
|
|
165
|
-
oai_output = await _response_fn(oai_input)
|
|
166
|
-
|
|
167
|
-
return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
|
|
168
|
-
|
|
169
|
-
yield FunctionInfo.from_fn(_str_api_fn, description=config.description)
|
|
168
|
+
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
@@ -25,7 +25,9 @@ from nat.builder.function_info import FunctionInfo
|
|
|
25
25
|
from nat.cli.register_workflow import register_function
|
|
26
26
|
from nat.data_models.agent import AgentBaseConfig
|
|
27
27
|
from nat.data_models.api_server import ChatRequest
|
|
28
|
+
from nat.data_models.api_server import ChatRequestOrMessage
|
|
28
29
|
from nat.data_models.api_server import ChatResponse
|
|
30
|
+
from nat.data_models.api_server import Usage
|
|
29
31
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
30
32
|
from nat.data_models.component_ref import FunctionRef
|
|
31
33
|
from nat.utils.type_converter import GlobalTypeConverter
|
|
@@ -53,9 +55,6 @@ class ReWOOAgentWorkflowConfig(AgentBaseConfig, name="rewoo_agent"):
|
|
|
53
55
|
description="The number of retries before raising a tool call error.",
|
|
54
56
|
ge=1)
|
|
55
57
|
max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
|
|
56
|
-
use_openai_api: bool = Field(default=False,
|
|
57
|
-
description=("Use OpenAI API for the input/output types to the function. "
|
|
58
|
-
"If False, strings will be used."))
|
|
59
58
|
additional_planner_instructions: str | None = Field(
|
|
60
59
|
default=None,
|
|
61
60
|
validation_alias=AliasChoices("additional_planner_instructions", "additional_instructions"),
|
|
@@ -124,21 +123,23 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
124
123
|
tool_call_max_retries=config.tool_call_max_retries,
|
|
125
124
|
raise_tool_call_error=config.raise_tool_call_error).build_graph()
|
|
126
125
|
|
|
127
|
-
async def _response_fn(
|
|
126
|
+
async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
|
|
128
127
|
"""
|
|
129
128
|
Main workflow entry function for the ReWOO Agent.
|
|
130
129
|
|
|
131
130
|
This function invokes the ReWOO Agent Graph and returns the response.
|
|
132
131
|
|
|
133
132
|
Args:
|
|
134
|
-
|
|
133
|
+
chat_request_or_message (ChatRequestOrMessage): The input message to process
|
|
135
134
|
|
|
136
135
|
Returns:
|
|
137
|
-
ChatResponse: The response from the agent or error message
|
|
136
|
+
ChatResponse | str: The response from the agent or error message
|
|
138
137
|
"""
|
|
139
138
|
try:
|
|
139
|
+
message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
|
|
140
|
+
|
|
140
141
|
# initialize the starting state with the user query
|
|
141
|
-
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in
|
|
142
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
|
|
142
143
|
max_tokens=config.max_history,
|
|
143
144
|
strategy="last",
|
|
144
145
|
token_counter=len,
|
|
@@ -157,21 +158,18 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
157
158
|
# Ensure output_message is a string
|
|
158
159
|
if isinstance(output_message, list | dict):
|
|
159
160
|
output_message = str(output_message)
|
|
160
|
-
return ChatResponse.from_string(output_message)
|
|
161
161
|
|
|
162
|
+
# Create usage statistics for the response
|
|
163
|
+
prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
|
|
164
|
+
completion_tokens = len(output_message.split()) if output_message else 0
|
|
165
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
166
|
+
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
|
|
167
|
+
response = ChatResponse.from_string(output_message, usage=usage)
|
|
168
|
+
if chat_request_or_message.is_string:
|
|
169
|
+
return GlobalTypeConverter.get().convert(response, to_type=str)
|
|
170
|
+
return response
|
|
162
171
|
except Exception as ex:
|
|
163
172
|
logger.exception("ReWOO Agent failed with exception: %s", ex)
|
|
164
173
|
raise RuntimeError
|
|
165
174
|
|
|
166
|
-
|
|
167
|
-
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
168
|
-
|
|
169
|
-
else:
|
|
170
|
-
|
|
171
|
-
async def _str_api_fn(input_message: str) -> str:
|
|
172
|
-
oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
|
|
173
|
-
oai_output = await _response_fn(oai_input)
|
|
174
|
-
|
|
175
|
-
return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
|
|
176
|
-
|
|
177
|
-
yield FunctionInfo.from_fn(_str_api_fn, description=config.description)
|
|
175
|
+
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
@@ -23,8 +23,10 @@ from nat.builder.function_info import FunctionInfo
|
|
|
23
23
|
from nat.cli.register_workflow import register_function
|
|
24
24
|
from nat.data_models.agent import AgentBaseConfig
|
|
25
25
|
from nat.data_models.api_server import ChatRequest
|
|
26
|
+
from nat.data_models.api_server import ChatRequestOrMessage
|
|
26
27
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
27
28
|
from nat.data_models.component_ref import FunctionRef
|
|
29
|
+
from nat.utils.type_converter import GlobalTypeConverter
|
|
28
30
|
|
|
29
31
|
logger = logging.getLogger(__name__)
|
|
30
32
|
|
|
@@ -81,21 +83,23 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
81
83
|
handle_tool_errors=config.handle_tool_errors,
|
|
82
84
|
return_direct=return_direct_tools).build_graph()
|
|
83
85
|
|
|
84
|
-
async def _response_fn(
|
|
86
|
+
async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> str:
|
|
85
87
|
"""
|
|
86
88
|
Main workflow entry function for the Tool Calling Agent.
|
|
87
89
|
|
|
88
90
|
This function invokes the Tool Calling Agent Graph and returns the response.
|
|
89
91
|
|
|
90
92
|
Args:
|
|
91
|
-
|
|
93
|
+
chat_request_or_message (ChatRequestOrMessage): The input message to process
|
|
92
94
|
|
|
93
95
|
Returns:
|
|
94
96
|
str: The response from the agent or error message
|
|
95
97
|
"""
|
|
96
98
|
try:
|
|
99
|
+
message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
|
|
100
|
+
|
|
97
101
|
# initialize the starting state with the user query
|
|
98
|
-
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in
|
|
102
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
|
|
99
103
|
max_tokens=config.max_history,
|
|
100
104
|
strategy="last",
|
|
101
105
|
token_counter=len,
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import logging
|
|
17
|
+
from collections.abc import Awaitable
|
|
17
18
|
from collections.abc import Callable
|
|
18
19
|
from datetime import UTC
|
|
19
20
|
from datetime import datetime
|
|
@@ -35,10 +36,15 @@ logger = logging.getLogger(__name__)
|
|
|
35
36
|
|
|
36
37
|
class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConfig]):
|
|
37
38
|
|
|
38
|
-
def __init__(self, config: OAuth2AuthCodeFlowProviderConfig):
|
|
39
|
+
def __init__(self, config: OAuth2AuthCodeFlowProviderConfig, token_storage=None):
|
|
39
40
|
super().__init__(config)
|
|
40
|
-
self._authenticated_tokens: dict[str, AuthResult] = {}
|
|
41
41
|
self._auth_callback = None
|
|
42
|
+
# Always use token storage - defaults to in-memory if not provided
|
|
43
|
+
if token_storage is None:
|
|
44
|
+
from nat.plugins.mcp.auth.token_storage import InMemoryTokenStorage
|
|
45
|
+
self._token_storage = InMemoryTokenStorage()
|
|
46
|
+
else:
|
|
47
|
+
self._token_storage = token_storage
|
|
42
48
|
|
|
43
49
|
async def _attempt_token_refresh(self, user_id: str, auth_result: AuthResult) -> AuthResult | None:
|
|
44
50
|
refresh_token = auth_result.raw.get("refresh_token")
|
|
@@ -61,7 +67,7 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
|
|
|
61
67
|
raw=new_token_data,
|
|
62
68
|
)
|
|
63
69
|
|
|
64
|
-
self.
|
|
70
|
+
await self._token_storage.store(user_id, new_auth_result)
|
|
65
71
|
except httpx.HTTPStatusError:
|
|
66
72
|
return None
|
|
67
73
|
except httpx.RequestError:
|
|
@@ -74,26 +80,30 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
|
|
|
74
80
|
|
|
75
81
|
def _set_custom_auth_callback(self,
|
|
76
82
|
auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType],
|
|
77
|
-
AuthenticatedContext]):
|
|
83
|
+
Awaitable[AuthenticatedContext]]):
|
|
78
84
|
self._auth_callback = auth_callback
|
|
79
85
|
|
|
80
86
|
async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
87
|
+
context = Context.get()
|
|
88
|
+
if user_id is None and hasattr(context, "metadata") and hasattr(
|
|
89
|
+
context.metadata, "cookies") and context.metadata.cookies is not None:
|
|
90
|
+
session_id = context.metadata.cookies.get("nat-session", None)
|
|
84
91
|
if not session_id:
|
|
85
92
|
raise RuntimeError("Authentication failed. No session ID found. Cannot identify user.")
|
|
86
93
|
|
|
87
94
|
user_id = session_id
|
|
88
95
|
|
|
89
|
-
if user_id
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
96
|
+
if user_id:
|
|
97
|
+
# Try to retrieve from token storage
|
|
98
|
+
auth_result = await self._token_storage.retrieve(user_id)
|
|
99
|
+
|
|
100
|
+
if auth_result:
|
|
101
|
+
if not auth_result.is_expired():
|
|
102
|
+
return auth_result
|
|
93
103
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
104
|
+
refreshed_auth_result = await self._attempt_token_refresh(user_id, auth_result)
|
|
105
|
+
if refreshed_auth_result:
|
|
106
|
+
return refreshed_auth_result
|
|
97
107
|
|
|
98
108
|
# Try getting callback from the context if that's not set, use the default callback
|
|
99
109
|
try:
|
|
@@ -109,19 +119,22 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
|
|
|
109
119
|
except Exception as e:
|
|
110
120
|
raise RuntimeError(f"Authentication callback failed: {e}") from e
|
|
111
121
|
|
|
112
|
-
|
|
122
|
+
headers = authenticated_context.headers or {}
|
|
123
|
+
auth_header = headers.get("Authorization", "")
|
|
113
124
|
if not auth_header.startswith("Bearer "):
|
|
114
125
|
raise RuntimeError("Invalid Authorization header")
|
|
115
126
|
|
|
116
127
|
token = auth_header.split(" ")[1]
|
|
117
128
|
|
|
129
|
+
# Safely access metadata
|
|
130
|
+
metadata = authenticated_context.metadata or {}
|
|
118
131
|
auth_result = AuthResult(
|
|
119
132
|
credentials=[BearerTokenCred(token=SecretStr(token))],
|
|
120
|
-
token_expires_at=
|
|
121
|
-
raw=
|
|
133
|
+
token_expires_at=metadata.get("expires_at"),
|
|
134
|
+
raw=metadata.get("raw_token") or {},
|
|
122
135
|
)
|
|
123
136
|
|
|
124
137
|
if user_id:
|
|
125
|
-
self.
|
|
138
|
+
await self._token_storage.store(user_id, auth_result)
|
|
126
139
|
|
|
127
140
|
return auth_result
|
nat/builder/component_utils.py
CHANGED
|
@@ -153,7 +153,7 @@ def recursive_componentref_discovery(cls: TypedBaseModel, value: typing.Any,
|
|
|
153
153
|
for v in value.values():
|
|
154
154
|
yield from recursive_componentref_discovery(cls, v, decomposed_type.args[1])
|
|
155
155
|
elif (issubclass(type(value), BaseModel)):
|
|
156
|
-
for field, field_info in value.model_fields.items():
|
|
156
|
+
for field, field_info in type(value).model_fields.items():
|
|
157
157
|
field_data = getattr(value, field)
|
|
158
158
|
yield from recursive_componentref_discovery(cls, field_data, field_info.annotation)
|
|
159
159
|
if (decomposed_type.is_union):
|
nat/builder/context.py
CHANGED
|
@@ -67,6 +67,8 @@ class ContextState(metaclass=Singleton):
|
|
|
67
67
|
def __init__(self):
|
|
68
68
|
self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None)
|
|
69
69
|
self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
|
|
70
|
+
self.workflow_run_id: ContextVar[str | None] = ContextVar("workflow_run_id", default=None)
|
|
71
|
+
self.workflow_trace_id: ContextVar[int | None] = ContextVar("workflow_trace_id", default=None)
|
|
70
72
|
self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
|
|
71
73
|
self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
|
|
72
74
|
self._metadata: ContextVar[RequestAttributes | None] = ContextVar("request_attributes", default=None)
|
|
@@ -120,14 +122,14 @@ class Context:
|
|
|
120
122
|
@property
|
|
121
123
|
def input_message(self):
|
|
122
124
|
"""
|
|
123
|
-
|
|
125
|
+
Retrieves the input message from the context state.
|
|
124
126
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
127
|
+
The input_message property is used to access the message stored in the
|
|
128
|
+
context state. This property returns the message as it is currently
|
|
129
|
+
maintained in the context.
|
|
128
130
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
+
Returns:
|
|
132
|
+
str: The input message retrieved from the context state.
|
|
131
133
|
"""
|
|
132
134
|
return self._context_state.input_message.get()
|
|
133
135
|
|
|
@@ -196,6 +198,20 @@ class Context:
|
|
|
196
198
|
"""
|
|
197
199
|
return self._context_state.user_message_id.get()
|
|
198
200
|
|
|
201
|
+
@property
|
|
202
|
+
def workflow_run_id(self) -> str | None:
|
|
203
|
+
"""
|
|
204
|
+
Returns a stable identifier for the current workflow/agent invocation (UUID string).
|
|
205
|
+
"""
|
|
206
|
+
return self._context_state.workflow_run_id.get()
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def workflow_trace_id(self) -> int | None:
|
|
210
|
+
"""
|
|
211
|
+
Returns the 128-bit trace identifier for the current run, used as the OpenTelemetry trace_id.
|
|
212
|
+
"""
|
|
213
|
+
return self._context_state.workflow_trace_id.get()
|
|
214
|
+
|
|
199
215
|
@contextmanager
|
|
200
216
|
def push_active_function(self,
|
|
201
217
|
function_name: str,
|
nat/builder/function.py
CHANGED
|
@@ -416,8 +416,9 @@ class FunctionGroup:
|
|
|
416
416
|
"""
|
|
417
417
|
if not name.strip():
|
|
418
418
|
raise ValueError("Function name cannot be empty or blank")
|
|
419
|
-
if not re.match(r"^[a-zA-Z0-9_
|
|
420
|
-
raise ValueError(
|
|
419
|
+
if not re.match(r"^[a-zA-Z0-9_.-]+$", name):
|
|
420
|
+
raise ValueError(
|
|
421
|
+
f"Function name can only contain letters, numbers, underscores, periods, and hyphens: {name}")
|
|
421
422
|
if name in self._functions:
|
|
422
423
|
raise ValueError(f"Function {name} already exists in function group {self._instance_name}")
|
|
423
424
|
|
nat/builder/workflow_builder.py
CHANGED
|
@@ -156,6 +156,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
156
156
|
self._registry = registry
|
|
157
157
|
|
|
158
158
|
self._logging_handlers: dict[str, logging.Handler] = {}
|
|
159
|
+
self._removed_root_handlers: list[tuple[logging.Handler, int]] = []
|
|
159
160
|
self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
|
|
160
161
|
|
|
161
162
|
self._functions: dict[str, ConfiguredFunction] = {}
|
|
@@ -187,6 +188,15 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
187
188
|
# Get the telemetry info from the config
|
|
188
189
|
telemetry_config = self.general_config.telemetry
|
|
189
190
|
|
|
191
|
+
# If we have logging configuration, we need to manage the root logger properly
|
|
192
|
+
root_logger = logging.getLogger()
|
|
193
|
+
|
|
194
|
+
# Collect configured handler types to determine if we need to adjust existing handlers
|
|
195
|
+
# This is somewhat of a hack by inspecting the class name of the config object
|
|
196
|
+
has_console_handler = any(
|
|
197
|
+
hasattr(config, "__class__") and "console" in config.__class__.__name__.lower()
|
|
198
|
+
for config in telemetry_config.logging.values())
|
|
199
|
+
|
|
190
200
|
for key, logging_config in telemetry_config.logging.items():
|
|
191
201
|
# Use the same pattern as tracing, but for logging
|
|
192
202
|
logging_info = self._registry.get_logging_method(type(logging_config))
|
|
@@ -200,7 +210,31 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
200
210
|
self._logging_handlers[key] = handler
|
|
201
211
|
|
|
202
212
|
# Now attach to NAT's root logger
|
|
203
|
-
|
|
213
|
+
root_logger.addHandler(handler)
|
|
214
|
+
|
|
215
|
+
# If we added logging handlers, manage existing handlers appropriately
|
|
216
|
+
if self._logging_handlers:
|
|
217
|
+
min_handler_level = min((handler.level for handler in root_logger.handlers), default=logging.CRITICAL)
|
|
218
|
+
|
|
219
|
+
# Ensure the root logger level allows messages through
|
|
220
|
+
root_logger.level = max(root_logger.level, min_handler_level)
|
|
221
|
+
|
|
222
|
+
# If a console handler is configured, adjust or remove default CLI handlers
|
|
223
|
+
# to avoid duplicate output while preserving workflow visibility
|
|
224
|
+
if has_console_handler:
|
|
225
|
+
# Remove existing StreamHandlers that are not the newly configured ones
|
|
226
|
+
for handler in root_logger.handlers[:]:
|
|
227
|
+
if type(handler) is logging.StreamHandler and handler not in self._logging_handlers.values():
|
|
228
|
+
self._removed_root_handlers.append((handler, handler.level))
|
|
229
|
+
root_logger.removeHandler(handler)
|
|
230
|
+
else:
|
|
231
|
+
# No console handler configured, but adjust existing handler levels
|
|
232
|
+
# to respect the minimum configured level for file/other handlers
|
|
233
|
+
for handler in root_logger.handlers[:]:
|
|
234
|
+
if type(handler) is logging.StreamHandler:
|
|
235
|
+
old_level = handler.level
|
|
236
|
+
handler.setLevel(min_handler_level)
|
|
237
|
+
self._removed_root_handlers.append((handler, old_level))
|
|
204
238
|
|
|
205
239
|
# Add the telemetry exporters
|
|
206
240
|
for key, telemetry_exporter_config in telemetry_config.tracing.items():
|
|
@@ -212,8 +246,17 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
212
246
|
|
|
213
247
|
assert self._exit_stack is not None, "Exit stack not initialized"
|
|
214
248
|
|
|
215
|
-
|
|
216
|
-
|
|
249
|
+
root_logger = logging.getLogger()
|
|
250
|
+
|
|
251
|
+
# Remove custom logging handlers
|
|
252
|
+
for handler in self._logging_handlers.values():
|
|
253
|
+
root_logger.removeHandler(handler)
|
|
254
|
+
|
|
255
|
+
# Restore original handlers and their levels
|
|
256
|
+
for handler, old_level in self._removed_root_handlers:
|
|
257
|
+
if handler not in root_logger.handlers:
|
|
258
|
+
root_logger.addHandler(handler)
|
|
259
|
+
handler.setLevel(old_level)
|
|
217
260
|
|
|
218
261
|
await self._exit_stack.__aexit__(*exc_details)
|
|
219
262
|
|
nat/cli/commands/mcp/mcp.py
CHANGED
|
@@ -194,7 +194,7 @@ async def _create_mcp_client_config(
|
|
|
194
194
|
auth_user_id: str | None,
|
|
195
195
|
auth_scopes: list[str] | None,
|
|
196
196
|
):
|
|
197
|
-
from nat.plugins.mcp.
|
|
197
|
+
from nat.plugins.mcp.client_config import MCPClientConfig
|
|
198
198
|
|
|
199
199
|
if url and transport == "streamable-http" and auth_redirect_uri:
|
|
200
200
|
try:
|
|
@@ -236,8 +236,8 @@ async def list_tools_via_function_group(
|
|
|
236
236
|
try:
|
|
237
237
|
# Ensure the registration side-effects are loaded
|
|
238
238
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
239
|
-
from nat.plugins.mcp.
|
|
240
|
-
from nat.plugins.mcp.
|
|
239
|
+
from nat.plugins.mcp.client_config import MCPClientConfig
|
|
240
|
+
from nat.plugins.mcp.client_config import MCPServerConfig
|
|
241
241
|
except ImportError:
|
|
242
242
|
click.echo(
|
|
243
243
|
"MCP client functionality requires nvidia-nat-mcp package. Install with: uv pip install nvidia-nat-mcp",
|
|
@@ -297,7 +297,7 @@ async def list_tools_via_function_group(
|
|
|
297
297
|
if fn is not None:
|
|
298
298
|
tools.append(to_tool_entry(full, fn))
|
|
299
299
|
else:
|
|
300
|
-
for full, fn in
|
|
300
|
+
for full, fn in fns.items():
|
|
301
301
|
tools.append(to_tool_entry(full, fn))
|
|
302
302
|
|
|
303
303
|
return tools
|
|
@@ -826,8 +826,8 @@ async def call_tool_and_print(command: str | None,
|
|
|
826
826
|
|
|
827
827
|
try:
|
|
828
828
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
829
|
-
from nat.plugins.mcp.
|
|
830
|
-
from nat.plugins.mcp.
|
|
829
|
+
from nat.plugins.mcp.client_config import MCPClientConfig
|
|
830
|
+
from nat.plugins.mcp.client_config import MCPServerConfig
|
|
831
831
|
except ImportError:
|
|
832
832
|
click.echo(
|
|
833
833
|
"MCP client functionality requires nvidia-nat-mcp package. Install with: uv pip install nvidia-nat-mcp",
|
|
@@ -1,15 +1,17 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
1
|
+
functions:
|
|
2
|
+
current_datetime:
|
|
3
|
+
_type: current_datetime
|
|
4
|
+
{{python_safe_workflow_name}}:
|
|
5
|
+
_type: {{python_safe_workflow_name}}
|
|
6
|
+
prefix: "Hello:"
|
|
6
7
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
8
|
+
llms:
|
|
9
|
+
nim_llm:
|
|
10
|
+
_type: nim
|
|
11
|
+
model_name: meta/llama-3.1-70b-instruct
|
|
12
|
+
temperature: 0.0
|
|
12
13
|
|
|
13
14
|
workflow:
|
|
14
|
-
_type:
|
|
15
|
-
|
|
15
|
+
_type: react_agent
|
|
16
|
+
llm_name: nim_llm
|
|
17
|
+
tool_names: [current_datetime, {{python_safe_workflow_name}}]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
1
|
# flake8: noqa
|
|
2
2
|
|
|
3
|
-
# Import
|
|
4
|
-
from {{package_name}} import {{
|
|
3
|
+
# Import the generated workflow function to trigger registration
|
|
4
|
+
from .{{package_name}} import {{ python_safe_workflow_name }}_function
|
|
@@ -3,6 +3,7 @@ import logging
|
|
|
3
3
|
from pydantic import Field
|
|
4
4
|
|
|
5
5
|
from nat.builder.builder import Builder
|
|
6
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
6
7
|
from nat.builder.function_info import FunctionInfo
|
|
7
8
|
from nat.cli.register_workflow import register_function
|
|
8
9
|
from nat.data_models.function import FunctionBaseConfig
|
|
@@ -12,25 +13,38 @@ logger = logging.getLogger(__name__)
|
|
|
12
13
|
|
|
13
14
|
class {{ workflow_class_name }}(FunctionBaseConfig, name="{{ workflow_name }}"):
|
|
14
15
|
"""
|
|
15
|
-
{{workflow_description}}
|
|
16
|
+
{{ workflow_description }}
|
|
16
17
|
"""
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
18
|
+
prefix: str = Field(default="Echo:", description="Prefix to add before the echoed text.")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@register_function(config_type={{ workflow_class_name }}, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
22
|
+
async def {{ python_safe_workflow_name }}_function(config: {{ workflow_class_name }}, builder: Builder):
|
|
23
|
+
"""
|
|
24
|
+
Registers a function (addressable via `{{ workflow_name }}` in the configuration).
|
|
25
|
+
This registration ensures a static mapping of the function type, `{{ workflow_name }}`, to the `{{ workflow_class_name }}` configuration object.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
config ({{ workflow_class_name }}): The configuration for the function.
|
|
29
|
+
builder (Builder): The builder object.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
FunctionInfo: The function info object for the function.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
# Define the function that will be registered.
|
|
36
|
+
async def _echo(text: str) -> str:
|
|
37
|
+
"""
|
|
38
|
+
Takes a text input and echoes back with a pre-defined prefix.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
text (str): The text to echo back.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
str: The text with the prefix.
|
|
45
|
+
"""
|
|
46
|
+
return f"{config.prefix} {text}"
|
|
47
|
+
|
|
48
|
+
# The callable is wrapped in a FunctionInfo object.
|
|
49
|
+
# The description parameter is used to describe the function.
|
|
50
|
+
yield FunctionInfo.from_fn(_echo, description=_echo.__doc__)
|