nvidia-nat 1.4.0a20251008__py3-none-any.whl → 1.4.0a20251011__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/react_agent/register.py +15 -24
- nat/agent/rewoo_agent/register.py +15 -24
- nat/agent/tool_calling_agent/register.py +9 -5
- nat/builder/component_utils.py +1 -1
- nat/builder/function.py +4 -4
- nat/builder/intermediate_step_manager.py +32 -0
- nat/builder/workflow_builder.py +46 -3
- nat/cli/entrypoint.py +9 -1
- nat/data_models/api_server.py +78 -9
- nat/data_models/config.py +1 -1
- nat/front_ends/console/console_front_end_plugin.py +11 -2
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/mcp/mcp_front_end_config.py +13 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +18 -1
- nat/front_ends/mcp/memory_profiler.py +320 -0
- nat/front_ends/mcp/tool_converter.py +21 -2
- nat/observability/register.py +16 -0
- nat/runtime/runner.py +1 -2
- nat/runtime/session.py +1 -1
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +3 -3
- nat/utils/type_converter.py +8 -0
- nvidia_nat-1.4.0a20251011.dist-info/METADATA +195 -0
- {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/RECORD +30 -29
- nvidia_nat-1.4.0a20251008.dist-info/METADATA +0 -389
- {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Memory profiling utilities for MCP frontend."""
|
|
16
|
+
|
|
17
|
+
import gc
|
|
18
|
+
import logging
|
|
19
|
+
import tracemalloc
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MemoryProfiler:
|
|
26
|
+
"""Memory profiler for tracking memory usage and potential leaks."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, enabled: bool = False, log_interval: int = 50, top_n: int = 10, log_level: str = "DEBUG"):
|
|
29
|
+
"""Initialize the memory profiler.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
enabled: Whether memory profiling is enabled
|
|
33
|
+
log_interval: Log stats every N requests
|
|
34
|
+
top_n: Number of top allocations to log
|
|
35
|
+
log_level: Log level for memory profiling output (e.g., "DEBUG", "INFO")
|
|
36
|
+
"""
|
|
37
|
+
self.enabled = enabled
|
|
38
|
+
# normalize interval to avoid modulo-by-zero
|
|
39
|
+
self.log_interval = max(1, int(log_interval))
|
|
40
|
+
self.top_n = top_n
|
|
41
|
+
self.log_level = getattr(logging, log_level.upper(), logging.DEBUG)
|
|
42
|
+
self.request_count = 0
|
|
43
|
+
self.baseline_snapshot = None
|
|
44
|
+
|
|
45
|
+
# Track whether this instance started tracemalloc (to avoid resetting external tracing)
|
|
46
|
+
self._we_started_tracemalloc = False
|
|
47
|
+
|
|
48
|
+
if self.enabled:
|
|
49
|
+
logger.info("Memory profiling ENABLED (interval=%d, top_n=%d, log_level=%s)",
|
|
50
|
+
self.log_interval,
|
|
51
|
+
top_n,
|
|
52
|
+
log_level)
|
|
53
|
+
try:
|
|
54
|
+
if not tracemalloc.is_tracing():
|
|
55
|
+
tracemalloc.start()
|
|
56
|
+
self._we_started_tracemalloc = True
|
|
57
|
+
# Take baseline snapshot
|
|
58
|
+
gc.collect()
|
|
59
|
+
self.baseline_snapshot = tracemalloc.take_snapshot()
|
|
60
|
+
except RuntimeError as e:
|
|
61
|
+
logger.warning("tracemalloc unavailable or not tracing: %s", e)
|
|
62
|
+
else:
|
|
63
|
+
logger.info("Memory profiling DISABLED")
|
|
64
|
+
|
|
65
|
+
def _log(self, message: str, *args: Any) -> None:
|
|
66
|
+
"""Log a message at the configured log level.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
message: Log message format string
|
|
70
|
+
args: Arguments for the format string
|
|
71
|
+
"""
|
|
72
|
+
logger.log(self.log_level, message, *args)
|
|
73
|
+
|
|
74
|
+
def on_request_complete(self) -> None:
|
|
75
|
+
"""Called after each request completes."""
|
|
76
|
+
if not self.enabled:
|
|
77
|
+
return
|
|
78
|
+
self.request_count += 1
|
|
79
|
+
if self.request_count % self.log_interval == 0:
|
|
80
|
+
self.log_memory_stats()
|
|
81
|
+
|
|
82
|
+
def _ensure_tracing(self) -> bool:
|
|
83
|
+
"""Ensure tracemalloc is running if we started it originally.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
True if tracemalloc is active, False otherwise
|
|
87
|
+
"""
|
|
88
|
+
if tracemalloc.is_tracing():
|
|
89
|
+
return True
|
|
90
|
+
|
|
91
|
+
# Only restart if we started it originally (respect external control)
|
|
92
|
+
if not self._we_started_tracemalloc:
|
|
93
|
+
return False
|
|
94
|
+
|
|
95
|
+
# Attempt to restart
|
|
96
|
+
try:
|
|
97
|
+
logger.warning("tracemalloc was stopped externally; restarting (we started it originally)")
|
|
98
|
+
tracemalloc.start()
|
|
99
|
+
|
|
100
|
+
# Reset baseline since old tracking data is lost
|
|
101
|
+
gc.collect()
|
|
102
|
+
self.baseline_snapshot = tracemalloc.take_snapshot()
|
|
103
|
+
logger.info("Baseline snapshot reset after tracemalloc restart")
|
|
104
|
+
|
|
105
|
+
return True
|
|
106
|
+
except RuntimeError as e:
|
|
107
|
+
logger.error("Failed to restart tracemalloc: %s", e)
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
def _safe_traced_memory(self) -> tuple[float, float] | None:
|
|
111
|
+
"""Return (current, peak usage in MB) if tracemalloc is active, else None."""
|
|
112
|
+
if not self._ensure_tracing():
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
current, peak = tracemalloc.get_traced_memory()
|
|
117
|
+
megabyte = (1 << 20)
|
|
118
|
+
return (current / megabyte, peak / megabyte)
|
|
119
|
+
except RuntimeError:
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
def _safe_snapshot(self) -> tracemalloc.Snapshot | None:
|
|
123
|
+
"""Return a tracemalloc Snapshot if available, else None."""
|
|
124
|
+
if not self._ensure_tracing():
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
return tracemalloc.take_snapshot()
|
|
129
|
+
except RuntimeError:
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
def log_memory_stats(self) -> dict[str, Any]:
|
|
133
|
+
"""Log current memory statistics and return them."""
|
|
134
|
+
if not self.enabled:
|
|
135
|
+
return {}
|
|
136
|
+
|
|
137
|
+
# Force garbage collection first
|
|
138
|
+
gc.collect()
|
|
139
|
+
|
|
140
|
+
# Get current memory usage
|
|
141
|
+
mem = self._safe_traced_memory()
|
|
142
|
+
if mem is None:
|
|
143
|
+
logger.info("tracemalloc is not active; cannot collect memory stats.")
|
|
144
|
+
# still return structural fields
|
|
145
|
+
stats = {
|
|
146
|
+
"request_count": self.request_count,
|
|
147
|
+
"current_memory_mb": None,
|
|
148
|
+
"peak_memory_mb": None,
|
|
149
|
+
"active_intermediate_managers": self._safe_intermediate_step_manager_count(),
|
|
150
|
+
"outstanding_steps": self._safe_outstanding_step_count(),
|
|
151
|
+
"active_exporters": self._safe_exporter_count(),
|
|
152
|
+
"isolated_exporters": self._safe_isolated_exporter_count(),
|
|
153
|
+
"subject_instances": self._count_instances_of_type("Subject"),
|
|
154
|
+
}
|
|
155
|
+
return stats
|
|
156
|
+
|
|
157
|
+
current_mb, peak_mb = mem
|
|
158
|
+
|
|
159
|
+
# Take snapshot and compare to baseline
|
|
160
|
+
snapshot = self._safe_snapshot()
|
|
161
|
+
|
|
162
|
+
# Track BaseExporter instances (observability layer)
|
|
163
|
+
exporter_count = self._safe_exporter_count()
|
|
164
|
+
isolated_exporter_count = self._safe_isolated_exporter_count()
|
|
165
|
+
|
|
166
|
+
# Track Subject instances (event streams)
|
|
167
|
+
subject_count = self._count_instances_of_type("Subject")
|
|
168
|
+
|
|
169
|
+
stats = {
|
|
170
|
+
"request_count": self.request_count,
|
|
171
|
+
"current_memory_mb": round(current_mb, 2),
|
|
172
|
+
"peak_memory_mb": round(peak_mb, 2),
|
|
173
|
+
"active_intermediate_managers": self._safe_intermediate_step_manager_count(),
|
|
174
|
+
"outstanding_steps": self._safe_outstanding_step_count(),
|
|
175
|
+
"active_exporters": exporter_count,
|
|
176
|
+
"isolated_exporters": isolated_exporter_count,
|
|
177
|
+
"subject_instances": subject_count,
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
self._log("=" * 80)
|
|
181
|
+
self._log("MEMORY PROFILE AFTER %d REQUESTS:", self.request_count)
|
|
182
|
+
self._log(" Current Memory: %.2f MB", current_mb)
|
|
183
|
+
self._log(" Peak Memory: %.2f MB", peak_mb)
|
|
184
|
+
self._log("")
|
|
185
|
+
self._log("NAT COMPONENT INSTANCES:")
|
|
186
|
+
self._log(" IntermediateStepManagers: %d active (%d outstanding steps)",
|
|
187
|
+
stats["active_intermediate_managers"],
|
|
188
|
+
stats["outstanding_steps"])
|
|
189
|
+
self._log(" BaseExporters: %d active (%d isolated)", stats["active_exporters"], stats["isolated_exporters"])
|
|
190
|
+
self._log(" Subject (event streams): %d instances", stats["subject_instances"])
|
|
191
|
+
|
|
192
|
+
# Show top allocations
|
|
193
|
+
if snapshot is None:
|
|
194
|
+
self._log("tracemalloc snapshot unavailable.")
|
|
195
|
+
else:
|
|
196
|
+
if self.baseline_snapshot:
|
|
197
|
+
self._log("TOP %d MEMORY GROWTH SINCE BASELINE:", self.top_n)
|
|
198
|
+
top_stats = snapshot.compare_to(self.baseline_snapshot, 'lineno')
|
|
199
|
+
else:
|
|
200
|
+
self._log("TOP %d MEMORY ALLOCATIONS:", self.top_n)
|
|
201
|
+
top_stats = snapshot.statistics('lineno')
|
|
202
|
+
|
|
203
|
+
for i, stat in enumerate(top_stats[:self.top_n], 1):
|
|
204
|
+
self._log(" #%d: %s", i, stat)
|
|
205
|
+
|
|
206
|
+
self._log("=" * 80)
|
|
207
|
+
|
|
208
|
+
return stats
|
|
209
|
+
|
|
210
|
+
def _count_instances_of_type(self, type_name: str) -> int:
|
|
211
|
+
"""Count instances of a specific type in memory."""
|
|
212
|
+
count = 0
|
|
213
|
+
for obj in gc.get_objects():
|
|
214
|
+
try:
|
|
215
|
+
if type(obj).__name__ == type_name:
|
|
216
|
+
count += 1
|
|
217
|
+
except Exception:
|
|
218
|
+
pass
|
|
219
|
+
return count
|
|
220
|
+
|
|
221
|
+
def _safe_exporter_count(self) -> int:
|
|
222
|
+
try:
|
|
223
|
+
from nat.observability.exporter.base_exporter import BaseExporter
|
|
224
|
+
return BaseExporter.get_active_instance_count()
|
|
225
|
+
except Exception as e:
|
|
226
|
+
logger.debug("Could not get BaseExporter stats: %s", e)
|
|
227
|
+
return 0
|
|
228
|
+
|
|
229
|
+
def _safe_isolated_exporter_count(self) -> int:
|
|
230
|
+
try:
|
|
231
|
+
from nat.observability.exporter.base_exporter import BaseExporter
|
|
232
|
+
return BaseExporter.get_isolated_instance_count()
|
|
233
|
+
except Exception:
|
|
234
|
+
return 0
|
|
235
|
+
|
|
236
|
+
def _safe_intermediate_step_manager_count(self) -> int:
|
|
237
|
+
try:
|
|
238
|
+
from nat.builder.intermediate_step_manager import IntermediateStepManager
|
|
239
|
+
# len() is atomic in CPython, but catch RuntimeError just in case
|
|
240
|
+
try:
|
|
241
|
+
return IntermediateStepManager.get_active_instance_count()
|
|
242
|
+
except RuntimeError:
|
|
243
|
+
# Set was modified during len() - very rare
|
|
244
|
+
logger.debug("Set changed during count, returning 0")
|
|
245
|
+
return 0
|
|
246
|
+
except Exception as e:
|
|
247
|
+
logger.debug("Could not get IntermediateStepManager stats: %s", e)
|
|
248
|
+
return 0
|
|
249
|
+
|
|
250
|
+
def _safe_outstanding_step_count(self) -> int:
|
|
251
|
+
"""Get total outstanding steps across all active IntermediateStepManager instances."""
|
|
252
|
+
try:
|
|
253
|
+
from nat.builder.intermediate_step_manager import IntermediateStepManager
|
|
254
|
+
|
|
255
|
+
# Make a snapshot to avoid "Set changed size during iteration" if GC runs
|
|
256
|
+
try:
|
|
257
|
+
instances_snapshot = list(IntermediateStepManager._active_instances)
|
|
258
|
+
except RuntimeError:
|
|
259
|
+
# Set changed during list() call - rare but possible
|
|
260
|
+
logger.debug("Set changed during snapshot, returning 0 for outstanding steps")
|
|
261
|
+
return 0
|
|
262
|
+
|
|
263
|
+
total_outstanding = 0
|
|
264
|
+
# Iterate through snapshot safely
|
|
265
|
+
for ref in instances_snapshot:
|
|
266
|
+
try:
|
|
267
|
+
manager = ref()
|
|
268
|
+
if manager is not None:
|
|
269
|
+
total_outstanding += manager.get_outstanding_step_count()
|
|
270
|
+
except (ReferenceError, AttributeError):
|
|
271
|
+
# Manager was GC'd or in invalid state - skip it
|
|
272
|
+
continue
|
|
273
|
+
return total_outstanding
|
|
274
|
+
except Exception as e:
|
|
275
|
+
logger.debug("Could not get outstanding step count: %s", e)
|
|
276
|
+
return 0
|
|
277
|
+
|
|
278
|
+
def get_stats(self) -> dict[str, Any]:
|
|
279
|
+
"""Get current memory statistics without logging."""
|
|
280
|
+
if not self.enabled:
|
|
281
|
+
return {"enabled": False}
|
|
282
|
+
|
|
283
|
+
mem = self._safe_traced_memory()
|
|
284
|
+
if mem is None:
|
|
285
|
+
return {
|
|
286
|
+
"enabled": True,
|
|
287
|
+
"request_count": self.request_count,
|
|
288
|
+
"current_memory_mb": None,
|
|
289
|
+
"peak_memory_mb": None,
|
|
290
|
+
"active_intermediate_managers": self._safe_intermediate_step_manager_count(),
|
|
291
|
+
"outstanding_steps": self._safe_outstanding_step_count(),
|
|
292
|
+
"active_exporters": self._safe_exporter_count(),
|
|
293
|
+
"isolated_exporters": self._safe_isolated_exporter_count(),
|
|
294
|
+
"subject_instances": self._count_instances_of_type("Subject"),
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
current_mb, peak_mb = mem
|
|
298
|
+
return {
|
|
299
|
+
"enabled": True,
|
|
300
|
+
"request_count": self.request_count,
|
|
301
|
+
"current_memory_mb": round(current_mb, 2),
|
|
302
|
+
"peak_memory_mb": round(peak_mb, 2),
|
|
303
|
+
"active_intermediate_managers": self._safe_intermediate_step_manager_count(),
|
|
304
|
+
"outstanding_steps": self._safe_outstanding_step_count(),
|
|
305
|
+
"active_exporters": self._safe_exporter_count(),
|
|
306
|
+
"isolated_exporters": self._safe_isolated_exporter_count(),
|
|
307
|
+
"subject_instances": self._count_instances_of_type("Subject"),
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
def reset_baseline(self) -> None:
|
|
311
|
+
"""Reset the baseline snapshot to current state."""
|
|
312
|
+
if not self.enabled:
|
|
313
|
+
return
|
|
314
|
+
gc.collect()
|
|
315
|
+
snap = self._safe_snapshot()
|
|
316
|
+
if snap is None:
|
|
317
|
+
logger.info("Cannot reset baseline: tracemalloc is not active.")
|
|
318
|
+
return
|
|
319
|
+
self.baseline_snapshot = snap
|
|
320
|
+
logger.info("Memory profiling baseline reset at request %d", self.request_count)
|
|
@@ -28,6 +28,7 @@ from nat.builder.function_base import FunctionBase
|
|
|
28
28
|
|
|
29
29
|
if TYPE_CHECKING:
|
|
30
30
|
from nat.builder.workflow import Workflow
|
|
31
|
+
from nat.front_ends.mcp.memory_profiler import MemoryProfiler
|
|
31
32
|
|
|
32
33
|
logger = logging.getLogger(__name__)
|
|
33
34
|
|
|
@@ -38,6 +39,7 @@ def create_function_wrapper(
|
|
|
38
39
|
schema: type[BaseModel],
|
|
39
40
|
is_workflow: bool = False,
|
|
40
41
|
workflow: 'Workflow | None' = None,
|
|
42
|
+
memory_profiler: 'MemoryProfiler | None' = None,
|
|
41
43
|
):
|
|
42
44
|
"""Create a wrapper function that exposes the actual parameters of a NAT Function as an MCP tool.
|
|
43
45
|
|
|
@@ -47,6 +49,7 @@ def create_function_wrapper(
|
|
|
47
49
|
schema (type[BaseModel]): The input schema of the function
|
|
48
50
|
is_workflow (bool): Whether the function is a Workflow
|
|
49
51
|
workflow (Workflow | None): The parent workflow for observability context
|
|
52
|
+
memory_profiler: Optional memory profiler to track requests
|
|
50
53
|
|
|
51
54
|
Returns:
|
|
52
55
|
A wrapper function suitable for registration with MCP
|
|
@@ -172,6 +175,10 @@ def create_function_wrapper(
|
|
|
172
175
|
if ctx:
|
|
173
176
|
await ctx.report_progress(100, 100)
|
|
174
177
|
|
|
178
|
+
# Track request completion for memory profiling
|
|
179
|
+
if memory_profiler:
|
|
180
|
+
memory_profiler.on_request_complete()
|
|
181
|
+
|
|
175
182
|
# Handle different result types for proper formatting
|
|
176
183
|
if isinstance(result, str):
|
|
177
184
|
return result
|
|
@@ -181,6 +188,11 @@ def create_function_wrapper(
|
|
|
181
188
|
except Exception as e:
|
|
182
189
|
if ctx:
|
|
183
190
|
ctx.error("Error calling function %s: %s", function_name, str(e))
|
|
191
|
+
|
|
192
|
+
# Track request completion even on error
|
|
193
|
+
if memory_profiler:
|
|
194
|
+
memory_profiler.on_request_complete()
|
|
195
|
+
|
|
184
196
|
raise
|
|
185
197
|
|
|
186
198
|
return wrapper_with_ctx
|
|
@@ -242,7 +254,8 @@ def get_function_description(function: FunctionBase) -> str:
|
|
|
242
254
|
def register_function_with_mcp(mcp: FastMCP,
|
|
243
255
|
function_name: str,
|
|
244
256
|
function: FunctionBase,
|
|
245
|
-
workflow: 'Workflow | None' = None
|
|
257
|
+
workflow: 'Workflow | None' = None,
|
|
258
|
+
memory_profiler: 'MemoryProfiler | None' = None) -> None:
|
|
246
259
|
"""Register a NAT Function as an MCP tool.
|
|
247
260
|
|
|
248
261
|
Args:
|
|
@@ -250,6 +263,7 @@ def register_function_with_mcp(mcp: FastMCP,
|
|
|
250
263
|
function_name: The name to register the function under
|
|
251
264
|
function: The NAT Function to register
|
|
252
265
|
workflow: The parent workflow for observability context (if available)
|
|
266
|
+
memory_profiler: Optional memory profiler to track requests
|
|
253
267
|
"""
|
|
254
268
|
logger.info("Registering function %s with MCP", function_name)
|
|
255
269
|
|
|
@@ -267,5 +281,10 @@ def register_function_with_mcp(mcp: FastMCP,
|
|
|
267
281
|
function_description = get_function_description(function)
|
|
268
282
|
|
|
269
283
|
# Create and register the wrapper function with MCP
|
|
270
|
-
wrapper_func = create_function_wrapper(function_name,
|
|
284
|
+
wrapper_func = create_function_wrapper(function_name,
|
|
285
|
+
function,
|
|
286
|
+
input_schema,
|
|
287
|
+
is_workflow,
|
|
288
|
+
workflow,
|
|
289
|
+
memory_profiler)
|
|
271
290
|
mcp.tool(name=function_name, description=function_description)(wrapper_func)
|
nat/observability/register.py
CHANGED
|
@@ -77,6 +77,14 @@ async def console_logging_method(config: ConsoleLoggingMethodConfig, builder: Bu
|
|
|
77
77
|
level = getattr(logging, config.level.upper(), logging.INFO)
|
|
78
78
|
handler = logging.StreamHandler(stream=sys.stdout)
|
|
79
79
|
handler.setLevel(level)
|
|
80
|
+
|
|
81
|
+
# Set formatter to match the default CLI format
|
|
82
|
+
formatter = logging.Formatter(
|
|
83
|
+
fmt="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
|
|
84
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
85
|
+
)
|
|
86
|
+
handler.setFormatter(formatter)
|
|
87
|
+
|
|
80
88
|
yield handler
|
|
81
89
|
|
|
82
90
|
|
|
@@ -95,4 +103,12 @@ async def file_logging_method(config: FileLoggingMethod, builder: Builder):
|
|
|
95
103
|
level = getattr(logging, config.level.upper(), logging.INFO)
|
|
96
104
|
handler = logging.FileHandler(filename=config.path, mode="a", encoding="utf-8")
|
|
97
105
|
handler.setLevel(level)
|
|
106
|
+
|
|
107
|
+
# Set formatter to match the default CLI format
|
|
108
|
+
formatter = logging.Formatter(
|
|
109
|
+
fmt="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
|
|
110
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
111
|
+
)
|
|
112
|
+
handler.setFormatter(formatter)
|
|
113
|
+
|
|
98
114
|
yield handler
|
nat/runtime/runner.py
CHANGED
|
@@ -196,8 +196,7 @@ class Runner:
|
|
|
196
196
|
|
|
197
197
|
return result
|
|
198
198
|
except Exception as e:
|
|
199
|
-
|
|
200
|
-
logger.error("Error running workflow%s", err_msg)
|
|
199
|
+
logger.error("Error running workflow: %s", e)
|
|
201
200
|
event_stream = self._context_state.event_stream.get()
|
|
202
201
|
if event_stream:
|
|
203
202
|
event_stream.on_complete()
|
nat/runtime/session.py
CHANGED
|
@@ -192,7 +192,7 @@ class SessionManager:
|
|
|
192
192
|
user_message_id: str | None,
|
|
193
193
|
conversation_id: str | None) -> None:
|
|
194
194
|
"""
|
|
195
|
-
Extracts and sets user metadata for
|
|
195
|
+
Extracts and sets user metadata for WebSocket connections.
|
|
196
196
|
"""
|
|
197
197
|
|
|
198
198
|
# Extract cookies from WebSocket headers (similar to HTTP request)
|
|
@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
|
|
|
30
30
|
class AddToolConfig(FunctionBaseConfig, name="add_memory"):
|
|
31
31
|
"""Function to add memory to a hosted memory platform."""
|
|
32
32
|
|
|
33
|
-
description: str = Field(default=("Tool to add memory about a user's interactions to a system "
|
|
33
|
+
description: str = Field(default=("Tool to add a memory about a user's interactions to a system "
|
|
34
34
|
"for retrieval later."),
|
|
35
35
|
description="The description of this function's use for tool calling agents.")
|
|
36
|
-
memory: MemoryRef = Field(default="saas_memory",
|
|
36
|
+
memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
|
|
37
37
|
description=("Instance name of the memory client instance from the workflow "
|
|
38
38
|
"configuration object."))
|
|
39
39
|
|
|
@@ -46,7 +46,7 @@ async def add_memory_tool(config: AddToolConfig, builder: Builder):
|
|
|
46
46
|
from langchain_core.tools import ToolException
|
|
47
47
|
|
|
48
48
|
# First, retrieve the memory client
|
|
49
|
-
memory_editor = builder.get_memory_client(config.memory)
|
|
49
|
+
memory_editor = await builder.get_memory_client(config.memory)
|
|
50
50
|
|
|
51
51
|
async def _arun(item: MemoryItem) -> str:
|
|
52
52
|
"""
|
|
@@ -30,10 +30,9 @@ logger = logging.getLogger(__name__)
|
|
|
30
30
|
class DeleteToolConfig(FunctionBaseConfig, name="delete_memory"):
|
|
31
31
|
"""Function to delete memory from a hosted memory platform."""
|
|
32
32
|
|
|
33
|
-
description: str = Field(default=
|
|
34
|
-
"interactions to help answer questions in a personalized way."),
|
|
33
|
+
description: str = Field(default="Tool to delete a memory from a hosted memory platform.",
|
|
35
34
|
description="The description of this function's use for tool calling agents.")
|
|
36
|
-
memory: MemoryRef = Field(default="saas_memory",
|
|
35
|
+
memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
|
|
37
36
|
description=("Instance name of the memory client instance from the workflow "
|
|
38
37
|
"configuration object."))
|
|
39
38
|
|
|
@@ -47,7 +46,7 @@ async def delete_memory_tool(config: DeleteToolConfig, builder: Builder):
|
|
|
47
46
|
from langchain_core.tools import ToolException
|
|
48
47
|
|
|
49
48
|
# First, retrieve the memory client
|
|
50
|
-
memory_editor = builder.get_memory_client(config.memory)
|
|
49
|
+
memory_editor = await builder.get_memory_client(config.memory)
|
|
51
50
|
|
|
52
51
|
async def _arun(user_id: str) -> str:
|
|
53
52
|
"""
|
|
@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
|
|
|
30
30
|
class GetToolConfig(FunctionBaseConfig, name="get_memory"):
|
|
31
31
|
"""Function to get memory to a hosted memory platform."""
|
|
32
32
|
|
|
33
|
-
description: str = Field(default=("Tool to retrieve memory about a user's "
|
|
33
|
+
description: str = Field(default=("Tool to retrieve a memory about a user's "
|
|
34
34
|
"interactions to help answer questions in a personalized way."),
|
|
35
35
|
description="The description of this function's use for tool calling agents.")
|
|
36
|
-
memory: MemoryRef = Field(default="saas_memory",
|
|
36
|
+
memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
|
|
37
37
|
description=("Instance name of the memory client instance from the workflow "
|
|
38
38
|
"configuration object."))
|
|
39
39
|
|
|
@@ -49,7 +49,7 @@ async def get_memory_tool(config: GetToolConfig, builder: Builder):
|
|
|
49
49
|
from langchain_core.tools import ToolException
|
|
50
50
|
|
|
51
51
|
# First, retrieve the memory client
|
|
52
|
-
memory_editor = builder.get_memory_client(config.memory)
|
|
52
|
+
memory_editor = await builder.get_memory_client(config.memory)
|
|
53
53
|
|
|
54
54
|
async def _arun(search_input: SearchMemoryInput) -> str:
|
|
55
55
|
"""
|
nat/utils/type_converter.py
CHANGED
|
@@ -93,6 +93,14 @@ class TypeConverter:
|
|
|
93
93
|
if to_type is None or decomposed.is_instance(data):
|
|
94
94
|
return data
|
|
95
95
|
|
|
96
|
+
# 2) If data is a union type, try to convert to each type in the union
|
|
97
|
+
if decomposed.is_union:
|
|
98
|
+
for union_type in decomposed.args:
|
|
99
|
+
result = self._convert(data, union_type)
|
|
100
|
+
if result is not None:
|
|
101
|
+
return result
|
|
102
|
+
return None
|
|
103
|
+
|
|
96
104
|
root = decomposed.root
|
|
97
105
|
|
|
98
106
|
# 2) Attempt direct in *this* converter
|