genai-otel-instrument 0.1.1.dev0__py3-none-any.whl → 0.1.4.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of genai-otel-instrument might be problematic. Click here for more details.

@@ -1,11 +1,11 @@
1
1
  """OpenTelemetry instrumentor for the Cohere SDK.
2
2
 
3
3
  This instrumentor automatically traces calls to Cohere models, capturing
4
- relevant attributes such as the model name.
4
+ relevant attributes such as the model name and token usage.
5
5
  """
6
6
 
7
7
  import logging
8
- from typing import Dict, Optional
8
+ from typing import Any, Dict, Optional
9
9
 
10
10
  from ..config import OTelConfig
11
11
  from .base import BaseInstrumentor
@@ -34,7 +34,7 @@ class CohereInstrumentor(BaseInstrumentor):
34
34
  self._cohere_available = False
35
35
 
36
36
  def instrument(self, config: OTelConfig):
37
- """Instrument cohere available if available."""
37
+ """Instrument cohere if available."""
38
38
  if not self._cohere_available:
39
39
  logger.debug("Skipping instrumentation - library not available")
40
40
  return
@@ -50,27 +50,91 @@ class CohereInstrumentor(BaseInstrumentor):
50
50
  self._instrument_client(instance)
51
51
 
52
52
  cohere.Client.__init__ = wrapped_init
53
+ self._instrumented = True
54
+ logger.info("Cohere instrumentation enabled")
53
55
 
54
- except ImportError:
55
- pass
56
+ except Exception as e:
57
+ logger.error("Failed to instrument Cohere: %s", e, exc_info=True)
58
+ if config.fail_on_error:
59
+ raise
56
60
 
57
61
  def _instrument_client(self, client):
62
+ """Instrument Cohere client methods."""
58
63
  original_generate = client.generate
59
64
 
60
- def wrapped_generate(*args, **kwargs):
61
- with self.tracer.start_as_current_span("cohere.generate") as span:
62
- model = kwargs.get("model", "command")
65
+ # Wrap using create_span_wrapper
66
+ wrapped_generate = self.create_span_wrapper(
67
+ span_name="cohere.generate",
68
+ extract_attributes=self._extract_generate_attributes,
69
+ )(original_generate)
63
70
 
64
- span.set_attribute("gen_ai.system", "cohere")
65
- span.set_attribute("gen_ai.request.model", model)
71
+ client.generate = wrapped_generate
66
72
 
67
- if self.request_counter:
68
- self.request_counter.add(1, {"model": model, "provider": "cohere"})
73
+ def _extract_generate_attributes(self, instance: Any, args: Any, kwargs: Any) -> Dict[str, Any]:
74
+ """Extract attributes from Cohere generate call.
69
75
 
70
- result = original_generate(*args, **kwargs)
71
- return result
76
+ Args:
77
+ instance: The client instance.
78
+ args: Positional arguments.
79
+ kwargs: Keyword arguments.
72
80
 
73
- client.generate = wrapped_generate
81
+ Returns:
82
+ Dict[str, Any]: Dictionary of attributes to set on the span.
83
+ """
84
+ attrs = {}
85
+ model = kwargs.get("model", "command")
86
+ prompt = kwargs.get("prompt", "")
87
+
88
+ attrs["gen_ai.system"] = "cohere"
89
+ attrs["gen_ai.request.model"] = model
90
+ attrs["gen_ai.operation.name"] = "generate"
91
+ attrs["gen_ai.request.message_count"] = 1 if prompt else 0
92
+
93
+ return attrs
74
94
 
75
95
  def _extract_usage(self, result) -> Optional[Dict[str, int]]:
