mistralai 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. mistralai/_hooks/tracing.py +28 -3
  2. mistralai/_version.py +2 -2
  3. mistralai/classifiers.py +13 -1
  4. mistralai/embeddings.py +7 -1
  5. mistralai/extra/README.md +1 -1
  6. mistralai/extra/mcp/auth.py +10 -11
  7. mistralai/extra/mcp/base.py +17 -16
  8. mistralai/extra/mcp/sse.py +13 -15
  9. mistralai/extra/mcp/stdio.py +5 -6
  10. mistralai/extra/observability/otel.py +47 -68
  11. mistralai/extra/run/context.py +33 -43
  12. mistralai/extra/run/result.py +29 -30
  13. mistralai/extra/run/tools.py +8 -9
  14. mistralai/extra/struct_chat.py +15 -8
  15. mistralai/extra/utils/response_format.py +5 -3
  16. mistralai/mistral_jobs.py +31 -5
  17. mistralai/models/__init__.py +30 -1
  18. mistralai/models/agents_api_v1_agents_listop.py +1 -1
  19. mistralai/models/agents_api_v1_conversations_listop.py +1 -1
  20. mistralai/models/audioencoding.py +13 -0
  21. mistralai/models/audioformat.py +19 -0
  22. mistralai/models/batchjobin.py +17 -6
  23. mistralai/models/batchjobout.py +5 -0
  24. mistralai/models/batchrequest.py +48 -0
  25. mistralai/models/classificationrequest.py +37 -3
  26. mistralai/models/embeddingrequest.py +11 -3
  27. mistralai/models/jobs_api_routes_batch_get_batch_jobop.py +40 -3
  28. mistralai/models/toolfilechunk.py +11 -4
  29. mistralai/models/toolreferencechunk.py +13 -4
  30. {mistralai-1.10.0.dist-info → mistralai-1.10.1.dist-info}/METADATA +142 -150
  31. {mistralai-1.10.0.dist-info → mistralai-1.10.1.dist-info}/RECORD +122 -105
  32. {mistralai-1.10.0.dist-info → mistralai-1.10.1.dist-info}/WHEEL +1 -1
  33. mistralai_azure/_version.py +3 -3
  34. mistralai_azure/basesdk.py +15 -5
  35. mistralai_azure/chat.py +59 -98
  36. mistralai_azure/models/__init__.py +50 -3
  37. mistralai_azure/models/chatcompletionrequest.py +16 -4
  38. mistralai_azure/models/chatcompletionstreamrequest.py +16 -4
  39. mistralai_azure/models/httpvalidationerror.py +11 -6
  40. mistralai_azure/models/mistralazureerror.py +26 -0
  41. mistralai_azure/models/no_response_error.py +13 -0
  42. mistralai_azure/models/prediction.py +4 -0
  43. mistralai_azure/models/responseformat.py +4 -2
  44. mistralai_azure/models/responseformats.py +0 -1
  45. mistralai_azure/models/responsevalidationerror.py +25 -0
  46. mistralai_azure/models/sdkerror.py +30 -14
  47. mistralai_azure/models/systemmessage.py +7 -3
  48. mistralai_azure/models/systemmessagecontentchunks.py +21 -0
  49. mistralai_azure/models/thinkchunk.py +35 -0
  50. mistralai_azure/ocr.py +15 -36
  51. mistralai_azure/utils/__init__.py +18 -5
  52. mistralai_azure/utils/eventstreaming.py +10 -0
  53. mistralai_azure/utils/serializers.py +3 -2
  54. mistralai_azure/utils/unmarshal_json_response.py +24 -0
  55. mistralai_gcp/_hooks/types.py +7 -0
  56. mistralai_gcp/_version.py +4 -4
  57. mistralai_gcp/basesdk.py +27 -25
  58. mistralai_gcp/chat.py +75 -98
  59. mistralai_gcp/fim.py +39 -74
  60. mistralai_gcp/httpclient.py +6 -16
  61. mistralai_gcp/models/__init__.py +321 -116
  62. mistralai_gcp/models/assistantmessage.py +1 -1
  63. mistralai_gcp/models/chatcompletionrequest.py +36 -7
  64. mistralai_gcp/models/chatcompletionresponse.py +6 -6
  65. mistralai_gcp/models/chatcompletionstreamrequest.py +36 -7
  66. mistralai_gcp/models/completionresponsestreamchoice.py +1 -1
  67. mistralai_gcp/models/deltamessage.py +1 -1
  68. mistralai_gcp/models/fimcompletionrequest.py +3 -9
  69. mistralai_gcp/models/fimcompletionresponse.py +6 -6
  70. mistralai_gcp/models/fimcompletionstreamrequest.py +3 -9
  71. mistralai_gcp/models/httpvalidationerror.py +11 -6
  72. mistralai_gcp/models/imageurl.py +1 -1
  73. mistralai_gcp/models/jsonschema.py +1 -1
  74. mistralai_gcp/models/mistralgcperror.py +26 -0
  75. mistralai_gcp/models/mistralpromptmode.py +8 -0
  76. mistralai_gcp/models/no_response_error.py +13 -0
  77. mistralai_gcp/models/prediction.py +4 -0
  78. mistralai_gcp/models/responseformat.py +5 -3
  79. mistralai_gcp/models/responseformats.py +0 -1
  80. mistralai_gcp/models/responsevalidationerror.py +25 -0
  81. mistralai_gcp/models/sdkerror.py +30 -14
  82. mistralai_gcp/models/systemmessage.py +7 -3
  83. mistralai_gcp/models/systemmessagecontentchunks.py +21 -0
  84. mistralai_gcp/models/thinkchunk.py +35 -0
  85. mistralai_gcp/models/toolmessage.py +1 -1
  86. mistralai_gcp/models/usageinfo.py +71 -8
  87. mistralai_gcp/models/usermessage.py +1 -1
  88. mistralai_gcp/sdk.py +12 -10
  89. mistralai_gcp/sdkconfiguration.py +0 -7
  90. mistralai_gcp/types/basemodel.py +3 -3
  91. mistralai_gcp/utils/__init__.py +143 -45
  92. mistralai_gcp/utils/datetimes.py +23 -0
  93. mistralai_gcp/utils/enums.py +67 -27
  94. mistralai_gcp/utils/eventstreaming.py +10 -0
  95. mistralai_gcp/utils/forms.py +49 -28
  96. mistralai_gcp/utils/serializers.py +33 -3
  97. mistralai_gcp/utils/unmarshal_json_response.py +24 -0
  98. {mistralai-1.10.0.dist-info → mistralai-1.10.1.dist-info}/licenses/LICENSE +0 -0
