mistralai 0.4.2__py3-none-any.whl → 0.5.5a50__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (240) hide show
  1. mistralai/__init__.py +5 -0
  2. mistralai/_hooks/__init__.py +5 -0
  3. mistralai/_hooks/custom_user_agent.py +16 -0
  4. mistralai/_hooks/deprecation_warning.py +26 -0
  5. mistralai/_hooks/registration.py +17 -0
  6. mistralai/_hooks/sdkhooks.py +57 -0
  7. mistralai/_hooks/types.py +76 -0
  8. mistralai/async_client.py +5 -413
  9. mistralai/basesdk.py +216 -0
  10. mistralai/chat.py +475 -0
  11. mistralai/client.py +5 -414
  12. mistralai/embeddings.py +182 -0
  13. mistralai/files.py +600 -84
  14. mistralai/fim.py +439 -0
  15. mistralai/fine_tuning.py +855 -0
  16. mistralai/httpclient.py +78 -0
  17. mistralai/models/__init__.py +80 -0
  18. mistralai/models/archiveftmodelout.py +19 -0
  19. mistralai/models/assistantmessage.py +58 -0
  20. mistralai/models/chatcompletionchoice.py +33 -0
  21. mistralai/models/chatcompletionrequest.py +114 -0
  22. mistralai/models/chatcompletionresponse.py +27 -0
  23. mistralai/models/chatcompletionstreamrequest.py +112 -0
  24. mistralai/models/checkpointout.py +25 -0
  25. mistralai/models/completionchunk.py +27 -0
  26. mistralai/models/completionevent.py +15 -0
  27. mistralai/models/completionresponsestreamchoice.py +53 -0
  28. mistralai/models/contentchunk.py +17 -0
  29. mistralai/models/delete_model_v1_models_model_id_deleteop.py +16 -0
  30. mistralai/models/deletefileout.py +24 -0
  31. mistralai/models/deletemodelout.py +25 -0
  32. mistralai/models/deltamessage.py +52 -0
  33. mistralai/models/detailedjobout.py +96 -0
  34. mistralai/models/embeddingrequest.py +66 -0
  35. mistralai/models/embeddingresponse.py +24 -0
  36. mistralai/models/embeddingresponsedata.py +19 -0
  37. mistralai/models/eventout.py +55 -0
  38. mistralai/models/files_api_routes_delete_fileop.py +16 -0
  39. mistralai/models/files_api_routes_retrieve_fileop.py +16 -0
  40. mistralai/models/files_api_routes_upload_fileop.py +51 -0
  41. mistralai/models/fileschema.py +76 -0
  42. mistralai/models/fimcompletionrequest.py +99 -0
  43. mistralai/models/fimcompletionresponse.py +27 -0
  44. mistralai/models/fimcompletionstreamrequest.py +97 -0
  45. mistralai/models/finetuneablemodel.py +8 -0
  46. mistralai/models/ftmodelcapabilitiesout.py +21 -0
  47. mistralai/models/ftmodelout.py +70 -0
  48. mistralai/models/function.py +19 -0
  49. mistralai/models/functioncall.py +16 -0
  50. mistralai/models/githubrepositoryin.py +57 -0
  51. mistralai/models/githubrepositoryout.py +57 -0
  52. mistralai/models/httpvalidationerror.py +23 -0
  53. mistralai/models/jobin.py +78 -0
  54. mistralai/models/jobmetadataout.py +59 -0
  55. mistralai/models/jobout.py +112 -0
  56. mistralai/models/jobs_api_routes_fine_tuning_archive_fine_tuned_modelop.py +16 -0
  57. mistralai/models/jobs_api_routes_fine_tuning_cancel_fine_tuning_jobop.py +18 -0
  58. mistralai/models/jobs_api_routes_fine_tuning_create_fine_tuning_jobop.py +73 -0
  59. mistralai/models/jobs_api_routes_fine_tuning_get_fine_tuning_jobop.py +18 -0
  60. mistralai/models/jobs_api_routes_fine_tuning_get_fine_tuning_jobsop.py +86 -0
  61. mistralai/models/jobs_api_routes_fine_tuning_start_fine_tuning_jobop.py +16 -0
  62. mistralai/models/jobs_api_routes_fine_tuning_unarchive_fine_tuned_modelop.py +16 -0
  63. mistralai/models/jobs_api_routes_fine_tuning_update_fine_tuned_modelop.py +19 -0
  64. mistralai/models/jobsout.py +20 -0
  65. mistralai/models/legacyjobmetadataout.py +85 -0
  66. mistralai/models/listfilesout.py +17 -0
  67. mistralai/models/metricout.py +55 -0
  68. mistralai/models/modelcapabilities.py +21 -0
  69. mistralai/models/modelcard.py +71 -0
  70. mistralai/models/modellist.py +18 -0
  71. mistralai/models/responseformat.py +18 -0
  72. mistralai/models/retrieve_model_v1_models_model_id_getop.py +16 -0
  73. mistralai/models/retrievefileout.py +76 -0
  74. mistralai/models/sampletype.py +7 -0
  75. mistralai/models/sdkerror.py +22 -0
  76. mistralai/models/security.py +16 -0
  77. mistralai/models/source.py +7 -0
  78. mistralai/models/systemmessage.py +26 -0
  79. mistralai/models/textchunk.py +17 -0
  80. mistralai/models/tool.py +18 -0
  81. mistralai/models/toolcall.py +20 -0
  82. mistralai/models/toolmessage.py +55 -0
  83. mistralai/models/trainingfile.py +17 -0
  84. mistralai/models/trainingparameters.py +53 -0
  85. mistralai/models/trainingparametersin.py +61 -0
  86. mistralai/models/unarchiveftmodelout.py +19 -0
  87. mistralai/models/updateftmodelin.py +49 -0
  88. mistralai/models/uploadfileout.py +76 -0
  89. mistralai/models/usageinfo.py +18 -0
  90. mistralai/models/usermessage.py +26 -0
  91. mistralai/models/validationerror.py +24 -0
  92. mistralai/models/wandbintegration.py +61 -0
  93. mistralai/models/wandbintegrationout.py +57 -0
  94. mistralai/models_.py +928 -0
  95. mistralai/py.typed +1 -0
  96. mistralai/sdk.py +111 -0
  97. mistralai/sdkconfiguration.py +53 -0
  98. mistralai/types/__init__.py +21 -0
  99. mistralai/types/basemodel.py +35 -0
  100. mistralai/utils/__init__.py +82 -0
  101. mistralai/utils/annotations.py +19 -0
  102. mistralai/utils/enums.py +34 -0
  103. mistralai/utils/eventstreaming.py +179 -0
  104. mistralai/utils/forms.py +207 -0
  105. mistralai/utils/headers.py +136 -0
  106. mistralai/utils/metadata.py +118 -0
  107. mistralai/utils/queryparams.py +203 -0
  108. mistralai/utils/requestbodies.py +66 -0
  109. mistralai/utils/retries.py +216 -0
  110. mistralai/utils/security.py +182 -0
  111. mistralai/utils/serializers.py +181 -0
  112. mistralai/utils/url.py +150 -0
  113. mistralai/utils/values.py +128 -0
  114. {mistralai-0.4.2.dist-info → mistralai-0.5.5a50.dist-info}/LICENSE +1 -1
  115. mistralai-0.5.5a50.dist-info/METADATA +626 -0
  116. mistralai-0.5.5a50.dist-info/RECORD +228 -0
  117. mistralai_azure/__init__.py +5 -0
  118. mistralai_azure/_hooks/__init__.py +5 -0
  119. mistralai_azure/_hooks/custom_user_agent.py +16 -0
  120. mistralai_azure/_hooks/registration.py +15 -0
  121. mistralai_azure/_hooks/sdkhooks.py +57 -0
  122. mistralai_azure/_hooks/types.py +76 -0
  123. mistralai_azure/basesdk.py +215 -0
  124. mistralai_azure/chat.py +475 -0
  125. mistralai_azure/httpclient.py +78 -0
  126. mistralai_azure/models/__init__.py +28 -0
  127. mistralai_azure/models/assistantmessage.py +58 -0
  128. mistralai_azure/models/chatcompletionchoice.py +33 -0
  129. mistralai_azure/models/chatcompletionrequest.py +114 -0
  130. mistralai_azure/models/chatcompletionresponse.py +27 -0
  131. mistralai_azure/models/chatcompletionstreamrequest.py +112 -0
  132. mistralai_azure/models/completionchunk.py +27 -0
  133. mistralai_azure/models/completionevent.py +15 -0
  134. mistralai_azure/models/completionresponsestreamchoice.py +53 -0
  135. mistralai_azure/models/contentchunk.py +17 -0
  136. mistralai_azure/models/deltamessage.py +52 -0
  137. mistralai_azure/models/function.py +19 -0
  138. mistralai_azure/models/functioncall.py +16 -0
  139. mistralai_azure/models/httpvalidationerror.py +23 -0
  140. mistralai_azure/models/responseformat.py +18 -0
  141. mistralai_azure/models/sdkerror.py +22 -0
  142. mistralai_azure/models/security.py +16 -0
  143. mistralai_azure/models/systemmessage.py +26 -0
  144. mistralai_azure/models/textchunk.py +17 -0
  145. mistralai_azure/models/tool.py +18 -0
  146. mistralai_azure/models/toolcall.py +20 -0
  147. mistralai_azure/models/toolmessage.py +55 -0
  148. mistralai_azure/models/usageinfo.py +18 -0
  149. mistralai_azure/models/usermessage.py +26 -0
  150. mistralai_azure/models/validationerror.py +24 -0
  151. mistralai_azure/py.typed +1 -0
  152. mistralai_azure/sdk.py +102 -0
  153. mistralai_azure/sdkconfiguration.py +53 -0
  154. mistralai_azure/types/__init__.py +21 -0
  155. mistralai_azure/types/basemodel.py +35 -0
  156. mistralai_azure/utils/__init__.py +80 -0
  157. mistralai_azure/utils/annotations.py +19 -0
  158. mistralai_azure/utils/enums.py +34 -0
  159. mistralai_azure/utils/eventstreaming.py +179 -0
  160. mistralai_azure/utils/forms.py +207 -0
  161. mistralai_azure/utils/headers.py +136 -0
  162. mistralai_azure/utils/metadata.py +118 -0
  163. mistralai_azure/utils/queryparams.py +203 -0
  164. mistralai_azure/utils/requestbodies.py +66 -0
  165. mistralai_azure/utils/retries.py +216 -0
  166. mistralai_azure/utils/security.py +168 -0
  167. mistralai_azure/utils/serializers.py +181 -0
  168. mistralai_azure/utils/url.py +150 -0
  169. mistralai_azure/utils/values.py +128 -0
  170. mistralai_gcp/__init__.py +5 -0
  171. mistralai_gcp/_hooks/__init__.py +5 -0
  172. mistralai_gcp/_hooks/custom_user_agent.py +16 -0
  173. mistralai_gcp/_hooks/registration.py +15 -0
  174. mistralai_gcp/_hooks/sdkhooks.py +57 -0
  175. mistralai_gcp/_hooks/types.py +76 -0
  176. mistralai_gcp/basesdk.py +215 -0
  177. mistralai_gcp/chat.py +463 -0
  178. mistralai_gcp/fim.py +439 -0
  179. mistralai_gcp/httpclient.py +78 -0
  180. mistralai_gcp/models/__init__.py +31 -0
  181. mistralai_gcp/models/assistantmessage.py +58 -0
  182. mistralai_gcp/models/chatcompletionchoice.py +33 -0
  183. mistralai_gcp/models/chatcompletionrequest.py +110 -0
  184. mistralai_gcp/models/chatcompletionresponse.py +27 -0
  185. mistralai_gcp/models/chatcompletionstreamrequest.py +108 -0
  186. mistralai_gcp/models/completionchunk.py +27 -0
  187. mistralai_gcp/models/completionevent.py +15 -0
  188. mistralai_gcp/models/completionresponsestreamchoice.py +53 -0
  189. mistralai_gcp/models/contentchunk.py +17 -0
  190. mistralai_gcp/models/deltamessage.py +52 -0
  191. mistralai_gcp/models/fimcompletionrequest.py +99 -0
  192. mistralai_gcp/models/fimcompletionresponse.py +27 -0
  193. mistralai_gcp/models/fimcompletionstreamrequest.py +97 -0
  194. mistralai_gcp/models/function.py +19 -0
  195. mistralai_gcp/models/functioncall.py +16 -0
  196. mistralai_gcp/models/httpvalidationerror.py +23 -0
  197. mistralai_gcp/models/responseformat.py +18 -0
  198. mistralai_gcp/models/sdkerror.py +22 -0
  199. mistralai_gcp/models/security.py +16 -0
  200. mistralai_gcp/models/systemmessage.py +26 -0
  201. mistralai_gcp/models/textchunk.py +17 -0
  202. mistralai_gcp/models/tool.py +18 -0
  203. mistralai_gcp/models/toolcall.py +20 -0
  204. mistralai_gcp/models/toolmessage.py +55 -0
  205. mistralai_gcp/models/usageinfo.py +18 -0
  206. mistralai_gcp/models/usermessage.py +26 -0
  207. mistralai_gcp/models/validationerror.py +24 -0
  208. mistralai_gcp/py.typed +1 -0
  209. mistralai_gcp/sdk.py +165 -0
  210. mistralai_gcp/sdkconfiguration.py +53 -0
  211. mistralai_gcp/types/__init__.py +21 -0
  212. mistralai_gcp/types/basemodel.py +35 -0
  213. mistralai_gcp/utils/__init__.py +80 -0
  214. mistralai_gcp/utils/annotations.py +19 -0
  215. mistralai_gcp/utils/enums.py +34 -0
  216. mistralai_gcp/utils/eventstreaming.py +179 -0
  217. mistralai_gcp/utils/forms.py +207 -0
  218. mistralai_gcp/utils/headers.py +136 -0
  219. mistralai_gcp/utils/metadata.py +118 -0
  220. mistralai_gcp/utils/queryparams.py +203 -0
  221. mistralai_gcp/utils/requestbodies.py +66 -0
  222. mistralai_gcp/utils/retries.py +216 -0
  223. mistralai_gcp/utils/security.py +168 -0
  224. mistralai_gcp/utils/serializers.py +181 -0
  225. mistralai_gcp/utils/url.py +150 -0
  226. mistralai_gcp/utils/values.py +128 -0
  227. py.typed +1 -0
  228. mistralai/client_base.py +0 -211
  229. mistralai/constants.py +0 -5
  230. mistralai/exceptions.py +0 -54
  231. mistralai/jobs.py +0 -172
  232. mistralai/models/chat_completion.py +0 -93
  233. mistralai/models/common.py +0 -9
  234. mistralai/models/embeddings.py +0 -19
  235. mistralai/models/files.py +0 -23
  236. mistralai/models/jobs.py +0 -100
  237. mistralai/models/models.py +0 -39
  238. mistralai-0.4.2.dist-info/METADATA +0 -82
  239. mistralai-0.4.2.dist-info/RECORD +0 -20
  240. {mistralai-0.4.2.dist-info → mistralai-0.5.5a50.dist-info}/WHEEL +0 -0
