mistralai 1.9.10__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. mistralai/_hooks/registration.py +5 -0
  2. mistralai/_hooks/tracing.py +50 -0
  3. mistralai/_version.py +3 -3
  4. mistralai/accesses.py +51 -116
  5. mistralai/agents.py +58 -85
  6. mistralai/audio.py +8 -3
  7. mistralai/basesdk.py +15 -5
  8. mistralai/batch.py +6 -3
  9. mistralai/beta.py +10 -5
  10. mistralai/chat.py +70 -97
  11. mistralai/classifiers.py +57 -144
  12. mistralai/conversations.py +435 -412
  13. mistralai/documents.py +156 -359
  14. mistralai/embeddings.py +21 -42
  15. mistralai/extra/observability/__init__.py +15 -0
  16. mistralai/extra/observability/otel.py +393 -0
  17. mistralai/extra/run/tools.py +28 -16
  18. mistralai/files.py +53 -176
  19. mistralai/fim.py +46 -73
  20. mistralai/fine_tuning.py +6 -3
  21. mistralai/jobs.py +49 -158
  22. mistralai/libraries.py +71 -178
  23. mistralai/mistral_agents.py +298 -179
  24. mistralai/mistral_jobs.py +51 -138
  25. mistralai/models/__init__.py +94 -5
  26. mistralai/models/agent.py +15 -2
  27. mistralai/models/agentconversation.py +11 -3
  28. mistralai/models/agentcreationrequest.py +6 -2
  29. mistralai/models/agents_api_v1_agents_deleteop.py +16 -0
  30. mistralai/models/agents_api_v1_agents_getop.py +40 -3
  31. mistralai/models/agents_api_v1_agents_listop.py +72 -2
  32. mistralai/models/agents_api_v1_conversations_deleteop.py +18 -0
  33. mistralai/models/agents_api_v1_conversations_listop.py +39 -2
  34. mistralai/models/agentscompletionrequest.py +21 -6
  35. mistralai/models/agentscompletionstreamrequest.py +21 -6
  36. mistralai/models/agentupdaterequest.py +18 -2
  37. mistralai/models/audiotranscriptionrequest.py +2 -0
  38. mistralai/models/batchjobin.py +10 -0
  39. mistralai/models/chatcompletionrequest.py +22 -5
  40. mistralai/models/chatcompletionstreamrequest.py +22 -5
  41. mistralai/models/conversationrequest.py +15 -4
  42. mistralai/models/conversationrestartrequest.py +50 -2
  43. mistralai/models/conversationrestartstreamrequest.py +50 -2
  44. mistralai/models/conversationstreamrequest.py +15 -4
  45. mistralai/models/documentout.py +26 -10
  46. mistralai/models/documentupdatein.py +24 -3
  47. mistralai/models/embeddingrequest.py +8 -8
  48. mistralai/models/files_api_routes_list_filesop.py +7 -0
  49. mistralai/models/fimcompletionrequest.py +8 -9
  50. mistralai/models/fimcompletionstreamrequest.py +8 -9
  51. mistralai/models/httpvalidationerror.py +11 -6
  52. mistralai/models/libraries_documents_list_v1op.py +15 -2
  53. mistralai/models/libraryout.py +10 -7
  54. mistralai/models/listfilesout.py +35 -4
  55. mistralai/models/mistralerror.py +26 -0
  56. mistralai/models/modelcapabilities.py +13 -4
  57. mistralai/models/modelconversation.py +8 -2
  58. mistralai/models/no_response_error.py +13 -0
  59. mistralai/models/ocrpageobject.py +26 -5
  60. mistralai/models/ocrrequest.py +17 -1
  61. mistralai/models/ocrtableobject.py +31 -0
  62. mistralai/models/prediction.py +4 -0
  63. mistralai/models/requestsource.py +7 -0
  64. mistralai/models/responseformat.py +4 -2
  65. mistralai/models/responseformats.py +0 -1
  66. mistralai/models/responsevalidationerror.py +25 -0
  67. mistralai/models/sdkerror.py +30 -14
  68. mistralai/models/sharingdelete.py +36 -5
  69. mistralai/models/sharingin.py +36 -5
  70. mistralai/models/sharingout.py +3 -3
  71. mistralai/models/toolexecutiondeltaevent.py +13 -4
  72. mistralai/models/toolexecutiondoneevent.py +13 -4
  73. mistralai/models/toolexecutionentry.py +9 -4
  74. mistralai/models/toolexecutionstartedevent.py +13 -4
  75. mistralai/models_.py +67 -212
  76. mistralai/ocr.py +33 -36
  77. mistralai/sdk.py +15 -2
  78. mistralai/transcriptions.py +21 -60
  79. mistralai/utils/__init__.py +18 -5
  80. mistralai/utils/eventstreaming.py +10 -0
  81. mistralai/utils/serializers.py +3 -2
  82. mistralai/utils/unmarshal_json_response.py +24 -0
  83. {mistralai-1.9.10.dist-info → mistralai-1.10.0.dist-info}/METADATA +89 -40
  84. {mistralai-1.9.10.dist-info → mistralai-1.10.0.dist-info}/RECORD +86 -75
  85. {mistralai-1.9.10.dist-info → mistralai-1.10.0.dist-info}/WHEEL +1 -1
  86. {mistralai-1.9.10.dist-info → mistralai-1.10.0.dist-info/licenses}/LICENSE +0 -0
