nvidia-nat 1.3.0rc1__py3-none-any.whl → 1.3.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. nat/agent/prompt_optimizer/register.py +2 -2
  2. nat/agent/react_agent/register.py +20 -21
  3. nat/agent/rewoo_agent/register.py +18 -20
  4. nat/agent/tool_calling_agent/register.py +7 -3
  5. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +31 -18
  6. nat/builder/component_utils.py +1 -1
  7. nat/builder/context.py +22 -6
  8. nat/builder/function.py +3 -2
  9. nat/builder/workflow_builder.py +46 -3
  10. nat/cli/commands/mcp/mcp.py +6 -6
  11. nat/cli/commands/workflow/templates/config.yml.j2 +14 -12
  12. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  13. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  14. nat/cli/commands/workflow/workflow_commands.py +54 -10
  15. nat/cli/entrypoint.py +9 -1
  16. nat/cli/main.py +3 -0
  17. nat/data_models/api_server.py +143 -66
  18. nat/data_models/config.py +1 -1
  19. nat/data_models/span.py +41 -3
  20. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  21. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +2 -2
  22. nat/front_ends/console/console_front_end_plugin.py +11 -2
  23. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  24. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +5 -35
  25. nat/front_ends/fastapi/message_validator.py +3 -1
  26. nat/observability/exporter/span_exporter.py +34 -14
  27. nat/observability/register.py +16 -0
  28. nat/profiler/decorators/framework_wrapper.py +1 -1
  29. nat/profiler/forecasting/models/linear_model.py +1 -1
  30. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  31. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  32. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  33. nat/runtime/runner.py +103 -6
  34. nat/runtime/session.py +27 -1
  35. nat/tool/memory_tools/add_memory_tool.py +3 -3
  36. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  37. nat/tool/memory_tools/get_memory_tool.py +4 -4
  38. nat/utils/decorators.py +210 -0
  39. nat/utils/type_converter.py +8 -0
  40. nvidia_nat-1.3.0rc3.dist-info/METADATA +195 -0
  41. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/RECORD +46 -45
  42. nvidia_nat-1.3.0rc1.dist-info/METADATA +0 -391
  43. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/WHEEL +0 -0
  44. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/entry_points.txt +0 -0
  45. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  46. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/licenses/LICENSE.md +0 -0
  47. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/top_level.txt +0 -0
@@ -51,7 +51,7 @@ async def prompt_optimizer_function(config: PromptOptimizerConfig, builder: Buil
51
51
  from .prompt import mutator_prompt
52
52
  except ImportError as exc:
53
53
  raise ImportError("langchain-core is not installed. Please install it to use MultiLLMPlanner.\n"
54
- "This error can be resolve by installing nvidia-nat[langchain]") from exc
54
+ "This error can be resolve by installing \"nvidia-nat[langchain]\".") from exc
55
55
 
