mistralai 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. mistralai/_hooks/tracing.py +28 -3
  2. mistralai/_version.py +2 -2
  3. mistralai/classifiers.py +13 -1
  4. mistralai/embeddings.py +7 -1
  5. mistralai/extra/README.md +1 -1
  6. mistralai/extra/mcp/auth.py +10 -11
  7. mistralai/extra/mcp/base.py +17 -16
  8. mistralai/extra/mcp/sse.py +13 -15
  9. mistralai/extra/mcp/stdio.py +5 -6
  10. mistralai/extra/observability/otel.py +47 -68
  11. mistralai/extra/run/context.py +33 -43
  12. mistralai/extra/run/result.py +29 -30
  13. mistralai/extra/run/tools.py +8 -9
  14. mistralai/extra/struct_chat.py +15 -8
  15. mistralai/extra/utils/response_format.py +5 -3
  16. mistralai/mistral_jobs.py +31 -5
  17. mistralai/models/__init__.py +30 -1
  18. mistralai/models/agents_api_v1_agents_listop.py +1 -1
  19. mistralai/models/agents_api_v1_conversations_listop.py +1 -1
  20. mistralai/models/audioencoding.py +13 -0
  21. mistralai/models/audioformat.py +19 -0
  22. mistralai/models/batchjobin.py +17 -6
  23. mistralai/models/batchjobout.py +5 -0
  24. mistralai/models/batchrequest.py +48 -0
  25. mistralai/models/classificationrequest.py +37 -3
  26. mistralai/models/embeddingrequest.py +11 -3
  27. mistralai/models/jobs_api_routes_batch_get_batch_jobop.py +40 -3
  28. mistralai/models/toolfilechunk.py +11 -4
  29. mistralai/models/toolreferencechunk.py +13 -4
  30. {mistralai-1.10.0.dist-info → mistralai-1.10.1.dist-info}/METADATA +142 -150
  31. {mistralai-1.10.0.dist-info → mistralai-1.10.1.dist-info}/RECORD +122 -105
  32. {mistralai-1.10.0.dist-info → mistralai-1.10.1.dist-info}/WHEEL +1 -1
  33. mistralai_azure/_version.py +3 -3
  34. mistralai_azure/basesdk.py +15 -5
  35. mistralai_azure/chat.py +59 -98
  36. mistralai_azure/models/__init__.py +50 -3
  37. mistralai_azure/models/chatcompletionrequest.py +16 -4
  38. mistralai_azure/models/chatcompletionstreamrequest.py +16 -4
  39. mistralai_azure/models/httpvalidationerror.py +11 -6
  40. mistralai_azure/models/mistralazureerror.py +26 -0
  41. mistralai_azure/models/no_response_error.py +13 -0
  42. mistralai_azure/models/prediction.py +4 -0
  43. mistralai_azure/models/responseformat.py +4 -2
  44. mistralai_azure/models/responseformats.py +0 -1
  45. mistralai_azure/models/responsevalidationerror.py +25 -0
  46. mistralai_azure/models/sdkerror.py +30 -14
  47. mistralai_azure/models/systemmessage.py +7 -3
  48. mistralai_azure/models/systemmessagecontentchunks.py +21 -0
  49. mistralai_azure/models/thinkchunk.py +35 -0
  50. mistralai_azure/ocr.py +15 -36
  51. mistralai_azure/utils/__init__.py +18 -5
  52. mistralai_azure/utils/eventstreaming.py +10 -0
  53. mistralai_azure/utils/serializers.py +3 -2
  54. mistralai_azure/utils/unmarshal_json_response.py +24 -0
  55. mistralai_gcp/_hooks/types.py +7 -0
  56. mistralai_gcp/_version.py +4 -4
  57. mistralai_gcp/basesdk.py +27 -25
  58. mistralai_gcp/chat.py +75 -98
  59. mistralai_gcp/fim.py +39 -74
  60. mistralai_gcp/httpclient.py +6 -16
  61. mistralai_gcp/models/__init__.py +321 -116
  62. mistralai_gcp/models/assistantmessage.py +1 -1
  63. mistralai_gcp/models/chatcompletionrequest.py +36 -7
  64. mistralai_gcp/models/chatcompletionresponse.py +6 -6
  65. mistralai_gcp/models/chatcompletionstreamrequest.py +36 -7
  66. mistralai_gcp/models/completionresponsestreamchoice.py +1 -1
  67. mistralai_gcp/models/deltamessage.py +1 -1
  68. mistralai_gcp/models/fimcompletionrequest.py +3 -9
  69. mistralai_gcp/models/fimcompletionresponse.py +6 -6
  70. mistralai_gcp/models/fimcompletionstreamrequest.py +3 -9
  71. mistralai_gcp/models/httpvalidationerror.py +11 -6
  72. mistralai_gcp/models/imageurl.py +1 -1
  73. mistralai_gcp/models/jsonschema.py +1 -1
  74. mistralai_gcp/models/mistralgcperror.py +26 -0
  75. mistralai_gcp/models/mistralpromptmode.py +8 -0
  76. mistralai_gcp/models/no_response_error.py +13 -0
  77. mistralai_gcp/models/prediction.py +4 -0
  78. mistralai_gcp/models/responseformat.py +5 -3
  79. mistralai_gcp/models/responseformats.py +0 -1
  80. mistralai_gcp/models/responsevalidationerror.py +25 -0
  81. mistralai_gcp/models/sdkerror.py +30 -14
  82. mistralai_gcp/models/systemmessage.py +7 -3
  83. mistralai_gcp/models/systemmessagecontentchunks.py +21 -0
  84. mistralai_gcp/models/thinkchunk.py +35 -0
  85. mistralai_gcp/models/toolmessage.py +1 -1
  86. mistralai_gcp/models/usageinfo.py +71 -8
  87. mistralai_gcp/models/usermessage.py +1 -1
  88. mistralai_gcp/sdk.py +12 -10
  89. mistralai_gcp/sdkconfiguration.py +0 -7
  90. mistralai_gcp/types/basemodel.py +3 -3
  91. mistralai_gcp/utils/__init__.py +143 -45
  92. mistralai_gcp/utils/datetimes.py +23 -0
  93. mistralai_gcp/utils/enums.py +67 -27
  94. mistralai_gcp/utils/eventstreaming.py +10 -0
  95. mistralai_gcp/utils/forms.py +49 -28
  96. mistralai_gcp/utils/serializers.py +33 -3
  97. mistralai_gcp/utils/unmarshal_json_response.py +24 -0
  98. {mistralai-1.10.0.dist-info → mistralai-1.10.1.dist-info}/licenses/LICENSE +0 -0