@@ -30,13 +30,30 @@ class TracingHook(BeforeRequestHook, AfterSuccessHook, AfterErrorHook):
30
30
  def before_request(
31
31
  self, hook_ctx: BeforeRequestContext, request: httpx.Request
32
32
  ) -> Union[httpx.Request, Exception]:
33
- request, self.request_span = get_traced_request_and_span(tracing_enabled=self.tracing_enabled, tracer=self.tracer, span=self.request_span, operation_id=hook_ctx.operation_id, request=request)
33
+ # Refresh tracer/provider per request so tracing can be enabled if the
34
+ # application configures OpenTelemetry after the client is instantiated.
35
+ self.tracing_enabled, self.tracer = get_or_create_otel_tracer()
36
+ self.request_span = None
37
+ request, self.request_span = get_traced_request_and_span(
38
+ tracing_enabled=self.tracing_enabled,
39
+ tracer=self.tracer,
40
+ span=self.request_span,
41
+ operation_id=hook_ctx.operation_id,
42
+ request=request,
43
+ )
34
44
  return request
35
45
 
36
46
  def after_success(
37
47
  self, hook_ctx: AfterSuccessContext, response: httpx.Response
38
48
  ) -> Union[httpx.Response, Exception]:
39
- response = get_traced_response(tracing_enabled=self.tracing_enabled, tracer=self.tracer, span=self.request_span, operation_id=hook_ctx.operation_id, response=response)
49
+ response = get_traced_response(
50
+ tracing_enabled=self.tracing_enabled,
51
+ tracer=self.tracer,
52
+ span=self.request_span,
53
+ operation_id=hook_ctx.operation_id,
54
+ response=response,
55
+ )
56
+ self.request_span = None
40
57
  return response
