mistralai 0.4.2__py3-none-any.whl → 0.5.5a50__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/__init__.py +5 -0
- mistralai/_hooks/__init__.py +5 -0
- mistralai/_hooks/custom_user_agent.py +16 -0
- mistralai/_hooks/deprecation_warning.py +26 -0
- mistralai/_hooks/registration.py +17 -0
- mistralai/_hooks/sdkhooks.py +57 -0
- mistralai/_hooks/types.py +76 -0
- mistralai/async_client.py +5 -413
- mistralai/basesdk.py +216 -0
- mistralai/chat.py +475 -0
- mistralai/client.py +5 -414
- mistralai/embeddings.py +182 -0
- mistralai/files.py +600 -84
- mistralai/fim.py +439 -0
- mistralai/fine_tuning.py +855 -0
- mistralai/httpclient.py +78 -0
- mistralai/models/__init__.py +80 -0
- mistralai/models/archiveftmodelout.py +19 -0
- mistralai/models/assistantmessage.py +58 -0
- mistralai/models/chatcompletionchoice.py +33 -0
- mistralai/models/chatcompletionrequest.py +114 -0
- mistralai/models/chatcompletionresponse.py +27 -0
- mistralai/models/chatcompletionstreamrequest.py +112 -0
- mistralai/models/checkpointout.py +25 -0
- mistralai/models/completionchunk.py +27 -0
- mistralai/models/completionevent.py +15 -0
- mistralai/models/completionresponsestreamchoice.py +53 -0
- mistralai/models/contentchunk.py +17 -0
- mistralai/models/delete_model_v1_models_model_id_deleteop.py +16 -0
- mistralai/models/deletefileout.py +24 -0
- mistralai/models/deletemodelout.py +25 -0
- mistralai/models/deltamessage.py +52 -0
- mistralai/models/detailedjobout.py +96 -0
- mistralai/models/embeddingrequest.py +66 -0
- mistralai/models/embeddingresponse.py +24 -0
- mistralai/models/embeddingresponsedata.py +19 -0
- mistralai/models/eventout.py +55 -0
- mistralai/models/files_api_routes_delete_fileop.py +16 -0
- mistralai/models/files_api_routes_retrieve_fileop.py +16 -0
- mistralai/models/files_api_routes_upload_fileop.py +51 -0
- mistralai/models/fileschema.py +76 -0
- mistralai/models/fimcompletionrequest.py +99 -0
- mistralai/models/fimcompletionresponse.py +27 -0
- mistralai/models/fimcompletionstreamrequest.py +97 -0
- mistralai/models/finetuneablemodel.py +8 -0
- mistralai/models/ftmodelcapabilitiesout.py +21 -0
- mistralai/models/ftmodelout.py +70 -0
- mistralai/models/function.py +19 -0
- mistralai/models/functioncall.py +16 -0
- mistralai/models/githubrepositoryin.py +57 -0
- mistralai/models/githubrepositoryout.py +57 -0
- mistralai/models/httpvalidationerror.py +23 -0
- mistralai/models/jobin.py +78 -0
- mistralai/models/jobmetadataout.py +59 -0
- mistralai/models/jobout.py +112 -0
- mistralai/models/jobs_api_routes_fine_tuning_archive_fine_tuned_modelop.py +16 -0
- mistralai/models/jobs_api_routes_fine_tuning_cancel_fine_tuning_jobop.py +18 -0
- mistralai/models/jobs_api_routes_fine_tuning_create_fine_tuning_jobop.py +73 -0
- mistralai/models/jobs_api_routes_fine_tuning_get_fine_tuning_jobop.py +18 -0
- mistralai/models/jobs_api_routes_fine_tuning_get_fine_tuning_jobsop.py +86 -0
- mistralai/models/jobs_api_routes_fine_tuning_start_fine_tuning_jobop.py +16 -0
- mistralai/models/jobs_api_routes_fine_tuning_unarchive_fine_tuned_modelop.py +16 -0
- mistralai/models/jobs_api_routes_fine_tuning_update_fine_tuned_modelop.py +19 -0
- mistralai/models/jobsout.py +20 -0
- mistralai/models/legacyjobmetadataout.py +85 -0
- mistralai/models/listfilesout.py +17 -0
- mistralai/models/metricout.py +55 -0
- mistralai/models/modelcapabilities.py +21 -0
- mistralai/models/modelcard.py +71 -0
- mistralai/models/modellist.py +18 -0
- mistralai/models/responseformat.py +18 -0
- mistralai/models/retrieve_model_v1_models_model_id_getop.py +16 -0
- mistralai/models/retrievefileout.py +76 -0
- mistralai/models/sampletype.py +7 -0
- mistralai/models/sdkerror.py +22 -0
- mistralai/models/security.py +16 -0
- mistralai/models/source.py +7 -0
- mistralai/models/systemmessage.py +26 -0
- mistralai/models/textchunk.py +17 -0
- mistralai/models/tool.py +18 -0
- mistralai/models/toolcall.py +20 -0
- mistralai/models/toolmessage.py +55 -0
- mistralai/models/trainingfile.py +17 -0
- mistralai/models/trainingparameters.py +53 -0
- mistralai/models/trainingparametersin.py +61 -0
- mistralai/models/unarchiveftmodelout.py +19 -0
- mistralai/models/updateftmodelin.py +49 -0
- mistralai/models/uploadfileout.py +76 -0
- mistralai/models/usageinfo.py +18 -0
- mistralai/models/usermessage.py +26 -0
- mistralai/models/validationerror.py +24 -0
- mistralai/models/wandbintegration.py +61 -0
- mistralai/models/wandbintegrationout.py +57 -0
- mistralai/models_.py +928 -0
- mistralai/py.typed +1 -0
- mistralai/sdk.py +111 -0
- mistralai/sdkconfiguration.py +53 -0
- mistralai/types/__init__.py +21 -0
- mistralai/types/basemodel.py +35 -0
- mistralai/utils/__init__.py +82 -0
- mistralai/utils/annotations.py +19 -0
- mistralai/utils/enums.py +34 -0
- mistralai/utils/eventstreaming.py +179 -0
- mistralai/utils/forms.py +207 -0
- mistralai/utils/headers.py +136 -0
- mistralai/utils/metadata.py +118 -0
- mistralai/utils/queryparams.py +203 -0
- mistralai/utils/requestbodies.py +66 -0
- mistralai/utils/retries.py +216 -0
- mistralai/utils/security.py +182 -0
- mistralai/utils/serializers.py +181 -0
- mistralai/utils/url.py +150 -0
- mistralai/utils/values.py +128 -0
- {mistralai-0.4.2.dist-info → mistralai-0.5.5a50.dist-info}/LICENSE +1 -1
- mistralai-0.5.5a50.dist-info/METADATA +626 -0
- mistralai-0.5.5a50.dist-info/RECORD +228 -0
- mistralai_azure/__init__.py +5 -0
- mistralai_azure/_hooks/__init__.py +5 -0
- mistralai_azure/_hooks/custom_user_agent.py +16 -0
- mistralai_azure/_hooks/registration.py +15 -0
- mistralai_azure/_hooks/sdkhooks.py +57 -0
- mistralai_azure/_hooks/types.py +76 -0
- mistralai_azure/basesdk.py +215 -0
- mistralai_azure/chat.py +475 -0
- mistralai_azure/httpclient.py +78 -0
- mistralai_azure/models/__init__.py +28 -0
- mistralai_azure/models/assistantmessage.py +58 -0
- mistralai_azure/models/chatcompletionchoice.py +33 -0
- mistralai_azure/models/chatcompletionrequest.py +114 -0
- mistralai_azure/models/chatcompletionresponse.py +27 -0
- mistralai_azure/models/chatcompletionstreamrequest.py +112 -0
- mistralai_azure/models/completionchunk.py +27 -0
- mistralai_azure/models/completionevent.py +15 -0
- mistralai_azure/models/completionresponsestreamchoice.py +53 -0
- mistralai_azure/models/contentchunk.py +17 -0
- mistralai_azure/models/deltamessage.py +52 -0
- mistralai_azure/models/function.py +19 -0
- mistralai_azure/models/functioncall.py +16 -0
- mistralai_azure/models/httpvalidationerror.py +23 -0
- mistralai_azure/models/responseformat.py +18 -0
- mistralai_azure/models/sdkerror.py +22 -0
- mistralai_azure/models/security.py +16 -0
- mistralai_azure/models/systemmessage.py +26 -0
- mistralai_azure/models/textchunk.py +17 -0
- mistralai_azure/models/tool.py +18 -0
- mistralai_azure/models/toolcall.py +20 -0
- mistralai_azure/models/toolmessage.py +55 -0
- mistralai_azure/models/usageinfo.py +18 -0
- mistralai_azure/models/usermessage.py +26 -0
- mistralai_azure/models/validationerror.py +24 -0
- mistralai_azure/py.typed +1 -0
- mistralai_azure/sdk.py +102 -0
- mistralai_azure/sdkconfiguration.py +53 -0
- mistralai_azure/types/__init__.py +21 -0
- mistralai_azure/types/basemodel.py +35 -0
- mistralai_azure/utils/__init__.py +80 -0
- mistralai_azure/utils/annotations.py +19 -0
- mistralai_azure/utils/enums.py +34 -0
- mistralai_azure/utils/eventstreaming.py +179 -0
- mistralai_azure/utils/forms.py +207 -0
- mistralai_azure/utils/headers.py +136 -0
- mistralai_azure/utils/metadata.py +118 -0
- mistralai_azure/utils/queryparams.py +203 -0
- mistralai_azure/utils/requestbodies.py +66 -0
- mistralai_azure/utils/retries.py +216 -0
- mistralai_azure/utils/security.py +168 -0
- mistralai_azure/utils/serializers.py +181 -0
- mistralai_azure/utils/url.py +150 -0
- mistralai_azure/utils/values.py +128 -0
- mistralai_gcp/__init__.py +5 -0
- mistralai_gcp/_hooks/__init__.py +5 -0
- mistralai_gcp/_hooks/custom_user_agent.py +16 -0
- mistralai_gcp/_hooks/registration.py +15 -0
- mistralai_gcp/_hooks/sdkhooks.py +57 -0
- mistralai_gcp/_hooks/types.py +76 -0
- mistralai_gcp/basesdk.py +215 -0
- mistralai_gcp/chat.py +463 -0
- mistralai_gcp/fim.py +439 -0
- mistralai_gcp/httpclient.py +78 -0
- mistralai_gcp/models/__init__.py +31 -0
- mistralai_gcp/models/assistantmessage.py +58 -0
- mistralai_gcp/models/chatcompletionchoice.py +33 -0
- mistralai_gcp/models/chatcompletionrequest.py +110 -0
- mistralai_gcp/models/chatcompletionresponse.py +27 -0
- mistralai_gcp/models/chatcompletionstreamrequest.py +108 -0
- mistralai_gcp/models/completionchunk.py +27 -0
- mistralai_gcp/models/completionevent.py +15 -0
- mistralai_gcp/models/completionresponsestreamchoice.py +53 -0
- mistralai_gcp/models/contentchunk.py +17 -0
- mistralai_gcp/models/deltamessage.py +52 -0
- mistralai_gcp/models/fimcompletionrequest.py +99 -0
- mistralai_gcp/models/fimcompletionresponse.py +27 -0
- mistralai_gcp/models/fimcompletionstreamrequest.py +97 -0
- mistralai_gcp/models/function.py +19 -0
- mistralai_gcp/models/functioncall.py +16 -0
- mistralai_gcp/models/httpvalidationerror.py +23 -0
- mistralai_gcp/models/responseformat.py +18 -0
- mistralai_gcp/models/sdkerror.py +22 -0
- mistralai_gcp/models/security.py +16 -0
- mistralai_gcp/models/systemmessage.py +26 -0
- mistralai_gcp/models/textchunk.py +17 -0
- mistralai_gcp/models/tool.py +18 -0
- mistralai_gcp/models/toolcall.py +20 -0
- mistralai_gcp/models/toolmessage.py +55 -0
- mistralai_gcp/models/usageinfo.py +18 -0
- mistralai_gcp/models/usermessage.py +26 -0
- mistralai_gcp/models/validationerror.py +24 -0
- mistralai_gcp/py.typed +1 -0
- mistralai_gcp/sdk.py +165 -0
- mistralai_gcp/sdkconfiguration.py +53 -0
- mistralai_gcp/types/__init__.py +21 -0
- mistralai_gcp/types/basemodel.py +35 -0
- mistralai_gcp/utils/__init__.py +80 -0
- mistralai_gcp/utils/annotations.py +19 -0
- mistralai_gcp/utils/enums.py +34 -0
- mistralai_gcp/utils/eventstreaming.py +179 -0
- mistralai_gcp/utils/forms.py +207 -0
- mistralai_gcp/utils/headers.py +136 -0
- mistralai_gcp/utils/metadata.py +118 -0
- mistralai_gcp/utils/queryparams.py +203 -0
- mistralai_gcp/utils/requestbodies.py +66 -0
- mistralai_gcp/utils/retries.py +216 -0
- mistralai_gcp/utils/security.py +168 -0
- mistralai_gcp/utils/serializers.py +181 -0
- mistralai_gcp/utils/url.py +150 -0
- mistralai_gcp/utils/values.py +128 -0
- py.typed +1 -0
- mistralai/client_base.py +0 -211
- mistralai/constants.py +0 -5
- mistralai/exceptions.py +0 -54
- mistralai/jobs.py +0 -172
- mistralai/models/chat_completion.py +0 -93
- mistralai/models/common.py +0 -9
- mistralai/models/embeddings.py +0 -19
- mistralai/models/files.py +0 -23
- mistralai/models/jobs.py +0 -100
- mistralai/models/models.py +0 -39
- mistralai-0.4.2.dist-info/METADATA +0 -82
- mistralai-0.4.2.dist-info/RECORD +0 -20
- {mistralai-0.4.2.dist-info → mistralai-0.5.5a50.dist-info}/WHEEL +0 -0
mistralai/client_base.py
DELETED
|
@@ -1,211 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import os
|
|
3
|
-
from abc import ABC
|
|
4
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
|
5
|
-
|
|
6
|
-
import orjson
|
|
7
|
-
from httpx import Headers
|
|
8
|
-
|
|
9
|
-
from mistralai.constants import HEADER_MODEL_DEPRECATION_TIMESTAMP
|
|
10
|
-
from mistralai.exceptions import MistralException
|
|
11
|
-
from mistralai.models.chat_completion import (
|
|
12
|
-
ChatMessage,
|
|
13
|
-
Function,
|
|
14
|
-
ResponseFormat,
|
|
15
|
-
ToolChoice,
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
CLIENT_VERSION = "0.4.2"
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class ClientBase(ABC):
|
|
22
|
-
def __init__(
|
|
23
|
-
self,
|
|
24
|
-
endpoint: str,
|
|
25
|
-
api_key: Optional[str] = None,
|
|
26
|
-
max_retries: int = 5,
|
|
27
|
-
timeout: int = 120,
|
|
28
|
-
):
|
|
29
|
-
self._max_retries = max_retries
|
|
30
|
-
self._timeout = timeout
|
|
31
|
-
|
|
32
|
-
if api_key is None:
|
|
33
|
-
api_key = os.environ.get("MISTRAL_API_KEY")
|
|
34
|
-
if api_key is None:
|
|
35
|
-
raise MistralException(message="API key not provided. Please set MISTRAL_API_KEY environment variable.")
|
|
36
|
-
self._api_key = api_key
|
|
37
|
-
self._endpoint = endpoint
|
|
38
|
-
self._logger = logging.getLogger(__name__)
|
|
39
|
-
|
|
40
|
-
# For azure endpoints, we default to the mistral model
|
|
41
|
-
if "inference.azure.com" in self._endpoint:
|
|
42
|
-
self._default_model = "mistral"
|
|
43
|
-
|
|
44
|
-
self._version = CLIENT_VERSION
|
|
45
|
-
|
|
46
|
-
def _get_model(self, model: Optional[str] = None) -> str:
|
|
47
|
-
if model is not None:
|
|
48
|
-
return model
|
|
49
|
-
else:
|
|
50
|
-
if self._default_model is None:
|
|
51
|
-
raise MistralException(message="model must be provided")
|
|
52
|
-
return self._default_model
|
|
53
|
-
|
|
54
|
-
def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
55
|
-
parsed_tools: List[Dict[str, Any]] = []
|
|
56
|
-
for tool in tools:
|
|
57
|
-
if tool["type"] == "function":
|
|
58
|
-
parsed_function = {}
|
|
59
|
-
parsed_function["type"] = tool["type"]
|
|
60
|
-
if isinstance(tool["function"], Function):
|
|
61
|
-
parsed_function["function"] = tool["function"].model_dump(exclude_none=True)
|
|
62
|
-
else:
|
|
63
|
-
parsed_function["function"] = tool["function"]
|
|
64
|
-
|
|
65
|
-
parsed_tools.append(parsed_function)
|
|
66
|
-
|
|
67
|
-
return parsed_tools
|
|
68
|
-
|
|
69
|
-
def _parse_tool_choice(self, tool_choice: Union[str, ToolChoice]) -> str:
|
|
70
|
-
if isinstance(tool_choice, ToolChoice):
|
|
71
|
-
return tool_choice.value
|
|
72
|
-
return tool_choice
|
|
73
|
-
|
|
74
|
-
def _parse_response_format(self, response_format: Union[Dict[str, Any], ResponseFormat]) -> Dict[str, Any]:
|
|
75
|
-
if isinstance(response_format, ResponseFormat):
|
|
76
|
-
return response_format.model_dump(exclude_none=True)
|
|
77
|
-
return response_format
|
|
78
|
-
|
|
79
|
-
def _parse_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
|
|
80
|
-
parsed_messages: List[Dict[str, Any]] = []
|
|
81
|
-
for message in messages:
|
|
82
|
-
if isinstance(message, ChatMessage):
|
|
83
|
-
parsed_messages.append(message.model_dump(exclude_none=True))
|
|
84
|
-
else:
|
|
85
|
-
parsed_messages.append(message)
|
|
86
|
-
|
|
87
|
-
return parsed_messages
|
|
88
|
-
|
|
89
|
-
def _check_model_deprecation_header_callback_factory(self, model: Optional[str] = None) -> Callable:
|
|
90
|
-
model = self._get_model(model)
|
|
91
|
-
|
|
92
|
-
def _check_model_deprecation_header_callback(
|
|
93
|
-
headers: Headers,
|
|
94
|
-
) -> None:
|
|
95
|
-
if HEADER_MODEL_DEPRECATION_TIMESTAMP in headers:
|
|
96
|
-
self._logger.warning(
|
|
97
|
-
f"WARNING: The model {model} is deprecated "
|
|
98
|
-
f"and will be removed on {headers[HEADER_MODEL_DEPRECATION_TIMESTAMP]}. "
|
|
99
|
-
"Please refer to https://docs.mistral.ai/getting-started/models/#api-versioning "
|
|
100
|
-
"for more information."
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
return _check_model_deprecation_header_callback
|
|
104
|
-
|
|
105
|
-
def _make_completion_request(
|
|
106
|
-
self,
|
|
107
|
-
prompt: str,
|
|
108
|
-
model: Optional[str] = None,
|
|
109
|
-
suffix: Optional[str] = None,
|
|
110
|
-
temperature: Optional[float] = None,
|
|
111
|
-
max_tokens: Optional[int] = None,
|
|
112
|
-
top_p: Optional[float] = None,
|
|
113
|
-
random_seed: Optional[int] = None,
|
|
114
|
-
stop: Optional[List[str]] = None,
|
|
115
|
-
stream: Optional[bool] = False,
|
|
116
|
-
) -> Dict[str, Any]:
|
|
117
|
-
request_data: Dict[str, Any] = {
|
|
118
|
-
"prompt": prompt,
|
|
119
|
-
"suffix": suffix,
|
|
120
|
-
"model": model,
|
|
121
|
-
"stream": stream,
|
|
122
|
-
}
|
|
123
|
-
|
|
124
|
-
if stop is not None:
|
|
125
|
-
request_data["stop"] = stop
|
|
126
|
-
|
|
127
|
-
request_data["model"] = self._get_model(model)
|
|
128
|
-
|
|
129
|
-
request_data.update(
|
|
130
|
-
self._build_sampling_params(
|
|
131
|
-
temperature=temperature,
|
|
132
|
-
max_tokens=max_tokens,
|
|
133
|
-
top_p=top_p,
|
|
134
|
-
random_seed=random_seed,
|
|
135
|
-
)
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
self._logger.debug(f"Completion request: {request_data}")
|
|
139
|
-
|
|
140
|
-
return request_data
|
|
141
|
-
|
|
142
|
-
def _build_sampling_params(
|
|
143
|
-
self,
|
|
144
|
-
max_tokens: Optional[int],
|
|
145
|
-
random_seed: Optional[int],
|
|
146
|
-
temperature: Optional[float],
|
|
147
|
-
top_p: Optional[float],
|
|
148
|
-
) -> Dict[str, Any]:
|
|
149
|
-
params = {}
|
|
150
|
-
if temperature is not None:
|
|
151
|
-
params["temperature"] = temperature
|
|
152
|
-
if max_tokens is not None:
|
|
153
|
-
params["max_tokens"] = max_tokens
|
|
154
|
-
if top_p is not None:
|
|
155
|
-
params["top_p"] = top_p
|
|
156
|
-
if random_seed is not None:
|
|
157
|
-
params["random_seed"] = random_seed
|
|
158
|
-
return params
|
|
159
|
-
|
|
160
|
-
def _make_chat_request(
|
|
161
|
-
self,
|
|
162
|
-
messages: List[Any],
|
|
163
|
-
model: Optional[str] = None,
|
|
164
|
-
tools: Optional[List[Dict[str, Any]]] = None,
|
|
165
|
-
temperature: Optional[float] = None,
|
|
166
|
-
max_tokens: Optional[int] = None,
|
|
167
|
-
top_p: Optional[float] = None,
|
|
168
|
-
random_seed: Optional[int] = None,
|
|
169
|
-
stream: Optional[bool] = None,
|
|
170
|
-
safe_prompt: Optional[bool] = False,
|
|
171
|
-
tool_choice: Optional[Union[str, ToolChoice]] = None,
|
|
172
|
-
response_format: Optional[Union[Dict[str, str], ResponseFormat]] = None,
|
|
173
|
-
) -> Dict[str, Any]:
|
|
174
|
-
request_data: Dict[str, Any] = {
|
|
175
|
-
"messages": self._parse_messages(messages),
|
|
176
|
-
}
|
|
177
|
-
|
|
178
|
-
request_data["model"] = self._get_model(model)
|
|
179
|
-
|
|
180
|
-
request_data.update(
|
|
181
|
-
self._build_sampling_params(
|
|
182
|
-
temperature=temperature,
|
|
183
|
-
max_tokens=max_tokens,
|
|
184
|
-
top_p=top_p,
|
|
185
|
-
random_seed=random_seed,
|
|
186
|
-
)
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
if safe_prompt:
|
|
190
|
-
request_data["safe_prompt"] = safe_prompt
|
|
191
|
-
if tools is not None:
|
|
192
|
-
request_data["tools"] = self._parse_tools(tools)
|
|
193
|
-
if stream is not None:
|
|
194
|
-
request_data["stream"] = stream
|
|
195
|
-
|
|
196
|
-
if tool_choice is not None:
|
|
197
|
-
request_data["tool_choice"] = self._parse_tool_choice(tool_choice)
|
|
198
|
-
if response_format is not None:
|
|
199
|
-
request_data["response_format"] = self._parse_response_format(response_format)
|
|
200
|
-
|
|
201
|
-
self._logger.debug(f"Chat request: {request_data}")
|
|
202
|
-
|
|
203
|
-
return request_data
|
|
204
|
-
|
|
205
|
-
def _process_line(self, line: str) -> Optional[Dict[str, Any]]:
|
|
206
|
-
if line.startswith("data: "):
|
|
207
|
-
line = line[6:].strip()
|
|
208
|
-
if line != "[DONE]":
|
|
209
|
-
json_streamed_response: Dict[str, Any] = orjson.loads(line)
|
|
210
|
-
return json_streamed_response
|
|
211
|
-
return None
|
mistralai/constants.py
DELETED
mistralai/exceptions.py
DELETED
|
@@ -1,54 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from typing import Any, Dict, Optional
|
|
4
|
-
|
|
5
|
-
from httpx import Response
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class MistralException(Exception):
|
|
9
|
-
"""Base Exception class, returned when nothing more specific applies"""
|
|
10
|
-
|
|
11
|
-
def __init__(self, message: Optional[str] = None) -> None:
|
|
12
|
-
super(MistralException, self).__init__(message)
|
|
13
|
-
|
|
14
|
-
self.message = message
|
|
15
|
-
|
|
16
|
-
def __str__(self) -> str:
|
|
17
|
-
msg = self.message or "<empty message>"
|
|
18
|
-
return msg
|
|
19
|
-
|
|
20
|
-
def __repr__(self) -> str:
|
|
21
|
-
return f"{self.__class__.__name__}(message={str(self)})"
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class MistralAPIException(MistralException):
|
|
25
|
-
"""Returned when the API responds with an error message"""
|
|
26
|
-
|
|
27
|
-
def __init__(
|
|
28
|
-
self,
|
|
29
|
-
message: Optional[str] = None,
|
|
30
|
-
http_status: Optional[int] = None,
|
|
31
|
-
headers: Optional[Dict[str, Any]] = None,
|
|
32
|
-
) -> None:
|
|
33
|
-
super().__init__(message)
|
|
34
|
-
self.http_status = http_status
|
|
35
|
-
self.headers = headers or {}
|
|
36
|
-
|
|
37
|
-
@classmethod
|
|
38
|
-
def from_response(cls, response: Response, message: Optional[str] = None) -> MistralAPIException:
|
|
39
|
-
return cls(
|
|
40
|
-
message=message or response.text,
|
|
41
|
-
http_status=response.status_code,
|
|
42
|
-
headers=dict(response.headers),
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
def __repr__(self) -> str:
|
|
46
|
-
return f"{self.__class__.__name__}(message={str(self)}, http_status={self.http_status})"
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
class MistralAPIStatusException(MistralAPIException):
|
|
50
|
-
"""Returned when we receive a non-200 response from the API that we should retry"""
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class MistralConnectionException(MistralException):
|
|
54
|
-
"""Returned when the SDK can not reach the API server for any reason"""
|
mistralai/jobs.py
DELETED
|
@@ -1,172 +0,0 @@
|
|
|
1
|
-
from datetime import datetime
|
|
2
|
-
from typing import Any, Optional, Union
|
|
3
|
-
|
|
4
|
-
from mistralai.exceptions import (
|
|
5
|
-
MistralException,
|
|
6
|
-
)
|
|
7
|
-
from mistralai.models.jobs import DetailedJob, IntegrationIn, Job, JobMetadata, JobQueryFilter, Jobs, TrainingParameters
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class JobsClient:
|
|
11
|
-
def __init__(self, client: Any):
|
|
12
|
-
self.client = client
|
|
13
|
-
|
|
14
|
-
def create(
|
|
15
|
-
self,
|
|
16
|
-
model: str,
|
|
17
|
-
training_files: Union[list[str], None] = None,
|
|
18
|
-
validation_files: Union[list[str], None] = None,
|
|
19
|
-
hyperparameters: TrainingParameters = TrainingParameters(
|
|
20
|
-
training_steps=1800,
|
|
21
|
-
learning_rate=1.0e-4,
|
|
22
|
-
),
|
|
23
|
-
suffix: Union[str, None] = None,
|
|
24
|
-
integrations: Union[set[IntegrationIn], None] = None,
|
|
25
|
-
training_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
|
|
26
|
-
validation_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
|
|
27
|
-
dry_run: bool = False,
|
|
28
|
-
) -> Union[Job, JobMetadata]:
|
|
29
|
-
# Handle deprecated arguments
|
|
30
|
-
if not training_files and training_file:
|
|
31
|
-
training_files = [training_file]
|
|
32
|
-
if not validation_files and validation_file:
|
|
33
|
-
validation_files = [validation_file]
|
|
34
|
-
single_response = self.client._request(
|
|
35
|
-
method="post",
|
|
36
|
-
json={
|
|
37
|
-
"model": model,
|
|
38
|
-
"training_files": training_files,
|
|
39
|
-
"validation_files": validation_files,
|
|
40
|
-
"hyperparameters": hyperparameters.dict(),
|
|
41
|
-
"suffix": suffix,
|
|
42
|
-
"integrations": integrations,
|
|
43
|
-
},
|
|
44
|
-
path="v1/fine_tuning/jobs",
|
|
45
|
-
params={"dry_run": dry_run},
|
|
46
|
-
)
|
|
47
|
-
for response in single_response:
|
|
48
|
-
return Job(**response) if not dry_run else JobMetadata(**response)
|
|
49
|
-
raise MistralException("No response received")
|
|
50
|
-
|
|
51
|
-
def retrieve(self, job_id: str) -> DetailedJob:
|
|
52
|
-
single_response = self.client._request(method="get", path=f"v1/fine_tuning/jobs/{job_id}", json={})
|
|
53
|
-
for response in single_response:
|
|
54
|
-
return DetailedJob(**response)
|
|
55
|
-
raise MistralException("No response received")
|
|
56
|
-
|
|
57
|
-
def list(
|
|
58
|
-
self,
|
|
59
|
-
page: int = 0,
|
|
60
|
-
page_size: int = 10,
|
|
61
|
-
model: Optional[str] = None,
|
|
62
|
-
created_after: Optional[datetime] = None,
|
|
63
|
-
created_by_me: Optional[bool] = None,
|
|
64
|
-
status: Optional[str] = None,
|
|
65
|
-
wandb_project: Optional[str] = None,
|
|
66
|
-
wandb_name: Optional[str] = None,
|
|
67
|
-
suffix: Optional[str] = None,
|
|
68
|
-
) -> Jobs:
|
|
69
|
-
query_params = JobQueryFilter(
|
|
70
|
-
page=page,
|
|
71
|
-
page_size=page_size,
|
|
72
|
-
model=model,
|
|
73
|
-
created_after=created_after,
|
|
74
|
-
created_by_me=created_by_me,
|
|
75
|
-
status=status,
|
|
76
|
-
wandb_project=wandb_project,
|
|
77
|
-
wandb_name=wandb_name,
|
|
78
|
-
suffix=suffix,
|
|
79
|
-
).model_dump(exclude_none=True)
|
|
80
|
-
single_response = self.client._request(method="get", params=query_params, path="v1/fine_tuning/jobs", json={})
|
|
81
|
-
for response in single_response:
|
|
82
|
-
return Jobs(**response)
|
|
83
|
-
raise MistralException("No response received")
|
|
84
|
-
|
|
85
|
-
def cancel(self, job_id: str) -> DetailedJob:
|
|
86
|
-
single_response = self.client._request(method="post", path=f"v1/fine_tuning/jobs/{job_id}/cancel", json={})
|
|
87
|
-
for response in single_response:
|
|
88
|
-
return DetailedJob(**response)
|
|
89
|
-
raise MistralException("No response received")
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
class JobsAsyncClient:
|
|
93
|
-
def __init__(self, client: Any):
|
|
94
|
-
self.client = client
|
|
95
|
-
|
|
96
|
-
async def create(
|
|
97
|
-
self,
|
|
98
|
-
model: str,
|
|
99
|
-
training_files: Union[list[str], None] = None,
|
|
100
|
-
validation_files: Union[list[str], None] = None,
|
|
101
|
-
hyperparameters: TrainingParameters = TrainingParameters(
|
|
102
|
-
training_steps=1800,
|
|
103
|
-
learning_rate=1.0e-4,
|
|
104
|
-
),
|
|
105
|
-
suffix: Union[str, None] = None,
|
|
106
|
-
integrations: Union[set[IntegrationIn], None] = None,
|
|
107
|
-
training_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
|
|
108
|
-
validation_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
|
|
109
|
-
dry_run: bool = False,
|
|
110
|
-
) -> Union[Job, JobMetadata]:
|
|
111
|
-
# Handle deprecated arguments
|
|
112
|
-
if not training_files and training_file:
|
|
113
|
-
training_files = [training_file]
|
|
114
|
-
if not validation_files and validation_file:
|
|
115
|
-
validation_files = [validation_file]
|
|
116
|
-
|
|
117
|
-
single_response = self.client._request(
|
|
118
|
-
method="post",
|
|
119
|
-
json={
|
|
120
|
-
"model": model,
|
|
121
|
-
"training_files": training_files,
|
|
122
|
-
"validation_files": validation_files,
|
|
123
|
-
"hyperparameters": hyperparameters.dict(),
|
|
124
|
-
"suffix": suffix,
|
|
125
|
-
"integrations": integrations,
|
|
126
|
-
},
|
|
127
|
-
path="v1/fine_tuning/jobs",
|
|
128
|
-
params={"dry_run": dry_run},
|
|
129
|
-
)
|
|
130
|
-
async for response in single_response:
|
|
131
|
-
return Job(**response) if not dry_run else JobMetadata(**response)
|
|
132
|
-
raise MistralException("No response received")
|
|
133
|
-
|
|
134
|
-
async def retrieve(self, job_id: str) -> DetailedJob:
|
|
135
|
-
single_response = self.client._request(method="get", path=f"v1/fine_tuning/jobs/{job_id}", json={})
|
|
136
|
-
async for response in single_response:
|
|
137
|
-
return DetailedJob(**response)
|
|
138
|
-
raise MistralException("No response received")
|
|
139
|
-
|
|
140
|
-
async def list(
|
|
141
|
-
self,
|
|
142
|
-
page: int = 0,
|
|
143
|
-
page_size: int = 10,
|
|
144
|
-
model: Optional[str] = None,
|
|
145
|
-
created_after: Optional[datetime] = None,
|
|
146
|
-
created_by_me: Optional[bool] = None,
|
|
147
|
-
status: Optional[str] = None,
|
|
148
|
-
wandb_project: Optional[str] = None,
|
|
149
|
-
wandb_name: Optional[str] = None,
|
|
150
|
-
suffix: Optional[str] = None,
|
|
151
|
-
) -> Jobs:
|
|
152
|
-
query_params = JobQueryFilter(
|
|
153
|
-
page=page,
|
|
154
|
-
page_size=page_size,
|
|
155
|
-
model=model,
|
|
156
|
-
created_after=created_after,
|
|
157
|
-
created_by_me=created_by_me,
|
|
158
|
-
status=status,
|
|
159
|
-
wandb_project=wandb_project,
|
|
160
|
-
wandb_name=wandb_name,
|
|
161
|
-
suffix=suffix,
|
|
162
|
-
).model_dump(exclude_none=True)
|
|
163
|
-
single_response = self.client._request(method="get", path="v1/fine_tuning/jobs", params=query_params, json={})
|
|
164
|
-
async for response in single_response:
|
|
165
|
-
return Jobs(**response)
|
|
166
|
-
raise MistralException("No response received")
|
|
167
|
-
|
|
168
|
-
async def cancel(self, job_id: str) -> DetailedJob:
|
|
169
|
-
single_response = self.client._request(method="post", path=f"v1/fine_tuning/jobs/{job_id}/cancel", json={})
|
|
170
|
-
async for response in single_response:
|
|
171
|
-
return DetailedJob(**response)
|
|
172
|
-
raise MistralException("No response received")
|
|
@@ -1,93 +0,0 @@
|
|
|
1
|
-
from enum import Enum
|
|
2
|
-
from typing import List, Optional
|
|
3
|
-
|
|
4
|
-
from pydantic import BaseModel
|
|
5
|
-
|
|
6
|
-
from mistralai.models.common import UsageInfo
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class Function(BaseModel):
|
|
10
|
-
name: str
|
|
11
|
-
description: str
|
|
12
|
-
parameters: dict
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class ToolType(str, Enum):
|
|
16
|
-
function = "function"
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class FunctionCall(BaseModel):
|
|
20
|
-
name: str
|
|
21
|
-
arguments: str
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class ToolCall(BaseModel):
|
|
25
|
-
id: str = "null"
|
|
26
|
-
type: ToolType = ToolType.function
|
|
27
|
-
function: FunctionCall
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class ResponseFormats(str, Enum):
|
|
31
|
-
text: str = "text"
|
|
32
|
-
json_object: str = "json_object"
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class ToolChoice(str, Enum):
|
|
36
|
-
auto: str = "auto"
|
|
37
|
-
any: str = "any"
|
|
38
|
-
none: str = "none"
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
class ResponseFormat(BaseModel):
|
|
42
|
-
type: ResponseFormats = ResponseFormats.text
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
class ChatMessage(BaseModel):
|
|
46
|
-
role: str
|
|
47
|
-
content: str
|
|
48
|
-
name: Optional[str] = None
|
|
49
|
-
tool_calls: Optional[List[ToolCall]] = None
|
|
50
|
-
tool_call_id: Optional[str] = None
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class DeltaMessage(BaseModel):
|
|
54
|
-
role: Optional[str] = None
|
|
55
|
-
content: Optional[str] = None
|
|
56
|
-
tool_calls: Optional[List[ToolCall]] = None
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
class FinishReason(str, Enum):
|
|
60
|
-
stop = "stop"
|
|
61
|
-
length = "length"
|
|
62
|
-
error = "error"
|
|
63
|
-
tool_calls = "tool_calls"
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
class ChatCompletionResponseStreamChoice(BaseModel):
|
|
67
|
-
index: int
|
|
68
|
-
delta: DeltaMessage
|
|
69
|
-
finish_reason: Optional[FinishReason]
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
class ChatCompletionStreamResponse(BaseModel):
|
|
73
|
-
id: str
|
|
74
|
-
model: str
|
|
75
|
-
choices: List[ChatCompletionResponseStreamChoice]
|
|
76
|
-
created: Optional[int] = None
|
|
77
|
-
object: Optional[str] = None
|
|
78
|
-
usage: Optional[UsageInfo] = None
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
class ChatCompletionResponseChoice(BaseModel):
|
|
82
|
-
index: int
|
|
83
|
-
message: ChatMessage
|
|
84
|
-
finish_reason: Optional[FinishReason]
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
class ChatCompletionResponse(BaseModel):
|
|
88
|
-
id: str
|
|
89
|
-
object: str
|
|
90
|
-
created: int
|
|
91
|
-
model: str
|
|
92
|
-
choices: List[ChatCompletionResponseChoice]
|
|
93
|
-
usage: UsageInfo
|
mistralai/models/common.py
DELETED
mistralai/models/embeddings.py
DELETED
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
from typing import List
|
|
2
|
-
|
|
3
|
-
from pydantic import BaseModel
|
|
4
|
-
|
|
5
|
-
from mistralai.models.common import UsageInfo
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class EmbeddingObject(BaseModel):
|
|
9
|
-
object: str
|
|
10
|
-
embedding: List[float]
|
|
11
|
-
index: int
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class EmbeddingResponse(BaseModel):
|
|
15
|
-
id: str
|
|
16
|
-
object: str
|
|
17
|
-
data: List[EmbeddingObject]
|
|
18
|
-
model: str
|
|
19
|
-
usage: UsageInfo
|
mistralai/models/files.py
DELETED
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
from typing import Literal, Optional
|
|
2
|
-
|
|
3
|
-
from pydantic import BaseModel
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class FileObject(BaseModel):
|
|
7
|
-
id: str
|
|
8
|
-
object: str
|
|
9
|
-
bytes: int
|
|
10
|
-
created_at: int
|
|
11
|
-
filename: str
|
|
12
|
-
purpose: Optional[Literal["fine-tune"]] = "fine-tune"
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class FileDeleted(BaseModel):
|
|
16
|
-
id: str
|
|
17
|
-
object: str
|
|
18
|
-
deleted: bool
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class Files(BaseModel):
|
|
22
|
-
data: list[FileObject]
|
|
23
|
-
object: Literal["list"]
|
mistralai/models/jobs.py
DELETED
|
@@ -1,100 +0,0 @@
|
|
|
1
|
-
from datetime import datetime
|
|
2
|
-
from typing import Annotated, List, Literal, Optional, Union
|
|
3
|
-
|
|
4
|
-
from pydantic import BaseModel, Field
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class TrainingParameters(BaseModel):
|
|
8
|
-
training_steps: int = Field(1800, le=10000, ge=1)
|
|
9
|
-
learning_rate: float = Field(1.0e-4, le=1, ge=1.0e-8)
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class WandbIntegration(BaseModel):
|
|
13
|
-
type: Literal["wandb"] = "wandb"
|
|
14
|
-
project: str
|
|
15
|
-
name: Union[str, None] = None
|
|
16
|
-
run_name: Union[str, None] = None
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class WandbIntegrationIn(WandbIntegration):
|
|
20
|
-
api_key: str
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
Integration = Annotated[Union[WandbIntegration], Field(discriminator="type")]
|
|
24
|
-
IntegrationIn = Annotated[Union[WandbIntegrationIn], Field(discriminator="type")]
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class JobMetadata(BaseModel):
|
|
28
|
-
object: Literal["job.metadata"] = "job.metadata"
|
|
29
|
-
training_steps: int
|
|
30
|
-
train_tokens_per_step: int
|
|
31
|
-
data_tokens: int
|
|
32
|
-
train_tokens: int
|
|
33
|
-
epochs: float
|
|
34
|
-
expected_duration_seconds: Optional[int]
|
|
35
|
-
cost: Optional[float] = None
|
|
36
|
-
cost_currency: Optional[str] = None
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
class Job(BaseModel):
|
|
40
|
-
id: str
|
|
41
|
-
hyperparameters: TrainingParameters
|
|
42
|
-
fine_tuned_model: Union[str, None]
|
|
43
|
-
model: str
|
|
44
|
-
status: Literal[
|
|
45
|
-
"QUEUED",
|
|
46
|
-
"STARTED",
|
|
47
|
-
"RUNNING",
|
|
48
|
-
"FAILED",
|
|
49
|
-
"SUCCESS",
|
|
50
|
-
"CANCELLED",
|
|
51
|
-
"CANCELLATION_REQUESTED",
|
|
52
|
-
]
|
|
53
|
-
job_type: str
|
|
54
|
-
created_at: int
|
|
55
|
-
modified_at: int
|
|
56
|
-
training_files: list[str]
|
|
57
|
-
validation_files: Union[list[str], None] = []
|
|
58
|
-
object: Literal["job"]
|
|
59
|
-
integrations: List[Integration] = []
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
class Event(BaseModel):
|
|
63
|
-
name: str
|
|
64
|
-
data: Union[dict, None] = None
|
|
65
|
-
created_at: int
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
class Metric(BaseModel):
|
|
69
|
-
train_loss: Union[float, None] = None
|
|
70
|
-
valid_loss: Union[float, None] = None
|
|
71
|
-
valid_mean_token_accuracy: Union[float, None] = None
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
class Checkpoint(BaseModel):
|
|
75
|
-
metrics: Metric
|
|
76
|
-
step_number: int
|
|
77
|
-
created_at: int
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
class JobQueryFilter(BaseModel):
|
|
81
|
-
page: int = 0
|
|
82
|
-
page_size: int = 100
|
|
83
|
-
model: Optional[str] = None
|
|
84
|
-
created_after: Optional[datetime] = None
|
|
85
|
-
created_by_me: Optional[bool] = None
|
|
86
|
-
status: Optional[str] = None
|
|
87
|
-
wandb_project: Optional[str] = None
|
|
88
|
-
wandb_name: Optional[str] = None
|
|
89
|
-
suffix: Optional[str] = None
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
class DetailedJob(Job):
|
|
93
|
-
events: list[Event] = []
|
|
94
|
-
checkpoints: list[Checkpoint] = []
|
|
95
|
-
estimated_start_time: Optional[int] = None
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
class Jobs(BaseModel):
|
|
99
|
-
data: list[Job] = []
|
|
100
|
-
object: Literal["list"]
|