mistralai/embeddings.py CHANGED
@@ -5,6 +5,7 @@ from mistralai import models, utils
5
5
  from mistralai._hooks import HookContext
6
6
  from mistralai.types import OptionalNullable, UNSET
7
7
  from mistralai.utils import get_security_from_env
8
+ from mistralai.utils.unmarshal_json_response import unmarshal_json_response
8
9
  from typing import Any, Mapping, Optional, Union
9
10
 
10
11
 
@@ -30,9 +31,9 @@ class Embeddings(BaseSDK):
30
31
 
31
32
  Embeddings
32
33
 
33
- :param model: ID of the model to use.
34
- :param inputs: Text to embed.
35
- :param output_dimension: The dimension of the output embeddings.
34
+ :param model: The ID of the model to be used for embedding.
35
+ :param inputs: The text content to be embedded, can be a string or an array of strings for fast processing in bulk.
36
+ :param output_dimension: The dimension of the output embeddings when feature available. If not provided, a default output dimension will be used.
36
37
  :param output_dtype:
37
38
  :param encoding_format:
38
39
  :param retries: Override the default retry configuration for this method
@@ -102,31 +103,20 @@ class Embeddings(BaseSDK):
102
103
 
103
104
  response_data: Any = None
104
105
  if utils.match_response(http_res, "200", "application/json"):
105
- return utils.unmarshal_json(http_res.text, models.EmbeddingResponse)
106
+ return unmarshal_json_response(models.EmbeddingResponse, http_res)
106
107
  if utils.match_response(http_res, "422", "application/json"):