@@ -5,34 +5,23 @@ import os
5
5
  import traceback
6
6
  from datetime import datetime, timezone
7
7
  from enum import Enum
8
- from typing import Optional, Tuple
9
8
 
10
9
  import httpx
11
10
  import opentelemetry.semconv._incubating.attributes.gen_ai_attributes as gen_ai_attributes
12
11
  import opentelemetry.semconv._incubating.attributes.http_attributes as http_attributes
13
12
  import opentelemetry.semconv.attributes.server_attributes as server_attributes
14
13
  from opentelemetry import propagate, trace
15
- from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
16
- from opentelemetry.sdk.resources import SERVICE_NAME, Resource
17
- from opentelemetry.sdk.trace import SpanProcessor, TracerProvider
18
- from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExportResult
14
+ from opentelemetry.sdk.trace import SpanProcessor
19
15
  from opentelemetry.trace import Span, Status, StatusCode, Tracer, set_span_in_context
20
16
 
21
17
  logger = logging.getLogger(__name__)
22
18
 
23
19
 
24
20
  OTEL_SERVICE_NAME: str = "mistralai_sdk"
25
- OTEL_EXPORTER_OTLP_ENDPOINT: str = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "")
26
- OTEL_EXPORTER_OTLP_TIMEOUT: int = int(os.getenv("OTEL_EXPORTER_OTLP_TIMEOUT", "2"))
27
- OTEL_EXPORTER_OTLP_MAX_EXPORT_BATCH_SIZE: int = int(os.getenv("OTEL_EXPORTER_OTLP_MAX_EXPORT_BATCH_SIZE", "512"))
28
- OTEL_EXPORTER_OTLP_SCHEDULE_DELAY_MILLIS: int = int(os.getenv("OTEL_EXPORTER_OTLP_SCHEDULE_DELAY_MILLIS", "1000"))
29
- OTEL_EXPORTER_OTLP_MAX_QUEUE_SIZE: int = int(os.getenv("OTEL_EXPORTER_OTLP_MAX_QUEUE_SIZE", "2048"))
30
- OTEL_EXPORTER_OTLP_EXPORT_TIMEOUT_MILLIS: int = int(os.getenv("OTEL_EXPORTER_OTLP_EXPORT_TIMEOUT_MILLIS", "5000"))
31
-
32
21
  MISTRAL_SDK_OTEL_TRACER_NAME: str = OTEL_SERVICE_NAME + "_tracer"
33
22
 
34
23
  MISTRAL_SDK_DEBUG_TRACING: bool = os.getenv("MISTRAL_SDK_DEBUG_TRACING", "false").lower() == "true"