76
- return None
96
+ """Extract token usage from Cohere response.
97
+
98
+ Cohere responses include meta.tokens with:
99
+ - input_tokens: Input tokens
100
+ - output_tokens: Output tokens
101
+
102
+ Args:
103
+ result: The API response object.
104
+
105
+ Returns:
106
+ Optional[Dict[str, int]]: Dictionary with token counts or None.
107
+ """
108
+ try:
109
+ # Handle object response
110
+ if hasattr(result, "meta") and result.meta:
111
+ meta = result.meta
112
+ # Check for tokens object
113
+ if hasattr(meta, "tokens") and meta.tokens:
114
+ tokens = meta.tokens
115
+ input_tokens = getattr(tokens, "input_tokens", 0)
116
+ output_tokens = getattr(tokens, "output_tokens", 0)
117
+
118
+ if input_tokens or output_tokens:
119
+ return {
120
+ "prompt_tokens": int(input_tokens) if input_tokens else 0,
121
+ "completion_tokens": int(output_tokens) if output_tokens else 0,
122
+ "total_tokens": int(input_tokens or 0) + int(output_tokens or 0),
123
+ }
124
+ # Fallback to billed_units
125
+ elif hasattr(meta, "billed_units") and meta.billed_units:
126
+ billed = meta.billed_units
127
+ input_tokens = getattr(billed, "input_tokens", 0)
128
+ output_tokens = getattr(billed, "output_tokens", 0)
129
+
130
+ if input_tokens or output_tokens:
131
+ return {
132
+ "prompt_tokens": int(input_tokens) if input_tokens else 0,
133
+ "completion_tokens": int(output_tokens) if output_tokens else 0,
134
+ "total_tokens": int(input_tokens or 0) + int(output_tokens or 0),
135
+ }
136
+
137
+ return None
138
+ except Exception as e:
139
+ logger.debug("Failed to extract usage from Cohere response: %s", e)
140
+ return None
@@ -1,11 +1,14 @@
1
- """OpenTelemetry instrumentor for HuggingFace Transformers library.
1
+ """OpenTelemetry instrumentor for HuggingFace Transformers and Inference API.
2
2
 
