ioa-observe-sdk 1.0.14__py3-none-any.whl → 1.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,494 @@
1
+ # Copyright AGNTCY Contributors (https://github.com/agntcy)
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from contextlib import asynccontextmanager
5
+ from dataclasses import dataclass
6
+ from typing import Any, AsyncGenerator, Callable, Collection, Tuple, cast, Union
7
+ import json
8
+ import logging
9
+ import traceback
10
+ import re
11
+ from http import HTTPStatus
12
+
13
+ from opentelemetry import context, propagate
14
+ from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
15
+ from opentelemetry.instrumentation.utils import unwrap
16
+ from opentelemetry.trace import get_tracer, Tracer
17
+ from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper
18
+ from opentelemetry.trace.status import Status, StatusCode
19
+ from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
20
+ from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
21
+
22
+ from ..utils.const import (
23
+ MCP_REQUEST_ID,
24
+ MCP_METHOD_NAME,
25
+ MCP_REQUEST_ARGUMENT,
26
+ MCP_RESPONSE_VALUE,
27
+ OBSERVE_ENTITY_OUTPUT,
28
+ OBSERVE_ENTITY_INPUT,
29
+ )
30
+ from ..version import __version__
31
+
32
+ _instruments = ("mcp >= 1.6.0",)
33
+
34
+
35
+ class Config:
36
+ exception_logger = None
37
+
38
+
39
+ def dont_throw(func):
40
+ """
41
+ A decorator that wraps the passed in function and logs exceptions instead of throwing them.
42
+
43
+ @param func: The function to wrap
44
+ @return: The wrapper function
45
+ """
46
+ # Obtain a logger specific to the function's module
47
+ logger = logging.getLogger(func.__module__)
48
+
49
+ def wrapper(*args, **kwargs):
50
+ try:
51
+ return func(*args, **kwargs)
52
+ except Exception as e:
53
+ logger.debug(
54
+ "failed to trace in %s, error: %s",
55
+ func.__name__,
56
+ traceback.format_exc(),
57
+ )
58
+ if Config.exception_logger:
59
+ Config.exception_logger(e)
60
+
61
+ return wrapper
62
+
63
+
64
+ class McpInstrumentor(BaseInstrumentor):
65
+ def instrumentation_dependencies(self) -> Collection[str]:
66
+ return _instruments
67
+
68
+ def _instrument(self, **kwargs):
69
+ tracer_provider = kwargs.get("tracer_provider")
70
+ tracer = get_tracer(__name__, __version__, tracer_provider)
71
+
72
+ register_post_import_hook(
73
+ lambda _: wrap_function_wrapper(
74
+ "mcp.client.sse", "sse_client", self._transport_wrapper(tracer)
75
+ ),
76
+ "mcp.client.sse",
77
+ )
78
+ register_post_import_hook(
79
+ lambda _: wrap_function_wrapper(
80
+ "mcp.server.sse",
81
+ "SseServerTransport.connect_sse",
82
+ self._transport_wrapper(tracer),
83
+ ),
84
+ "mcp.server.sse",
85
+ )
86
+ register_post_import_hook(
87
+ lambda _: wrap_function_wrapper(
88
+ "mcp.client.stdio", "stdio_client", self._transport_wrapper(tracer)
89
+ ),
90
+ "mcp.client.stdio",
91
+ )
92
+ register_post_import_hook(
93
+ lambda _: wrap_function_wrapper(
94
+ "mcp.server.stdio", "stdio_server", self._transport_wrapper(tracer)
95
+ ),
96
+ "mcp.server.stdio",
97
+ )
98
+ register_post_import_hook(
99
+ lambda _: wrap_function_wrapper(
100
+ "mcp.server.session",
101
+ "ServerSession.__init__",
102
+ self._base_session_init_wrapper(tracer),
103
+ ),
104
+ "mcp.server.session",
105
+ )
106
+ register_post_import_hook(
107
+ lambda _: wrap_function_wrapper(
108
+ "mcp.client.streamable_http",
109
+ "streamablehttp_client",
110
+ self._transport_wrapper(tracer),
111
+ ),
112
+ "mcp.client.streamable_http",
113
+ )
114
+ register_post_import_hook(
115
+ lambda _: wrap_function_wrapper(
116
+ "mcp.server.streamable_http",
117
+ "StreamableHTTPServerTransport.connect",
118
+ self._transport_wrapper(tracer),
119
+ ),
120
+ "mcp.server.streamable_http",
121
+ )
122
+ wrap_function_wrapper(
123
+ "mcp.shared.session",
124
+ "BaseSession.send_request",
125
+ self.patch_mcp_client(tracer),
126
+ )
127
+
128
+ def _uninstrument(self, **kwargs):
129
+ unwrap("mcp.client.stdio", "stdio_client")
130
+ unwrap("mcp.server.stdio", "stdio_server")
131
+
132
+ def _transport_wrapper(self, tracer):
133
+ @asynccontextmanager
134
+ async def traced_method(
135
+ wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any
136
+ ) -> AsyncGenerator[
137
+ Union[
138
+ Tuple[InstrumentedStreamReader, InstrumentedStreamWriter],
139
+ Tuple[InstrumentedStreamReader, InstrumentedStreamWriter, Any],
140
+ ],
141
+ None,
142
+ ]:
143
+ async with wrapped(*args, **kwargs) as result:
144
+ try:
145
+ read_stream, write_stream = result
146
+ yield (
147
+ InstrumentedStreamReader(read_stream, tracer),
148
+ InstrumentedStreamWriter(write_stream, tracer),
149
+ )
150
+ except ValueError:
151
+ try:
152
+ read_stream, write_stream, get_session_id_callback = result
153
+ yield (
154
+ InstrumentedStreamReader(read_stream, tracer),
155
+ InstrumentedStreamWriter(write_stream, tracer),
156
+ get_session_id_callback,
157
+ )
158
+ except Exception as e:
159
+ logging.warning(
160
+ f"mcp instrumentation _transport_wrapper exception: {e}"
161
+ )
162
+ yield result
163
+ except Exception as e:
164
+ logging.warning(
165
+ f"mcp instrumentation transport_wrapper exception: {e}"
166
+ )
167
+ yield result
168
+
169
+ return traced_method
170
+
171
+ def _base_session_init_wrapper(self, tracer):
172
+ def traced_method(
173
+ wrapped: Callable[..., None], instance: Any, args: Any, kwargs: Any
174
+ ) -> None:
175
+ wrapped(*args, **kwargs)
176
+ reader = getattr(instance, "_incoming_message_stream_reader", None)
177
+ writer = getattr(instance, "_incoming_message_stream_writer", None)
178
+ if reader and writer:
179
+ setattr(
180
+ instance,
181
+ "_incoming_message_stream_reader",
182
+ ContextAttachingStreamReader(reader, tracer),
183
+ )
184
+ setattr(
185
+ instance,
186
+ "_incoming_message_stream_writer",
187
+ ContextSavingStreamWriter(writer, tracer),
188
+ )
189
+
190
+ return traced_method
191
+
192
+ def patch_mcp_client(self, tracer: Tracer):
193
+ @dont_throw
194
+ async def traced_method(wrapped, instance, args, kwargs):
195
+ meta = None
196
+ method = None
197
+ params = None
198
+ if len(args) > 0 and hasattr(args[0].root, "method"):
199
+ method = args[0].root.method
200
+ if len(args) > 0 and hasattr(args[0].root, "params"):
201
+ params = args[0].root.params
202
+ if params:
203
+ if hasattr(args[0].root.params, "meta"):
204
+ meta = args[0].root.params.meta
205
+
206
+ with tracer.start_as_current_span(f"{method}.mcp") as span:
207
+ span.set_attribute(OBSERVE_ENTITY_INPUT, f"{serialize(args[0])}")
208
+ from ioa_observe.sdk.client import kv_store
209
+ from ioa_observe.sdk.tracing import get_current_traceparent
210
+
211
+ traceparent = get_current_traceparent()
212
+ session_id = None
213
+ if traceparent:
214
+ session_id = kv_store.get(f"execution.{traceparent}")
215
+ if session_id:
216
+ kv_store.set(f"execution.{traceparent}", session_id)
217
+
218
+ meta = meta or {}
219
+ if isinstance(meta, dict):
220
+ meta["session.id"] = session_id
221
+ meta["traceparent"] = traceparent
222
+ else:
223
+ # If meta is an object, convert it to a dict
224
+ meta = {
225
+ "session.id": session_id,
226
+ "traceparent": traceparent,
227
+ }
228
+
229
+ if meta and len(args) > 0:
230
+ carrier = {}
231
+ TraceContextTextMapPropagator().inject(carrier)
232
+
233
+ args[0].root.params.meta = meta
234
+
235
+ try:
236
+ result = await wrapped(*args, **kwargs)
237
+ span.set_attribute(
238
+ OBSERVE_ENTITY_OUTPUT,
239
+ serialize(result),
240
+ )
241
+ if hasattr(result, "isError") and result.isError:
242
+ if len(result.content) > 0:
243
+ span.set_status(
244
+ Status(StatusCode.ERROR, f"{result.content[0].text}")
245
+ )
246
+ error_type = get_error_type(result.content[0].text)
247
+ if error_type is not None:
248
+ span.set_attribute(ERROR_TYPE, error_type)
249
+ else:
250
+ span.set_status(Status(StatusCode.OK))
251
+ return result
252
+ except Exception as e:
253
+ span.set_attribute(ERROR_TYPE, type(e).__name__)
254
+ span.record_exception(e)
255
+ span.set_status(Status(StatusCode.ERROR, str(e)))
256
+ raise
257
+
258
+ return traced_method
259
+
260
+
261
+ def get_error_type(error_message):
262
+ if not isinstance(error_message, str):
263
+ return None
264
+ match = re.search(r"\b(4\d{2}|5\d{2})\b", error_message)
265
+ if match:
266
+ num = int(match.group())
267
+ if 400 <= num <= 599:
268
+ return HTTPStatus(num).name
269
+ else:
270
+ return None
271
+ else:
272
+ return None
273
+
274
+
275
+ def serialize(request, depth=0, max_depth=4):
276
+ """Serialize input args to MCP server into JSON.
277
+ The function accepts input object and converts into JSON
278
+ keeping depth in mind to prevent creating large nested JSON"""
279
+ if depth > max_depth:
280
+ return {}
281
+ depth += 1
282
+
283
+ def is_serializable(request):
284
+ try:
285
+ json.dumps(request)
286
+ return True
287
+ except Exception:
288
+ return False
289
+
290
+ if is_serializable(request):
291
+ return json.dumps(request)
292
+ else:
293
+ result = {}
294
+ try:
295
+ if hasattr(request, "model_dump_json"):
296
+ return request.model_dump_json()
297
+ if hasattr(request, "__dict__"):
298
+ for attrib in request.__dict__:
299
+ if not attrib.startswith("_"):
300
+ if type(request.__dict__[attrib]) in [
301
+ bool,
302
+ str,
303
+ int,
304
+ float,
305
+ type(None),
306
+ ]:
307
+ result[str(attrib)] = request.__dict__[attrib]
308
+ else:
309
+ result[str(attrib)] = serialize(
310
+ request.__dict__[attrib], depth
311
+ )
312
+ except Exception:
313
+ pass
314
+ return json.dumps(result)
315
+
316
+
317
+ class InstrumentedStreamReader(ObjectProxy): # type: ignore
318
+ def __init__(self, wrapped, tracer):
319
+ super().__init__(wrapped)
320
+ self._tracer = tracer
321
+
322
+ async def __aenter__(self) -> Any:
323
+ return await self.__wrapped__.__aenter__()
324
+
325
+ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
326
+ return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
327
+
328
+ @dont_throw
329
+ async def __aiter__(self) -> AsyncGenerator[Any, None]:
330
+ from mcp.types import JSONRPCMessage, JSONRPCRequest
331
+ from mcp.shared.message import SessionMessage
332
+
333
+ async for item in self.__wrapped__:
334
+ if isinstance(item, SessionMessage):
335
+ request = cast(JSONRPCMessage, item.message).root
336
+ elif type(item) is JSONRPCMessage:
337
+ request = cast(JSONRPCMessage, item).root
338
+ else:
339
+ yield item
340
+ continue
341
+
342
+ if not isinstance(request, JSONRPCRequest):
343
+ yield item
344
+ continue
345
+
346
+ if request.params:
347
+ # Check both _meta and meta fields
348
+ meta = request.params.get("_meta") or getattr(
349
+ request.params, "meta", None
350
+ )
351
+
352
+ if meta:
353
+ if isinstance(meta, dict):
354
+ session_id = meta.get("session.id")
355
+ traceparent = meta.get("traceparent")
356
+ carrier = meta
357
+ else:
358
+ session_id = getattr(meta, "session.id", None)
359
+ traceparent = getattr(meta, "traceparent", None)
360
+ # Convert object to dict for propagate.extract
361
+ carrier = {}
362
+ if session_id:
363
+ carrier["session.id"] = session_id
364
+
365
+ if carrier and traceparent:
366
+ ctx = propagate.extract(carrier)
367
+
368
+ # Add session_id extraction and storage like in a2a.py
369
+ if session_id and session_id != "None":
370
+ from ioa_observe.sdk.client import kv_store
371
+ from ioa_observe.sdk.tracing import set_session_id
372
+
373
+ set_session_id(session_id, traceparent=traceparent)
374
+ kv_store.set(f"execution.{traceparent}", session_id)
375
+
376
+ restore = context.attach(ctx)
377
+ try:
378
+ yield item
379
+ continue
380
+ finally:
381
+ context.detach(restore)
382
+ yield item
383
+
384
+
385
+ class InstrumentedStreamWriter(ObjectProxy): # type: ignore
386
+ def __init__(self, wrapped, tracer):
387
+ super().__init__(wrapped)
388
+ self._tracer = tracer
389
+
390
+ async def __aenter__(self) -> Any:
391
+ return await self.__wrapped__.__aenter__()
392
+
393
+ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
394
+ return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
395
+
396
+ @dont_throw
397
+ async def send(self, item: Any) -> Any:
398
+ from mcp.types import JSONRPCMessage, JSONRPCRequest
399
+ from mcp.shared.message import SessionMessage
400
+
401
+ if isinstance(item, SessionMessage):
402
+ request = cast(JSONRPCMessage, item.message).root
403
+ elif type(item) is JSONRPCMessage:
404
+ request = cast(JSONRPCMessage, item).root
405
+ else:
406
+ return
407
+
408
+ with self._tracer.start_as_current_span("ResponseStreamWriter") as span:
409
+ if hasattr(request, "result"):
410
+ span.set_attribute(MCP_RESPONSE_VALUE, f"{serialize(request.result)}")
411
+ if "isError" in request.result:
412
+ if request.result["isError"] is True:
413
+ span.set_status(
414
+ Status(
415
+ StatusCode.ERROR,
416
+ f"{request.result['content'][0]['text']}",
417
+ )
418
+ )
419
+ error_type = get_error_type(
420
+ request.result["content"][0]["text"]
421
+ )
422
+ if error_type is not None:
423
+ span.set_attribute(ERROR_TYPE, error_type)
424
+ if hasattr(request, "id"):
425
+ span.set_attribute(MCP_REQUEST_ID, f"{request.id}")
426
+
427
+ if not isinstance(request, JSONRPCRequest):
428
+ return await self.__wrapped__.send(item)
429
+ meta = None
430
+ if not request.params:
431
+ request.params = {}
432
+ meta = request.params.setdefault("_meta", {})
433
+
434
+ propagate.get_global_textmap().inject(meta)
435
+ return await self.__wrapped__.send(item)
436
+
437
+
438
+ @dataclass(slots=True, frozen=True)
439
+ class ItemWithContext:
440
+ item: Any
441
+ ctx: context.Context
442
+
443
+
444
+ class ContextSavingStreamWriter(ObjectProxy): # type: ignore
445
+ def __init__(self, wrapped, tracer):
446
+ super().__init__(wrapped)
447
+ self._tracer = tracer
448
+
449
+ async def __aenter__(self) -> Any:
450
+ return await self.__wrapped__.__aenter__()
451
+
452
+ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
453
+ return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
454
+
455
+ @dont_throw
456
+ async def send(self, item: Any) -> Any:
457
+ with self._tracer.start_as_current_span("RequestStreamWriter") as span:
458
+ if hasattr(item, "request_id"):
459
+ span.set_attribute(MCP_REQUEST_ID, f"{item.request_id}")
460
+ if hasattr(item, "request"):
461
+ if hasattr(item.request, "root"):
462
+ if hasattr(item.request.root, "method"):
463
+ span.set_attribute(
464
+ MCP_METHOD_NAME,
465
+ f"{item.request.root.method}",
466
+ )
467
+ if hasattr(item.request.root, "params"):
468
+ span.set_attribute(
469
+ MCP_REQUEST_ARGUMENT,
470
+ f"{serialize(item.request.root.params)}",
471
+ )
472
+ ctx = context.get_current()
473
+ return await self.__wrapped__.send(ItemWithContext(item, ctx))
474
+
475
+
476
+ class ContextAttachingStreamReader(ObjectProxy): # type: ignore
477
+ def __init__(self, wrapped, tracer):
478
+ super().__init__(wrapped)
479
+ self._tracer = tracer
480
+
481
+ async def __aenter__(self) -> Any:
482
+ return await self.__wrapped__.__aenter__()
483
+
484
+ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
485
+ return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
486
+
487
+ async def __aiter__(self) -> AsyncGenerator[Any, None]:
488
+ async for item in self.__wrapped__:
489
+ item_with_context = cast(ItemWithContext, item)
490
+ restore = context.attach(item_with_context.ctx)
491
+ try:
492
+ yield item_with_context.item
493
+ finally:
494
+ context.detach(restore)