56
56
  llm = await builder.get_llm(config.optimizer_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
57
57
 
@@ -111,7 +111,7 @@ async def prompt_recombiner_function(config: PromptRecombinerConfig, builder: Bu
111
111
  from langchain_core.prompts import PromptTemplate
112
112
  except ImportError as exc:
113
113
  raise ImportError("langchain-core is not installed. Please install it to use MultiLLMPlanner.\n"
114
- "This error can be resolve by installing nvidia-nat[langchain].") from exc
114
+ "This error can be resolve by installing \"nvidia-nat[langchain]\".") from exc
115
115
 
116
116
  llm = await builder.get_llm(config.optimizer_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
117
117
 
@@ -24,7 +24,9 @@ from nat.builder.function_info import FunctionInfo
24
24
  from nat.cli.register_workflow import register_function
25
25
  from nat.data_models.agent import AgentBaseConfig
26
26
  from nat.data_models.api_server import ChatRequest
27
+ from nat.data_models.api_server import ChatRequestOrMessage
27
28
  from nat.data_models.api_server import ChatResponse
29
+ from nat.data_models.api_server import Usage
28
30
  from nat.data_models.component_ref import FunctionGroupRef
29
31
  from nat.data_models.component_ref import FunctionRef
30
32
  from nat.data_models.optimizable import OptimizableField
@@ -69,9 +71,6 @@ class ReActAgentWorkflowConfig(AgentBaseConfig, OptimizableMixin, name="react_ag
69
71
  default=None,
70
72
  description="Provides the SYSTEM_PROMPT to use with the agent") # defaults to SYSTEM_PROMPT in prompt.py
71
73
  max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
72
- use_openai_api: bool = Field(default=False,
73
- description=("Use OpenAI API for the input/output types to the function. "
74
- "If False, strings will be used."))
75
74
  additional_instructions: str | None = OptimizableField(
76
75
  default=None,
77
76
  description="Additional instructions to provide to the agent in addition to the base prompt.",
@@ -117,21 +116,23 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
117
116
  pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent,
118
117
  normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()
119
118
 
120
- async def _response_fn(input_message: ChatRequest) -> ChatResponse:
119
+ async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
121
120
  """
122
121
  Main workflow entry function for the ReAct Agent.
123
122
 
124
123
  This function invokes the ReAct Agent Graph and returns the response.
125
124
 
126
125
  Args:
127
- input_message (ChatRequest): The input message to process
126
+ chat_request_or_message (ChatRequestOrMessage): The input message to process
128
127
 
129
128
  Returns:
130
- ChatResponse: The response from the agent or error message
129
+ ChatResponse | str: The response from the agent or error message
131
130
  """
132
131
  try:
132
+ message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
133
+
133
134
  # initialize the starting state with the user query
134
- messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
135
+ messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
135
136
  max_tokens=config.max_history,
136
137
  strategy="last",
137
138
  token_counter=len,
@@ -149,21 +150,19 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
149
150
  # get and return the output from the state
150
151
  state = ReActGraphState(**state)
151
152
  output_message = state.messages[-1]
152
- return ChatResponse.from_string(str(output_message.content))
153
-
153
+ content = str(output_message.content)
154
+
155
+ # Create usage statistics for the response
156
+ prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
157
+ completion_tokens = len(content.split()) if content else 0
158
+ total_tokens = prompt_tokens + completion_tokens
159
+ usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
160
+ response = ChatResponse.from_string(content, usage=usage)
161
+ if chat_request_or_message.is_string:
162
+ return GlobalTypeConverter.get().convert(response, to_type=str)
163
+ return response
154
164
  except Exception as ex:
155
165
  logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, str(ex))
156
166
  raise RuntimeError
157
167
 
158
- if (config.use_openai_api):
159
- yield FunctionInfo.from_fn(_response_fn, description=config.description)
160
- else:
161
-
162
- async def _str_api_fn(input_message: str) -> str:
163
- oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
164
-
165
- oai_output = await _response_fn(oai_input)
166
-
167
- return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
168
-
169
- yield FunctionInfo.from_fn(_str_api_fn, description=config.description)
168
+ yield FunctionInfo.from_fn(_response_fn, description=config.description)
@@ -25,7 +25,9 @@ from nat.builder.function_info import FunctionInfo
25
25
  from nat.cli.register_workflow import register_function
26
26
  from nat.data_models.agent import AgentBaseConfig
27
27
  from nat.data_models.api_server import ChatRequest
28
+ from nat.data_models.api_server import ChatRequestOrMessage
28
29
  from nat.data_models.api_server import ChatResponse
30
+ from nat.data_models.api_server import Usage
29
31
  from nat.data_models.component_ref import FunctionGroupRef
30
32
  from nat.data_models.component_ref import FunctionRef
31
33
  from nat.utils.type_converter import GlobalTypeConverter
@@ -53,9 +55,6 @@ class ReWOOAgentWorkflowConfig(AgentBaseConfig, name="rewoo_agent"):
53
55
  description="The number of retries before raising a tool call error.",
54
56
  ge=1)
55
57
  max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
56
- use_openai_api: bool = Field(default=False,
57
- description=("Use OpenAI API for the input/output types to the function. "
58
- "If False, strings will be used."))
59
58
  additional_planner_instructions: str | None = Field(
60
59
  default=None,
61
60
  validation_alias=AliasChoices("additional_planner_instructions", "additional_instructions"),
@@ -124,21 +123,23 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
124
123
  tool_call_max_retries=config.tool_call_max_retries,
125
124
  raise_tool_call_error=config.raise_tool_call_error).build_graph()
126
125
 
127
- async def _response_fn(input_message: ChatRequest) -> ChatResponse:
126
+ async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
128
127
  """
129
128
  Main workflow entry function for the ReWOO Agent.
130
129
 
131
130
  This function invokes the ReWOO Agent Graph and returns the response.
132
131
 
133
132
  Args:
134
- input_message (ChatRequest): The input message to process
133
+ chat_request_or_message (ChatRequestOrMessage): The input message to process
135
134
 
136
135
  Returns:
137
- ChatResponse: The response from the agent or error message
136
+ ChatResponse | str: The response from the agent or error message
138
137
  """
139
138
  try:
139
+ message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
140
+
140
141
  # initialize the starting state with the user query
141
- messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
142
+ messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
142
143
  max_tokens=config.max_history,
143
144
  strategy="last",
144
145
  token_counter=len,
@@ -157,21 +158,18 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
157
158
  # Ensure output_message is a string
158
159
  if isinstance(output_message, list | dict):
159
160
  output_message = str(output_message)
160
- return ChatResponse.from_string(output_message)
161
161
 
162
+ # Create usage statistics for the response
163
+ prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
164
+ completion_tokens = len(output_message.split()) if output_message else 0
165
+ total_tokens = prompt_tokens + completion_tokens
166
+ usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
167
+ response = ChatResponse.from_string(output_message, usage=usage)
168
+ if chat_request_or_message.is_string:
169
+ return GlobalTypeConverter.get().convert(response, to_type=str)
170
+ return response
162
171
  except Exception as ex:
163
172
  logger.exception("ReWOO Agent failed with exception: %s", ex)
164
173
  raise RuntimeError
165
174
 
166
- if (config.use_openai_api):
167
- yield FunctionInfo.from_fn(_response_fn, description=config.description)
168
-
169
- else:
170
-
171
- async def _str_api_fn(input_message: str) -> str:
172
- oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
173
- oai_output = await _response_fn(oai_input)
174
-
175
- return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
176
-
177
- yield FunctionInfo.from_fn(_str_api_fn, description=config.description)
175
+ yield FunctionInfo.from_fn(_response_fn, description=config.description)
@@ -23,8 +23,10 @@ from nat.builder.function_info import FunctionInfo
23
23
  from nat.cli.register_workflow import register_function
24
24
  from nat.data_models.agent import AgentBaseConfig
25
25
  from nat.data_models.api_server import ChatRequest
26
+ from nat.data_models.api_server import ChatRequestOrMessage
26
27
  from nat.data_models.component_ref import FunctionGroupRef
27
28
  from nat.data_models.component_ref import FunctionRef
29
+ from nat.utils.type_converter import GlobalTypeConverter
28
30
 
29
31
  logger = logging.getLogger(__name__)
30
32
 
@@ -81,21 +83,23 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
81
83
  handle_tool_errors=config.handle_tool_errors,
82
84
  return_direct=return_direct_tools).build_graph()
