nvidia-nat-data-flywheel 1.3.0a20250828__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. nat/meta/pypi.md +23 -0
  2. nat/plugins/data_flywheel/observability/__init__.py +14 -0
  3. nat/plugins/data_flywheel/observability/exporter/__init__.py +14 -0
  4. nat/plugins/data_flywheel/observability/exporter/dfw_elasticsearch_exporter.py +74 -0
  5. nat/plugins/data_flywheel/observability/exporter/dfw_exporter.py +99 -0
  6. nat/plugins/data_flywheel/observability/mixin/__init__.py +14 -0
  7. nat/plugins/data_flywheel/observability/mixin/elasticsearch_mixin.py +75 -0
  8. nat/plugins/data_flywheel/observability/processor/__init__.py +27 -0
  9. nat/plugins/data_flywheel/observability/processor/dfw_record_processor.py +86 -0
  10. nat/plugins/data_flywheel/observability/processor/trace_conversion/__init__.py +30 -0
  11. nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/__init__.py +14 -0
  12. nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/elasticsearch/__init__.py +14 -0
  13. nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/elasticsearch/nim_converter.py +44 -0
  14. nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/elasticsearch/openai_converter.py +368 -0
  15. nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/register.py +24 -0
  16. nat/plugins/data_flywheel/observability/processor/trace_conversion/span_extractor.py +79 -0
  17. nat/plugins/data_flywheel/observability/processor/trace_conversion/span_to_dfw_record.py +119 -0
  18. nat/plugins/data_flywheel/observability/processor/trace_conversion/trace_adapter_registry.py +255 -0
  19. nat/plugins/data_flywheel/observability/register.py +61 -0
  20. nat/plugins/data_flywheel/observability/schema/__init__.py +14 -0
  21. nat/plugins/data_flywheel/observability/schema/provider/__init__.py +14 -0
  22. nat/plugins/data_flywheel/observability/schema/provider/nim_trace_source.py +24 -0
  23. nat/plugins/data_flywheel/observability/schema/provider/openai_message.py +31 -0
  24. nat/plugins/data_flywheel/observability/schema/provider/openai_trace_source.py +95 -0
  25. nat/plugins/data_flywheel/observability/schema/register.py +21 -0
  26. nat/plugins/data_flywheel/observability/schema/schema_registry.py +144 -0
  27. nat/plugins/data_flywheel/observability/schema/sink/__init__.py +14 -0
  28. nat/plugins/data_flywheel/observability/schema/sink/elasticsearch/__init__.py +20 -0
  29. nat/plugins/data_flywheel/observability/schema/sink/elasticsearch/contract_version.py +31 -0
  30. nat/plugins/data_flywheel/observability/schema/sink/elasticsearch/dfw_es_record.py +222 -0
  31. nat/plugins/data_flywheel/observability/schema/trace_container.py +79 -0
  32. nat/plugins/data_flywheel/observability/schema/trace_source_base.py +22 -0
  33. nat/plugins/data_flywheel/observability/utils/deserialize.py +42 -0
  34. nvidia_nat_data_flywheel-1.3.0a20250828.dist-info/METADATA +34 -0
  35. nvidia_nat_data_flywheel-1.3.0a20250828.dist-info/RECORD +38 -0
  36. nvidia_nat_data_flywheel-1.3.0a20250828.dist-info/WHEEL +5 -0
  37. nvidia_nat_data_flywheel-1.3.0a20250828.dist-info/entry_points.txt +4 -0
  38. nvidia_nat_data_flywheel-1.3.0a20250828.dist-info/top_level.txt +1 -0