3
- This instrumentor automatically traces calls made through HuggingFace pipelines,
4
- capturing relevant attributes such as the model name and task type.
3
+ This instrumentor automatically traces:
4
+ 1. HuggingFace Transformers pipelines (local model execution)
5
+ 2. HuggingFace Inference API calls via InferenceClient (used by smolagents)
6
+
7
+ Note: Transformers runs models locally (no API costs), but InferenceClient makes
8
+ API calls to HuggingFace endpoints which may have costs based on usage.
5
9
  """
6
10
 
7
11
  import logging
8
- import types
9
12
  from typing import Dict, Optional
10
13
 
11
14
  from ..config import OTelConfig
@@ -15,16 +18,22 @@ logger = logging.getLogger(__name__)
15
18
 
16
19
 
17
20
  class HuggingFaceInstrumentor(BaseInstrumentor):
18
- """Instrumentor for HuggingFace Transformers"""
21
+ """Instrumentor for HuggingFace Transformers and Inference API.
22
+
23
+ Instruments both:
24
+ - transformers.pipeline (local execution, no API costs)
25
+ - huggingface_hub.InferenceClient (API calls, may have costs)
26
+ """
19
27
 
20
28
  def __init__(self):
21
29
  """Initialize the instrumentor."""
22
30
  super().__init__()
23
31
  self._transformers_available = False
32
+ self._inference_client_available = False
24
33
  self._check_availability()
25
34
 
26
35
  def _check_availability(self):
27
- """Check if Transformers library is available."""
36
+ """Check if Transformers and InferenceClient libraries are available."""
28
37
  try:
29
38
  import transformers
30
39
 
@@ -34,12 +43,47 @@ class HuggingFaceInstrumentor(BaseInstrumentor):
34
43
  logger.debug("Transformers library not installed, instrumentation will be skipped")
35
44
  self._transformers_available = False
36
45
 
46
+ try:
47
+ from huggingface_hub import InferenceClient
48
+
49
+ self._inference_client_available = True
50
+ logger.debug("HuggingFace InferenceClient detected and available for instrumentation")
51
+ except ImportError:
52
+ logger.debug("huggingface_hub not installed, InferenceClient instrumentation will be skipped")
53
+ self._inference_client_available = False
54
+
37
55
  def instrument(self, config: OTelConfig):
56
+ """Instrument HuggingFace Transformers pipelines and InferenceClient."""
38
57
  self.config = config
39
58
 
40
- if not self._transformers_available:
41
- return
42
-
59
+ instrumented_count = 0
60
+
61
+ # Instrument transformers.pipeline if available
62
+ if self._transformers_available:
63
+ try:
64
+ self._instrument_transformers()
65
+ instrumented_count += 1
66
+ except Exception as e:
67
+ logger.error("Failed to instrument HuggingFace Transformers: %s", e, exc_info=True)
68
+ if config.fail_on_error:
69
+ raise
70
+
71
+ # Instrument InferenceClient if available
72
+ if self._inference_client_available:
73
+ try:
74
+ self._instrument_inference_client()
75
+ instrumented_count += 1
76
+ except Exception as e:
77
+ logger.error("Failed to instrument HuggingFace InferenceClient: %s", e, exc_info=True)
78
+ if config.fail_on_error:
79
+ raise
80
+
81
+ if instrumented_count > 0:
82
+ self._instrumented = True
83
+ logger.info(f"HuggingFace instrumentation enabled ({instrumented_count} components)")
84
+
85
+ def _instrument_transformers(self):
86
+ """Instrument transformers.pipeline for local model execution."""
43
87
  try:
44
88
  import importlib
45
89
 
@@ -68,6 +112,7 @@ class HuggingFaceInstrumentor(BaseInstrumentor):
68
112
 
69
113
  span.set_attribute("gen_ai.system", "huggingface")
70
114
  span.set_attribute("gen_ai.request.model", model)
115
+ span.set_attribute("gen_ai.operation.name", task)
71
116
  span.set_attribute("huggingface.task", task)
72
117
 
73
118
  if instrumentor.request_counter:
@@ -88,10 +133,90 @@ class HuggingFaceInstrumentor(BaseInstrumentor):
88
133
  return WrappedPipeline(pipe)
89
134
 
90
135
  transformers_module.pipeline = wrapped_pipeline
91
- logger.info("HuggingFace instrumentation enabled")
92
-
93
- except ImportError:
94
- pass
136
+ logger.debug("HuggingFace Transformers pipeline instrumented")
137
+
138
+ except Exception as e:
139
+ raise # Re-raise to be caught by instrument() method
140
+
141
+ def _instrument_inference_client(self):
142
+ """Instrument HuggingFace InferenceClient for API calls."""
143
+ from huggingface_hub import InferenceClient
144
+
145
+ # Store original methods
146
+ original_chat_completion = InferenceClient.chat_completion
147
+ original_text_generation = InferenceClient.text_generation
148
+
149
+ # Wrap chat_completion method
150
+ wrapped_chat_completion = self.create_span_wrapper(
151
+ span_name="huggingface.inference.chat_completion",
152
+ extract_attributes=self._extract_inference_client_attributes,
153
+ )(original_chat_completion)
154
+
155
+ # Wrap text_generation method
156
+ wrapped_text_generation = self.create_span_wrapper(
157
+ span_name="huggingface.inference.text_generation",
158
+ extract_attributes=self._extract_inference_client_attributes,
159
+ )(original_text_generation)
160
+
161
+ InferenceClient.chat_completion = wrapped_chat_completion
162
+ InferenceClient.text_generation = wrapped_text_generation
163
+ logger.debug("HuggingFace InferenceClient instrumented")
164
+
165
+ def _extract_inference_client_attributes(self, instance, args, kwargs) -> Dict[str, str]:
166
+ """Extract attributes from Inference API call."""
167
+ attrs = {}
168
+ model = kwargs.get("model") or (args[0] if args else "unknown")
169
+
170
+ attrs["gen_ai.system"] = "huggingface"
171
+ attrs["gen_ai.request.model"] = str(model)
172
+ attrs["gen_ai.operation.name"] = "chat" # Default to chat
173
+
174
+ # Extract parameters if available
175
+ if "max_tokens" in kwargs:
176
+ attrs["gen_ai.request.max_tokens"] = kwargs["max_tokens"]
177
+ if "temperature" in kwargs:
178
+ attrs["gen_ai.request.temperature"] = kwargs["temperature"]
179
+ if "top_p" in kwargs:
180
+ attrs["gen_ai.request.top_p"] = kwargs["top_p"]
181
+
182
+ return attrs
95
183
 
96
184
  def _extract_usage(self, result) -> Optional[Dict[str, int]]:
185
+ """Extract token usage from HuggingFace response.
186
+
187
+ Handles both:
188
+ 1. Transformers pipeline (local execution) - returns None
189
+ 2. InferenceClient API calls - extracts token usage from response
190
+
191
+ Args:
192
+ result: The pipeline output or InferenceClient response.
193
+
194
+ Returns:
195
+ Dict with token counts for InferenceClient calls, None for local execution.
196
+ """
197
+ # Check if this is an InferenceClient API response
198
+ if result is not None and hasattr(result, "usage"):
199
+ usage = result.usage
200
+
201
+ # Extract token counts from usage object
202
+ prompt_tokens = getattr(usage, "prompt_tokens", None)
203
+ completion_tokens = getattr(usage, "completion_tokens", None)
204
+ total_tokens = getattr(usage, "total_tokens", None)
205
+
206
+ # If usage is a dict instead of object
207
+ if isinstance(usage, dict):
208
+ prompt_tokens = usage.get("prompt_tokens")
209
+ completion_tokens = usage.get("completion_tokens")
210
+ total_tokens = usage.get("total_tokens")
211
+
212
+ # Return token counts if available
213
+ if prompt_tokens is not None or completion_tokens is not None:
214
+ return {
215
+ "prompt_tokens": prompt_tokens or 0,
216
+ "completion_tokens": completion_tokens or 0,
217
+ "total_tokens": total_tokens or (prompt_tokens or 0) + (completion_tokens or 0),
218
+ }
219
+
220
+ # HuggingFace Transformers is free (local execution)
221
+ # No token-based costs to track
97
222
  return None
@@ -10,6 +10,7 @@ Supports Mistral SDK v1.0+ with the new API structure:
10
10
  """