107
- response_data = utils.unmarshal_json(
108
- http_res.text, models.HTTPValidationErrorData
108
+ response_data = unmarshal_json_response(
109
+ models.HTTPValidationErrorData, http_res
109
110
  )
110
- raise models.HTTPValidationError(data=response_data)
111
+ raise models.HTTPValidationError(response_data, http_res)
111
112
  if utils.match_response(http_res, "4XX", "*"):
112
113
  http_res_text = utils.stream_to_text(http_res)
113
- raise models.SDKError(
114
- "API error occurred", http_res.status_code, http_res_text, http_res
115
- )
114
+ raise models.SDKError("API error occurred", http_res, http_res_text)
116
115
  if utils.match_response(http_res, "5XX", "*"):
117
116
  http_res_text = utils.stream_to_text(http_res)
118
- raise models.SDKError(
119
- "API error occurred", http_res.status_code, http_res_text, http_res
120
- )
117
+ raise models.SDKError("API error occurred", http_res, http_res_text)
121
118
 
122
- content_type = http_res.headers.get("Content-Type")
123
- http_res_text = utils.stream_to_text(http_res)
124
- raise models.SDKError(
125
- f"Unexpected response received (code: {http_res.status_code}, type: {content_type})",
126
- http_res.status_code,
127
- http_res_text,
128
- http_res,
129
- )
119
+ raise models.SDKError("Unexpected response received", http_res)
130
120
 
131
121
  async def create_async(
132
122
  self,
@@ -147,9 +137,9 @@ class Embeddings(BaseSDK):
147
137
 
148
138
  Embeddings
149
139
 
150
- :param model: ID of the model to use.
151
- :param inputs: Text to embed.
152
- :param output_dimension: The dimension of the output embeddings.
140
+ :param model: The ID of the model to be used for embedding.
141
+ :param inputs: The text content to be embedded, can be a string or an array of strings for fast processing in bulk.
142
+ :param output_dimension: The dimension of the output embeddings when feature available. If not provided, a default output dimension will be used.
153
143
  :param output_dtype:
154
144
  :param encoding_format:
155
145
  :param retries: Override the default retry configuration for this method
@@ -219,28 +209,17 @@ class Embeddings(BaseSDK):
219
209
 
220
210
  response_data: Any = None
221
211
  if utils.match_response(http_res, "200", "application/json"):
222
- return utils.unmarshal_json(http_res.text, models.EmbeddingResponse)
212
+ return unmarshal_json_response(models.EmbeddingResponse, http_res)
223
213
  if utils.match_response(http_res, "422", "application/json"):
224
- response_data = utils.unmarshal_json(
225
- http_res.text, models.HTTPValidationErrorData
214
+ response_data = unmarshal_json_response(
215
+ models.HTTPValidationErrorData, http_res
226
216
  )
227
- raise models.HTTPValidationError(data=response_data)
217
+ raise models.HTTPValidationError(response_data, http_res)
228
218
  if utils.match_response(http_res, "4XX", "*"):
229
219
  http_res_text = await utils.stream_to_text_async(http_res)
230
- raise models.SDKError(
231
- "API error occurred", http_res.status_code, http_res_text, http_res
232
- )
220
+ raise models.SDKError("API error occurred", http_res, http_res_text)
233
221
  if utils.match_response(http_res, "5XX", "*"):
234
222
  http_res_text = await utils.stream_to_text_async(http_res)
235
- raise models.SDKError(
236
- "API error occurred", http_res.status_code, http_res_text, http_res
237
- )
223
+ raise models.SDKError("API error occurred", http_res, http_res_text)
238
224
 
239
- content_type = http_res.headers.get("Content-Type")
240
- http_res_text = await utils.stream_to_text_async(http_res)
241
- raise models.SDKError(
242
- f"Unexpected response received (code: {http_res.status_code}, type: {content_type})",
243
- http_res.status_code,
244
- http_res_text,
245
- http_res,
246
- )
225
+ raise models.SDKError("Unexpected response received", http_res)
@@ -0,0 +1,15 @@
1
+ from contextlib import contextmanager
2
+
3
+ from opentelemetry import trace as otel_trace
4
+
5
+ from .otel import MISTRAL_SDK_OTEL_TRACER_NAME
6
+
7
+
8
+ @contextmanager
9
+ def trace(name: str, **kwargs):
10
+ tracer = otel_trace.get_tracer(MISTRAL_SDK_OTEL_TRACER_NAME)
11
+ with tracer.start_as_current_span(name, **kwargs) as span:
12
+ yield span
13
+
14
+
15
+ __all__ = ["trace"]
@@ -0,0 +1,393 @@
1
+ import copy
2
+ import json
3
+ import logging
4
+ import os
5
+ import traceback
6
+ from datetime import datetime, timezone
7
+ from enum import Enum
8
+ from typing import Optional, Tuple
9
+
10
+ import httpx
11
+ import opentelemetry.semconv._incubating.attributes.gen_ai_attributes as gen_ai_attributes
12
+ import opentelemetry.semconv._incubating.attributes.http_attributes as http_attributes
13
+ import opentelemetry.semconv.attributes.server_attributes as server_attributes
14
+ from opentelemetry import propagate, trace
15
+ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
16
+ from opentelemetry.sdk.resources import SERVICE_NAME, Resource
17
+ from opentelemetry.sdk.trace import SpanProcessor, TracerProvider
18
+ from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExportResult
19
+ from opentelemetry.trace import Span, Status, StatusCode, Tracer, set_span_in_context
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ OTEL_SERVICE_NAME: str = "mistralai_sdk"
25
+ OTEL_EXPORTER_OTLP_ENDPOINT: str = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "")
26
+ OTEL_EXPORTER_OTLP_TIMEOUT: int = int(os.getenv("OTEL_EXPORTER_OTLP_TIMEOUT", "2"))
27
+ OTEL_EXPORTER_OTLP_MAX_EXPORT_BATCH_SIZE: int = int(os.getenv("OTEL_EXPORTER_OTLP_MAX_EXPORT_BATCH_SIZE", "512"))
28
+ OTEL_EXPORTER_OTLP_SCHEDULE_DELAY_MILLIS: int = int(os.getenv("OTEL_EXPORTER_OTLP_SCHEDULE_DELAY_MILLIS", "1000"))
29
+ OTEL_EXPORTER_OTLP_MAX_QUEUE_SIZE: int = int(os.getenv("OTEL_EXPORTER_OTLP_MAX_QUEUE_SIZE", "2048"))
30
+ OTEL_EXPORTER_OTLP_EXPORT_TIMEOUT_MILLIS: int = int(os.getenv("OTEL_EXPORTER_OTLP_EXPORT_TIMEOUT_MILLIS", "5000"))
31
+
32
+ MISTRAL_SDK_OTEL_TRACER_NAME: str = OTEL_SERVICE_NAME + "_tracer"
33
+
34
+ MISTRAL_SDK_DEBUG_TRACING: bool = os.getenv("MISTRAL_SDK_DEBUG_TRACING", "false").lower() == "true"
35
+ DEBUG_HINT: str = "To see detailed exporter logs, set MISTRAL_SDK_DEBUG_TRACING=true."
36
+
37
+
38
+ class MistralAIAttributes:
39
+ MISTRAL_AI_TOTAL_TOKENS = "mistral_ai.request.total_tokens"
40
+ MISTRAL_AI_TOOL_CALL_ARGUMENTS = "mistral_ai.tool.call.arguments"
41
+ MISTRAL_AI_MESSAGE_ID = "mistral_ai.message.id"
42
+ MISTRAL_AI_OPERATION_NAME= "mistral_ai.operation.name"
43
+ MISTRAL_AI_OCR_USAGE_PAGES_PROCESSED = "mistral_ai.ocr.usage.pages_processed"
44
+ MISTRAL_AI_OCR_USAGE_DOC_SIZE_BYTES = "mistral_ai.ocr.usage.doc_size_bytes"
45
+ MISTRAL_AI_OPERATION_ID = "mistral_ai.operation.id"
46
+ MISTRAL_AI_ERROR_TYPE = "mistral_ai.error.type"
47
+ MISTRAL_AI_ERROR_MESSAGE = "mistral_ai.error.message"
48
+ MISTRAL_AI_ERROR_CODE = "mistral_ai.error.code"
49
+ MISTRAL_AI_FUNCTION_CALL_ARGUMENTS = "mistral_ai.function.call.arguments"
50
+
51
+ class MistralAINameValues(Enum):
52
+ OCR = "ocr"
53
+
54
+ class TracingErrors(Exception, Enum):
55
+ FAILED_TO_EXPORT_OTEL_SPANS = "Failed to export OpenTelemetry (OTEL) spans."
56
+ FAILED_TO_INITIALIZE_OPENTELEMETRY_TRACING = "Failed to initialize OpenTelemetry tracing."
57
+ FAILED_TO_CREATE_SPAN_FOR_REQUEST = "Failed to create span for request."
58
+ FAILED_TO_ENRICH_SPAN_WITH_RESPONSE = "Failed to enrich span with response."
59
+ FAILED_TO_HANDLE_ERROR_IN_SPAN = "Failed to handle error in span."
60
+ FAILED_TO_END_SPAN = "Failed to end span."
61
+
62
+ def __str__(self):
63
+ return str(self.value)
64
+
65
+ class GenAISpanEnum(str, Enum):
66
+ CONVERSATION = "conversation"
67
+ CONV_REQUEST = "POST /v1/conversations"
68
+ EXECUTE_TOOL = "execute_tool"
69
+ VALIDATE_RUN = "validate_run"
70
+
71
+ @staticmethod
72
+ def function_call(func_name: str):
73
+ return f"function_call[{func_name}]"
74
+
75
+
76
+ def parse_time_to_nanos(ts: str) -> int:
77
+ dt = datetime.fromisoformat(ts.replace("Z", "+00:00")).astimezone(timezone.utc)
78
+ return int(dt.timestamp() * 1e9)
79
+
80
+ def set_available_attributes(span: Span, attributes: dict) -> None:
81
+ for attribute, value in attributes.items():
82
+ if value:
83
+ span.set_attribute(attribute, value)
84
+
85
+
86
+ def enrich_span_from_request(span: Span, request: httpx.Request) -> Span:
87
+ if not request.url.port:
88
+ # From httpx doc:
89
+ # Note that the URL class performs port normalization as per the WHATWG spec.
90
+ # Default ports for "http", "https", "ws", "wss", and "ftp" schemes are always treated as None.
91
+ # Handling default ports since most of the time we are using https
92
+ if request.url.scheme == "https":
93
+ port = 443
94
+ elif request.url.scheme == "http":
95
+ port = 80
96
+ else:
97
+ port = -1
98
+ else:
99
+ port = request.url.port
100
+
101
+ span.set_attributes({
102
+ http_attributes.HTTP_REQUEST_METHOD: request.method,
103
+ http_attributes.HTTP_URL: str(request.url),
104
+ server_attributes.SERVER_ADDRESS: request.headers.get("host", ""),
105
+ server_attributes.SERVER_PORT: port
106
+ })
107
+ if request._content:
108
+ request_body = json.loads(request._content)
109
+
110
+ attributes = {
111
+ gen_ai_attributes.GEN_AI_REQUEST_CHOICE_COUNT: request_body.get("n", None),
112
+ gen_ai_attributes.GEN_AI_REQUEST_ENCODING_FORMATS: request_body.get("encoding_formats", None),
113
+ gen_ai_attributes.GEN_AI_REQUEST_FREQUENCY_PENALTY: request_body.get("frequency_penalty", None),
114
+ gen_ai_attributes.GEN_AI_REQUEST_MAX_TOKENS: request_body.get("max_tokens", None),
115
+ gen_ai_attributes.GEN_AI_REQUEST_MODEL: request_body.get("model", None),
116
+ gen_ai_attributes.GEN_AI_REQUEST_PRESENCE_PENALTY: request_body.get("presence_penalty", None),
117
+ gen_ai_attributes.GEN_AI_REQUEST_SEED: request_body.get("random_seed", None),
118
+ gen_ai_attributes.GEN_AI_REQUEST_STOP_SEQUENCES: request_body.get("stop", None),
119
+ gen_ai_attributes.GEN_AI_REQUEST_TEMPERATURE: request_body.get("temperature", None),
120
+ gen_ai_attributes.GEN_AI_REQUEST_TOP_P: request_body.get("top_p", None),
121
+ gen_ai_attributes.GEN_AI_REQUEST_TOP_K: request_body.get("top_k", None),
122
+ # Input messages are likely to be large, containing user/PII data and other sensitive information.
123
+ # Also structured attributes are not yet supported on spans in Python.
124
+ # For those reasons, we will not record the input messages for now.
125
+ gen_ai_attributes.GEN_AI_INPUT_MESSAGES: None,
126
+ }
127
+ # Set attributes only if they are not None.
128
+ # From OpenTelemetry documentation: None is not a valid attribute value per spec / is not a permitted value type for an attribute.
129
+ set_available_attributes(span, attributes)
130
+ return span
131
+
132
+
133
+ def enrich_span_from_response(tracer: trace.Tracer, span: Span, operation_id: str, response: httpx.Response) -> None:
134
+ span.set_status(Status(StatusCode.OK))
135
+ response_data = json.loads(response.content)
136
+
137
+ # Base attributes
138
+ attributes: dict[str, str | int] = {
139
+ http_attributes.HTTP_RESPONSE_STATUS_CODE: response.status_code,
140
+ MistralAIAttributes.MISTRAL_AI_OPERATION_ID: operation_id,
141
+ gen_ai_attributes.GEN_AI_PROVIDER_NAME: gen_ai_attributes.GenAiProviderNameValues.MISTRAL_AI.value
142
+ }
143
+
144
+ # Add usage attributes if available
145
+ usage = response_data.get("usage", {})
146
+ if usage:
147
+ attributes.update({
148
+ gen_ai_attributes.GEN_AI_USAGE_PROMPT_TOKENS: usage.get("prompt_tokens", 0),
149
+ gen_ai_attributes.GEN_AI_USAGE_OUTPUT_TOKENS: usage.get("completion_tokens", 0),
150
+ MistralAIAttributes.MISTRAL_AI_TOTAL_TOKENS: usage.get("total_tokens", 0)
151
+ })
152
+
153
+ span.set_attributes(attributes)
154
+ if operation_id == "agents_api_v1_agents_create":
155
+ # Semantics from https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-agent-spans/#create-agent-span
156
+ agent_attributes = {
157
+ gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.CREATE_AGENT.value,
158
+ gen_ai_attributes.GEN_AI_AGENT_DESCRIPTION: response_data.get("description", ""),
159
+ gen_ai_attributes.GEN_AI_AGENT_ID: response_data.get("id", ""),
160
+ gen_ai_attributes.GEN_AI_AGENT_NAME: response_data.get("name", ""),
161
+ gen_ai_attributes.GEN_AI_REQUEST_MODEL: response_data.get("model", ""),
162
+ gen_ai_attributes.GEN_AI_SYSTEM_INSTRUCTIONS: response_data.get("instructions", "")
163
+ }
164
+ span.set_attributes(agent_attributes)
165
+ if operation_id in ["agents_api_v1_conversations_start", "agents_api_v1_conversations_append"]:
166
+ outputs = response_data.get("outputs", [])
167
+ conversation_attributes = {
168
+ gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.INVOKE_AGENT.value,
169
+ gen_ai_attributes.GEN_AI_CONVERSATION_ID: response_data.get("conversation_id", "")
170
+ }
171
+ span.set_attributes(conversation_attributes)
172
+ parent_context = set_span_in_context(span)
173
+
174
+ for output in outputs:
175
+ # TODO: Only enrich the spans if it's a single turn conversation.
176
+ # Multi turn conversations are handled in the extra.run.tools.create_function_result function
177
+ if output["type"] == "function.call":
178
+ pass
179
+ if output["type"] == "tool.execution":
180
+ start_ns = parse_time_to_nanos(output["created_at"])
181
+ end_ns = parse_time_to_nanos(output["completed_at"])
182
+ child_span = tracer.start_span("Tool Execution", start_time=start_ns, context=parent_context)
183
+ tool_attributes = {
184
+ gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.EXECUTE_TOOL.value,
185
+ gen_ai_attributes.GEN_AI_TOOL_CALL_ID: output.get("id", ""),
186
+ MistralAIAttributes.MISTRAL_AI_TOOL_CALL_ARGUMENTS: output.get("arguments", ""),
187
+ gen_ai_attributes.GEN_AI_TOOL_NAME: output.get("name", "")
188
+ }
189
+ child_span.set_attributes(tool_attributes)
190
+ child_span.end(end_time=end_ns)
191
+ if output["type"] == "message.output":
192
+ start_ns = parse_time_to_nanos(output["created_at"])
193
+ end_ns = parse_time_to_nanos(output["completed_at"])
194
+ child_span = tracer.start_span("Message Output", start_time=start_ns, context=parent_context)
195
+ message_attributes = {
196
+ gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.CHAT.value,
197
+ gen_ai_attributes.GEN_AI_PROVIDER_NAME: gen_ai_attributes.GenAiProviderNameValues.MISTRAL_AI.value,
198
+ MistralAIAttributes.MISTRAL_AI_MESSAGE_ID: output.get("id", ""),
199
+ gen_ai_attributes.GEN_AI_AGENT_ID: output.get("agent_id", ""),
200
+ gen_ai_attributes.GEN_AI_REQUEST_MODEL: output.get("model", "")
201
+ }
202
+ child_span.set_attributes(message_attributes)
203
+ child_span.end(end_time=end_ns)
204
+ if operation_id == "ocr_v1_ocr_post":
205
+ usage_info = response_data.get("usage_info", "")
206
+ ocr_attributes = {
207
+ MistralAIAttributes.MISTRAL_AI_OPERATION_NAME: MistralAINameValues.OCR.value,
208
+ MistralAIAttributes.MISTRAL_AI_OCR_USAGE_PAGES_PROCESSED: usage_info.get("pages_processed", "") if usage_info else "",
209
+ MistralAIAttributes.MISTRAL_AI_OCR_USAGE_DOC_SIZE_BYTES: usage_info.get("doc_size_bytes", "") if usage_info else "",
210
+ gen_ai_attributes.GEN_AI_REQUEST_MODEL: response_data.get("model", "")
211
+ }
212
+ span.set_attributes(ocr_attributes)
213
+
214
+
215
+ class GenAISpanProcessor(SpanProcessor):
216
+ def on_start(self, span, parent_context = None):
217
+ span.set_attributes({"agent.trace.public": ""})
218
+
219
+
220
+ class QuietOTLPSpanExporter(OTLPSpanExporter):
221
+ def export(self, spans):
222
+ try:
223
+ return super().export(spans)
224
+ except Exception:
225
+ logger.warning(f"{TracingErrors.FAILED_TO_EXPORT_OTEL_SPANS} {(traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT)}")
226
+ return SpanExportResult.FAILURE
227
+
228
+
229
+ def get_or_create_otel_tracer() -> Tuple[bool, Tracer]:
230
+ """
231
+ 3 possible cases:
232
+
233
+ -> [SDK in a Workflow / App] If there is already a tracer provider set -> use that one
234
+
235
+ -> [SDK standalone] If no tracer provider is set but the OTEL_EXPORTER_OTLP_ENDPOINT is set -> create a new tracer provider that exports to the OTEL_EXPORTER_OTLP_ENDPOINT
236
+
237
+ -> Else tracing is disabled
238
+ """
239
+ tracing_enabled = True
240
+ tracer_provider = trace.get_tracer_provider()
241
+
242
+ if isinstance(tracer_provider, trace.ProxyTracerProvider):
243
+ if OTEL_EXPORTER_OTLP_ENDPOINT:
244
+ # SDK standalone: No tracer provider but OTEL_EXPORTER_OTLP_ENDPOINT is set -> create a new tracer provider that exports to the OTEL_EXPORTER_OTLP_ENDPOINT
245
+ try:
246
+ exporter = QuietOTLPSpanExporter(
247
+ endpoint=OTEL_EXPORTER_OTLP_ENDPOINT,
248
+ timeout=OTEL_EXPORTER_OTLP_TIMEOUT
249
+ )
250
+ resource = Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME})
251
+ tracer_provider = TracerProvider(resource=resource)
252
+
253
+ span_processor = BatchSpanProcessor(
254
+ exporter,
255
+ export_timeout_millis=OTEL_EXPORTER_OTLP_EXPORT_TIMEOUT_MILLIS,
256
+ max_export_batch_size=OTEL_EXPORTER_OTLP_MAX_EXPORT_BATCH_SIZE,
257
+ schedule_delay_millis=OTEL_EXPORTER_OTLP_SCHEDULE_DELAY_MILLIS,
258
+ max_queue_size=OTEL_EXPORTER_OTLP_MAX_QUEUE_SIZE
259
+ )
260
+
261
+ tracer_provider.add_span_processor(span_processor)
262
+ tracer_provider.add_span_processor(GenAISpanProcessor())
263
+ trace.set_tracer_provider(tracer_provider)
264
+
265
+ except Exception:
266
+ logger.warning(f"{TracingErrors.FAILED_TO_INITIALIZE_OPENTELEMETRY_TRACING} {(traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT)}")
267
+ tracing_enabled = False
268
+ else:
269
+ # No tracer provider nor OTEL_EXPORTER_OTLP_ENDPOINT set -> tracing is disabled
270
+ tracing_enabled = False
271
+
272
+ tracer = tracer_provider.get_tracer(MISTRAL_SDK_OTEL_TRACER_NAME)
273
+
274
+ return tracing_enabled, tracer
275
+
276
+ def get_traced_request_and_span(tracing_enabled: bool, tracer: Tracer, span: Optional[Span], operation_id: str, request: httpx.Request) -> Tuple[httpx.Request, Optional[Span]]:
277
+ if not tracing_enabled:
278
+ return request, span
279
+
280
+ try:
281
+ span = tracer.start_span(name=operation_id)
282
+ # Inject the span context into the request headers to be used by the backend service to continue the trace
283
+ propagate.inject(request.headers)
284
+ span = enrich_span_from_request(span, request)
285
+ except Exception:
286
+ logger.warning(
287
+ "%s %s",
288
+ TracingErrors.FAILED_TO_CREATE_SPAN_FOR_REQUEST,
289
+ traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT,
290
+ )
291
+ if span:
292
+ end_span(span=span)
293
+ span = None
294
+
295
+ return request, span
296
+
297
+
298
+ def get_traced_response(tracing_enabled: bool, tracer: Tracer, span: Optional[Span], operation_id: str, response: httpx.Response) -> httpx.Response:
299
+ if not tracing_enabled or not span:
300
+ return response
301
+ try:
302
+ is_stream_response = not response.is_closed and not response.is_stream_consumed
303
+ if is_stream_response:
304
+ return TracedResponse.from_response(resp=response, span=span)
305
+ enrich_span_from_response(
306
+ tracer, span, operation_id, response
307
+ )
308
+ except Exception:
309
+ logger.warning(
310
+ "%s %s",
311
+ TracingErrors.FAILED_TO_ENRICH_SPAN_WITH_RESPONSE,
312
+ traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT,
313
+ )
314
+ if span:
315
+ end_span(span=span)
316
+ return response
317
+
318
+ def get_response_and_error(tracing_enabled: bool, tracer: Tracer, span: Optional[Span], operation_id: str, response: httpx.Response, error: Optional[Exception]) -> Tuple[httpx.Response, Optional[Exception]]:
319
+ if not tracing_enabled or not span:
320
+ return response, error
321
+ try:
322
+ if error:
323
+ span.record_exception(error)
324
+ span.set_status(Status(StatusCode.ERROR, str(error)))
325
+ if hasattr(response, "_content") and response._content:
326
+ response_body = json.loads(response._content)
327
+ if response_body.get("object", "") == "error":
328
+ if error_msg := response_body.get("message", ""):
329
+ attributes = {
330
+ http_attributes.HTTP_RESPONSE_STATUS_CODE: response.status_code,
331
+ MistralAIAttributes.MISTRAL_AI_ERROR_TYPE: response_body.get("type", ""),
332
+ MistralAIAttributes.MISTRAL_AI_ERROR_MESSAGE: error_msg,
333
+ MistralAIAttributes.MISTRAL_AI_ERROR_CODE: response_body.get("code", ""),
334
+ }
335
+ for attribute, value in attributes.items():
336
+ if value:
337
+ span.set_attribute(attribute, value)
338
+ span.end()
339
+ span = None
340
+ except Exception:
341
+ logger.warning(
342
+ "%s %s",
343
+ TracingErrors.FAILED_TO_HANDLE_ERROR_IN_SPAN,
344
+ traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT,
345
+ )
346
+
347
+ if span:
348
+ span.end()
349
+ span = None
350
+ return response, error
351
+
352
+
353
+ def end_span(span: Span) -> None:
354
+ try:
355
+ span.end()
356
+ except Exception:
357
+ logger.warning(
358
+ "%s %s",
359
+ TracingErrors.FAILED_TO_END_SPAN,
360
+ traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT,
361
+ )
362
+
363
+ class TracedResponse(httpx.Response):
364
+ """
365
+ TracedResponse is a subclass of httpx.Response that ends the span when the response is closed.
366
+
367
+ This hack allows ending the span only once the stream is fully consumed.
368
+ """
369
+ def __init__(self, *args, span: Optional[Span], **kwargs) -> None:
370
+ super().__init__(*args, **kwargs)
371
+ self.span = span
372
+
373
+ def close(self) -> None:
374
+ if self.span:
375
+ end_span(span=self.span)
376
+ super().close()
377
+
378
+ async def aclose(self) -> None:
379
+ if self.span:
380
+ end_span(span=self.span)
381
+ await super().aclose()
382
+
383
+ @classmethod
384
+ def from_response(cls, resp: httpx.Response, span: Optional[Span]) -> "TracedResponse":
385
+ traced_resp = cls.__new__(cls)
386
+ traced_resp.__dict__ = copy.copy(resp.__dict__)
387
+ traced_resp.span = span
388
+
389
+ # Warning: this syntax bypasses the __init__ method.
390
+ # If you add init logic in the TracedResponse.__init__ method, you will need to add the following line for it to execute:
391
+ # traced_resp.__init__(your_arguments)
392
+
393
+ return traced_resp
@@ -8,6 +8,7 @@ from pydantic.fields import FieldInfo
8
8
  import json