83
85
 
84
- async def _response_fn(input_message: ChatRequest) -> str:
86
+ async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> str:
85
87
  """
86
88
  Main workflow entry function for the Tool Calling Agent.
87
89
 
88
90
  This function invokes the Tool Calling Agent Graph and returns the response.
89
91
 
90
92
  Args:
91
- input_message (ChatRequest): The input message to process
93
+ chat_request_or_message (ChatRequestOrMessage): The input message to process
92
94
 
93
95
  Returns:
94
96
  str: The response from the agent or error message
95
97
  """
96
98
  try:
99
+ message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
100
+
97
101
  # initialize the starting state with the user query
98
- messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
102
+ messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
99
103
  max_tokens=config.max_history,
100
104
  strategy="last",
101
105
  token_counter=len,
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import logging
17
+ from collections.abc import Awaitable
17
18
  from collections.abc import Callable
18
19
  from datetime import UTC
19
20
  from datetime import datetime
@@ -35,10 +36,15 @@ logger = logging.getLogger(__name__)
35
36
 
36
37
  class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConfig]):
37
38
 
38
- def __init__(self, config: OAuth2AuthCodeFlowProviderConfig):
39
+ def __init__(self, config: OAuth2AuthCodeFlowProviderConfig, token_storage=None):
39
40
  super().__init__(config)
40
- self._authenticated_tokens: dict[str, AuthResult] = {}
41
41
  self._auth_callback = None