11
11
 
12
12
  import logging
13
+ import time
13
14
  from typing import Any, Dict, Optional
14
15
 
15
16
  from ..config import OTelConfig
@@ -27,50 +28,261 @@ class MistralAIInstrumentor(BaseInstrumentor):
27
28
  import wrapt
28
29
  from mistralai import Mistral
29
30
 
30
- # Wrap the Mistral client __init__ to instrument each instance
31
- original_init = Mistral.__init__
31
+ # Get access to the chat and embeddings modules
32
+ # In Mistral SDK v1.0+, structure is:
33
+ # - Mistral client has .chat and .embeddings properties
34
+ # - These are bound methods that call internal APIs
32
35
 
33
- def wrapped_init(wrapped, instance, args, kwargs):
34
- result = wrapped(*args, **kwargs)
35
- self._instrument_client(instance)
36
- return result
37
-
38
- Mistral.__init__ = wrapt.FunctionWrapper(original_init, wrapped_init)
39
- logger.info("MistralAI instrumentation enabled (v1.0+ SDK)")
36
+ # Store original methods at module level before any instances are created
37
+ if not hasattr(Mistral, '_genai_otel_instrumented'):
38
+ self._wrap_mistral_methods(Mistral, wrapt)
39
+ Mistral._genai_otel_instrumented = True
40
+ logger.info("MistralAI instrumentation enabled (v1.0+ SDK)")
40
41
 
41
42
  except ImportError:
42
43
  logger.warning("mistralai package not available, skipping instrumentation")
43
44
  except Exception as e:
44
45
  logger.error(f"Failed to instrument mistralai: {e}", exc_info=True)