41
58
 
42
59
  def after_error(
@@ -46,5 +63,13 @@ class TracingHook(BeforeRequestHook, AfterSuccessHook, AfterErrorHook):
46
63
  error: Optional[Exception],
47
64
  ) -> Union[Tuple[Optional[httpx.Response], Optional[Exception]], Exception]:
48
65
  if response:
49
- response, error = get_response_and_error(tracing_enabled=self.tracing_enabled, tracer=self.tracer, span=self.request_span, operation_id=hook_ctx.operation_id, response=response, error=error)
66
+ response, error = get_response_and_error(
67
+ tracing_enabled=self.tracing_enabled,
68
+ tracer=self.tracer,
69
+ span=self.request_span,
70
+ operation_id=hook_ctx.operation_id,
71
+ response=response,
72
+ error=error,
73
+ )
74
+ self.request_span = None
50
75
  return response, error
mistralai/_version.py CHANGED
@@ -3,10 +3,10 @@
3
3
  import importlib.metadata
4
4
 
5
5
  __title__: str = "mistralai"
6
- __version__: str = "1.10.0"
6
+ __version__: str = "1.10.1"
7
7
  __openapi_doc_version__: str = "1.0.0"
8
8
  __gen_version__: str = "2.687.13"
9
- __user_agent__: str = "speakeasy-sdk/python 1.10.0 2.687.13 1.0.0 mistralai"
9
+ __user_agent__: str = "speakeasy-sdk/python 1.10.1 2.687.13 1.0.0 mistralai"
10
10
 
11
11
  try:
12
12
  if __package__ is not None:
mistralai/classifiers.py CHANGED
@@ -6,7 +6,7 @@ from mistralai._hooks import HookContext
6
6
  from mistralai.types import OptionalNullable, UNSET
7
7
  from mistralai.utils import get_security_from_env
8
8
  from mistralai.utils.unmarshal_json_response import unmarshal_json_response
9
- from typing import Any, Mapping, Optional, Union
9
+ from typing import Any, Dict, Mapping, Optional, Union
10
10
 
11
11
 
12
12
  class Classifiers(BaseSDK):
@@ -20,6 +20,7 @@ class Classifiers(BaseSDK):
20
20
  models.ClassificationRequestInputs,
21
21
  models.ClassificationRequestInputsTypedDict,
22
22
  ],
23
+ metadata: OptionalNullable[Dict[str, Any]] = UNSET,
23
24
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
24
25
  server_url: Optional[str] = None,
25
26
  timeout_ms: Optional[int] = None,
@@ -29,6 +30,7 @@ class Classifiers(BaseSDK):
29
30
 
30
31
  :param model: ID of the model to use.
31
32
  :param inputs: Text to classify.
33
+ :param metadata:
32
34
  :param retries: Override the default retry configuration for this method
33
35
  :param server_url: Override the default server URL for this method
34
36
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -46,6 +48,7 @@ class Classifiers(BaseSDK):
46
48
 
47
49
  request = models.ClassificationRequest(
48
50
  model=model,
51
+ metadata=metadata,
49
52
  inputs=inputs,
50
53
  )
51
54
 
@@ -116,6 +119,7 @@ class Classifiers(BaseSDK):
116
119
  models.ClassificationRequestInputs,
117
120
  models.ClassificationRequestInputsTypedDict,
118
121
  ],
122
+ metadata: OptionalNullable[Dict[str, Any]] = UNSET,
119
123
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
120
124
  server_url: Optional[str] = None,
121
125
  timeout_ms: Optional[int] = None,
@@ -125,6 +129,7 @@ class Classifiers(BaseSDK):
125
129
 
126
130
  :param model: ID of the model to use.
127
131
  :param inputs: Text to classify.
132
+ :param metadata:
128
133
  :param retries: Override the default retry configuration for this method
129
134
  :param server_url: Override the default server URL for this method
130
135
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -142,6 +147,7 @@ class Classifiers(BaseSDK):
142
147
 
143
148
  request = models.ClassificationRequest(
144
149
  model=model,
150
+ metadata=metadata,
145
151
  inputs=inputs,
146
152
  )