35
- DEBUG_HINT: str = "To see detailed exporter logs, set MISTRAL_SDK_DEBUG_TRACING=true."
24
+ DEBUG_HINT: str = "To see detailed tracing logs, set MISTRAL_SDK_DEBUG_TRACING=true."
36
25
 
37
26
 
38
27
  class MistralAIAttributes:
@@ -52,13 +41,11 @@ class MistralAINameValues(Enum):
52
41
  OCR = "ocr"
53
42
 
54
43
  class TracingErrors(Exception, Enum):
55
- FAILED_TO_EXPORT_OTEL_SPANS = "Failed to export OpenTelemetry (OTEL) spans."
56
- FAILED_TO_INITIALIZE_OPENTELEMETRY_TRACING = "Failed to initialize OpenTelemetry tracing."
57
44
  FAILED_TO_CREATE_SPAN_FOR_REQUEST = "Failed to create span for request."
58
45
  FAILED_TO_ENRICH_SPAN_WITH_RESPONSE = "Failed to enrich span with response."
59
46
  FAILED_TO_HANDLE_ERROR_IN_SPAN = "Failed to handle error in span."
60
47
  FAILED_TO_END_SPAN = "Failed to end span."
61
-
48
+
62
49
  def __str__(self):
63
50
  return str(self.value)
64
51
 
