nvidia-nat 1.4.0a20251008__py3-none-any.whl → 1.4.0a20251011__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. nat/agent/react_agent/register.py +15 -24
  2. nat/agent/rewoo_agent/register.py +15 -24
  3. nat/agent/tool_calling_agent/register.py +9 -5
  4. nat/builder/component_utils.py +1 -1
  5. nat/builder/function.py +4 -4
  6. nat/builder/intermediate_step_manager.py +32 -0
  7. nat/builder/workflow_builder.py +46 -3
  8. nat/cli/entrypoint.py +9 -1
  9. nat/data_models/api_server.py +78 -9
  10. nat/data_models/config.py +1 -1
  11. nat/front_ends/console/console_front_end_plugin.py +11 -2
  12. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  13. nat/front_ends/mcp/mcp_front_end_config.py +13 -0
  14. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +18 -1
  15. nat/front_ends/mcp/memory_profiler.py +320 -0
  16. nat/front_ends/mcp/tool_converter.py +21 -2
  17. nat/observability/register.py +16 -0
  18. nat/runtime/runner.py +1 -2
  19. nat/runtime/session.py +1 -1
  20. nat/tool/memory_tools/add_memory_tool.py +3 -3
  21. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  22. nat/tool/memory_tools/get_memory_tool.py +3 -3
  23. nat/utils/type_converter.py +8 -0
  24. nvidia_nat-1.4.0a20251011.dist-info/METADATA +195 -0
  25. {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/RECORD +30 -29
  26. nvidia_nat-1.4.0a20251008.dist-info/METADATA +0 -389
  27. {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/WHEEL +0 -0
  28. {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/entry_points.txt +0 -0
  29. {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  30. {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/licenses/LICENSE.md +0 -0
  31. {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,320 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Memory profiling utilities for MCP frontend."""
16
+
17
+ import gc
18
+ import logging
19
+ import tracemalloc
20
+ from typing import Any
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class MemoryProfiler:
26
+ """Memory profiler for tracking memory usage and potential leaks."""
27
+
28
+ def __init__(self, enabled: bool = False, log_interval: int = 50, top_n: int = 10, log_level: str = "DEBUG"):
29
+ """Initialize the memory profiler.
30
+
31
+ Args:
32
+ enabled: Whether memory profiling is enabled
33
+ log_interval: Log stats every N requests
34
+ top_n: Number of top allocations to log
35
+ log_level: Log level for memory profiling output (e.g., "DEBUG", "INFO")
36
+ """
37
+ self.enabled = enabled
38
+ # normalize interval to avoid modulo-by-zero
39
+ self.log_interval = max(1, int(log_interval))
40
+ self.top_n = top_n
41
+ self.log_level = getattr(logging, log_level.upper(), logging.DEBUG)
42
+ self.request_count = 0
43
+ self.baseline_snapshot = None
44
+
45
+ # Track whether this instance started tracemalloc (to avoid resetting external tracing)
46
+ self._we_started_tracemalloc = False
47
+
48
+ if self.enabled:
49
+ logger.info("Memory profiling ENABLED (interval=%d, top_n=%d, log_level=%s)",
50
+ self.log_interval,
51
+ top_n,
52
+ log_level)
53
+ try:
54
+ if not tracemalloc.is_tracing():
55
+ tracemalloc.start()
56
+ self._we_started_tracemalloc = True
57
+ # Take baseline snapshot
58
+ gc.collect()
59
+ self.baseline_snapshot = tracemalloc.take_snapshot()
60
+ except RuntimeError as e:
61
+ logger.warning("tracemalloc unavailable or not tracing: %s", e)
62
+ else:
63
+ logger.info("Memory profiling DISABLED")
64
+
65
+ def _log(self, message: str, *args: Any) -> None:
66
+ """Log a message at the configured log level.
67
+
68
+ Args:
69
+ message: Log message format string
70
+ args: Arguments for the format string
71
+ """
72
+ logger.log(self.log_level, message, *args)
73
+
74
+ def on_request_complete(self) -> None:
75
+ """Called after each request completes."""
76
+ if not self.enabled:
77
+ return
78
+ self.request_count += 1
79
+ if self.request_count % self.log_interval == 0:
80
+ self.log_memory_stats()
81
+
82
+ def _ensure_tracing(self) -> bool:
83
+ """Ensure tracemalloc is running if we started it originally.
84
+
85
+ Returns:
86
+ True if tracemalloc is active, False otherwise
87
+ """
88
+ if tracemalloc.is_tracing():
89
+ return True
90
+
91
+ # Only restart if we started it originally (respect external control)
92
+ if not self._we_started_tracemalloc:
93
+ return False
94
+
95
+ # Attempt to restart
96
+ try:
97
+ logger.warning("tracemalloc was stopped externally; restarting (we started it originally)")
98
+ tracemalloc.start()
99
+
100
+ # Reset baseline since old tracking data is lost
101
+ gc.collect()
102
+ self.baseline_snapshot = tracemalloc.take_snapshot()
103
+ logger.info("Baseline snapshot reset after tracemalloc restart")
104
+
105
+ return True
106
+ except RuntimeError as e:
107
+ logger.error("Failed to restart tracemalloc: %s", e)
108
+ return False
109
+
110
+ def _safe_traced_memory(self) -> tuple[float, float] | None:
111
+ """Return (current, peak usage in MB) if tracemalloc is active, else None."""
112
+ if not self._ensure_tracing():
113
+ return None
114
+
115
+ try:
116
+ current, peak = tracemalloc.get_traced_memory()
117
+ megabyte = (1 << 20)
118
+ return (current / megabyte, peak / megabyte)
119
+ except RuntimeError:
120
+ return None
121
+
122
+ def _safe_snapshot(self) -> tracemalloc.Snapshot | None:
123
+ """Return a tracemalloc Snapshot if available, else None."""
124
+ if not self._ensure_tracing():
125
+ return None
126
+
127
+ try:
128
+ return tracemalloc.take_snapshot()
129
+ except RuntimeError:
130
+ return None
131
+
132
+ def log_memory_stats(self) -> dict[str, Any]:
133
+ """Log current memory statistics and return them."""
134
+ if not self.enabled:
135
+ return {}
136
+
137
+ # Force garbage collection first
138
+ gc.collect()
139
+
140
+ # Get current memory usage
141
+ mem = self._safe_traced_memory()
142
+ if mem is None:
143
+ logger.info("tracemalloc is not active; cannot collect memory stats.")
144
+ # still return structural fields
145
+ stats = {
146
+ "request_count": self.request_count,
147
+ "current_memory_mb": None,
148
+ "peak_memory_mb": None,
149
+ "active_intermediate_managers": self._safe_intermediate_step_manager_count(),
150
+ "outstanding_steps": self._safe_outstanding_step_count(),
151
+ "active_exporters": self._safe_exporter_count(),
152
+ "isolated_exporters": self._safe_isolated_exporter_count(),
153
+ "subject_instances": self._count_instances_of_type("Subject"),
154
+ }
155
+ return stats
156
+
157
+ current_mb, peak_mb = mem
158
+
159
+ # Take snapshot and compare to baseline
160
+ snapshot = self._safe_snapshot()
161
+
162
+ # Track BaseExporter instances (observability layer)
163
+ exporter_count = self._safe_exporter_count()
164
+ isolated_exporter_count = self._safe_isolated_exporter_count()
165
+
166
+ # Track Subject instances (event streams)
167
+ subject_count = self._count_instances_of_type("Subject")
168
+
169
+ stats = {
170
+ "request_count": self.request_count,
171
+ "current_memory_mb": round(current_mb, 2),
172
+ "peak_memory_mb": round(peak_mb, 2),
173
+ "active_intermediate_managers": self._safe_intermediate_step_manager_count(),
174
+ "outstanding_steps": self._safe_outstanding_step_count(),
175
+ "active_exporters": exporter_count,
176
+ "isolated_exporters": isolated_exporter_count,
177
+ "subject_instances": subject_count,
178
+ }
179
+
180
+ self._log("=" * 80)
181
+ self._log("MEMORY PROFILE AFTER %d REQUESTS:", self.request_count)
182
+ self._log(" Current Memory: %.2f MB", current_mb)
183
+ self._log(" Peak Memory: %.2f MB", peak_mb)
184
+ self._log("")
185
+ self._log("NAT COMPONENT INSTANCES:")
186
+ self._log(" IntermediateStepManagers: %d active (%d outstanding steps)",
187
+ stats["active_intermediate_managers"],
188
+ stats["outstanding_steps"])
189
+ self._log(" BaseExporters: %d active (%d isolated)", stats["active_exporters"], stats["isolated_exporters"])
190
+ self._log(" Subject (event streams): %d instances", stats["subject_instances"])
191
+
192
+ # Show top allocations
193
+ if snapshot is None:
194
+ self._log("tracemalloc snapshot unavailable.")
195
+ else:
196
+ if self.baseline_snapshot:
197
+ self._log("TOP %d MEMORY GROWTH SINCE BASELINE:", self.top_n)
198
+ top_stats = snapshot.compare_to(self.baseline_snapshot, 'lineno')
199
+ else:
200
+ self._log("TOP %d MEMORY ALLOCATIONS:", self.top_n)
201
+ top_stats = snapshot.statistics('lineno')
202
+
203
+ for i, stat in enumerate(top_stats[:self.top_n], 1):
204
+ self._log(" #%d: %s", i, stat)
205
+
206
+ self._log("=" * 80)
207
+
208
+ return stats
209
+
210
+ def _count_instances_of_type(self, type_name: str) -> int:
211
+ """Count instances of a specific type in memory."""
212
+ count = 0
213
+ for obj in gc.get_objects():
214
+ try:
215
+ if type(obj).__name__ == type_name:
216
+ count += 1
217
+ except Exception:
218
+ pass
219
+ return count
220
+
221
+ def _safe_exporter_count(self) -> int:
222
+ try:
223
+ from nat.observability.exporter.base_exporter import BaseExporter
224
+ return BaseExporter.get_active_instance_count()
225
+ except Exception as e:
226
+ logger.debug("Could not get BaseExporter stats: %s", e)
227
+ return 0
228
+
229
+ def _safe_isolated_exporter_count(self) -> int:
230
+ try:
231
+ from nat.observability.exporter.base_exporter import BaseExporter
232
+ return BaseExporter.get_isolated_instance_count()
233
+ except Exception:
234
+ return 0
235
+
236
+ def _safe_intermediate_step_manager_count(self) -> int:
237
+ try:
238
+ from nat.builder.intermediate_step_manager import IntermediateStepManager
239
+ # len() is atomic in CPython, but catch RuntimeError just in case
240
+ try:
241
+ return IntermediateStepManager.get_active_instance_count()
242
+ except RuntimeError:
243
+ # Set was modified during len() - very rare
244
+ logger.debug("Set changed during count, returning 0")
245
+ return 0
246
+ except Exception as e:
247
+ logger.debug("Could not get IntermediateStepManager stats: %s", e)
248
+ return 0
249
+
250
+ def _safe_outstanding_step_count(self) -> int:
251
+ """Get total outstanding steps across all active IntermediateStepManager instances."""
252
+ try:
253
+ from nat.builder.intermediate_step_manager import IntermediateStepManager
254
+
255
+ # Make a snapshot to avoid "Set changed size during iteration" if GC runs
256
+ try:
257
+ instances_snapshot = list(IntermediateStepManager._active_instances)
258
+ except RuntimeError:
259
+ # Set changed during list() call - rare but possible
260
+ logger.debug("Set changed during snapshot, returning 0 for outstanding steps")
261
+ return 0
262
+
263
+ total_outstanding = 0
264
+ # Iterate through snapshot safely
265
+ for ref in instances_snapshot:
266
+ try:
267
+ manager = ref()
268
+ if manager is not None:
269
+ total_outstanding += manager.get_outstanding_step_count()
270
+ except (ReferenceError, AttributeError):
271
+ # Manager was GC'd or in invalid state - skip it
272
+ continue
273
+ return total_outstanding
274
+ except Exception as e:
275
+ logger.debug("Could not get outstanding step count: %s", e)
276
+ return 0
277
+
278
+ def get_stats(self) -> dict[str, Any]:
279
+ """Get current memory statistics without logging."""
280
+ if not self.enabled:
281
+ return {"enabled": False}
282
+
283
+ mem = self._safe_traced_memory()
284
+ if mem is None:
285
+ return {
286
+ "enabled": True,
287
+ "request_count": self.request_count,
288
+ "current_memory_mb": None,
289
+ "peak_memory_mb": None,
290
+ "active_intermediate_managers": self._safe_intermediate_step_manager_count(),
291
+ "outstanding_steps": self._safe_outstanding_step_count(),
292
+ "active_exporters": self._safe_exporter_count(),
293
+ "isolated_exporters": self._safe_isolated_exporter_count(),
294
+ "subject_instances": self._count_instances_of_type("Subject"),
295
+ }
296
+
297
+ current_mb, peak_mb = mem
298
+ return {
299
+ "enabled": True,
300
+ "request_count": self.request_count,
301
+ "current_memory_mb": round(current_mb, 2),
302
+ "peak_memory_mb": round(peak_mb, 2),
303
+ "active_intermediate_managers": self._safe_intermediate_step_manager_count(),
304
+ "outstanding_steps": self._safe_outstanding_step_count(),
305
+ "active_exporters": self._safe_exporter_count(),
306
+ "isolated_exporters": self._safe_isolated_exporter_count(),
307
+ "subject_instances": self._count_instances_of_type("Subject"),
308
+ }
309
+
310
+ def reset_baseline(self) -> None:
311
+ """Reset the baseline snapshot to current state."""
312
+ if not self.enabled:
313
+ return
314
+ gc.collect()
315
+ snap = self._safe_snapshot()
316
+ if snap is None:
317
+ logger.info("Cannot reset baseline: tracemalloc is not active.")
318
+ return
319
+ self.baseline_snapshot = snap
320
+ logger.info("Memory profiling baseline reset at request %d", self.request_count)
@@ -28,6 +28,7 @@ from nat.builder.function_base import FunctionBase
28
28
 
29
29
  if TYPE_CHECKING:
30
30
  from nat.builder.workflow import Workflow
31
+ from nat.front_ends.mcp.memory_profiler import MemoryProfiler
31
32
 
32
33
  logger = logging.getLogger(__name__)
33
34
 
@@ -38,6 +39,7 @@ def create_function_wrapper(
38
39
  schema: type[BaseModel],
39
40
  is_workflow: bool = False,
40
41
  workflow: 'Workflow | None' = None,
42
+ memory_profiler: 'MemoryProfiler | None' = None,
41
43
  ):
42
44
  """Create a wrapper function that exposes the actual parameters of a NAT Function as an MCP tool.
43
45
 
@@ -47,6 +49,7 @@ def create_function_wrapper(
47
49
  schema (type[BaseModel]): The input schema of the function
48
50
  is_workflow (bool): Whether the function is a Workflow
49
51
  workflow (Workflow | None): The parent workflow for observability context
52
+ memory_profiler: Optional memory profiler to track requests
50
53
 
51
54
  Returns:
52
55
  A wrapper function suitable for registration with MCP
@@ -172,6 +175,10 @@ def create_function_wrapper(
172
175
  if ctx:
173
176
  await ctx.report_progress(100, 100)
174
177
 
178
+ # Track request completion for memory profiling
179
+ if memory_profiler:
180
+ memory_profiler.on_request_complete()
181
+
175
182
  # Handle different result types for proper formatting
176
183
  if isinstance(result, str):
177
184
  return result
@@ -181,6 +188,11 @@ def create_function_wrapper(
181
188
  except Exception as e:
182
189
  if ctx:
183
190
  ctx.error("Error calling function %s: %s", function_name, str(e))
191
+
192
+ # Track request completion even on error
193
+ if memory_profiler:
194
+ memory_profiler.on_request_complete()
195
+
184
196
  raise
185
197
 
186
198
  return wrapper_with_ctx
@@ -242,7 +254,8 @@ def get_function_description(function: FunctionBase) -> str:
242
254
  def register_function_with_mcp(mcp: FastMCP,
243
255
  function_name: str,
244
256
  function: FunctionBase,
245
- workflow: 'Workflow | None' = None) -> None:
257
+ workflow: 'Workflow | None' = None,
258
+ memory_profiler: 'MemoryProfiler | None' = None) -> None:
246
259
  """Register a NAT Function as an MCP tool.
247
260
 
248
261
  Args:
@@ -250,6 +263,7 @@ def register_function_with_mcp(mcp: FastMCP,
250
263
  function_name: The name to register the function under
251
264
  function: The NAT Function to register
252
265
  workflow: The parent workflow for observability context (if available)
266
+ memory_profiler: Optional memory profiler to track requests
253
267
  """
254
268
  logger.info("Registering function %s with MCP", function_name)
255
269
 
@@ -267,5 +281,10 @@ def register_function_with_mcp(mcp: FastMCP,
267
281
  function_description = get_function_description(function)
268
282
 
269
283
  # Create and register the wrapper function with MCP
270
- wrapper_func = create_function_wrapper(function_name, function, input_schema, is_workflow, workflow)
284
+ wrapper_func = create_function_wrapper(function_name,
285
+ function,
286
+ input_schema,
287
+ is_workflow,
288
+ workflow,
289
+ memory_profiler)
271
290
  mcp.tool(name=function_name, description=function_description)(wrapper_func)
@@ -77,6 +77,14 @@ async def console_logging_method(config: ConsoleLoggingMethodConfig, builder: Bu
77
77
  level = getattr(logging, config.level.upper(), logging.INFO)
78
78
  handler = logging.StreamHandler(stream=sys.stdout)
79
79
  handler.setLevel(level)
80
+
81
+ # Set formatter to match the default CLI format
82
+ formatter = logging.Formatter(
83
+ fmt="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
84
+ datefmt="%Y-%m-%d %H:%M:%S",
85
+ )
86
+ handler.setFormatter(formatter)
87
+
80
88
  yield handler
81
89
 
82
90
 
@@ -95,4 +103,12 @@ async def file_logging_method(config: FileLoggingMethod, builder: Builder):
95
103
  level = getattr(logging, config.level.upper(), logging.INFO)
96
104
  handler = logging.FileHandler(filename=config.path, mode="a", encoding="utf-8")
97
105
  handler.setLevel(level)
106
+
107
+ # Set formatter to match the default CLI format
108
+ formatter = logging.Formatter(
109
+ fmt="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
110
+ datefmt="%Y-%m-%d %H:%M:%S",
111
+ )
112
+ handler.setFormatter(formatter)
113
+
98
114
  yield handler
nat/runtime/runner.py CHANGED
@@ -196,8 +196,7 @@ class Runner:
196
196
 
197
197
  return result
198
198
  except Exception as e:
199
- err_msg = f": {e}" if str(e).strip() else "."
200
- logger.error("Error running workflow%s", err_msg)
199
+ logger.error("Error running workflow: %s", e)
201
200
  event_stream = self._context_state.event_stream.get()
202
201
  if event_stream:
203
202
  event_stream.on_complete()
nat/runtime/session.py CHANGED
@@ -192,7 +192,7 @@ class SessionManager:
192
192
  user_message_id: str | None,
193
193
  conversation_id: str | None) -> None:
194
194
  """
195
- Extracts and sets user metadata for Websocket connections.
195
+ Extracts and sets user metadata for WebSocket connections.
196
196
  """
197
197
 
198
198
  # Extract cookies from WebSocket headers (similar to HTTP request)
@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
30
30
  class AddToolConfig(FunctionBaseConfig, name="add_memory"):
31
31
  """Function to add memory to a hosted memory platform."""
32
32
 
33
- description: str = Field(default=("Tool to add memory about a user's interactions to a system "
33
+ description: str = Field(default=("Tool to add a memory about a user's interactions to a system "
34
34
  "for retrieval later."),
35
35
  description="The description of this function's use for tool calling agents.")
36
- memory: MemoryRef = Field(default="saas_memory",
36
+ memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
37
37
  description=("Instance name of the memory client instance from the workflow "
38
38
  "configuration object."))
39
39
 
@@ -46,7 +46,7 @@ async def add_memory_tool(config: AddToolConfig, builder: Builder):
46
46
  from langchain_core.tools import ToolException
47
47
 
48
48
  # First, retrieve the memory client
49
- memory_editor = builder.get_memory_client(config.memory)
49
+ memory_editor = await builder.get_memory_client(config.memory)
50
50
 
51
51
  async def _arun(item: MemoryItem) -> str:
52
52
  """
@@ -30,10 +30,9 @@ logger = logging.getLogger(__name__)
30
30
  class DeleteToolConfig(FunctionBaseConfig, name="delete_memory"):
31
31
  """Function to delete memory from a hosted memory platform."""
32
32
 
33
- description: str = Field(default=("Tool to retrieve memory about a user's "
34
- "interactions to help answer questions in a personalized way."),
33
+ description: str = Field(default="Tool to delete a memory from a hosted memory platform.",
35
34
  description="The description of this function's use for tool calling agents.")
36
- memory: MemoryRef = Field(default="saas_memory",
35
+ memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
37
36
  description=("Instance name of the memory client instance from the workflow "
38
37
  "configuration object."))
39
38
 
@@ -47,7 +46,7 @@ async def delete_memory_tool(config: DeleteToolConfig, builder: Builder):
47
46
  from langchain_core.tools import ToolException
48
47
 
49
48
  # First, retrieve the memory client
50
- memory_editor = builder.get_memory_client(config.memory)
49
+ memory_editor = await builder.get_memory_client(config.memory)
51
50
 
52
51
  async def _arun(user_id: str) -> str:
53
52
  """
@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
30
30
  class GetToolConfig(FunctionBaseConfig, name="get_memory"):
31
31
  """Function to get memory to a hosted memory platform."""
32
32
 
33
- description: str = Field(default=("Tool to retrieve memory about a user's "
33
+ description: str = Field(default=("Tool to retrieve a memory about a user's "
34
34
  "interactions to help answer questions in a personalized way."),
35
35
  description="The description of this function's use for tool calling agents.")
36
- memory: MemoryRef = Field(default="saas_memory",
36
+ memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
37
37
  description=("Instance name of the memory client instance from the workflow "
38
38
  "configuration object."))
39
39
 
@@ -49,7 +49,7 @@ async def get_memory_tool(config: GetToolConfig, builder: Builder):
49
49
  from langchain_core.tools import ToolException
50
50
 
51
51
  # First, retrieve the memory client
52
- memory_editor = builder.get_memory_client(config.memory)
52
+ memory_editor = await builder.get_memory_client(config.memory)
53
53
 
54
54
  async def _arun(search_input: SearchMemoryInput) -> str:
55
55
  """
@@ -93,6 +93,14 @@ class TypeConverter:
93
93
  if to_type is None or decomposed.is_instance(data):
94
94
  return data
95
95
 
96
+ # 2) If data is a union type, try to convert to each type in the union
97
+ if decomposed.is_union:
98
+ for union_type in decomposed.args:
99
+ result = self._convert(data, union_type)
100
+ if result is not None:
101
+ return result
102
+ return None
103
+
96
104
  root = decomposed.root
97
105
 
98
106
  # 2) Attempt direct in *this* converter