46
+ if config.fail_on_error:
47
+ raise
48
+
49
+ def _wrap_mistral_methods(self, Mistral, wrapt):
50
+ """Wrap Mistral client methods at the class level."""
51
+ # Import the internal classes that handle chat and embeddings
52
+ try:
53
+ from mistralai.chat import Chat
54
+ from mistralai.embeddings import Embeddings
55
+
56
+ # Wrap Chat.complete method
57
+ if hasattr(Chat, 'complete'):
58
+ wrapt.wrap_function_wrapper(
59
+ 'mistralai.chat',
60
+ 'Chat.complete',
61
+ self._wrap_chat_complete
62
+ )
63
+ logger.debug("Wrapped Mistral Chat.complete")
64
+
65
+ # Wrap Chat.stream method
66
+ if hasattr(Chat, 'stream'):
67
+ wrapt.wrap_function_wrapper(
68
+ 'mistralai.chat',
69
+ 'Chat.stream',
70
+ self._wrap_chat_stream
71
+ )
72
+ logger.debug("Wrapped Mistral Chat.stream")
73
+
74
+ # Wrap Embeddings.create method
75
+ if hasattr(Embeddings, 'create'):
76
+ wrapt.wrap_function_wrapper(
77
+ 'mistralai.embeddings',
78
+ 'Embeddings.create',
79
+ self._wrap_embeddings_create
80
+ )
81
+ logger.debug("Wrapped Mistral Embeddings.create")
82
+
83
+ except (ImportError, AttributeError) as e:
84
+ logger.warning(f"Could not access Mistral internal classes: {e}")
85
+
86
+ def _wrap_chat_complete(self, wrapped, instance, args, kwargs):
87
+ """Wrapper for chat.complete() method."""
88
+ model = kwargs.get("model", "mistral-small-latest")
89
+ span_name = f"mistralai.chat.complete {model}"
90
+
91
+ with self.tracer.start_span(span_name) as span:
92
+ # Set attributes
93
+ attributes = self._extract_chat_attributes(instance, args, kwargs)
94
+ for key, value in attributes.items():
95
+ span.set_attribute(key, value)
96
+
97
+ # Record request metric
98
+ if self.request_counter:
99
+ self.request_counter.add(1, {"model": model, "provider": "mistralai"})
100
+
101
+ # Execute the call
102
+ start_time = time.time()
103
+ try:
104
+ response = wrapped(*args, **kwargs)
105
+
106
+ # Record metrics from response
107
+ self._record_result_metrics(span, response, start_time, kwargs)
108
+
109
+ return response
110
+
111
+ except Exception as e:
112
+ if self.error_counter:
113
+ self.error_counter.add(
114
+ 1, {"operation": span_name, "error.type": type(e).__name__}
115
+ )
116
+ span.record_exception(e)
117
+ raise
118
+
119
+ def _wrap_chat_stream(self, wrapped, instance, args, kwargs):
120
+ """Wrapper for chat.stream() method - handles streaming responses."""
121
+ model = kwargs.get("model", "mistral-small-latest")
122
+ span_name = f"mistralai.chat.stream {model}"
123
+
124
+ # Start the span
125
+ span = self.tracer.start_span(span_name)
126
+
127
+ # Set attributes
128
+ attributes = self._extract_chat_attributes(instance, args, kwargs)
129
+ for key, value in attributes.items():
130
+ span.set_attribute(key, value)
131
+
132
+ # Record request metric
133
+ if self.request_counter:
134
+ self.request_counter.add(1, {"model": model, "provider": "mistralai"})
135
+
136
+ start_time = time.time()
137
+
138
+ # Execute and get the stream
139
+ try:
140
+ stream = wrapped(*args, **kwargs)
141
+
142
+ # Wrap the stream with our tracking wrapper
143
+ return self._StreamWrapper(
144
+ stream, span, self, model, start_time, span_name
145
+ )
146
+
147
+ except Exception as e:
148
+ if self.error_counter:
149
+ self.error_counter.add(
150
+ 1, {"operation": span_name, "error.type": type(e).__name__}
151
+ )
152
+ span.record_exception(e)
153
+ span.end()
154
+ raise
155
+
156
+ def _wrap_embeddings_create(self, wrapped, instance, args, kwargs):
157
+ """Wrapper for embeddings.create() method."""
158
+ model = kwargs.get("model", "mistral-embed")
159
+ span_name = f"mistralai.embeddings.create {model}"
160
+
161
+ with self.tracer.start_span(span_name) as span:
162
+ # Set attributes
163
+ attributes = self._extract_embeddings_attributes(instance, args, kwargs)
164
+ for key, value in attributes.items():
165
+ span.set_attribute(key, value)
166
+
167
+ # Record request metric
168
+ if self.request_counter:
169
+ self.request_counter.add(1, {"model": model, "provider": "mistralai"})
170
+
171
+ # Execute the call
172
+ start_time = time.time()
173
+ try:
174
+ response = wrapped(*args, **kwargs)
175
+
176
+ # Record metrics from response
177
+ self._record_result_metrics(span, response, start_time, kwargs)
178
+
179
+ return response
180
+
181
+ except Exception as e:
182
+ if self.error_counter:
183
+ self.error_counter.add(
184
+ 1, {"operation": span_name, "error.type": type(e).__name__}
185
+ )
186
+ span.record_exception(e)
187
+ raise
188
+
189
+ class _StreamWrapper:
190
+ """Wrapper for streaming responses that collects metrics."""
191
+
192
+ def __init__(self, stream, span, instrumentor, model, start_time, span_name):
193
+ self._stream = stream
194
+ self._span = span
195
+ self._instrumentor = instrumentor
196
+ self._model = model
197
+ self._start_time = start_time
198
+ self._span_name = span_name
199
+ self._usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
200
+ self._response_text = ""
201
+ self._first_chunk = True
202
+ self._ttft = None
203
+
204
+ def __iter__(self):
205
+ return self
206
+
207
+ def __next__(self):
208
+ try:
209
+ chunk = next(self._stream)
210
+
211
+ # Record time to first token
212
+ if self._first_chunk:
213
+ self._ttft = time.time() - self._start_time
214
+ self._first_chunk = False
215
+
216
+ # Process chunk to extract usage and content
217
+ self._process_chunk(chunk)
218
+
219
+ return chunk
220
+
221
+ except StopIteration:
222
+ # Stream completed - record final metrics
223
+ try:
224
+ # Set TTFT if we got any chunks
225
+ if self._ttft is not None:
226
+ self._span.set_attribute("gen_ai.server.ttft", self._ttft)
227
+
228
+ # Record usage metrics if available
229
+ if self._usage["total_tokens"] > 0:
230
+ # Create a mock response object with usage for _record_result_metrics
231
+ class MockUsage:
232
+ def __init__(self, usage_dict):
233
+ self.prompt_tokens = usage_dict["prompt_tokens"]
234
+ self.completion_tokens = usage_dict["completion_tokens"]
235
+ self.total_tokens = usage_dict["total_tokens"]
236
+
237
+ class MockResponse:
238
+ def __init__(self, usage_dict):
239
+ self.usage = MockUsage(usage_dict)
240
+
241
+ mock_response = MockResponse(self._usage)
242
+ self._instrumentor._record_result_metrics(
243
+ self._span,
244
+ mock_response,
245
+ self._start_time,
246
+ {"model": self._model}
247
+ )
248
+
249
+ finally:
250
+ self._span.end()
251
+
252
+ raise
253
+
254
+ def _process_chunk(self, chunk):
255
+ """Process a streaming chunk to extract usage."""
256
+ try:
257
+ # Mistral streaming chunks have: data.choices[0].delta.content
258
+ if hasattr(chunk, 'data'):
259
+ data = chunk.data
260
+ if hasattr(data, 'choices') and len(data.choices) > 0:
261
+ delta = data.choices[0].delta
262
+ if hasattr(delta, 'content') and delta.content:
263
+ self._response_text += delta.content
264
+
265
+ # Extract usage if available on final chunk
266
+ if hasattr(data, 'usage') and data.usage:
267
+ usage = data.usage
268
+ if hasattr(usage, 'prompt_tokens'):
269
+ self._usage["prompt_tokens"] = usage.prompt_tokens
270
+ if hasattr(usage, 'completion_tokens'):
271
+ self._usage["completion_tokens"] = usage.completion_tokens
272
+ if hasattr(usage, 'total_tokens'):
273
+ self._usage["total_tokens"] = usage.total_tokens
274
+
275
+ except Exception as e:
276
+ logger.debug(f"Error processing Mistral stream chunk: {e}")
277
+
278
+ def __enter__(self):
279
+ return self
45
280
 
