nvidia-nat-data-flywheel 1.3.0a20250828__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +23 -0
- nat/plugins/data_flywheel/observability/__init__.py +14 -0
- nat/plugins/data_flywheel/observability/exporter/__init__.py +14 -0
- nat/plugins/data_flywheel/observability/exporter/dfw_elasticsearch_exporter.py +74 -0
- nat/plugins/data_flywheel/observability/exporter/dfw_exporter.py +99 -0
- nat/plugins/data_flywheel/observability/mixin/__init__.py +14 -0
- nat/plugins/data_flywheel/observability/mixin/elasticsearch_mixin.py +75 -0
- nat/plugins/data_flywheel/observability/processor/__init__.py +27 -0
- nat/plugins/data_flywheel/observability/processor/dfw_record_processor.py +86 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/__init__.py +30 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/__init__.py +14 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/elasticsearch/__init__.py +14 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/elasticsearch/nim_converter.py +44 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/elasticsearch/openai_converter.py +368 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/register.py +24 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/span_extractor.py +79 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/span_to_dfw_record.py +119 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/trace_adapter_registry.py +255 -0
- nat/plugins/data_flywheel/observability/register.py +61 -0
- nat/plugins/data_flywheel/observability/schema/__init__.py +14 -0
- nat/plugins/data_flywheel/observability/schema/provider/__init__.py +14 -0
- nat/plugins/data_flywheel/observability/schema/provider/nim_trace_source.py +24 -0
- nat/plugins/data_flywheel/observability/schema/provider/openai_message.py +31 -0
- nat/plugins/data_flywheel/observability/schema/provider/openai_trace_source.py +95 -0
- nat/plugins/data_flywheel/observability/schema/register.py +21 -0
- nat/plugins/data_flywheel/observability/schema/schema_registry.py +144 -0
- nat/plugins/data_flywheel/observability/schema/sink/__init__.py +14 -0
- nat/plugins/data_flywheel/observability/schema/sink/elasticsearch/__init__.py +20 -0
- nat/plugins/data_flywheel/observability/schema/sink/elasticsearch/contract_version.py +31 -0
- nat/plugins/data_flywheel/observability/schema/sink/elasticsearch/dfw_es_record.py +222 -0
- nat/plugins/data_flywheel/observability/schema/trace_container.py +79 -0
- nat/plugins/data_flywheel/observability/schema/trace_source_base.py +22 -0
- nat/plugins/data_flywheel/observability/utils/deserialize.py +42 -0
- nvidia_nat_data_flywheel-1.3.0a20250828.dist-info/METADATA +34 -0
- nvidia_nat_data_flywheel-1.3.0a20250828.dist-info/RECORD +38 -0
- nvidia_nat_data_flywheel-1.3.0a20250828.dist-info/WHEEL +5 -0
- nvidia_nat_data_flywheel-1.3.0a20250828.dist-info/entry_points.txt +4 -0
- nvidia_nat_data_flywheel-1.3.0a20250828.dist-info/top_level.txt +1 -0
@@ -0,0 +1,368 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import json
|
17
|
+
import logging
|
18
|
+
|
19
|
+
from nat.data_models.intermediate_step import ToolSchema
|
20
|
+
from nat.plugins.data_flywheel.observability.processor.trace_conversion.span_extractor import extract_timestamp
|
21
|
+
from nat.plugins.data_flywheel.observability.processor.trace_conversion.span_extractor import extract_usage_info
|
22
|
+
from nat.plugins.data_flywheel.observability.processor.trace_conversion.trace_adapter_registry import register_adapter
|
23
|
+
from nat.plugins.data_flywheel.observability.schema.provider.openai_message import OpenAIMessage
|
24
|
+
from nat.plugins.data_flywheel.observability.schema.provider.openai_trace_source import OpenAITraceSource
|
25
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import AssistantMessage
|
26
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import DFWESRecord
|
27
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import FinishReason
|
28
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import Function
|
29
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import FunctionDetails
|
30
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import FunctionMessage
|
31
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import Message
|
32
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import Request
|
33
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import RequestTool
|
34
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import Response
|
35
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import ResponseChoice
|
36
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import ResponseMessage
|
37
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import SystemMessage
|
38
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import ToolCall
|
39
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import ToolMessage
|
40
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import UserMessage
|
41
|
+
from nat.plugins.data_flywheel.observability.schema.trace_container import TraceContainer
|
42
|
+
|
43
|
+
logger = logging.getLogger(__name__)
|
44
|
+
|
45
|
+
DEFAULT_ROLE = "user"
|
46
|
+
|
47
|
+
# Role mapping from various role types to standard roles
|
48
|
+
ROLE_MAP = {
|
49
|
+
"human": "user",
|
50
|
+
"user": "user",
|
51
|
+
"assistant": "assistant",
|
52
|
+
"ai": "assistant",
|
53
|
+
"system": "system",
|
54
|
+
"tool": "tool",
|
55
|
+
"function": "function",
|
56
|
+
"chain": "function"
|
57
|
+
}
|
58
|
+
|
59
|
+
FINISH_REASON_MAP = {"tool_calls": FinishReason.TOOL_CALLS, "stop": FinishReason.STOP, "length": FinishReason.LENGTH}
|
60
|
+
|
61
|
+
|
62
|
+
def convert_role(role: str) -> str:
|
63
|
+
"""Convert role to standard format with fallback.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
role (str): The role to convert
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
str: The converted role
|
70
|
+
"""
|
71
|
+
return ROLE_MAP.get(role, DEFAULT_ROLE)
|
72
|
+
|
73
|
+
|
74
|
+
def create_message_by_role(role: str, content: str | None, **kwargs) -> Message:
|
75
|
+
"""Factory function for creating messages by role.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
role (str): The message role
|
79
|
+
content (str): The message content
|
80
|
+
**kwargs: Additional role-specific parameters
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
Message: The appropriate message type for the role
|
84
|
+
|
85
|
+
Raises:
|
86
|
+
ValueError: If the role is unsupported
|
87
|
+
"""
|
88
|
+
role = convert_role(role)
|
89
|
+
|
90
|
+
match role:
|
91
|
+
case "user":
|
92
|
+
if content is None:
|
93
|
+
raise ValueError("User message content cannot be None")
|
94
|
+
return UserMessage(content=content, role="user")
|
95
|
+
case "system":
|
96
|
+
if content is None:
|
97
|
+
raise ValueError("System message content cannot be None")
|
98
|
+
return SystemMessage(content=content, role="system")
|
99
|
+
case "assistant":
|
100
|
+
tool_calls = kwargs.get("tool_calls", [])
|
101
|
+
if len(tool_calls) > 0:
|
102
|
+
content = None
|
103
|
+
return AssistantMessage(content=content, role="assistant", tool_calls=tool_calls if tool_calls else None)
|
104
|
+
case "tool":
|
105
|
+
tool_call_id = kwargs.get("tool_call_id", "")
|
106
|
+
if content is None:
|
107
|
+
raise ValueError("Tool message content cannot be None")
|
108
|
+
return ToolMessage(content=content, role="tool", tool_call_id=tool_call_id)
|
109
|
+
case "function":
|
110
|
+
return FunctionMessage(content=content, role="function")
|
111
|
+
case _:
|
112
|
+
raise ValueError(f"Unsupported message role: {role}. Supported roles: {list(ROLE_MAP.keys())}")
|
113
|
+
|
114
|
+
|
115
|
+
def create_tool_calls(tool_calls_data: list) -> list[ToolCall]:
|
116
|
+
"""Create standardized tool calls from raw data.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
tool_calls_data (list): Raw tool call data
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
list[ToolCall]: List of validated tool calls
|
123
|
+
"""
|
124
|
+
validated_tool_calls = []
|
125
|
+
|
126
|
+
for tool_call in tool_calls_data:
|
127
|
+
if not isinstance(tool_call, dict):
|
128
|
+
continue
|
129
|
+
|
130
|
+
function = tool_call.get("function", {})
|
131
|
+
if not isinstance(function, dict):
|
132
|
+
continue
|
133
|
+
|
134
|
+
# Parse function arguments safely
|
135
|
+
function_args = {}
|
136
|
+
try:
|
137
|
+
raw_args = function.get("arguments", "{}")
|
138
|
+
if isinstance(raw_args, str):
|
139
|
+
function_args = json.loads(raw_args) or {}
|
140
|
+
elif isinstance(raw_args, dict):
|
141
|
+
function_args = raw_args
|
142
|
+
except json.JSONDecodeError:
|
143
|
+
logger.warning("Invalid JSON in function arguments: %s", raw_args)
|
144
|
+
function_args = {}
|
145
|
+
|
146
|
+
validated_tool_calls.append(
|
147
|
+
ToolCall(type="function",
|
148
|
+
function=Function(name=function.get("name", "unknown") or "unknown", arguments=function_args)))
|
149
|
+
|
150
|
+
return validated_tool_calls
|
151
|
+
|
152
|
+
|
153
|
+
def convert_message_to_dfw(message: OpenAIMessage) -> Message:
|
154
|
+
"""Convert a message to appropriate DFW message type with improved structure.
|
155
|
+
|
156
|
+
Args:
|
157
|
+
message (OpenAIMessage): The message to convert
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
Message: The converted message
|
161
|
+
|
162
|
+
Raises:
|
163
|
+
ValueError: If the message cannot be converted
|
164
|
+
"""
|
165
|
+
|
166
|
+
# Get content
|
167
|
+
if "content" in message.response_metadata:
|
168
|
+
content = message.response_metadata.get("content", None)
|
169
|
+
else:
|
170
|
+
content = message.content
|
171
|
+
|
172
|
+
# Get role
|
173
|
+
role = message.type or DEFAULT_ROLE
|
174
|
+
|
175
|
+
# Handle tool calls for assistant messages
|
176
|
+
tool_calls = []
|
177
|
+
raw_tool_calls = message.additional_kwargs.get("tool_calls", [])
|
178
|
+
if raw_tool_calls:
|
179
|
+
tool_calls = create_tool_calls(raw_tool_calls)
|
180
|
+
|
181
|
+
# # Get tool_call_id for tool messages
|
182
|
+
tool_call_id = message.tool_call_id or None
|
183
|
+
|
184
|
+
return create_message_by_role(role=role, content=content, tool_calls=tool_calls, tool_call_id=tool_call_id)
|
185
|
+
|
186
|
+
|
187
|
+
def validate_and_convert_tools(tools_schema: list) -> list[RequestTool]:
|
188
|
+
"""Validate and convert tools schema to RequestTool format.
|
189
|
+
|
190
|
+
Args:
|
191
|
+
tools_schema (list): Raw tools schema
|
192
|
+
|
193
|
+
Returns:
|
194
|
+
list[RequestTool]: Validated request tools
|
195
|
+
"""
|
196
|
+
request_tools = []
|
197
|
+
|
198
|
+
for tool in tools_schema:
|
199
|
+
if isinstance(tool, ToolSchema):
|
200
|
+
tool = tool.model_dump()
|
201
|
+
|
202
|
+
if not isinstance(tool, dict):
|
203
|
+
logger.warning("Invalid tool schema: expected 'dict', got '%s'", type(tool))
|
204
|
+
continue
|
205
|
+
|
206
|
+
if "function" not in tool:
|
207
|
+
logger.warning("Tool schema missing 'function' key: '%s'", tool)
|
208
|
+
continue
|
209
|
+
|
210
|
+
function_details = tool["function"]
|
211
|
+
if not isinstance(function_details, dict):
|
212
|
+
logger.warning("Tool function details must be 'dict', got '%s'", function_details)
|
213
|
+
continue
|
214
|
+
|
215
|
+
# Validate required function fields
|
216
|
+
required_fields = ["name", "description", "parameters"]
|
217
|
+
if not all(field in function_details for field in required_fields):
|
218
|
+
logger.warning("Tool function missing required fields '%s': '%s'", required_fields, function_details)
|
219
|
+
continue
|
220
|
+
|
221
|
+
try:
|
222
|
+
# Create FunctionDetails object from dict
|
223
|
+
function_obj = FunctionDetails(**function_details)
|
224
|
+
request_tools.append(RequestTool(type="function", function=function_obj))
|
225
|
+
except Exception as e:
|
226
|
+
logger.warning("Failed to create RequestTool: '%s'", str(e))
|
227
|
+
continue
|
228
|
+
|
229
|
+
return request_tools
|
230
|
+
|
231
|
+
|
232
|
+
def convert_chat_response(chat_response: dict, span_name: str = "", index: int = 0) -> ResponseChoice:
|
233
|
+
"""Convert a chat response to a DFW payload with better error context.
|
234
|
+
|
235
|
+
Args:
|
236
|
+
chat_response (dict): The chat response to convert
|
237
|
+
span_name (str): Span name for error context
|
238
|
+
index (int): The index of this choice
|
239
|
+
|
240
|
+
Returns:
|
241
|
+
ResponseChoice: The converted chat response
|
242
|
+
|
243
|
+
Raises:
|
244
|
+
ValueError: If the chat response is invalid
|
245
|
+
"""
|
246
|
+
message = chat_response.get("message", {})
|
247
|
+
if message is None or not message:
|
248
|
+
raise ValueError(f"Chat response missing message for span: '{span_name}'")
|
249
|
+
|
250
|
+
# Get content
|
251
|
+
content = message.get("content", None)
|
252
|
+
|
253
|
+
# Get role and finish reason
|
254
|
+
response_message = message.get("response_metadata", {})
|
255
|
+
finish_reason = response_message.get("finish_reason", {})
|
256
|
+
|
257
|
+
# Get tool calls using the centralized function
|
258
|
+
validated_tool_calls = []
|
259
|
+
additional_kwargs = message.get("additional_kwargs", {})
|
260
|
+
if additional_kwargs is not None:
|
261
|
+
tool_calls = additional_kwargs.get("tool_calls", [])
|
262
|
+
if tool_calls is not None:
|
263
|
+
validated_tool_calls = create_tool_calls(tool_calls)
|
264
|
+
|
265
|
+
# If there are no tool calls, set the content to None
|
266
|
+
if len(validated_tool_calls) > 0:
|
267
|
+
content = None
|
268
|
+
|
269
|
+
# Map finish reason to enum
|
270
|
+
if isinstance(finish_reason, str):
|
271
|
+
mapped_finish_reason = FINISH_REASON_MAP.get(finish_reason)
|
272
|
+
else:
|
273
|
+
mapped_finish_reason = None
|
274
|
+
|
275
|
+
response_choice = ResponseChoice(message=ResponseMessage(
|
276
|
+
content=content, role="assistant", tool_calls=validated_tool_calls if validated_tool_calls else None),
|
277
|
+
finish_reason=mapped_finish_reason,
|
278
|
+
index=index)
|
279
|
+
|
280
|
+
return response_choice
|
281
|
+
|
282
|
+
|
283
|
+
@register_adapter(trace_source_model=OpenAITraceSource)
|
284
|
+
def convert_langchain_openai(trace_source: TraceContainer) -> DFWESRecord:
|
285
|
+
"""Convert a LangChain OpenAI trace source to a DFWESRecord.
|
286
|
+
|
287
|
+
Args:
|
288
|
+
trace_source (TraceContainer): The trace source to convert
|
289
|
+
|
290
|
+
Returns:
|
291
|
+
DFWESRecord: The converted DFW record
|
292
|
+
|
293
|
+
Raises:
|
294
|
+
ValueError: If the trace source cannot be converted to DFWESRecord
|
295
|
+
"""
|
296
|
+
# Convert messages
|
297
|
+
messages = []
|
298
|
+
for message in trace_source.source.input_value:
|
299
|
+
try:
|
300
|
+
msg_result = convert_message_to_dfw(message)
|
301
|
+
messages.append(msg_result)
|
302
|
+
except ValueError as e:
|
303
|
+
raise ValueError(f"Failed to convert message in trace source: {e}") from e
|
304
|
+
|
305
|
+
# Get tools schema
|
306
|
+
tools_schema = trace_source.source.metadata.tools_schema
|
307
|
+
request_tools = validate_and_convert_tools(tools_schema) if tools_schema else []
|
308
|
+
|
309
|
+
# Construct a Request object
|
310
|
+
model_name = str(trace_source.span.attributes.get("nat.subspan.name", "unknown"))
|
311
|
+
|
312
|
+
# These parameters don't exist in current span structure, so set to None
|
313
|
+
# The schema allows them to be optional
|
314
|
+
temperature = None
|
315
|
+
max_tokens = None
|
316
|
+
|
317
|
+
request = Request(messages=messages,
|
318
|
+
model=model_name,
|
319
|
+
tools=request_tools if request_tools else None,
|
320
|
+
temperature=temperature,
|
321
|
+
max_tokens=max_tokens)
|
322
|
+
|
323
|
+
# Transform chat responses
|
324
|
+
response_choices = []
|
325
|
+
chat_responses = trace_source.source.metadata.chat_responses or []
|
326
|
+
for idx, chat_response in enumerate(chat_responses):
|
327
|
+
try:
|
328
|
+
response_choice = convert_chat_response(chat_response, trace_source.span.name, index=idx)
|
329
|
+
response_choices.append(response_choice)
|
330
|
+
except ValueError as e:
|
331
|
+
raise ValueError(f"Failed to convert chat response {idx}: {e}") from e
|
332
|
+
|
333
|
+
# Require at least one response choice
|
334
|
+
if not response_choices:
|
335
|
+
raise ValueError(f"No valid response choices found in span: '{trace_source.span.name}'. "
|
336
|
+
f"Expected at least one chat response in metadata.")
|
337
|
+
|
338
|
+
# Get timestamp with better error handling
|
339
|
+
timestamp_int = extract_timestamp(trace_source.span)
|
340
|
+
|
341
|
+
# Extract additional response metadata from span
|
342
|
+
response_id = trace_source.span.attributes.get(
|
343
|
+
"response.id") or f"response-{trace_source.span.name}-{timestamp_int}"
|
344
|
+
response_object = "chat.completion" # Standard OpenAI object type
|
345
|
+
created_timestamp = timestamp_int # Use same timestamp as the record
|
346
|
+
|
347
|
+
# Extract usage information from span attributes using structured models
|
348
|
+
usage_info = extract_usage_info(trace_source.span)
|
349
|
+
responses = Response(choices=response_choices,
|
350
|
+
id=response_id,
|
351
|
+
object=response_object,
|
352
|
+
created=created_timestamp,
|
353
|
+
model=model_name,
|
354
|
+
usage=usage_info.model_dump() if usage_info else None)
|
355
|
+
|
356
|
+
workload_id = trace_source.span.attributes.get("nat.function.name", "unknown")
|
357
|
+
|
358
|
+
try:
|
359
|
+
dfw_payload = DFWESRecord(request=request,
|
360
|
+
response=responses,
|
361
|
+
timestamp=timestamp_int,
|
362
|
+
workload_id=str(workload_id),
|
363
|
+
client_id=trace_source.source.client_id,
|
364
|
+
error_details=None)
|
365
|
+
logger.debug("Successfully converted span to DFWESRecord: '%s'", trace_source.span.name)
|
366
|
+
return dfw_payload
|
367
|
+
except Exception as e:
|
368
|
+
raise ValueError(f"Failed to create DFWESRecord for span '{trace_source.span.name}': {e}") from e
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
# pylint: disable=unused-import
|
17
|
+
# flake8: noqa
|
18
|
+
# isort:skip_file
|
19
|
+
|
20
|
+
# Import any adapters which need to be automatically registered here
|
21
|
+
from nat.plugins.data_flywheel.observability.processor.trace_conversion.adapter.elasticsearch import \
|
22
|
+
nim_converter
|
23
|
+
from nat.plugins.data_flywheel.observability.processor.trace_conversion.adapter.elasticsearch import \
|
24
|
+
openai_converter
|
@@ -0,0 +1,79 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
18
|
+
from nat.data_models.intermediate_step import TokenUsageBaseModel
|
19
|
+
from nat.data_models.intermediate_step import UsageInfo
|
20
|
+
from nat.data_models.span import Span
|
21
|
+
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
|
25
|
+
def extract_token_usage(span: Span) -> TokenUsageBaseModel:
|
26
|
+
"""Extract token usage information from a span.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
span (Span): The span to extract token usage from
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
TokenUsageBaseModel: The token usage information
|
33
|
+
"""
|
34
|
+
# Extract usage information from span attributes using structured models
|
35
|
+
token_usage = TokenUsageBaseModel(prompt_tokens=span.attributes.get("llm.token_count.prompt", 0),
|
36
|
+
completion_tokens=span.attributes.get("llm.token_count.completion", 0),
|
37
|
+
total_tokens=span.attributes.get("llm.token_count.total", 0))
|
38
|
+
|
39
|
+
return token_usage
|
40
|
+
|
41
|
+
|
42
|
+
def extract_usage_info(span: Span) -> UsageInfo:
|
43
|
+
"""Extract usage information from a span.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
span (Span): The span to extract usage information from
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
UsageInfo: The usage information
|
50
|
+
"""
|
51
|
+
# Get additional usage metrics from span attributes
|
52
|
+
token_usage = extract_token_usage(span)
|
53
|
+
num_llm_calls = span.attributes.get("nat.usage.num_llm_calls", 0)
|
54
|
+
seconds_between_calls = span.attributes.get("nat.usage.seconds_between_calls", 0)
|
55
|
+
|
56
|
+
usage_info = UsageInfo(token_usage=token_usage,
|
57
|
+
num_llm_calls=num_llm_calls,
|
58
|
+
seconds_between_calls=seconds_between_calls)
|
59
|
+
|
60
|
+
return usage_info
|
61
|
+
|
62
|
+
|
63
|
+
def extract_timestamp(span: Span) -> int:
|
64
|
+
"""Extract timestamp from a span.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
span (Span): The span to extract timestamp from
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
int: The timestamp
|
71
|
+
"""
|
72
|
+
timestamp = span.attributes.get("nat.event_timestamp", 0)
|
73
|
+
try:
|
74
|
+
timestamp_int = int(float(str(timestamp)))
|
75
|
+
except (ValueError, TypeError):
|
76
|
+
logger.warning("Invalid timestamp in span '%s', using 0", span.name)
|
77
|
+
timestamp_int = 0
|
78
|
+
|
79
|
+
return timestamp_int
|
@@ -0,0 +1,119 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
from enum import Enum
|
18
|
+
from typing import Any
|
19
|
+
|
20
|
+
from pydantic import BaseModel
|
21
|
+
|
22
|
+
from nat.data_models.span import Span
|
23
|
+
from nat.plugins.data_flywheel.observability.processor.trace_conversion.trace_adapter_registry import (
|
24
|
+
TraceAdapterRegistry, # noqa: F401
|
25
|
+
)
|
26
|
+
from nat.plugins.data_flywheel.observability.schema.trace_container import TraceContainer
|
27
|
+
|
28
|
+
logger = logging.getLogger(__name__)
|
29
|
+
|
30
|
+
|
31
|
+
def _get_string_value(value: Any) -> str:
|
32
|
+
"""Extract string value from enum or literal type safely.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
value (Any): Could be an Enum, string, or other type
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
str: String representation of the value
|
39
|
+
"""
|
40
|
+
if isinstance(value, Enum):
|
41
|
+
return str(value.value)
|
42
|
+
return str(value)
|
43
|
+
|
44
|
+
|
45
|
+
def get_trace_container(span: Span, client_id: str) -> TraceContainer:
|
46
|
+
"""Create a TraceContainer from a span for schema detection and conversion.
|
47
|
+
|
48
|
+
Extracts trace data from span attributes and creates a TraceContainer where Pydantic's
|
49
|
+
discriminated union will automatically detect the correct trace source schema type.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
span (Span): The span containing trace attributes to extract
|
53
|
+
client_id (str): The client ID to include in the trace source data
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
TraceContainer: Container with automatically detected source type and original span
|
57
|
+
|
58
|
+
Raises:
|
59
|
+
ValueError: If span data doesn't match any registered trace source schemas
|
60
|
+
"""
|
61
|
+
# Extract framework name from span attributes
|
62
|
+
framework = _get_string_value(span.attributes.get("nat.framework", "langchain"))
|
63
|
+
|
64
|
+
# Create trace source data - Pydantic union will detect correct schema type automatically
|
65
|
+
source_dict = {
|
66
|
+
"source": {
|
67
|
+
"framework": framework,
|
68
|
+
"input_value": span.attributes.get("input.value", None),
|
69
|
+
"metadata": span.attributes.get("nat.metadata", None),
|
70
|
+
"client_id": client_id,
|
71
|
+
},
|
72
|
+
"span": span
|
73
|
+
}
|
74
|
+
|
75
|
+
try:
|
76
|
+
# Create TraceContainer - Pydantic discriminated union automatically detects source type
|
77
|
+
trace_container = TraceContainer(**source_dict)
|
78
|
+
logger.debug("Pydantic union detected source type: %s for framework: %s",
|
79
|
+
type(trace_container.source).__name__,
|
80
|
+
framework)
|
81
|
+
return trace_container
|
82
|
+
|
83
|
+
except Exception as e:
|
84
|
+
# Schema detection failed - indicates missing adapter registration or malformed span data
|
85
|
+
registry_data = TraceAdapterRegistry.list_registered_types()
|
86
|
+
adapter_metadata = []
|
87
|
+
for source_type, target_converters in registry_data.items():
|
88
|
+
for target_type in target_converters.keys():
|
89
|
+
target_name = getattr(target_type, '__name__', str(target_type))
|
90
|
+
adapter_metadata.append(f"{source_type.__name__} -> {target_name}")
|
91
|
+
|
92
|
+
raise ValueError(f"Trace source schema detection failed for framework '{framework}'. "
|
93
|
+
f"Span data structure doesn't match any registered trace source schemas. "
|
94
|
+
f"Available registered adapters: {adapter_metadata}. "
|
95
|
+
f"Ensure a schema is registered with @register_adapter() for this trace format. "
|
96
|
+
f"Original error: {e}") from e
|
97
|
+
|
98
|
+
|
99
|
+
def span_to_dfw_record(span: Span, to_type: type[BaseModel], client_id: str) -> BaseModel:
|
100
|
+
"""Convert a span to Data Flywheel record using registered trace adapters.
|
101
|
+
|
102
|
+
Creates a TraceContainer from the span, automatically detects the trace source type
|
103
|
+
via Pydantic schema matching, then uses the registered converter to transform it
|
104
|
+
to the specified target type.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
span (Span): The span containing trace data to convert.
|
108
|
+
to_type (type[BaseModel]): Target Pydantic model type for the conversion.
|
109
|
+
client_id (str): Client identifier to include in the trace data.
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
BaseModel: Converted record of the specified type.
|
113
|
+
|
114
|
+
Raises:
|
115
|
+
ValueError: If no converter is registered for the detected source type -> target type,
|
116
|
+
or if the conversion fails.
|
117
|
+
"""
|
118
|
+
trace_container = get_trace_container(span, client_id)
|
119
|
+
return TraceAdapterRegistry.convert(trace_container, to_type=to_type)
|