42
+ # Always use token storage - defaults to in-memory if not provided
43
+ if token_storage is None:
44
+ from nat.plugins.mcp.auth.token_storage import InMemoryTokenStorage
45
+ self._token_storage = InMemoryTokenStorage()
46
+ else:
47
+ self._token_storage = token_storage
42
48
 
43
49
  async def _attempt_token_refresh(self, user_id: str, auth_result: AuthResult) -> AuthResult | None:
44
50
  refresh_token = auth_result.raw.get("refresh_token")
@@ -61,7 +67,7 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
61
67
  raw=new_token_data,
62
68
  )
63
69
 
64
- self._authenticated_tokens[user_id] = new_auth_result
70
+ await self._token_storage.store(user_id, new_auth_result)
65
71
  except httpx.HTTPStatusError:
66
72
  return None
67
73
  except httpx.RequestError:
@@ -74,26 +80,30 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
74
80
 
75
81
  def _set_custom_auth_callback(self,
76
82
  auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType],
77
- AuthenticatedContext]):
83
+ Awaitable[AuthenticatedContext]]):
78
84
  self._auth_callback = auth_callback
79
85
 
80
86
  async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
81
- if user_id is None and hasattr(Context.get(), "metadata") and hasattr(
82
- Context.get().metadata, "cookies") and Context.get().metadata.cookies is not None:
83
- session_id = Context.get().metadata.cookies.get("nat-session", None)
87
+ context = Context.get()
88
+ if user_id is None and hasattr(context, "metadata") and hasattr(
89
+ context.metadata, "cookies") and context.metadata.cookies is not None:
90
+ session_id = context.metadata.cookies.get("nat-session", None)
84
91
  if not session_id:
85
92
  raise RuntimeError("Authentication failed. No session ID found. Cannot identify user.")
86
93
 
87
94
  user_id = session_id
88
95
 
89
- if user_id and user_id in self._authenticated_tokens:
90
- auth_result = self._authenticated_tokens[user_id]
91
- if not auth_result.is_expired():
92
- return auth_result
96
+ if user_id:
97
+ # Try to retrieve from token storage
98
+ auth_result = await self._token_storage.retrieve(user_id)
99
+
100
+ if auth_result:
101
+ if not auth_result.is_expired():
102
+ return auth_result
93
103
 
94
- refreshed_auth_result = await self._attempt_token_refresh(user_id, auth_result)
95
- if refreshed_auth_result:
96
- return refreshed_auth_result
104
+ refreshed_auth_result = await self._attempt_token_refresh(user_id, auth_result)
105
+ if refreshed_auth_result:
106
+ return refreshed_auth_result
97
107
 
98
108
  # Try getting callback from the context if that's not set, use the default callback
99
109
  try:
@@ -109,19 +119,22 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
109
119
  except Exception as e:
110
120
  raise RuntimeError(f"Authentication callback failed: {e}") from e
111
121
 
112
- auth_header = authenticated_context.headers.get("Authorization", "")
122
+ headers = authenticated_context.headers or {}
123
+ auth_header = headers.get("Authorization", "")
113
124
  if not auth_header.startswith("Bearer "):
114
125
  raise RuntimeError("Invalid Authorization header")
115
126
 
116
127
  token = auth_header.split(" ")[1]
117
128
 
129
+ # Safely access metadata
130
+ metadata = authenticated_context.metadata or {}
118
131
  auth_result = AuthResult(
119
132
  credentials=[BearerTokenCred(token=SecretStr(token))],
120
- token_expires_at=authenticated_context.metadata.get("expires_at"),
121
- raw=authenticated_context.metadata.get("raw_token"),
133
+ token_expires_at=metadata.get("expires_at"),
134
+ raw=metadata.get("raw_token") or {},
122
135
  )
123
136
 
124
137
  if user_id:
125
- self._authenticated_tokens[user_id] = auth_result
138
+ await self._token_storage.store(user_id, auth_result)
126
139
 
127
140
  return auth_result
@@ -153,7 +153,7 @@ def recursive_componentref_discovery(cls: TypedBaseModel, value: typing.Any,
153
153
  for v in value.values():
154
154
  yield from recursive_componentref_discovery(cls, v, decomposed_type.args[1])
155
155
  elif (issubclass(type(value), BaseModel)):
156
- for field, field_info in value.model_fields.items():
156
+ for field, field_info in type(value).model_fields.items():
157
157
  field_data = getattr(value, field)
158
158
  yield from recursive_componentref_discovery(cls, field_data, field_info.annotation)
159
159
  if (decomposed_type.is_union):
nat/builder/context.py CHANGED
@@ -67,6 +67,8 @@ class ContextState(metaclass=Singleton):
67
67
  def __init__(self):
68
68
  self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None)
