nvidia-nat 1.4.0a20251015__py3-none-any.whl → 1.4.0a20251021__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. nat/agent/base.py +3 -3
  2. nat/agent/reasoning_agent/reasoning_agent.py +6 -6
  3. nat/agent/register.py +1 -0
  4. nat/agent/responses_api_agent/__init__.py +14 -0
  5. nat/agent/responses_api_agent/register.py +126 -0
  6. nat/agent/tool_calling_agent/agent.py +6 -10
  7. nat/builder/context.py +2 -1
  8. nat/builder/intermediate_step_manager.py +6 -2
  9. nat/data_models/api_server.py +83 -33
  10. nat/data_models/intermediate_step.py +9 -1
  11. nat/data_models/llm.py +15 -1
  12. nat/data_models/openai_mcp.py +46 -0
  13. nat/data_models/optimizable.py +2 -1
  14. nat/data_models/thinking_mixin.py +2 -2
  15. nat/eval/evaluate.py +2 -0
  16. nat/eval/usage_stats.py +2 -0
  17. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +3 -0
  18. nat/front_ends/fastapi/message_handler.py +65 -40
  19. nat/front_ends/fastapi/message_validator.py +1 -2
  20. nat/front_ends/mcp/mcp_front_end_config.py +32 -0
  21. nat/front_ends/mcp/mcp_front_end_plugin.py +9 -6
  22. nat/llm/aws_bedrock_llm.py +3 -3
  23. nat/llm/litellm_llm.py +6 -3
  24. nat/llm/nim_llm.py +3 -3
  25. nat/llm/openai_llm.py +4 -3
  26. nat/profiler/callbacks/langchain_callback_handler.py +32 -7
  27. nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
  28. nat/profiler/callbacks/token_usage_base_model.py +2 -0
  29. nat/utils/exception_handlers/automatic_retries.py +205 -54
  30. nat/utils/responses_api.py +26 -0
  31. {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/METADATA +4 -4
  32. {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/RECORD +37 -33
  33. {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/WHEEL +0 -0
  34. {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/entry_points.txt +0 -0
  35. {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  36. {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/licenses/LICENSE.md +0 -0
  37. {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ from pydantic import ValidationError
25
25
  from starlette.websockets import WebSocketDisconnect
26
26
 
27
27
  from nat.authentication.interfaces import FlowHandlerBase
28
+ from nat.data_models.api_server import ChatRequest
28
29
  from nat.data_models.api_server import ChatResponse
29
30
  from nat.data_models.api_server import ChatResponseChunk
30
31
  from nat.data_models.api_server import Error
@@ -33,6 +34,8 @@ from nat.data_models.api_server import ResponsePayloadOutput
33
34
  from nat.data_models.api_server import ResponseSerializable
34
35
  from nat.data_models.api_server import SystemResponseContent
35
36
  from nat.data_models.api_server import TextContent
37
+ from nat.data_models.api_server import UserMessageContentRoleType
38
+ from nat.data_models.api_server import UserMessages
36
39
  from nat.data_models.api_server import WebSocketMessageStatus
37
40
  from nat.data_models.api_server import WebSocketMessageType
38
41
  from nat.data_models.api_server import WebSocketSystemInteractionMessage
@@ -64,12 +67,12 @@ class WebSocketMessageHandler:
64
67
  self._running_workflow_task: asyncio.Task | None = None
65
68
  self._message_parent_id: str = "default_id"
66
69
  self._conversation_id: str | None = None
67
- self._workflow_schema_type: str = None
68
- self._user_interaction_response: asyncio.Future[HumanResponse] | None = None
70
+ self._workflow_schema_type: str | None = None
71
+ self._user_interaction_response: asyncio.Future[TextContent] | None = None
69
72
 
70
73
  self._flow_handler: FlowHandlerBase | None = None
71
74
 
72
- self._schema_output_mapping: dict[str, type[BaseModel] | None] = {
75
+ self._schema_output_mapping: dict[str, type[BaseModel] | type[None]] = {
73
76
  WorkflowSchemaType.GENERATE: self._session_manager.workflow.single_output_schema,
74
77
  WorkflowSchemaType.CHAT: ChatResponse,
75
78
  WorkflowSchemaType.CHAT_STREAM: ChatResponseChunk,
@@ -114,36 +117,58 @@ class WebSocketMessageHandler:
114
117
  pass
115
118
 
116
119
  elif (isinstance(validated_message, WebSocketUserInteractionResponseMessage)):
117
- user_content = await self.process_user_message_content(validated_message)
120
+ user_content = await self._process_websocket_user_interaction_response_message(validated_message)
121
+ assert self._user_interaction_response is not None
118
122
  self._user_interaction_response.set_result(user_content)
119
123
  except (asyncio.CancelledError, WebSocketDisconnect):
120
124
  # TODO: Handle the disconnect
121
125
  break
122
126
 
123
- async def process_user_message_content(
124
- self, user_content: WebSocketUserMessage | WebSocketUserInteractionResponseMessage) -> BaseModel | None:
127
+ def _extract_last_user_message_content(self, messages: list[UserMessages]) -> TextContent:
125
128
  """
126
- Processes the contents of a user message.
129
+ Extracts the last user's TextContent from a list of messages.
127
130
 
128
- :param user_content: Incoming content data model.
129
- :return: A validated Pydantic user content model or None if not found.
130
- """
131
+ Args:
132
+ messages: List of UserMessages.
131
133
 
132
- for user_message in user_content.content.messages[::-1]:
133
- if (user_message.role == "user"):
134
+ Returns:
135
+ TextContent object from the last user message.
134
136
 
137
+ Raises:
138
+ ValueError: If no user text content is found.
139
+ """
140
+ for user_message in messages[::-1]:
141
+ if user_message.role == UserMessageContentRoleType.USER:
135
142
  for attachment in user_message.content:
136
-
137
143
  if isinstance(attachment, TextContent):
138
144
  return attachment
145
+ raise ValueError("No user text content found in messages.")
146
+
147
+ async def _process_websocket_user_interaction_response_message(
148
+ self, user_content: WebSocketUserInteractionResponseMessage) -> TextContent:
149
+ """
150
+ Processes a WebSocketUserInteractionResponseMessage.
151
+ """
152
+ return self._extract_last_user_message_content(user_content.content.messages)
139
153
 
140
- return None
154
+ async def _process_websocket_user_message(self, user_content: WebSocketUserMessage) -> ChatRequest | str:
155
+ """
156
+ Processes a WebSocketUserMessage based on schema type.
157
+ """
158
+ if self._workflow_schema_type in [WorkflowSchemaType.CHAT, WorkflowSchemaType.CHAT_STREAM]:
159
+ return ChatRequest(**user_content.content.model_dump(include={"messages"}))
160
+
161
+ elif self._workflow_schema_type in [WorkflowSchemaType.GENERATE, WorkflowSchemaType.GENERATE_STREAM]:
162
+ return self._extract_last_user_message_content(user_content.content.messages).text
163
+
164
+ raise ValueError("Unsupported workflow schema type for WebSocketUserMessage")
141
165
 
142
166
  async def process_workflow_request(self, user_message_as_validated_type: WebSocketUserMessage) -> None:
143
167
  """
144
168
  Process user messages and routes them appropriately.
145
169
 
146
- :param user_message_as_validated_type: A WebSocketUserMessage Data Model instance.
170
+ Args:
171
+ user_message_as_validated_type (WebSocketUserMessage): The validated user message to process.
147
172
  """
148
173
 
149
174
  try:
@@ -151,18 +176,15 @@ class WebSocketMessageHandler:
151
176
  self._workflow_schema_type = user_message_as_validated_type.schema_type
152
177
  self._conversation_id = user_message_as_validated_type.conversation_id
153
178
 
154
- content: BaseModel | None = await self.process_user_message_content(user_message_as_validated_type)
155
-
156
- if content is None:
157
- raise ValueError(f"User message content could not be found: {user_message_as_validated_type}")
179
+ message_content: typing.Any = await self._process_websocket_user_message(user_message_as_validated_type)
158
180
 
159
- if isinstance(content, TextContent) and (self._running_workflow_task is None):
181
+ if (self._running_workflow_task is None):
160
182
 
161
- def _done_callback(task: asyncio.Task):
183
+ def _done_callback(_task: asyncio.Task):
162
184
  self._running_workflow_task = None
163
185
 
164
186
  self._running_workflow_task = asyncio.create_task(
165
- self._run_workflow(payload=content.text,
187
+ self._run_workflow(payload=message_content,
166
188
  user_message_id=self._message_parent_id,
167
189
  conversation_id=self._conversation_id,
168
190
  result_type=self._schema_output_mapping[self._workflow_schema_type],
@@ -180,13 +202,14 @@ class WebSocketMessageHandler:
180
202
  async def create_websocket_message(self,
181
203
  data_model: BaseModel,
182
204
  message_type: str | None = None,
183
- status: str = WebSocketMessageStatus.IN_PROGRESS) -> None:
205
+ status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS) -> None:
184
206
  """
185
207
  Creates a websocket message that will be ready for routing based on message type or data model.
186
208
 
187
- :param data_model: Message content model.
188
- :param message_type: Message content model.
189
- :param status: Message content model.
209
+ Args:
210
+ data_model (BaseModel): Message content model.
211
+ message_type (str | None): Message content model.
212
+ status (WebSocketMessageStatus): Message content model.
190
213
  """
191
214
  try:
192
215
  message: BaseModel | None = None
@@ -196,8 +219,8 @@ class WebSocketMessageHandler:
196
219
 
197
220
  message_schema: type[BaseModel] = await self._message_validator.get_message_schema_by_type(message_type)
198
221
 
199
- if 'id' in data_model.model_fields:
200
- message_id: str = data_model.id
222
+ if hasattr(data_model, 'id'):
223
+ message_id: str = str(getattr(data_model, 'id'))
201
224
  else:
202
225
  message_id = str(uuid.uuid4())
203
226
 
@@ -253,12 +276,15 @@ class WebSocketMessageHandler:
253
276
  Registered human interaction callback that processes human interactions and returns
254
277
  responses from websocket connection.
255
278
 
256
- :param prompt: Incoming interaction content data model.
257
- :return: A Text Content Base Pydantic model.
279
+ Args:
280
+ prompt: Incoming interaction content data model.
281
+
282
+ Returns:
283
+ A Text Content Base Pydantic model.
258
284
  """
259
285
 
260
286
  # First create a future from the loop for the human response
261
- human_response_future: asyncio.Future[HumanResponse] = asyncio.get_running_loop().create_future()
287
+ human_response_future: asyncio.Future[TextContent] = asyncio.get_running_loop().create_future()
262
288
 
263
289
  # Then add the future to the outstanding human prompts dictionary
264
290
  self._user_interaction_response = human_response_future
@@ -274,10 +300,10 @@ class WebSocketMessageHandler:
274
300
  return HumanResponseNotification()
275
301
 
276
302
  # Wait for the human response future to complete
277
- interaction_response: HumanResponse = await human_response_future
303
+ text_content: TextContent = await human_response_future
278
304
 
279
305
  interaction_response: HumanResponse = await self._message_validator.convert_text_content_to_human_response(
280
- interaction_response, prompt.content)
306
+ text_content, prompt.content)
281
307
 
282
308
  return interaction_response
283
309
 
@@ -293,13 +319,12 @@ class WebSocketMessageHandler:
293
319
  output_type: type | None = None) -> None:
294
320
 
295
321
  try:
296
- async with self._session_manager.session(
297
- user_message_id=user_message_id,
298
- conversation_id=conversation_id,
299
- http_connection=self._socket,
300
- user_input_callback=self.human_interaction_callback,
301
- user_authentication_callback=(self._flow_handler.authenticate
302
- if self._flow_handler else None)) as session:
322
+ auth_callback = self._flow_handler.authenticate if self._flow_handler else None
323
+ async with self._session_manager.session(user_message_id=user_message_id,
324
+ conversation_id=conversation_id,
325
+ http_connection=self._socket,
326
+ user_input_callback=self.human_interaction_callback,
327
+ user_authentication_callback=auth_callback) as session:
303
328
 
304
329
  async for value in generate_streaming_response(payload,
305
330
  session_manager=session,
@@ -240,8 +240,7 @@ class MessageValidator:
240
240
  thread_id: str = "default",
241
241
  parent_id: str = "default",
242
242
  conversation_id: str | None = None,
243
- content: SystemResponseContent
244
- | Error = SystemResponseContent(),
243
+ content: SystemResponseContent | Error = SystemResponseContent(),
245
244
  status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
246
245
  timestamp: str = str(datetime.datetime.now(datetime.UTC))
247
246
  ) -> WebSocketSystemResponseTokenMessage | None:
@@ -13,13 +13,17 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import logging
16
17
  from typing import Literal
17
18
 
18
19
  from pydantic import Field
20
+ from pydantic import model_validator
19
21
 
20
22
  from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
21
23
  from nat.data_models.front_end import FrontEndBaseConfig
22
24
 
25
+ logger = logging.getLogger(__name__)
26
+
23
27
 
24
28
  class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
25
29
  """MCP front end configuration.
@@ -56,3 +60,31 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
56
60
  le=50)
57
61
  memory_profile_log_level: str = Field(default="DEBUG",
58
62
  description="Log level for memory profiling output (default: DEBUG)")
63
+
64
+ @model_validator(mode="after")
65
+ def validate_security_configuration(self):
66
+ """Validate security configuration to prevent accidental misconfigurations."""
67
+ # Check if server is bound to a non-localhost interface without authentication
68
+ localhost_hosts = {"localhost", "127.0.0.1", "::1"}
69
+ if self.host not in localhost_hosts and self.server_auth is None:
70
+ logger.warning(
71
+ "MCP server is configured to bind to '%s' without authentication. "
72
+ "This may expose your server to unauthorized access. "
73
+ "Consider either: (1) binding to localhost for local-only access, "
74
+ "or (2) configuring server_auth for production deployments on public interfaces.",
75
+ self.host)
76
+
77
+ # Check if SSE transport is used (which doesn't support authentication)
78
+ if self.transport == "sse":
79
+ if self.server_auth is not None:
80
+ logger.warning("SSE transport does not support authentication. "
81
+ "The configured server_auth will be ignored. "
82
+ "For production use with authentication, use 'streamable-http' transport instead.")
83
+ elif self.host not in localhost_hosts:
84
+ logger.warning(
85
+ "SSE transport does not support authentication and is bound to '%s'. "
86
+ "This configuration is not recommended for production use. "
87
+ "For production deployments, use 'streamable-http' transport with server_auth configured.",
88
+ self.host)
89
+
90
+ return self
@@ -105,9 +105,12 @@ class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
105
105
 
106
106
  # Start the MCP server with configurable transport
107
107
  # streamable-http is the default, but users can choose sse if preferred
108
- if self.front_end_config.transport == "sse":
109
- logger.info("Starting MCP server with SSE endpoint at /sse")
110
- await mcp.run_sse_async()
111
- else: # streamable-http
112
- logger.info("Starting MCP server with streamable-http endpoint at /mcp/")
113
- await mcp.run_streamable_http_async()
108
+ try:
109
+ if self.front_end_config.transport == "sse":
110
+ logger.info("Starting MCP server with SSE endpoint at /sse")
111
+ await mcp.run_sse_async()
112
+ else: # streamable-http
113
+ logger.info("Starting MCP server with streamable-http endpoint at /mcp/")
114
+ await mcp.run_streamable_http_async()
115
+ except KeyboardInterrupt:
116
+ logger.info("MCP server shutdown requested (Ctrl+C). Shutting down gracefully.")
@@ -42,9 +42,9 @@ class AWSBedrockModelConfig(LLMBaseConfig,
42
42
  model_config = ConfigDict(protected_namespaces=(), extra="allow")
43
43
 
44
44
  # Completion parameters
45
- model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
46
- serialization_alias="model",
47
- description="The model name for the hosted AWS Bedrock.")
45
+ model_name: str = OptimizableField(validation_alias=AliasChoices("model_name", "model"),
46
+ serialization_alias="model",
47
+ description="The model name for the hosted AWS Bedrock.")
48
48
  max_tokens: int = OptimizableField(default=300,
49
49
  description="Maximum number of tokens to generate.",
50
50
  space=SearchSpace(high=2176, low=128, step=512))
nat/llm/litellm_llm.py CHANGED
@@ -23,6 +23,8 @@ from nat.builder.builder import Builder
23
23
  from nat.builder.llm import LLMProviderInfo
24
24
  from nat.cli.register_workflow import register_llm_provider
25
25
  from nat.data_models.llm import LLMBaseConfig
26
+ from nat.data_models.optimizable import OptimizableField
27
+ from nat.data_models.optimizable import OptimizableMixin
26
28
  from nat.data_models.retry_mixin import RetryMixin
27
29
  from nat.data_models.temperature_mixin import TemperatureMixin
28
30
  from nat.data_models.thinking_mixin import ThinkingMixin
@@ -31,6 +33,7 @@ from nat.data_models.top_p_mixin import TopPMixin
31
33
 
32
34
  class LiteLlmModelConfig(
33
35
  LLMBaseConfig,
36
+ OptimizableMixin,
34
37
  RetryMixin,
35
38
  TemperatureMixin,
36
39
  TopPMixin,
@@ -46,9 +49,9 @@ class LiteLlmModelConfig(
46
49
  description="Base url to the hosted model.",
47
50
  validation_alias=AliasChoices("base_url", "api_base"),
48
51
  serialization_alias="api_base")
49
- model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
50
- serialization_alias="model",
51
- description="The LiteLlm hosted model name.")
52
+ model_name: str = OptimizableField(validation_alias=AliasChoices("model_name", "model"),
53
+ serialization_alias="model",
54
+ description="The LiteLlm hosted model name.")
52
55
  seed: int | None = Field(default=None, description="Random seed to set for generation.")
53
56
 
54
57
 
nat/llm/nim_llm.py CHANGED
@@ -44,9 +44,9 @@ class NIMModelConfig(LLMBaseConfig,
44
44
 
45
45
  api_key: str | None = Field(default=None, description="NVIDIA API key to interact with hosted NIM.")
46
46
  base_url: str | None = Field(default=None, description="Base url to the hosted NIM.")
47
- model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
48
- serialization_alias="model",
49
- description="The model name for the hosted NIM.")
47
+ model_name: str = OptimizableField(validation_alias=AliasChoices("model_name", "model"),
48
+ serialization_alias="model",
49
+ description="The model name for the hosted NIM.")
50
50
  max_tokens: PositiveInt = OptimizableField(default=300,
51
51
  description="Maximum number of tokens to generate.",
52
52
  space=SearchSpace(high=2176, low=128, step=512))
nat/llm/openai_llm.py CHANGED
@@ -21,6 +21,7 @@ from nat.builder.builder import Builder
21
21
  from nat.builder.llm import LLMProviderInfo
22
22
  from nat.cli.register_workflow import register_llm_provider
23
23
  from nat.data_models.llm import LLMBaseConfig
24
+ from nat.data_models.optimizable import OptimizableField
24
25
  from nat.data_models.optimizable import OptimizableMixin
25
26
  from nat.data_models.retry_mixin import RetryMixin
26
27
  from nat.data_models.temperature_mixin import TemperatureMixin
@@ -41,9 +42,9 @@ class OpenAIModelConfig(LLMBaseConfig,
41
42
 
42
43
  api_key: str | None = Field(default=None, description="OpenAI API key to interact with hosted model.")
43
44
  base_url: str | None = Field(default=None, description="Base url to the hosted model.")
44
- model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
45
- serialization_alias="model",
46
- description="The OpenAI hosted model name.")
45
+ model_name: str = OptimizableField(validation_alias=AliasChoices("model_name", "model"),
46
+ serialization_alias="model",
47
+ description="The OpenAI hosted model name.")
47
48
  seed: int | None = Field(default=None, description="Random seed to set for generation.")
48
49
  max_retries: int = Field(default=10, description="The max number of retries for the request.")
49
50
 
@@ -33,6 +33,7 @@ from nat.builder.context import Context
33
33
  from nat.builder.framework_enum import LLMFrameworkEnum
34
34
  from nat.data_models.intermediate_step import IntermediateStepPayload
35
35
  from nat.data_models.intermediate_step import IntermediateStepType
36
+ from nat.data_models.intermediate_step import ServerToolUseSchema
36
37
  from nat.data_models.intermediate_step import StreamEventData
37
38
  from nat.data_models.intermediate_step import ToolSchema
38
39
  from nat.data_models.intermediate_step import TraceMetadata
@@ -48,7 +49,14 @@ def _extract_tools_schema(invocation_params: dict) -> list:
48
49
  tools_schema = []
49
50
  if invocation_params is not None:
50
51
  for tool in invocation_params.get("tools", []):
51
- tools_schema.append(ToolSchema(**tool))
52
+ try:
53
+ tools_schema.append(ToolSchema(**tool))
54
+ except Exception:
55
+ logger.debug(
56
+ "Failed to parse tool schema from invocation params: %s. \n This "
57
+ "can occur when the LLM server has native tools and can be ignored if "
58
+ "using the responses API.",
59
+ tool)
52
60
 
53
61
  return tools_schema
54
62
 
@@ -93,11 +101,15 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback):
93
101
  completion_tokens = usage_metadata.get("output_tokens", 0)
94
102
  total_tokens = usage_metadata.get("total_tokens", 0)
95
103
 
96
- return TokenUsageBaseModel(
97
- prompt_tokens=prompt_tokens,
98
- completion_tokens=completion_tokens,
99
- total_tokens=total_tokens,
100
- )
104
+ cache_tokens = usage_metadata.get("input_token_details", {}).get("cache_read", 0)
105
+
106
+ reasoning_tokens = usage_metadata.get("output_token_details", {}).get("reasoning", 0)
107
+
108
+ return TokenUsageBaseModel(prompt_tokens=prompt_tokens,
109
+ completion_tokens=completion_tokens,
110
+ total_tokens=total_tokens,
111
+ cached_tokens=cache_tokens,
112
+ reasoning_tokens=reasoning_tokens)
101
113
  return TokenUsageBaseModel()
102
114
 
103
115
  async def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any) -> None:
@@ -213,6 +225,7 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback):
213
225
  except IndexError:
214
226
  generation = None
215
227
 
228
+ message = None
216
229
  if isinstance(generation, ChatGeneration):
217
230
  try:
218
231
  message = generation.message
@@ -232,6 +245,17 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback):
232
245
  else:
233
246
  llm_text_output = ""
234
247
 
248
+ tool_outputs_list = []
249
+ # Check if message.additional_kwargs as tool_outputs indicative of server side tool calling
250
+ if message and message.additional_kwargs and "tool_outputs" in message.additional_kwargs:
251
+ tools_outputs = message.additional_kwargs["tool_outputs"]
252
+ if isinstance(tools_outputs, list):
253
+ for tool in tools_outputs:
254
+ try:
255
+ tool_outputs_list.append(ServerToolUseSchema(**tool))
256
+ except Exception:
257
+ pass
258
+
235
259
  # update shared state behind lock
236
260
  with self._lock:
237
261
  usage_stat = IntermediateStepPayload(
@@ -243,7 +267,8 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback):
243
267
  data=StreamEventData(input=self._run_id_to_llm_input.get(str(kwargs.get("run_id", "")), ""),
244
268
  output=llm_text_output),
245
269
  usage_info=UsageInfo(token_usage=self._extract_token_base_model(usage_metadata)),
246
- metadata=TraceMetadata(chat_responses=[generation] if generation else []))
270
+ metadata=TraceMetadata(chat_responses=[generation] if generation else [],
271
+ tool_outputs=tool_outputs_list if tool_outputs_list else []))
247
272
 
248
273
  self.step_manager.push_intermediate_step(usage_stat)
249
274
 
@@ -30,6 +30,7 @@ from nat.builder.context import Context
30
30
  from nat.builder.framework_enum import LLMFrameworkEnum
31
31
  from nat.data_models.intermediate_step import IntermediateStepPayload
32
32
  from nat.data_models.intermediate_step import IntermediateStepType
33
+ from nat.data_models.intermediate_step import ServerToolUseSchema
33
34
  from nat.data_models.intermediate_step import StreamEventData
34
35
  from nat.data_models.intermediate_step import TraceMetadata
35
36
  from nat.data_models.intermediate_step import UsageInfo
@@ -64,6 +65,26 @@ class LlamaIndexProfilerHandler(BaseCallbackHandler, BaseProfilerCallback):
64
65
  self._run_id_to_tool_input = {}
65
66
  self._run_id_to_timestamp = {}
66
67
 
68
+ @staticmethod
69
+ def _extract_token_usage(response: ChatResponse) -> TokenUsageBaseModel:
70
+ token_usage = TokenUsageBaseModel()
71
+ try:
72
+ if response and response.additional_kwargs and "usage" in response.additional_kwargs:
73
+ usage = response.additional_kwargs["usage"] if "usage" in response.additional_kwargs else {}
74
+ token_usage.prompt_tokens = usage.input_tokens if hasattr(usage, "input_tokens") else 0
75
+ token_usage.completion_tokens = usage.output_tokens if hasattr(usage, "output_tokens") else 0
76
+
77
+ if hasattr(usage, "input_tokens_details") and hasattr(usage.input_tokens_details, "cached_tokens"):
78
+ token_usage.cached_tokens = usage.input_tokens_details.cached_tokens
79
+
80
+ if hasattr(usage, "output_tokens_details") and hasattr(usage.output_tokens_details, "reasoning_tokens"):
81
+ token_usage.reasoning_tokens = usage.output_tokens_details.reasoning_tokens
82
+
83
+ except Exception as e:
84
+ logger.debug("Error extracting token usage: %s", e, exc_info=True)
85
+
86
+ return token_usage
87
+
67
88
  def on_event_start(
68
89
  self,
69
90
  event_type: CBEventType,
@@ -167,6 +188,18 @@ class LlamaIndexProfilerHandler(BaseCallbackHandler, BaseProfilerCallback):
167
188
  except Exception as e:
168
189
  logger.exception("Error getting model name: %s", e)
169
190
 
191
+ # Append usage data to NAT usage stats
192
+ tool_outputs_list = []
193
+ # Check if message.additional_kwargs as tool_outputs indicative of server side tool calling
194
+ if response and response.additional_kwargs and "built_in_tool_calls" in response.additional_kwargs:
195
+ tools_outputs = response.additional_kwargs["built_in_tool_calls"]
196
+ if isinstance(tools_outputs, list):
197
+ for tool in tools_outputs:
198
+ try:
199
+ tool_outputs_list.append(ServerToolUseSchema(**tool.model_dump()))
200
+ except Exception:
201
+ pass
202
+
170
203
  # Append usage data to NAT usage stats
171
204
  with self._lock:
172
205
  stats = IntermediateStepPayload(
@@ -176,8 +209,9 @@ class LlamaIndexProfilerHandler(BaseCallbackHandler, BaseProfilerCallback):
176
209
  name=model_name,
177
210
  UUID=event_id,
178
211
  data=StreamEventData(input=self._run_id_to_llm_input.get(event_id), output=llm_text_output),
179
- metadata=TraceMetadata(chat_responses=response.message if response.message else None),
180
- usage_info=UsageInfo(token_usage=TokenUsageBaseModel(**response.additional_kwargs)))
212
+ metadata=TraceMetadata(chat_responses=response.message if response.message else None,
213
+ tool_outputs=tool_outputs_list if tool_outputs_list else []),
214
+ usage_info=UsageInfo(token_usage=self._extract_token_usage(response)))
181
215
  self.step_manager.push_intermediate_step(stats)
182
216
 
183
217
  elif event_type == CBEventType.FUNCTION_CALL and payload:
@@ -24,4 +24,6 @@ class TokenUsageBaseModel(BaseModel):
24
24
 
25
25
  prompt_tokens: int = Field(default=0, description="Number of tokens in the prompt.")
26
26
  completion_tokens: int = Field(default=0, description="Number of tokens in the completion.")
27
+ cached_tokens: int = Field(default=0, description="Number of tokens read from cache.")
28
+ reasoning_tokens: int = Field(default=0, description="Number of tokens used for reasoning.")
27
29
  total_tokens: int = Field(default=0, description="Number of tokens total.")