nvidia-nat 1.3.0rc1__py3-none-any.whl → 1.3.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/prompt_optimizer/register.py +2 -2
- nat/agent/react_agent/register.py +20 -21
- nat/agent/rewoo_agent/register.py +18 -20
- nat/agent/tool_calling_agent/register.py +7 -3
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +31 -18
- nat/builder/component_utils.py +1 -1
- nat/builder/context.py +22 -6
- nat/builder/function.py +3 -2
- nat/builder/workflow_builder.py +46 -3
- nat/cli/commands/mcp/mcp.py +6 -6
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -12
- nat/cli/commands/workflow/templates/register.py.j2 +2 -2
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +54 -10
- nat/cli/entrypoint.py +9 -1
- nat/cli/main.py +3 -0
- nat/data_models/api_server.py +143 -66
- nat/data_models/config.py +1 -1
- nat/data_models/span.py +41 -3
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +2 -2
- nat/front_ends/console/console_front_end_plugin.py +11 -2
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +5 -35
- nat/front_ends/fastapi/message_validator.py +3 -1
- nat/observability/exporter/span_exporter.py +34 -14
- nat/observability/register.py +16 -0
- nat/profiler/decorators/framework_wrapper.py +1 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
- nat/runtime/runner.py +103 -6
- nat/runtime/session.py +27 -1
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +4 -4
- nat/utils/decorators.py +210 -0
- nat/utils/type_converter.py +8 -0
- nvidia_nat-1.3.0rc3.dist-info/METADATA +195 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/RECORD +46 -45
- nvidia_nat-1.3.0rc1.dist-info/METADATA +0 -391
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/top_level.txt +0 -0
|
@@ -689,10 +689,13 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
689
689
|
|
|
690
690
|
async def post_openai_api_compatible(response: Response, request: Request, payload: request_type):
|
|
691
691
|
# Check if streaming is requested
|
|
692
|
+
|
|
693
|
+
response.headers["Content-Type"] = "application/json"
|
|
692
694
|
stream_requested = getattr(payload, 'stream', False)
|
|
693
695
|
|
|
694
696
|
async with session_manager.session(http_connection=request):
|
|
695
697
|
if stream_requested:
|
|
698
|
+
|
|
696
699
|
# Return streaming response
|
|
697
700
|
return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
|
|
698
701
|
content=generate_streaming_response_as_str(
|
|
@@ -703,40 +706,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
703
706
|
result_type=ChatResponseChunk,
|
|
704
707
|
output_type=ChatResponseChunk))
|
|
705
708
|
|
|
706
|
-
|
|
707
|
-
try:
|
|
708
|
-
response.headers["Content-Type"] = "application/json"
|
|
709
|
-
return await generate_single_response(payload, session_manager, result_type=ChatResponse)
|
|
710
|
-
except ValueError as e:
|
|
711
|
-
if "Cannot get a single output value for streaming workflows" in str(e):
|
|
712
|
-
# Workflow only supports streaming, but client requested non-streaming
|
|
713
|
-
# Fall back to streaming and collect the result
|
|
714
|
-
chunks = []
|
|
715
|
-
async for chunk_str in generate_streaming_response_as_str(
|
|
716
|
-
payload,
|
|
717
|
-
session_manager=session_manager,
|
|
718
|
-
streaming=True,
|
|
719
|
-
step_adaptor=self.get_step_adaptor(),
|
|
720
|
-
result_type=ChatResponseChunk,
|
|
721
|
-
output_type=ChatResponseChunk):
|
|
722
|
-
if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
|
|
723
|
-
chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
|
|
724
|
-
if chunk_data:
|
|
725
|
-
try:
|
|
726
|
-
chunk_json = ChatResponseChunk.model_validate_json(chunk_data)
|
|
727
|
-
if (chunk_json.choices and len(chunk_json.choices) > 0
|
|
728
|
-
and chunk_json.choices[0].delta
|
|
729
|
-
and chunk_json.choices[0].delta.content is not None):
|
|
730
|
-
chunks.append(chunk_json.choices[0].delta.content)
|
|
731
|
-
except Exception:
|
|
732
|
-
continue
|
|
733
|
-
|
|
734
|
-
# Create a single response from collected chunks
|
|
735
|
-
content = "".join(chunks)
|
|
736
|
-
single_response = ChatResponse.from_string(content)
|
|
737
|
-
response.headers["Content-Type"] = "application/json"
|
|
738
|
-
return single_response
|
|
739
|
-
raise
|
|
709
|
+
return await generate_single_response(payload, session_manager, result_type=ChatResponse)
|
|
740
710
|
|
|
741
711
|
return post_openai_api_compatible
|
|
742
712
|
|
|
@@ -1128,7 +1098,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
1128
1098
|
if configured_group.config.type != "mcp_client":
|
|
1129
1099
|
continue
|
|
1130
1100
|
|
|
1131
|
-
from nat.plugins.mcp.
|
|
1101
|
+
from nat.plugins.mcp.client_config import MCPClientConfig
|
|
1132
1102
|
|
|
1133
1103
|
config = configured_group.config
|
|
1134
1104
|
assert isinstance(config, MCPClientConfig)
|
|
@@ -139,8 +139,10 @@ class MessageValidator:
|
|
|
139
139
|
text_content: str = str(data_model.payload)
|
|
140
140
|
validated_message_content = SystemResponseContent(text=text_content)
|
|
141
141
|
|
|
142
|
-
elif
|
|
142
|
+
elif isinstance(data_model, ChatResponse):
|
|
143
143
|
validated_message_content = SystemResponseContent(text=data_model.choices[0].message.content)
|
|
144
|
+
elif isinstance(data_model, ChatResponseChunk):
|
|
145
|
+
validated_message_content = SystemResponseContent(text=data_model.choices[0].delta.content)
|
|
144
146
|
|
|
145
147
|
elif (isinstance(data_model, ResponseIntermediateStep)):
|
|
146
148
|
validated_message_content = SystemIntermediateStepContent(name=data_model.name,
|
|
@@ -126,6 +126,7 @@ class SpanExporter(ProcessingExporter[InputSpanT, OutputSpanT], SerializeMixin):
|
|
|
126
126
|
|
|
127
127
|
parent_span = None
|
|
128
128
|
span_ctx = None
|
|
129
|
+
workflow_trace_id = self._context_state.workflow_trace_id.get()
|
|
129
130
|
|
|
130
131
|
# Look up the parent span to establish hierarchy
|
|
131
132
|
# event.parent_id is the UUID of the last START step with a different UUID from current step
|
|
@@ -141,6 +142,9 @@ class SpanExporter(ProcessingExporter[InputSpanT, OutputSpanT], SerializeMixin):
|
|
|
141
142
|
parent_span = parent_span.model_copy() if isinstance(parent_span, Span) else None
|
|
142
143
|
if parent_span and parent_span.context:
|
|
143
144
|
span_ctx = SpanContext(trace_id=parent_span.context.trace_id)
|
|
145
|
+
# No parent: adopt workflow trace id if available to keep all spans in the same trace
|
|
146
|
+
if span_ctx is None and workflow_trace_id:
|
|
147
|
+
span_ctx = SpanContext(trace_id=workflow_trace_id)
|
|
144
148
|
|
|
145
149
|
# Extract start/end times from the step
|
|
146
150
|
# By convention, `span_event_timestamp` is the time we started, `event_timestamp` is the time we ended.
|
|
@@ -154,23 +158,39 @@ class SpanExporter(ProcessingExporter[InputSpanT, OutputSpanT], SerializeMixin):
|
|
|
154
158
|
else:
|
|
155
159
|
sub_span_name = f"{event.payload.event_type}"
|
|
156
160
|
|
|
161
|
+
# Prefer parent/context trace id for attribute, else workflow trace id
|
|
162
|
+
_attr_trace_id = None
|
|
163
|
+
if span_ctx is not None:
|
|
164
|
+
_attr_trace_id = span_ctx.trace_id
|
|
165
|
+
elif parent_span and parent_span.context:
|
|
166
|
+
_attr_trace_id = parent_span.context.trace_id
|
|
167
|
+
elif workflow_trace_id:
|
|
168
|
+
_attr_trace_id = workflow_trace_id
|
|
169
|
+
|
|
170
|
+
attributes = {
|
|
171
|
+
f"{self._span_prefix}.event_type":
|
|
172
|
+
event.payload.event_type.value,
|
|
173
|
+
f"{self._span_prefix}.function.id":
|
|
174
|
+
event.function_ancestry.function_id if event.function_ancestry else "unknown",
|
|
175
|
+
f"{self._span_prefix}.function.name":
|
|
176
|
+
event.function_ancestry.function_name if event.function_ancestry else "unknown",
|
|
177
|
+
f"{self._span_prefix}.subspan.name":
|
|
178
|
+
event.payload.name or "",
|
|
179
|
+
f"{self._span_prefix}.event_timestamp":
|
|
180
|
+
event.event_timestamp,
|
|
181
|
+
f"{self._span_prefix}.framework":
|
|
182
|
+
event.payload.framework.value if event.payload.framework else "unknown",
|
|
183
|
+
f"{self._span_prefix}.conversation.id":
|
|
184
|
+
self._context_state.conversation_id.get() or "unknown",
|
|
185
|
+
f"{self._span_prefix}.workflow.run_id":
|
|
186
|
+
self._context_state.workflow_run_id.get() or "unknown",
|
|
187
|
+
f"{self._span_prefix}.workflow.trace_id": (f"{_attr_trace_id:032x}" if _attr_trace_id else "unknown"),
|
|
188
|
+
}
|
|
189
|
+
|
|
157
190
|
sub_span = Span(name=sub_span_name,
|
|
158
191
|
parent=parent_span,
|
|
159
192
|
context=span_ctx,
|
|
160
|
-
attributes=
|
|
161
|
-
f"{self._span_prefix}.event_type":
|
|
162
|
-
event.payload.event_type.value,
|
|
163
|
-
f"{self._span_prefix}.function.id":
|
|
164
|
-
event.function_ancestry.function_id if event.function_ancestry else "unknown",
|
|
165
|
-
f"{self._span_prefix}.function.name":
|
|
166
|
-
event.function_ancestry.function_name if event.function_ancestry else "unknown",
|
|
167
|
-
f"{self._span_prefix}.subspan.name":
|
|
168
|
-
event.payload.name or "",
|
|
169
|
-
f"{self._span_prefix}.event_timestamp":
|
|
170
|
-
event.event_timestamp,
|
|
171
|
-
f"{self._span_prefix}.framework":
|
|
172
|
-
event.payload.framework.value if event.payload.framework else "unknown",
|
|
173
|
-
},
|
|
193
|
+
attributes=attributes,
|
|
174
194
|
start_time=start_ns)
|
|
175
195
|
|
|
176
196
|
span_kind = event_type_to_span_kind(event.event_type)
|
nat/observability/register.py
CHANGED
|
@@ -77,6 +77,14 @@ async def console_logging_method(config: ConsoleLoggingMethodConfig, builder: Bu
|
|
|
77
77
|
level = getattr(logging, config.level.upper(), logging.INFO)
|
|
78
78
|
handler = logging.StreamHandler(stream=sys.stdout)
|
|
79
79
|
handler.setLevel(level)
|
|
80
|
+
|
|
81
|
+
# Set formatter to match the default CLI format
|
|
82
|
+
formatter = logging.Formatter(
|
|
83
|
+
fmt="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
|
|
84
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
85
|
+
)
|
|
86
|
+
handler.setFormatter(formatter)
|
|
87
|
+
|
|
80
88
|
yield handler
|
|
81
89
|
|
|
82
90
|
|
|
@@ -95,4 +103,12 @@ async def file_logging_method(config: FileLoggingMethod, builder: Builder):
|
|
|
95
103
|
level = getattr(logging, config.level.upper(), logging.INFO)
|
|
96
104
|
handler = logging.FileHandler(filename=config.path, mode="a", encoding="utf-8")
|
|
97
105
|
handler.setLevel(level)
|
|
106
|
+
|
|
107
|
+
# Set formatter to match the default CLI format
|
|
108
|
+
formatter = logging.Formatter(
|
|
109
|
+
fmt="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
|
|
110
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
111
|
+
)
|
|
112
|
+
handler.setFormatter(formatter)
|
|
113
|
+
|
|
98
114
|
yield handler
|
|
@@ -123,7 +123,7 @@ def set_framework_profiler_handler(
|
|
|
123
123
|
except ImportError as e:
|
|
124
124
|
logger.warning(
|
|
125
125
|
"ADK profiler not available. " +
|
|
126
|
-
"Install NAT with ADK extras: pip install
|
|
126
|
+
"Install NAT with ADK extras: pip install \"nvidia-nat[adk]\". Error: %s",
|
|
127
127
|
e)
|
|
128
128
|
else:
|
|
129
129
|
handler = ADKProfilerHandler()
|
|
@@ -36,7 +36,7 @@ class LinearModel(ForecastingBaseModel):
|
|
|
36
36
|
except ImportError:
|
|
37
37
|
logger.error(
|
|
38
38
|
"scikit-learn is not installed. Please install scikit-learn to use the LinearModel "
|
|
39
|
-
"profiling model or install
|
|
39
|
+
"profiling model or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
|
|
40
40
|
|
|
41
41
|
raise
|
|
42
42
|
|
|
@@ -36,7 +36,7 @@ class RandomForestModel(ForecastingBaseModel):
|
|
|
36
36
|
except ImportError:
|
|
37
37
|
logger.error(
|
|
38
38
|
"scikit-learn is not installed. Please install scikit-learn to use the RandomForest "
|
|
39
|
-
"profiling model or install
|
|
39
|
+
"profiling model or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
|
|
40
40
|
|
|
41
41
|
raise
|
|
42
42
|
|
|
@@ -304,7 +304,7 @@ def save_gantt_chart(all_nodes: list[CallNode], output_path: str) -> None:
|
|
|
304
304
|
import matplotlib.pyplot as plt
|
|
305
305
|
except ImportError:
|
|
306
306
|
logger.error("matplotlib is not installed. Please install matplotlib to use generate plots for the profiler "
|
|
307
|
-
"or install
|
|
307
|
+
"or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
|
|
308
308
|
|
|
309
309
|
raise
|
|
310
310
|
|
|
@@ -212,7 +212,7 @@ def run_prefixspan(sequences_map: dict[int, list[PrefixCallNode]],
|
|
|
212
212
|
from prefixspan import PrefixSpan
|
|
213
213
|
except ImportError:
|
|
214
214
|
logger.error("prefixspan is not installed. Please install prefixspan to run the prefix analysis in the "
|
|
215
|
-
"profiler or install
|
|
215
|
+
"profiler or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
|
|
216
216
|
|
|
217
217
|
raise
|
|
218
218
|
|
nat/runtime/runner.py
CHANGED
|
@@ -15,11 +15,16 @@
|
|
|
15
15
|
|
|
16
16
|
import logging
|
|
17
17
|
import typing
|
|
18
|
+
import uuid
|
|
18
19
|
from enum import Enum
|
|
19
20
|
|
|
20
21
|
from nat.builder.context import Context
|
|
21
22
|
from nat.builder.context import ContextState
|
|
22
23
|
from nat.builder.function import Function
|
|
24
|
+
from nat.data_models.intermediate_step import IntermediateStepPayload
|
|
25
|
+
from nat.data_models.intermediate_step import IntermediateStepType
|
|
26
|
+
from nat.data_models.intermediate_step import StreamEventData
|
|
27
|
+
from nat.data_models.intermediate_step import TraceMetadata
|
|
23
28
|
from nat.data_models.invocation_node import InvocationNode
|
|
24
29
|
from nat.observability.exporter_manager import ExporterManager
|
|
25
30
|
from nat.utils.reactive.subject import Subject
|
|
@@ -130,17 +135,59 @@ class Runner:
|
|
|
130
135
|
if (self._state != RunnerState.INITIALIZED):
|
|
131
136
|
raise ValueError("Cannot run the workflow without entering the context")
|
|
132
137
|
|
|
138
|
+
token_run_id = None
|
|
139
|
+
token_trace_id = None
|
|
133
140
|
try:
|
|
134
141
|
self._state = RunnerState.RUNNING
|
|
135
142
|
|
|
136
143
|
if (not self._entry_fn.has_single_output):
|
|
137
144
|
raise ValueError("Workflow does not support single output")
|
|
138
145
|
|
|
146
|
+
# Establish workflow run and trace identifiers
|
|
147
|
+
existing_run_id = self._context_state.workflow_run_id.get()
|
|
148
|
+
existing_trace_id = self._context_state.workflow_trace_id.get()
|
|
149
|
+
|
|
150
|
+
workflow_run_id = existing_run_id or str(uuid.uuid4())
|
|
151
|
+
|
|
152
|
+
workflow_trace_id = existing_trace_id or uuid.uuid4().int
|
|
153
|
+
|
|
154
|
+
token_run_id = self._context_state.workflow_run_id.set(workflow_run_id)
|
|
155
|
+
token_trace_id = self._context_state.workflow_trace_id.set(workflow_trace_id)
|
|
156
|
+
|
|
157
|
+
# Prepare workflow-level intermediate step identifiers
|
|
158
|
+
workflow_step_uuid = str(uuid.uuid4())
|
|
159
|
+
workflow_name = getattr(self._entry_fn, 'instance_name', None) or "workflow"
|
|
160
|
+
|
|
139
161
|
async with self._exporter_manager.start(context_state=self._context_state):
|
|
140
|
-
#
|
|
141
|
-
|
|
162
|
+
# Emit WORKFLOW_START
|
|
163
|
+
start_metadata = TraceMetadata(
|
|
164
|
+
provided_metadata={
|
|
165
|
+
"workflow_run_id": workflow_run_id,
|
|
166
|
+
"workflow_trace_id": f"{workflow_trace_id:032x}",
|
|
167
|
+
"conversation_id": self._context_state.conversation_id.get(),
|
|
168
|
+
})
|
|
169
|
+
self._context.intermediate_step_manager.push_intermediate_step(
|
|
170
|
+
IntermediateStepPayload(UUID=workflow_step_uuid,
|
|
171
|
+
event_type=IntermediateStepType.WORKFLOW_START,
|
|
172
|
+
name=workflow_name,
|
|
173
|
+
metadata=start_metadata))
|
|
174
|
+
|
|
175
|
+
result = await self._entry_fn.ainvoke(self._input_message, to_type=to_type) # type: ignore
|
|
176
|
+
|
|
177
|
+
# Emit WORKFLOW_END with output
|
|
178
|
+
end_metadata = TraceMetadata(
|
|
179
|
+
provided_metadata={
|
|
180
|
+
"workflow_run_id": workflow_run_id,
|
|
181
|
+
"workflow_trace_id": f"{workflow_trace_id:032x}",
|
|
182
|
+
"conversation_id": self._context_state.conversation_id.get(),
|
|
183
|
+
})
|
|
184
|
+
self._context.intermediate_step_manager.push_intermediate_step(
|
|
185
|
+
IntermediateStepPayload(UUID=workflow_step_uuid,
|
|
186
|
+
event_type=IntermediateStepType.WORKFLOW_END,
|
|
187
|
+
name=workflow_name,
|
|
188
|
+
metadata=end_metadata,
|
|
189
|
+
data=StreamEventData(output=result)))
|
|
142
190
|
|
|
143
|
-
# Close the intermediate stream
|
|
144
191
|
event_stream = self._context_state.event_stream.get()
|
|
145
192
|
if event_stream:
|
|
146
193
|
event_stream.on_complete()
|
|
@@ -155,25 +202,71 @@ class Runner:
|
|
|
155
202
|
if event_stream:
|
|
156
203
|
event_stream.on_complete()
|
|
157
204
|
self._state = RunnerState.FAILED
|
|
158
|
-
|
|
159
205
|
raise
|
|
206
|
+
finally:
|
|
207
|
+
if token_run_id is not None:
|
|
208
|
+
self._context_state.workflow_run_id.reset(token_run_id)
|
|
209
|
+
if token_trace_id is not None:
|
|
210
|
+
self._context_state.workflow_trace_id.reset(token_trace_id)
|
|
160
211
|
|
|
161
212
|
async def result_stream(self, to_type: type | None = None):
|
|
162
213
|
|
|
163
214
|
if (self._state != RunnerState.INITIALIZED):
|
|
164
215
|
raise ValueError("Cannot run the workflow without entering the context")
|
|
165
216
|
|
|
217
|
+
token_run_id = None
|
|
218
|
+
token_trace_id = None
|
|
166
219
|
try:
|
|
167
220
|
self._state = RunnerState.RUNNING
|
|
168
221
|
|
|
169
222
|
if (not self._entry_fn.has_streaming_output):
|
|
170
223
|
raise ValueError("Workflow does not support streaming output")
|
|
171
224
|
|
|
225
|
+
# Establish workflow run and trace identifiers
|
|
226
|
+
existing_run_id = self._context_state.workflow_run_id.get()
|
|
227
|
+
existing_trace_id = self._context_state.workflow_trace_id.get()
|
|
228
|
+
|
|
229
|
+
workflow_run_id = existing_run_id or str(uuid.uuid4())
|
|
230
|
+
|
|
231
|
+
workflow_trace_id = existing_trace_id or uuid.uuid4().int
|
|
232
|
+
|
|
233
|
+
token_run_id = self._context_state.workflow_run_id.set(workflow_run_id)
|
|
234
|
+
token_trace_id = self._context_state.workflow_trace_id.set(workflow_trace_id)
|
|
235
|
+
|
|
236
|
+
# Prepare workflow-level intermediate step identifiers
|
|
237
|
+
workflow_step_uuid = str(uuid.uuid4())
|
|
238
|
+
workflow_name = getattr(self._entry_fn, 'instance_name', None) or "workflow"
|
|
239
|
+
|
|
172
240
|
# Run the workflow
|
|
173
241
|
async with self._exporter_manager.start(context_state=self._context_state):
|
|
174
|
-
|
|
242
|
+
# Emit WORKFLOW_START
|
|
243
|
+
start_metadata = TraceMetadata(
|
|
244
|
+
provided_metadata={
|
|
245
|
+
"workflow_run_id": workflow_run_id,
|
|
246
|
+
"workflow_trace_id": f"{workflow_trace_id:032x}",
|
|
247
|
+
"conversation_id": self._context_state.conversation_id.get(),
|
|
248
|
+
})
|
|
249
|
+
self._context.intermediate_step_manager.push_intermediate_step(
|
|
250
|
+
IntermediateStepPayload(UUID=workflow_step_uuid,
|
|
251
|
+
event_type=IntermediateStepType.WORKFLOW_START,
|
|
252
|
+
name=workflow_name,
|
|
253
|
+
metadata=start_metadata))
|
|
254
|
+
|
|
255
|
+
async for m in self._entry_fn.astream(self._input_message, to_type=to_type): # type: ignore
|
|
175
256
|
yield m
|
|
176
257
|
|
|
258
|
+
# Emit WORKFLOW_END
|
|
259
|
+
end_metadata = TraceMetadata(
|
|
260
|
+
provided_metadata={
|
|
261
|
+
"workflow_run_id": workflow_run_id,
|
|
262
|
+
"workflow_trace_id": f"{workflow_trace_id:032x}",
|
|
263
|
+
"conversation_id": self._context_state.conversation_id.get(),
|
|
264
|
+
})
|
|
265
|
+
self._context.intermediate_step_manager.push_intermediate_step(
|
|
266
|
+
IntermediateStepPayload(UUID=workflow_step_uuid,
|
|
267
|
+
event_type=IntermediateStepType.WORKFLOW_END,
|
|
268
|
+
name=workflow_name,
|
|
269
|
+
metadata=end_metadata))
|
|
177
270
|
self._state = RunnerState.COMPLETED
|
|
178
271
|
|
|
179
272
|
# Close the intermediate stream
|
|
@@ -187,8 +280,12 @@ class Runner:
|
|
|
187
280
|
if event_stream:
|
|
188
281
|
event_stream.on_complete()
|
|
189
282
|
self._state = RunnerState.FAILED
|
|
190
|
-
|
|
191
283
|
raise
|
|
284
|
+
finally:
|
|
285
|
+
if token_run_id is not None:
|
|
286
|
+
self._context_state.workflow_run_id.reset(token_run_id)
|
|
287
|
+
if token_trace_id is not None:
|
|
288
|
+
self._context_state.workflow_trace_id.reset(token_trace_id)
|
|
192
289
|
|
|
193
290
|
|
|
194
291
|
# Compatibility aliases with previous releases
|
nat/runtime/session.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import asyncio
|
|
17
17
|
import contextvars
|
|
18
18
|
import typing
|
|
19
|
+
import uuid
|
|
19
20
|
from collections.abc import Awaitable
|
|
20
21
|
from collections.abc import Callable
|
|
21
22
|
from contextlib import asynccontextmanager
|
|
@@ -161,12 +162,37 @@ class SessionManager:
|
|
|
161
162
|
if request.headers.get("user-message-id"):
|
|
162
163
|
self._context_state.user_message_id.set(request.headers["user-message-id"])
|
|
163
164
|
|
|
165
|
+
# W3C Trace Context header: traceparent: 00-<trace-id>-<span-id>-<flags>
|
|
166
|
+
traceparent = request.headers.get("traceparent")
|
|
167
|
+
if traceparent:
|
|
168
|
+
try:
|
|
169
|
+
parts = traceparent.split("-")
|
|
170
|
+
if len(parts) >= 4:
|
|
171
|
+
trace_id_hex = parts[1]
|
|
172
|
+
if len(trace_id_hex) == 32:
|
|
173
|
+
trace_id_int = uuid.UUID(trace_id_hex).int
|
|
174
|
+
self._context_state.workflow_trace_id.set(trace_id_int)
|
|
175
|
+
except Exception:
|
|
176
|
+
pass
|
|
177
|
+
|
|
178
|
+
if not self._context_state.workflow_trace_id.get():
|
|
179
|
+
workflow_trace_id = request.headers.get("workflow-trace-id")
|
|
180
|
+
if workflow_trace_id:
|
|
181
|
+
try:
|
|
182
|
+
self._context_state.workflow_trace_id.set(uuid.UUID(workflow_trace_id).int)
|
|
183
|
+
except Exception:
|
|
184
|
+
pass
|
|
185
|
+
|
|
186
|
+
workflow_run_id = request.headers.get("workflow-run-id")
|
|
187
|
+
if workflow_run_id:
|
|
188
|
+
self._context_state.workflow_run_id.set(workflow_run_id)
|
|
189
|
+
|
|
164
190
|
def set_metadata_from_websocket(self,
|
|
165
191
|
websocket: WebSocket,
|
|
166
192
|
user_message_id: str | None,
|
|
167
193
|
conversation_id: str | None) -> None:
|
|
168
194
|
"""
|
|
169
|
-
Extracts and sets user metadata for
|
|
195
|
+
Extracts and sets user metadata for WebSocket connections.
|
|
170
196
|
"""
|
|
171
197
|
|
|
172
198
|
# Extract cookies from WebSocket headers (similar to HTTP request)
|
|
@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
|
|
|
30
30
|
class AddToolConfig(FunctionBaseConfig, name="add_memory"):
|
|
31
31
|
"""Function to add memory to a hosted memory platform."""
|
|
32
32
|
|
|
33
|
-
description: str = Field(default=("Tool to add memory about a user's interactions to a system "
|
|
33
|
+
description: str = Field(default=("Tool to add a memory about a user's interactions to a system "
|
|
34
34
|
"for retrieval later."),
|
|
35
35
|
description="The description of this function's use for tool calling agents.")
|
|
36
|
-
memory: MemoryRef = Field(default="saas_memory",
|
|
36
|
+
memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
|
|
37
37
|
description=("Instance name of the memory client instance from the workflow "
|
|
38
38
|
"configuration object."))
|
|
39
39
|
|
|
@@ -46,7 +46,7 @@ async def add_memory_tool(config: AddToolConfig, builder: Builder):
|
|
|
46
46
|
from langchain_core.tools import ToolException
|
|
47
47
|
|
|
48
48
|
# First, retrieve the memory client
|
|
49
|
-
memory_editor = builder.get_memory_client(config.memory)
|
|
49
|
+
memory_editor = await builder.get_memory_client(config.memory)
|
|
50
50
|
|
|
51
51
|
async def _arun(item: MemoryItem) -> str:
|
|
52
52
|
"""
|
|
@@ -30,10 +30,9 @@ logger = logging.getLogger(__name__)
|
|
|
30
30
|
class DeleteToolConfig(FunctionBaseConfig, name="delete_memory"):
|
|
31
31
|
"""Function to delete memory from a hosted memory platform."""
|
|
32
32
|
|
|
33
|
-
description: str = Field(default=
|
|
34
|
-
"interactions to help answer questions in a personalized way."),
|
|
33
|
+
description: str = Field(default="Tool to delete a memory from a hosted memory platform.",
|
|
35
34
|
description="The description of this function's use for tool calling agents.")
|
|
36
|
-
memory: MemoryRef = Field(default="saas_memory",
|
|
35
|
+
memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
|
|
37
36
|
description=("Instance name of the memory client instance from the workflow "
|
|
38
37
|
"configuration object."))
|
|
39
38
|
|
|
@@ -47,7 +46,7 @@ async def delete_memory_tool(config: DeleteToolConfig, builder: Builder):
|
|
|
47
46
|
from langchain_core.tools import ToolException
|
|
48
47
|
|
|
49
48
|
# First, retrieve the memory client
|
|
50
|
-
memory_editor = builder.get_memory_client(config.memory)
|
|
49
|
+
memory_editor = await builder.get_memory_client(config.memory)
|
|
51
50
|
|
|
52
51
|
async def _arun(user_id: str) -> str:
|
|
53
52
|
"""
|
|
@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
|
|
|
30
30
|
class GetToolConfig(FunctionBaseConfig, name="get_memory"):
|
|
31
31
|
"""Function to get memory to a hosted memory platform."""
|
|
32
32
|
|
|
33
|
-
description: str = Field(default=("Tool to retrieve memory about a user's "
|
|
33
|
+
description: str = Field(default=("Tool to retrieve a memory about a user's "
|
|
34
34
|
"interactions to help answer questions in a personalized way."),
|
|
35
35
|
description="The description of this function's use for tool calling agents.")
|
|
36
|
-
memory: MemoryRef = Field(default="saas_memory",
|
|
36
|
+
memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
|
|
37
37
|
description=("Instance name of the memory client instance from the workflow "
|
|
38
38
|
"configuration object."))
|
|
39
39
|
|
|
@@ -49,7 +49,7 @@ async def get_memory_tool(config: GetToolConfig, builder: Builder):
|
|
|
49
49
|
from langchain_core.tools import ToolException
|
|
50
50
|
|
|
51
51
|
# First, retrieve the memory client
|
|
52
|
-
memory_editor = builder.get_memory_client(config.memory)
|
|
52
|
+
memory_editor = await builder.get_memory_client(config.memory)
|
|
53
53
|
|
|
54
54
|
async def _arun(search_input: SearchMemoryInput) -> str:
|
|
55
55
|
"""
|
|
@@ -67,6 +67,6 @@ async def get_memory_tool(config: GetToolConfig, builder: Builder):
|
|
|
67
67
|
|
|
68
68
|
except Exception as e:
|
|
69
69
|
|
|
70
|
-
raise ToolException(f"Error
|
|
70
|
+
raise ToolException(f"Error retrieving memory: {e}") from e
|
|
71
71
|
|
|
72
72
|
yield FunctionInfo.from_fn(_arun, description=config.description)
|