nvidia-nat 1.4.0a20251015__py3-none-any.whl → 1.4.0a20251021__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +3 -3
- nat/agent/reasoning_agent/reasoning_agent.py +6 -6
- nat/agent/register.py +1 -0
- nat/agent/responses_api_agent/__init__.py +14 -0
- nat/agent/responses_api_agent/register.py +126 -0
- nat/agent/tool_calling_agent/agent.py +6 -10
- nat/builder/context.py +2 -1
- nat/builder/intermediate_step_manager.py +6 -2
- nat/data_models/api_server.py +83 -33
- nat/data_models/intermediate_step.py +9 -1
- nat/data_models/llm.py +15 -1
- nat/data_models/openai_mcp.py +46 -0
- nat/data_models/optimizable.py +2 -1
- nat/data_models/thinking_mixin.py +2 -2
- nat/eval/evaluate.py +2 -0
- nat/eval/usage_stats.py +2 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +3 -0
- nat/front_ends/fastapi/message_handler.py +65 -40
- nat/front_ends/fastapi/message_validator.py +1 -2
- nat/front_ends/mcp/mcp_front_end_config.py +32 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +9 -6
- nat/llm/aws_bedrock_llm.py +3 -3
- nat/llm/litellm_llm.py +6 -3
- nat/llm/nim_llm.py +3 -3
- nat/llm/openai_llm.py +4 -3
- nat/profiler/callbacks/langchain_callback_handler.py +32 -7
- nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
- nat/profiler/callbacks/token_usage_base_model.py +2 -0
- nat/utils/exception_handlers/automatic_retries.py +205 -54
- nat/utils/responses_api.py +26 -0
- {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/METADATA +4 -4
- {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/RECORD +37 -33
- {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/top_level.txt +0 -0
|
@@ -25,6 +25,7 @@ from pydantic import ValidationError
|
|
|
25
25
|
from starlette.websockets import WebSocketDisconnect
|
|
26
26
|
|
|
27
27
|
from nat.authentication.interfaces import FlowHandlerBase
|
|
28
|
+
from nat.data_models.api_server import ChatRequest
|
|
28
29
|
from nat.data_models.api_server import ChatResponse
|
|
29
30
|
from nat.data_models.api_server import ChatResponseChunk
|
|
30
31
|
from nat.data_models.api_server import Error
|
|
@@ -33,6 +34,8 @@ from nat.data_models.api_server import ResponsePayloadOutput
|
|
|
33
34
|
from nat.data_models.api_server import ResponseSerializable
|
|
34
35
|
from nat.data_models.api_server import SystemResponseContent
|
|
35
36
|
from nat.data_models.api_server import TextContent
|
|
37
|
+
from nat.data_models.api_server import UserMessageContentRoleType
|
|
38
|
+
from nat.data_models.api_server import UserMessages
|
|
36
39
|
from nat.data_models.api_server import WebSocketMessageStatus
|
|
37
40
|
from nat.data_models.api_server import WebSocketMessageType
|
|
38
41
|
from nat.data_models.api_server import WebSocketSystemInteractionMessage
|
|
@@ -64,12 +67,12 @@ class WebSocketMessageHandler:
|
|
|
64
67
|
self._running_workflow_task: asyncio.Task | None = None
|
|
65
68
|
self._message_parent_id: str = "default_id"
|
|
66
69
|
self._conversation_id: str | None = None
|
|
67
|
-
self._workflow_schema_type: str = None
|
|
68
|
-
self._user_interaction_response: asyncio.Future[
|
|
70
|
+
self._workflow_schema_type: str | None = None
|
|
71
|
+
self._user_interaction_response: asyncio.Future[TextContent] | None = None
|
|
69
72
|
|
|
70
73
|
self._flow_handler: FlowHandlerBase | None = None
|
|
71
74
|
|
|
72
|
-
self._schema_output_mapping: dict[str, type[BaseModel] | None] = {
|
|
75
|
+
self._schema_output_mapping: dict[str, type[BaseModel] | type[None]] = {
|
|
73
76
|
WorkflowSchemaType.GENERATE: self._session_manager.workflow.single_output_schema,
|
|
74
77
|
WorkflowSchemaType.CHAT: ChatResponse,
|
|
75
78
|
WorkflowSchemaType.CHAT_STREAM: ChatResponseChunk,
|
|
@@ -114,36 +117,58 @@ class WebSocketMessageHandler:
|
|
|
114
117
|
pass
|
|
115
118
|
|
|
116
119
|
elif (isinstance(validated_message, WebSocketUserInteractionResponseMessage)):
|
|
117
|
-
user_content = await self.
|
|
120
|
+
user_content = await self._process_websocket_user_interaction_response_message(validated_message)
|
|
121
|
+
assert self._user_interaction_response is not None
|
|
118
122
|
self._user_interaction_response.set_result(user_content)
|
|
119
123
|
except (asyncio.CancelledError, WebSocketDisconnect):
|
|
120
124
|
# TODO: Handle the disconnect
|
|
121
125
|
break
|
|
122
126
|
|
|
123
|
-
|
|
124
|
-
self, user_content: WebSocketUserMessage | WebSocketUserInteractionResponseMessage) -> BaseModel | None:
|
|
127
|
+
def _extract_last_user_message_content(self, messages: list[UserMessages]) -> TextContent:
|
|
125
128
|
"""
|
|
126
|
-
|
|
129
|
+
Extracts the last user's TextContent from a list of messages.
|
|
127
130
|
|
|
128
|
-
:
|
|
129
|
-
|
|
130
|
-
"""
|
|
131
|
+
Args:
|
|
132
|
+
messages: List of UserMessages.
|
|
131
133
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
+
Returns:
|
|
135
|
+
TextContent object from the last user message.
|
|
134
136
|
|
|
137
|
+
Raises:
|
|
138
|
+
ValueError: If no user text content is found.
|
|
139
|
+
"""
|
|
140
|
+
for user_message in messages[::-1]:
|
|
141
|
+
if user_message.role == UserMessageContentRoleType.USER:
|
|
135
142
|
for attachment in user_message.content:
|
|
136
|
-
|
|
137
143
|
if isinstance(attachment, TextContent):
|
|
138
144
|
return attachment
|
|
145
|
+
raise ValueError("No user text content found in messages.")
|
|
146
|
+
|
|
147
|
+
async def _process_websocket_user_interaction_response_message(
|
|
148
|
+
self, user_content: WebSocketUserInteractionResponseMessage) -> TextContent:
|
|
149
|
+
"""
|
|
150
|
+
Processes a WebSocketUserInteractionResponseMessage.
|
|
151
|
+
"""
|
|
152
|
+
return self._extract_last_user_message_content(user_content.content.messages)
|
|
139
153
|
|
|
140
|
-
|
|
154
|
+
async def _process_websocket_user_message(self, user_content: WebSocketUserMessage) -> ChatRequest | str:
|
|
155
|
+
"""
|
|
156
|
+
Processes a WebSocketUserMessage based on schema type.
|
|
157
|
+
"""
|
|
158
|
+
if self._workflow_schema_type in [WorkflowSchemaType.CHAT, WorkflowSchemaType.CHAT_STREAM]:
|
|
159
|
+
return ChatRequest(**user_content.content.model_dump(include={"messages"}))
|
|
160
|
+
|
|
161
|
+
elif self._workflow_schema_type in [WorkflowSchemaType.GENERATE, WorkflowSchemaType.GENERATE_STREAM]:
|
|
162
|
+
return self._extract_last_user_message_content(user_content.content.messages).text
|
|
163
|
+
|
|
164
|
+
raise ValueError("Unsupported workflow schema type for WebSocketUserMessage")
|
|
141
165
|
|
|
142
166
|
async def process_workflow_request(self, user_message_as_validated_type: WebSocketUserMessage) -> None:
|
|
143
167
|
"""
|
|
144
168
|
Process user messages and routes them appropriately.
|
|
145
169
|
|
|
146
|
-
:
|
|
170
|
+
Args:
|
|
171
|
+
user_message_as_validated_type (WebSocketUserMessage): The validated user message to process.
|
|
147
172
|
"""
|
|
148
173
|
|
|
149
174
|
try:
|
|
@@ -151,18 +176,15 @@ class WebSocketMessageHandler:
|
|
|
151
176
|
self._workflow_schema_type = user_message_as_validated_type.schema_type
|
|
152
177
|
self._conversation_id = user_message_as_validated_type.conversation_id
|
|
153
178
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
if content is None:
|
|
157
|
-
raise ValueError(f"User message content could not be found: {user_message_as_validated_type}")
|
|
179
|
+
message_content: typing.Any = await self._process_websocket_user_message(user_message_as_validated_type)
|
|
158
180
|
|
|
159
|
-
if
|
|
181
|
+
if (self._running_workflow_task is None):
|
|
160
182
|
|
|
161
|
-
def _done_callback(
|
|
183
|
+
def _done_callback(_task: asyncio.Task):
|
|
162
184
|
self._running_workflow_task = None
|
|
163
185
|
|
|
164
186
|
self._running_workflow_task = asyncio.create_task(
|
|
165
|
-
self._run_workflow(payload=
|
|
187
|
+
self._run_workflow(payload=message_content,
|
|
166
188
|
user_message_id=self._message_parent_id,
|
|
167
189
|
conversation_id=self._conversation_id,
|
|
168
190
|
result_type=self._schema_output_mapping[self._workflow_schema_type],
|
|
@@ -180,13 +202,14 @@ class WebSocketMessageHandler:
|
|
|
180
202
|
async def create_websocket_message(self,
|
|
181
203
|
data_model: BaseModel,
|
|
182
204
|
message_type: str | None = None,
|
|
183
|
-
status:
|
|
205
|
+
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS) -> None:
|
|
184
206
|
"""
|
|
185
207
|
Creates a websocket message that will be ready for routing based on message type or data model.
|
|
186
208
|
|
|
187
|
-
:
|
|
188
|
-
|
|
189
|
-
|
|
209
|
+
Args:
|
|
210
|
+
data_model (BaseModel): Message content model.
|
|
211
|
+
message_type (str | None): Message content model.
|
|
212
|
+
status (WebSocketMessageStatus): Message content model.
|
|
190
213
|
"""
|
|
191
214
|
try:
|
|
192
215
|
message: BaseModel | None = None
|
|
@@ -196,8 +219,8 @@ class WebSocketMessageHandler:
|
|
|
196
219
|
|
|
197
220
|
message_schema: type[BaseModel] = await self._message_validator.get_message_schema_by_type(message_type)
|
|
198
221
|
|
|
199
|
-
if 'id'
|
|
200
|
-
message_id: str = data_model
|
|
222
|
+
if hasattr(data_model, 'id'):
|
|
223
|
+
message_id: str = str(getattr(data_model, 'id'))
|
|
201
224
|
else:
|
|
202
225
|
message_id = str(uuid.uuid4())
|
|
203
226
|
|
|
@@ -253,12 +276,15 @@ class WebSocketMessageHandler:
|
|
|
253
276
|
Registered human interaction callback that processes human interactions and returns
|
|
254
277
|
responses from websocket connection.
|
|
255
278
|
|
|
256
|
-
:
|
|
257
|
-
|
|
279
|
+
Args:
|
|
280
|
+
prompt: Incoming interaction content data model.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
A Text Content Base Pydantic model.
|
|
258
284
|
"""
|
|
259
285
|
|
|
260
286
|
# First create a future from the loop for the human response
|
|
261
|
-
human_response_future: asyncio.Future[
|
|
287
|
+
human_response_future: asyncio.Future[TextContent] = asyncio.get_running_loop().create_future()
|
|
262
288
|
|
|
263
289
|
# Then add the future to the outstanding human prompts dictionary
|
|
264
290
|
self._user_interaction_response = human_response_future
|
|
@@ -274,10 +300,10 @@ class WebSocketMessageHandler:
|
|
|
274
300
|
return HumanResponseNotification()
|
|
275
301
|
|
|
276
302
|
# Wait for the human response future to complete
|
|
277
|
-
|
|
303
|
+
text_content: TextContent = await human_response_future
|
|
278
304
|
|
|
279
305
|
interaction_response: HumanResponse = await self._message_validator.convert_text_content_to_human_response(
|
|
280
|
-
|
|
306
|
+
text_content, prompt.content)
|
|
281
307
|
|
|
282
308
|
return interaction_response
|
|
283
309
|
|
|
@@ -293,13 +319,12 @@ class WebSocketMessageHandler:
|
|
|
293
319
|
output_type: type | None = None) -> None:
|
|
294
320
|
|
|
295
321
|
try:
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
if self._flow_handler else None)) as session:
|
|
322
|
+
auth_callback = self._flow_handler.authenticate if self._flow_handler else None
|
|
323
|
+
async with self._session_manager.session(user_message_id=user_message_id,
|
|
324
|
+
conversation_id=conversation_id,
|
|
325
|
+
http_connection=self._socket,
|
|
326
|
+
user_input_callback=self.human_interaction_callback,
|
|
327
|
+
user_authentication_callback=auth_callback) as session:
|
|
303
328
|
|
|
304
329
|
async for value in generate_streaming_response(payload,
|
|
305
330
|
session_manager=session,
|
|
@@ -240,8 +240,7 @@ class MessageValidator:
|
|
|
240
240
|
thread_id: str = "default",
|
|
241
241
|
parent_id: str = "default",
|
|
242
242
|
conversation_id: str | None = None,
|
|
243
|
-
content: SystemResponseContent
|
|
244
|
-
| Error = SystemResponseContent(),
|
|
243
|
+
content: SystemResponseContent | Error = SystemResponseContent(),
|
|
245
244
|
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
246
245
|
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
247
246
|
) -> WebSocketSystemResponseTokenMessage | None:
|
|
@@ -13,13 +13,17 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import logging
|
|
16
17
|
from typing import Literal
|
|
17
18
|
|
|
18
19
|
from pydantic import Field
|
|
20
|
+
from pydantic import model_validator
|
|
19
21
|
|
|
20
22
|
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
21
23
|
from nat.data_models.front_end import FrontEndBaseConfig
|
|
22
24
|
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
23
27
|
|
|
24
28
|
class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
25
29
|
"""MCP front end configuration.
|
|
@@ -56,3 +60,31 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
|
56
60
|
le=50)
|
|
57
61
|
memory_profile_log_level: str = Field(default="DEBUG",
|
|
58
62
|
description="Log level for memory profiling output (default: DEBUG)")
|
|
63
|
+
|
|
64
|
+
@model_validator(mode="after")
|
|
65
|
+
def validate_security_configuration(self):
|
|
66
|
+
"""Validate security configuration to prevent accidental misconfigurations."""
|
|
67
|
+
# Check if server is bound to a non-localhost interface without authentication
|
|
68
|
+
localhost_hosts = {"localhost", "127.0.0.1", "::1"}
|
|
69
|
+
if self.host not in localhost_hosts and self.server_auth is None:
|
|
70
|
+
logger.warning(
|
|
71
|
+
"MCP server is configured to bind to '%s' without authentication. "
|
|
72
|
+
"This may expose your server to unauthorized access. "
|
|
73
|
+
"Consider either: (1) binding to localhost for local-only access, "
|
|
74
|
+
"or (2) configuring server_auth for production deployments on public interfaces.",
|
|
75
|
+
self.host)
|
|
76
|
+
|
|
77
|
+
# Check if SSE transport is used (which doesn't support authentication)
|
|
78
|
+
if self.transport == "sse":
|
|
79
|
+
if self.server_auth is not None:
|
|
80
|
+
logger.warning("SSE transport does not support authentication. "
|
|
81
|
+
"The configured server_auth will be ignored. "
|
|
82
|
+
"For production use with authentication, use 'streamable-http' transport instead.")
|
|
83
|
+
elif self.host not in localhost_hosts:
|
|
84
|
+
logger.warning(
|
|
85
|
+
"SSE transport does not support authentication and is bound to '%s'. "
|
|
86
|
+
"This configuration is not recommended for production use. "
|
|
87
|
+
"For production deployments, use 'streamable-http' transport with server_auth configured.",
|
|
88
|
+
self.host)
|
|
89
|
+
|
|
90
|
+
return self
|
|
@@ -105,9 +105,12 @@ class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
|
|
|
105
105
|
|
|
106
106
|
# Start the MCP server with configurable transport
|
|
107
107
|
# streamable-http is the default, but users can choose sse if preferred
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
108
|
+
try:
|
|
109
|
+
if self.front_end_config.transport == "sse":
|
|
110
|
+
logger.info("Starting MCP server with SSE endpoint at /sse")
|
|
111
|
+
await mcp.run_sse_async()
|
|
112
|
+
else: # streamable-http
|
|
113
|
+
logger.info("Starting MCP server with streamable-http endpoint at /mcp/")
|
|
114
|
+
await mcp.run_streamable_http_async()
|
|
115
|
+
except KeyboardInterrupt:
|
|
116
|
+
logger.info("MCP server shutdown requested (Ctrl+C). Shutting down gracefully.")
|
nat/llm/aws_bedrock_llm.py
CHANGED
|
@@ -42,9 +42,9 @@ class AWSBedrockModelConfig(LLMBaseConfig,
|
|
|
42
42
|
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
43
43
|
|
|
44
44
|
# Completion parameters
|
|
45
|
-
model_name: str =
|
|
46
|
-
|
|
47
|
-
|
|
45
|
+
model_name: str = OptimizableField(validation_alias=AliasChoices("model_name", "model"),
|
|
46
|
+
serialization_alias="model",
|
|
47
|
+
description="The model name for the hosted AWS Bedrock.")
|
|
48
48
|
max_tokens: int = OptimizableField(default=300,
|
|
49
49
|
description="Maximum number of tokens to generate.",
|
|
50
50
|
space=SearchSpace(high=2176, low=128, step=512))
|
nat/llm/litellm_llm.py
CHANGED
|
@@ -23,6 +23,8 @@ from nat.builder.builder import Builder
|
|
|
23
23
|
from nat.builder.llm import LLMProviderInfo
|
|
24
24
|
from nat.cli.register_workflow import register_llm_provider
|
|
25
25
|
from nat.data_models.llm import LLMBaseConfig
|
|
26
|
+
from nat.data_models.optimizable import OptimizableField
|
|
27
|
+
from nat.data_models.optimizable import OptimizableMixin
|
|
26
28
|
from nat.data_models.retry_mixin import RetryMixin
|
|
27
29
|
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
28
30
|
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
@@ -31,6 +33,7 @@ from nat.data_models.top_p_mixin import TopPMixin
|
|
|
31
33
|
|
|
32
34
|
class LiteLlmModelConfig(
|
|
33
35
|
LLMBaseConfig,
|
|
36
|
+
OptimizableMixin,
|
|
34
37
|
RetryMixin,
|
|
35
38
|
TemperatureMixin,
|
|
36
39
|
TopPMixin,
|
|
@@ -46,9 +49,9 @@ class LiteLlmModelConfig(
|
|
|
46
49
|
description="Base url to the hosted model.",
|
|
47
50
|
validation_alias=AliasChoices("base_url", "api_base"),
|
|
48
51
|
serialization_alias="api_base")
|
|
49
|
-
model_name: str =
|
|
50
|
-
|
|
51
|
-
|
|
52
|
+
model_name: str = OptimizableField(validation_alias=AliasChoices("model_name", "model"),
|
|
53
|
+
serialization_alias="model",
|
|
54
|
+
description="The LiteLlm hosted model name.")
|
|
52
55
|
seed: int | None = Field(default=None, description="Random seed to set for generation.")
|
|
53
56
|
|
|
54
57
|
|
nat/llm/nim_llm.py
CHANGED
|
@@ -44,9 +44,9 @@ class NIMModelConfig(LLMBaseConfig,
|
|
|
44
44
|
|
|
45
45
|
api_key: str | None = Field(default=None, description="NVIDIA API key to interact with hosted NIM.")
|
|
46
46
|
base_url: str | None = Field(default=None, description="Base url to the hosted NIM.")
|
|
47
|
-
model_name: str =
|
|
48
|
-
|
|
49
|
-
|
|
47
|
+
model_name: str = OptimizableField(validation_alias=AliasChoices("model_name", "model"),
|
|
48
|
+
serialization_alias="model",
|
|
49
|
+
description="The model name for the hosted NIM.")
|
|
50
50
|
max_tokens: PositiveInt = OptimizableField(default=300,
|
|
51
51
|
description="Maximum number of tokens to generate.",
|
|
52
52
|
space=SearchSpace(high=2176, low=128, step=512))
|
nat/llm/openai_llm.py
CHANGED
|
@@ -21,6 +21,7 @@ from nat.builder.builder import Builder
|
|
|
21
21
|
from nat.builder.llm import LLMProviderInfo
|
|
22
22
|
from nat.cli.register_workflow import register_llm_provider
|
|
23
23
|
from nat.data_models.llm import LLMBaseConfig
|
|
24
|
+
from nat.data_models.optimizable import OptimizableField
|
|
24
25
|
from nat.data_models.optimizable import OptimizableMixin
|
|
25
26
|
from nat.data_models.retry_mixin import RetryMixin
|
|
26
27
|
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
@@ -41,9 +42,9 @@ class OpenAIModelConfig(LLMBaseConfig,
|
|
|
41
42
|
|
|
42
43
|
api_key: str | None = Field(default=None, description="OpenAI API key to interact with hosted model.")
|
|
43
44
|
base_url: str | None = Field(default=None, description="Base url to the hosted model.")
|
|
44
|
-
model_name: str =
|
|
45
|
-
|
|
46
|
-
|
|
45
|
+
model_name: str = OptimizableField(validation_alias=AliasChoices("model_name", "model"),
|
|
46
|
+
serialization_alias="model",
|
|
47
|
+
description="The OpenAI hosted model name.")
|
|
47
48
|
seed: int | None = Field(default=None, description="Random seed to set for generation.")
|
|
48
49
|
max_retries: int = Field(default=10, description="The max number of retries for the request.")
|
|
49
50
|
|
|
@@ -33,6 +33,7 @@ from nat.builder.context import Context
|
|
|
33
33
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
34
34
|
from nat.data_models.intermediate_step import IntermediateStepPayload
|
|
35
35
|
from nat.data_models.intermediate_step import IntermediateStepType
|
|
36
|
+
from nat.data_models.intermediate_step import ServerToolUseSchema
|
|
36
37
|
from nat.data_models.intermediate_step import StreamEventData
|
|
37
38
|
from nat.data_models.intermediate_step import ToolSchema
|
|
38
39
|
from nat.data_models.intermediate_step import TraceMetadata
|
|
@@ -48,7 +49,14 @@ def _extract_tools_schema(invocation_params: dict) -> list:
|
|
|
48
49
|
tools_schema = []
|
|
49
50
|
if invocation_params is not None:
|
|
50
51
|
for tool in invocation_params.get("tools", []):
|
|
51
|
-
|
|
52
|
+
try:
|
|
53
|
+
tools_schema.append(ToolSchema(**tool))
|
|
54
|
+
except Exception:
|
|
55
|
+
logger.debug(
|
|
56
|
+
"Failed to parse tool schema from invocation params: %s. \n This "
|
|
57
|
+
"can occur when the LLM server has native tools and can be ignored if "
|
|
58
|
+
"using the responses API.",
|
|
59
|
+
tool)
|
|
52
60
|
|
|
53
61
|
return tools_schema
|
|
54
62
|
|
|
@@ -93,11 +101,15 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback):
|
|
|
93
101
|
completion_tokens = usage_metadata.get("output_tokens", 0)
|
|
94
102
|
total_tokens = usage_metadata.get("total_tokens", 0)
|
|
95
103
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
104
|
+
cache_tokens = usage_metadata.get("input_token_details", {}).get("cache_read", 0)
|
|
105
|
+
|
|
106
|
+
reasoning_tokens = usage_metadata.get("output_token_details", {}).get("reasoning", 0)
|
|
107
|
+
|
|
108
|
+
return TokenUsageBaseModel(prompt_tokens=prompt_tokens,
|
|
109
|
+
completion_tokens=completion_tokens,
|
|
110
|
+
total_tokens=total_tokens,
|
|
111
|
+
cached_tokens=cache_tokens,
|
|
112
|
+
reasoning_tokens=reasoning_tokens)
|
|
101
113
|
return TokenUsageBaseModel()
|
|
102
114
|
|
|
103
115
|
async def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any) -> None:
|
|
@@ -213,6 +225,7 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback):
|
|
|
213
225
|
except IndexError:
|
|
214
226
|
generation = None
|
|
215
227
|
|
|
228
|
+
message = None
|
|
216
229
|
if isinstance(generation, ChatGeneration):
|
|
217
230
|
try:
|
|
218
231
|
message = generation.message
|
|
@@ -232,6 +245,17 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback):
|
|
|
232
245
|
else:
|
|
233
246
|
llm_text_output = ""
|
|
234
247
|
|
|
248
|
+
tool_outputs_list = []
|
|
249
|
+
# Check if message.additional_kwargs as tool_outputs indicative of server side tool calling
|
|
250
|
+
if message and message.additional_kwargs and "tool_outputs" in message.additional_kwargs:
|
|
251
|
+
tools_outputs = message.additional_kwargs["tool_outputs"]
|
|
252
|
+
if isinstance(tools_outputs, list):
|
|
253
|
+
for tool in tools_outputs:
|
|
254
|
+
try:
|
|
255
|
+
tool_outputs_list.append(ServerToolUseSchema(**tool))
|
|
256
|
+
except Exception:
|
|
257
|
+
pass
|
|
258
|
+
|
|
235
259
|
# update shared state behind lock
|
|
236
260
|
with self._lock:
|
|
237
261
|
usage_stat = IntermediateStepPayload(
|
|
@@ -243,7 +267,8 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback):
|
|
|
243
267
|
data=StreamEventData(input=self._run_id_to_llm_input.get(str(kwargs.get("run_id", "")), ""),
|
|
244
268
|
output=llm_text_output),
|
|
245
269
|
usage_info=UsageInfo(token_usage=self._extract_token_base_model(usage_metadata)),
|
|
246
|
-
metadata=TraceMetadata(chat_responses=[generation] if generation else []
|
|
270
|
+
metadata=TraceMetadata(chat_responses=[generation] if generation else [],
|
|
271
|
+
tool_outputs=tool_outputs_list if tool_outputs_list else []))
|
|
247
272
|
|
|
248
273
|
self.step_manager.push_intermediate_step(usage_stat)
|
|
249
274
|
|
|
@@ -30,6 +30,7 @@ from nat.builder.context import Context
|
|
|
30
30
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
31
31
|
from nat.data_models.intermediate_step import IntermediateStepPayload
|
|
32
32
|
from nat.data_models.intermediate_step import IntermediateStepType
|
|
33
|
+
from nat.data_models.intermediate_step import ServerToolUseSchema
|
|
33
34
|
from nat.data_models.intermediate_step import StreamEventData
|
|
34
35
|
from nat.data_models.intermediate_step import TraceMetadata
|
|
35
36
|
from nat.data_models.intermediate_step import UsageInfo
|
|
@@ -64,6 +65,26 @@ class LlamaIndexProfilerHandler(BaseCallbackHandler, BaseProfilerCallback):
|
|
|
64
65
|
self._run_id_to_tool_input = {}
|
|
65
66
|
self._run_id_to_timestamp = {}
|
|
66
67
|
|
|
68
|
+
@staticmethod
|
|
69
|
+
def _extract_token_usage(response: ChatResponse) -> TokenUsageBaseModel:
|
|
70
|
+
token_usage = TokenUsageBaseModel()
|
|
71
|
+
try:
|
|
72
|
+
if response and response.additional_kwargs and "usage" in response.additional_kwargs:
|
|
73
|
+
usage = response.additional_kwargs["usage"] if "usage" in response.additional_kwargs else {}
|
|
74
|
+
token_usage.prompt_tokens = usage.input_tokens if hasattr(usage, "input_tokens") else 0
|
|
75
|
+
token_usage.completion_tokens = usage.output_tokens if hasattr(usage, "output_tokens") else 0
|
|
76
|
+
|
|
77
|
+
if hasattr(usage, "input_tokens_details") and hasattr(usage.input_tokens_details, "cached_tokens"):
|
|
78
|
+
token_usage.cached_tokens = usage.input_tokens_details.cached_tokens
|
|
79
|
+
|
|
80
|
+
if hasattr(usage, "output_tokens_details") and hasattr(usage.output_tokens_details, "reasoning_tokens"):
|
|
81
|
+
token_usage.reasoning_tokens = usage.output_tokens_details.reasoning_tokens
|
|
82
|
+
|
|
83
|
+
except Exception as e:
|
|
84
|
+
logger.debug("Error extracting token usage: %s", e, exc_info=True)
|
|
85
|
+
|
|
86
|
+
return token_usage
|
|
87
|
+
|
|
67
88
|
def on_event_start(
|
|
68
89
|
self,
|
|
69
90
|
event_type: CBEventType,
|
|
@@ -167,6 +188,18 @@ class LlamaIndexProfilerHandler(BaseCallbackHandler, BaseProfilerCallback):
|
|
|
167
188
|
except Exception as e:
|
|
168
189
|
logger.exception("Error getting model name: %s", e)
|
|
169
190
|
|
|
191
|
+
# Append usage data to NAT usage stats
|
|
192
|
+
tool_outputs_list = []
|
|
193
|
+
# Check if message.additional_kwargs as tool_outputs indicative of server side tool calling
|
|
194
|
+
if response and response.additional_kwargs and "built_in_tool_calls" in response.additional_kwargs:
|
|
195
|
+
tools_outputs = response.additional_kwargs["built_in_tool_calls"]
|
|
196
|
+
if isinstance(tools_outputs, list):
|
|
197
|
+
for tool in tools_outputs:
|
|
198
|
+
try:
|
|
199
|
+
tool_outputs_list.append(ServerToolUseSchema(**tool.model_dump()))
|
|
200
|
+
except Exception:
|
|
201
|
+
pass
|
|
202
|
+
|
|
170
203
|
# Append usage data to NAT usage stats
|
|
171
204
|
with self._lock:
|
|
172
205
|
stats = IntermediateStepPayload(
|
|
@@ -176,8 +209,9 @@ class LlamaIndexProfilerHandler(BaseCallbackHandler, BaseProfilerCallback):
|
|
|
176
209
|
name=model_name,
|
|
177
210
|
UUID=event_id,
|
|
178
211
|
data=StreamEventData(input=self._run_id_to_llm_input.get(event_id), output=llm_text_output),
|
|
179
|
-
metadata=TraceMetadata(chat_responses=response.message if response.message else None
|
|
180
|
-
|
|
212
|
+
metadata=TraceMetadata(chat_responses=response.message if response.message else None,
|
|
213
|
+
tool_outputs=tool_outputs_list if tool_outputs_list else []),
|
|
214
|
+
usage_info=UsageInfo(token_usage=self._extract_token_usage(response)))
|
|
181
215
|
self.step_manager.push_intermediate_step(stats)
|
|
182
216
|
|
|
183
217
|
elif event_type == CBEventType.FUNCTION_CALL and payload:
|
|
@@ -24,4 +24,6 @@ class TokenUsageBaseModel(BaseModel):
|
|
|
24
24
|
|
|
25
25
|
prompt_tokens: int = Field(default=0, description="Number of tokens in the prompt.")
|
|
26
26
|
completion_tokens: int = Field(default=0, description="Number of tokens in the completion.")
|
|
27
|
+
cached_tokens: int = Field(default=0, description="Number of tokens read from cache.")
|
|
28
|
+
reasoning_tokens: int = Field(default=0, description="Number of tokens used for reasoning.")
|
|
27
29
|
total_tokens: int = Field(default=0, description="Number of tokens total.")
|