mistralai 1.9.11__py3-none-any.whl → 1.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/_hooks/registration.py +5 -0
- mistralai/_hooks/tracing.py +75 -0
- mistralai/_version.py +2 -2
- mistralai/accesses.py +8 -8
- mistralai/agents.py +29 -17
- mistralai/chat.py +41 -29
- mistralai/classifiers.py +13 -1
- mistralai/conversations.py +294 -62
- mistralai/documents.py +19 -3
- mistralai/embeddings.py +13 -7
- mistralai/extra/README.md +1 -1
- mistralai/extra/mcp/auth.py +10 -11
- mistralai/extra/mcp/base.py +17 -16
- mistralai/extra/mcp/sse.py +13 -15
- mistralai/extra/mcp/stdio.py +5 -6
- mistralai/extra/observability/__init__.py +15 -0
- mistralai/extra/observability/otel.py +372 -0
- mistralai/extra/run/context.py +33 -43
- mistralai/extra/run/result.py +29 -30
- mistralai/extra/run/tools.py +34 -23
- mistralai/extra/struct_chat.py +15 -8
- mistralai/extra/utils/response_format.py +5 -3
- mistralai/files.py +6 -0
- mistralai/fim.py +17 -5
- mistralai/mistral_agents.py +229 -1
- mistralai/mistral_jobs.py +39 -13
- mistralai/models/__init__.py +99 -3
- mistralai/models/agent.py +15 -2
- mistralai/models/agentconversation.py +11 -3
- mistralai/models/agentcreationrequest.py +6 -2
- mistralai/models/agents_api_v1_agents_deleteop.py +16 -0
- mistralai/models/agents_api_v1_agents_getop.py +40 -3
- mistralai/models/agents_api_v1_agents_listop.py +72 -2
- mistralai/models/agents_api_v1_conversations_deleteop.py +18 -0
- mistralai/models/agents_api_v1_conversations_listop.py +39 -2
- mistralai/models/agentscompletionrequest.py +21 -6
- mistralai/models/agentscompletionstreamrequest.py +21 -6
- mistralai/models/agentupdaterequest.py +18 -2
- mistralai/models/audioencoding.py +13 -0
- mistralai/models/audioformat.py +19 -0
- mistralai/models/audiotranscriptionrequest.py +2 -0
- mistralai/models/batchjobin.py +26 -5
- mistralai/models/batchjobout.py +5 -0
- mistralai/models/batchrequest.py +48 -0
- mistralai/models/chatcompletionrequest.py +22 -5
- mistralai/models/chatcompletionstreamrequest.py +22 -5
- mistralai/models/classificationrequest.py +37 -3
- mistralai/models/conversationrequest.py +15 -4
- mistralai/models/conversationrestartrequest.py +50 -2
- mistralai/models/conversationrestartstreamrequest.py +50 -2
- mistralai/models/conversationstreamrequest.py +15 -4
- mistralai/models/documentout.py +26 -10
- mistralai/models/documentupdatein.py +24 -3
- mistralai/models/embeddingrequest.py +19 -11
- mistralai/models/files_api_routes_list_filesop.py +7 -0
- mistralai/models/fimcompletionrequest.py +8 -9
- mistralai/models/fimcompletionstreamrequest.py +8 -9
- mistralai/models/jobs_api_routes_batch_get_batch_jobop.py +40 -3
- mistralai/models/libraries_documents_list_v1op.py +15 -2
- mistralai/models/libraryout.py +10 -7
- mistralai/models/listfilesout.py +35 -4
- mistralai/models/modelcapabilities.py +13 -4
- mistralai/models/modelconversation.py +8 -2
- mistralai/models/ocrpageobject.py +26 -5
- mistralai/models/ocrrequest.py +17 -1
- mistralai/models/ocrtableobject.py +31 -0
- mistralai/models/prediction.py +4 -0
- mistralai/models/requestsource.py +7 -0
- mistralai/models/responseformat.py +4 -2
- mistralai/models/responseformats.py +0 -1
- mistralai/models/sharingdelete.py +36 -5
- mistralai/models/sharingin.py +36 -5
- mistralai/models/sharingout.py +3 -3
- mistralai/models/toolexecutiondeltaevent.py +13 -4
- mistralai/models/toolexecutiondoneevent.py +13 -4
- mistralai/models/toolexecutionentry.py +9 -4
- mistralai/models/toolexecutionstartedevent.py +13 -4
- mistralai/models/toolfilechunk.py +11 -4
- mistralai/models/toolreferencechunk.py +13 -4
- mistralai/models_.py +2 -14
- mistralai/ocr.py +18 -0
- mistralai/transcriptions.py +4 -4
- {mistralai-1.9.11.dist-info → mistralai-1.10.1.dist-info}/METADATA +162 -152
- {mistralai-1.9.11.dist-info → mistralai-1.10.1.dist-info}/RECORD +168 -144
- {mistralai-1.9.11.dist-info → mistralai-1.10.1.dist-info}/WHEEL +1 -1
- mistralai_azure/_version.py +3 -3
- mistralai_azure/basesdk.py +15 -5
- mistralai_azure/chat.py +59 -98
- mistralai_azure/models/__init__.py +50 -3
- mistralai_azure/models/chatcompletionrequest.py +16 -4
- mistralai_azure/models/chatcompletionstreamrequest.py +16 -4
- mistralai_azure/models/httpvalidationerror.py +11 -6
- mistralai_azure/models/mistralazureerror.py +26 -0
- mistralai_azure/models/no_response_error.py +13 -0
- mistralai_azure/models/prediction.py +4 -0
- mistralai_azure/models/responseformat.py +4 -2
- mistralai_azure/models/responseformats.py +0 -1
- mistralai_azure/models/responsevalidationerror.py +25 -0
- mistralai_azure/models/sdkerror.py +30 -14
- mistralai_azure/models/systemmessage.py +7 -3
- mistralai_azure/models/systemmessagecontentchunks.py +21 -0
- mistralai_azure/models/thinkchunk.py +35 -0
- mistralai_azure/ocr.py +15 -36
- mistralai_azure/utils/__init__.py +18 -5
- mistralai_azure/utils/eventstreaming.py +10 -0
- mistralai_azure/utils/serializers.py +3 -2
- mistralai_azure/utils/unmarshal_json_response.py +24 -0
- mistralai_gcp/_hooks/types.py +7 -0
- mistralai_gcp/_version.py +4 -4
- mistralai_gcp/basesdk.py +27 -25
- mistralai_gcp/chat.py +75 -98
- mistralai_gcp/fim.py +39 -74
- mistralai_gcp/httpclient.py +6 -16
- mistralai_gcp/models/__init__.py +321 -116
- mistralai_gcp/models/assistantmessage.py +1 -1
- mistralai_gcp/models/chatcompletionrequest.py +36 -7
- mistralai_gcp/models/chatcompletionresponse.py +6 -6
- mistralai_gcp/models/chatcompletionstreamrequest.py +36 -7
- mistralai_gcp/models/completionresponsestreamchoice.py +1 -1
- mistralai_gcp/models/deltamessage.py +1 -1
- mistralai_gcp/models/fimcompletionrequest.py +3 -9
- mistralai_gcp/models/fimcompletionresponse.py +6 -6
- mistralai_gcp/models/fimcompletionstreamrequest.py +3 -9
- mistralai_gcp/models/httpvalidationerror.py +11 -6
- mistralai_gcp/models/imageurl.py +1 -1
- mistralai_gcp/models/jsonschema.py +1 -1
- mistralai_gcp/models/mistralgcperror.py +26 -0
- mistralai_gcp/models/mistralpromptmode.py +8 -0
- mistralai_gcp/models/no_response_error.py +13 -0
- mistralai_gcp/models/prediction.py +4 -0
- mistralai_gcp/models/responseformat.py +5 -3
- mistralai_gcp/models/responseformats.py +0 -1
- mistralai_gcp/models/responsevalidationerror.py +25 -0
- mistralai_gcp/models/sdkerror.py +30 -14
- mistralai_gcp/models/systemmessage.py +7 -3
- mistralai_gcp/models/systemmessagecontentchunks.py +21 -0
- mistralai_gcp/models/thinkchunk.py +35 -0
- mistralai_gcp/models/toolmessage.py +1 -1
- mistralai_gcp/models/usageinfo.py +71 -8
- mistralai_gcp/models/usermessage.py +1 -1
- mistralai_gcp/sdk.py +12 -10
- mistralai_gcp/sdkconfiguration.py +0 -7
- mistralai_gcp/types/basemodel.py +3 -3
- mistralai_gcp/utils/__init__.py +143 -45
- mistralai_gcp/utils/datetimes.py +23 -0
- mistralai_gcp/utils/enums.py +67 -27
- mistralai_gcp/utils/eventstreaming.py +10 -0
- mistralai_gcp/utils/forms.py +49 -28
- mistralai_gcp/utils/serializers.py +33 -3
- mistralai_gcp/utils/unmarshal_json_response.py +24 -0
- {mistralai-1.9.11.dist-info → mistralai-1.10.1.dist-info}/licenses/LICENSE +0 -0
mistralai/documents.py
CHANGED
|
@@ -6,7 +6,7 @@ from mistralai._hooks import HookContext
|
|
|
6
6
|
from mistralai.types import OptionalNullable, UNSET
|
|
7
7
|
from mistralai.utils import get_security_from_env
|
|
8
8
|
from mistralai.utils.unmarshal_json_response import unmarshal_json_response
|
|
9
|
-
from typing import Any, Mapping, Optional, Union
|
|
9
|
+
from typing import Any, Dict, Mapping, Optional, Union
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class Documents(BaseSDK):
|
|
@@ -19,6 +19,7 @@ class Documents(BaseSDK):
|
|
|
19
19
|
search: OptionalNullable[str] = UNSET,
|
|
20
20
|
page_size: Optional[int] = 100,
|
|
21
21
|
page: Optional[int] = 0,
|
|
22
|
+
filters_attributes: OptionalNullable[str] = UNSET,
|
|
22
23
|
sort_by: Optional[str] = "created_at",
|
|
23
24
|
sort_order: Optional[str] = "desc",
|
|
24
25
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
@@ -26,7 +27,7 @@ class Documents(BaseSDK):
|
|
|
26
27
|
timeout_ms: Optional[int] = None,
|
|
27
28
|
http_headers: Optional[Mapping[str, str]] = None,
|
|
28
29
|
) -> models.ListDocumentOut:
|
|
29
|
-
r"""List
|
|
30
|
+
r"""List documents in a given library.
|
|
30
31
|
|
|
31
32
|
Given a library, lists the document that have been uploaded to that library.
|
|
32
33
|
|
|
@@ -34,6 +35,7 @@ class Documents(BaseSDK):
|
|
|
34
35
|
:param search:
|
|
35
36
|
:param page_size:
|
|
36
37
|
:param page:
|
|
38
|
+
:param filters_attributes:
|
|
37
39
|
:param sort_by:
|
|
38
40
|
:param sort_order:
|
|
39
41
|
:param retries: Override the default retry configuration for this method
|
|
@@ -56,6 +58,7 @@ class Documents(BaseSDK):
|
|
|
56
58
|
search=search,
|
|
57
59
|
page_size=page_size,
|
|
58
60
|
page=page,
|
|
61
|
+
filters_attributes=filters_attributes,
|
|
59
62
|
sort_by=sort_by,
|
|
60
63
|
sort_order=sort_order,
|
|
61
64
|
)
|
|
@@ -123,6 +126,7 @@ class Documents(BaseSDK):
|
|
|
123
126
|
search: OptionalNullable[str] = UNSET,
|
|
124
127
|
page_size: Optional[int] = 100,
|
|
125
128
|
page: Optional[int] = 0,
|
|
129
|
+
filters_attributes: OptionalNullable[str] = UNSET,
|
|
126
130
|
sort_by: Optional[str] = "created_at",
|
|
127
131
|
sort_order: Optional[str] = "desc",
|
|
128
132
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
@@ -130,7 +134,7 @@ class Documents(BaseSDK):
|
|
|
130
134
|
timeout_ms: Optional[int] = None,
|
|
131
135
|
http_headers: Optional[Mapping[str, str]] = None,
|
|
132
136
|
) -> models.ListDocumentOut:
|
|
133
|
-
r"""List
|
|
137
|
+
r"""List documents in a given library.
|
|
134
138
|
|
|
135
139
|
Given a library, lists the document that have been uploaded to that library.
|
|
136
140
|
|
|
@@ -138,6 +142,7 @@ class Documents(BaseSDK):
|
|
|
138
142
|
:param search:
|
|
139
143
|
:param page_size:
|
|
140
144
|
:param page:
|
|
145
|
+
:param filters_attributes:
|
|
141
146
|
:param sort_by:
|
|
142
147
|
:param sort_order:
|
|
143
148
|
:param retries: Override the default retry configuration for this method
|
|
@@ -160,6 +165,7 @@ class Documents(BaseSDK):
|
|
|
160
165
|
search=search,
|
|
161
166
|
page_size=page_size,
|
|
162
167
|
page=page,
|
|
168
|
+
filters_attributes=filters_attributes,
|
|
163
169
|
sort_by=sort_by,
|
|
164
170
|
sort_order=sort_order,
|
|
165
171
|
)
|
|
@@ -612,6 +618,9 @@ class Documents(BaseSDK):
|
|
|
612
618
|
library_id: str,
|
|
613
619
|
document_id: str,
|
|
614
620
|
name: OptionalNullable[str] = UNSET,
|
|
621
|
+
attributes: OptionalNullable[
|
|
622
|
+
Union[Dict[str, models.Attributes], Dict[str, models.AttributesTypedDict]]
|
|
623
|
+
] = UNSET,
|
|
615
624
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
616
625
|
server_url: Optional[str] = None,
|
|
617
626
|
timeout_ms: Optional[int] = None,
|
|
@@ -624,6 +633,7 @@ class Documents(BaseSDK):
|
|
|
624
633
|
:param library_id:
|
|
625
634
|
:param document_id:
|
|
626
635
|
:param name:
|
|
636
|
+
:param attributes:
|
|
627
637
|
:param retries: Override the default retry configuration for this method
|
|
628
638
|
:param server_url: Override the default server URL for this method
|
|
629
639
|
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
@@ -644,6 +654,7 @@ class Documents(BaseSDK):
|
|
|
644
654
|
document_id=document_id,
|
|
645
655
|
document_update_in=models.DocumentUpdateIn(
|
|
646
656
|
name=name,
|
|
657
|
+
attributes=attributes,
|
|
647
658
|
),
|
|
648
659
|
)
|
|
649
660
|
|
|
@@ -716,6 +727,9 @@ class Documents(BaseSDK):
|
|
|
716
727
|
library_id: str,
|
|
717
728
|
document_id: str,
|
|
718
729
|
name: OptionalNullable[str] = UNSET,
|
|
730
|
+
attributes: OptionalNullable[
|
|
731
|
+
Union[Dict[str, models.Attributes], Dict[str, models.AttributesTypedDict]]
|
|
732
|
+
] = UNSET,
|
|
719
733
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
720
734
|
server_url: Optional[str] = None,
|
|
721
735
|
timeout_ms: Optional[int] = None,
|
|
@@ -728,6 +742,7 @@ class Documents(BaseSDK):
|
|
|
728
742
|
:param library_id:
|
|
729
743
|
:param document_id:
|
|
730
744
|
:param name:
|
|
745
|
+
:param attributes:
|
|
731
746
|
:param retries: Override the default retry configuration for this method
|
|
732
747
|
:param server_url: Override the default server URL for this method
|
|
733
748
|
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
@@ -748,6 +763,7 @@ class Documents(BaseSDK):
|
|
|
748
763
|
document_id=document_id,
|
|
749
764
|
document_update_in=models.DocumentUpdateIn(
|
|
750
765
|
name=name,
|
|
766
|
+
attributes=attributes,
|
|
751
767
|
),
|
|
752
768
|
)
|
|
753
769
|
|
mistralai/embeddings.py
CHANGED
|
@@ -6,7 +6,7 @@ from mistralai._hooks import HookContext
|
|
|
6
6
|
from mistralai.types import OptionalNullable, UNSET
|
|
7
7
|
from mistralai.utils import get_security_from_env
|
|
8
8
|
from mistralai.utils.unmarshal_json_response import unmarshal_json_response
|
|
9
|
-
from typing import Any, Mapping, Optional, Union
|
|
9
|
+
from typing import Any, Dict, Mapping, Optional, Union
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class Embeddings(BaseSDK):
|
|
@@ -19,6 +19,7 @@ class Embeddings(BaseSDK):
|
|
|
19
19
|
inputs: Union[
|
|
20
20
|
models.EmbeddingRequestInputs, models.EmbeddingRequestInputsTypedDict
|
|
21
21
|
],
|
|
22
|
+
metadata: OptionalNullable[Dict[str, Any]] = UNSET,
|
|
22
23
|
output_dimension: OptionalNullable[int] = UNSET,
|
|
23
24
|
output_dtype: Optional[models.EmbeddingDtype] = None,
|
|
24
25
|
encoding_format: Optional[models.EncodingFormat] = None,
|
|
@@ -31,9 +32,10 @@ class Embeddings(BaseSDK):
|
|
|
31
32
|
|
|
32
33
|
Embeddings
|
|
33
34
|
|
|
34
|
-
:param model: ID of the model to
|
|
35
|
-
:param inputs:
|
|
36
|
-
:param
|
|
35
|
+
:param model: The ID of the model to be used for embedding.
|
|
36
|
+
:param inputs: The text content to be embedded, can be a string or an array of strings for fast processing in bulk.
|
|
37
|
+
:param metadata:
|
|
38
|
+
:param output_dimension: The dimension of the output embeddings when feature available. If not provided, a default output dimension will be used.
|
|
37
39
|
:param output_dtype:
|
|
38
40
|
:param encoding_format:
|
|
39
41
|
:param retries: Override the default retry configuration for this method
|
|
@@ -53,6 +55,7 @@ class Embeddings(BaseSDK):
|
|
|
53
55
|
|
|
54
56
|
request = models.EmbeddingRequest(
|
|
55
57
|
model=model,
|
|
58
|
+
metadata=metadata,
|
|
56
59
|
inputs=inputs,
|
|
57
60
|
output_dimension=output_dimension,
|
|
58
61
|
output_dtype=output_dtype,
|
|
@@ -125,6 +128,7 @@ class Embeddings(BaseSDK):
|
|
|
125
128
|
inputs: Union[
|
|
126
129
|
models.EmbeddingRequestInputs, models.EmbeddingRequestInputsTypedDict
|
|
127
130
|
],
|
|
131
|
+
metadata: OptionalNullable[Dict[str, Any]] = UNSET,
|
|
128
132
|
output_dimension: OptionalNullable[int] = UNSET,
|
|
129
133
|
output_dtype: Optional[models.EmbeddingDtype] = None,
|
|
130
134
|
encoding_format: Optional[models.EncodingFormat] = None,
|
|
@@ -137,9 +141,10 @@ class Embeddings(BaseSDK):
|
|
|
137
141
|
|
|
138
142
|
Embeddings
|
|
139
143
|
|
|
140
|
-
:param model: ID of the model to
|
|
141
|
-
:param inputs:
|
|
142
|
-
:param
|
|
144
|
+
:param model: The ID of the model to be used for embedding.
|
|
145
|
+
:param inputs: The text content to be embedded, can be a string or an array of strings for fast processing in bulk.
|
|
146
|
+
:param metadata:
|
|
147
|
+
:param output_dimension: The dimension of the output embeddings when feature available. If not provided, a default output dimension will be used.
|
|
143
148
|
:param output_dtype:
|
|
144
149
|
:param encoding_format:
|
|
145
150
|
:param retries: Override the default retry configuration for this method
|
|
@@ -159,6 +164,7 @@ class Embeddings(BaseSDK):
|
|
|
159
164
|
|
|
160
165
|
request = models.EmbeddingRequest(
|
|
161
166
|
model=model,
|
|
167
|
+
metadata=metadata,
|
|
162
168
|
inputs=inputs,
|
|
163
169
|
output_dimension=output_dimension,
|
|
164
170
|
output_dtype=output_dtype,
|
mistralai/extra/README.md
CHANGED
|
@@ -34,7 +34,7 @@ class Chat(BaseSDK):
|
|
|
34
34
|
|
|
35
35
|
3. Now build the SDK with the custom code:
|
|
36
36
|
```bash
|
|
37
|
-
rm -rf dist;
|
|
37
|
+
rm -rf dist; uv build; uv pip install --reinstall ~/client-python/dist/mistralai-1.4.1-py3-none-any.whl
|
|
38
38
|
```
|
|
39
39
|
|
|
40
40
|
4. And now you should be able to call the custom method:
|
mistralai/extra/mcp/auth.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
2
|
|
|
3
|
-
from authlib.oauth2.rfc8414 import AuthorizationServerMetadata
|
|
4
|
-
from authlib.integrations.httpx_client import AsyncOAuth2Client as AsyncOAuth2ClientBase
|
|
5
3
|
import httpx
|
|
6
|
-
import
|
|
4
|
+
from authlib.integrations.httpx_client import AsyncOAuth2Client as AsyncOAuth2ClientBase
|
|
5
|
+
from authlib.oauth2.rfc8414 import AuthorizationServerMetadata
|
|
7
6
|
|
|
8
7
|
from mistralai.types import BaseModel
|
|
9
8
|
|
|
@@ -16,8 +15,8 @@ class Oauth2AuthorizationScheme(BaseModel):
|
|
|
16
15
|
authorization_url: str
|
|
17
16
|
token_url: str
|
|
18
17
|
scope: list[str]
|
|
19
|
-
description:
|
|
20
|
-
refresh_url:
|
|
18
|
+
description: str | None = None
|
|
19
|
+
refresh_url: str | None = None
|
|
21
20
|
|
|
22
21
|
|
|
23
22
|
class OAuthParams(BaseModel):
|
|
@@ -42,7 +41,7 @@ class AsyncOAuth2Client(AsyncOAuth2ClientBase):
|
|
|
42
41
|
|
|
43
42
|
async def get_well_known_authorization_server_metadata(
|
|
44
43
|
server_url: str,
|
|
45
|
-
) ->
|
|
44
|
+
) -> AuthorizationServerMetadata | None:
|
|
46
45
|
"""Fetch the metadata from the well-known location.
|
|
47
46
|
|
|
48
47
|
This should be available on MCP servers as described by the specification:
|
|
@@ -123,10 +122,10 @@ async def dynamic_client_registration(
|
|
|
123
122
|
async def build_oauth_params(
|
|
124
123
|
server_url: str,
|
|
125
124
|
redirect_url: str,
|
|
126
|
-
client_id:
|
|
127
|
-
client_secret:
|
|
128
|
-
scope:
|
|
129
|
-
async_client:
|
|
125
|
+
client_id: str | None = None,
|
|
126
|
+
client_secret: str | None = None,
|
|
127
|
+
scope: list[str] | None = None,
|
|
128
|
+
async_client: httpx.AsyncClient | None = None,
|
|
130
129
|
) -> OAuthParams:
|
|
131
130
|
"""Get issuer metadata and build the oauth required params."""
|
|
132
131
|
metadata = await get_oauth_server_metadata(server_url=server_url)
|
mistralai/extra/mcp/base.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
|
-
from typing import Optional, Union
|
|
2
1
|
import logging
|
|
3
2
|
import typing
|
|
3
|
+
from collections.abc import Sequence
|
|
4
4
|
from contextlib import AsyncExitStack
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Any, Protocol
|
|
6
6
|
|
|
7
|
-
from mcp import ClientSession
|
|
8
|
-
from mcp.types import
|
|
7
|
+
from mcp import ClientSession # pyright: ignore[reportMissingImports]
|
|
8
|
+
from mcp.types import ( # pyright: ignore[reportMissingImports]
|
|
9
|
+
ContentBlock,
|
|
10
|
+
ListPromptsResult,
|
|
11
|
+
)
|
|
9
12
|
|
|
10
13
|
from mistralai.extra.exceptions import MCPException
|
|
11
14
|
from mistralai.models import (
|
|
@@ -20,8 +23,8 @@ logger = logging.getLogger(__name__)
|
|
|
20
23
|
|
|
21
24
|
|
|
22
25
|
class MCPSystemPrompt(typing.TypedDict):
|
|
23
|
-
description:
|
|
24
|
-
messages: list[
|
|
26
|
+
description: str | None
|
|
27
|
+
messages: list[SystemMessageTypedDict | AssistantMessageTypedDict]
|
|
25
28
|
|
|
26
29
|
|
|
27
30
|
class MCPClientProtocol(Protocol):
|
|
@@ -29,7 +32,7 @@ class MCPClientProtocol(Protocol):
|
|
|
29
32
|
|
|
30
33
|
_name: str
|
|
31
34
|
|
|
32
|
-
async def initialize(self, exit_stack:
|
|
35
|
+
async def initialize(self, exit_stack: AsyncExitStack | None) -> None:
|
|
33
36
|
...
|
|
34
37
|
|
|
35
38
|
async def aclose(self) -> None:
|
|
@@ -39,7 +42,7 @@ class MCPClientProtocol(Protocol):
|
|
|
39
42
|
...
|
|
40
43
|
|
|
41
44
|
async def execute_tool(
|
|
42
|
-
self, name: str, arguments: dict
|
|
45
|
+
self, name: str, arguments: dict[str, Any]
|
|
43
46
|
) -> list[TextChunkTypedDict]:
|
|
44
47
|
...
|
|
45
48
|
|
|
@@ -57,20 +60,18 @@ class MCPClientBase(MCPClientProtocol):
|
|
|
57
60
|
|
|
58
61
|
_session: ClientSession
|
|
59
62
|
|
|
60
|
-
def __init__(self, name:
|
|
63
|
+
def __init__(self, name: str | None = None):
|
|
61
64
|
self._name = name or self.__class__.__name__
|
|
62
|
-
self._exit_stack:
|
|
65
|
+
self._exit_stack: AsyncExitStack | None = None
|
|
63
66
|
self._is_initialized = False
|
|
64
67
|
|
|
65
|
-
def _convert_content(
|
|
66
|
-
self, mcp_content: Union[TextContent, ImageContent, EmbeddedResource]
|
|
67
|
-
) -> TextChunkTypedDict:
|
|
68
|
+
def _convert_content(self, mcp_content: ContentBlock) -> TextChunkTypedDict:
|
|
68
69
|
if not mcp_content.type == "text":
|
|
69
70
|
raise MCPException("Only supporting text tool responses for now.")
|
|
70
71
|
return {"type": "text", "text": mcp_content.text}
|
|
71
72
|
|
|
72
73
|
def _convert_content_list(
|
|
73
|
-
self, mcp_contents:
|
|
74
|
+
self, mcp_contents: Sequence[ContentBlock]
|
|
74
75
|
) -> list[TextChunkTypedDict]:
|
|
75
76
|
content_chunks = []
|
|
76
77
|
for mcp_content in mcp_contents:
|
|
@@ -108,7 +109,7 @@ class MCPClientBase(MCPClientProtocol):
|
|
|
108
109
|
"description": prompt_result.description,
|
|
109
110
|
"messages": [
|
|
110
111
|
typing.cast(
|
|
111
|
-
|
|
112
|
+
SystemMessageTypedDict | AssistantMessageTypedDict,
|
|
112
113
|
{
|
|
113
114
|
"role": message.role,
|
|
114
115
|
"content": self._convert_content(mcp_content=message.content),
|
|
@@ -121,7 +122,7 @@ class MCPClientBase(MCPClientProtocol):
|
|
|
121
122
|
async def list_system_prompts(self) -> ListPromptsResult:
|
|
122
123
|
return await self._session.list_prompts()
|
|
123
124
|
|
|
124
|
-
async def initialize(self, exit_stack:
|
|
125
|
+
async def initialize(self, exit_stack: AsyncExitStack | None = None) -> None:
|
|
125
126
|
"""Initialize the MCP session."""
|
|
126
127
|
# client is already initialized so return
|
|
127
128
|
if self._is_initialized:
|
mistralai/extra/mcp/sse.py
CHANGED
|
@@ -1,22 +1,20 @@
|
|
|
1
1
|
import http
|
|
2
2
|
import logging
|
|
3
|
-
import typing
|
|
4
|
-
from typing import Any, Optional
|
|
5
3
|
from contextlib import AsyncExitStack
|
|
6
4
|
from functools import cached_property
|
|
5
|
+
from typing import Any
|
|
7
6
|
|
|
8
7
|
import httpx
|
|
8
|
+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
9
|
+
from authlib.oauth2.rfc6749 import OAuth2Token
|
|
10
|
+
from mcp.client.sse import sse_client # pyright: ignore[reportMissingImports]
|
|
11
|
+
from mcp.shared.message import SessionMessage # pyright: ignore[reportMissingImports]
|
|
9
12
|
|
|
10
13
|
from mistralai.extra.exceptions import MCPAuthException
|
|
11
14
|
from mistralai.extra.mcp.base import (
|
|
12
15
|
MCPClientBase,
|
|
13
16
|
)
|
|
14
17
|
from mistralai.extra.mcp.auth import OAuthParams, AsyncOAuth2Client
|
|
15
|
-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
16
|
-
|
|
17
|
-
from mcp.client.sse import sse_client
|
|
18
|
-
from mcp.shared.message import SessionMessage
|
|
19
|
-
from authlib.oauth2.rfc6749 import OAuth2Token
|
|
20
18
|
|
|
21
19
|
from mistralai.types import BaseModel
|
|
22
20
|
|
|
@@ -27,7 +25,7 @@ class SSEServerParams(BaseModel):
|
|
|
27
25
|
"""Parameters required for a MCPClient with SSE transport"""
|
|
28
26
|
|
|
29
27
|
url: str
|
|
30
|
-
headers:
|
|
28
|
+
headers: dict[str, Any] | None = None
|
|
31
29
|
timeout: float = 5
|
|
32
30
|
sse_read_timeout: float = 60 * 5
|
|
33
31
|
|
|
@@ -41,20 +39,20 @@ class MCPClientSSE(MCPClientBase):
|
|
|
41
39
|
This is possibly going to change in the future since the protocol has ongoing discussions.
|
|
42
40
|
"""
|
|
43
41
|
|
|
44
|
-
_oauth_params:
|
|
42
|
+
_oauth_params: OAuthParams | None
|
|
45
43
|
_sse_params: SSEServerParams
|
|
46
44
|
|
|
47
45
|
def __init__(
|
|
48
46
|
self,
|
|
49
47
|
sse_params: SSEServerParams,
|
|
50
|
-
name:
|
|
51
|
-
oauth_params:
|
|
52
|
-
auth_token:
|
|
48
|
+
name: str | None = None,
|
|
49
|
+
oauth_params: OAuthParams | None = None,
|
|
50
|
+
auth_token: OAuth2Token | None = None,
|
|
53
51
|
):
|
|
54
52
|
super().__init__(name=name)
|
|
55
53
|
self._sse_params = sse_params
|
|
56
|
-
self._oauth_params:
|
|
57
|
-
self._auth_token:
|
|
54
|
+
self._oauth_params: OAuthParams | None = oauth_params
|
|
55
|
+
self._auth_token: OAuth2Token | None = auth_token
|
|
58
56
|
|
|
59
57
|
@cached_property
|
|
60
58
|
def base_url(self) -> str:
|
|
@@ -142,7 +140,7 @@ class MCPClientSSE(MCPClientBase):
|
|
|
142
140
|
async def _get_transport(
|
|
143
141
|
self, exit_stack: AsyncExitStack
|
|
144
142
|
) -> tuple[
|
|
145
|
-
MemoryObjectReceiveStream[
|
|
143
|
+
MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
146
144
|
MemoryObjectSendStream[SessionMessage],
|
|
147
145
|
]:
|
|
148
146
|
try:
|
mistralai/extra/mcp/stdio.py
CHANGED
|
@@ -1,12 +1,9 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
1
|
import logging
|
|
3
2
|
from contextlib import AsyncExitStack
|
|
4
3
|
|
|
5
|
-
from
|
|
6
|
-
MCPClientBase,
|
|
7
|
-
)
|
|
4
|
+
from mcp import StdioServerParameters, stdio_client # pyright: ignore[reportMissingImports]
|
|
8
5
|
|
|
9
|
-
from mcp import
|
|
6
|
+
from mistralai.extra.mcp.base import MCPClientBase
|
|
10
7
|
|
|
11
8
|
logger = logging.getLogger(__name__)
|
|
12
9
|
|
|
@@ -14,7 +11,9 @@ logger = logging.getLogger(__name__)
|
|
|
14
11
|
class MCPClientSTDIO(MCPClientBase):
|
|
15
12
|
"""MCP client that uses stdio for communication."""
|
|
16
13
|
|
|
17
|
-
def __init__(
|
|
14
|
+
def __init__(
|
|
15
|
+
self, stdio_params: StdioServerParameters, name: str | None = None
|
|
16
|
+
):
|
|
18
17
|
super().__init__(name=name)
|
|
19
18
|
self._stdio_params = stdio_params
|
|
20
19
|
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
|
|
3
|
+
from opentelemetry import trace as otel_trace
|
|
4
|
+
|
|
5
|
+
from .otel import MISTRAL_SDK_OTEL_TRACER_NAME
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@contextmanager
|
|
9
|
+
def trace(name: str, **kwargs):
|
|
10
|
+
tracer = otel_trace.get_tracer(MISTRAL_SDK_OTEL_TRACER_NAME)
|
|
11
|
+
with tracer.start_as_current_span(name, **kwargs) as span:
|
|
12
|
+
yield span
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
__all__ = ["trace"]
|