nvidia-nat-data-flywheel 1.3.0a20250828__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +23 -0
- nat/plugins/data_flywheel/observability/__init__.py +14 -0
- nat/plugins/data_flywheel/observability/exporter/__init__.py +14 -0
- nat/plugins/data_flywheel/observability/exporter/dfw_elasticsearch_exporter.py +74 -0
- nat/plugins/data_flywheel/observability/exporter/dfw_exporter.py +99 -0
- nat/plugins/data_flywheel/observability/mixin/__init__.py +14 -0
- nat/plugins/data_flywheel/observability/mixin/elasticsearch_mixin.py +75 -0
- nat/plugins/data_flywheel/observability/processor/__init__.py +27 -0
- nat/plugins/data_flywheel/observability/processor/dfw_record_processor.py +86 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/__init__.py +30 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/__init__.py +14 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/elasticsearch/__init__.py +14 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/elasticsearch/nim_converter.py +44 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/elasticsearch/openai_converter.py +368 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/register.py +24 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/span_extractor.py +79 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/span_to_dfw_record.py +119 -0
- nat/plugins/data_flywheel/observability/processor/trace_conversion/trace_adapter_registry.py +255 -0
- nat/plugins/data_flywheel/observability/register.py +61 -0
- nat/plugins/data_flywheel/observability/schema/__init__.py +14 -0
- nat/plugins/data_flywheel/observability/schema/provider/__init__.py +14 -0
- nat/plugins/data_flywheel/observability/schema/provider/nim_trace_source.py +24 -0
- nat/plugins/data_flywheel/observability/schema/provider/openai_message.py +31 -0
- nat/plugins/data_flywheel/observability/schema/provider/openai_trace_source.py +95 -0
- nat/plugins/data_flywheel/observability/schema/register.py +21 -0
- nat/plugins/data_flywheel/observability/schema/schema_registry.py +144 -0
- nat/plugins/data_flywheel/observability/schema/sink/__init__.py +14 -0
- nat/plugins/data_flywheel/observability/schema/sink/elasticsearch/__init__.py +20 -0
- nat/plugins/data_flywheel/observability/schema/sink/elasticsearch/contract_version.py +31 -0
- nat/plugins/data_flywheel/observability/schema/sink/elasticsearch/dfw_es_record.py +222 -0
- nat/plugins/data_flywheel/observability/schema/trace_container.py +79 -0
- nat/plugins/data_flywheel/observability/schema/trace_source_base.py +22 -0
- nat/plugins/data_flywheel/observability/utils/deserialize.py +42 -0
- nvidia_nat_data_flywheel-1.3.0a20250828.dist-info/METADATA +34 -0
- nvidia_nat_data_flywheel-1.3.0a20250828.dist-info/RECORD +38 -0
- nvidia_nat_data_flywheel-1.3.0a20250828.dist-info/WHEEL +5 -0
- nvidia_nat_data_flywheel-1.3.0a20250828.dist-info/entry_points.txt +4 -0
- nvidia_nat_data_flywheel-1.3.0a20250828.dist-info/top_level.txt +1 -0
@@ -0,0 +1,255 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
from collections.abc import Callable
|
18
|
+
from functools import reduce
|
19
|
+
from typing import Any
|
20
|
+
|
21
|
+
from nat.plugins.data_flywheel.observability.schema.trace_container import TraceContainer
|
22
|
+
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
class TraceAdapterRegistry:
|
27
|
+
"""Registry for trace source to target type conversions.
|
28
|
+
|
29
|
+
Maintains schema detection through Pydantic unions while enabling dynamic registration
|
30
|
+
of converter functions for different trace source types.
|
31
|
+
"""
|
32
|
+
|
33
|
+
_registered_types: dict[type, dict[type, Callable]] = {} # source_type -> {target_type -> converter}
|
34
|
+
_union_cache: Any = None
|
35
|
+
|
36
|
+
@classmethod
|
37
|
+
def register_adapter(cls, trace_source_model: type) -> Callable[[Callable], Callable]:
|
38
|
+
"""Register adapter with a trace source Pydantic model.
|
39
|
+
|
40
|
+
The model defines the schema for union-based detection, allowing automatic
|
41
|
+
schema matching without explicit framework/provider specification.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
trace_source_model (type): Pydantic model class that defines the trace source schema
|
45
|
+
(e.g., OpenAITraceSource, NIMTraceSource, CustomTraceSource)
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
Callable: Decorator function that registers the converter
|
49
|
+
"""
|
50
|
+
|
51
|
+
def decorator(func):
|
52
|
+
return_type = func.__annotations__.get('return')
|
53
|
+
|
54
|
+
# Validate return type annotation exists and is meaningful
|
55
|
+
if return_type is None:
|
56
|
+
raise ValueError(f"Converter function '{func.__name__}' must have a return type annotation.\n"
|
57
|
+
f"Example: def {func.__name__}(trace: TraceContainer) -> DFWESRecord:")
|
58
|
+
|
59
|
+
# Initialize nested dict if needed
|
60
|
+
if trace_source_model not in cls._registered_types:
|
61
|
+
cls._registered_types[trace_source_model] = {}
|
62
|
+
|
63
|
+
# Store converter: source_type -> target_type -> converter_func
|
64
|
+
cls._registered_types[trace_source_model][return_type] = func
|
65
|
+
|
66
|
+
# Immediately rebuild union and update TraceContainer model
|
67
|
+
cls._rebuild_union()
|
68
|
+
|
69
|
+
logger.debug("Registered %s -> %s converter",
|
70
|
+
trace_source_model.__name__,
|
71
|
+
getattr(return_type, '__name__', str(return_type)))
|
72
|
+
return func
|
73
|
+
|
74
|
+
return decorator
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
def convert(cls, trace_container: TraceContainer, to_type: type) -> Any:
|
78
|
+
"""Convert trace to target type using registered converter function.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
trace_container (TraceContainer): TraceContainer with source data to convert
|
82
|
+
to_type (type): Target type to convert to
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
Converted object of to_type
|
86
|
+
|
87
|
+
Raises:
|
88
|
+
ValueError: If no converter is registered for source->target combination
|
89
|
+
"""
|
90
|
+
source_type = type(trace_container.source)
|
91
|
+
|
92
|
+
# Look up converter: source_type -> target_type -> converter_func
|
93
|
+
source_converters = cls._registered_types.get(source_type, {})
|
94
|
+
converter = source_converters.get(to_type)
|
95
|
+
|
96
|
+
if not converter:
|
97
|
+
available_targets = list(source_converters.keys()) if source_converters else []
|
98
|
+
available_target_names = [getattr(t, '__name__', str(t)) for t in available_targets]
|
99
|
+
raise ValueError(
|
100
|
+
f"No converter from {source_type.__name__} to {getattr(to_type, '__name__', str(to_type))}. "
|
101
|
+
f"Available targets: {available_target_names}")
|
102
|
+
|
103
|
+
return converter(trace_container)
|
104
|
+
|
105
|
+
@classmethod
|
106
|
+
def get_adapter(cls, trace_container: TraceContainer, to_type: type) -> Callable | None:
|
107
|
+
"""Get the converter function for a given trace source and target type.
|
108
|
+
|
109
|
+
Args:
|
110
|
+
trace_container (TraceContainer): TraceContainer with source data
|
111
|
+
to_type (type): Target type to convert to
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
Converter function if registered, None if not found
|
115
|
+
"""
|
116
|
+
source_type = type(trace_container.source)
|
117
|
+
return cls._registered_types.get(source_type, {}).get(to_type)
|
118
|
+
|
119
|
+
@classmethod
|
120
|
+
def get_current_union(cls) -> type:
|
121
|
+
"""Get the current source union with all registered source types.
|
122
|
+
|
123
|
+
Returns:
|
124
|
+
type: Union type containing all registered trace source types
|
125
|
+
"""
|
126
|
+
if cls._union_cache is None:
|
127
|
+
cls._rebuild_union()
|
128
|
+
return cls._union_cache
|
129
|
+
|
130
|
+
@classmethod
|
131
|
+
def _rebuild_union(cls):
|
132
|
+
"""Rebuild the union with all registered trace source types."""
|
133
|
+
|
134
|
+
# Get all registered source types (dictionary keys)
|
135
|
+
all_schema_types = set(cls._registered_types.keys())
|
136
|
+
|
137
|
+
# Create union from source types (used for Pydantic schema detection)
|
138
|
+
if len(all_schema_types) == 0:
|
139
|
+
# No types registered yet - use Any as permissive fallback
|
140
|
+
cls._union_cache = Any
|
141
|
+
elif len(all_schema_types) == 1:
|
142
|
+
cls._union_cache = next(iter(all_schema_types))
|
143
|
+
else:
|
144
|
+
# Sort types by name to ensure consistent order
|
145
|
+
sorted_types = sorted(all_schema_types, key=lambda t: t.__name__)
|
146
|
+
# Create Union from multiple types using reduce
|
147
|
+
cls._union_cache = reduce(lambda a, b: a | b, sorted_types)
|
148
|
+
|
149
|
+
logger.debug("Rebuilt source union with %d registered source types: %s",
|
150
|
+
len(all_schema_types), [t.__name__ for t in all_schema_types])
|
151
|
+
|
152
|
+
# Update TraceContainer model with new union
|
153
|
+
cls._update_trace_source_model()
|
154
|
+
|
155
|
+
@classmethod
|
156
|
+
def _update_trace_source_model(cls):
|
157
|
+
"""Update the TraceContainer model to use the current dynamic union."""
|
158
|
+
try:
|
159
|
+
# Update the source field annotation to use current union
|
160
|
+
if hasattr(TraceContainer, '__annotations__'):
|
161
|
+
TraceContainer.__annotations__['source'] = cls._union_cache
|
162
|
+
|
163
|
+
# Force Pydantic to rebuild the model with new annotations
|
164
|
+
TraceContainer.model_rebuild()
|
165
|
+
logger.debug("Updated TraceContainer model with new union type")
|
166
|
+
except Exception as e:
|
167
|
+
logger.warning("Failed to update TraceContainer model: %s", e)
|
168
|
+
|
169
|
+
@classmethod
|
170
|
+
def unregister_adapter(cls, source_type: type, target_type: type) -> bool:
|
171
|
+
"""Unregister a specific adapter.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
source_type (type): The trace source type
|
175
|
+
target_type (type): The target conversion type
|
176
|
+
|
177
|
+
Returns:
|
178
|
+
bool: True if adapter was found and removed, False if not found
|
179
|
+
"""
|
180
|
+
if source_type not in cls._registered_types:
|
181
|
+
return False
|
182
|
+
|
183
|
+
target_converters = cls._registered_types[source_type]
|
184
|
+
if target_type not in target_converters:
|
185
|
+
return False
|
186
|
+
|
187
|
+
# Remove the specific converter
|
188
|
+
del target_converters[target_type]
|
189
|
+
|
190
|
+
# Clean up empty source entry
|
191
|
+
if not target_converters:
|
192
|
+
del cls._registered_types[source_type]
|
193
|
+
|
194
|
+
# Rebuild union since registered types changed
|
195
|
+
cls._rebuild_union()
|
196
|
+
|
197
|
+
logger.debug("Unregistered %s -> %s converter",
|
198
|
+
source_type.__name__,
|
199
|
+
getattr(target_type, '__name__', str(target_type)))
|
200
|
+
return True
|
201
|
+
|
202
|
+
@classmethod
|
203
|
+
def unregister_all_adapters(cls, source_type: type) -> int:
|
204
|
+
"""Unregister all adapters for a given source type.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
source_type (type): The trace source type to remove all converters for
|
208
|
+
|
209
|
+
Returns:
|
210
|
+
int: Number of converters removed
|
211
|
+
"""
|
212
|
+
if source_type not in cls._registered_types:
|
213
|
+
return 0
|
214
|
+
|
215
|
+
removed_count = len(cls._registered_types[source_type])
|
216
|
+
del cls._registered_types[source_type]
|
217
|
+
|
218
|
+
# Rebuild union since registered types changed
|
219
|
+
cls._rebuild_union()
|
220
|
+
|
221
|
+
logger.debug("Unregistered all %d converters for %s", removed_count, source_type.__name__)
|
222
|
+
return removed_count
|
223
|
+
|
224
|
+
@classmethod
|
225
|
+
def clear_registry(cls) -> int:
|
226
|
+
"""Clear all registered adapters. Useful for testing cleanup.
|
227
|
+
|
228
|
+
Returns:
|
229
|
+
int: Total number of converters removed
|
230
|
+
"""
|
231
|
+
total_removed = sum(len(converters) for converters in cls._registered_types.values())
|
232
|
+
cls._registered_types.clear()
|
233
|
+
cls._union_cache = None
|
234
|
+
|
235
|
+
# Rebuild union (will be empty now)
|
236
|
+
cls._rebuild_union()
|
237
|
+
|
238
|
+
logger.debug("Cleared registry - removed %d total converters", total_removed)
|
239
|
+
return total_removed
|
240
|
+
|
241
|
+
@classmethod
|
242
|
+
def list_registered_types(cls) -> dict[type, dict[type, Callable]]:
|
243
|
+
"""List all registered conversions: source_type -> {target_type -> converter}.
|
244
|
+
|
245
|
+
Returns:
|
246
|
+
dict[type, dict[type, Callable]]: Nested dict mapping source types to their available target conversions
|
247
|
+
"""
|
248
|
+
return cls._registered_types
|
249
|
+
|
250
|
+
|
251
|
+
# Convenience functions for adapter management
|
252
|
+
register_adapter = TraceAdapterRegistry.register_adapter
|
253
|
+
unregister_adapter = TraceAdapterRegistry.unregister_adapter
|
254
|
+
unregister_all_adapters = TraceAdapterRegistry.unregister_all_adapters
|
255
|
+
clear_registry = TraceAdapterRegistry.clear_registry
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
18
|
+
from pydantic import Field
|
19
|
+
|
20
|
+
from nat.builder.builder import Builder
|
21
|
+
from nat.cli.register_workflow import register_telemetry_exporter
|
22
|
+
from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
|
23
|
+
from nat.observability.mixin.batch_config_mixin import BatchConfigMixin
|
24
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch import ContractVersion
|
25
|
+
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
|
29
|
+
class DFWElasticsearchTelemetryExporter(TelemetryExporterBaseConfig,
|
30
|
+
BatchConfigMixin,
|
31
|
+
name="data_flywheel_elasticsearch"):
|
32
|
+
"""A telemetry exporter to transmit traces to NVIDIA Data Flywheel via Elasticsearch."""
|
33
|
+
|
34
|
+
client_id: str = Field(description="The data flywheel client ID.")
|
35
|
+
index: str = Field(description="The elasticsearch index name.")
|
36
|
+
endpoint: str = Field(description="The elasticsearch endpoint.")
|
37
|
+
contract_version: ContractVersion = Field(default=ContractVersion.V1_1,
|
38
|
+
description="The DFW Elasticsearch record schema version to use.")
|
39
|
+
username: str | None = Field(default=None, description="The elasticsearch username.")
|
40
|
+
password: str | None = Field(default=None, description="The elasticsearch password.")
|
41
|
+
headers: dict | None = Field(default=None, description="Additional headers for elasticsearch requests.")
|
42
|
+
|
43
|
+
|
44
|
+
@register_telemetry_exporter(config_type=DFWElasticsearchTelemetryExporter)
|
45
|
+
async def dfw_elasticsearch_telemetry_exporter(config: DFWElasticsearchTelemetryExporter, _builder: Builder):
|
46
|
+
# pylint: disable=import-outside-toplevel
|
47
|
+
from nat.plugins.data_flywheel.observability.exporter.dfw_elasticsearch_exporter import DFWElasticsearchExporter
|
48
|
+
|
49
|
+
elasticsearch_auth = (config.username, config.password) if config.username and config.password else ()
|
50
|
+
|
51
|
+
yield DFWElasticsearchExporter(client_id=config.client_id,
|
52
|
+
index=config.index,
|
53
|
+
endpoint=config.endpoint,
|
54
|
+
elasticsearch_auth=elasticsearch_auth,
|
55
|
+
headers=config.headers,
|
56
|
+
contract_version=config.contract_version,
|
57
|
+
batch_size=config.batch_size,
|
58
|
+
flush_interval=config.flush_interval,
|
59
|
+
max_queue_size=config.max_queue_size,
|
60
|
+
drop_on_overflow=config.drop_on_overflow,
|
61
|
+
shutdown_timeout=config.shutdown_timeout)
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
18
|
+
from nat.plugins.data_flywheel.observability.schema.provider.openai_trace_source import OpenAITraceSourceBase
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class NIMTraceSource(OpenAITraceSourceBase):
|
24
|
+
pass
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from typing import Any
|
17
|
+
|
18
|
+
from pydantic import BaseModel
|
19
|
+
from pydantic import Field
|
20
|
+
|
21
|
+
|
22
|
+
# LangChain message models for validation
|
23
|
+
class OpenAIMessage(BaseModel):
|
24
|
+
content: str | None = Field(default=None, description="The content of the message.")
|
25
|
+
additional_kwargs: dict[str, Any] = Field(default_factory=dict, description="Additional kwargs for the message.")
|
26
|
+
response_metadata: dict[str, Any] = Field(default_factory=dict, description="Response metadata for the message.")
|
27
|
+
type: str = Field(description="The type of the message.")
|
28
|
+
name: str | None = Field(default=None, description="The name of the message.")
|
29
|
+
id: str | None = None
|
30
|
+
example: bool = Field(default=False, description="Whether the message is an example.")
|
31
|
+
tool_call_id: str | None = Field(default=None, description="The tool call ID for the message.")
|
@@ -0,0 +1,95 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
from typing import Any
|
18
|
+
from typing import TypeVar
|
19
|
+
|
20
|
+
from pydantic import BaseModel
|
21
|
+
from pydantic import Field
|
22
|
+
from pydantic import field_validator
|
23
|
+
|
24
|
+
from nat.data_models.intermediate_step import ToolSchema
|
25
|
+
from nat.plugins.data_flywheel.observability.schema.provider.openai_message import OpenAIMessage
|
26
|
+
from nat.plugins.data_flywheel.observability.schema.trace_source_base import TraceSourceBase
|
27
|
+
from nat.plugins.data_flywheel.observability.utils.deserialize import deserialize_span_attribute
|
28
|
+
|
29
|
+
ProviderT = TypeVar("ProviderT")
|
30
|
+
|
31
|
+
logger = logging.getLogger(__name__)
|
32
|
+
|
33
|
+
|
34
|
+
class OpenAIMetadata(BaseModel):
|
35
|
+
"""Metadata for the OpenAITraceSource."""
|
36
|
+
|
37
|
+
tools_schema: list[ToolSchema] | None = Field(default=None,
|
38
|
+
description="The tools schema for the OpenAITraceSource.")
|
39
|
+
chat_responses: list[dict[str, Any]] | None = Field(default=None,
|
40
|
+
description="The chat responses for the OpenAITraceSource.")
|
41
|
+
|
42
|
+
|
43
|
+
class OpenAITraceSourceBase(TraceSourceBase):
|
44
|
+
"""Base class for the OpenAITraceSource."""
|
45
|
+
|
46
|
+
input_value: list[OpenAIMessage]
|
47
|
+
metadata: OpenAIMetadata
|
48
|
+
|
49
|
+
@field_validator("input_value", mode="before")
|
50
|
+
@classmethod
|
51
|
+
def validate_input_value(cls, v: Any) -> list[OpenAIMessage]:
|
52
|
+
"""Validate the input value for the OpenAITraceSource."""
|
53
|
+
if v is None:
|
54
|
+
raise ValueError("Input value is required")
|
55
|
+
|
56
|
+
# Handle string input (JSON string)
|
57
|
+
if isinstance(v, str):
|
58
|
+
v = deserialize_span_attribute(v)
|
59
|
+
|
60
|
+
# Handle dict input (single message)
|
61
|
+
if isinstance(v, dict):
|
62
|
+
v = [v]
|
63
|
+
|
64
|
+
# Validate list of messages
|
65
|
+
if isinstance(v, list):
|
66
|
+
validated_messages = []
|
67
|
+
for msg in v:
|
68
|
+
if isinstance(msg, dict):
|
69
|
+
validated_messages.append(OpenAIMessage(**msg))
|
70
|
+
elif isinstance(msg, OpenAIMessage):
|
71
|
+
validated_messages.append(msg)
|
72
|
+
else:
|
73
|
+
raise ValueError(f"Invalid message format: {msg}")
|
74
|
+
return validated_messages
|
75
|
+
|
76
|
+
raise ValueError(f"Invalid input_value format: {v}")
|
77
|
+
|
78
|
+
@field_validator("metadata", mode="before")
|
79
|
+
@classmethod
|
80
|
+
def validate_metadata(cls, v: Any) -> "OpenAIMetadata | dict[str, Any]":
|
81
|
+
"""Normalize metadata supplied as OpenAIMetadata, dict, or JSON string."""
|
82
|
+
if v is None:
|
83
|
+
return {}
|
84
|
+
if isinstance(v, OpenAIMetadata):
|
85
|
+
return v
|
86
|
+
if isinstance(v, str):
|
87
|
+
v = deserialize_span_attribute(v)
|
88
|
+
if isinstance(v, dict):
|
89
|
+
return v
|
90
|
+
raise ValueError(f"Invalid metadata format: {v!r}")
|
91
|
+
|
92
|
+
|
93
|
+
class OpenAITraceSource(OpenAITraceSourceBase):
|
94
|
+
"""Concrete OpenAI trace source."""
|
95
|
+
pass
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
# pylint: disable=unused-import
|
17
|
+
# flake8: noqa
|
18
|
+
# isort:skip_file
|
19
|
+
|
20
|
+
# Import any destinations contract schemas which need to be automatically registered here
|
21
|
+
from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch import dfw_es_record
|
@@ -0,0 +1,144 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
from typing import TypeVar
|
18
|
+
|
19
|
+
from pydantic import BaseModel
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
T = TypeVar('T', bound=BaseModel)
|
24
|
+
|
25
|
+
|
26
|
+
class SchemaRegistry:
|
27
|
+
"""Registry for managing schema contracts and versions."""
|
28
|
+
|
29
|
+
_schemas: dict[str, dict[str, type[BaseModel]]] = {}
|
30
|
+
|
31
|
+
@classmethod
|
32
|
+
def register(cls, name: str, version: str):
|
33
|
+
"""Decorator to register a schema class for a specific destination and version.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
name (str): The destination/exporter name (e.g., "elasticsearch")
|
37
|
+
version (str): The version string (e.g., "1.0", "1.1")
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
The decorator function
|
41
|
+
"""
|
42
|
+
|
43
|
+
def decorator(schema_cls: type[T]) -> type[T]:
|
44
|
+
if name not in cls._schemas:
|
45
|
+
cls._schemas[name] = {}
|
46
|
+
|
47
|
+
if version in cls._schemas[name]:
|
48
|
+
logger.warning("Overriding existing schema for %s:%s", name, version)
|
49
|
+
|
50
|
+
cls._schemas[name][version] = schema_cls
|
51
|
+
logger.debug("Registered schema %s for %s:%s", schema_cls.__name__, name, version)
|
52
|
+
|
53
|
+
return schema_cls
|
54
|
+
|
55
|
+
return decorator
|
56
|
+
|
57
|
+
@classmethod
|
58
|
+
def get_schema(cls, name: str, version: str) -> type[BaseModel]:
|
59
|
+
"""Get the schema class for a specific destination and version.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
name (str): The destination/exporter name (e.g., "elasticsearch")
|
63
|
+
version (str): The version string to look up
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
type[BaseModel]: The Pydantic model class for the requested destination and version
|
67
|
+
|
68
|
+
Raises:
|
69
|
+
KeyError: If the name:version combination is not registered.
|
70
|
+
"""
|
71
|
+
if name not in cls._schemas:
|
72
|
+
available_destinations = list(cls._schemas.keys())
|
73
|
+
raise KeyError(f"Destination '{name}' not found. "
|
74
|
+
f"Available destinations: {available_destinations}")
|
75
|
+
|
76
|
+
if version not in cls._schemas[name]:
|
77
|
+
available_versions = list(cls._schemas[name].keys())
|
78
|
+
raise KeyError(f"Version '{version}' not found for destination '{name}'. "
|
79
|
+
f"Available versions: {available_versions}")
|
80
|
+
|
81
|
+
return cls._schemas[name][version]
|
82
|
+
|
83
|
+
@classmethod
|
84
|
+
def get_available_schemas(cls) -> list[str]:
|
85
|
+
"""Get all registered schema name:version combinations.
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
list[str]: List of registered schema keys in "name:version" format
|
89
|
+
"""
|
90
|
+
schemas = []
|
91
|
+
for name, versions in cls._schemas.items():
|
92
|
+
for version in versions.keys():
|
93
|
+
schemas.append(f"{name}:{version}")
|
94
|
+
return schemas
|
95
|
+
|
96
|
+
@classmethod
|
97
|
+
def get_schemas_for_destination(cls, name: str) -> list[str]:
|
98
|
+
"""Get all registered schema versions for a specific destination.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
name (str): The destination/exporter name
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
list[str]: List of version strings for the specified destination
|
105
|
+
"""
|
106
|
+
if name not in cls._schemas:
|
107
|
+
return []
|
108
|
+
return list(cls._schemas[name].keys())
|
109
|
+
|
110
|
+
@classmethod
|
111
|
+
def get_available_destinations(cls) -> list[str]:
|
112
|
+
"""Get all registered destination names.
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
list[str]: List of registered destination names
|
116
|
+
"""
|
117
|
+
return list(cls._schemas.keys())
|
118
|
+
|
119
|
+
@classmethod
|
120
|
+
def is_registered(cls, name: str, version: str) -> bool:
|
121
|
+
"""Check if a name:version combination is registered.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
name (str): The destination/exporter name
|
125
|
+
version (str): The version string to check
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
bool: True if the name:version is registered, False otherwise
|
129
|
+
"""
|
130
|
+
return name in cls._schemas and version in cls._schemas[name]
|
131
|
+
|
132
|
+
@classmethod
|
133
|
+
def clear(cls) -> None:
|
134
|
+
"""Clear all registered schemas."""
|
135
|
+
cls._schemas.clear()
|
136
|
+
|
137
|
+
|
138
|
+
# Convenience aliases for more concise usage
|
139
|
+
register_schema = SchemaRegistry.register
|
140
|
+
get_schema = SchemaRegistry.get_schema
|
141
|
+
get_available_schemas = SchemaRegistry.get_available_schemas
|
142
|
+
get_available_destinations = SchemaRegistry.get_available_destinations
|
143
|
+
get_schemas_for_destination = SchemaRegistry.get_schemas_for_destination
|
144
|
+
is_registered = SchemaRegistry.is_registered
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|