nvidia-nat 1.3.0rc1__py3-none-any.whl → 1.3.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. nat/agent/prompt_optimizer/register.py +2 -2
  2. nat/agent/react_agent/register.py +20 -21
  3. nat/agent/rewoo_agent/register.py +18 -20
  4. nat/agent/tool_calling_agent/register.py +7 -3
  5. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +31 -18
  6. nat/builder/component_utils.py +1 -1
  7. nat/builder/context.py +22 -6
  8. nat/builder/function.py +3 -2
  9. nat/builder/workflow_builder.py +46 -3
  10. nat/cli/commands/mcp/mcp.py +6 -6
  11. nat/cli/commands/workflow/templates/config.yml.j2 +14 -12
  12. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  13. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  14. nat/cli/commands/workflow/workflow_commands.py +54 -10
  15. nat/cli/entrypoint.py +9 -1
  16. nat/cli/main.py +3 -0
  17. nat/data_models/api_server.py +143 -66
  18. nat/data_models/config.py +1 -1
  19. nat/data_models/span.py +41 -3
  20. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  21. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +2 -2
  22. nat/front_ends/console/console_front_end_plugin.py +11 -2
  23. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  24. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +5 -35
  25. nat/front_ends/fastapi/message_validator.py +3 -1
  26. nat/observability/exporter/span_exporter.py +34 -14
  27. nat/observability/register.py +16 -0
  28. nat/profiler/decorators/framework_wrapper.py +1 -1
  29. nat/profiler/forecasting/models/linear_model.py +1 -1
  30. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  31. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  32. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  33. nat/runtime/runner.py +103 -6
  34. nat/runtime/session.py +27 -1
  35. nat/tool/memory_tools/add_memory_tool.py +3 -3
  36. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  37. nat/tool/memory_tools/get_memory_tool.py +4 -4
  38. nat/utils/decorators.py +210 -0
  39. nat/utils/type_converter.py +8 -0
  40. nvidia_nat-1.3.0rc3.dist-info/METADATA +195 -0
  41. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/RECORD +46 -45
  42. nvidia_nat-1.3.0rc1.dist-info/METADATA +0 -391
  43. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/WHEEL +0 -0
  44. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/entry_points.txt +0 -0
  45. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  46. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/licenses/LICENSE.md +0 -0
  47. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/top_level.txt +0 -0
@@ -689,10 +689,13 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
689
689
 
690
690
  async def post_openai_api_compatible(response: Response, request: Request, payload: request_type):
691
691
  # Check if streaming is requested
692
+
693
+ response.headers["Content-Type"] = "application/json"
692
694
  stream_requested = getattr(payload, 'stream', False)
693
695
 
694
696
  async with session_manager.session(http_connection=request):
695
697
  if stream_requested:
698
+
696
699
  # Return streaming response
697
700
  return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
698
701
  content=generate_streaming_response_as_str(
@@ -703,40 +706,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
703
706
  result_type=ChatResponseChunk,
704
707
  output_type=ChatResponseChunk))
705
708
 
706
- # Return single response - check if workflow supports non-streaming
707
- try:
708
- response.headers["Content-Type"] = "application/json"
709
- return await generate_single_response(payload, session_manager, result_type=ChatResponse)
710
- except ValueError as e:
711
- if "Cannot get a single output value for streaming workflows" in str(e):
712
- # Workflow only supports streaming, but client requested non-streaming
713
- # Fall back to streaming and collect the result
714
- chunks = []
715
- async for chunk_str in generate_streaming_response_as_str(
716
- payload,
717
- session_manager=session_manager,
718
- streaming=True,
719
- step_adaptor=self.get_step_adaptor(),
720
- result_type=ChatResponseChunk,
721
- output_type=ChatResponseChunk):
722
- if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
723
- chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
724
- if chunk_data:
725
- try:
726
- chunk_json = ChatResponseChunk.model_validate_json(chunk_data)
727
- if (chunk_json.choices and len(chunk_json.choices) > 0
728
- and chunk_json.choices[0].delta
729
- and chunk_json.choices[0].delta.content is not None):
730
- chunks.append(chunk_json.choices[0].delta.content)
731
- except Exception:
732
- continue
733
-
734
- # Create a single response from collected chunks
735
- content = "".join(chunks)
736
- single_response = ChatResponse.from_string(content)
737
- response.headers["Content-Type"] = "application/json"
738
- return single_response
739
- raise
709
+ return await generate_single_response(payload, session_manager, result_type=ChatResponse)
740
710
 