mistralai/client_base.py DELETED
@@ -1,211 +0,0 @@
1
- import logging
2
- import os
3
- from abc import ABC
4
- from typing import Any, Callable, Dict, List, Optional, Union
5
-
6
- import orjson
7
- from httpx import Headers
8
-
9
- from mistralai.constants import HEADER_MODEL_DEPRECATION_TIMESTAMP
10
- from mistralai.exceptions import MistralException
11
- from mistralai.models.chat_completion import (
12
- ChatMessage,
13
- Function,
14
- ResponseFormat,
15
- ToolChoice,
16
- )
17
-
18
- CLIENT_VERSION = "0.4.2"
19
-
20
-
21
- class ClientBase(ABC):
22
- def __init__(
23
- self,
24
- endpoint: str,
25
- api_key: Optional[str] = None,
26
- max_retries: int = 5,
27
- timeout: int = 120,
28
- ):
29
- self._max_retries = max_retries
30
- self._timeout = timeout
31
-
32
- if api_key is None:
33
- api_key = os.environ.get("MISTRAL_API_KEY")
34
- if api_key is None:
35
- raise MistralException(message="API key not provided. Please set MISTRAL_API_KEY environment variable.")
36
- self._api_key = api_key
37
- self._endpoint = endpoint
38
- self._logger = logging.getLogger(__name__)
39
-
40
- # For azure endpoints, we default to the mistral model
41
- if "inference.azure.com" in self._endpoint:
42
- self._default_model = "mistral"
43
-
44
- self._version = CLIENT_VERSION
45
-
46
- def _get_model(self, model: Optional[str] = None) -> str:
47
- if model is not None:
48
- return model
49
- else:
50
- if self._default_model is None:
51
- raise MistralException(message="model must be provided")
52
- return self._default_model
53
-
54
- def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
55
- parsed_tools: List[Dict[str, Any]] = []
56
- for tool in tools:
57
- if tool["type"] == "function":
58
- parsed_function = {}
59
- parsed_function["type"] = tool["type"]
60
- if isinstance(tool["function"], Function):
61
- parsed_function["function"] = tool["function"].model_dump(exclude_none=True)
62
- else:
63
- parsed_function["function"] = tool["function"]
64
-
65
- parsed_tools.append(parsed_function)
66
-
67
- return parsed_tools
68
-
69
- def _parse_tool_choice(self, tool_choice: Union[str, ToolChoice]) -> str:
70
- if isinstance(tool_choice, ToolChoice):
71
- return tool_choice.value
72
- return tool_choice
73
-
74
- def _parse_response_format(self, response_format: Union[Dict[str, Any], ResponseFormat]) -> Dict[str, Any]:
75
- if isinstance(response_format, ResponseFormat):
76
- return response_format.model_dump(exclude_none=True)
77
- return response_format
78
-
79
- def _parse_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
80
- parsed_messages: List[Dict[str, Any]] = []
81
- for message in messages:
82
- if isinstance(message, ChatMessage):
83
- parsed_messages.append(message.model_dump(exclude_none=True))
84
- else:
85
- parsed_messages.append(message)
86
-
87
- return parsed_messages
88
-
89
- def _check_model_deprecation_header_callback_factory(self, model: Optional[str] = None) -> Callable:
90
- model = self._get_model(model)
91
-
92
- def _check_model_deprecation_header_callback(
93
- headers: Headers,
94
- ) -> None:
95
- if HEADER_MODEL_DEPRECATION_TIMESTAMP in headers:
96
- self._logger.warning(
97
- f"WARNING: The model {model} is deprecated "
98
- f"and will be removed on {headers[HEADER_MODEL_DEPRECATION_TIMESTAMP]}. "
99
- "Please refer to https://docs.mistral.ai/getting-started/models/#api-versioning "
100
- "for more information."
101
- )
102
-
103
- return _check_model_deprecation_header_callback
104
-
105
- def _make_completion_request(
106
- self,
107
- prompt: str,
108
- model: Optional[str] = None,
109
- suffix: Optional[str] = None,
110
- temperature: Optional[float] = None,
111
- max_tokens: Optional[int] = None,
112
- top_p: Optional[float] = None,
113
- random_seed: Optional[int] = None,
114
- stop: Optional[List[str]] = None,
115
- stream: Optional[bool] = False,
116
- ) -> Dict[str, Any]:
117
- request_data: Dict[str, Any] = {
118
- "prompt": prompt,
119
- "suffix": suffix,
120
- "model": model,
121
- "stream": stream,
122
- }
123
-
124
- if stop is not None:
125
- request_data["stop"] = stop
126
-
127
- request_data["model"] = self._get_model(model)
128
-
129
- request_data.update(
130
- self._build_sampling_params(
131
- temperature=temperature,
132
- max_tokens=max_tokens,
133
- top_p=top_p,
134
- random_seed=random_seed,
135
- )
136
- )
137
-
138
- self._logger.debug(f"Completion request: {request_data}")
139
-
140
- return request_data
141
-
142
- def _build_sampling_params(
143
- self,
144
- max_tokens: Optional[int],
145
- random_seed: Optional[int],
146
- temperature: Optional[float],
147
- top_p: Optional[float],
148
- ) -> Dict[str, Any]:
149
- params = {}
150
- if temperature is not None:
151
- params["temperature"] = temperature
152
- if max_tokens is not None:
153
- params["max_tokens"] = max_tokens
154
- if top_p is not None:
155
- params["top_p"] = top_p
156
- if random_seed is not None:
157
- params["random_seed"] = random_seed
158
- return params
159
-
160
- def _make_chat_request(
161
- self,
162
- messages: List[Any],
163
- model: Optional[str] = None,
164
- tools: Optional[List[Dict[str, Any]]] = None,
165
- temperature: Optional[float] = None,
166
- max_tokens: Optional[int] = None,
167
- top_p: Optional[float] = None,
168
- random_seed: Optional[int] = None,
169
- stream: Optional[bool] = None,
170
- safe_prompt: Optional[bool] = False,
171
- tool_choice: Optional[Union[str, ToolChoice]] = None,
172
- response_format: Optional[Union[Dict[str, str], ResponseFormat]] = None,
173
- ) -> Dict[str, Any]:
174
- request_data: Dict[str, Any] = {
175
- "messages": self._parse_messages(messages),
176
- }
177
-
178
- request_data["model"] = self._get_model(model)
179
-
180
- request_data.update(
181
- self._build_sampling_params(
182
- temperature=temperature,
183
- max_tokens=max_tokens,
184
- top_p=top_p,
185
- random_seed=random_seed,
186
- )
187
- )
188
-
189
- if safe_prompt:
190
- request_data["safe_prompt"] = safe_prompt
191
- if tools is not None:
192
- request_data["tools"] = self._parse_tools(tools)
193
- if stream is not None:
194
- request_data["stream"] = stream
195
-
196
- if tool_choice is not None:
197
- request_data["tool_choice"] = self._parse_tool_choice(tool_choice)
198
- if response_format is not None:
199
- request_data["response_format"] = self._parse_response_format(response_format)
200
-
201
- self._logger.debug(f"Chat request: {request_data}")
202
-
203
- return request_data
204
-
205
- def _process_line(self, line: str) -> Optional[Dict[str, Any]]:
206
- if line.startswith("data: "):
207
- line = line[6:].strip()
208
- if line != "[DONE]":
209
- json_streamed_response: Dict[str, Any] = orjson.loads(line)
210
- return json_streamed_response
211
- return None
mistralai/constants.py DELETED
@@ -1,5 +0,0 @@
1
- RETRY_STATUS_CODES = {429, 500, 502, 503, 504}
2
-
3
- ENDPOINT = "https://api.mistral.ai"
4
-
5
- HEADER_MODEL_DEPRECATION_TIMESTAMP = "x-model-deprecation-timestamp"
mistralai/exceptions.py DELETED
@@ -1,54 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Any, Dict, Optional
4
-
5
- from httpx import Response
6
-
7
-
8
- class MistralException(Exception):
9
- """Base Exception class, returned when nothing more specific applies"""
10
-
11
- def __init__(self, message: Optional[str] = None) -> None:
12
- super(MistralException, self).__init__(message)
13
-
14
- self.message = message
15
-
16
- def __str__(self) -> str:
17
- msg = self.message or "<empty message>"
18
- return msg
19
-
20
- def __repr__(self) -> str:
21
- return f"{self.__class__.__name__}(message={str(self)})"
22
-
23
-
24
- class MistralAPIException(MistralException):
25
- """Returned when the API responds with an error message"""
26
-
27
- def __init__(
28
- self,
29
- message: Optional[str] = None,
30
- http_status: Optional[int] = None,
31
- headers: Optional[Dict[str, Any]] = None,
32
- ) -> None:
33
- super().__init__(message)
34
- self.http_status = http_status
35
- self.headers = headers or {}
36
-
37
- @classmethod
38
- def from_response(cls, response: Response, message: Optional[str] = None) -> MistralAPIException:
39
- return cls(
40
- message=message or response.text,
41
- http_status=response.status_code,
42
- headers=dict(response.headers),
43
- )
44
-
45
- def __repr__(self) -> str:
46
- return f"{self.__class__.__name__}(message={str(self)}, http_status={self.http_status})"
47
-
48
-
49
- class MistralAPIStatusException(MistralAPIException):
50
- """Returned when we receive a non-200 response from the API that we should retry"""
51
-
52
-
53
- class MistralConnectionException(MistralException):
54
- """Returned when the SDK can not reach the API server for any reason"""
mistralai/jobs.py DELETED
@@ -1,172 +0,0 @@
1
- from datetime import datetime
2
- from typing import Any, Optional, Union
3
-
4
- from mistralai.exceptions import (
5
- MistralException,
6
- )
7
- from mistralai.models.jobs import DetailedJob, IntegrationIn, Job, JobMetadata, JobQueryFilter, Jobs, TrainingParameters
8
-
9
-
10
- class JobsClient:
11
- def __init__(self, client: Any):
12
- self.client = client
13
-
14
- def create(
15
- self,
16
- model: str,
17
- training_files: Union[list[str], None] = None,
18
- validation_files: Union[list[str], None] = None,
19
- hyperparameters: TrainingParameters = TrainingParameters(
20
- training_steps=1800,
21
- learning_rate=1.0e-4,
22
- ),
23
- suffix: Union[str, None] = None,
24
- integrations: Union[set[IntegrationIn], None] = None,
25
- training_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
26
- validation_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
27
- dry_run: bool = False,
28
- ) -> Union[Job, JobMetadata]:
29
- # Handle deprecated arguments
30
- if not training_files and training_file:
31
- training_files = [training_file]
32
- if not validation_files and validation_file:
33
- validation_files = [validation_file]
34
- single_response = self.client._request(
35
- method="post",
36
- json={
37
- "model": model,
38
- "training_files": training_files,
39
- "validation_files": validation_files,
40
- "hyperparameters": hyperparameters.dict(),
41
- "suffix": suffix,
42
- "integrations": integrations,
43
- },
44
- path="v1/fine_tuning/jobs",
45
- params={"dry_run": dry_run},
46
- )
47
- for response in single_response:
48
- return Job(**response) if not dry_run else JobMetadata(**response)
49
- raise MistralException("No response received")
50
-
51
- def retrieve(self, job_id: str) -> DetailedJob:
52
- single_response = self.client._request(method="get", path=f"v1/fine_tuning/jobs/{job_id}", json={})
53
- for response in single_response:
54
- return DetailedJob(**response)
55
- raise MistralException("No response received")
56
-
57
- def list(
58
- self,
59
- page: int = 0,
60
- page_size: int = 10,
61
- model: Optional[str] = None,
62
- created_after: Optional[datetime] = None,
63
- created_by_me: Optional[bool] = None,
64
- status: Optional[str] = None,
65
- wandb_project: Optional[str] = None,
66
- wandb_name: Optional[str] = None,
67
- suffix: Optional[str] = None,
68
- ) -> Jobs:
69
- query_params = JobQueryFilter(
70
- page=page,
71
- page_size=page_size,
72
- model=model,
73
- created_after=created_after,
74
- created_by_me=created_by_me,
75
- status=status,
76
- wandb_project=wandb_project,
77
- wandb_name=wandb_name,
78
- suffix=suffix,
79
- ).model_dump(exclude_none=True)
80
- single_response = self.client._request(method="get", params=query_params, path="v1/fine_tuning/jobs", json={})
81
- for response in single_response:
82
- return Jobs(**response)
83
- raise MistralException("No response received")
84
-
85
- def cancel(self, job_id: str) -> DetailedJob:
86
- single_response = self.client._request(method="post", path=f"v1/fine_tuning/jobs/{job_id}/cancel", json={})
87
- for response in single_response:
88
- return DetailedJob(**response)
89
- raise MistralException("No response received")
90
-
91
-
92
- class JobsAsyncClient:
93
- def __init__(self, client: Any):
94
- self.client = client
95
-
96
- async def create(
97
- self,
98
- model: str,
99
- training_files: Union[list[str], None] = None,
100
- validation_files: Union[list[str], None] = None,
101
- hyperparameters: TrainingParameters = TrainingParameters(
102
- training_steps=1800,
103
- learning_rate=1.0e-4,
104
- ),
105
- suffix: Union[str, None] = None,
106
- integrations: Union[set[IntegrationIn], None] = None,
107
- training_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
108
- validation_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
109
- dry_run: bool = False,
110
- ) -> Union[Job, JobMetadata]:
111
- # Handle deprecated arguments
112
- if not training_files and training_file:
113
- training_files = [training_file]
114
- if not validation_files and validation_file:
115
- validation_files = [validation_file]
116
-
117
- single_response = self.client._request(
118
- method="post",
119
- json={
120
- "model": model,
121
- "training_files": training_files,
122
- "validation_files": validation_files,
123
- "hyperparameters": hyperparameters.dict(),
124
- "suffix": suffix,
125
- "integrations": integrations,
126
- },
127
- path="v1/fine_tuning/jobs",
128
- params={"dry_run": dry_run},
129
- )
130
- async for response in single_response:
131
- return Job(**response) if not dry_run else JobMetadata(**response)
132
- raise MistralException("No response received")
133
-
134
- async def retrieve(self, job_id: str) -> DetailedJob:
135
- single_response = self.client._request(method="get", path=f"v1/fine_tuning/jobs/{job_id}", json={})
136
- async for response in single_response:
137
- return DetailedJob(**response)
138
- raise MistralException("No response received")
139
-
140
- async def list(
141
- self,
142
- page: int = 0,
143
- page_size: int = 10,
144
- model: Optional[str] = None,
145
- created_after: Optional[datetime] = None,
146
- created_by_me: Optional[bool] = None,
147
- status: Optional[str] = None,
148
- wandb_project: Optional[str] = None,
149
- wandb_name: Optional[str] = None,
150
- suffix: Optional[str] = None,
151
- ) -> Jobs:
152
- query_params = JobQueryFilter(
153
- page=page,
154
- page_size=page_size,
155
- model=model,
156
- created_after=created_after,
157
- created_by_me=created_by_me,
158
- status=status,
159
- wandb_project=wandb_project,
160
- wandb_name=wandb_name,
161
- suffix=suffix,
162
- ).model_dump(exclude_none=True)
163
- single_response = self.client._request(method="get", path="v1/fine_tuning/jobs", params=query_params, json={})
164
- async for response in single_response:
165
- return Jobs(**response)
166
- raise MistralException("No response received")
167
-
168
- async def cancel(self, job_id: str) -> DetailedJob:
169
- single_response = self.client._request(method="post", path=f"v1/fine_tuning/jobs/{job_id}/cancel", json={})
170
- async for response in single_response:
171
- return DetailedJob(**response)
172
- raise MistralException("No response received")
@@ -1,93 +0,0 @@
1
- from enum import Enum
2
- from typing import List, Optional
3
-
4
- from pydantic import BaseModel
5
-
6
- from mistralai.models.common import UsageInfo
7
-
8
-
9
- class Function(BaseModel):
10
- name: str
11
- description: str
12
- parameters: dict
13
-
14
-
15
- class ToolType(str, Enum):
16
- function = "function"
17
-
18
-
19
- class FunctionCall(BaseModel):
20
- name: str
21
- arguments: str
22
-
23
-
24
- class ToolCall(BaseModel):
25
- id: str = "null"
26
- type: ToolType = ToolType.function
27
- function: FunctionCall
28
-
29
-
30
- class ResponseFormats(str, Enum):
31
- text: str = "text"
32
- json_object: str = "json_object"
33
-
34
-
35
- class ToolChoice(str, Enum):
36
- auto: str = "auto"
37
- any: str = "any"
38
- none: str = "none"
39
-
40
-
41
- class ResponseFormat(BaseModel):
42
- type: ResponseFormats = ResponseFormats.text
43
-
44
-
45
- class ChatMessage(BaseModel):
46
- role: str
47
- content: str
48
- name: Optional[str] = None
49
- tool_calls: Optional[List[ToolCall]] = None
50
- tool_call_id: Optional[str] = None
51
-
52
-
53
- class DeltaMessage(BaseModel):
54
- role: Optional[str] = None
55
- content: Optional[str] = None
56
- tool_calls: Optional[List[ToolCall]] = None
57
-
58
-
59
- class FinishReason(str, Enum):
60
- stop = "stop"
61
- length = "length"
62
- error = "error"
63
- tool_calls = "tool_calls"
64
-
65
-
66
- class ChatCompletionResponseStreamChoice(BaseModel):
67
- index: int
68
- delta: DeltaMessage
69
- finish_reason: Optional[FinishReason]
70
-
71
-
72
- class ChatCompletionStreamResponse(BaseModel):
73
- id: str
74
- model: str
75
- choices: List[ChatCompletionResponseStreamChoice]
76
- created: Optional[int] = None
77
- object: Optional[str] = None
78
- usage: Optional[UsageInfo] = None
79
-
80
-
81
- class ChatCompletionResponseChoice(BaseModel):
82
- index: int
83
- message: ChatMessage
84
- finish_reason: Optional[FinishReason]
85
-
86
-
87
- class ChatCompletionResponse(BaseModel):
88
- id: str
89
- object: str
90
- created: int
91
- model: str
92
- choices: List[ChatCompletionResponseChoice]
93
- usage: UsageInfo
@@ -1,9 +0,0 @@
1
- from typing import Optional
2
-
3
- from pydantic import BaseModel
4
-
5
-
6
- class UsageInfo(BaseModel):
7
- prompt_tokens: int
8
- total_tokens: int
9
- completion_tokens: Optional[int]
@@ -1,19 +0,0 @@
1
- from typing import List
2
-
3
- from pydantic import BaseModel
4
-
5
- from mistralai.models.common import UsageInfo
6
-
7
-
8
- class EmbeddingObject(BaseModel):
9
- object: str
10
- embedding: List[float]
11
- index: int
12
-
13
-
14
- class EmbeddingResponse(BaseModel):
15
- id: str
16
- object: str
17
- data: List[EmbeddingObject]
18
- model: str
19
- usage: UsageInfo
mistralai/models/files.py DELETED
@@ -1,23 +0,0 @@
1
- from typing import Literal, Optional
2
-
3
- from pydantic import BaseModel
4
-
5
-
6
- class FileObject(BaseModel):
7
- id: str
8
- object: str
9
- bytes: int
10
- created_at: int
11
- filename: str
12
- purpose: Optional[Literal["fine-tune"]] = "fine-tune"
13
-
14
-
15
- class FileDeleted(BaseModel):
16
- id: str
17
- object: str
18
- deleted: bool
19
-
20
-
21
- class Files(BaseModel):
22
- data: list[FileObject]
23
- object: Literal["list"]
mistralai/models/jobs.py DELETED
@@ -1,100 +0,0 @@
1
- from datetime import datetime
2
- from typing import Annotated, List, Literal, Optional, Union
3
-
4
- from pydantic import BaseModel, Field
5
-
6
-
7
- class TrainingParameters(BaseModel):
8
- training_steps: int = Field(1800, le=10000, ge=1)
9
- learning_rate: float = Field(1.0e-4, le=1, ge=1.0e-8)
10
-
11
-
12
- class WandbIntegration(BaseModel):
13
- type: Literal["wandb"] = "wandb"
14
- project: str
15
- name: Union[str, None] = None
16
- run_name: Union[str, None] = None
17
-
18
-
19
- class WandbIntegrationIn(WandbIntegration):
20
- api_key: str
21
-
22
-
23
- Integration = Annotated[Union[WandbIntegration], Field(discriminator="type")]
24
- IntegrationIn = Annotated[Union[WandbIntegrationIn], Field(discriminator="type")]
25
-
26
-
27
- class JobMetadata(BaseModel):
28
- object: Literal["job.metadata"] = "job.metadata"
29
- training_steps: int
30
- train_tokens_per_step: int
31
- data_tokens: int
32
- train_tokens: int
33
- epochs: float
34
- expected_duration_seconds: Optional[int]
35
- cost: Optional[float] = None
36
- cost_currency: Optional[str] = None
37
-
38
-
39
- class Job(BaseModel):
40
- id: str
41
- hyperparameters: TrainingParameters
42
- fine_tuned_model: Union[str, None]
43
- model: str
44
- status: Literal[
45
- "QUEUED",
46
- "STARTED",
47
- "RUNNING",
48
- "FAILED",
49
- "SUCCESS",
50
- "CANCELLED",
51
- "CANCELLATION_REQUESTED",
52
- ]
53
- job_type: str
54
- created_at: int
55
- modified_at: int
56
- training_files: list[str]
57
- validation_files: Union[list[str], None] = []
58
- object: Literal["job"]
59
- integrations: List[Integration] = []
60
-
61
-
62
- class Event(BaseModel):
63
- name: str
64
- data: Union[dict, None] = None
65
- created_at: int
66
-
67
-
68
- class Metric(BaseModel):
69
- train_loss: Union[float, None] = None
70
- valid_loss: Union[float, None] = None
71
- valid_mean_token_accuracy: Union[float, None] = None
72
-
73
-
74
- class Checkpoint(BaseModel):
75
- metrics: Metric
76
- step_number: int
77
- created_at: int
78
-
79
-
80
- class JobQueryFilter(BaseModel):
81
- page: int = 0
82
- page_size: int = 100
83
- model: Optional[str] = None
84
- created_after: Optional[datetime] = None
85
- created_by_me: Optional[bool] = None
86
- status: Optional[str] = None
87
- wandb_project: Optional[str] = None
88
- wandb_name: Optional[str] = None
89
- suffix: Optional[str] = None
90
-
91
-
92
- class DetailedJob(Job):
93
- events: list[Event] = []
94
- checkpoints: list[Checkpoint] = []
95
- estimated_start_time: Optional[int] = None
96
-
97
-
98
- class Jobs(BaseModel):
99
- data: list[Job] = []
100
- object: Literal["list"]