9
9
  from typing import cast, Callable, Sequence, Any, ForwardRef, get_type_hints, Union
10
10
 
11
+ from opentelemetry import trace
11
12
  from griffe import (
12
13
  Docstring,
13
14
  DocstringSectionKind,
@@ -15,9 +16,11 @@ from griffe import (
15
16
  DocstringParameter,
16
17
  DocstringSection,
17
18
  )
19
+ import opentelemetry.semconv._incubating.attributes.gen_ai_attributes as gen_ai_attributes
18
20
 
19
21
  from mistralai.extra.exceptions import RunException
20
22
  from mistralai.extra.mcp.base import MCPClientProtocol
23
+ from mistralai.extra.observability.otel import GenAISpanEnum, MistralAIAttributes, set_available_attributes
21
24
  from mistralai.extra.run.result import RunOutputEntries
22
25
  from mistralai.models import (
23
26
  FunctionResultEntry,
@@ -191,22 +194,31 @@ async def create_function_result(
191
194
  if isinstance(function_call.arguments, str)
192
195
  else function_call.arguments
193
196
  )
194
- try:
195
- if isinstance(run_tool, RunFunction):
196
- res = run_tool.callable(**arguments)
197
- elif isinstance(run_tool, RunCoroutine):
198
- res = await run_tool.awaitable(**arguments)
199
- elif isinstance(run_tool, RunMCPTool):
200
- res = await run_tool.mcp_client.execute_tool(function_call.name, arguments)
201
- except Exception as e:
202
- if continue_on_fn_error is True:
203
- return FunctionResultEntry(
204
- tool_call_id=function_call.tool_call_id,
205
- result=f"Error while executing {function_call.name}: {str(e)}",
206
- )
207
- raise RunException(
208
- f"Failed to execute tool {function_call.name} with arguments '{function_call.arguments}'"
209
- ) from e
197
+ tracer = trace.get_tracer(__name__)
198
+ with tracer.start_as_current_span(GenAISpanEnum.function_call(function_call.name)) as span:
199
+ try:
200
+ if isinstance(run_tool, RunFunction):
201
+ res = run_tool.callable(**arguments)
202
+ elif isinstance(run_tool, RunCoroutine):
203
+ res = await run_tool.awaitable(**arguments)
204
+ elif isinstance(run_tool, RunMCPTool):
205
+ res = await run_tool.mcp_client.execute_tool(function_call.name, arguments)
206
+ function_call_attributes = {
207
+ gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.EXECUTE_TOOL.value,
208
+ gen_ai_attributes.GEN_AI_TOOL_CALL_ID: function_call.id,
209
+ MistralAIAttributes.MISTRAL_AI_TOOL_CALL_ARGUMENTS: str(function_call.arguments),
210
+ gen_ai_attributes.GEN_AI_TOOL_NAME: function_call.name
211
+ }
212
+ set_available_attributes(span, function_call_attributes)
213
+ except Exception as e:
214
+ if continue_on_fn_error is True:
215
+ return FunctionResultEntry(
216
+ tool_call_id=function_call.tool_call_id,
217
+ result=f"Error while executing {function_call.name}: {str(e)}",
218
+ )
219
+ raise RunException(
220
+ f"Failed to execute tool {function_call.name} with arguments '{function_call.arguments}'"
221
+ ) from e
210
222
 
211
223
  return FunctionResultEntry(
212
224
  tool_call_id=function_call.tool_call_id,