741
711
  return post_openai_api_compatible
742
712
 
@@ -1128,7 +1098,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
1128
1098
  if configured_group.config.type != "mcp_client":
1129
1099
  continue
1130
1100
 
1131
- from nat.plugins.mcp.client_impl import MCPClientConfig
1101
+ from nat.plugins.mcp.client_config import MCPClientConfig
1132
1102
 
1133
1103
  config = configured_group.config
1134
1104
  assert isinstance(config, MCPClientConfig)
@@ -139,8 +139,10 @@ class MessageValidator:
139
139
  text_content: str = str(data_model.payload)
140
140
  validated_message_content = SystemResponseContent(text=text_content)
141
141
 
142
- elif (isinstance(data_model, ChatResponse | ChatResponseChunk)):
142
+ elif isinstance(data_model, ChatResponse):
143
143
  validated_message_content = SystemResponseContent(text=data_model.choices[0].message.content)
144
+ elif isinstance(data_model, ChatResponseChunk):
145
+ validated_message_content = SystemResponseContent(text=data_model.choices[0].delta.content)
144
146
 
145
147
  elif (isinstance(data_model, ResponseIntermediateStep)):
146
148
  validated_message_content = SystemIntermediateStepContent(name=data_model.name,
@@ -126,6 +126,7 @@ class SpanExporter(ProcessingExporter[InputSpanT, OutputSpanT], SerializeMixin):
126
126
 
127
127
  parent_span = None
128
128
  span_ctx = None
129
+ workflow_trace_id = self._context_state.workflow_trace_id.get()
129
130
 
130
131
  # Look up the parent span to establish hierarchy
131
132
  # event.parent_id is the UUID of the last START step with a different UUID from current step
@@ -141,6 +142,9 @@ class SpanExporter(ProcessingExporter[InputSpanT, OutputSpanT], SerializeMixin):
141
142
  parent_span = parent_span.model_copy() if isinstance(parent_span, Span) else None
142
143
  if parent_span and parent_span.context:
143
144
  span_ctx = SpanContext(trace_id=parent_span.context.trace_id)
145
+ # No parent: adopt workflow trace id if available to keep all spans in the same trace
146
+ if span_ctx is None and workflow_trace_id:
147
+ span_ctx = SpanContext(trace_id=workflow_trace_id)
144
148
 
145
149
  # Extract start/end times from the step
146
150
  # By convention, `span_event_timestamp` is the time we started, `event_timestamp` is the time we ended.
@@ -154,23 +158,39 @@ class SpanExporter(ProcessingExporter[InputSpanT, OutputSpanT], SerializeMixin):
154
158
  else:
155
159
  sub_span_name = f"{event.payload.event_type}"
156
160
 
161
+ # Prefer parent/context trace id for attribute, else workflow trace id
162
+ _attr_trace_id = None
163
+ if span_ctx is not None:
164
+ _attr_trace_id = span_ctx.trace_id
165
+ elif parent_span and parent_span.context:
166
+ _attr_trace_id = parent_span.context.trace_id
167
+ elif workflow_trace_id:
168
+ _attr_trace_id = workflow_trace_id
169
+
170
+ attributes = {
171
+ f"{self._span_prefix}.event_type":
172
+ event.payload.event_type.value,
173
+ f"{self._span_prefix}.function.id":
174
+ event.function_ancestry.function_id if event.function_ancestry else "unknown",
175
+ f"{self._span_prefix}.function.name":
176
+ event.function_ancestry.function_name if event.function_ancestry else "unknown",
177
+ f"{self._span_prefix}.subspan.name":
178
+ event.payload.name or "",
179
+ f"{self._span_prefix}.event_timestamp":
180
+ event.event_timestamp,
181
+ f"{self._span_prefix}.framework":
182
+ event.payload.framework.value if event.payload.framework else "unknown",
183
+ f"{self._span_prefix}.conversation.id":
184
+ self._context_state.conversation_id.get() or "unknown",
185
+ f"{self._span_prefix}.workflow.run_id":
186
+ self._context_state.workflow_run_id.get() or "unknown",
187
+ f"{self._span_prefix}.workflow.trace_id": (f"{_attr_trace_id:032x}" if _attr_trace_id else "unknown"),
188
+ }
189
+
157
190
  sub_span = Span(name=sub_span_name,
158
191
  parent=parent_span,
159
192
  context=span_ctx,
160
- attributes={
161
- f"{self._span_prefix}.event_type":
162
- event.payload.event_type.value,
163
- f"{self._span_prefix}.function.id":
164
- event.function_ancestry.function_id if event.function_ancestry else "unknown",
165
- f"{self._span_prefix}.function.name":
166
- event.function_ancestry.function_name if event.function_ancestry else "unknown",
167
- f"{self._span_prefix}.subspan.name":
168
- event.payload.name or "",
169
- f"{self._span_prefix}.event_timestamp":
170
- event.event_timestamp,
171
- f"{self._span_prefix}.framework":
172
- event.payload.framework.value if event.payload.framework else "unknown",
173
- },
193
+ attributes=attributes,
174
194
  start_time=start_ns)
175
195
 
176
196
  span_kind = event_type_to_span_kind(event.event_type)
@@ -77,6 +77,14 @@ async def console_logging_method(config: ConsoleLoggingMethodConfig, builder: Bu
77
77
  level = getattr(logging, config.level.upper(), logging.INFO)
78
78
  handler = logging.StreamHandler(stream=sys.stdout)
79
79
  handler.setLevel(level)
80
+
81
+ # Set formatter to match the default CLI format
82
+ formatter = logging.Formatter(
83
+ fmt="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
84
+ datefmt="%Y-%m-%d %H:%M:%S",
85
+ )
86
+ handler.setFormatter(formatter)
87
+
80
88
  yield handler
81
89
 
82
90
 
@@ -95,4 +103,12 @@ async def file_logging_method(config: FileLoggingMethod, builder: Builder):
95
103
  level = getattr(logging, config.level.upper(), logging.INFO)
96
104
  handler = logging.FileHandler(filename=config.path, mode="a", encoding="utf-8")
97
105
  handler.setLevel(level)
106
+
107
+ # Set formatter to match the default CLI format
108
+ formatter = logging.Formatter(
109
+ fmt="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
110
+ datefmt="%Y-%m-%d %H:%M:%S",
111
+ )
112
+ handler.setFormatter(formatter)
113
+
98
114
  yield handler
@@ -123,7 +123,7 @@ def set_framework_profiler_handler(
123
123
  except ImportError as e:
124
124
  logger.warning(
125
125
  "ADK profiler not available. " +
126
- "Install NAT with ADK extras: pip install 'nvidia-nat[adk]'. Error: %s",
126
+ "Install NAT with ADK extras: pip install \"nvidia-nat[adk]\". Error: %s",
127
127
  e)
128
128
  else:
129
129
  handler = ADKProfilerHandler()
@@ -36,7 +36,7 @@ class LinearModel(ForecastingBaseModel):
36
36
  except ImportError:
37
37
  logger.error(
38
38
  "scikit-learn is not installed. Please install scikit-learn to use the LinearModel "
39
- "profiling model or install `nvidia-nat[profiler]` to install all necessary profiling packages.")
39
+ "profiling model or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
40
40
 
41
41
  raise
42
42
 
@@ -36,7 +36,7 @@ class RandomForestModel(ForecastingBaseModel):
36
36
  except ImportError:
37
37
  logger.error(
38
38
  "scikit-learn is not installed. Please install scikit-learn to use the RandomForest "
39
- "profiling model or install `nvidia-nat[profiler]` to install all necessary profiling packages.")
39
+ "profiling model or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
40
40
 
41
41
  raise
42
42
 
@@ -304,7 +304,7 @@ def save_gantt_chart(all_nodes: list[CallNode], output_path: str) -> None:
304
304
  import matplotlib.pyplot as plt
305
305
  except ImportError:
306
306
  logger.error("matplotlib is not installed. Please install matplotlib to use generate plots for the profiler "
307
- "or install `nvidia-nat[profiler]` to install all necessary profiling packages.")
307
+ "or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
308
308
 
309
309
  raise
310
310
 
@@ -212,7 +212,7 @@ def run_prefixspan(sequences_map: dict[int, list[PrefixCallNode]],
212
212
  from prefixspan import PrefixSpan
213
213
  except ImportError:
214
214
  logger.error("prefixspan is not installed. Please install prefixspan to run the prefix analysis in the "
215
- "profiler or install `nvidia-nat[profiler]` to install all necessary profiling packages.")
215
+ "profiler or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
216
216
 
217
217
  raise
218
218
 
nat/runtime/runner.py CHANGED
@@ -15,11 +15,16 @@
15
15
 
16
16
  import logging
17
17
  import typing
18
+ import uuid
18
19
  from enum import Enum
19
20
 
20
21
  from nat.builder.context import Context
21
22
  from nat.builder.context import ContextState
22
23
  from nat.builder.function import Function
24
+ from nat.data_models.intermediate_step import IntermediateStepPayload
25
+ from nat.data_models.intermediate_step import IntermediateStepType
26
+ from nat.data_models.intermediate_step import StreamEventData
27
+ from nat.data_models.intermediate_step import TraceMetadata
23
28
  from nat.data_models.invocation_node import InvocationNode
24
29
  from nat.observability.exporter_manager import ExporterManager
25
30
  from nat.utils.reactive.subject import Subject
@@ -130,17 +135,59 @@ class Runner:
130
135
  if (self._state != RunnerState.INITIALIZED):
131
136
  raise ValueError("Cannot run the workflow without entering the context")
132
137
 
138
+ token_run_id = None
139
+ token_trace_id = None
133
140
  try:
134
141
  self._state = RunnerState.RUNNING
135
142
 
136
143
  if (not self._entry_fn.has_single_output):
137
144
  raise ValueError("Workflow does not support single output")
138
145
 
146
+ # Establish workflow run and trace identifiers
147
+ existing_run_id = self._context_state.workflow_run_id.get()
148
+ existing_trace_id = self._context_state.workflow_trace_id.get()
149
+
150
+ workflow_run_id = existing_run_id or str(uuid.uuid4())
151
+
152
+ workflow_trace_id = existing_trace_id or uuid.uuid4().int
153
+
154
+ token_run_id = self._context_state.workflow_run_id.set(workflow_run_id)
155
+ token_trace_id = self._context_state.workflow_trace_id.set(workflow_trace_id)
156
+
157
+ # Prepare workflow-level intermediate step identifiers
158
+ workflow_step_uuid = str(uuid.uuid4())
159
+ workflow_name = getattr(self._entry_fn, 'instance_name', None) or "workflow"
160
+
139
161
  async with self._exporter_manager.start(context_state=self._context_state):
140
- # Run the workflow
141
- result = await self._entry_fn.ainvoke(self._input_message, to_type=to_type)
162
+ # Emit WORKFLOW_START
163
+ start_metadata = TraceMetadata(
164
+ provided_metadata={
165
+ "workflow_run_id": workflow_run_id,
166
+ "workflow_trace_id": f"{workflow_trace_id:032x}",
167
+ "conversation_id": self._context_state.conversation_id.get(),
168
+ })
169
+ self._context.intermediate_step_manager.push_intermediate_step(
170
+ IntermediateStepPayload(UUID=workflow_step_uuid,
171
+ event_type=IntermediateStepType.WORKFLOW_START,
172
+ name=workflow_name,
173
+ metadata=start_metadata))
174
+
175
+ result = await self._entry_fn.ainvoke(self._input_message, to_type=to_type) # type: ignore
176
+
177
+ # Emit WORKFLOW_END with output
178
+ end_metadata = TraceMetadata(
179
+ provided_metadata={
180
+ "workflow_run_id": workflow_run_id,
181
+ "workflow_trace_id": f"{workflow_trace_id:032x}",
182
+ "conversation_id": self._context_state.conversation_id.get(),
183
+ })
184
+ self._context.intermediate_step_manager.push_intermediate_step(
185
+ IntermediateStepPayload(UUID=workflow_step_uuid,
186
+ event_type=IntermediateStepType.WORKFLOW_END,
187
+ name=workflow_name,
188
+ metadata=end_metadata,
189
+ data=StreamEventData(output=result)))
142
190
 
143
- # Close the intermediate stream
144
191
  event_stream = self._context_state.event_stream.get()
145
192
  if event_stream:
146
193
  event_stream.on_complete()
@@ -155,25 +202,71 @@ class Runner:
155
202
  if event_stream:
156
203
  event_stream.on_complete()
157
204
  self._state = RunnerState.FAILED
158
-
159
205
  raise
206
+ finally:
207
+ if token_run_id is not None:
208
+ self._context_state.workflow_run_id.reset(token_run_id)
209
+ if token_trace_id is not None:
210
+ self._context_state.workflow_trace_id.reset(token_trace_id)
160
211
 
161
212
  async def result_stream(self, to_type: type | None = None):
162
213
 
163
214
  if (self._state != RunnerState.INITIALIZED):
164
215
  raise ValueError("Cannot run the workflow without entering the context")
165
216
 
217
+ token_run_id = None
218
+ token_trace_id = None
166
219
  try:
167
220
  self._state = RunnerState.RUNNING
168
221
 
169
222
  if (not self._entry_fn.has_streaming_output):
170
223
  raise ValueError("Workflow does not support streaming output")
171
224
 
225
+ # Establish workflow run and trace identifiers
226
+ existing_run_id = self._context_state.workflow_run_id.get()
227
+ existing_trace_id = self._context_state.workflow_trace_id.get()
228
+
229
+ workflow_run_id = existing_run_id or str(uuid.uuid4())
230
+
231
+ workflow_trace_id = existing_trace_id or uuid.uuid4().int
232
+
233
+ token_run_id = self._context_state.workflow_run_id.set(workflow_run_id)
234
+ token_trace_id = self._context_state.workflow_trace_id.set(workflow_trace_id)
235
+
236
+ # Prepare workflow-level intermediate step identifiers
237
+ workflow_step_uuid = str(uuid.uuid4())
238
+ workflow_name = getattr(self._entry_fn, 'instance_name', None) or "workflow"
239
+
172
240
  # Run the workflow
173
241
  async with self._exporter_manager.start(context_state=self._context_state):
174
- async for m in self._entry_fn.astream(self._input_message, to_type=to_type):
242
+ # Emit WORKFLOW_START
243
+ start_metadata = TraceMetadata(
244
+ provided_metadata={
245
+ "workflow_run_id": workflow_run_id,
246
+ "workflow_trace_id": f"{workflow_trace_id:032x}",
247
+ "conversation_id": self._context_state.conversation_id.get(),
248
+ })
249
+ self._context.intermediate_step_manager.push_intermediate_step(
250
+ IntermediateStepPayload(UUID=workflow_step_uuid,
251
+ event_type=IntermediateStepType.WORKFLOW_START,
252
+ name=workflow_name,
253
+ metadata=start_metadata))
254
+
255
+ async for m in self._entry_fn.astream(self._input_message, to_type=to_type): # type: ignore
175
256
  yield m
176
257
 
258
+ # Emit WORKFLOW_END
259
+ end_metadata = TraceMetadata(
260
+ provided_metadata={
261
+ "workflow_run_id": workflow_run_id,
262
+ "workflow_trace_id": f"{workflow_trace_id:032x}",
263
+ "conversation_id": self._context_state.conversation_id.get(),
264
+ })
265
+ self._context.intermediate_step_manager.push_intermediate_step(
266
+ IntermediateStepPayload(UUID=workflow_step_uuid,
267
+ event_type=IntermediateStepType.WORKFLOW_END,
268
+ name=workflow_name,
269
+ metadata=end_metadata))
177
270
  self._state = RunnerState.COMPLETED
178
271
 
179
272
  # Close the intermediate stream
@@ -187,8 +280,12 @@ class Runner:
187
280
  if event_stream:
188
281
  event_stream.on_complete()
189
282
  self._state = RunnerState.FAILED
190
-
191
283
  raise
284
+ finally:
285
+ if token_run_id is not None:
286
+ self._context_state.workflow_run_id.reset(token_run_id)
287
+ if token_trace_id is not None:
288
+ self._context_state.workflow_trace_id.reset(token_trace_id)
192
289
 
193
290
 
194
291
  # Compatibility aliases with previous releases
nat/runtime/session.py CHANGED
@@ -16,6 +16,7 @@
16
16
  import asyncio
17
17
  import contextvars
18
18
  import typing
19
+ import uuid
19
20
  from collections.abc import Awaitable
20
21
  from collections.abc import Callable
21
22
  from contextlib import asynccontextmanager
@@ -161,12 +162,37 @@ class SessionManager:
161
162
  if request.headers.get("user-message-id"):
162
163
  self._context_state.user_message_id.set(request.headers["user-message-id"])
163
164
 
165
+ # W3C Trace Context header: traceparent: 00-<trace-id>-<span-id>-<flags>
166
+ traceparent = request.headers.get("traceparent")
167
+ if traceparent:
168
+ try:
169
+ parts = traceparent.split("-")
170
+ if len(parts) >= 4:
171
+ trace_id_hex = parts[1]
172
+ if len(trace_id_hex) == 32:
173
+ trace_id_int = uuid.UUID(trace_id_hex).int
174
+ self._context_state.workflow_trace_id.set(trace_id_int)
175
+ except Exception:
176
+ pass
177
+
178
+ if not self._context_state.workflow_trace_id.get():
179
+ workflow_trace_id = request.headers.get("workflow-trace-id")
180
+ if workflow_trace_id:
181
+ try:
182
+ self._context_state.workflow_trace_id.set(uuid.UUID(workflow_trace_id).int)
183
+ except Exception:
184
+ pass
185
+
186
+ workflow_run_id = request.headers.get("workflow-run-id")
187
+ if workflow_run_id:
188
+ self._context_state.workflow_run_id.set(workflow_run_id)
189
+
164
190
  def set_metadata_from_websocket(self,
165
191
  websocket: WebSocket,
166
192
  user_message_id: str | None,
167
193
  conversation_id: str | None) -> None:
168
194
  """
169
- Extracts and sets user metadata for Websocket connections.
195
+ Extracts and sets user metadata for WebSocket connections.
170
196
  """
171
197
 
172
198
  # Extract cookies from WebSocket headers (similar to HTTP request)
@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
30
30
  class AddToolConfig(FunctionBaseConfig, name="add_memory"):
31
31
  """Function to add memory to a hosted memory platform."""
32
32
 
33
- description: str = Field(default=("Tool to add memory about a user's interactions to a system "
33
+ description: str = Field(default=("Tool to add a memory about a user's interactions to a system "
34
34
  "for retrieval later."),
35
35
  description="The description of this function's use for tool calling agents.")
36
- memory: MemoryRef = Field(default="saas_memory",
36
+ memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
37
37
  description=("Instance name of the memory client instance from the workflow "
38
38
  "configuration object."))
39
39
 
@@ -46,7 +46,7 @@ async def add_memory_tool(config: AddToolConfig, builder: Builder):
46
46
  from langchain_core.tools import ToolException
47
47
 
48
48
  # First, retrieve the memory client
49
- memory_editor = builder.get_memory_client(config.memory)
49
+ memory_editor = await builder.get_memory_client(config.memory)
50
50
 
51
51
  async def _arun(item: MemoryItem) -> str:
52
52
  """
@@ -30,10 +30,9 @@ logger = logging.getLogger(__name__)
30
30
  class DeleteToolConfig(FunctionBaseConfig, name="delete_memory"):
31
31
  """Function to delete memory from a hosted memory platform."""
32
32
 
33
- description: str = Field(default=("Tool to retrieve memory about a user's "
34
- "interactions to help answer questions in a personalized way."),
33
+ description: str = Field(default="Tool to delete a memory from a hosted memory platform.",
35
34
  description="The description of this function's use for tool calling agents.")
36
- memory: MemoryRef = Field(default="saas_memory",
35
+ memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
37
36
  description=("Instance name of the memory client instance from the workflow "
38
37
  "configuration object."))
39
38
 
@@ -47,7 +46,7 @@ async def delete_memory_tool(config: DeleteToolConfig, builder: Builder):
47
46
  from langchain_core.tools import ToolException
48
47
 
49
48
  # First, retrieve the memory client
50
- memory_editor = builder.get_memory_client(config.memory)
49
+ memory_editor = await builder.get_memory_client(config.memory)
51
50
 
52
51
  async def _arun(user_id: str) -> str:
53
52
  """
@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
30
30
  class GetToolConfig(FunctionBaseConfig, name="get_memory"):
31
31
  """Function to get memory to a hosted memory platform."""
32
32
 
33
- description: str = Field(default=("Tool to retrieve memory about a user's "
33
+ description: str = Field(default=("Tool to retrieve a memory about a user's "
34
34
  "interactions to help answer questions in a personalized way."),
35
35
  description="The description of this function's use for tool calling agents.")
36
- memory: MemoryRef = Field(default="saas_memory",
36
+ memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
37
37
  description=("Instance name of the memory client instance from the workflow "
38
38
  "configuration object."))
39
39
 
@@ -49,7 +49,7 @@ async def get_memory_tool(config: GetToolConfig, builder: Builder):
49
49
  from langchain_core.tools import ToolException
50
50
 
51
51
  # First, retrieve the memory client
52
- memory_editor = builder.get_memory_client(config.memory)
52
+ memory_editor = await builder.get_memory_client(config.memory)
53
53
 
54
54
  async def _arun(search_input: SearchMemoryInput) -> str:
55
55
  """
@@ -67,6 +67,6 @@ async def get_memory_tool(config: GetToolConfig, builder: Builder):
67
67
 
68
68
  except Exception as e:
69
69
 
70
- raise ToolException(f"Error retreiving memory: {e}") from e
70
+ raise ToolException(f"Error retrieving memory: {e}") from e
71
71
 
72
72
  yield FunctionInfo.from_fn(_arun, description=config.description)