69
69
  self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
70
+ self.workflow_run_id: ContextVar[str | None] = ContextVar("workflow_run_id", default=None)
71
+ self.workflow_trace_id: ContextVar[int | None] = ContextVar("workflow_trace_id", default=None)
70
72
  self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
71
73
  self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
72
74
  self._metadata: ContextVar[RequestAttributes | None] = ContextVar("request_attributes", default=None)
@@ -120,14 +122,14 @@ class Context:
120
122
  @property
121
123
  def input_message(self):
122
124
  """
123
- Retrieves the input message from the context state.
125
+ Retrieves the input message from the context state.
124
126
 
125
- The input_message property is used to access the message stored in the
126
- context state. This property returns the message as it is currently
127
- maintained in the context.
127
+ The input_message property is used to access the message stored in the
128
+ context state. This property returns the message as it is currently
129
+ maintained in the context.
128
130
 
129
- Returns:
130
- str: The input message retrieved from the context state.
131
+ Returns:
132
+ str: The input message retrieved from the context state.
131
133
  """
132
134
  return self._context_state.input_message.get()
133
135
 
@@ -196,6 +198,20 @@ class Context:
196
198
  """
197
199
  return self._context_state.user_message_id.get()
198
200
 
201
+ @property
202
+ def workflow_run_id(self) -> str | None:
203
+ """
204
+ Returns a stable identifier for the current workflow/agent invocation (UUID string).
205
+ """
206
+ return self._context_state.workflow_run_id.get()
207
+
208
+ @property
209
+ def workflow_trace_id(self) -> int | None:
210
+ """
211
+ Returns the 128-bit trace identifier for the current run, used as the OpenTelemetry trace_id.
212
+ """
213
+ return self._context_state.workflow_trace_id.get()
214
+
199
215
  @contextmanager
200
216
  def push_active_function(self,
201
217
  function_name: str,
nat/builder/function.py CHANGED
@@ -416,8 +416,9 @@ class FunctionGroup:
416
416
  """
417
417
  if not name.strip():
418
418
  raise ValueError("Function name cannot be empty or blank")
419
- if not re.match(r"^[a-zA-Z0-9_-]+$", name):
420
- raise ValueError(f"Function name can only contain letters, numbers, underscores, and hyphens: {name}")
419
+ if not re.match(r"^[a-zA-Z0-9_.-]+$", name):
420
+ raise ValueError(
421
+ f"Function name can only contain letters, numbers, underscores, periods, and hyphens: {name}")
421
422
  if name in self._functions:
422
423
  raise ValueError(f"Function {name} already exists in function group {self._instance_name}")
423
424
 
@@ -156,6 +156,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
156
156
  self._registry = registry
157
157
 
158
158
  self._logging_handlers: dict[str, logging.Handler] = {}
159
+ self._removed_root_handlers: list[tuple[logging.Handler, int]] = []
159
160
  self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
160
161
 
161
162
  self._functions: dict[str, ConfiguredFunction] = {}
@@ -187,6 +188,15 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
187
188
  # Get the telemetry info from the config
188
189
  telemetry_config = self.general_config.telemetry
189
190
 
191
+ # If we have logging configuration, we need to manage the root logger properly
192
+ root_logger = logging.getLogger()
193
+
194
+ # Collect configured handler types to determine if we need to adjust existing handlers
195
+ # This is somewhat of a hack by inspecting the class name of the config object
196
+ has_console_handler = any(
197
+ hasattr(config, "__class__") and "console" in config.__class__.__name__.lower()
198
+ for config in telemetry_config.logging.values())
199
+
190
200
  for key, logging_config in telemetry_config.logging.items():
191
201
  # Use the same pattern as tracing, but for logging
192
202
  logging_info = self._registry.get_logging_method(type(logging_config))
@@ -200,7 +210,31 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
200
210
  self._logging_handlers[key] = handler
201
211
 
202
212
  # Now attach to NAT's root logger