147
153
 
@@ -404,6 +410,7 @@ class Classifiers(BaseSDK):
404
410
  models.ClassificationRequestInputs,
405
411
  models.ClassificationRequestInputsTypedDict,
406
412
  ],
413
+ metadata: OptionalNullable[Dict[str, Any]] = UNSET,
407
414
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
408
415
  server_url: Optional[str] = None,
409
416
  timeout_ms: Optional[int] = None,
@@ -413,6 +420,7 @@ class Classifiers(BaseSDK):
413
420
 
414
421
  :param model: ID of the model to use.
415
422
  :param inputs: Text to classify.
423
+ :param metadata:
416
424
  :param retries: Override the default retry configuration for this method
417
425
  :param server_url: Override the default server URL for this method
418
426
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -430,6 +438,7 @@ class Classifiers(BaseSDK):
430
438
 
431
439
  request = models.ClassificationRequest(
432
440
  model=model,
441
+ metadata=metadata,
433
442
  inputs=inputs,
434
443
  )
435
444
 
@@ -500,6 +509,7 @@ class Classifiers(BaseSDK):
500
509
  models.ClassificationRequestInputs,
501
510
  models.ClassificationRequestInputsTypedDict,
502
511
  ],
512
+ metadata: OptionalNullable[Dict[str, Any]] = UNSET,
503
513
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
504
514
  server_url: Optional[str] = None,
505
515
  timeout_ms: Optional[int] = None,
@@ -509,6 +519,7 @@ class Classifiers(BaseSDK):
509
519
 
510
520
  :param model: ID of the model to use.
511
521
  :param inputs: Text to classify.
522
+ :param metadata:
512
523
  :param retries: Override the default retry configuration for this method
513
524
  :param server_url: Override the default server URL for this method
514
525
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -526,6 +537,7 @@ class Classifiers(BaseSDK):
526
537
 
527
538
  request = models.ClassificationRequest(
528
539
  model=model,
540
+ metadata=metadata,
529
541
  inputs=inputs,
530
542
  )
531
543
 
mistralai/embeddings.py CHANGED
@@ -6,7 +6,7 @@ from mistralai._hooks import HookContext
6
6
  from mistralai.types import OptionalNullable, UNSET
7
7
  from mistralai.utils import get_security_from_env
8
8
  from mistralai.utils.unmarshal_json_response import unmarshal_json_response
9
- from typing import Any, Mapping, Optional, Union
9
+ from typing import Any, Dict, Mapping, Optional, Union
10
10
 
11
11
 
12
12
  class Embeddings(BaseSDK):
@@ -19,6 +19,7 @@ class Embeddings(BaseSDK):
19
19
  inputs: Union[
20
20
  models.EmbeddingRequestInputs, models.EmbeddingRequestInputsTypedDict
21
21
  ],
22
+ metadata: OptionalNullable[Dict[str, Any]] = UNSET,
22
23
  output_dimension: OptionalNullable[int] = UNSET,
23
24
  output_dtype: Optional[models.EmbeddingDtype] = None,
24
25
  encoding_format: Optional[models.EncodingFormat] = None,
@@ -33,6 +34,7 @@ class Embeddings(BaseSDK):
33
34
 
34
35
  :param model: The ID of the model to be used for embedding.
35
36
  :param inputs: The text content to be embedded, can be a string or an array of strings for fast processing in bulk.
37
+ :param metadata:
36
38
  :param output_dimension: The dimension of the output embeddings when feature available. If not provided, a default output dimension will be used.
37
39
  :param output_dtype:
38
40
  :param encoding_format:
@@ -53,6 +55,7 @@ class Embeddings(BaseSDK):
53
55
 