46
- def _instrument_client(self, client):
47
- """Instrument Mistral client instance methods."""
48
- # Instrument chat.complete()
49
- if hasattr(client, "chat") and hasattr(client.chat, "complete"):
50
- original_complete = client.chat.complete
51
- instrumented_complete = self.create_span_wrapper(
52
- span_name="mistralai.chat.complete",
53
- extract_attributes=self._extract_chat_attributes,
54
- )(original_complete)
55
- client.chat.complete = instrumented_complete
56
-
57
- # Instrument chat.stream()
58
- if hasattr(client, "chat") and hasattr(client.chat, "stream"):
59
- original_stream = client.chat.stream
60
- instrumented_stream = self.create_span_wrapper(
61
- span_name="mistralai.chat.stream",
62
- extract_attributes=self._extract_chat_attributes,
63
- )(original_stream)
64
- client.chat.stream = instrumented_stream
65
-
66
- # Instrument embeddings.create()
67
- if hasattr(client, "embeddings") and hasattr(client.embeddings, "create"):
68
- original_embeddings = client.embeddings.create
69
- instrumented_embeddings = self.create_span_wrapper(
70
- span_name="mistralai.embeddings.create",
71
- extract_attributes=self._extract_embeddings_attributes,
72
- )(original_embeddings)
73
- client.embeddings.create = instrumented_embeddings
281
+ def __exit__(self, exc_type, exc_val, exc_tb):
282
+ if exc_type is not None:
283
+ self._span.record_exception(exc_val)
284
+ self._span.end()
285
+ return False
74
286
 
75
287
  def _extract_chat_attributes(self, instance: Any, args: Any, kwargs: Any) -> Dict[str, Any]:
76
288
  """Extract attributes from chat.complete() or chat.stream() call."""