@@ -0,0 +1,368 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import logging
18
+
19
+ from nat.data_models.intermediate_step import ToolSchema
20
+ from nat.plugins.data_flywheel.observability.processor.trace_conversion.span_extractor import extract_timestamp
21
+ from nat.plugins.data_flywheel.observability.processor.trace_conversion.span_extractor import extract_usage_info
22
+ from nat.plugins.data_flywheel.observability.processor.trace_conversion.trace_adapter_registry import register_adapter
23
+ from nat.plugins.data_flywheel.observability.schema.provider.openai_message import OpenAIMessage
24
+ from nat.plugins.data_flywheel.observability.schema.provider.openai_trace_source import OpenAITraceSource
25
+ from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import AssistantMessage
26
+ from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import DFWESRecord
27
+ from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import FinishReason
28
+ from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import Function
29
+ from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import FunctionDetails
30
+ from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import FunctionMessage
31
+ from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import Message
32
+ from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import Request
33
+ from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import RequestTool
34
+ from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import Response
35
+ from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import ResponseChoice
36
+ from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import ResponseMessage
37
+ from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import SystemMessage
38
+ from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import ToolCall
39
+ from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import ToolMessage
40
+ from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import UserMessage
41
+ from nat.plugins.data_flywheel.observability.schema.trace_container import TraceContainer
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+ DEFAULT_ROLE = "user"
46
+
47
+ # Role mapping from various role types to standard roles
48
+ ROLE_MAP = {
49
+ "human": "user",
50
+ "user": "user",
51
+ "assistant": "assistant",
52
+ "ai": "assistant",
53
+ "system": "system",
54
+ "tool": "tool",
55
+ "function": "function",
56
+ "chain": "function"
57
+ }
58
+
59
+ FINISH_REASON_MAP = {"tool_calls": FinishReason.TOOL_CALLS, "stop": FinishReason.STOP, "length": FinishReason.LENGTH}
60
+
61
+
62
+ def convert_role(role: str) -> str:
63
+ """Convert role to standard format with fallback.
64
+
65
+ Args:
66
+ role (str): The role to convert
67
+
68
+ Returns:
69
+ str: The converted role
70
+ """
71
+ return ROLE_MAP.get(role, DEFAULT_ROLE)
72
+
73
+
74
+ def create_message_by_role(role: str, content: str | None, **kwargs) -> Message:
75
+ """Factory function for creating messages by role.
76
+
77
+ Args:
78
+ role (str): The message role
79
+ content (str): The message content
80
+ **kwargs: Additional role-specific parameters
81
+
82
+ Returns:
83
+ Message: The appropriate message type for the role
84
+
85
+ Raises:
86
+ ValueError: If the role is unsupported
87
+ """
88
+ role = convert_role(role)
89
+
90
+ match role:
91
+ case "user":
92
+ if content is None:
93
+ raise ValueError("User message content cannot be None")
94
+ return UserMessage(content=content, role="user")
95
+ case "system":
96
+ if content is None:
97
+ raise ValueError("System message content cannot be None")
98
+ return SystemMessage(content=content, role="system")
99
+ case "assistant":
100
+ tool_calls = kwargs.get("tool_calls", [])
101
+ if len(tool_calls) > 0:
102
+ content = None
103
+ return AssistantMessage(content=content, role="assistant", tool_calls=tool_calls if tool_calls else None)
104
+ case "tool":
105
+ tool_call_id = kwargs.get("tool_call_id", "")
106
+ if content is None:
107
+ raise ValueError("Tool message content cannot be None")
108
+ return ToolMessage(content=content, role="tool", tool_call_id=tool_call_id)
109
+ case "function":
110
+ return FunctionMessage(content=content, role="function")
111
+ case _:
112
+ raise ValueError(f"Unsupported message role: {role}. Supported roles: {list(ROLE_MAP.keys())}")
113
+
114
+
115
+ def create_tool_calls(tool_calls_data: list) -> list[ToolCall]:
116
+ """Create standardized tool calls from raw data.
117
+
118
+ Args:
119
+ tool_calls_data (list): Raw tool call data
120
+
121
+ Returns:
122
+ list[ToolCall]: List of validated tool calls
123
+ """
124
+ validated_tool_calls = []
125
+
126
+ for tool_call in tool_calls_data:
127
+ if not isinstance(tool_call, dict):
128
+ continue
129
+
130
+ function = tool_call.get("function", {})
131
+ if not isinstance(function, dict):
132
+ continue
133
+
134
+ # Parse function arguments safely
135
+ function_args = {}
136
+ try:
137
+ raw_args = function.get("arguments", "{}")
138
+ if isinstance(raw_args, str):
139
+ function_args = json.loads(raw_args) or {}
140
+ elif isinstance(raw_args, dict):
141
+ function_args = raw_args
142
+ except json.JSONDecodeError:
143
+ logger.warning("Invalid JSON in function arguments: %s", raw_args)
144
+ function_args = {}
145
+
146
+ validated_tool_calls.append(
147
+ ToolCall(type="function",
148
+ function=Function(name=function.get("name", "unknown") or "unknown", arguments=function_args)))
149
+
150
+ return validated_tool_calls
151
+
152
+
153
+ def convert_message_to_dfw(message: OpenAIMessage) -> Message:
154
+ """Convert a message to appropriate DFW message type with improved structure.
155
+
156
+ Args:
157
+ message (OpenAIMessage): The message to convert
158
+
159
+ Returns:
160
+ Message: The converted message
161
+
162
+ Raises:
163
+ ValueError: If the message cannot be converted
164
+ """
165
+
166
+ # Get content
167
+ if "content" in message.response_metadata:
168
+ content = message.response_metadata.get("content", None)
169
+ else:
170
+ content = message.content
171
+
172
+ # Get role
173
+ role = message.type or DEFAULT_ROLE
174
+
175
+ # Handle tool calls for assistant messages
176
+ tool_calls = []
177
+ raw_tool_calls = message.additional_kwargs.get("tool_calls", [])
178
+ if raw_tool_calls:
179
+ tool_calls = create_tool_calls(raw_tool_calls)
180
+
181
+ # # Get tool_call_id for tool messages
182
+ tool_call_id = message.tool_call_id or None
183
+
184
+ return create_message_by_role(role=role, content=content, tool_calls=tool_calls, tool_call_id=tool_call_id)
185
+
186
+
187
+ def validate_and_convert_tools(tools_schema: list) -> list[RequestTool]:
188
+ """Validate and convert tools schema to RequestTool format.
189
+
190
+ Args:
191
+ tools_schema (list): Raw tools schema
192
+
193
+ Returns:
194
+ list[RequestTool]: Validated request tools
195
+ """
196
+ request_tools = []
197
+
198
+ for tool in tools_schema:
199
+ if isinstance(tool, ToolSchema):
200
+ tool = tool.model_dump()
201
+
202
+ if not isinstance(tool, dict):
203
+ logger.warning("Invalid tool schema: expected 'dict', got '%s'", type(tool))
204
+ continue
205
+
206
+ if "function" not in tool:
207
+ logger.warning("Tool schema missing 'function' key: '%s'", tool)
208
+ continue
209
+
210
+ function_details = tool["function"]
211
+ if not isinstance(function_details, dict):
212
+ logger.warning("Tool function details must be 'dict', got '%s'", function_details)
213
+ continue
214
+
215
+ # Validate required function fields
216
+ required_fields = ["name", "description", "parameters"]
217
+ if not all(field in function_details for field in required_fields):
218
+ logger.warning("Tool function missing required fields '%s': '%s'", required_fields, function_details)
219
+ continue
220
+
221
+ try:
222
+ # Create FunctionDetails object from dict
223
+ function_obj = FunctionDetails(**function_details)
224
+ request_tools.append(RequestTool(type="function", function=function_obj))
225
+ except Exception as e:
226
+ logger.warning("Failed to create RequestTool: '%s'", str(e))
227
+ continue
228
+
229
+ return request_tools
230
+
231
+
232
+ def convert_chat_response(chat_response: dict, span_name: str = "", index: int = 0) -> ResponseChoice:
233
+ """Convert a chat response to a DFW payload with better error context.
234
+
235
+ Args:
236
+ chat_response (dict): The chat response to convert
237
+ span_name (str): Span name for error context
238
+ index (int): The index of this choice
239
+
240
+ Returns:
241
+ ResponseChoice: The converted chat response
242
+
243
+ Raises:
244
+ ValueError: If the chat response is invalid
245
+ """
246
+ message = chat_response.get("message", {})
247
+ if message is None or not message:
248
+ raise ValueError(f"Chat response missing message for span: '{span_name}'")
249
+
250
+ # Get content
251
+ content = message.get("content", None)
252
+
253
+ # Get role and finish reason
254
+ response_message = message.get("response_metadata", {})
255
+ finish_reason = response_message.get("finish_reason", {})
256
+
257
+ # Get tool calls using the centralized function
258
+ validated_tool_calls = []
259
+ additional_kwargs = message.get("additional_kwargs", {})
260
+ if additional_kwargs is not None:
261
+ tool_calls = additional_kwargs.get("tool_calls", [])
262
+ if tool_calls is not None:
263
+ validated_tool_calls = create_tool_calls(tool_calls)
264
+
265
+ # If there are no tool calls, set the content to None
266
+ if len(validated_tool_calls) > 0:
267
+ content = None
268
+
269
+ # Map finish reason to enum
270
+ if isinstance(finish_reason, str):
271
+ mapped_finish_reason = FINISH_REASON_MAP.get(finish_reason)
272
+ else:
273
+ mapped_finish_reason = None
274
+
275
+ response_choice = ResponseChoice(message=ResponseMessage(
276
+ content=content, role="assistant", tool_calls=validated_tool_calls if validated_tool_calls else None),
277
+ finish_reason=mapped_finish_reason,
278
+ index=index)
279
+
280
+ return response_choice
281
+
282
+
283
+ @register_adapter(trace_source_model=OpenAITraceSource)
284
+ def convert_langchain_openai(trace_source: TraceContainer) -> DFWESRecord:
285
+ """Convert a LangChain OpenAI trace source to a DFWESRecord.
286
+
287
+ Args:
288
+ trace_source (TraceContainer): The trace source to convert
289
+
290
+ Returns:
291
+ DFWESRecord: The converted DFW record
292
+
293
+ Raises:
294
+ ValueError: If the trace source cannot be converted to DFWESRecord
295
+ """
296
+ # Convert messages
297
+ messages = []
298
+ for message in trace_source.source.input_value:
299
+ try:
300
+ msg_result = convert_message_to_dfw(message)
301
+ messages.append(msg_result)
302
+ except ValueError as e:
303
+ raise ValueError(f"Failed to convert message in trace source: {e}") from e
304
+
305
+ # Get tools schema
306
+ tools_schema = trace_source.source.metadata.tools_schema
307
+ request_tools = validate_and_convert_tools(tools_schema) if tools_schema else []
308
+
309
+ # Construct a Request object
310
+ model_name = str(trace_source.span.attributes.get("nat.subspan.name", "unknown"))
311
+
312
+ # These parameters don't exist in current span structure, so set to None
313
+ # The schema allows them to be optional
314
+ temperature = None
315
+ max_tokens = None
316
+
317
+ request = Request(messages=messages,
318
+ model=model_name,
319
+ tools=request_tools if request_tools else None,
320
+ temperature=temperature,
321
+ max_tokens=max_tokens)
322
+
323
+ # Transform chat responses
324
+ response_choices = []
325
+ chat_responses = trace_source.source.metadata.chat_responses or []
326
+ for idx, chat_response in enumerate(chat_responses):
327
+ try:
328
+ response_choice = convert_chat_response(chat_response, trace_source.span.name, index=idx)
329
+ response_choices.append(response_choice)
330
+ except ValueError as e:
331
+ raise ValueError(f"Failed to convert chat response {idx}: {e}") from e
332
+
333
+ # Require at least one response choice
334
+ if not response_choices:
335
+ raise ValueError(f"No valid response choices found in span: '{trace_source.span.name}'. "
336
+ f"Expected at least one chat response in metadata.")
337
+
338
+ # Get timestamp with better error handling
339
+ timestamp_int = extract_timestamp(trace_source.span)
340
+
341
+ # Extract additional response metadata from span
342
+ response_id = trace_source.span.attributes.get(
343
+ "response.id") or f"response-{trace_source.span.name}-{timestamp_int}"
344
+ response_object = "chat.completion" # Standard OpenAI object type
345
+ created_timestamp = timestamp_int # Use same timestamp as the record
346
+
347
+ # Extract usage information from span attributes using structured models
348
+ usage_info = extract_usage_info(trace_source.span)
349
+ responses = Response(choices=response_choices,
350
+ id=response_id,
351
+ object=response_object,
352
+ created=created_timestamp,
353
+ model=model_name,
354
+ usage=usage_info.model_dump() if usage_info else None)
355
+
356
+ workload_id = trace_source.span.attributes.get("nat.function.name", "unknown")
357
+
358
+ try:
359
+ dfw_payload = DFWESRecord(request=request,
360
+ response=responses,
361
+ timestamp=timestamp_int,
362
+ workload_id=str(workload_id),
363
+ client_id=trace_source.source.client_id,
364
+ error_details=None)
365
+ logger.debug("Successfully converted span to DFWESRecord: '%s'", trace_source.span.name)
366
+ return dfw_payload
367
+ except Exception as e:
368
+ raise ValueError(f"Failed to create DFWESRecord for span '{trace_source.span.name}': {e}") from e
@@ -0,0 +1,24 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: disable=unused-import
17
+ # flake8: noqa
18
+ # isort:skip_file
19
+
20
+ # Import any adapters which need to be automatically registered here
21
+ from nat.plugins.data_flywheel.observability.processor.trace_conversion.adapter.elasticsearch import \
22
+ nim_converter
23
+ from nat.plugins.data_flywheel.observability.processor.trace_conversion.adapter.elasticsearch import \
24
+ openai_converter
@@ -0,0 +1,79 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ from nat.data_models.intermediate_step import TokenUsageBaseModel
19
+ from nat.data_models.intermediate_step import UsageInfo
20
+ from nat.data_models.span import Span
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def extract_token_usage(span: Span) -> TokenUsageBaseModel:
26
+ """Extract token usage information from a span.
27
+
28
+ Args:
29
+ span (Span): The span to extract token usage from
30
+
31
+ Returns:
32
+ TokenUsageBaseModel: The token usage information
33
+ """
34
+ # Extract usage information from span attributes using structured models
35
+ token_usage = TokenUsageBaseModel(prompt_tokens=span.attributes.get("llm.token_count.prompt", 0),
36
+ completion_tokens=span.attributes.get("llm.token_count.completion", 0),
37
+ total_tokens=span.attributes.get("llm.token_count.total", 0))
38
+
39
+ return token_usage
40
+
41
+
42
+ def extract_usage_info(span: Span) -> UsageInfo:
43
+ """Extract usage information from a span.
44
+
45
+ Args:
46
+ span (Span): The span to extract usage information from
47
+
48
+ Returns:
49
+ UsageInfo: The usage information
50
+ """
51
+ # Get additional usage metrics from span attributes
52
+ token_usage = extract_token_usage(span)
53
+ num_llm_calls = span.attributes.get("nat.usage.num_llm_calls", 0)
54
+ seconds_between_calls = span.attributes.get("nat.usage.seconds_between_calls", 0)
55
+
56
+ usage_info = UsageInfo(token_usage=token_usage,
57
+ num_llm_calls=num_llm_calls,
58
+ seconds_between_calls=seconds_between_calls)
59
+
60
+ return usage_info
61
+
62
+
63
+ def extract_timestamp(span: Span) -> int:
64
+ """Extract timestamp from a span.
65
+
66
+ Args:
67
+ span (Span): The span to extract timestamp from
68
+
69
+ Returns:
70
+ int: The timestamp
71
+ """
72
+ timestamp = span.attributes.get("nat.event_timestamp", 0)
73
+ try:
74
+ timestamp_int = int(float(str(timestamp)))
75
+ except (ValueError, TypeError):
76
+ logger.warning("Invalid timestamp in span '%s', using 0", span.name)
77
+ timestamp_int = 0
78
+
79
+ return timestamp_int
@@ -0,0 +1,119 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ from enum import Enum
18
+ from typing import Any
19
+
20
+ from pydantic import BaseModel
21
+
22
+ from nat.data_models.span import Span
23
+ from nat.plugins.data_flywheel.observability.processor.trace_conversion.trace_adapter_registry import (
24
+ TraceAdapterRegistry, # noqa: F401
25
+ )
26
+ from nat.plugins.data_flywheel.observability.schema.trace_container import TraceContainer
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def _get_string_value(value: Any) -> str:
32
+ """Extract string value from enum or literal type safely.
33
+
34
+ Args:
35
+ value (Any): Could be an Enum, string, or other type
36
+
37
+ Returns:
38
+ str: String representation of the value
39
+ """
40
+ if isinstance(value, Enum):
41
+ return str(value.value)
42
+ return str(value)
43
+
44
+
45
+ def get_trace_container(span: Span, client_id: str) -> TraceContainer:
46
+ """Create a TraceContainer from a span for schema detection and conversion.
47
+
48
+ Extracts trace data from span attributes and creates a TraceContainer where Pydantic's
49
+ discriminated union will automatically detect the correct trace source schema type.
50
+
51
+ Args:
52
+ span (Span): The span containing trace attributes to extract
53
+ client_id (str): The client ID to include in the trace source data
54
+
55
+ Returns:
56
+ TraceContainer: Container with automatically detected source type and original span
57
+
58
+ Raises:
59
+ ValueError: If span data doesn't match any registered trace source schemas
60
+ """
61
+ # Extract framework name from span attributes
62
+ framework = _get_string_value(span.attributes.get("nat.framework", "langchain"))
63
+
64
+ # Create trace source data - Pydantic union will detect correct schema type automatically
65
+ source_dict = {
66
+ "source": {
67
+ "framework": framework,
68
+ "input_value": span.attributes.get("input.value", None),
69
+ "metadata": span.attributes.get("nat.metadata", None),
70
+ "client_id": client_id,
71
+ },
72
+ "span": span
73
+ }
74
+
75
+ try:
76
+ # Create TraceContainer - Pydantic discriminated union automatically detects source type
77
+ trace_container = TraceContainer(**source_dict)
78
+ logger.debug("Pydantic union detected source type: %s for framework: %s",
79
+ type(trace_container.source).__name__,
80
+ framework)
81
+ return trace_container
82
+
83
+ except Exception as e:
84
+ # Schema detection failed - indicates missing adapter registration or malformed span data
85
+ registry_data = TraceAdapterRegistry.list_registered_types()
86
+ adapter_metadata = []
87
+ for source_type, target_converters in registry_data.items():
88
+ for target_type in target_converters.keys():
89
+ target_name = getattr(target_type, '__name__', str(target_type))
90
+ adapter_metadata.append(f"{source_type.__name__} -> {target_name}")
91
+
92
+ raise ValueError(f"Trace source schema detection failed for framework '{framework}'. "
93
+ f"Span data structure doesn't match any registered trace source schemas. "
94
+ f"Available registered adapters: {adapter_metadata}. "
95
+ f"Ensure a schema is registered with @register_adapter() for this trace format. "
96
+ f"Original error: {e}") from e
97
+
98
+
99
+ def span_to_dfw_record(span: Span, to_type: type[BaseModel], client_id: str) -> BaseModel:
100
+ """Convert a span to Data Flywheel record using registered trace adapters.
101
+
102
+ Creates a TraceContainer from the span, automatically detects the trace source type
103
+ via Pydantic schema matching, then uses the registered converter to transform it
104
+ to the specified target type.
105
+
106
+ Args:
107
+ span (Span): The span containing trace data to convert.
108
+ to_type (type[BaseModel]): Target Pydantic model type for the conversion.
109
+ client_id (str): Client identifier to include in the trace data.
110
+
111
+ Returns:
112
+ BaseModel: Converted record of the specified type.
113
+
114
+ Raises:
115
+ ValueError: If no converter is registered for the detected source type -> target type,
116
+ or if the conversion fails.
117
+ """
118
+ trace_container = get_trace_container(span, client_id)
119
+ return TraceAdapterRegistry.convert(trace_container, to_type=to_type)