203
- logging.getLogger().addHandler(handler)
213
+ root_logger.addHandler(handler)
214
+
215
+ # If we added logging handlers, manage existing handlers appropriately
216
+ if self._logging_handlers:
217
+ min_handler_level = min((handler.level for handler in root_logger.handlers), default=logging.CRITICAL)
218
+
219
+ # Ensure the root logger level allows messages through
220
+ root_logger.level = max(root_logger.level, min_handler_level)
221
+
222
+ # If a console handler is configured, adjust or remove default CLI handlers
223
+ # to avoid duplicate output while preserving workflow visibility
224
+ if has_console_handler:
225
+ # Remove existing StreamHandlers that are not the newly configured ones
226
+ for handler in root_logger.handlers[:]:
227
+ if type(handler) is logging.StreamHandler and handler not in self._logging_handlers.values():
228
+ self._removed_root_handlers.append((handler, handler.level))
229
+ root_logger.removeHandler(handler)
230
+ else:
231
+ # No console handler configured, but adjust existing handler levels
232
+ # to respect the minimum configured level for file/other handlers
233
+ for handler in root_logger.handlers[:]:
234
+ if type(handler) is logging.StreamHandler:
235
+ old_level = handler.level
236
+ handler.setLevel(min_handler_level)
237
+ self._removed_root_handlers.append((handler, old_level))
204
238
 
205
239
  # Add the telemetry exporters
206
240
  for key, telemetry_exporter_config in telemetry_config.tracing.items():
@@ -212,8 +246,17 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
212
246
 
213
247
  assert self._exit_stack is not None, "Exit stack not initialized"
214
248
 
215
- for _, handler in self._logging_handlers.items():
216
- logging.getLogger().removeHandler(handler)
249
+ root_logger = logging.getLogger()
250
+
251
+ # Remove custom logging handlers
252
+ for handler in self._logging_handlers.values():
253
+ root_logger.removeHandler(handler)
254
+
255
+ # Restore original handlers and their levels
256
+ for handler, old_level in self._removed_root_handlers:
257
+ if handler not in root_logger.handlers:
258
+ root_logger.addHandler(handler)
259
+ handler.setLevel(old_level)
217
260
 
218
261
  await self._exit_stack.__aexit__(*exc_details)
219
262
 
@@ -194,7 +194,7 @@ async def _create_mcp_client_config(
194
194
  auth_user_id: str | None,
195
195
  auth_scopes: list[str] | None,
196
196
  ):
197
- from nat.plugins.mcp.client_impl import MCPClientConfig
197
+ from nat.plugins.mcp.client_config import MCPClientConfig
198
198
 
199
199
  if url and transport == "streamable-http" and auth_redirect_uri:
200
200
  try:
@@ -236,8 +236,8 @@ async def list_tools_via_function_group(
236
236
  try:
237
237
  # Ensure the registration side-effects are loaded
238
238
  from nat.builder.workflow_builder import WorkflowBuilder
239
- from nat.plugins.mcp.client_impl import MCPClientConfig
240
- from nat.plugins.mcp.client_impl import MCPServerConfig
239
+ from nat.plugins.mcp.client_config import MCPClientConfig
240
+ from nat.plugins.mcp.client_config import MCPServerConfig
241
241
  except ImportError:
242
242
  click.echo(
243
243
  "MCP client functionality requires nvidia-nat-mcp package. Install with: uv pip install nvidia-nat-mcp",
@@ -297,7 +297,7 @@ async def list_tools_via_function_group(
297
297
  if fn is not None:
298
298
  tools.append(to_tool_entry(full, fn))
299
299
  else:
300
- for full, fn in (await fns).items():
300
+ for full, fn in fns.items():
301
301
  tools.append(to_tool_entry(full, fn))
302
302
 
303
303
  return tools
@@ -826,8 +826,8 @@ async def call_tool_and_print(command: str | None,
826
826
 
827
827
  try:
828
828
  from nat.builder.workflow_builder import WorkflowBuilder
829
- from nat.plugins.mcp.client_impl import MCPClientConfig
830
- from nat.plugins.mcp.client_impl import MCPServerConfig
829
+ from nat.plugins.mcp.client_config import MCPClientConfig
830
+ from nat.plugins.mcp.client_config import MCPServerConfig
831
831
  except ImportError:
832
832
  click.echo(
833
833
  "MCP client functionality requires nvidia-nat-mcp package. Install with: uv pip install nvidia-nat-mcp",
@@ -1,15 +1,17 @@
1
- general:
2
- logging:
3
- console:
4
- _type: console
5
- level: WARN
1
+ functions:
2
+ current_datetime:
3
+ _type: current_datetime
4
+ {{python_safe_workflow_name}}:
5
+ _type: {{python_safe_workflow_name}}
6
+ prefix: "Hello:"
6
7
 
7
- front_end:
8
- _type: fastapi
9
-
10
- front_end:
11
- _type: console
8
+ llms:
9
+ nim_llm:
10
+ _type: nim
11
+ model_name: meta/llama-3.1-70b-instruct
12
+ temperature: 0.0
12
13
 
13
14
  workflow:
14
- _type: {{workflow_name}}
15
- parameter: default_value
15
+ _type: react_agent
16
+ llm_name: nim_llm
17
+ tool_names: [current_datetime, {{python_safe_workflow_name}}]
@@ -1,4 +1,4 @@
1
1
  # flake8: noqa
2
2
 
3
- # Import any tools which need to be automatically registered here
4
- from {{package_name}} import {{workflow_name}}_function
3
+ # Import the generated workflow function to trigger registration
4
+ from .{{package_name}} import {{ python_safe_workflow_name }}_function
@@ -3,6 +3,7 @@ import logging
3
3
  from pydantic import Field
4
4
 
5
5
  from nat.builder.builder import Builder
6
+ from nat.builder.framework_enum import LLMFrameworkEnum
6
7
  from nat.builder.function_info import FunctionInfo
7
8
  from nat.cli.register_workflow import register_function
8
9
  from nat.data_models.function import FunctionBaseConfig
@@ -12,25 +13,38 @@ logger = logging.getLogger(__name__)
12
13
 
13
14
  class {{ workflow_class_name }}(FunctionBaseConfig, name="{{ workflow_name }}"):
14
15
  """
15
- {{workflow_description}}
16
+ {{ workflow_description }}
16
17
  """
17
- # Add your custom configuration parameters here
18
- parameter: str = Field(default="default_value", description="Notional description for this parameter")
19
-
20
-
21
- @register_function(config_type={{ workflow_class_name }})
22
- async def {{ python_safe_workflow_name }}_function(
23
- config: {{ workflow_class_name }}, builder: Builder
24
- ):
25
- # Implement your function logic here
26
- async def _response_fn(input_message: str) -> str:
27
- # Process the input_message and generate output
28
- output_message = f"Hello from {{ workflow_name }} workflow! You said: {input_message}"
29
- return output_message
30
-
31
- try:
32
- yield FunctionInfo.create(single_fn=_response_fn)
33
- except GeneratorExit:
34
- logger.warning("Function exited early!")
35
- finally:
36
- logger.info("Cleaning up {{ workflow_name }} workflow.")
18
+ prefix: str = Field(default="Echo:", description="Prefix to add before the echoed text.")
19
+
20
+
21
+ @register_function(config_type={{ workflow_class_name }}, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
22
+ async def {{ python_safe_workflow_name }}_function(config: {{ workflow_class_name }}, builder: Builder):
23
+ """
24
+ Registers a function (addressable via `{{ workflow_name }}` in the configuration).
25
+ This registration ensures a static mapping of the function type, `{{ workflow_name }}`, to the `{{ workflow_class_name }}` configuration object.
26
+
27
+ Args:
28
+ config ({{ workflow_class_name }}): The configuration for the function.
29
+ builder (Builder): The builder object.
30
+
31
+ Returns:
32
+ FunctionInfo: The function info object for the function.
33
+ """
34
+
35
+ # Define the function that will be registered.
36
+ async def _echo(text: str) -> str:
37
+ """
38
+ Takes a text input and echoes back with a pre-defined prefix.
39
+
40
+ Args:
41
+ text (str): The text to echo back.
42
+
43
+ Returns:
44
+ str: The text with the prefix.
45
+ """
46
+ return f"{config.prefix} {text}"
47
+
48
+ # The callable is wrapped in a FunctionInfo object.
49
+ # The description parameter is used to describe the function.
50
+ yield FunctionInfo.from_fn(_echo, description=_echo.__doc__)