@@ -180,6 +167,7 @@ def enrich_span_from_response(tracer: trace.Tracer, span: Span, operation_id: st
180
167
  start_ns = parse_time_to_nanos(output["created_at"])
181
168
  end_ns = parse_time_to_nanos(output["completed_at"])
182
169
  child_span = tracer.start_span("Tool Execution", start_time=start_ns, context=parent_context)
170
+ child_span.set_attributes({"agent.trace.public": ""})
183
171
  tool_attributes = {
184
172
  gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.EXECUTE_TOOL.value,
185
173
  gen_ai_attributes.GEN_AI_TOOL_CALL_ID: output.get("id", ""),
@@ -192,6 +180,7 @@ def enrich_span_from_response(tracer: trace.Tracer, span: Span, operation_id: st
192
180
  start_ns = parse_time_to_nanos(output["created_at"])
193
181
  end_ns = parse_time_to_nanos(output["completed_at"])
194
182
  child_span = tracer.start_span("Message Output", start_time=start_ns, context=parent_context)
183
+ child_span.set_attributes({"agent.trace.public": ""})
195
184
  message_attributes = {
196
185
  gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.CHAT.value,
197
186
  gen_ai_attributes.GEN_AI_PROVIDER_NAME: gen_ai_attributes.GenAiProviderNameValues.MISTRAL_AI.value,
@@ -217,70 +206,47 @@ class GenAISpanProcessor(SpanProcessor):
217
206
  span.set_attributes({"agent.trace.public": ""})
218
207
 
219
208
 
220
- class QuietOTLPSpanExporter(OTLPSpanExporter):
221
- def export(self, spans):
222
- try:
223
- return super().export(spans)
224
- except Exception:
225
- logger.warning(f"{TracingErrors.FAILED_TO_EXPORT_OTEL_SPANS} {(traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT)}")
226
- return SpanExportResult.FAILURE
227
-
228
-
229
- def get_or_create_otel_tracer() -> Tuple[bool, Tracer]:
209
+ def get_or_create_otel_tracer() -> tuple[bool, Tracer]:
230
210
  """
231
- 3 possible cases:
211
+ Get a tracer from the current TracerProvider.
232
212
 
233
- -> [SDK in a Workflow / App] If there is already a tracer provider set -> use that one
213
+ The SDK does not set up its own TracerProvider - it relies on the application
214
+ to configure OpenTelemetry. This follows OTEL best practices where:
215
+ - Libraries/SDKs get tracers from the global provider
216
+ - Applications configure the TracerProvider
234
217
 
235
- -> [SDK standalone] If no tracer provider is set but the OTEL_EXPORTER_OTLP_ENDPOINT is set -> create a new tracer provider that exports to the OTEL_EXPORTER_OTLP_ENDPOINT
218
+ If no TracerProvider is configured, the ProxyTracerProvider (default) will
219
+ return a NoOp tracer, effectively disabling tracing. Once the application
220
+ sets up a real TracerProvider, subsequent spans will be recorded.
236
221
 
237
- -> Else tracing is disabled
222
+ Returns:
223
+ Tuple[bool, Tracer]: (tracing_enabled, tracer)
224
+ - tracing_enabled is True if a real TracerProvider is configured
225
+ - tracer is always valid (may be NoOp if no provider configured)
238
226
  """
239
- tracing_enabled = True
240
227
  tracer_provider = trace.get_tracer_provider()
241
-
242
- if isinstance(tracer_provider, trace.ProxyTracerProvider):
243
- if OTEL_EXPORTER_OTLP_ENDPOINT:
244
- # SDK standalone: No tracer provider but OTEL_EXPORTER_OTLP_ENDPOINT is set -> create a new tracer provider that exports to the OTEL_EXPORTER_OTLP_ENDPOINT
245
- try:
246
- exporter = QuietOTLPSpanExporter(
247
- endpoint=OTEL_EXPORTER_OTLP_ENDPOINT,
248
- timeout=OTEL_EXPORTER_OTLP_TIMEOUT
249
- )
250
- resource = Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME})
251
- tracer_provider = TracerProvider(resource=resource)
252
-
253
- span_processor = BatchSpanProcessor(
254
- exporter,
255
- export_timeout_millis=OTEL_EXPORTER_OTLP_EXPORT_TIMEOUT_MILLIS,
256
- max_export_batch_size=OTEL_EXPORTER_OTLP_MAX_EXPORT_BATCH_SIZE,
257
- schedule_delay_millis=OTEL_EXPORTER_OTLP_SCHEDULE_DELAY_MILLIS,
258
- max_queue_size=OTEL_EXPORTER_OTLP_MAX_QUEUE_SIZE
259
- )
260
-
261
- tracer_provider.add_span_processor(span_processor)
262
- tracer_provider.add_span_processor(GenAISpanProcessor())
263
- trace.set_tracer_provider(tracer_provider)
264
-
265
- except Exception:
266
- logger.warning(f"{TracingErrors.FAILED_TO_INITIALIZE_OPENTELEMETRY_TRACING} {(traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT)}")
267
- tracing_enabled = False
268
- else:
269
- # No tracer provider nor OTEL_EXPORTER_OTLP_ENDPOINT set -> tracing is disabled
270
- tracing_enabled = False
271
-
272
228
  tracer = tracer_provider.get_tracer(MISTRAL_SDK_OTEL_TRACER_NAME)
273
229
 
230
+ # Tracing is considered enabled if we have a real TracerProvider (not the default proxy)
231
+ tracing_enabled = not isinstance(tracer_provider, trace.ProxyTracerProvider)
232
+
274
233
  return tracing_enabled, tracer
275
234
 
276
- def get_traced_request_and_span(tracing_enabled: bool, tracer: Tracer, span: Optional[Span], operation_id: str, request: httpx.Request) -> Tuple[httpx.Request, Optional[Span]]:
235
+ def get_traced_request_and_span(
236
+ tracing_enabled: bool,
237
+ tracer: Tracer,
238
+ span: Span | None,
239
+ operation_id: str,
240
+ request: httpx.Request,
241
+ ) -> tuple[httpx.Request, Span | None]:
277
242
  if not tracing_enabled:
278
243
  return request, span
279
244
 
280
245
  try:
281
246
  span = tracer.start_span(name=operation_id)
247
+ span.set_attributes({"agent.trace.public": ""})
282
248
  # Inject the span context into the request headers to be used by the backend service to continue the trace
283
- propagate.inject(request.headers)
249
+ propagate.inject(request.headers, context=set_span_in_context(span))
284
250
  span = enrich_span_from_request(span, request)
285
251
  except Exception:
286
252
  logger.warning(
@@ -295,7 +261,13 @@ def get_traced_request_and_span(tracing_enabled: bool, tracer: Tracer, span: Opt
295
261
  return request, span
296
262
 
297
263
 
298
- def get_traced_response(tracing_enabled: bool, tracer: Tracer, span: Optional[Span], operation_id: str, response: httpx.Response) -> httpx.Response:
264
+ def get_traced_response(
265
+ tracing_enabled: bool,
266
+ tracer: Tracer,
267
+ span: Span | None,
268
+ operation_id: str,
269
+ response: httpx.Response,
270
+ ) -> httpx.Response:
299
271
  if not tracing_enabled or not span:
300
272
  return response
301
273
  try:
@@ -315,7 +287,14 @@ def get_traced_response(tracing_enabled: bool, tracer: Tracer, span: Optional[Sp
315
287
  end_span(span=span)
316
288
  return response
317
289
 
318
- def get_response_and_error(tracing_enabled: bool, tracer: Tracer, span: Optional[Span], operation_id: str, response: httpx.Response, error: Optional[Exception]) -> Tuple[httpx.Response, Optional[Exception]]:
290
+ def get_response_and_error(
291
+ tracing_enabled: bool,
292
+ tracer: Tracer,
293
+ span: Span | None,
294
+ operation_id: str,
295
+ response: httpx.Response,
296
+ error: Exception | None,
297
+ ) -> tuple[httpx.Response, Exception | None]:
319
298
  if not tracing_enabled or not span:
320
299
  return response, error
321
300
  try:
@@ -366,7 +345,7 @@ class TracedResponse(httpx.Response):
366
345
 
367
346
  This hack allows ending the span only once the stream is fully consumed.
368
347
  """
369
- def __init__(self, *args, span: Optional[Span], **kwargs) -> None:
348
+ def __init__(self, *args, span: Span | None, **kwargs) -> None:
370
349
  super().__init__(*args, **kwargs)
371
350
  self.span = span
372
351
 
@@ -381,7 +360,7 @@ class TracedResponse(httpx.Response):
381
360
  await super().aclose()
382
361
 
383
362
  @classmethod
384
- def from_response(cls, resp: httpx.Response, span: Optional[Span]) -> "TracedResponse":
363
+ def from_response(cls, resp: httpx.Response, span: Span | None) -> "TracedResponse":
385
364
  traced_resp = cls.__new__(cls)
386
365
  traced_resp.__dict__ = copy.copy(resp.__dict__)
387
366
  traced_resp.span = span
@@ -1,47 +1,41 @@
1
1
  import asyncio
2
2
  import inspect
3
3
  import typing
4
- from contextlib import AsyncExitStack
5
- from functools import wraps
6
4
  from collections.abc import Callable
7
-
5
+ from contextlib import AsyncExitStack
8
6
  from dataclasses import dataclass, field
9
- from typing import Union, Optional
7
+ from functools import wraps
8
+ from logging import getLogger
10
9
 
11
10
  import pydantic
12
11
 
13
- from mistralai.extra import (
14
- response_format_from_pydantic_model,
15
- )
12
+ from mistralai.extra import response_format_from_pydantic_model
16
13
  from mistralai.extra.exceptions import RunException
17
14
  from mistralai.extra.mcp.base import MCPClientProtocol
18
15
  from mistralai.extra.run.result import RunResult
19
- from mistralai.types.basemodel import OptionalNullable, BaseModel, UNSET
16
+ from mistralai.extra.run.tools import (
17
+ RunCoroutine,
18
+ RunFunction,
19
+ RunMCPTool,
20
+ RunTool,
21
+ create_function_result,
22
+ create_tool_call,
23
+ )
20
24
  from mistralai.models import (
21
- ResponseFormat,
22
- FunctionCallEntry,
23
- Tools,
24
- ToolsTypedDict,
25
25
  CompletionArgs,
26
26
  CompletionArgsTypedDict,
27
- FunctionResultEntry,
28
27
  ConversationInputs,
29
28
  ConversationInputsTypedDict,
29
+ FunctionCallEntry,
30
+ FunctionResultEntry,
30
31
  FunctionTool,
31
- MessageInputEntry,
32
32
  InputEntries,
33
+ MessageInputEntry,
34
+ ResponseFormat,
35
+ Tools,
36
+ ToolsTypedDict,
33
37
  )
34
-
35
- from logging import getLogger
36
-
37
- from mistralai.extra.run.tools import (
38
- create_function_result,
39
- RunFunction,
40
- create_tool_call,
41
- RunTool,
42
- RunMCPTool,
43
- RunCoroutine,
44
- )
38
+ from mistralai.types.basemodel import BaseModel, OptionalNullable, UNSET
45
39
 
46
40
  if typing.TYPE_CHECKING:
47
41
  from mistralai import Beta, OptionalNullable
@@ -56,8 +50,8 @@ class AgentRequestKwargs(typing.TypedDict):
56
50
  class ModelRequestKwargs(typing.TypedDict):
57
51
  model: str
58
52
  instructions: OptionalNullable[str]
59
- tools: OptionalNullable[Union[list[Tools], list[ToolsTypedDict]]]
60
- completion_args: OptionalNullable[Union[CompletionArgs, CompletionArgsTypedDict]]
53
+ tools: OptionalNullable[list[Tools] | list[ToolsTypedDict]]
54
+ completion_args: OptionalNullable[CompletionArgs | CompletionArgsTypedDict]
61
55
 
62
56
 
63
57
  @dataclass
@@ -72,7 +66,7 @@ class RunContext:
72
66
  passed if the user wants to continue an existing conversation.
73
67
  model (Options[str]): The model name to be used for the conversation. Can't be used along with 'agent_id'.
74
68
  agent_id (Options[str]): The agent id to be used for the conversation. Can't be used along with 'model'.
75
- output_format (Optional[type[BaseModel]]): The output format expected from the conversation. It represents
69
+ output_format (type[BaseModel] | None): The output format expected from the conversation. It represents
76
70
  the `response_format` which is part of the `CompletionArgs`.
77
71
  request_count (int): The number of requests made in the current `RunContext`.
78
72
  continue_on_fn_error (bool): Flag to determine if the conversation should continue when function execution
@@ -83,10 +77,10 @@ class RunContext:
83
77
  _callable_tools: dict[str, RunTool] = field(init=False, default_factory=dict)
84
78
  _mcp_clients: list[MCPClientProtocol] = field(init=False, default_factory=list)
85
79
 
86
- conversation_id: Optional[str] = field(default=None)
87
- model: Optional[str] = field(default=None)
88
- agent_id: Optional[str] = field(default=None)
89
- output_format: Optional[type[BaseModel]] = field(default=None)
80
+ conversation_id: str | None = field(default=None)
81
+ model: str | None = field(default=None)
82
+ agent_id: str | None = field(default=None)
83
+ output_format: type[BaseModel] | None = field(default=None)
90
84
  request_count: int = field(default=0)
91
85
  continue_on_fn_error: bool = field(default=False)
92
86
 
@@ -215,10 +209,8 @@ class RunContext:
215
209
 
216
210
  async def prepare_model_request(
217
211
  self,
218
- tools: OptionalNullable[Union[list[Tools], list[ToolsTypedDict]]] = UNSET,
219
- completion_args: OptionalNullable[
220
- Union[CompletionArgs, CompletionArgsTypedDict]
221
- ] = UNSET,
212
+ tools: OptionalNullable[list[Tools] | list[ToolsTypedDict]] = UNSET,
213
+ completion_args: OptionalNullable[CompletionArgs | CompletionArgsTypedDict] = UNSET,
222
214
  instructions: OptionalNullable[str] = None,
223
215
  ) -> ModelRequestKwargs:
224
216
  if self.model is None:
@@ -254,14 +246,12 @@ async def _validate_run(
254
246
  *,
255
247
  beta_client: "Beta",
256
248
  run_ctx: RunContext,
257
- inputs: Union[ConversationInputs, ConversationInputsTypedDict],
249
+ inputs: ConversationInputs | ConversationInputsTypedDict,
258
250
  instructions: OptionalNullable[str] = UNSET,
259
- tools: OptionalNullable[Union[list[Tools], list[ToolsTypedDict]]] = UNSET,
260
- completion_args: OptionalNullable[
261
- Union[CompletionArgs, CompletionArgsTypedDict]
262
- ] = UNSET,
251
+ tools: OptionalNullable[list[Tools] | list[ToolsTypedDict]] = UNSET,
252
+ completion_args: OptionalNullable[CompletionArgs | CompletionArgsTypedDict] = UNSET,
263
253
  ) -> tuple[
264
- Union[AgentRequestKwargs, ModelRequestKwargs], RunResult, list[InputEntries]
254
+ AgentRequestKwargs | ModelRequestKwargs, RunResult, list[InputEntries]
265
255
  ]:
266
256
  input_entries: list[InputEntries] = []
267
257
  if isinstance(inputs, str):
@@ -277,7 +267,7 @@ async def _validate_run(
277
267
  output_model=run_ctx.output_format,
278
268
  conversation_id=run_ctx.conversation_id,
279
269
  )
280
- req: Union[AgentRequestKwargs, ModelRequestKwargs]
270
+ req: AgentRequestKwargs | ModelRequestKwargs
281
271
  if run_ctx.agent_id:
282
272
  if tools or completion_args:
283
273
  raise RunException("Can't set tools or completion_args when using an agent")
@@ -1,9 +1,10 @@
1
1
  import datetime
2
2
  import json
3
3
  import typing
4
- from typing import Union, Annotated, Optional, Literal
5
4
  from dataclasses import dataclass, field
6
- from pydantic import Discriminator, Tag, BaseModel
5
+ from typing import Annotated, Literal
6
+
7
+ from pydantic import BaseModel, Discriminator, Tag
7
8
 
8
9
  from mistralai.extra.utils.response_format import pydantic_model_from_json
9
10
  from mistralai.models import (
@@ -35,15 +36,15 @@ from mistralai.models import (
35
36
  )
36
37
  from mistralai.utils import get_discriminator
37
38
 
38
- RunOutputEntries = typing.Union[
39
- MessageOutputEntry,
40
- FunctionCallEntry,
41
- FunctionResultEntry,
42
- AgentHandoffEntry,
43
- ToolExecutionEntry,
44
- ]
39
+ RunOutputEntries = (
40
+ MessageOutputEntry
41
+ | FunctionCallEntry
42
+ | FunctionResultEntry
43
+ | AgentHandoffEntry
44
+ | ToolExecutionEntry
45
+ )
45
46
 
46
- RunEntries = typing.Union[RunOutputEntries, MessageInputEntry]
47
+ RunEntries = RunOutputEntries | MessageInputEntry
47
48
 
48
49
 
49
50
  def as_text(entry: RunOutputEntries) -> str:
@@ -140,12 +141,12 @@ class RunFiles:
140
141
  @dataclass
141
142
  class RunResult:
142
143
  input_entries: list[InputEntries]
143
- conversation_id: Optional[str] = field(default=None)
144
+ conversation_id: str | None = field(default=None)
144
145
  output_entries: list[RunOutputEntries] = field(default_factory=list)
145
146
  files: dict[str, RunFiles] = field(default_factory=dict)
146
- output_model: Optional[type[BaseModel]] = field(default=None)
147
+ output_model: type[BaseModel] | None = field(default=None)
147
148
 
148
- def get_file(self, file_id: str) -> Optional[RunFiles]:
149
+ def get_file(self, file_id: str) -> RunFiles | None:
149
150
  return self.files.get(file_id)
150
151
 
151
152
  @property
@@ -172,36 +173,34 @@ class RunResult:
172
173
 
173
174
 
174
175
  class FunctionResultEvent(BaseModel):
175
- id: Optional[str] = None
176
+ id: str | None = None
176
177
 
177
- type: Optional[Literal["function.result"]] = "function.result"
178
+ type: Literal["function.result"] | None = "function.result"
178
179
 
179
180
  result: str
180
181
 
181
182
  tool_call_id: str
182
183
 
183
- created_at: Optional[datetime.datetime] = datetime.datetime.now(
184
+ created_at: datetime.datetime | None = datetime.datetime.now(
184
185
  tz=datetime.timezone.utc
185
186
  )
186
187
 
187
- output_index: Optional[int] = 0
188
+ output_index: int | None = 0
188
189
 
189
190
 
190
- RunResultEventsType = typing.Union[SSETypes, Literal["function.result"]]
191
+ RunResultEventsType = SSETypes | Literal["function.result"]
191
192
 
192
193
  RunResultEventsData = typing.Annotated[
193
- Union[
194
- Annotated[AgentHandoffDoneEvent, Tag("agent.handoff.done")],
195
- Annotated[AgentHandoffStartedEvent, Tag("agent.handoff.started")],
196
- Annotated[ResponseDoneEvent, Tag("conversation.response.done")],
197
- Annotated[ResponseErrorEvent, Tag("conversation.response.error")],
198
- Annotated[ResponseStartedEvent, Tag("conversation.response.started")],
199
- Annotated[FunctionCallEvent, Tag("function.call.delta")],
200
- Annotated[MessageOutputEvent, Tag("message.output.delta")],
201
- Annotated[ToolExecutionDoneEvent, Tag("tool.execution.done")],
202
- Annotated[ToolExecutionStartedEvent, Tag("tool.execution.started")],
203
- Annotated[FunctionResultEvent, Tag("function.result")],
204
- ],
194
+ Annotated[AgentHandoffDoneEvent, Tag("agent.handoff.done")]
195
+ | Annotated[AgentHandoffStartedEvent, Tag("agent.handoff.started")]
196
+ | Annotated[ResponseDoneEvent, Tag("conversation.response.done")]
197
+ | Annotated[ResponseErrorEvent, Tag("conversation.response.error")]
198
+ | Annotated[ResponseStartedEvent, Tag("conversation.response.started")]
199
+ | Annotated[FunctionCallEvent, Tag("function.call.delta")]
200
+ | Annotated[MessageOutputEvent, Tag("message.output.delta")]
201
+ | Annotated[ToolExecutionDoneEvent, Tag("tool.execution.done")]
202
+ | Annotated[ToolExecutionStartedEvent, Tag("tool.execution.started")]
203
+ | Annotated[FunctionResultEvent, Tag("function.result")],
205
204
  Discriminator(lambda m: get_discriminator(m, "type", "type")),
206
205
  ]
207
206
 
@@ -1,14 +1,11 @@
1
+ import inspect
1
2
  import itertools
3
+ import json
2
4
  import logging
3
5
  from dataclasses import dataclass
4
- import inspect
5
-
6
- from pydantic import Field, create_model
7
- from pydantic.fields import FieldInfo
8
- import json
9
- from typing import cast, Callable, Sequence, Any, ForwardRef, get_type_hints, Union
6
+ from typing import Any, Callable, ForwardRef, Sequence, cast, get_type_hints
10
7
 
11
- from opentelemetry import trace
8
+ import opentelemetry.semconv._incubating.attributes.gen_ai_attributes as gen_ai_attributes
12
9
  from griffe import (
13
10
  Docstring,
14
11
  DocstringSectionKind,
@@ -16,7 +13,9 @@ from griffe import (
16
13
  DocstringParameter,
17
14
  DocstringSection,
18
15
  )
19
- import opentelemetry.semconv._incubating.attributes.gen_ai_attributes as gen_ai_attributes
16
+ from opentelemetry import trace
17
+ from pydantic import Field, create_model
18
+ from pydantic.fields import FieldInfo
20
19
 
21
20
  from mistralai.extra.exceptions import RunException
22
21
  from mistralai.extra.mcp.base import MCPClientProtocol
@@ -54,7 +53,7 @@ class RunMCPTool:
54
53
  mcp_client: MCPClientProtocol
55
54
 
56
55
 
57
- RunTool = Union[RunFunction, RunCoroutine, RunMCPTool]
56
+ RunTool = RunFunction | RunCoroutine | RunMCPTool
58
57
 
59
58
 
60
59
  def _get_function_description(docstring_sections: list[DocstringSection]) -> str:
@@ -1,19 +1,26 @@
1
- from ..models import ChatCompletionResponse, ChatCompletionChoice, AssistantMessage
2
- from .utils.response_format import CustomPydanticModel, pydantic_model_from_json
3
- from typing import List, Optional, Type, Generic
4
- from pydantic import BaseModel
5
1
  import json
2
+ from typing import Generic
3
+
4
+ from ..models import AssistantMessage, ChatCompletionChoice, ChatCompletionResponse
5
+ from .utils.response_format import CustomPydanticModel, pydantic_model_from_json
6
+
6
7
 
7
8
  class ParsedAssistantMessage(AssistantMessage, Generic[CustomPydanticModel]):
8
- parsed: Optional[CustomPydanticModel]
9
+ parsed: CustomPydanticModel | None
10
+
9
11
 
10
12
  class ParsedChatCompletionChoice(ChatCompletionChoice, Generic[CustomPydanticModel]):
11
- message: Optional[ParsedAssistantMessage[CustomPydanticModel]] # type: ignore
13
+ message: ParsedAssistantMessage[CustomPydanticModel] | None # type: ignore
14
+
12
15
 
13
16
  class ParsedChatCompletionResponse(ChatCompletionResponse, Generic[CustomPydanticModel]):
14
- choices: Optional[List[ParsedChatCompletionChoice[CustomPydanticModel]]] # type: ignore
17
+ choices: list[ParsedChatCompletionChoice[CustomPydanticModel]] | None # type: ignore
18
+
15
19
 
16
- def convert_to_parsed_chat_completion_response(response: ChatCompletionResponse, response_format: Type[BaseModel]) -> ParsedChatCompletionResponse:
20
+ def convert_to_parsed_chat_completion_response(
21
+ response: ChatCompletionResponse,
22
+ response_format: type[CustomPydanticModel],
23
+ ) -> ParsedChatCompletionResponse[CustomPydanticModel]:
17
24
  parsed_choices = []
18
25
 
19
26
  if response.choices:
@@ -1,5 +1,6 @@
1
+ from typing import Any, TypeVar
2
+
1
3
  from pydantic import BaseModel
2
- from typing import TypeVar, Any, Type, Dict
3
4
  from ...models import JSONSchema, ResponseFormat
4
5
  from ._pydantic_helper import rec_strict_json_schema
5
6
 
@@ -7,7 +8,7 @@ CustomPydanticModel = TypeVar("CustomPydanticModel", bound=BaseModel)
7
8
 
8
9
 
9
10
  def response_format_from_pydantic_model(
10
- model: Type[CustomPydanticModel],
11
+ model: type[CustomPydanticModel],
11
12
  ) -> ResponseFormat:
12
13
  """Generate a strict JSON schema from a pydantic model."""
13
14
  model_schema = rec_strict_json_schema(model.model_json_schema())
@@ -18,7 +19,8 @@ def response_format_from_pydantic_model(
18
19
 
19
20
 
20
21
  def pydantic_model_from_json(
21
- json_data: Dict[str, Any], pydantic_model: Type[CustomPydanticModel]
22
+ json_data: dict[str, Any],
23
+ pydantic_model: type[CustomPydanticModel],
22
24
  ) -> CustomPydanticModel:
23
25
  """Parse a JSON schema into a pydantic model."""
24
26
  return pydantic_model.model_validate(json_data)