mistralai 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/_hooks/tracing.py +28 -3
- mistralai/_version.py +2 -2
- mistralai/classifiers.py +13 -1
- mistralai/embeddings.py +7 -1
- mistralai/extra/README.md +1 -1
- mistralai/extra/mcp/auth.py +10 -11
- mistralai/extra/mcp/base.py +17 -16
- mistralai/extra/mcp/sse.py +13 -15
- mistralai/extra/mcp/stdio.py +5 -6
- mistralai/extra/observability/otel.py +47 -68
- mistralai/extra/run/context.py +33 -43
- mistralai/extra/run/result.py +29 -30
- mistralai/extra/run/tools.py +8 -9
- mistralai/extra/struct_chat.py +15 -8
- mistralai/extra/utils/response_format.py +5 -3
- mistralai/mistral_jobs.py +31 -5
- mistralai/models/__init__.py +30 -1
- mistralai/models/agents_api_v1_agents_listop.py +1 -1
- mistralai/models/agents_api_v1_conversations_listop.py +1 -1
- mistralai/models/audioencoding.py +13 -0
- mistralai/models/audioformat.py +19 -0
- mistralai/models/batchjobin.py +17 -6
- mistralai/models/batchjobout.py +5 -0
- mistralai/models/batchrequest.py +48 -0
- mistralai/models/classificationrequest.py +37 -3
- mistralai/models/embeddingrequest.py +11 -3
- mistralai/models/jobs_api_routes_batch_get_batch_jobop.py +40 -3
- mistralai/models/toolfilechunk.py +11 -4
- mistralai/models/toolreferencechunk.py +13 -4
- {mistralai-1.10.0.dist-info → mistralai-1.10.1.dist-info}/METADATA +142 -150
- {mistralai-1.10.0.dist-info → mistralai-1.10.1.dist-info}/RECORD +122 -105
- {mistralai-1.10.0.dist-info → mistralai-1.10.1.dist-info}/WHEEL +1 -1
- mistralai_azure/_version.py +3 -3
- mistralai_azure/basesdk.py +15 -5
- mistralai_azure/chat.py +59 -98
- mistralai_azure/models/__init__.py +50 -3
- mistralai_azure/models/chatcompletionrequest.py +16 -4
- mistralai_azure/models/chatcompletionstreamrequest.py +16 -4
- mistralai_azure/models/httpvalidationerror.py +11 -6
- mistralai_azure/models/mistralazureerror.py +26 -0
- mistralai_azure/models/no_response_error.py +13 -0
- mistralai_azure/models/prediction.py +4 -0
- mistralai_azure/models/responseformat.py +4 -2
- mistralai_azure/models/responseformats.py +0 -1
- mistralai_azure/models/responsevalidationerror.py +25 -0
- mistralai_azure/models/sdkerror.py +30 -14
- mistralai_azure/models/systemmessage.py +7 -3
- mistralai_azure/models/systemmessagecontentchunks.py +21 -0
- mistralai_azure/models/thinkchunk.py +35 -0
- mistralai_azure/ocr.py +15 -36
- mistralai_azure/utils/__init__.py +18 -5
- mistralai_azure/utils/eventstreaming.py +10 -0
- mistralai_azure/utils/serializers.py +3 -2
- mistralai_azure/utils/unmarshal_json_response.py +24 -0
- mistralai_gcp/_hooks/types.py +7 -0
- mistralai_gcp/_version.py +4 -4
- mistralai_gcp/basesdk.py +27 -25
- mistralai_gcp/chat.py +75 -98
- mistralai_gcp/fim.py +39 -74
- mistralai_gcp/httpclient.py +6 -16
- mistralai_gcp/models/__init__.py +321 -116
- mistralai_gcp/models/assistantmessage.py +1 -1
- mistralai_gcp/models/chatcompletionrequest.py +36 -7
- mistralai_gcp/models/chatcompletionresponse.py +6 -6
- mistralai_gcp/models/chatcompletionstreamrequest.py +36 -7
- mistralai_gcp/models/completionresponsestreamchoice.py +1 -1
- mistralai_gcp/models/deltamessage.py +1 -1
- mistralai_gcp/models/fimcompletionrequest.py +3 -9
- mistralai_gcp/models/fimcompletionresponse.py +6 -6
- mistralai_gcp/models/fimcompletionstreamrequest.py +3 -9
- mistralai_gcp/models/httpvalidationerror.py +11 -6
- mistralai_gcp/models/imageurl.py +1 -1
- mistralai_gcp/models/jsonschema.py +1 -1
- mistralai_gcp/models/mistralgcperror.py +26 -0
- mistralai_gcp/models/mistralpromptmode.py +8 -0
- mistralai_gcp/models/no_response_error.py +13 -0
- mistralai_gcp/models/prediction.py +4 -0
- mistralai_gcp/models/responseformat.py +5 -3
- mistralai_gcp/models/responseformats.py +0 -1
- mistralai_gcp/models/responsevalidationerror.py +25 -0
- mistralai_gcp/models/sdkerror.py +30 -14
- mistralai_gcp/models/systemmessage.py +7 -3
- mistralai_gcp/models/systemmessagecontentchunks.py +21 -0
- mistralai_gcp/models/thinkchunk.py +35 -0
- mistralai_gcp/models/toolmessage.py +1 -1
- mistralai_gcp/models/usageinfo.py +71 -8
- mistralai_gcp/models/usermessage.py +1 -1
- mistralai_gcp/sdk.py +12 -10
- mistralai_gcp/sdkconfiguration.py +0 -7
- mistralai_gcp/types/basemodel.py +3 -3
- mistralai_gcp/utils/__init__.py +143 -45
- mistralai_gcp/utils/datetimes.py +23 -0
- mistralai_gcp/utils/enums.py +67 -27
- mistralai_gcp/utils/eventstreaming.py +10 -0
- mistralai_gcp/utils/forms.py +49 -28
- mistralai_gcp/utils/serializers.py +33 -3
- mistralai_gcp/utils/unmarshal_json_response.py +24 -0
- {mistralai-1.10.0.dist-info → mistralai-1.10.1.dist-info}/licenses/LICENSE +0 -0
mistralai/_hooks/tracing.py
CHANGED
|
@@ -30,13 +30,30 @@ class TracingHook(BeforeRequestHook, AfterSuccessHook, AfterErrorHook):
|
|
|
30
30
|
def before_request(
|
|
31
31
|
self, hook_ctx: BeforeRequestContext, request: httpx.Request
|
|
32
32
|
) -> Union[httpx.Request, Exception]:
|
|
33
|
-
request
|
|
33
|
+
# Refresh tracer/provider per request so tracing can be enabled if the
|
|
34
|
+
# application configures OpenTelemetry after the client is instantiated.
|
|
35
|
+
self.tracing_enabled, self.tracer = get_or_create_otel_tracer()
|
|
36
|
+
self.request_span = None
|
|
37
|
+
request, self.request_span = get_traced_request_and_span(
|
|
38
|
+
tracing_enabled=self.tracing_enabled,
|
|
39
|
+
tracer=self.tracer,
|
|
40
|
+
span=self.request_span,
|
|
41
|
+
operation_id=hook_ctx.operation_id,
|
|
42
|
+
request=request,
|
|
43
|
+
)
|
|
34
44
|
return request
|
|
35
45
|
|
|
36
46
|
def after_success(
|
|
37
47
|
self, hook_ctx: AfterSuccessContext, response: httpx.Response
|
|
38
48
|
) -> Union[httpx.Response, Exception]:
|
|
39
|
-
response = get_traced_response(
|
|
49
|
+
response = get_traced_response(
|
|
50
|
+
tracing_enabled=self.tracing_enabled,
|
|
51
|
+
tracer=self.tracer,
|
|
52
|
+
span=self.request_span,
|
|
53
|
+
operation_id=hook_ctx.operation_id,
|
|
54
|
+
response=response,
|
|
55
|
+
)
|
|
56
|
+
self.request_span = None
|
|
40
57
|
return response
|
|
41
58
|
|
|
42
59
|
def after_error(
|
|
@@ -46,5 +63,13 @@ class TracingHook(BeforeRequestHook, AfterSuccessHook, AfterErrorHook):
|
|
|
46
63
|
error: Optional[Exception],
|
|
47
64
|
) -> Union[Tuple[Optional[httpx.Response], Optional[Exception]], Exception]:
|
|
48
65
|
if response:
|
|
49
|
-
response, error = get_response_and_error(
|
|
66
|
+
response, error = get_response_and_error(
|
|
67
|
+
tracing_enabled=self.tracing_enabled,
|
|
68
|
+
tracer=self.tracer,
|
|
69
|
+
span=self.request_span,
|
|
70
|
+
operation_id=hook_ctx.operation_id,
|
|
71
|
+
response=response,
|
|
72
|
+
error=error,
|
|
73
|
+
)
|
|
74
|
+
self.request_span = None
|
|
50
75
|
return response, error
|
mistralai/_version.py
CHANGED
|
@@ -3,10 +3,10 @@
|
|
|
3
3
|
import importlib.metadata
|
|
4
4
|
|
|
5
5
|
__title__: str = "mistralai"
|
|
6
|
-
__version__: str = "1.10.
|
|
6
|
+
__version__: str = "1.10.1"
|
|
7
7
|
__openapi_doc_version__: str = "1.0.0"
|
|
8
8
|
__gen_version__: str = "2.687.13"
|
|
9
|
-
__user_agent__: str = "speakeasy-sdk/python 1.10.
|
|
9
|
+
__user_agent__: str = "speakeasy-sdk/python 1.10.1 2.687.13 1.0.0 mistralai"
|
|
10
10
|
|
|
11
11
|
try:
|
|
12
12
|
if __package__ is not None:
|
mistralai/classifiers.py
CHANGED
|
@@ -6,7 +6,7 @@ from mistralai._hooks import HookContext
|
|
|
6
6
|
from mistralai.types import OptionalNullable, UNSET
|
|
7
7
|
from mistralai.utils import get_security_from_env
|
|
8
8
|
from mistralai.utils.unmarshal_json_response import unmarshal_json_response
|
|
9
|
-
from typing import Any, Mapping, Optional, Union
|
|
9
|
+
from typing import Any, Dict, Mapping, Optional, Union
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class Classifiers(BaseSDK):
|
|
@@ -20,6 +20,7 @@ class Classifiers(BaseSDK):
|
|
|
20
20
|
models.ClassificationRequestInputs,
|
|
21
21
|
models.ClassificationRequestInputsTypedDict,
|
|
22
22
|
],
|
|
23
|
+
metadata: OptionalNullable[Dict[str, Any]] = UNSET,
|
|
23
24
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
24
25
|
server_url: Optional[str] = None,
|
|
25
26
|
timeout_ms: Optional[int] = None,
|
|
@@ -29,6 +30,7 @@ class Classifiers(BaseSDK):
|
|
|
29
30
|
|
|
30
31
|
:param model: ID of the model to use.
|
|
31
32
|
:param inputs: Text to classify.
|
|
33
|
+
:param metadata:
|
|
32
34
|
:param retries: Override the default retry configuration for this method
|
|
33
35
|
:param server_url: Override the default server URL for this method
|
|
34
36
|
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
@@ -46,6 +48,7 @@ class Classifiers(BaseSDK):
|
|
|
46
48
|
|
|
47
49
|
request = models.ClassificationRequest(
|
|
48
50
|
model=model,
|
|
51
|
+
metadata=metadata,
|
|
49
52
|
inputs=inputs,
|
|
50
53
|
)
|
|
51
54
|
|
|
@@ -116,6 +119,7 @@ class Classifiers(BaseSDK):
|
|
|
116
119
|
models.ClassificationRequestInputs,
|
|
117
120
|
models.ClassificationRequestInputsTypedDict,
|
|
118
121
|
],
|
|
122
|
+
metadata: OptionalNullable[Dict[str, Any]] = UNSET,
|
|
119
123
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
120
124
|
server_url: Optional[str] = None,
|
|
121
125
|
timeout_ms: Optional[int] = None,
|
|
@@ -125,6 +129,7 @@ class Classifiers(BaseSDK):
|
|
|
125
129
|
|
|
126
130
|
:param model: ID of the model to use.
|
|
127
131
|
:param inputs: Text to classify.
|
|
132
|
+
:param metadata:
|
|
128
133
|
:param retries: Override the default retry configuration for this method
|
|
129
134
|
:param server_url: Override the default server URL for this method
|
|
130
135
|
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
@@ -142,6 +147,7 @@ class Classifiers(BaseSDK):
|
|
|
142
147
|
|
|
143
148
|
request = models.ClassificationRequest(
|
|
144
149
|
model=model,
|
|
150
|
+
metadata=metadata,
|
|
145
151
|
inputs=inputs,
|
|
146
152
|
)
|
|
147
153
|
|
|
@@ -404,6 +410,7 @@ class Classifiers(BaseSDK):
|
|
|
404
410
|
models.ClassificationRequestInputs,
|
|
405
411
|
models.ClassificationRequestInputsTypedDict,
|
|
406
412
|
],
|
|
413
|
+
metadata: OptionalNullable[Dict[str, Any]] = UNSET,
|
|
407
414
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
408
415
|
server_url: Optional[str] = None,
|
|
409
416
|
timeout_ms: Optional[int] = None,
|
|
@@ -413,6 +420,7 @@ class Classifiers(BaseSDK):
|
|
|
413
420
|
|
|
414
421
|
:param model: ID of the model to use.
|
|
415
422
|
:param inputs: Text to classify.
|
|
423
|
+
:param metadata:
|
|
416
424
|
:param retries: Override the default retry configuration for this method
|
|
417
425
|
:param server_url: Override the default server URL for this method
|
|
418
426
|
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
@@ -430,6 +438,7 @@ class Classifiers(BaseSDK):
|
|
|
430
438
|
|
|
431
439
|
request = models.ClassificationRequest(
|
|
432
440
|
model=model,
|
|
441
|
+
metadata=metadata,
|
|
433
442
|
inputs=inputs,
|
|
434
443
|
)
|
|
435
444
|
|
|
@@ -500,6 +509,7 @@ class Classifiers(BaseSDK):
|
|
|
500
509
|
models.ClassificationRequestInputs,
|
|
501
510
|
models.ClassificationRequestInputsTypedDict,
|
|
502
511
|
],
|
|
512
|
+
metadata: OptionalNullable[Dict[str, Any]] = UNSET,
|
|
503
513
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
504
514
|
server_url: Optional[str] = None,
|
|
505
515
|
timeout_ms: Optional[int] = None,
|
|
@@ -509,6 +519,7 @@ class Classifiers(BaseSDK):
|
|
|
509
519
|
|
|
510
520
|
:param model: ID of the model to use.
|
|
511
521
|
:param inputs: Text to classify.
|
|
522
|
+
:param metadata:
|
|
512
523
|
:param retries: Override the default retry configuration for this method
|
|
513
524
|
:param server_url: Override the default server URL for this method
|
|
514
525
|
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
@@ -526,6 +537,7 @@ class Classifiers(BaseSDK):
|
|
|
526
537
|
|
|
527
538
|
request = models.ClassificationRequest(
|
|
528
539
|
model=model,
|
|
540
|
+
metadata=metadata,
|
|
529
541
|
inputs=inputs,
|
|
530
542
|
)
|
|
531
543
|
|
mistralai/embeddings.py
CHANGED
|
@@ -6,7 +6,7 @@ from mistralai._hooks import HookContext
|
|
|
6
6
|
from mistralai.types import OptionalNullable, UNSET
|
|
7
7
|
from mistralai.utils import get_security_from_env
|
|
8
8
|
from mistralai.utils.unmarshal_json_response import unmarshal_json_response
|
|
9
|
-
from typing import Any, Mapping, Optional, Union
|
|
9
|
+
from typing import Any, Dict, Mapping, Optional, Union
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class Embeddings(BaseSDK):
|
|
@@ -19,6 +19,7 @@ class Embeddings(BaseSDK):
|
|
|
19
19
|
inputs: Union[
|
|
20
20
|
models.EmbeddingRequestInputs, models.EmbeddingRequestInputsTypedDict
|
|
21
21
|
],
|
|
22
|
+
metadata: OptionalNullable[Dict[str, Any]] = UNSET,
|
|
22
23
|
output_dimension: OptionalNullable[int] = UNSET,
|
|
23
24
|
output_dtype: Optional[models.EmbeddingDtype] = None,
|
|
24
25
|
encoding_format: Optional[models.EncodingFormat] = None,
|
|
@@ -33,6 +34,7 @@ class Embeddings(BaseSDK):
|
|
|
33
34
|
|
|
34
35
|
:param model: The ID of the model to be used for embedding.
|
|
35
36
|
:param inputs: The text content to be embedded, can be a string or an array of strings for fast processing in bulk.
|
|
37
|
+
:param metadata:
|
|
36
38
|
:param output_dimension: The dimension of the output embeddings when feature available. If not provided, a default output dimension will be used.
|
|
37
39
|
:param output_dtype:
|
|
38
40
|
:param encoding_format:
|
|
@@ -53,6 +55,7 @@ class Embeddings(BaseSDK):
|
|
|
53
55
|
|
|
54
56
|
request = models.EmbeddingRequest(
|
|
55
57
|
model=model,
|
|
58
|
+
metadata=metadata,
|
|
56
59
|
inputs=inputs,
|
|
57
60
|
output_dimension=output_dimension,
|
|
58
61
|
output_dtype=output_dtype,
|
|
@@ -125,6 +128,7 @@ class Embeddings(BaseSDK):
|
|
|
125
128
|
inputs: Union[
|
|
126
129
|
models.EmbeddingRequestInputs, models.EmbeddingRequestInputsTypedDict
|
|
127
130
|
],
|
|
131
|
+
metadata: OptionalNullable[Dict[str, Any]] = UNSET,
|
|
128
132
|
output_dimension: OptionalNullable[int] = UNSET,
|
|
129
133
|
output_dtype: Optional[models.EmbeddingDtype] = None,
|
|
130
134
|
encoding_format: Optional[models.EncodingFormat] = None,
|
|
@@ -139,6 +143,7 @@ class Embeddings(BaseSDK):
|
|
|
139
143
|
|
|
140
144
|
:param model: The ID of the model to be used for embedding.
|
|
141
145
|
:param inputs: The text content to be embedded, can be a string or an array of strings for fast processing in bulk.
|
|
146
|
+
:param metadata:
|
|
142
147
|
:param output_dimension: The dimension of the output embeddings when feature available. If not provided, a default output dimension will be used.
|
|
143
148
|
:param output_dtype:
|
|
144
149
|
:param encoding_format:
|
|
@@ -159,6 +164,7 @@ class Embeddings(BaseSDK):
|
|
|
159
164
|
|
|
160
165
|
request = models.EmbeddingRequest(
|
|
161
166
|
model=model,
|
|
167
|
+
metadata=metadata,
|
|
162
168
|
inputs=inputs,
|
|
163
169
|
output_dimension=output_dimension,
|
|
164
170
|
output_dtype=output_dtype,
|
mistralai/extra/README.md
CHANGED
|
@@ -34,7 +34,7 @@ class Chat(BaseSDK):
|
|
|
34
34
|
|
|
35
35
|
3. Now build the SDK with the custom code:
|
|
36
36
|
```bash
|
|
37
|
-
rm -rf dist;
|
|
37
|
+
rm -rf dist; uv build; uv pip install --reinstall ~/client-python/dist/mistralai-1.4.1-py3-none-any.whl
|
|
38
38
|
```
|
|
39
39
|
|
|
40
40
|
4. And now you should be able to call the custom method:
|
mistralai/extra/mcp/auth.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
2
|
|
|
3
|
-
from authlib.oauth2.rfc8414 import AuthorizationServerMetadata
|
|
4
|
-
from authlib.integrations.httpx_client import AsyncOAuth2Client as AsyncOAuth2ClientBase
|
|
5
3
|
import httpx
|
|
6
|
-
import
|
|
4
|
+
from authlib.integrations.httpx_client import AsyncOAuth2Client as AsyncOAuth2ClientBase
|
|
5
|
+
from authlib.oauth2.rfc8414 import AuthorizationServerMetadata
|
|
7
6
|
|
|
8
7
|
from mistralai.types import BaseModel
|
|
9
8
|
|
|
@@ -16,8 +15,8 @@ class Oauth2AuthorizationScheme(BaseModel):
|
|
|
16
15
|
authorization_url: str
|
|
17
16
|
token_url: str
|
|
18
17
|
scope: list[str]
|
|
19
|
-
description:
|
|
20
|
-
refresh_url:
|
|
18
|
+
description: str | None = None
|
|
19
|
+
refresh_url: str | None = None
|
|
21
20
|
|
|
22
21
|
|
|
23
22
|
class OAuthParams(BaseModel):
|
|
@@ -42,7 +41,7 @@ class AsyncOAuth2Client(AsyncOAuth2ClientBase):
|
|
|
42
41
|
|
|
43
42
|
async def get_well_known_authorization_server_metadata(
|
|
44
43
|
server_url: str,
|
|
45
|
-
) ->
|
|
44
|
+
) -> AuthorizationServerMetadata | None:
|
|
46
45
|
"""Fetch the metadata from the well-known location.
|
|
47
46
|
|
|
48
47
|
This should be available on MCP servers as described by the specification:
|
|
@@ -123,10 +122,10 @@ async def dynamic_client_registration(
|
|
|
123
122
|
async def build_oauth_params(
|
|
124
123
|
server_url: str,
|
|
125
124
|
redirect_url: str,
|
|
126
|
-
client_id:
|
|
127
|
-
client_secret:
|
|
128
|
-
scope:
|
|
129
|
-
async_client:
|
|
125
|
+
client_id: str | None = None,
|
|
126
|
+
client_secret: str | None = None,
|
|
127
|
+
scope: list[str] | None = None,
|
|
128
|
+
async_client: httpx.AsyncClient | None = None,
|
|
130
129
|
) -> OAuthParams:
|
|
131
130
|
"""Get issuer metadata and build the oauth required params."""
|
|
132
131
|
metadata = await get_oauth_server_metadata(server_url=server_url)
|
mistralai/extra/mcp/base.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
|
-
from typing import Optional, Union
|
|
2
1
|
import logging
|
|
3
2
|
import typing
|
|
3
|
+
from collections.abc import Sequence
|
|
4
4
|
from contextlib import AsyncExitStack
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Any, Protocol
|
|
6
6
|
|
|
7
|
-
from mcp import ClientSession
|
|
8
|
-
from mcp.types import
|
|
7
|
+
from mcp import ClientSession # pyright: ignore[reportMissingImports]
|
|
8
|
+
from mcp.types import ( # pyright: ignore[reportMissingImports]
|
|
9
|
+
ContentBlock,
|
|
10
|
+
ListPromptsResult,
|
|
11
|
+
)
|
|
9
12
|
|
|
10
13
|
from mistralai.extra.exceptions import MCPException
|
|
11
14
|
from mistralai.models import (
|
|
@@ -20,8 +23,8 @@ logger = logging.getLogger(__name__)
|
|
|
20
23
|
|
|
21
24
|
|
|
22
25
|
class MCPSystemPrompt(typing.TypedDict):
|
|
23
|
-
description:
|
|
24
|
-
messages: list[
|
|
26
|
+
description: str | None
|
|
27
|
+
messages: list[SystemMessageTypedDict | AssistantMessageTypedDict]
|
|
25
28
|
|
|
26
29
|
|
|
27
30
|
class MCPClientProtocol(Protocol):
|
|
@@ -29,7 +32,7 @@ class MCPClientProtocol(Protocol):
|
|
|
29
32
|
|
|
30
33
|
_name: str
|
|
31
34
|
|
|
32
|
-
async def initialize(self, exit_stack:
|
|
35
|
+
async def initialize(self, exit_stack: AsyncExitStack | None) -> None:
|
|
33
36
|
...
|
|
34
37
|
|
|
35
38
|
async def aclose(self) -> None:
|
|
@@ -39,7 +42,7 @@ class MCPClientProtocol(Protocol):
|
|
|
39
42
|
...
|
|
40
43
|
|
|
41
44
|
async def execute_tool(
|
|
42
|
-
self, name: str, arguments: dict
|
|
45
|
+
self, name: str, arguments: dict[str, Any]
|
|
43
46
|
) -> list[TextChunkTypedDict]:
|
|
44
47
|
...
|
|
45
48
|
|
|
@@ -57,20 +60,18 @@ class MCPClientBase(MCPClientProtocol):
|
|
|
57
60
|
|
|
58
61
|
_session: ClientSession
|
|
59
62
|
|
|
60
|
-
def __init__(self, name:
|
|
63
|
+
def __init__(self, name: str | None = None):
|
|
61
64
|
self._name = name or self.__class__.__name__
|
|
62
|
-
self._exit_stack:
|
|
65
|
+
self._exit_stack: AsyncExitStack | None = None
|
|
63
66
|
self._is_initialized = False
|
|
64
67
|
|
|
65
|
-
def _convert_content(
|
|
66
|
-
self, mcp_content: Union[TextContent, ImageContent, EmbeddedResource]
|
|
67
|
-
) -> TextChunkTypedDict:
|
|
68
|
+
def _convert_content(self, mcp_content: ContentBlock) -> TextChunkTypedDict:
|
|
68
69
|
if not mcp_content.type == "text":
|
|
69
70
|
raise MCPException("Only supporting text tool responses for now.")
|
|
70
71
|
return {"type": "text", "text": mcp_content.text}
|
|
71
72
|
|
|
72
73
|
def _convert_content_list(
|
|
73
|
-
self, mcp_contents:
|
|
74
|
+
self, mcp_contents: Sequence[ContentBlock]
|
|
74
75
|
) -> list[TextChunkTypedDict]:
|
|
75
76
|
content_chunks = []
|
|
76
77
|
for mcp_content in mcp_contents:
|
|
@@ -108,7 +109,7 @@ class MCPClientBase(MCPClientProtocol):
|
|
|
108
109
|
"description": prompt_result.description,
|
|
109
110
|
"messages": [
|
|
110
111
|
typing.cast(
|
|
111
|
-
|
|
112
|
+
SystemMessageTypedDict | AssistantMessageTypedDict,
|
|
112
113
|
{
|
|
113
114
|
"role": message.role,
|
|
114
115
|
"content": self._convert_content(mcp_content=message.content),
|
|
@@ -121,7 +122,7 @@ class MCPClientBase(MCPClientProtocol):
|
|
|
121
122
|
async def list_system_prompts(self) -> ListPromptsResult:
|
|
122
123
|
return await self._session.list_prompts()
|
|
123
124
|
|
|
124
|
-
async def initialize(self, exit_stack:
|
|
125
|
+
async def initialize(self, exit_stack: AsyncExitStack | None = None) -> None:
|
|
125
126
|
"""Initialize the MCP session."""
|
|
126
127
|
# client is already initialized so return
|
|
127
128
|
if self._is_initialized:
|
mistralai/extra/mcp/sse.py
CHANGED
|
@@ -1,22 +1,20 @@
|
|
|
1
1
|
import http
|
|
2
2
|
import logging
|
|
3
|
-
import typing
|
|
4
|
-
from typing import Any, Optional
|
|
5
3
|
from contextlib import AsyncExitStack
|
|
6
4
|
from functools import cached_property
|
|
5
|
+
from typing import Any
|
|
7
6
|
|
|
8
7
|
import httpx
|
|
8
|
+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
9
|
+
from authlib.oauth2.rfc6749 import OAuth2Token
|
|
10
|
+
from mcp.client.sse import sse_client # pyright: ignore[reportMissingImports]
|
|
11
|
+
from mcp.shared.message import SessionMessage # pyright: ignore[reportMissingImports]
|
|
9
12
|
|
|
10
13
|
from mistralai.extra.exceptions import MCPAuthException
|
|
11
14
|
from mistralai.extra.mcp.base import (
|
|
12
15
|
MCPClientBase,
|
|
13
16
|
)
|
|
14
17
|
from mistralai.extra.mcp.auth import OAuthParams, AsyncOAuth2Client
|
|
15
|
-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
16
|
-
|
|
17
|
-
from mcp.client.sse import sse_client
|
|
18
|
-
from mcp.shared.message import SessionMessage
|
|
19
|
-
from authlib.oauth2.rfc6749 import OAuth2Token
|
|
20
18
|
|
|
21
19
|
from mistralai.types import BaseModel
|
|
22
20
|
|
|
@@ -27,7 +25,7 @@ class SSEServerParams(BaseModel):
|
|
|
27
25
|
"""Parameters required for a MCPClient with SSE transport"""
|
|
28
26
|
|
|
29
27
|
url: str
|
|
30
|
-
headers:
|
|
28
|
+
headers: dict[str, Any] | None = None
|
|
31
29
|
timeout: float = 5
|
|
32
30
|
sse_read_timeout: float = 60 * 5
|
|
33
31
|
|
|
@@ -41,20 +39,20 @@ class MCPClientSSE(MCPClientBase):
|
|
|
41
39
|
This is possibly going to change in the future since the protocol has ongoing discussions.
|
|
42
40
|
"""
|
|
43
41
|
|
|
44
|
-
_oauth_params:
|
|
42
|
+
_oauth_params: OAuthParams | None
|
|
45
43
|
_sse_params: SSEServerParams
|
|
46
44
|
|
|
47
45
|
def __init__(
|
|
48
46
|
self,
|
|
49
47
|
sse_params: SSEServerParams,
|
|
50
|
-
name:
|
|
51
|
-
oauth_params:
|
|
52
|
-
auth_token:
|
|
48
|
+
name: str | None = None,
|
|
49
|
+
oauth_params: OAuthParams | None = None,
|
|
50
|
+
auth_token: OAuth2Token | None = None,
|
|
53
51
|
):
|
|
54
52
|
super().__init__(name=name)
|
|
55
53
|
self._sse_params = sse_params
|
|
56
|
-
self._oauth_params:
|
|
57
|
-
self._auth_token:
|
|
54
|
+
self._oauth_params: OAuthParams | None = oauth_params
|
|
55
|
+
self._auth_token: OAuth2Token | None = auth_token
|
|
58
56
|
|
|
59
57
|
@cached_property
|
|
60
58
|
def base_url(self) -> str:
|
|
@@ -142,7 +140,7 @@ class MCPClientSSE(MCPClientBase):
|
|
|
142
140
|
async def _get_transport(
|
|
143
141
|
self, exit_stack: AsyncExitStack
|
|
144
142
|
) -> tuple[
|
|
145
|
-
MemoryObjectReceiveStream[
|
|
143
|
+
MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
146
144
|
MemoryObjectSendStream[SessionMessage],
|
|
147
145
|
]:
|
|
148
146
|
try:
|
mistralai/extra/mcp/stdio.py
CHANGED
|
@@ -1,12 +1,9 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
1
|
import logging
|
|
3
2
|
from contextlib import AsyncExitStack
|
|
4
3
|
|
|
5
|
-
from
|
|
6
|
-
MCPClientBase,
|
|
7
|
-
)
|
|
4
|
+
from mcp import StdioServerParameters, stdio_client # pyright: ignore[reportMissingImports]
|
|
8
5
|
|
|
9
|
-
from mcp import
|
|
6
|
+
from mistralai.extra.mcp.base import MCPClientBase
|
|
10
7
|
|
|
11
8
|
logger = logging.getLogger(__name__)
|
|
12
9
|
|
|
@@ -14,7 +11,9 @@ logger = logging.getLogger(__name__)
|
|
|
14
11
|
class MCPClientSTDIO(MCPClientBase):
|
|
15
12
|
"""MCP client that uses stdio for communication."""
|
|
16
13
|
|
|
17
|
-
def __init__(
|
|
14
|
+
def __init__(
|
|
15
|
+
self, stdio_params: StdioServerParameters, name: str | None = None
|
|
16
|
+
):
|
|
18
17
|
super().__init__(name=name)
|
|
19
18
|
self._stdio_params = stdio_params
|
|
20
19
|
|