54
56
  request = models.EmbeddingRequest(
55
57
  model=model,
58
+ metadata=metadata,
56
59
  inputs=inputs,
57
60
  output_dimension=output_dimension,
58
61
  output_dtype=output_dtype,
@@ -125,6 +128,7 @@ class Embeddings(BaseSDK):
125
128
  inputs: Union[
126
129
  models.EmbeddingRequestInputs, models.EmbeddingRequestInputsTypedDict
127
130
  ],
131
+ metadata: OptionalNullable[Dict[str, Any]] = UNSET,
128
132
  output_dimension: OptionalNullable[int] = UNSET,
129
133
  output_dtype: Optional[models.EmbeddingDtype] = None,
130
134
  encoding_format: Optional[models.EncodingFormat] = None,
@@ -139,6 +143,7 @@ class Embeddings(BaseSDK):
139
143
 
140
144
  :param model: The ID of the model to be used for embedding.
141
145
  :param inputs: The text content to be embedded, can be a string or an array of strings for fast processing in bulk.
146
+ :param metadata:
142
147
  :param output_dimension: The dimension of the output embeddings when feature available. If not provided, a default output dimension will be used.
143
148
  :param output_dtype:
144
149
  :param encoding_format:
@@ -159,6 +164,7 @@ class Embeddings(BaseSDK):
159
164
 
160
165
  request = models.EmbeddingRequest(
161
166
  model=model,
167
+ metadata=metadata,
162
168
  inputs=inputs,
163
169
  output_dimension=output_dimension,
164
170
  output_dtype=output_dtype,
mistralai/extra/README.md CHANGED
@@ -34,7 +34,7 @@ class Chat(BaseSDK):
34
34
 
35
35
  3. Now build the SDK with the custom code:
36
36
  ```bash
37
- rm -rf dist; poetry build; python3 -m pip install ~/client-python/dist/mistralai-1.4.1-py3-none-any.whl --force-reinstall
37
+ rm -rf dist; uv build; uv pip install --reinstall ~/client-python/dist/mistralai-1.4.1-py3-none-any.whl
38
38
  ```
39
39
 
40
40
  4. And now you should be able to call the custom method:
@@ -1,9 +1,8 @@
1
- from typing import Optional
1
+ import logging
2
2
 
3
- from authlib.oauth2.rfc8414 import AuthorizationServerMetadata
4
- from authlib.integrations.httpx_client import AsyncOAuth2Client as AsyncOAuth2ClientBase
5
3
  import httpx
6
- import logging
4
+ from authlib.integrations.httpx_client import AsyncOAuth2Client as AsyncOAuth2ClientBase
5
+ from authlib.oauth2.rfc8414 import AuthorizationServerMetadata
7
6
 
8
7
  from mistralai.types import BaseModel
9
8
 
@@ -16,8 +15,8 @@ class Oauth2AuthorizationScheme(BaseModel):
16
15
  authorization_url: str
17
16
  token_url: str
18
17
  scope: list[str]
19
- description: Optional[str] = None
20
- refresh_url: Optional[str] = None
18
+ description: str | None = None
19
+ refresh_url: str | None = None
21
20
 
22
21
 
23
22
  class OAuthParams(BaseModel):
@@ -42,7 +41,7 @@ class AsyncOAuth2Client(AsyncOAuth2ClientBase):
42
41
 
43
42
  async def get_well_known_authorization_server_metadata(
44
43
  server_url: str,
45
- ) -> Optional[AuthorizationServerMetadata]:
44
+ ) -> AuthorizationServerMetadata | None:
46
45
  """Fetch the metadata from the well-known location.
47
46
 
48
47
  This should be available on MCP servers as described by the specification:
@@ -123,10 +122,10 @@ async def dynamic_client_registration(
123
122
  async def build_oauth_params(
124
123
  server_url: str,
125
124
  redirect_url: str,
126
- client_id: Optional[str] = None,
127
- client_secret: Optional[str] = None,
128
- scope: Optional[list[str]] = None,
129
- async_client: Optional[httpx.AsyncClient] = None,
125
+ client_id: str | None = None,
126
+ client_secret: str | None = None,
127
+ scope: list[str] | None = None,
128
+ async_client: httpx.AsyncClient | None = None,
130
129
  ) -> OAuthParams:
131
130
  """Get issuer metadata and build the oauth required params."""
132
131
  metadata = await get_oauth_server_metadata(server_url=server_url)
@@ -1,11 +1,14 @@
1
- from typing import Optional, Union
2
1
  import logging
3
2
  import typing
3
+ from collections.abc import Sequence
4
4
  from contextlib import AsyncExitStack
5
- from typing import Protocol, Any
5
+ from typing import Any, Protocol
6
6
 
7
- from mcp import ClientSession
8
- from mcp.types import ListPromptsResult, EmbeddedResource, ImageContent, TextContent
7
+ from mcp import ClientSession # pyright: ignore[reportMissingImports]
8
+ from mcp.types import ( # pyright: ignore[reportMissingImports]
9
+ ContentBlock,
10
+ ListPromptsResult,
11
+ )
9
12
 
10
13
  from mistralai.extra.exceptions import MCPException
11
14
  from mistralai.models import (
@@ -20,8 +23,8 @@ logger = logging.getLogger(__name__)
20
23
 
21
24
 
22
25
  class MCPSystemPrompt(typing.TypedDict):
23
- description: Optional[str]
24
- messages: list[Union[SystemMessageTypedDict, AssistantMessageTypedDict]]
26
+ description: str | None
27
+ messages: list[SystemMessageTypedDict | AssistantMessageTypedDict]
25
28
 
26
29
 
27
30
  class MCPClientProtocol(Protocol):
@@ -29,7 +32,7 @@ class MCPClientProtocol(Protocol):
29
32
 
30
33
  _name: str
31
34
 
32
- async def initialize(self, exit_stack: Optional[AsyncExitStack]) -> None:
35
+ async def initialize(self, exit_stack: AsyncExitStack | None) -> None:
33
36
  ...
34
37
 
35
38
  async def aclose(self) -> None:
@@ -39,7 +42,7 @@ class MCPClientProtocol(Protocol):
39
42
  ...
40
43
 
41
44
  async def execute_tool(
42
- self, name: str, arguments: dict
45
+ self, name: str, arguments: dict[str, Any]
43
46
  ) -> list[TextChunkTypedDict]:
44
47
  ...
45
48
 
@@ -57,20 +60,18 @@ class MCPClientBase(MCPClientProtocol):
57
60
 
58
61
  _session: ClientSession
59
62
 
60
- def __init__(self, name: Optional[str] = None):
63
+ def __init__(self, name: str | None = None):
61
64
  self._name = name or self.__class__.__name__
62
- self._exit_stack: Optional[AsyncExitStack] = None
65
+ self._exit_stack: AsyncExitStack | None = None
63
66
  self._is_initialized = False
64
67
 
65
- def _convert_content(
66
- self, mcp_content: Union[TextContent, ImageContent, EmbeddedResource]
67
- ) -> TextChunkTypedDict:
68
+ def _convert_content(self, mcp_content: ContentBlock) -> TextChunkTypedDict:
68
69
  if not mcp_content.type == "text":
69
70
  raise MCPException("Only supporting text tool responses for now.")
70
71
  return {"type": "text", "text": mcp_content.text}
71
72
 
72
73
  def _convert_content_list(
73
- self, mcp_contents: list[Union[TextContent, ImageContent, EmbeddedResource]]
74
+ self, mcp_contents: Sequence[ContentBlock]
74
75
  ) -> list[TextChunkTypedDict]:
75
76
  content_chunks = []
76
77
  for mcp_content in mcp_contents:
@@ -108,7 +109,7 @@ class MCPClientBase(MCPClientProtocol):
108
109
  "description": prompt_result.description,
109
110
  "messages": [
110
111
  typing.cast(
111
- Union[SystemMessageTypedDict, AssistantMessageTypedDict],
112
+ SystemMessageTypedDict | AssistantMessageTypedDict,
112
113
  {
113
114
  "role": message.role,
114
115
  "content": self._convert_content(mcp_content=message.content),
@@ -121,7 +122,7 @@ class MCPClientBase(MCPClientProtocol):
121
122
  async def list_system_prompts(self) -> ListPromptsResult:
122
123
  return await self._session.list_prompts()
123
124
 
124
- async def initialize(self, exit_stack: Optional[AsyncExitStack] = None) -> None:
125
+ async def initialize(self, exit_stack: AsyncExitStack | None = None) -> None:
125
126
  """Initialize the MCP session."""
126
127
  # client is already initialized so return
127
128
  if self._is_initialized:
@@ -1,22 +1,20 @@
1
1
  import http
2
2
  import logging
3
- import typing
4
- from typing import Any, Optional
5
3
  from contextlib import AsyncExitStack
6
4
  from functools import cached_property
5
+ from typing import Any
7
6
 
8
7
  import httpx
8
+ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
9
+ from authlib.oauth2.rfc6749 import OAuth2Token
10
+ from mcp.client.sse import sse_client # pyright: ignore[reportMissingImports]
11
+ from mcp.shared.message import SessionMessage # pyright: ignore[reportMissingImports]
9
12
 
10
13
  from mistralai.extra.exceptions import MCPAuthException
11
14
  from mistralai.extra.mcp.base import (
12
15
  MCPClientBase,
13
16
  )
14
17
  from mistralai.extra.mcp.auth import OAuthParams, AsyncOAuth2Client
15
- from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
16
-
17
- from mcp.client.sse import sse_client
18
- from mcp.shared.message import SessionMessage
19
- from authlib.oauth2.rfc6749 import OAuth2Token
20
18
 
21
19
  from mistralai.types import BaseModel
22
20
 
@@ -27,7 +25,7 @@ class SSEServerParams(BaseModel):
27
25
  """Parameters required for a MCPClient with SSE transport"""
28
26
 
29
27
  url: str
30
- headers: Optional[dict[str, Any]] = None
28
+ headers: dict[str, Any] | None = None
31
29
  timeout: float = 5
32
30
  sse_read_timeout: float = 60 * 5
33
31
 
@@ -41,20 +39,20 @@ class MCPClientSSE(MCPClientBase):
41
39
  This is possibly going to change in the future since the protocol has ongoing discussions.
42
40
  """
43
41
 
44
- _oauth_params: Optional[OAuthParams]
42
+ _oauth_params: OAuthParams | None
45
43
  _sse_params: SSEServerParams
46
44
 
47
45
  def __init__(
48
46
  self,
49
47
  sse_params: SSEServerParams,
50
- name: Optional[str] = None,
51
- oauth_params: Optional[OAuthParams] = None,
52
- auth_token: Optional[OAuth2Token] = None,
48
+ name: str | None = None,
49
+ oauth_params: OAuthParams | None = None,
50
+ auth_token: OAuth2Token | None = None,
53
51
  ):
54
52
  super().__init__(name=name)
55
53
  self._sse_params = sse_params
56
- self._oauth_params: Optional[OAuthParams] = oauth_params
57
- self._auth_token: Optional[OAuth2Token] = auth_token
54
+ self._oauth_params: OAuthParams | None = oauth_params
55
+ self._auth_token: OAuth2Token | None = auth_token
58
56
 
59
57
  @cached_property
60
58
  def base_url(self) -> str:
@@ -142,7 +140,7 @@ class MCPClientSSE(MCPClientBase):
142
140
  async def _get_transport(
143
141
  self, exit_stack: AsyncExitStack
144
142
  ) -> tuple[
145
- MemoryObjectReceiveStream[typing.Union[SessionMessage, Exception]],
143
+ MemoryObjectReceiveStream[SessionMessage | Exception],
146
144
  MemoryObjectSendStream[SessionMessage],
147
145
  ]:
148
146
  try:
@@ -1,12 +1,9 @@
1
- from typing import Optional
2
1
  import logging
3
2
  from contextlib import AsyncExitStack
4
3
 
5
- from mistralai.extra.mcp.base import (
6
- MCPClientBase,
7
- )
4
+ from mcp import StdioServerParameters, stdio_client # pyright: ignore[reportMissingImports]
8
5
 
9
- from mcp import stdio_client, StdioServerParameters
6
+ from mistralai.extra.mcp.base import MCPClientBase
10
7
 
11
8
  logger = logging.getLogger(__name__)
12
9
 
@@ -14,7 +11,9 @@ logger = logging.getLogger(__name__)
14
11
  class MCPClientSTDIO(MCPClientBase):
15
12
  """MCP client that uses stdio for communication."""
16
13
 
17
- def __init__(self, stdio_params: StdioServerParameters, name: Optional[str] = None):
14
+ def __init__(
15
+ self, stdio_params: StdioServerParameters, name: str | None = None
16
+ ):
18
17
  super().__init__(name=name)
19
18
  self._stdio_params = stdio_params
20
19