mistralai 1.0.3__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/__init__.py +4 -0
- mistralai/_hooks/sdkhooks.py +23 -4
- mistralai/_hooks/types.py +27 -9
- mistralai/_version.py +12 -0
- mistralai/agents.py +334 -164
- mistralai/basesdk.py +90 -5
- mistralai/batch.py +17 -0
- mistralai/chat.py +316 -166
- mistralai/classifiers.py +396 -0
- mistralai/embeddings.py +79 -55
- mistralai/files.py +487 -194
- mistralai/fim.py +206 -132
- mistralai/fine_tuning.py +3 -2
- mistralai/jobs.py +392 -263
- mistralai/mistral_jobs.py +733 -0
- mistralai/models/__init__.py +593 -50
- mistralai/models/agentscompletionrequest.py +70 -17
- mistralai/models/agentscompletionstreamrequest.py +72 -17
- mistralai/models/apiendpoint.py +9 -0
- mistralai/models/archiveftmodelout.py +15 -5
- mistralai/models/assistantmessage.py +22 -10
- mistralai/models/{modelcard.py → basemodelcard.py} +53 -14
- mistralai/models/batcherror.py +17 -0
- mistralai/models/batchjobin.py +58 -0
- mistralai/models/batchjobout.py +117 -0
- mistralai/models/batchjobsout.py +30 -0
- mistralai/models/batchjobstatus.py +15 -0
- mistralai/models/chatclassificationrequest.py +104 -0
- mistralai/models/chatcompletionchoice.py +13 -6
- mistralai/models/chatcompletionrequest.py +86 -21
- mistralai/models/chatcompletionresponse.py +8 -4
- mistralai/models/chatcompletionstreamrequest.py +88 -21
- mistralai/models/checkpointout.py +4 -3
- mistralai/models/classificationobject.py +21 -0
- mistralai/models/classificationrequest.py +59 -0
- mistralai/models/classificationresponse.py +21 -0
- mistralai/models/completionchunk.py +12 -5
- mistralai/models/completionevent.py +2 -3
- mistralai/models/completionresponsestreamchoice.py +22 -8
- mistralai/models/contentchunk.py +13 -10
- mistralai/models/delete_model_v1_models_model_id_deleteop.py +5 -5
- mistralai/models/deletefileout.py +4 -3
- mistralai/models/deletemodelout.py +5 -4
- mistralai/models/deltamessage.py +23 -11
- mistralai/models/detailedjobout.py +70 -12
- mistralai/models/embeddingrequest.py +14 -9
- mistralai/models/embeddingresponse.py +7 -3
- mistralai/models/embeddingresponsedata.py +5 -4
- mistralai/models/eventout.py +11 -6
- mistralai/models/filepurpose.py +8 -0
- mistralai/models/files_api_routes_delete_fileop.py +5 -5
- mistralai/models/files_api_routes_download_fileop.py +16 -0
- mistralai/models/files_api_routes_list_filesop.py +96 -0
- mistralai/models/files_api_routes_retrieve_fileop.py +5 -5
- mistralai/models/files_api_routes_upload_fileop.py +33 -14
- mistralai/models/fileschema.py +22 -15
- mistralai/models/fimcompletionrequest.py +44 -16
- mistralai/models/fimcompletionresponse.py +8 -4
- mistralai/models/fimcompletionstreamrequest.py +44 -16
- mistralai/models/finetuneablemodel.py +7 -1
- mistralai/models/ftmodelcapabilitiesout.py +6 -4
- mistralai/models/ftmodelcard.py +121 -0
- mistralai/models/ftmodelout.py +39 -9
- mistralai/models/function.py +5 -4
- mistralai/models/functioncall.py +4 -3
- mistralai/models/functionname.py +17 -0
- mistralai/models/githubrepositoryin.py +24 -7
- mistralai/models/githubrepositoryout.py +24 -7
- mistralai/models/httpvalidationerror.py +1 -3
- mistralai/models/imageurl.py +47 -0
- mistralai/models/imageurlchunk.py +38 -0
- mistralai/models/jobin.py +24 -7
- mistralai/models/jobmetadataout.py +32 -8
- mistralai/models/jobout.py +65 -12
- mistralai/models/jobs_api_routes_batch_cancel_batch_jobop.py +16 -0
- mistralai/models/jobs_api_routes_batch_get_batch_jobop.py +16 -0
- mistralai/models/jobs_api_routes_batch_get_batch_jobsop.py +95 -0
- mistralai/models/jobs_api_routes_fine_tuning_archive_fine_tuned_modelop.py +5 -5
- mistralai/models/jobs_api_routes_fine_tuning_cancel_fine_tuning_jobop.py +5 -5
- mistralai/models/jobs_api_routes_fine_tuning_create_fine_tuning_jobop.py +3 -2
- mistralai/models/jobs_api_routes_fine_tuning_get_fine_tuning_jobop.py +5 -5
- mistralai/models/jobs_api_routes_fine_tuning_get_fine_tuning_jobsop.py +85 -18
- mistralai/models/jobs_api_routes_fine_tuning_start_fine_tuning_jobop.py +5 -5
- mistralai/models/jobs_api_routes_fine_tuning_unarchive_fine_tuned_modelop.py +5 -5
- mistralai/models/jobs_api_routes_fine_tuning_update_fine_tuned_modelop.py +10 -6
- mistralai/models/jobsout.py +13 -5
- mistralai/models/legacyjobmetadataout.py +55 -9
- mistralai/models/listfilesout.py +7 -3
- mistralai/models/metricout.py +12 -8
- mistralai/models/modelcapabilities.py +9 -4
- mistralai/models/modellist.py +21 -7
- mistralai/models/responseformat.py +7 -8
- mistralai/models/responseformats.py +8 -0
- mistralai/models/retrieve_model_v1_models_model_id_getop.py +25 -6
- mistralai/models/retrievefileout.py +25 -15
- mistralai/models/sampletype.py +6 -2
- mistralai/models/security.py +14 -5
- mistralai/models/source.py +3 -2
- mistralai/models/systemmessage.py +10 -9
- mistralai/models/textchunk.py +14 -5
- mistralai/models/tool.py +10 -9
- mistralai/models/toolcall.py +10 -8
- mistralai/models/toolchoice.py +29 -0
- mistralai/models/toolchoiceenum.py +7 -0
- mistralai/models/toolmessage.py +13 -6
- mistralai/models/tooltypes.py +8 -0
- mistralai/models/trainingfile.py +4 -4
- mistralai/models/trainingparameters.py +34 -8
- mistralai/models/trainingparametersin.py +36 -10
- mistralai/models/unarchiveftmodelout.py +15 -5
- mistralai/models/updateftmodelin.py +9 -6
- mistralai/models/uploadfileout.py +22 -15
- mistralai/models/usageinfo.py +4 -3
- mistralai/models/usermessage.py +42 -10
- mistralai/models/validationerror.py +5 -3
- mistralai/models/wandbintegration.py +23 -7
- mistralai/models/wandbintegrationout.py +23 -8
- mistralai/models_.py +416 -294
- mistralai/sdk.py +31 -19
- mistralai/sdkconfiguration.py +9 -11
- mistralai/utils/__init__.py +14 -1
- mistralai/utils/annotations.py +13 -2
- mistralai/utils/logger.py +4 -1
- mistralai/utils/retries.py +2 -1
- mistralai/utils/security.py +13 -6
- mistralai/utils/serializers.py +25 -0
- {mistralai-1.0.3.dist-info → mistralai-1.2.0.dist-info}/METADATA +171 -66
- mistralai-1.2.0.dist-info/RECORD +276 -0
- {mistralai-1.0.3.dist-info → mistralai-1.2.0.dist-info}/WHEEL +1 -1
- mistralai_azure/__init__.py +4 -0
- mistralai_azure/_hooks/sdkhooks.py +23 -4
- mistralai_azure/_hooks/types.py +27 -9
- mistralai_azure/_version.py +12 -0
- mistralai_azure/basesdk.py +91 -6
- mistralai_azure/chat.py +308 -166
- mistralai_azure/models/__init__.py +164 -16
- mistralai_azure/models/assistantmessage.py +29 -11
- mistralai_azure/models/chatcompletionchoice.py +15 -6
- mistralai_azure/models/chatcompletionrequest.py +94 -22
- mistralai_azure/models/chatcompletionresponse.py +8 -4
- mistralai_azure/models/chatcompletionstreamrequest.py +96 -22
- mistralai_azure/models/completionchunk.py +12 -5
- mistralai_azure/models/completionevent.py +2 -3
- mistralai_azure/models/completionresponsestreamchoice.py +19 -8
- mistralai_azure/models/contentchunk.py +4 -11
- mistralai_azure/models/deltamessage.py +30 -12
- mistralai_azure/models/function.py +5 -4
- mistralai_azure/models/functioncall.py +4 -3
- mistralai_azure/models/functionname.py +17 -0
- mistralai_azure/models/httpvalidationerror.py +1 -3
- mistralai_azure/models/responseformat.py +7 -8
- mistralai_azure/models/responseformats.py +8 -0
- mistralai_azure/models/security.py +13 -5
- mistralai_azure/models/systemmessage.py +10 -9
- mistralai_azure/models/textchunk.py +14 -5
- mistralai_azure/models/tool.py +10 -9
- mistralai_azure/models/toolcall.py +10 -8
- mistralai_azure/models/toolchoice.py +29 -0
- mistralai_azure/models/toolchoiceenum.py +7 -0
- mistralai_azure/models/toolmessage.py +20 -7
- mistralai_azure/models/tooltypes.py +8 -0
- mistralai_azure/models/usageinfo.py +4 -3
- mistralai_azure/models/usermessage.py +42 -10
- mistralai_azure/models/validationerror.py +5 -3
- mistralai_azure/sdkconfiguration.py +9 -11
- mistralai_azure/utils/__init__.py +16 -3
- mistralai_azure/utils/annotations.py +13 -2
- mistralai_azure/utils/forms.py +10 -9
- mistralai_azure/utils/headers.py +8 -8
- mistralai_azure/utils/logger.py +6 -0
- mistralai_azure/utils/queryparams.py +16 -14
- mistralai_azure/utils/retries.py +2 -1
- mistralai_azure/utils/security.py +12 -6
- mistralai_azure/utils/serializers.py +42 -8
- mistralai_azure/utils/url.py +13 -8
- mistralai_azure/utils/values.py +6 -0
- mistralai_gcp/__init__.py +4 -0
- mistralai_gcp/_hooks/sdkhooks.py +23 -4
- mistralai_gcp/_hooks/types.py +27 -9
- mistralai_gcp/_version.py +12 -0
- mistralai_gcp/basesdk.py +91 -6
- mistralai_gcp/chat.py +308 -166
- mistralai_gcp/fim.py +198 -132
- mistralai_gcp/models/__init__.py +186 -18
- mistralai_gcp/models/assistantmessage.py +29 -11
- mistralai_gcp/models/chatcompletionchoice.py +15 -6
- mistralai_gcp/models/chatcompletionrequest.py +91 -22
- mistralai_gcp/models/chatcompletionresponse.py +8 -4
- mistralai_gcp/models/chatcompletionstreamrequest.py +93 -22
- mistralai_gcp/models/completionchunk.py +12 -5
- mistralai_gcp/models/completionevent.py +2 -3
- mistralai_gcp/models/completionresponsestreamchoice.py +19 -8
- mistralai_gcp/models/contentchunk.py +4 -11
- mistralai_gcp/models/deltamessage.py +30 -12
- mistralai_gcp/models/fimcompletionrequest.py +51 -17
- mistralai_gcp/models/fimcompletionresponse.py +8 -4
- mistralai_gcp/models/fimcompletionstreamrequest.py +51 -17
- mistralai_gcp/models/function.py +5 -4
- mistralai_gcp/models/functioncall.py +4 -3
- mistralai_gcp/models/functionname.py +17 -0
- mistralai_gcp/models/httpvalidationerror.py +1 -3
- mistralai_gcp/models/responseformat.py +7 -8
- mistralai_gcp/models/responseformats.py +8 -0
- mistralai_gcp/models/security.py +13 -5
- mistralai_gcp/models/systemmessage.py +10 -9
- mistralai_gcp/models/textchunk.py +14 -5
- mistralai_gcp/models/tool.py +10 -9
- mistralai_gcp/models/toolcall.py +10 -8
- mistralai_gcp/models/toolchoice.py +29 -0
- mistralai_gcp/models/toolchoiceenum.py +7 -0
- mistralai_gcp/models/toolmessage.py +20 -7
- mistralai_gcp/models/tooltypes.py +8 -0
- mistralai_gcp/models/usageinfo.py +4 -3
- mistralai_gcp/models/usermessage.py +42 -10
- mistralai_gcp/models/validationerror.py +5 -3
- mistralai_gcp/sdk.py +6 -7
- mistralai_gcp/sdkconfiguration.py +9 -11
- mistralai_gcp/utils/__init__.py +16 -3
- mistralai_gcp/utils/annotations.py +13 -2
- mistralai_gcp/utils/forms.py +10 -9
- mistralai_gcp/utils/headers.py +8 -8
- mistralai_gcp/utils/logger.py +6 -0
- mistralai_gcp/utils/queryparams.py +16 -14
- mistralai_gcp/utils/retries.py +2 -1
- mistralai_gcp/utils/security.py +12 -6
- mistralai_gcp/utils/serializers.py +42 -8
- mistralai_gcp/utils/url.py +13 -8
- mistralai_gcp/utils/values.py +6 -0
- mistralai-1.0.3.dist-info/RECORD +0 -236
- {mistralai-1.0.3.dist-info → mistralai-1.2.0.dist-info}/LICENSE +0 -0
mistralai/sdk.py
CHANGED
|
@@ -9,7 +9,9 @@ import httpx
|
|
|
9
9
|
from mistralai import models, utils
|
|
10
10
|
from mistralai._hooks import SDKHooks
|
|
11
11
|
from mistralai.agents import Agents
|
|
12
|
+
from mistralai.batch import Batch
|
|
12
13
|
from mistralai.chat import Chat
|
|
14
|
+
from mistralai.classifiers import Classifiers
|
|
13
15
|
from mistralai.embeddings import Embeddings
|
|
14
16
|
from mistralai.files import Files
|
|
15
17
|
from mistralai.fim import Fim
|
|
@@ -18,13 +20,16 @@ from mistralai.models_ import Models
|
|
|
18
20
|
from mistralai.types import OptionalNullable, UNSET
|
|
19
21
|
from typing import Any, Callable, Dict, Optional, Union
|
|
20
22
|
|
|
23
|
+
|
|
21
24
|
class Mistral(BaseSDK):
|
|
22
25
|
r"""Mistral AI API: Our Chat Completion and Embeddings APIs specification. Create your account on [La Plateforme](https://console.mistral.ai) to get access and read the [docs](https://docs.mistral.ai) to learn how to use it."""
|
|
26
|
+
|
|
23
27
|
models: Models
|
|
24
28
|
r"""Model Management API"""
|
|
25
29
|
files: Files
|
|
26
30
|
r"""Files API"""
|
|
27
31
|
fine_tuning: FineTuning
|
|
32
|
+
batch: Batch
|
|
28
33
|
chat: Chat
|
|
29
34
|
r"""Chat Completion API."""
|
|
30
35
|
fim: Fim
|
|
@@ -33,6 +38,9 @@ class Mistral(BaseSDK):
|
|
|
33
38
|
r"""Agents API."""
|
|
34
39
|
embeddings: Embeddings
|
|
35
40
|
r"""Embeddings API."""
|
|
41
|
+
classifiers: Classifiers
|
|
42
|
+
r"""Classifiers API."""
|
|
43
|
+
|
|
36
44
|
def __init__(
|
|
37
45
|
self,
|
|
38
46
|
api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
|
|
@@ -43,7 +51,7 @@ class Mistral(BaseSDK):
|
|
|
43
51
|
async_client: Optional[AsyncHttpClient] = None,
|
|
44
52
|
retry_config: OptionalNullable[RetryConfig] = UNSET,
|
|
45
53
|
timeout_ms: Optional[int] = None,
|
|
46
|
-
debug_logger: Optional[Logger] = None
|
|
54
|
+
debug_logger: Optional[Logger] = None,
|
|
47
55
|
) -> None:
|
|
48
56
|
r"""Instantiates the SDK configuring it with the provided parameters.
|
|
49
57
|
|
|
@@ -72,33 +80,37 @@ class Mistral(BaseSDK):
|
|
|
72
80
|
assert issubclass(
|
|
73
81
|
type(async_client), AsyncHttpClient
|
|
74
82
|
), "The provided async_client must implement the AsyncHttpClient protocol."
|
|
75
|
-
|
|
83
|
+
|
|
76
84
|
security: Any = None
|
|
77
85
|
if callable(api_key):
|
|
78
|
-
security = lambda: models.Security(api_key
|
|
86
|
+
security = lambda: models.Security(api_key=api_key()) # pylint: disable=unnecessary-lambda-assignment
|
|
79
87
|
else:
|
|
80
|
-
security = models.Security(api_key
|
|
88
|
+
security = models.Security(api_key=api_key)
|
|
81
89
|
|
|
82
90
|
if server_url is not None:
|
|
83
91
|
if url_params is not None:
|
|
84
92
|
server_url = utils.template_url(server_url, url_params)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
93
|
+
|
|
94
|
+
BaseSDK.__init__(
|
|
95
|
+
self,
|
|
96
|
+
SDKConfiguration(
|
|
97
|
+
client=client,
|
|
98
|
+
async_client=async_client,
|
|
99
|
+
security=security,
|
|
100
|
+
server_url=server_url,
|
|
101
|
+
server=server,
|
|
102
|
+
retry_config=retry_config,
|
|
103
|
+
timeout_ms=timeout_ms,
|
|
104
|
+
debug_logger=debug_logger,
|
|
105
|
+
),
|
|
106
|
+
)
|
|
97
107
|
|
|
98
108
|
hooks = SDKHooks()
|
|
99
109
|
|
|
100
110
|
current_server_url, *_ = self.sdk_configuration.get_server_details()
|
|
101
|
-
server_url, self.sdk_configuration.client = hooks.sdk_init(
|
|
111
|
+
server_url, self.sdk_configuration.client = hooks.sdk_init(
|
|
112
|
+
current_server_url, self.sdk_configuration.client
|
|
113
|
+
)
|
|
102
114
|
if current_server_url != server_url:
|
|
103
115
|
self.sdk_configuration.server_url = server_url
|
|
104
116
|
|
|
@@ -107,13 +119,13 @@ class Mistral(BaseSDK):
|
|
|
107
119
|
|
|
108
120
|
self._init_sdks()
|
|
109
121
|
|
|
110
|
-
|
|
111
122
|
def _init_sdks(self):
|
|
112
123
|
self.models = Models(self.sdk_configuration)
|
|
113
124
|
self.files = Files(self.sdk_configuration)
|
|
114
125
|
self.fine_tuning = FineTuning(self.sdk_configuration)
|
|
126
|
+
self.batch = Batch(self.sdk_configuration)
|
|
115
127
|
self.chat = Chat(self.sdk_configuration)
|
|
116
128
|
self.fim = Fim(self.sdk_configuration)
|
|
117
129
|
self.agents = Agents(self.sdk_configuration)
|
|
118
130
|
self.embeddings = Embeddings(self.sdk_configuration)
|
|
119
|
-
|
|
131
|
+
self.classifiers = Classifiers(self.sdk_configuration)
|
mistralai/sdkconfiguration.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
|
|
2
2
|
|
|
3
|
-
|
|
4
3
|
from ._hooks import SDKHooks
|
|
5
4
|
from .httpclient import AsyncHttpClient, HttpClient
|
|
6
5
|
from .utils import Logger, RetryConfig, remove_suffix
|
|
@@ -11,10 +10,10 @@ from pydantic import Field
|
|
|
11
10
|
from typing import Callable, Dict, Optional, Tuple, Union
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
|
|
15
|
-
r"""Production server"""
|
|
13
|
+
SERVER_EU = "eu"
|
|
14
|
+
r"""EU Production server"""
|
|
16
15
|
SERVERS = {
|
|
17
|
-
|
|
16
|
+
SERVER_EU: "https://api.mistral.ai",
|
|
18
17
|
}
|
|
19
18
|
"""Contains the list of servers available to the SDK"""
|
|
20
19
|
|
|
@@ -24,14 +23,14 @@ class SDKConfiguration:
|
|
|
24
23
|
client: HttpClient
|
|
25
24
|
async_client: AsyncHttpClient
|
|
26
25
|
debug_logger: Logger
|
|
27
|
-
security: Optional[Union[models.Security,Callable[[], models.Security]]] = None
|
|
26
|
+
security: Optional[Union[models.Security, Callable[[], models.Security]]] = None
|
|
28
27
|
server_url: Optional[str] = ""
|
|
29
28
|
server: Optional[str] = ""
|
|
30
29
|
language: str = "python"
|
|
31
30
|
openapi_doc_version: str = "0.0.2"
|
|
32
|
-
sdk_version: str = "1.0
|
|
33
|
-
gen_version: str = "2.
|
|
34
|
-
user_agent: str = "speakeasy-sdk/python 1.0
|
|
31
|
+
sdk_version: str = "1.2.0"
|
|
32
|
+
gen_version: str = "2.452.0"
|
|
33
|
+
user_agent: str = "speakeasy-sdk/python 1.2.0 2.452.0 0.0.2 mistralai"
|
|
35
34
|
retry_config: OptionalNullable[RetryConfig] = Field(default_factory=lambda: UNSET)
|
|
36
35
|
timeout_ms: Optional[int] = None
|
|
37
36
|
|
|
@@ -42,13 +41,12 @@ class SDKConfiguration:
|
|
|
42
41
|
if self.server_url is not None and self.server_url:
|
|
43
42
|
return remove_suffix(self.server_url, "/"), {}
|
|
44
43
|
if not self.server:
|
|
45
|
-
self.server =
|
|
44
|
+
self.server = SERVER_EU
|
|
46
45
|
|
|
47
46
|
if self.server not in SERVERS:
|
|
48
|
-
raise ValueError(f
|
|
47
|
+
raise ValueError(f'Invalid server "{self.server}"')
|
|
49
48
|
|
|
50
49
|
return SERVERS[self.server], {}
|
|
51
50
|
|
|
52
|
-
|
|
53
51
|
def get_hooks(self) -> SDKHooks:
|
|
54
52
|
return self._hooks
|
mistralai/utils/__init__.py
CHANGED
|
@@ -28,13 +28,22 @@ from .serializers import (
|
|
|
28
28
|
serialize_float,
|
|
29
29
|
serialize_int,
|
|
30
30
|
stream_to_text,
|
|
31
|
+
stream_to_text_async,
|
|
32
|
+
stream_to_bytes,
|
|
33
|
+
stream_to_bytes_async,
|
|
34
|
+
validate_const,
|
|
31
35
|
validate_decimal,
|
|
32
36
|
validate_float,
|
|
33
37
|
validate_int,
|
|
34
38
|
validate_open_enum,
|
|
35
39
|
)
|
|
36
40
|
from .url import generate_url, template_url, remove_suffix
|
|
37
|
-
from .values import
|
|
41
|
+
from .values import (
|
|
42
|
+
get_global_from_env,
|
|
43
|
+
match_content_type,
|
|
44
|
+
match_status_codes,
|
|
45
|
+
match_response,
|
|
46
|
+
)
|
|
38
47
|
from .logger import Logger, get_body_content, get_default_logger
|
|
39
48
|
|
|
40
49
|
__all__ = [
|
|
@@ -76,10 +85,14 @@ __all__ = [
|
|
|
76
85
|
"serialize_request_body",
|
|
77
86
|
"SerializedRequestBody",
|
|
78
87
|
"stream_to_text",
|
|
88
|
+
"stream_to_text_async",
|
|
89
|
+
"stream_to_bytes",
|
|
90
|
+
"stream_to_bytes_async",
|
|
79
91
|
"template_url",
|
|
80
92
|
"unmarshal",
|
|
81
93
|
"unmarshal_json",
|
|
82
94
|
"validate_decimal",
|
|
95
|
+
"validate_const",
|
|
83
96
|
"validate_float",
|
|
84
97
|
"validate_int",
|
|
85
98
|
"validate_open_enum",
|
mistralai/utils/annotations.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
|
|
2
2
|
|
|
3
|
+
from enum import Enum
|
|
3
4
|
from typing import Any
|
|
4
5
|
|
|
5
6
|
def get_discriminator(model: Any, fieldname: str, key: str) -> str:
|
|
@@ -10,10 +11,20 @@ def get_discriminator(model: Any, fieldname: str, key: str) -> str:
|
|
|
10
11
|
raise ValueError(f'Could not find discriminator key {key} in {model}') from e
|
|
11
12
|
|
|
12
13
|
if hasattr(model, fieldname):
|
|
13
|
-
|
|
14
|
+
attr = getattr(model, fieldname)
|
|
15
|
+
|
|
16
|
+
if isinstance(attr, Enum):
|
|
17
|
+
return f'{attr.value}'
|
|
18
|
+
|
|
19
|
+
return f'{attr}'
|
|
14
20
|
|
|
15
21
|
fieldname = fieldname.upper()
|
|
16
22
|
if hasattr(model, fieldname):
|
|
17
|
-
|
|
23
|
+
attr = getattr(model, fieldname)
|
|
24
|
+
|
|
25
|
+
if isinstance(attr, Enum):
|
|
26
|
+
return f'{attr.value}'
|
|
27
|
+
|
|
28
|
+
return f'{attr}'
|
|
18
29
|
|
|
19
30
|
raise ValueError(f'Could not find discriminator field {fieldname} in {model}')
|
mistralai/utils/logger.py
CHANGED
|
@@ -5,20 +5,23 @@ import logging
|
|
|
5
5
|
import os
|
|
6
6
|
from typing import Any, Protocol
|
|
7
7
|
|
|
8
|
+
|
|
8
9
|
class Logger(Protocol):
|
|
9
10
|
def debug(self, msg: str, *args: Any, **kwargs: Any) -> None:
|
|
10
11
|
pass
|
|
11
12
|
|
|
13
|
+
|
|
12
14
|
class NoOpLogger:
|
|
13
15
|
def debug(self, msg: str, *args: Any, **kwargs: Any) -> None:
|
|
14
16
|
pass
|
|
15
17
|
|
|
18
|
+
|
|
16
19
|
def get_body_content(req: httpx.Request) -> str:
|
|
17
20
|
return "<streaming body>" if not hasattr(req, "_content") else str(req.content)
|
|
18
21
|
|
|
22
|
+
|
|
19
23
|
def get_default_logger() -> Logger:
|
|
20
24
|
if os.getenv("MISTRAL_DEBUG"):
|
|
21
25
|
logging.basicConfig(level=logging.DEBUG)
|
|
22
26
|
return logging.getLogger("mistralai")
|
|
23
27
|
return NoOpLogger()
|
|
24
|
-
|
mistralai/utils/retries.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import random
|
|
4
5
|
import time
|
|
5
6
|
from typing import List
|
|
@@ -212,5 +213,5 @@ async def retry_with_backoff_async(
|
|
|
212
213
|
raise
|
|
213
214
|
sleep = (initial_interval / 1000) * exponent**retries + random.uniform(0, 1)
|
|
214
215
|
sleep = min(sleep, max_interval / 1000)
|
|
215
|
-
|
|
216
|
+
await asyncio.sleep(sleep)
|
|
216
217
|
retries += 1
|
mistralai/utils/security.py
CHANGED
|
@@ -44,8 +44,10 @@ def get_security(security: Any) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
|
|
|
44
44
|
_parse_security_option(headers, query_params, value)
|
|
45
45
|
return headers, query_params
|
|
46
46
|
if metadata.scheme:
|
|
47
|
-
# Special case for basic auth which could be a flattened model
|
|
48
|
-
if metadata.sub_type
|
|
47
|
+
# Special case for basic auth or custom auth which could be a flattened model
|
|
48
|
+
if metadata.sub_type in ["basic", "custom"] and not isinstance(
|
|
49
|
+
value, BaseModel
|
|
50
|
+
):
|
|
49
51
|
_parse_security_scheme(headers, query_params, metadata, name, security)
|
|
50
52
|
else:
|
|
51
53
|
_parse_security_scheme(headers, query_params, metadata, name, value)
|
|
@@ -64,7 +66,7 @@ def get_security_from_env(security: Any, security_class: Any) -> Optional[BaseMo
|
|
|
64
66
|
|
|
65
67
|
if os.getenv("MISTRAL_API_KEY"):
|
|
66
68
|
security_dict["api_key"] = os.getenv("MISTRAL_API_KEY")
|
|
67
|
-
|
|
69
|
+
|
|
68
70
|
return security_class(**security_dict) if security_dict else None
|
|
69
71
|
|
|
70
72
|
|
|
@@ -97,9 +99,12 @@ def _parse_security_scheme(
|
|
|
97
99
|
sub_type = scheme_metadata.sub_type
|
|
98
100
|
|
|
99
101
|
if isinstance(scheme, BaseModel):
|
|
100
|
-
if scheme_type == "http"
|
|
101
|
-
|
|
102
|
-
|
|
102
|
+
if scheme_type == "http":
|
|
103
|
+
if sub_type == "basic":
|
|
104
|
+
_parse_basic_auth_scheme(headers, scheme)
|
|
105
|
+
return
|
|
106
|
+
if sub_type == "custom":
|
|
107
|
+
return
|
|
103
108
|
|
|
104
109
|
scheme_fields: Dict[str, FieldInfo] = scheme.__class__.model_fields
|
|
105
110
|
for name in scheme_fields:
|
|
@@ -148,6 +153,8 @@ def _parse_security_scheme_value(
|
|
|
148
153
|
elif scheme_type == "http":
|
|
149
154
|
if sub_type == "bearer":
|
|
150
155
|
headers[header_name] = _apply_bearer(value)
|
|
156
|
+
elif sub_type == "custom":
|
|
157
|
+
return
|
|
151
158
|
else:
|
|
152
159
|
raise ValueError("sub type {sub_type} not supported")
|
|
153
160
|
else:
|
mistralai/utils/serializers.py
CHANGED
|
@@ -116,6 +116,19 @@ def validate_open_enum(is_int: bool):
|
|
|
116
116
|
return validate
|
|
117
117
|
|
|
118
118
|
|
|
119
|
+
def validate_const(v):
|
|
120
|
+
def validate(c):
|
|
121
|
+
if is_optional_type(type(c)) and c is None:
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
if v != c:
|
|
125
|
+
raise ValueError(f"Expected {v}")
|
|
126
|
+
|
|
127
|
+
return c
|
|
128
|
+
|
|
129
|
+
return validate
|
|
130
|
+
|
|
131
|
+
|
|
119
132
|
def unmarshal_json(raw, typ: Any) -> Any:
|
|
120
133
|
return unmarshal(from_json(raw), typ)
|
|
121
134
|
|
|
@@ -172,6 +185,18 @@ def stream_to_text(stream: httpx.Response) -> str:
|
|
|
172
185
|
return "".join(stream.iter_text())
|
|
173
186
|
|
|
174
187
|
|
|
188
|
+
async def stream_to_text_async(stream: httpx.Response) -> str:
|
|
189
|
+
return "".join([chunk async for chunk in stream.aiter_text()])
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def stream_to_bytes(stream: httpx.Response) -> bytes:
|
|
193
|
+
return stream.content
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
async def stream_to_bytes_async(stream: httpx.Response) -> bytes:
|
|
197
|
+
return await stream.aread()
|
|
198
|
+
|
|
199
|
+
|
|
175
200
|
def get_pydantic_model(data: Any, typ: Any) -> Any:
|
|
176
201
|
if not _contains_pydantic_model(data):
|
|
177
202
|
return unmarshal(data, typ)
|