mistralai 1.9.11__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. mistralai/_hooks/registration.py +5 -0
  2. mistralai/_hooks/tracing.py +75 -0
  3. mistralai/_version.py +2 -2
  4. mistralai/accesses.py +8 -8
  5. mistralai/agents.py +29 -17
  6. mistralai/chat.py +41 -29
  7. mistralai/classifiers.py +13 -1
  8. mistralai/conversations.py +294 -62
  9. mistralai/documents.py +19 -3
  10. mistralai/embeddings.py +13 -7
  11. mistralai/extra/README.md +1 -1
  12. mistralai/extra/mcp/auth.py +10 -11
  13. mistralai/extra/mcp/base.py +17 -16
  14. mistralai/extra/mcp/sse.py +13 -15
  15. mistralai/extra/mcp/stdio.py +5 -6
  16. mistralai/extra/observability/__init__.py +15 -0
  17. mistralai/extra/observability/otel.py +372 -0
  18. mistralai/extra/run/context.py +33 -43
  19. mistralai/extra/run/result.py +29 -30
  20. mistralai/extra/run/tools.py +34 -23
  21. mistralai/extra/struct_chat.py +15 -8
  22. mistralai/extra/utils/response_format.py +5 -3
  23. mistralai/files.py +6 -0
  24. mistralai/fim.py +17 -5
  25. mistralai/mistral_agents.py +229 -1
  26. mistralai/mistral_jobs.py +39 -13
  27. mistralai/models/__init__.py +99 -3
  28. mistralai/models/agent.py +15 -2
  29. mistralai/models/agentconversation.py +11 -3
  30. mistralai/models/agentcreationrequest.py +6 -2
  31. mistralai/models/agents_api_v1_agents_deleteop.py +16 -0
  32. mistralai/models/agents_api_v1_agents_getop.py +40 -3
  33. mistralai/models/agents_api_v1_agents_listop.py +72 -2
  34. mistralai/models/agents_api_v1_conversations_deleteop.py +18 -0
  35. mistralai/models/agents_api_v1_conversations_listop.py +39 -2
  36. mistralai/models/agentscompletionrequest.py +21 -6
  37. mistralai/models/agentscompletionstreamrequest.py +21 -6
  38. mistralai/models/agentupdaterequest.py +18 -2
  39. mistralai/models/audioencoding.py +13 -0
  40. mistralai/models/audioformat.py +19 -0
  41. mistralai/models/audiotranscriptionrequest.py +2 -0
  42. mistralai/models/batchjobin.py +26 -5
  43. mistralai/models/batchjobout.py +5 -0
  44. mistralai/models/batchrequest.py +48 -0
  45. mistralai/models/chatcompletionrequest.py +22 -5
  46. mistralai/models/chatcompletionstreamrequest.py +22 -5
  47. mistralai/models/classificationrequest.py +37 -3
  48. mistralai/models/conversationrequest.py +15 -4
  49. mistralai/models/conversationrestartrequest.py +50 -2
  50. mistralai/models/conversationrestartstreamrequest.py +50 -2
  51. mistralai/models/conversationstreamrequest.py +15 -4
  52. mistralai/models/documentout.py +26 -10
  53. mistralai/models/documentupdatein.py +24 -3
  54. mistralai/models/embeddingrequest.py +19 -11
  55. mistralai/models/files_api_routes_list_filesop.py +7 -0
  56. mistralai/models/fimcompletionrequest.py +8 -9
  57. mistralai/models/fimcompletionstreamrequest.py +8 -9
  58. mistralai/models/jobs_api_routes_batch_get_batch_jobop.py +40 -3
  59. mistralai/models/libraries_documents_list_v1op.py +15 -2
  60. mistralai/models/libraryout.py +10 -7
  61. mistralai/models/listfilesout.py +35 -4
  62. mistralai/models/modelcapabilities.py +13 -4
  63. mistralai/models/modelconversation.py +8 -2
  64. mistralai/models/ocrpageobject.py +26 -5
  65. mistralai/models/ocrrequest.py +17 -1
  66. mistralai/models/ocrtableobject.py +31 -0
  67. mistralai/models/prediction.py +4 -0
  68. mistralai/models/requestsource.py +7 -0
  69. mistralai/models/responseformat.py +4 -2
  70. mistralai/models/responseformats.py +0 -1
  71. mistralai/models/sharingdelete.py +36 -5
  72. mistralai/models/sharingin.py +36 -5
  73. mistralai/models/sharingout.py +3 -3
  74. mistralai/models/toolexecutiondeltaevent.py +13 -4
  75. mistralai/models/toolexecutiondoneevent.py +13 -4
  76. mistralai/models/toolexecutionentry.py +9 -4
  77. mistralai/models/toolexecutionstartedevent.py +13 -4
  78. mistralai/models/toolfilechunk.py +11 -4
  79. mistralai/models/toolreferencechunk.py +13 -4
  80. mistralai/models_.py +2 -14
  81. mistralai/ocr.py +18 -0
  82. mistralai/transcriptions.py +4 -4
  83. {mistralai-1.9.11.dist-info → mistralai-1.10.1.dist-info}/METADATA +162 -152
  84. {mistralai-1.9.11.dist-info → mistralai-1.10.1.dist-info}/RECORD +168 -144
  85. {mistralai-1.9.11.dist-info → mistralai-1.10.1.dist-info}/WHEEL +1 -1
  86. mistralai_azure/_version.py +3 -3
  87. mistralai_azure/basesdk.py +15 -5
  88. mistralai_azure/chat.py +59 -98
  89. mistralai_azure/models/__init__.py +50 -3
  90. mistralai_azure/models/chatcompletionrequest.py +16 -4
  91. mistralai_azure/models/chatcompletionstreamrequest.py +16 -4
  92. mistralai_azure/models/httpvalidationerror.py +11 -6
  93. mistralai_azure/models/mistralazureerror.py +26 -0
  94. mistralai_azure/models/no_response_error.py +13 -0
  95. mistralai_azure/models/prediction.py +4 -0
  96. mistralai_azure/models/responseformat.py +4 -2
  97. mistralai_azure/models/responseformats.py +0 -1
  98. mistralai_azure/models/responsevalidationerror.py +25 -0
  99. mistralai_azure/models/sdkerror.py +30 -14
  100. mistralai_azure/models/systemmessage.py +7 -3
  101. mistralai_azure/models/systemmessagecontentchunks.py +21 -0
  102. mistralai_azure/models/thinkchunk.py +35 -0
  103. mistralai_azure/ocr.py +15 -36
  104. mistralai_azure/utils/__init__.py +18 -5
  105. mistralai_azure/utils/eventstreaming.py +10 -0
  106. mistralai_azure/utils/serializers.py +3 -2
  107. mistralai_azure/utils/unmarshal_json_response.py +24 -0
  108. mistralai_gcp/_hooks/types.py +7 -0
  109. mistralai_gcp/_version.py +4 -4
  110. mistralai_gcp/basesdk.py +27 -25
  111. mistralai_gcp/chat.py +75 -98
  112. mistralai_gcp/fim.py +39 -74
  113. mistralai_gcp/httpclient.py +6 -16
  114. mistralai_gcp/models/__init__.py +321 -116
  115. mistralai_gcp/models/assistantmessage.py +1 -1
  116. mistralai_gcp/models/chatcompletionrequest.py +36 -7
  117. mistralai_gcp/models/chatcompletionresponse.py +6 -6
  118. mistralai_gcp/models/chatcompletionstreamrequest.py +36 -7
  119. mistralai_gcp/models/completionresponsestreamchoice.py +1 -1
  120. mistralai_gcp/models/deltamessage.py +1 -1
  121. mistralai_gcp/models/fimcompletionrequest.py +3 -9
  122. mistralai_gcp/models/fimcompletionresponse.py +6 -6
  123. mistralai_gcp/models/fimcompletionstreamrequest.py +3 -9
  124. mistralai_gcp/models/httpvalidationerror.py +11 -6
  125. mistralai_gcp/models/imageurl.py +1 -1
  126. mistralai_gcp/models/jsonschema.py +1 -1
  127. mistralai_gcp/models/mistralgcperror.py +26 -0
  128. mistralai_gcp/models/mistralpromptmode.py +8 -0
  129. mistralai_gcp/models/no_response_error.py +13 -0
  130. mistralai_gcp/models/prediction.py +4 -0
  131. mistralai_gcp/models/responseformat.py +5 -3
  132. mistralai_gcp/models/responseformats.py +0 -1
  133. mistralai_gcp/models/responsevalidationerror.py +25 -0
  134. mistralai_gcp/models/sdkerror.py +30 -14
  135. mistralai_gcp/models/systemmessage.py +7 -3
  136. mistralai_gcp/models/systemmessagecontentchunks.py +21 -0
  137. mistralai_gcp/models/thinkchunk.py +35 -0
  138. mistralai_gcp/models/toolmessage.py +1 -1
  139. mistralai_gcp/models/usageinfo.py +71 -8
  140. mistralai_gcp/models/usermessage.py +1 -1
  141. mistralai_gcp/sdk.py +12 -10
  142. mistralai_gcp/sdkconfiguration.py +0 -7
  143. mistralai_gcp/types/basemodel.py +3 -3
  144. mistralai_gcp/utils/__init__.py +143 -45
  145. mistralai_gcp/utils/datetimes.py +23 -0
  146. mistralai_gcp/utils/enums.py +67 -27
  147. mistralai_gcp/utils/eventstreaming.py +10 -0
  148. mistralai_gcp/utils/forms.py +49 -28
  149. mistralai_gcp/utils/serializers.py +33 -3
  150. mistralai_gcp/utils/unmarshal_json_response.py +24 -0
  151. {mistralai-1.9.11.dist-info → mistralai-1.10.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,372 @@
1
+ import copy
2
+ import json
3
+ import logging
4
+ import os
5
+ import traceback
6
+ from datetime import datetime, timezone
7
+ from enum import Enum
8
+
9
+ import httpx
10
+ import opentelemetry.semconv._incubating.attributes.gen_ai_attributes as gen_ai_attributes
11
+ import opentelemetry.semconv._incubating.attributes.http_attributes as http_attributes
12
+ import opentelemetry.semconv.attributes.server_attributes as server_attributes
13
+ from opentelemetry import propagate, trace
14
+ from opentelemetry.sdk.trace import SpanProcessor
15
+ from opentelemetry.trace import Span, Status, StatusCode, Tracer, set_span_in_context
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ OTEL_SERVICE_NAME: str = "mistralai_sdk"
21
+ MISTRAL_SDK_OTEL_TRACER_NAME: str = OTEL_SERVICE_NAME + "_tracer"
22
+
23
+ MISTRAL_SDK_DEBUG_TRACING: bool = os.getenv("MISTRAL_SDK_DEBUG_TRACING", "false").lower() == "true"
24
+ DEBUG_HINT: str = "To see detailed tracing logs, set MISTRAL_SDK_DEBUG_TRACING=true."
25
+
26
+
27
+ class MistralAIAttributes:
28
+ MISTRAL_AI_TOTAL_TOKENS = "mistral_ai.request.total_tokens"
29
+ MISTRAL_AI_TOOL_CALL_ARGUMENTS = "mistral_ai.tool.call.arguments"
30
+ MISTRAL_AI_MESSAGE_ID = "mistral_ai.message.id"
31
+ MISTRAL_AI_OPERATION_NAME= "mistral_ai.operation.name"
32
+ MISTRAL_AI_OCR_USAGE_PAGES_PROCESSED = "mistral_ai.ocr.usage.pages_processed"
33
+ MISTRAL_AI_OCR_USAGE_DOC_SIZE_BYTES = "mistral_ai.ocr.usage.doc_size_bytes"
34
+ MISTRAL_AI_OPERATION_ID = "mistral_ai.operation.id"
35
+ MISTRAL_AI_ERROR_TYPE = "mistral_ai.error.type"
36
+ MISTRAL_AI_ERROR_MESSAGE = "mistral_ai.error.message"
37
+ MISTRAL_AI_ERROR_CODE = "mistral_ai.error.code"
38
+ MISTRAL_AI_FUNCTION_CALL_ARGUMENTS = "mistral_ai.function.call.arguments"
39
+
40
+ class MistralAINameValues(Enum):
41
+ OCR = "ocr"
42
+
43
+ class TracingErrors(Exception, Enum):
44
+ FAILED_TO_CREATE_SPAN_FOR_REQUEST = "Failed to create span for request."
45
+ FAILED_TO_ENRICH_SPAN_WITH_RESPONSE = "Failed to enrich span with response."
46
+ FAILED_TO_HANDLE_ERROR_IN_SPAN = "Failed to handle error in span."
47
+ FAILED_TO_END_SPAN = "Failed to end span."
48
+
49
+ def __str__(self):
50
+ return str(self.value)
51
+
52
+ class GenAISpanEnum(str, Enum):
53
+ CONVERSATION = "conversation"
54
+ CONV_REQUEST = "POST /v1/conversations"
55
+ EXECUTE_TOOL = "execute_tool"
56
+ VALIDATE_RUN = "validate_run"
57
+
58
+ @staticmethod
59
+ def function_call(func_name: str):
60
+ return f"function_call[{func_name}]"
61
+
62
+
63
+ def parse_time_to_nanos(ts: str) -> int:
64
+ dt = datetime.fromisoformat(ts.replace("Z", "+00:00")).astimezone(timezone.utc)
65
+ return int(dt.timestamp() * 1e9)
66
+
67
+ def set_available_attributes(span: Span, attributes: dict) -> None:
68
+ for attribute, value in attributes.items():
69
+ if value:
70
+ span.set_attribute(attribute, value)
71
+
72
+
73
+ def enrich_span_from_request(span: Span, request: httpx.Request) -> Span:
74
+ if not request.url.port:
75
+ # From httpx doc:
76
+ # Note that the URL class performs port normalization as per the WHATWG spec.
77
+ # Default ports for "http", "https", "ws", "wss", and "ftp" schemes are always treated as None.
78
+ # Handling default ports since most of the time we are using https
79
+ if request.url.scheme == "https":
80
+ port = 443
81
+ elif request.url.scheme == "http":
82
+ port = 80
83
+ else:
84
+ port = -1
85
+ else:
86
+ port = request.url.port
87
+
88
+ span.set_attributes({
89
+ http_attributes.HTTP_REQUEST_METHOD: request.method,
90
+ http_attributes.HTTP_URL: str(request.url),
91
+ server_attributes.SERVER_ADDRESS: request.headers.get("host", ""),
92
+ server_attributes.SERVER_PORT: port
93
+ })
94
+ if request._content:
95
+ request_body = json.loads(request._content)
96
+
97
+ attributes = {
98
+ gen_ai_attributes.GEN_AI_REQUEST_CHOICE_COUNT: request_body.get("n", None),
99
+ gen_ai_attributes.GEN_AI_REQUEST_ENCODING_FORMATS: request_body.get("encoding_formats", None),
100
+ gen_ai_attributes.GEN_AI_REQUEST_FREQUENCY_PENALTY: request_body.get("frequency_penalty", None),
101
+ gen_ai_attributes.GEN_AI_REQUEST_MAX_TOKENS: request_body.get("max_tokens", None),
102
+ gen_ai_attributes.GEN_AI_REQUEST_MODEL: request_body.get("model", None),
103
+ gen_ai_attributes.GEN_AI_REQUEST_PRESENCE_PENALTY: request_body.get("presence_penalty", None),
104
+ gen_ai_attributes.GEN_AI_REQUEST_SEED: request_body.get("random_seed", None),
105
+ gen_ai_attributes.GEN_AI_REQUEST_STOP_SEQUENCES: request_body.get("stop", None),
106
+ gen_ai_attributes.GEN_AI_REQUEST_TEMPERATURE: request_body.get("temperature", None),
107
+ gen_ai_attributes.GEN_AI_REQUEST_TOP_P: request_body.get("top_p", None),
108
+ gen_ai_attributes.GEN_AI_REQUEST_TOP_K: request_body.get("top_k", None),
109
+ # Input messages are likely to be large, containing user/PII data and other sensitive information.
110
+ # Also structured attributes are not yet supported on spans in Python.
111
+ # For those reasons, we will not record the input messages for now.
112
+ gen_ai_attributes.GEN_AI_INPUT_MESSAGES: None,
113
+ }
114
+ # Set attributes only if they are not None.
115
+ # From OpenTelemetry documentation: None is not a valid attribute value per spec / is not a permitted value type for an attribute.
116
+ set_available_attributes(span, attributes)
117
+ return span
118
+
119
+
120
+ def enrich_span_from_response(tracer: trace.Tracer, span: Span, operation_id: str, response: httpx.Response) -> None:
121
+ span.set_status(Status(StatusCode.OK))
122
+ response_data = json.loads(response.content)
123
+
124
+ # Base attributes
125
+ attributes: dict[str, str | int] = {
126
+ http_attributes.HTTP_RESPONSE_STATUS_CODE: response.status_code,
127
+ MistralAIAttributes.MISTRAL_AI_OPERATION_ID: operation_id,
128
+ gen_ai_attributes.GEN_AI_PROVIDER_NAME: gen_ai_attributes.GenAiProviderNameValues.MISTRAL_AI.value
129
+ }
130
+
131
+ # Add usage attributes if available
132
+ usage = response_data.get("usage", {})
133
+ if usage:
134
+ attributes.update({
135
+ gen_ai_attributes.GEN_AI_USAGE_PROMPT_TOKENS: usage.get("prompt_tokens", 0),
136
+ gen_ai_attributes.GEN_AI_USAGE_OUTPUT_TOKENS: usage.get("completion_tokens", 0),
137
+ MistralAIAttributes.MISTRAL_AI_TOTAL_TOKENS: usage.get("total_tokens", 0)
138
+ })
139
+
140
+ span.set_attributes(attributes)
141
+ if operation_id == "agents_api_v1_agents_create":
142
+ # Semantics from https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-agent-spans/#create-agent-span
143
+ agent_attributes = {
144
+ gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.CREATE_AGENT.value,
145
+ gen_ai_attributes.GEN_AI_AGENT_DESCRIPTION: response_data.get("description", ""),
146
+ gen_ai_attributes.GEN_AI_AGENT_ID: response_data.get("id", ""),
147
+ gen_ai_attributes.GEN_AI_AGENT_NAME: response_data.get("name", ""),
148
+ gen_ai_attributes.GEN_AI_REQUEST_MODEL: response_data.get("model", ""),
149
+ gen_ai_attributes.GEN_AI_SYSTEM_INSTRUCTIONS: response_data.get("instructions", "")
150
+ }
151
+ span.set_attributes(agent_attributes)
152
+ if operation_id in ["agents_api_v1_conversations_start", "agents_api_v1_conversations_append"]:
153
+ outputs = response_data.get("outputs", [])
154
+ conversation_attributes = {
155
+ gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.INVOKE_AGENT.value,
156
+ gen_ai_attributes.GEN_AI_CONVERSATION_ID: response_data.get("conversation_id", "")
157
+ }
158
+ span.set_attributes(conversation_attributes)
159
+ parent_context = set_span_in_context(span)
160
+
161
+ for output in outputs:
162
+ # TODO: Only enrich the spans if it's a single turn conversation.
163
+ # Multi turn conversations are handled in the extra.run.tools.create_function_result function
164
+ if output["type"] == "function.call":
165
+ pass
166
+ if output["type"] == "tool.execution":
167
+ start_ns = parse_time_to_nanos(output["created_at"])
168
+ end_ns = parse_time_to_nanos(output["completed_at"])
169
+ child_span = tracer.start_span("Tool Execution", start_time=start_ns, context=parent_context)
170
+ child_span.set_attributes({"agent.trace.public": ""})
171
+ tool_attributes = {
172
+ gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.EXECUTE_TOOL.value,
173
+ gen_ai_attributes.GEN_AI_TOOL_CALL_ID: output.get("id", ""),
174
+ MistralAIAttributes.MISTRAL_AI_TOOL_CALL_ARGUMENTS: output.get("arguments", ""),
175
+ gen_ai_attributes.GEN_AI_TOOL_NAME: output.get("name", "")
176
+ }
177
+ child_span.set_attributes(tool_attributes)
178
+ child_span.end(end_time=end_ns)
179
+ if output["type"] == "message.output":
180
+ start_ns = parse_time_to_nanos(output["created_at"])
181
+ end_ns = parse_time_to_nanos(output["completed_at"])
182
+ child_span = tracer.start_span("Message Output", start_time=start_ns, context=parent_context)
183
+ child_span.set_attributes({"agent.trace.public": ""})
184
+ message_attributes = {
185
+ gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.CHAT.value,
186
+ gen_ai_attributes.GEN_AI_PROVIDER_NAME: gen_ai_attributes.GenAiProviderNameValues.MISTRAL_AI.value,
187
+ MistralAIAttributes.MISTRAL_AI_MESSAGE_ID: output.get("id", ""),
188
+ gen_ai_attributes.GEN_AI_AGENT_ID: output.get("agent_id", ""),
189
+ gen_ai_attributes.GEN_AI_REQUEST_MODEL: output.get("model", "")
190
+ }
191
+ child_span.set_attributes(message_attributes)
192
+ child_span.end(end_time=end_ns)
193
+ if operation_id == "ocr_v1_ocr_post":
194
+ usage_info = response_data.get("usage_info", "")
195
+ ocr_attributes = {
196
+ MistralAIAttributes.MISTRAL_AI_OPERATION_NAME: MistralAINameValues.OCR.value,
197
+ MistralAIAttributes.MISTRAL_AI_OCR_USAGE_PAGES_PROCESSED: usage_info.get("pages_processed", "") if usage_info else "",
198
+ MistralAIAttributes.MISTRAL_AI_OCR_USAGE_DOC_SIZE_BYTES: usage_info.get("doc_size_bytes", "") if usage_info else "",
199
+ gen_ai_attributes.GEN_AI_REQUEST_MODEL: response_data.get("model", "")
200
+ }
201
+ span.set_attributes(ocr_attributes)
202
+
203
+
204
+ class GenAISpanProcessor(SpanProcessor):
205
+ def on_start(self, span, parent_context = None):
206
+ span.set_attributes({"agent.trace.public": ""})
207
+
208
+
209
+ def get_or_create_otel_tracer() -> tuple[bool, Tracer]:
210
+ """
211
+ Get a tracer from the current TracerProvider.
212
+
213
+ The SDK does not set up its own TracerProvider - it relies on the application
214
+ to configure OpenTelemetry. This follows OTEL best practices where:
215
+ - Libraries/SDKs get tracers from the global provider
216
+ - Applications configure the TracerProvider
217
+
218
+ If no TracerProvider is configured, the ProxyTracerProvider (default) will
219
+ return a NoOp tracer, effectively disabling tracing. Once the application
220
+ sets up a real TracerProvider, subsequent spans will be recorded.
221
+
222
+ Returns:
223
+ Tuple[bool, Tracer]: (tracing_enabled, tracer)
224
+ - tracing_enabled is True if a real TracerProvider is configured
225
+ - tracer is always valid (may be NoOp if no provider configured)
226
+ """
227
+ tracer_provider = trace.get_tracer_provider()
228
+ tracer = tracer_provider.get_tracer(MISTRAL_SDK_OTEL_TRACER_NAME)
229
+
230
+ # Tracing is considered enabled if we have a real TracerProvider (not the default proxy)
231
+ tracing_enabled = not isinstance(tracer_provider, trace.ProxyTracerProvider)
232
+
233
+ return tracing_enabled, tracer
234
+
235
+ def get_traced_request_and_span(
236
+ tracing_enabled: bool,
237
+ tracer: Tracer,
238
+ span: Span | None,
239
+ operation_id: str,
240
+ request: httpx.Request,
241
+ ) -> tuple[httpx.Request, Span | None]:
242
+ if not tracing_enabled:
243
+ return request, span
244
+
245
+ try:
246
+ span = tracer.start_span(name=operation_id)
247
+ span.set_attributes({"agent.trace.public": ""})
248
+ # Inject the span context into the request headers to be used by the backend service to continue the trace
249
+ propagate.inject(request.headers, context=set_span_in_context(span))
250
+ span = enrich_span_from_request(span, request)
251
+ except Exception:
252
+ logger.warning(
253
+ "%s %s",
254
+ TracingErrors.FAILED_TO_CREATE_SPAN_FOR_REQUEST,
255
+ traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT,
256
+ )
257
+ if span:
258
+ end_span(span=span)
259
+ span = None
260
+
261
+ return request, span
262
+
263
+
264
+ def get_traced_response(
265
+ tracing_enabled: bool,
266
+ tracer: Tracer,
267
+ span: Span | None,
268
+ operation_id: str,
269
+ response: httpx.Response,
270
+ ) -> httpx.Response:
271
+ if not tracing_enabled or not span:
272
+ return response
273
+ try:
274
+ is_stream_response = not response.is_closed and not response.is_stream_consumed
275
+ if is_stream_response:
276
+ return TracedResponse.from_response(resp=response, span=span)
277
+ enrich_span_from_response(
278
+ tracer, span, operation_id, response
279
+ )
280
+ except Exception:
281
+ logger.warning(
282
+ "%s %s",
283
+ TracingErrors.FAILED_TO_ENRICH_SPAN_WITH_RESPONSE,
284
+ traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT,
285
+ )
286
+ if span:
287
+ end_span(span=span)
288
+ return response
289
+
290
+ def get_response_and_error(
291
+ tracing_enabled: bool,
292
+ tracer: Tracer,
293
+ span: Span | None,
294
+ operation_id: str,
295
+ response: httpx.Response,
296
+ error: Exception | None,
297
+ ) -> tuple[httpx.Response, Exception | None]:
298
+ if not tracing_enabled or not span:
299
+ return response, error
300
+ try:
301
+ if error:
302
+ span.record_exception(error)
303
+ span.set_status(Status(StatusCode.ERROR, str(error)))
304
+ if hasattr(response, "_content") and response._content:
305
+ response_body = json.loads(response._content)
306
+ if response_body.get("object", "") == "error":
307
+ if error_msg := response_body.get("message", ""):
308
+ attributes = {
309
+ http_attributes.HTTP_RESPONSE_STATUS_CODE: response.status_code,
310
+ MistralAIAttributes.MISTRAL_AI_ERROR_TYPE: response_body.get("type", ""),
311
+ MistralAIAttributes.MISTRAL_AI_ERROR_MESSAGE: error_msg,
312
+ MistralAIAttributes.MISTRAL_AI_ERROR_CODE: response_body.get("code", ""),
313
+ }
314
+ for attribute, value in attributes.items():
315
+ if value:
316
+ span.set_attribute(attribute, value)
317
+ span.end()
318
+ span = None
319
+ except Exception:
320
+ logger.warning(
321
+ "%s %s",
322
+ TracingErrors.FAILED_TO_HANDLE_ERROR_IN_SPAN,
323
+ traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT,
324
+ )
325
+
326
+ if span:
327
+ span.end()
328
+ span = None
329
+ return response, error
330
+
331
+
332
+ def end_span(span: Span) -> None:
333
+ try:
334
+ span.end()
335
+ except Exception:
336
+ logger.warning(
337
+ "%s %s",
338
+ TracingErrors.FAILED_TO_END_SPAN,
339
+ traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT,
340
+ )
341
+
342
+ class TracedResponse(httpx.Response):
343
+ """
344
+ TracedResponse is a subclass of httpx.Response that ends the span when the response is closed.
345
+
346
+ This hack allows ending the span only once the stream is fully consumed.
347
+ """
348
+ def __init__(self, *args, span: Span | None, **kwargs) -> None:
349
+ super().__init__(*args, **kwargs)
350
+ self.span = span
351
+
352
+ def close(self) -> None:
353
+ if self.span:
354
+ end_span(span=self.span)
355
+ super().close()
356
+
357
+ async def aclose(self) -> None:
358
+ if self.span:
359
+ end_span(span=self.span)
360
+ await super().aclose()
361
+
362
+ @classmethod
363
+ def from_response(cls, resp: httpx.Response, span: Span | None) -> "TracedResponse":
364
+ traced_resp = cls.__new__(cls)
365
+ traced_resp.__dict__ = copy.copy(resp.__dict__)
366
+ traced_resp.span = span
367
+
368
+ # Warning: this syntax bypasses the __init__ method.
369
+ # If you add init logic in the TracedResponse.__init__ method, you will need to add the following line for it to execute:
370
+ # traced_resp.__init__(your_arguments)
371
+
372
+ return traced_resp
@@ -1,47 +1,41 @@
1
1
  import asyncio
2
2
  import inspect
3
3
  import typing
4
- from contextlib import AsyncExitStack
5
- from functools import wraps
6
4
  from collections.abc import Callable
7
-
5
+ from contextlib import AsyncExitStack
8
6
  from dataclasses import dataclass, field
9
- from typing import Union, Optional
7
+ from functools import wraps
8
+ from logging import getLogger
10
9
 
11
10
  import pydantic
12
11
 
13
- from mistralai.extra import (
14
- response_format_from_pydantic_model,
15
- )
12
+ from mistralai.extra import response_format_from_pydantic_model
16
13
  from mistralai.extra.exceptions import RunException
17
14
  from mistralai.extra.mcp.base import MCPClientProtocol
18
15
  from mistralai.extra.run.result import RunResult
19
- from mistralai.types.basemodel import OptionalNullable, BaseModel, UNSET
16
+ from mistralai.extra.run.tools import (
17
+ RunCoroutine,
18
+ RunFunction,
19
+ RunMCPTool,
20
+ RunTool,
21
+ create_function_result,
22
+ create_tool_call,
23
+ )
20
24
  from mistralai.models import (
21
- ResponseFormat,
22
- FunctionCallEntry,
23
- Tools,
24
- ToolsTypedDict,
25
25
  CompletionArgs,
26
26
  CompletionArgsTypedDict,
27
- FunctionResultEntry,
28
27
  ConversationInputs,
29
28
  ConversationInputsTypedDict,
29
+ FunctionCallEntry,
30
+ FunctionResultEntry,
30
31
  FunctionTool,
31
- MessageInputEntry,
32
32
  InputEntries,
33
+ MessageInputEntry,
34
+ ResponseFormat,
35
+ Tools,
36
+ ToolsTypedDict,
33
37
  )
34
-
35
- from logging import getLogger
36
-
37
- from mistralai.extra.run.tools import (
38
- create_function_result,
39
- RunFunction,
40
- create_tool_call,
41
- RunTool,
42
- RunMCPTool,
43
- RunCoroutine,
44
- )
38
+ from mistralai.types.basemodel import BaseModel, OptionalNullable, UNSET
45
39
 
46
40
  if typing.TYPE_CHECKING:
47
41
  from mistralai import Beta, OptionalNullable
@@ -56,8 +50,8 @@ class AgentRequestKwargs(typing.TypedDict):
56
50
  class ModelRequestKwargs(typing.TypedDict):
57
51
  model: str
58
52
  instructions: OptionalNullable[str]
59
- tools: OptionalNullable[Union[list[Tools], list[ToolsTypedDict]]]
60
- completion_args: OptionalNullable[Union[CompletionArgs, CompletionArgsTypedDict]]
53
+ tools: OptionalNullable[list[Tools] | list[ToolsTypedDict]]
54
+ completion_args: OptionalNullable[CompletionArgs | CompletionArgsTypedDict]
61
55
 
62
56
 
63
57
  @dataclass
@@ -72,7 +66,7 @@ class RunContext:
72
66
  passed if the user wants to continue an existing conversation.
73
67
  model (Options[str]): The model name to be used for the conversation. Can't be used along with 'agent_id'.
74
68
  agent_id (Options[str]): The agent id to be used for the conversation. Can't be used along with 'model'.
75
- output_format (Optional[type[BaseModel]]): The output format expected from the conversation. It represents
69
+ output_format (type[BaseModel] | None): The output format expected from the conversation. It represents
76
70
  the `response_format` which is part of the `CompletionArgs`.
77
71
  request_count (int): The number of requests made in the current `RunContext`.
78
72
  continue_on_fn_error (bool): Flag to determine if the conversation should continue when function execution
@@ -83,10 +77,10 @@ class RunContext:
83
77
  _callable_tools: dict[str, RunTool] = field(init=False, default_factory=dict)
84
78
  _mcp_clients: list[MCPClientProtocol] = field(init=False, default_factory=list)
85
79
 
86
- conversation_id: Optional[str] = field(default=None)
87
- model: Optional[str] = field(default=None)
88
- agent_id: Optional[str] = field(default=None)
89
- output_format: Optional[type[BaseModel]] = field(default=None)
80
+ conversation_id: str | None = field(default=None)
81
+ model: str | None = field(default=None)
82
+ agent_id: str | None = field(default=None)
83
+ output_format: type[BaseModel] | None = field(default=None)
90
84
  request_count: int = field(default=0)
91
85
  continue_on_fn_error: bool = field(default=False)
92
86
 
@@ -215,10 +209,8 @@ class RunContext:
215
209
 
216
210
  async def prepare_model_request(
217
211
  self,
218
- tools: OptionalNullable[Union[list[Tools], list[ToolsTypedDict]]] = UNSET,
219
- completion_args: OptionalNullable[
220
- Union[CompletionArgs, CompletionArgsTypedDict]
221
- ] = UNSET,
212
+ tools: OptionalNullable[list[Tools] | list[ToolsTypedDict]] = UNSET,
213
+ completion_args: OptionalNullable[CompletionArgs | CompletionArgsTypedDict] = UNSET,
222
214
  instructions: OptionalNullable[str] = None,
223
215
  ) -> ModelRequestKwargs:
224
216
  if self.model is None:
@@ -254,14 +246,12 @@ async def _validate_run(
254
246
  *,
255
247
  beta_client: "Beta",
256
248
  run_ctx: RunContext,
257
- inputs: Union[ConversationInputs, ConversationInputsTypedDict],
249
+ inputs: ConversationInputs | ConversationInputsTypedDict,
258
250
  instructions: OptionalNullable[str] = UNSET,
259
- tools: OptionalNullable[Union[list[Tools], list[ToolsTypedDict]]] = UNSET,
260
- completion_args: OptionalNullable[
261
- Union[CompletionArgs, CompletionArgsTypedDict]
262
- ] = UNSET,
251
+ tools: OptionalNullable[list[Tools] | list[ToolsTypedDict]] = UNSET,
252
+ completion_args: OptionalNullable[CompletionArgs | CompletionArgsTypedDict] = UNSET,
263
253
  ) -> tuple[
264
- Union[AgentRequestKwargs, ModelRequestKwargs], RunResult, list[InputEntries]
254
+ AgentRequestKwargs | ModelRequestKwargs, RunResult, list[InputEntries]
265
255
  ]:
266
256
  input_entries: list[InputEntries] = []
267
257
  if isinstance(inputs, str):
@@ -277,7 +267,7 @@ async def _validate_run(
277
267
  output_model=run_ctx.output_format,
278
268
  conversation_id=run_ctx.conversation_id,
279
269
  )
280
- req: Union[AgentRequestKwargs, ModelRequestKwargs]
270
+ req: AgentRequestKwargs | ModelRequestKwargs
281
271
  if run_ctx.agent_id:
282
272
  if tools or completion_args:
283
273
  raise RunException("Can't set tools or completion_args when using an agent")
@@ -1,9 +1,10 @@
1
1
  import datetime
2
2
  import json
3
3
  import typing
4
- from typing import Union, Annotated, Optional, Literal
5
4
  from dataclasses import dataclass, field
6
- from pydantic import Discriminator, Tag, BaseModel
5
+ from typing import Annotated, Literal
6
+
7
+ from pydantic import BaseModel, Discriminator, Tag
7
8
 
8
9
  from mistralai.extra.utils.response_format import pydantic_model_from_json
9
10
  from mistralai.models import (
@@ -35,15 +36,15 @@ from mistralai.models import (
35
36
  )
36
37
  from mistralai.utils import get_discriminator
37
38
 
38
- RunOutputEntries = typing.Union[
39
- MessageOutputEntry,
40
- FunctionCallEntry,
41
- FunctionResultEntry,
42
- AgentHandoffEntry,
43
- ToolExecutionEntry,
44
- ]
39
+ RunOutputEntries = (
40
+ MessageOutputEntry
41
+ | FunctionCallEntry
42
+ | FunctionResultEntry
43
+ | AgentHandoffEntry
44
+ | ToolExecutionEntry
45
+ )
45
46
 
46
- RunEntries = typing.Union[RunOutputEntries, MessageInputEntry]
47
+ RunEntries = RunOutputEntries | MessageInputEntry
47
48
 
48
49
 
49
50
  def as_text(entry: RunOutputEntries) -> str:
@@ -140,12 +141,12 @@ class RunFiles:
140
141
  @dataclass
141
142
  class RunResult:
142
143
  input_entries: list[InputEntries]
143
- conversation_id: Optional[str] = field(default=None)
144
+ conversation_id: str | None = field(default=None)
144
145
  output_entries: list[RunOutputEntries] = field(default_factory=list)
145
146
  files: dict[str, RunFiles] = field(default_factory=dict)
146
- output_model: Optional[type[BaseModel]] = field(default=None)
147
+ output_model: type[BaseModel] | None = field(default=None)
147
148
 
148
- def get_file(self, file_id: str) -> Optional[RunFiles]:
149
+ def get_file(self, file_id: str) -> RunFiles | None:
149
150
  return self.files.get(file_id)
150
151
 
151
152
  @property
@@ -172,36 +173,34 @@ class RunResult:
172
173
 
173
174
 
174
175
  class FunctionResultEvent(BaseModel):
175
- id: Optional[str] = None
176
+ id: str | None = None
176
177
 
177
- type: Optional[Literal["function.result"]] = "function.result"
178
+ type: Literal["function.result"] | None = "function.result"
178
179
 
179
180
  result: str
180
181
 
181
182
  tool_call_id: str
182
183
 
183
- created_at: Optional[datetime.datetime] = datetime.datetime.now(
184
+ created_at: datetime.datetime | None = datetime.datetime.now(
184
185
  tz=datetime.timezone.utc
185
186
  )
186
187
 
187
- output_index: Optional[int] = 0
188
+ output_index: int | None = 0
188
189
 
189
190
 
190
- RunResultEventsType = typing.Union[SSETypes, Literal["function.result"]]
191
+ RunResultEventsType = SSETypes | Literal["function.result"]
191
192
 
192
193
  RunResultEventsData = typing.Annotated[
193
- Union[
194
- Annotated[AgentHandoffDoneEvent, Tag("agent.handoff.done")],
195
- Annotated[AgentHandoffStartedEvent, Tag("agent.handoff.started")],
196
- Annotated[ResponseDoneEvent, Tag("conversation.response.done")],
197
- Annotated[ResponseErrorEvent, Tag("conversation.response.error")],
198
- Annotated[ResponseStartedEvent, Tag("conversation.response.started")],
199
- Annotated[FunctionCallEvent, Tag("function.call.delta")],
200
- Annotated[MessageOutputEvent, Tag("message.output.delta")],
201
- Annotated[ToolExecutionDoneEvent, Tag("tool.execution.done")],
202
- Annotated[ToolExecutionStartedEvent, Tag("tool.execution.started")],
203
- Annotated[FunctionResultEvent, Tag("function.result")],
204
- ],
194
+ Annotated[AgentHandoffDoneEvent, Tag("agent.handoff.done")]
195
+ | Annotated[AgentHandoffStartedEvent, Tag("agent.handoff.started")]
196
+ | Annotated[ResponseDoneEvent, Tag("conversation.response.done")]
197
+ | Annotated[ResponseErrorEvent, Tag("conversation.response.error")]
198
+ | Annotated[ResponseStartedEvent, Tag("conversation.response.started")]
199
+ | Annotated[FunctionCallEvent, Tag("function.call.delta")]
200
+ | Annotated[MessageOutputEvent, Tag("message.output.delta")]
201
+ | Annotated[ToolExecutionDoneEvent, Tag("tool.execution.done")]
202
+ | Annotated[ToolExecutionStartedEvent, Tag("tool.execution.started")]
203
+ | Annotated[FunctionResultEvent, Tag("function.result")],
205
204
  Discriminator(lambda m: get_discriminator(m, "type", "type")),
206
205
  ]
207
206