ioa-observe-sdk 1.0.15__py3-none-any.whl → 1.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ioa_observe/sdk/__init__.py +3 -3
- ioa_observe/sdk/client/client.py +3 -3
- ioa_observe/sdk/decorators/__init__.py +2 -2
- ioa_observe/sdk/decorators/base.py +2 -2
- ioa_observe/sdk/instrumentations/a2a.py +84 -31
- ioa_observe/sdk/instrumentations/mcp.py +494 -0
- ioa_observe/sdk/instrumentations/slim.py +376 -128
- ioa_observe/sdk/tracing/tracing.py +214 -9
- ioa_observe/sdk/tracing/transform_span.py +210 -0
- ioa_observe/sdk/utils/const.py +7 -0
- {ioa_observe_sdk-1.0.15.dist-info → ioa_observe_sdk-1.0.17.dist-info}/METADATA +3 -1
- {ioa_observe_sdk-1.0.15.dist-info → ioa_observe_sdk-1.0.17.dist-info}/RECORD +15 -13
- {ioa_observe_sdk-1.0.15.dist-info → ioa_observe_sdk-1.0.17.dist-info}/WHEEL +0 -0
- {ioa_observe_sdk-1.0.15.dist-info → ioa_observe_sdk-1.0.17.dist-info}/licenses/LICENSE.md +0 -0
- {ioa_observe_sdk-1.0.15.dist-info → ioa_observe_sdk-1.0.17.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,494 @@
|
|
|
1
|
+
# Copyright AGNTCY Contributors (https://github.com/agntcy)
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, AsyncGenerator, Callable, Collection, Tuple, cast, Union
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import traceback
|
|
10
|
+
import re
|
|
11
|
+
from http import HTTPStatus
|
|
12
|
+
|
|
13
|
+
from opentelemetry import context, propagate
|
|
14
|
+
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
|
|
15
|
+
from opentelemetry.instrumentation.utils import unwrap
|
|
16
|
+
from opentelemetry.trace import get_tracer, Tracer
|
|
17
|
+
from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper
|
|
18
|
+
from opentelemetry.trace.status import Status, StatusCode
|
|
19
|
+
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
|
20
|
+
from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
|
|
21
|
+
|
|
22
|
+
from ..utils.const import (
|
|
23
|
+
MCP_REQUEST_ID,
|
|
24
|
+
MCP_METHOD_NAME,
|
|
25
|
+
MCP_REQUEST_ARGUMENT,
|
|
26
|
+
MCP_RESPONSE_VALUE,
|
|
27
|
+
OBSERVE_ENTITY_OUTPUT,
|
|
28
|
+
OBSERVE_ENTITY_INPUT,
|
|
29
|
+
)
|
|
30
|
+
from ..version import __version__
|
|
31
|
+
|
|
32
|
+
_instruments = ("mcp >= 1.6.0",)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Config:
|
|
36
|
+
exception_logger = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def dont_throw(func):
|
|
40
|
+
"""
|
|
41
|
+
A decorator that wraps the passed in function and logs exceptions instead of throwing them.
|
|
42
|
+
|
|
43
|
+
@param func: The function to wrap
|
|
44
|
+
@return: The wrapper function
|
|
45
|
+
"""
|
|
46
|
+
# Obtain a logger specific to the function's module
|
|
47
|
+
logger = logging.getLogger(func.__module__)
|
|
48
|
+
|
|
49
|
+
def wrapper(*args, **kwargs):
|
|
50
|
+
try:
|
|
51
|
+
return func(*args, **kwargs)
|
|
52
|
+
except Exception as e:
|
|
53
|
+
logger.debug(
|
|
54
|
+
"failed to trace in %s, error: %s",
|
|
55
|
+
func.__name__,
|
|
56
|
+
traceback.format_exc(),
|
|
57
|
+
)
|
|
58
|
+
if Config.exception_logger:
|
|
59
|
+
Config.exception_logger(e)
|
|
60
|
+
|
|
61
|
+
return wrapper
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class McpInstrumentor(BaseInstrumentor):
|
|
65
|
+
def instrumentation_dependencies(self) -> Collection[str]:
|
|
66
|
+
return _instruments
|
|
67
|
+
|
|
68
|
+
def _instrument(self, **kwargs):
|
|
69
|
+
tracer_provider = kwargs.get("tracer_provider")
|
|
70
|
+
tracer = get_tracer(__name__, __version__, tracer_provider)
|
|
71
|
+
|
|
72
|
+
register_post_import_hook(
|
|
73
|
+
lambda _: wrap_function_wrapper(
|
|
74
|
+
"mcp.client.sse", "sse_client", self._transport_wrapper(tracer)
|
|
75
|
+
),
|
|
76
|
+
"mcp.client.sse",
|
|
77
|
+
)
|
|
78
|
+
register_post_import_hook(
|
|
79
|
+
lambda _: wrap_function_wrapper(
|
|
80
|
+
"mcp.server.sse",
|
|
81
|
+
"SseServerTransport.connect_sse",
|
|
82
|
+
self._transport_wrapper(tracer),
|
|
83
|
+
),
|
|
84
|
+
"mcp.server.sse",
|
|
85
|
+
)
|
|
86
|
+
register_post_import_hook(
|
|
87
|
+
lambda _: wrap_function_wrapper(
|
|
88
|
+
"mcp.client.stdio", "stdio_client", self._transport_wrapper(tracer)
|
|
89
|
+
),
|
|
90
|
+
"mcp.client.stdio",
|
|
91
|
+
)
|
|
92
|
+
register_post_import_hook(
|
|
93
|
+
lambda _: wrap_function_wrapper(
|
|
94
|
+
"mcp.server.stdio", "stdio_server", self._transport_wrapper(tracer)
|
|
95
|
+
),
|
|
96
|
+
"mcp.server.stdio",
|
|
97
|
+
)
|
|
98
|
+
register_post_import_hook(
|
|
99
|
+
lambda _: wrap_function_wrapper(
|
|
100
|
+
"mcp.server.session",
|
|
101
|
+
"ServerSession.__init__",
|
|
102
|
+
self._base_session_init_wrapper(tracer),
|
|
103
|
+
),
|
|
104
|
+
"mcp.server.session",
|
|
105
|
+
)
|
|
106
|
+
register_post_import_hook(
|
|
107
|
+
lambda _: wrap_function_wrapper(
|
|
108
|
+
"mcp.client.streamable_http",
|
|
109
|
+
"streamablehttp_client",
|
|
110
|
+
self._transport_wrapper(tracer),
|
|
111
|
+
),
|
|
112
|
+
"mcp.client.streamable_http",
|
|
113
|
+
)
|
|
114
|
+
register_post_import_hook(
|
|
115
|
+
lambda _: wrap_function_wrapper(
|
|
116
|
+
"mcp.server.streamable_http",
|
|
117
|
+
"StreamableHTTPServerTransport.connect",
|
|
118
|
+
self._transport_wrapper(tracer),
|
|
119
|
+
),
|
|
120
|
+
"mcp.server.streamable_http",
|
|
121
|
+
)
|
|
122
|
+
wrap_function_wrapper(
|
|
123
|
+
"mcp.shared.session",
|
|
124
|
+
"BaseSession.send_request",
|
|
125
|
+
self.patch_mcp_client(tracer),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def _uninstrument(self, **kwargs):
|
|
129
|
+
unwrap("mcp.client.stdio", "stdio_client")
|
|
130
|
+
unwrap("mcp.server.stdio", "stdio_server")
|
|
131
|
+
|
|
132
|
+
def _transport_wrapper(self, tracer):
|
|
133
|
+
@asynccontextmanager
|
|
134
|
+
async def traced_method(
|
|
135
|
+
wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any
|
|
136
|
+
) -> AsyncGenerator[
|
|
137
|
+
Union[
|
|
138
|
+
Tuple[InstrumentedStreamReader, InstrumentedStreamWriter],
|
|
139
|
+
Tuple[InstrumentedStreamReader, InstrumentedStreamWriter, Any],
|
|
140
|
+
],
|
|
141
|
+
None,
|
|
142
|
+
]:
|
|
143
|
+
async with wrapped(*args, **kwargs) as result:
|
|
144
|
+
try:
|
|
145
|
+
read_stream, write_stream = result
|
|
146
|
+
yield (
|
|
147
|
+
InstrumentedStreamReader(read_stream, tracer),
|
|
148
|
+
InstrumentedStreamWriter(write_stream, tracer),
|
|
149
|
+
)
|
|
150
|
+
except ValueError:
|
|
151
|
+
try:
|
|
152
|
+
read_stream, write_stream, get_session_id_callback = result
|
|
153
|
+
yield (
|
|
154
|
+
InstrumentedStreamReader(read_stream, tracer),
|
|
155
|
+
InstrumentedStreamWriter(write_stream, tracer),
|
|
156
|
+
get_session_id_callback,
|
|
157
|
+
)
|
|
158
|
+
except Exception as e:
|
|
159
|
+
logging.warning(
|
|
160
|
+
f"mcp instrumentation _transport_wrapper exception: {e}"
|
|
161
|
+
)
|
|
162
|
+
yield result
|
|
163
|
+
except Exception as e:
|
|
164
|
+
logging.warning(
|
|
165
|
+
f"mcp instrumentation transport_wrapper exception: {e}"
|
|
166
|
+
)
|
|
167
|
+
yield result
|
|
168
|
+
|
|
169
|
+
return traced_method
|
|
170
|
+
|
|
171
|
+
def _base_session_init_wrapper(self, tracer):
|
|
172
|
+
def traced_method(
|
|
173
|
+
wrapped: Callable[..., None], instance: Any, args: Any, kwargs: Any
|
|
174
|
+
) -> None:
|
|
175
|
+
wrapped(*args, **kwargs)
|
|
176
|
+
reader = getattr(instance, "_incoming_message_stream_reader", None)
|
|
177
|
+
writer = getattr(instance, "_incoming_message_stream_writer", None)
|
|
178
|
+
if reader and writer:
|
|
179
|
+
setattr(
|
|
180
|
+
instance,
|
|
181
|
+
"_incoming_message_stream_reader",
|
|
182
|
+
ContextAttachingStreamReader(reader, tracer),
|
|
183
|
+
)
|
|
184
|
+
setattr(
|
|
185
|
+
instance,
|
|
186
|
+
"_incoming_message_stream_writer",
|
|
187
|
+
ContextSavingStreamWriter(writer, tracer),
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
return traced_method
|
|
191
|
+
|
|
192
|
+
def patch_mcp_client(self, tracer: Tracer):
|
|
193
|
+
@dont_throw
|
|
194
|
+
async def traced_method(wrapped, instance, args, kwargs):
|
|
195
|
+
meta = None
|
|
196
|
+
method = None
|
|
197
|
+
params = None
|
|
198
|
+
if len(args) > 0 and hasattr(args[0].root, "method"):
|
|
199
|
+
method = args[0].root.method
|
|
200
|
+
if len(args) > 0 and hasattr(args[0].root, "params"):
|
|
201
|
+
params = args[0].root.params
|
|
202
|
+
if params:
|
|
203
|
+
if hasattr(args[0].root.params, "meta"):
|
|
204
|
+
meta = args[0].root.params.meta
|
|
205
|
+
|
|
206
|
+
with tracer.start_as_current_span(f"{method}.mcp") as span:
|
|
207
|
+
span.set_attribute(OBSERVE_ENTITY_INPUT, f"{serialize(args[0])}")
|
|
208
|
+
from ioa_observe.sdk.client import kv_store
|
|
209
|
+
from ioa_observe.sdk.tracing import get_current_traceparent
|
|
210
|
+
|
|
211
|
+
traceparent = get_current_traceparent()
|
|
212
|
+
session_id = None
|
|
213
|
+
if traceparent:
|
|
214
|
+
session_id = kv_store.get(f"execution.{traceparent}")
|
|
215
|
+
if session_id:
|
|
216
|
+
kv_store.set(f"execution.{traceparent}", session_id)
|
|
217
|
+
|
|
218
|
+
meta = meta or {}
|
|
219
|
+
if isinstance(meta, dict):
|
|
220
|
+
meta["session.id"] = session_id
|
|
221
|
+
meta["traceparent"] = traceparent
|
|
222
|
+
else:
|
|
223
|
+
# If meta is an object, convert it to a dict
|
|
224
|
+
meta = {
|
|
225
|
+
"session.id": session_id,
|
|
226
|
+
"traceparent": traceparent,
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
if meta and len(args) > 0:
|
|
230
|
+
carrier = {}
|
|
231
|
+
TraceContextTextMapPropagator().inject(carrier)
|
|
232
|
+
|
|
233
|
+
args[0].root.params.meta = meta
|
|
234
|
+
|
|
235
|
+
try:
|
|
236
|
+
result = await wrapped(*args, **kwargs)
|
|
237
|
+
span.set_attribute(
|
|
238
|
+
OBSERVE_ENTITY_OUTPUT,
|
|
239
|
+
serialize(result),
|
|
240
|
+
)
|
|
241
|
+
if hasattr(result, "isError") and result.isError:
|
|
242
|
+
if len(result.content) > 0:
|
|
243
|
+
span.set_status(
|
|
244
|
+
Status(StatusCode.ERROR, f"{result.content[0].text}")
|
|
245
|
+
)
|
|
246
|
+
error_type = get_error_type(result.content[0].text)
|
|
247
|
+
if error_type is not None:
|
|
248
|
+
span.set_attribute(ERROR_TYPE, error_type)
|
|
249
|
+
else:
|
|
250
|
+
span.set_status(Status(StatusCode.OK))
|
|
251
|
+
return result
|
|
252
|
+
except Exception as e:
|
|
253
|
+
span.set_attribute(ERROR_TYPE, type(e).__name__)
|
|
254
|
+
span.record_exception(e)
|
|
255
|
+
span.set_status(Status(StatusCode.ERROR, str(e)))
|
|
256
|
+
raise
|
|
257
|
+
|
|
258
|
+
return traced_method
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def get_error_type(error_message):
|
|
262
|
+
if not isinstance(error_message, str):
|
|
263
|
+
return None
|
|
264
|
+
match = re.search(r"\b(4\d{2}|5\d{2})\b", error_message)
|
|
265
|
+
if match:
|
|
266
|
+
num = int(match.group())
|
|
267
|
+
if 400 <= num <= 599:
|
|
268
|
+
return HTTPStatus(num).name
|
|
269
|
+
else:
|
|
270
|
+
return None
|
|
271
|
+
else:
|
|
272
|
+
return None
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def serialize(request, depth=0, max_depth=4):
|
|
276
|
+
"""Serialize input args to MCP server into JSON.
|
|
277
|
+
The function accepts input object and converts into JSON
|
|
278
|
+
keeping depth in mind to prevent creating large nested JSON"""
|
|
279
|
+
if depth > max_depth:
|
|
280
|
+
return {}
|
|
281
|
+
depth += 1
|
|
282
|
+
|
|
283
|
+
def is_serializable(request):
|
|
284
|
+
try:
|
|
285
|
+
json.dumps(request)
|
|
286
|
+
return True
|
|
287
|
+
except Exception:
|
|
288
|
+
return False
|
|
289
|
+
|
|
290
|
+
if is_serializable(request):
|
|
291
|
+
return json.dumps(request)
|
|
292
|
+
else:
|
|
293
|
+
result = {}
|
|
294
|
+
try:
|
|
295
|
+
if hasattr(request, "model_dump_json"):
|
|
296
|
+
return request.model_dump_json()
|
|
297
|
+
if hasattr(request, "__dict__"):
|
|
298
|
+
for attrib in request.__dict__:
|
|
299
|
+
if not attrib.startswith("_"):
|
|
300
|
+
if type(request.__dict__[attrib]) in [
|
|
301
|
+
bool,
|
|
302
|
+
str,
|
|
303
|
+
int,
|
|
304
|
+
float,
|
|
305
|
+
type(None),
|
|
306
|
+
]:
|
|
307
|
+
result[str(attrib)] = request.__dict__[attrib]
|
|
308
|
+
else:
|
|
309
|
+
result[str(attrib)] = serialize(
|
|
310
|
+
request.__dict__[attrib], depth
|
|
311
|
+
)
|
|
312
|
+
except Exception:
|
|
313
|
+
pass
|
|
314
|
+
return json.dumps(result)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class InstrumentedStreamReader(ObjectProxy): # type: ignore
|
|
318
|
+
def __init__(self, wrapped, tracer):
|
|
319
|
+
super().__init__(wrapped)
|
|
320
|
+
self._tracer = tracer
|
|
321
|
+
|
|
322
|
+
async def __aenter__(self) -> Any:
|
|
323
|
+
return await self.__wrapped__.__aenter__()
|
|
324
|
+
|
|
325
|
+
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
|
|
326
|
+
return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
|
|
327
|
+
|
|
328
|
+
@dont_throw
|
|
329
|
+
async def __aiter__(self) -> AsyncGenerator[Any, None]:
|
|
330
|
+
from mcp.types import JSONRPCMessage, JSONRPCRequest
|
|
331
|
+
from mcp.shared.message import SessionMessage
|
|
332
|
+
|
|
333
|
+
async for item in self.__wrapped__:
|
|
334
|
+
if isinstance(item, SessionMessage):
|
|
335
|
+
request = cast(JSONRPCMessage, item.message).root
|
|
336
|
+
elif type(item) is JSONRPCMessage:
|
|
337
|
+
request = cast(JSONRPCMessage, item).root
|
|
338
|
+
else:
|
|
339
|
+
yield item
|
|
340
|
+
continue
|
|
341
|
+
|
|
342
|
+
if not isinstance(request, JSONRPCRequest):
|
|
343
|
+
yield item
|
|
344
|
+
continue
|
|
345
|
+
|
|
346
|
+
if request.params:
|
|
347
|
+
# Check both _meta and meta fields
|
|
348
|
+
meta = request.params.get("_meta") or getattr(
|
|
349
|
+
request.params, "meta", None
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
if meta:
|
|
353
|
+
if isinstance(meta, dict):
|
|
354
|
+
session_id = meta.get("session.id")
|
|
355
|
+
traceparent = meta.get("traceparent")
|
|
356
|
+
carrier = meta
|
|
357
|
+
else:
|
|
358
|
+
session_id = getattr(meta, "session.id", None)
|
|
359
|
+
traceparent = getattr(meta, "traceparent", None)
|
|
360
|
+
# Convert object to dict for propagate.extract
|
|
361
|
+
carrier = {}
|
|
362
|
+
if session_id:
|
|
363
|
+
carrier["session.id"] = session_id
|
|
364
|
+
|
|
365
|
+
if carrier and traceparent:
|
|
366
|
+
ctx = propagate.extract(carrier)
|
|
367
|
+
|
|
368
|
+
# Add session_id extraction and storage like in a2a.py
|
|
369
|
+
if session_id and session_id != "None":
|
|
370
|
+
from ioa_observe.sdk.client import kv_store
|
|
371
|
+
from ioa_observe.sdk.tracing import set_session_id
|
|
372
|
+
|
|
373
|
+
set_session_id(session_id, traceparent=traceparent)
|
|
374
|
+
kv_store.set(f"execution.{traceparent}", session_id)
|
|
375
|
+
|
|
376
|
+
restore = context.attach(ctx)
|
|
377
|
+
try:
|
|
378
|
+
yield item
|
|
379
|
+
continue
|
|
380
|
+
finally:
|
|
381
|
+
context.detach(restore)
|
|
382
|
+
yield item
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
class InstrumentedStreamWriter(ObjectProxy): # type: ignore
|
|
386
|
+
def __init__(self, wrapped, tracer):
|
|
387
|
+
super().__init__(wrapped)
|
|
388
|
+
self._tracer = tracer
|
|
389
|
+
|
|
390
|
+
async def __aenter__(self) -> Any:
|
|
391
|
+
return await self.__wrapped__.__aenter__()
|
|
392
|
+
|
|
393
|
+
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
|
|
394
|
+
return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
|
|
395
|
+
|
|
396
|
+
@dont_throw
|
|
397
|
+
async def send(self, item: Any) -> Any:
|
|
398
|
+
from mcp.types import JSONRPCMessage, JSONRPCRequest
|
|
399
|
+
from mcp.shared.message import SessionMessage
|
|
400
|
+
|
|
401
|
+
if isinstance(item, SessionMessage):
|
|
402
|
+
request = cast(JSONRPCMessage, item.message).root
|
|
403
|
+
elif type(item) is JSONRPCMessage:
|
|
404
|
+
request = cast(JSONRPCMessage, item).root
|
|
405
|
+
else:
|
|
406
|
+
return
|
|
407
|
+
|
|
408
|
+
with self._tracer.start_as_current_span("ResponseStreamWriter") as span:
|
|
409
|
+
if hasattr(request, "result"):
|
|
410
|
+
span.set_attribute(MCP_RESPONSE_VALUE, f"{serialize(request.result)}")
|
|
411
|
+
if "isError" in request.result:
|
|
412
|
+
if request.result["isError"] is True:
|
|
413
|
+
span.set_status(
|
|
414
|
+
Status(
|
|
415
|
+
StatusCode.ERROR,
|
|
416
|
+
f"{request.result['content'][0]['text']}",
|
|
417
|
+
)
|
|
418
|
+
)
|
|
419
|
+
error_type = get_error_type(
|
|
420
|
+
request.result["content"][0]["text"]
|
|
421
|
+
)
|
|
422
|
+
if error_type is not None:
|
|
423
|
+
span.set_attribute(ERROR_TYPE, error_type)
|
|
424
|
+
if hasattr(request, "id"):
|
|
425
|
+
span.set_attribute(MCP_REQUEST_ID, f"{request.id}")
|
|
426
|
+
|
|
427
|
+
if not isinstance(request, JSONRPCRequest):
|
|
428
|
+
return await self.__wrapped__.send(item)
|
|
429
|
+
meta = None
|
|
430
|
+
if not request.params:
|
|
431
|
+
request.params = {}
|
|
432
|
+
meta = request.params.setdefault("_meta", {})
|
|
433
|
+
|
|
434
|
+
propagate.get_global_textmap().inject(meta)
|
|
435
|
+
return await self.__wrapped__.send(item)
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
@dataclass(slots=True, frozen=True)
|
|
439
|
+
class ItemWithContext:
|
|
440
|
+
item: Any
|
|
441
|
+
ctx: context.Context
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
class ContextSavingStreamWriter(ObjectProxy): # type: ignore
|
|
445
|
+
def __init__(self, wrapped, tracer):
|
|
446
|
+
super().__init__(wrapped)
|
|
447
|
+
self._tracer = tracer
|
|
448
|
+
|
|
449
|
+
async def __aenter__(self) -> Any:
|
|
450
|
+
return await self.__wrapped__.__aenter__()
|
|
451
|
+
|
|
452
|
+
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
|
|
453
|
+
return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
|
|
454
|
+
|
|
455
|
+
@dont_throw
|
|
456
|
+
async def send(self, item: Any) -> Any:
|
|
457
|
+
with self._tracer.start_as_current_span("RequestStreamWriter") as span:
|
|
458
|
+
if hasattr(item, "request_id"):
|
|
459
|
+
span.set_attribute(MCP_REQUEST_ID, f"{item.request_id}")
|
|
460
|
+
if hasattr(item, "request"):
|
|
461
|
+
if hasattr(item.request, "root"):
|
|
462
|
+
if hasattr(item.request.root, "method"):
|
|
463
|
+
span.set_attribute(
|
|
464
|
+
MCP_METHOD_NAME,
|
|
465
|
+
f"{item.request.root.method}",
|
|
466
|
+
)
|
|
467
|
+
if hasattr(item.request.root, "params"):
|
|
468
|
+
span.set_attribute(
|
|
469
|
+
MCP_REQUEST_ARGUMENT,
|
|
470
|
+
f"{serialize(item.request.root.params)}",
|
|
471
|
+
)
|
|
472
|
+
ctx = context.get_current()
|
|
473
|
+
return await self.__wrapped__.send(ItemWithContext(item, ctx))
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
class ContextAttachingStreamReader(ObjectProxy): # type: ignore
|
|
477
|
+
def __init__(self, wrapped, tracer):
|
|
478
|
+
super().__init__(wrapped)
|
|
479
|
+
self._tracer = tracer
|
|
480
|
+
|
|
481
|
+
async def __aenter__(self) -> Any:
|
|
482
|
+
return await self.__wrapped__.__aenter__()
|
|
483
|
+
|
|
484
|
+
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
|
|
485
|
+
return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
|
|
486
|
+
|
|
487
|
+
async def __aiter__(self) -> AsyncGenerator[Any, None]:
|
|
488
|
+
async for item in self.__wrapped__:
|
|
489
|
+
item_with_context = cast(ItemWithContext, item)
|
|
490
|
+
restore = context.attach(item_with_context.ctx)
|
|
491
|
+
try:
|
|
492
|
+
yield item_with_context.item
|
|
493
|
+
finally:
|
|
494
|
+
context.detach(restore)
|