mistralai 0.4.1__py3-none-any.whl → 0.5.5a50__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (240) hide show
  1. mistralai/__init__.py +5 -0
  2. mistralai/_hooks/__init__.py +5 -0
  3. mistralai/_hooks/custom_user_agent.py +16 -0
  4. mistralai/_hooks/deprecation_warning.py +26 -0
  5. mistralai/_hooks/registration.py +17 -0
  6. mistralai/_hooks/sdkhooks.py +57 -0
  7. mistralai/_hooks/types.py +76 -0
  8. mistralai/async_client.py +5 -374
  9. mistralai/basesdk.py +216 -0
  10. mistralai/chat.py +475 -0
  11. mistralai/client.py +5 -372
  12. mistralai/embeddings.py +182 -0
  13. mistralai/files.py +600 -84
  14. mistralai/fim.py +439 -0
  15. mistralai/fine_tuning.py +855 -0
  16. mistralai/httpclient.py +78 -0
  17. mistralai/models/__init__.py +80 -0
  18. mistralai/models/archiveftmodelout.py +19 -0
  19. mistralai/models/assistantmessage.py +58 -0
  20. mistralai/models/chatcompletionchoice.py +33 -0
  21. mistralai/models/chatcompletionrequest.py +114 -0
  22. mistralai/models/chatcompletionresponse.py +27 -0
  23. mistralai/models/chatcompletionstreamrequest.py +112 -0
  24. mistralai/models/checkpointout.py +25 -0
  25. mistralai/models/completionchunk.py +27 -0
  26. mistralai/models/completionevent.py +15 -0
  27. mistralai/models/completionresponsestreamchoice.py +53 -0
  28. mistralai/models/contentchunk.py +17 -0
  29. mistralai/models/delete_model_v1_models_model_id_deleteop.py +16 -0
  30. mistralai/models/deletefileout.py +24 -0
  31. mistralai/models/deletemodelout.py +25 -0
  32. mistralai/models/deltamessage.py +52 -0
  33. mistralai/models/detailedjobout.py +96 -0
  34. mistralai/models/embeddingrequest.py +66 -0
  35. mistralai/models/embeddingresponse.py +24 -0
  36. mistralai/models/embeddingresponsedata.py +19 -0
  37. mistralai/models/eventout.py +55 -0
  38. mistralai/models/files_api_routes_delete_fileop.py +16 -0
  39. mistralai/models/files_api_routes_retrieve_fileop.py +16 -0
  40. mistralai/models/files_api_routes_upload_fileop.py +51 -0
  41. mistralai/models/fileschema.py +76 -0
  42. mistralai/models/fimcompletionrequest.py +99 -0
  43. mistralai/models/fimcompletionresponse.py +27 -0
  44. mistralai/models/fimcompletionstreamrequest.py +97 -0
  45. mistralai/models/finetuneablemodel.py +8 -0
  46. mistralai/models/ftmodelcapabilitiesout.py +21 -0
  47. mistralai/models/ftmodelout.py +70 -0
  48. mistralai/models/function.py +19 -0
  49. mistralai/models/functioncall.py +16 -0
  50. mistralai/models/githubrepositoryin.py +57 -0
  51. mistralai/models/githubrepositoryout.py +57 -0
  52. mistralai/models/httpvalidationerror.py +23 -0
  53. mistralai/models/jobin.py +78 -0
  54. mistralai/models/jobmetadataout.py +59 -0
  55. mistralai/models/jobout.py +112 -0
  56. mistralai/models/jobs_api_routes_fine_tuning_archive_fine_tuned_modelop.py +16 -0
  57. mistralai/models/jobs_api_routes_fine_tuning_cancel_fine_tuning_jobop.py +18 -0
  58. mistralai/models/jobs_api_routes_fine_tuning_create_fine_tuning_jobop.py +73 -0
  59. mistralai/models/jobs_api_routes_fine_tuning_get_fine_tuning_jobop.py +18 -0
  60. mistralai/models/jobs_api_routes_fine_tuning_get_fine_tuning_jobsop.py +86 -0
  61. mistralai/models/jobs_api_routes_fine_tuning_start_fine_tuning_jobop.py +16 -0
  62. mistralai/models/jobs_api_routes_fine_tuning_unarchive_fine_tuned_modelop.py +16 -0
  63. mistralai/models/jobs_api_routes_fine_tuning_update_fine_tuned_modelop.py +19 -0
  64. mistralai/models/jobsout.py +20 -0
  65. mistralai/models/legacyjobmetadataout.py +85 -0
  66. mistralai/models/listfilesout.py +17 -0
  67. mistralai/models/metricout.py +55 -0
  68. mistralai/models/modelcapabilities.py +21 -0
  69. mistralai/models/modelcard.py +71 -0
  70. mistralai/models/modellist.py +18 -0
  71. mistralai/models/responseformat.py +18 -0
  72. mistralai/models/retrieve_model_v1_models_model_id_getop.py +16 -0
  73. mistralai/models/retrievefileout.py +76 -0
  74. mistralai/models/sampletype.py +7 -0
  75. mistralai/models/sdkerror.py +22 -0
  76. mistralai/models/security.py +16 -0
  77. mistralai/models/source.py +7 -0
  78. mistralai/models/systemmessage.py +26 -0
  79. mistralai/models/textchunk.py +17 -0
  80. mistralai/models/tool.py +18 -0
  81. mistralai/models/toolcall.py +20 -0
  82. mistralai/models/toolmessage.py +55 -0
  83. mistralai/models/trainingfile.py +17 -0
  84. mistralai/models/trainingparameters.py +53 -0
  85. mistralai/models/trainingparametersin.py +61 -0
  86. mistralai/models/unarchiveftmodelout.py +19 -0
  87. mistralai/models/updateftmodelin.py +49 -0
  88. mistralai/models/uploadfileout.py +76 -0
  89. mistralai/models/usageinfo.py +18 -0
  90. mistralai/models/usermessage.py +26 -0
  91. mistralai/models/validationerror.py +24 -0
  92. mistralai/models/wandbintegration.py +61 -0
  93. mistralai/models/wandbintegrationout.py +57 -0
  94. mistralai/models_.py +928 -0
  95. mistralai/py.typed +1 -0
  96. mistralai/sdk.py +111 -0
  97. mistralai/sdkconfiguration.py +53 -0
  98. mistralai/types/__init__.py +21 -0
  99. mistralai/types/basemodel.py +35 -0
  100. mistralai/utils/__init__.py +82 -0
  101. mistralai/utils/annotations.py +19 -0
  102. mistralai/utils/enums.py +34 -0
  103. mistralai/utils/eventstreaming.py +179 -0
  104. mistralai/utils/forms.py +207 -0
  105. mistralai/utils/headers.py +136 -0
  106. mistralai/utils/metadata.py +118 -0
  107. mistralai/utils/queryparams.py +203 -0
  108. mistralai/utils/requestbodies.py +66 -0
  109. mistralai/utils/retries.py +216 -0
  110. mistralai/utils/security.py +182 -0
  111. mistralai/utils/serializers.py +181 -0
  112. mistralai/utils/url.py +150 -0
  113. mistralai/utils/values.py +128 -0
  114. {mistralai-0.4.1.dist-info → mistralai-0.5.5a50.dist-info}/LICENSE +1 -1
  115. mistralai-0.5.5a50.dist-info/METADATA +626 -0
  116. mistralai-0.5.5a50.dist-info/RECORD +228 -0
  117. mistralai_azure/__init__.py +5 -0
  118. mistralai_azure/_hooks/__init__.py +5 -0
  119. mistralai_azure/_hooks/custom_user_agent.py +16 -0
  120. mistralai_azure/_hooks/registration.py +15 -0
  121. mistralai_azure/_hooks/sdkhooks.py +57 -0
  122. mistralai_azure/_hooks/types.py +76 -0
  123. mistralai_azure/basesdk.py +215 -0
  124. mistralai_azure/chat.py +475 -0
  125. mistralai_azure/httpclient.py +78 -0
  126. mistralai_azure/models/__init__.py +28 -0
  127. mistralai_azure/models/assistantmessage.py +58 -0
  128. mistralai_azure/models/chatcompletionchoice.py +33 -0
  129. mistralai_azure/models/chatcompletionrequest.py +114 -0
  130. mistralai_azure/models/chatcompletionresponse.py +27 -0
  131. mistralai_azure/models/chatcompletionstreamrequest.py +112 -0
  132. mistralai_azure/models/completionchunk.py +27 -0
  133. mistralai_azure/models/completionevent.py +15 -0
  134. mistralai_azure/models/completionresponsestreamchoice.py +53 -0
  135. mistralai_azure/models/contentchunk.py +17 -0
  136. mistralai_azure/models/deltamessage.py +52 -0
  137. mistralai_azure/models/function.py +19 -0
  138. mistralai_azure/models/functioncall.py +16 -0
  139. mistralai_azure/models/httpvalidationerror.py +23 -0
  140. mistralai_azure/models/responseformat.py +18 -0
  141. mistralai_azure/models/sdkerror.py +22 -0
  142. mistralai_azure/models/security.py +16 -0
  143. mistralai_azure/models/systemmessage.py +26 -0
  144. mistralai_azure/models/textchunk.py +17 -0
  145. mistralai_azure/models/tool.py +18 -0
  146. mistralai_azure/models/toolcall.py +20 -0
  147. mistralai_azure/models/toolmessage.py +55 -0
  148. mistralai_azure/models/usageinfo.py +18 -0
  149. mistralai_azure/models/usermessage.py +26 -0
  150. mistralai_azure/models/validationerror.py +24 -0
  151. mistralai_azure/py.typed +1 -0
  152. mistralai_azure/sdk.py +102 -0
  153. mistralai_azure/sdkconfiguration.py +53 -0
  154. mistralai_azure/types/__init__.py +21 -0
  155. mistralai_azure/types/basemodel.py +35 -0
  156. mistralai_azure/utils/__init__.py +80 -0
  157. mistralai_azure/utils/annotations.py +19 -0
  158. mistralai_azure/utils/enums.py +34 -0
  159. mistralai_azure/utils/eventstreaming.py +179 -0
  160. mistralai_azure/utils/forms.py +207 -0
  161. mistralai_azure/utils/headers.py +136 -0
  162. mistralai_azure/utils/metadata.py +118 -0
  163. mistralai_azure/utils/queryparams.py +203 -0
  164. mistralai_azure/utils/requestbodies.py +66 -0
  165. mistralai_azure/utils/retries.py +216 -0
  166. mistralai_azure/utils/security.py +168 -0
  167. mistralai_azure/utils/serializers.py +181 -0
  168. mistralai_azure/utils/url.py +150 -0
  169. mistralai_azure/utils/values.py +128 -0
  170. mistralai_gcp/__init__.py +5 -0
  171. mistralai_gcp/_hooks/__init__.py +5 -0
  172. mistralai_gcp/_hooks/custom_user_agent.py +16 -0
  173. mistralai_gcp/_hooks/registration.py +15 -0
  174. mistralai_gcp/_hooks/sdkhooks.py +57 -0
  175. mistralai_gcp/_hooks/types.py +76 -0
  176. mistralai_gcp/basesdk.py +215 -0
  177. mistralai_gcp/chat.py +463 -0
  178. mistralai_gcp/fim.py +439 -0
  179. mistralai_gcp/httpclient.py +78 -0
  180. mistralai_gcp/models/__init__.py +31 -0
  181. mistralai_gcp/models/assistantmessage.py +58 -0
  182. mistralai_gcp/models/chatcompletionchoice.py +33 -0
  183. mistralai_gcp/models/chatcompletionrequest.py +110 -0
  184. mistralai_gcp/models/chatcompletionresponse.py +27 -0
  185. mistralai_gcp/models/chatcompletionstreamrequest.py +108 -0
  186. mistralai_gcp/models/completionchunk.py +27 -0
  187. mistralai_gcp/models/completionevent.py +15 -0
  188. mistralai_gcp/models/completionresponsestreamchoice.py +53 -0
  189. mistralai_gcp/models/contentchunk.py +17 -0
  190. mistralai_gcp/models/deltamessage.py +52 -0
  191. mistralai_gcp/models/fimcompletionrequest.py +99 -0
  192. mistralai_gcp/models/fimcompletionresponse.py +27 -0
  193. mistralai_gcp/models/fimcompletionstreamrequest.py +97 -0
  194. mistralai_gcp/models/function.py +19 -0
  195. mistralai_gcp/models/functioncall.py +16 -0
  196. mistralai_gcp/models/httpvalidationerror.py +23 -0
  197. mistralai_gcp/models/responseformat.py +18 -0
  198. mistralai_gcp/models/sdkerror.py +22 -0
  199. mistralai_gcp/models/security.py +16 -0
  200. mistralai_gcp/models/systemmessage.py +26 -0
  201. mistralai_gcp/models/textchunk.py +17 -0
  202. mistralai_gcp/models/tool.py +18 -0
  203. mistralai_gcp/models/toolcall.py +20 -0
  204. mistralai_gcp/models/toolmessage.py +55 -0
  205. mistralai_gcp/models/usageinfo.py +18 -0
  206. mistralai_gcp/models/usermessage.py +26 -0
  207. mistralai_gcp/models/validationerror.py +24 -0
  208. mistralai_gcp/py.typed +1 -0
  209. mistralai_gcp/sdk.py +165 -0
  210. mistralai_gcp/sdkconfiguration.py +53 -0
  211. mistralai_gcp/types/__init__.py +21 -0
  212. mistralai_gcp/types/basemodel.py +35 -0
  213. mistralai_gcp/utils/__init__.py +80 -0
  214. mistralai_gcp/utils/annotations.py +19 -0
  215. mistralai_gcp/utils/enums.py +34 -0
  216. mistralai_gcp/utils/eventstreaming.py +179 -0
  217. mistralai_gcp/utils/forms.py +207 -0
  218. mistralai_gcp/utils/headers.py +136 -0
  219. mistralai_gcp/utils/metadata.py +118 -0
  220. mistralai_gcp/utils/queryparams.py +203 -0
  221. mistralai_gcp/utils/requestbodies.py +66 -0
  222. mistralai_gcp/utils/retries.py +216 -0
  223. mistralai_gcp/utils/security.py +168 -0
  224. mistralai_gcp/utils/serializers.py +181 -0
  225. mistralai_gcp/utils/url.py +150 -0
  226. mistralai_gcp/utils/values.py +128 -0
  227. py.typed +1 -0
  228. mistralai/client_base.py +0 -186
  229. mistralai/constants.py +0 -3
  230. mistralai/exceptions.py +0 -54
  231. mistralai/jobs.py +0 -172
  232. mistralai/models/chat_completion.py +0 -93
  233. mistralai/models/common.py +0 -9
  234. mistralai/models/embeddings.py +0 -19
  235. mistralai/models/files.py +0 -23
  236. mistralai/models/jobs.py +0 -98
  237. mistralai/models/models.py +0 -39
  238. mistralai-0.4.1.dist-info/METADATA +0 -80
  239. mistralai-0.4.1.dist-info/RECORD +0 -20
  240. {mistralai-0.4.1.dist-info → mistralai-0.5.5a50.dist-info}/WHEEL +0 -0
mistralai/client_base.py DELETED
@@ -1,186 +0,0 @@
1
- import logging
2
- import os
3
- from abc import ABC
4
- from typing import Any, Dict, List, Optional, Union
5
-
6
- import orjson
7
-
8
- from mistralai.exceptions import (
9
- MistralException,
10
- )
11
- from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice
12
-
13
- CLIENT_VERSION = "0.4.1"
14
-
15
-
16
- class ClientBase(ABC):
17
- def __init__(
18
- self,
19
- endpoint: str,
20
- api_key: Optional[str] = None,
21
- max_retries: int = 5,
22
- timeout: int = 120,
23
- ):
24
- self._max_retries = max_retries
25
- self._timeout = timeout
26
-
27
- if api_key is None:
28
- api_key = os.environ.get("MISTRAL_API_KEY")
29
- if api_key is None:
30
- raise MistralException(message="API key not provided. Please set MISTRAL_API_KEY environment variable.")
31
- self._api_key = api_key
32
- self._endpoint = endpoint
33
- self._logger = logging.getLogger(__name__)
34
-
35
- # For azure endpoints, we default to the mistral model
36
- if "inference.azure.com" in self._endpoint:
37
- self._default_model = "mistral"
38
-
39
- self._version = CLIENT_VERSION
40
-
41
- def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
42
- parsed_tools: List[Dict[str, Any]] = []
43
- for tool in tools:
44
- if tool["type"] == "function":
45
- parsed_function = {}
46
- parsed_function["type"] = tool["type"]
47
- if isinstance(tool["function"], Function):
48
- parsed_function["function"] = tool["function"].model_dump(exclude_none=True)
49
- else:
50
- parsed_function["function"] = tool["function"]
51
-
52
- parsed_tools.append(parsed_function)
53
-
54
- return parsed_tools
55
-
56
- def _parse_tool_choice(self, tool_choice: Union[str, ToolChoice]) -> str:
57
- if isinstance(tool_choice, ToolChoice):
58
- return tool_choice.value
59
- return tool_choice
60
-
61
- def _parse_response_format(self, response_format: Union[Dict[str, Any], ResponseFormat]) -> Dict[str, Any]:
62
- if isinstance(response_format, ResponseFormat):
63
- return response_format.model_dump(exclude_none=True)
64
- return response_format
65
-
66
- def _parse_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
67
- parsed_messages: List[Dict[str, Any]] = []
68
- for message in messages:
69
- if isinstance(message, ChatMessage):
70
- parsed_messages.append(message.model_dump(exclude_none=True))
71
- else:
72
- parsed_messages.append(message)
73
-
74
- return parsed_messages
75
-
76
- def _make_completion_request(
77
- self,
78
- prompt: str,
79
- model: Optional[str] = None,
80
- suffix: Optional[str] = None,
81
- temperature: Optional[float] = None,
82
- max_tokens: Optional[int] = None,
83
- top_p: Optional[float] = None,
84
- random_seed: Optional[int] = None,
85
- stop: Optional[List[str]] = None,
86
- stream: Optional[bool] = False,
87
- ) -> Dict[str, Any]:
88
- request_data: Dict[str, Any] = {
89
- "prompt": prompt,
90
- "suffix": suffix,
91
- "model": model,
92
- "stream": stream,
93
- }
94
-
95
- if stop is not None:
96
- request_data["stop"] = stop
97
-
98
- if model is not None:
99
- request_data["model"] = model
100
- else:
101
- if self._default_model is None:
102
- raise MistralException(message="model must be provided")
103
- request_data["model"] = self._default_model
104
-
105
- request_data.update(
106
- self._build_sampling_params(
107
- temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
108
- )
109
- )
110
-
111
- self._logger.debug(f"Completion request: {request_data}")
112
-
113
- return request_data
114
-
115
- def _build_sampling_params(
116
- self,
117
- max_tokens: Optional[int],
118
- random_seed: Optional[int],
119
- temperature: Optional[float],
120
- top_p: Optional[float],
121
- ) -> Dict[str, Any]:
122
- params = {}
123
- if temperature is not None:
124
- params["temperature"] = temperature
125
- if max_tokens is not None:
126
- params["max_tokens"] = max_tokens
127
- if top_p is not None:
128
- params["top_p"] = top_p
129
- if random_seed is not None:
130
- params["random_seed"] = random_seed
131
- return params
132
-
133
- def _make_chat_request(
134
- self,
135
- messages: List[Any],
136
- model: Optional[str] = None,
137
- tools: Optional[List[Dict[str, Any]]] = None,
138
- temperature: Optional[float] = None,
139
- max_tokens: Optional[int] = None,
140
- top_p: Optional[float] = None,
141
- random_seed: Optional[int] = None,
142
- stream: Optional[bool] = None,
143
- safe_prompt: Optional[bool] = False,
144
- tool_choice: Optional[Union[str, ToolChoice]] = None,
145
- response_format: Optional[Union[Dict[str, str], ResponseFormat]] = None,
146
- ) -> Dict[str, Any]:
147
- request_data: Dict[str, Any] = {
148
- "messages": self._parse_messages(messages),
149
- }
150
-
151
- if model is not None:
152
- request_data["model"] = model
153
- else:
154
- if self._default_model is None:
155
- raise MistralException(message="model must be provided")
156
- request_data["model"] = self._default_model
157
-
158
- request_data.update(
159
- self._build_sampling_params(
160
- temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
161
- )
162
- )
163
-
164
- if safe_prompt:
165
- request_data["safe_prompt"] = safe_prompt
166
- if tools is not None:
167
- request_data["tools"] = self._parse_tools(tools)
168
- if stream is not None:
169
- request_data["stream"] = stream
170
-
171
- if tool_choice is not None:
172
- request_data["tool_choice"] = self._parse_tool_choice(tool_choice)
173
- if response_format is not None:
174
- request_data["response_format"] = self._parse_response_format(response_format)
175
-
176
- self._logger.debug(f"Chat request: {request_data}")
177
-
178
- return request_data
179
-
180
- def _process_line(self, line: str) -> Optional[Dict[str, Any]]:
181
- if line.startswith("data: "):
182
- line = line[6:].strip()
183
- if line != "[DONE]":
184
- json_streamed_response: Dict[str, Any] = orjson.loads(line)
185
- return json_streamed_response
186
- return None
mistralai/constants.py DELETED
@@ -1,3 +0,0 @@
1
- RETRY_STATUS_CODES = {429, 500, 502, 503, 504}
2
-
3
- ENDPOINT = "https://api.mistral.ai"
mistralai/exceptions.py DELETED
@@ -1,54 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Any, Dict, Optional
4
-
5
- from httpx import Response
6
-
7
-
8
- class MistralException(Exception):
9
- """Base Exception class, returned when nothing more specific applies"""
10
-
11
- def __init__(self, message: Optional[str] = None) -> None:
12
- super(MistralException, self).__init__(message)
13
-
14
- self.message = message
15
-
16
- def __str__(self) -> str:
17
- msg = self.message or "<empty message>"
18
- return msg
19
-
20
- def __repr__(self) -> str:
21
- return f"{self.__class__.__name__}(message={str(self)})"
22
-
23
-
24
- class MistralAPIException(MistralException):
25
- """Returned when the API responds with an error message"""
26
-
27
- def __init__(
28
- self,
29
- message: Optional[str] = None,
30
- http_status: Optional[int] = None,
31
- headers: Optional[Dict[str, Any]] = None,
32
- ) -> None:
33
- super().__init__(message)
34
- self.http_status = http_status
35
- self.headers = headers or {}
36
-
37
- @classmethod
38
- def from_response(cls, response: Response, message: Optional[str] = None) -> MistralAPIException:
39
- return cls(
40
- message=message or response.text,
41
- http_status=response.status_code,
42
- headers=dict(response.headers),
43
- )
44
-
45
- def __repr__(self) -> str:
46
- return f"{self.__class__.__name__}(message={str(self)}, http_status={self.http_status})"
47
-
48
-
49
- class MistralAPIStatusException(MistralAPIException):
50
- """Returned when we receive a non-200 response from the API that we should retry"""
51
-
52
-
53
- class MistralConnectionException(MistralException):
54
- """Returned when the SDK can not reach the API server for any reason"""
mistralai/jobs.py DELETED
@@ -1,172 +0,0 @@
1
- from datetime import datetime
2
- from typing import Any, Optional, Union
3
-
4
- from mistralai.exceptions import (
5
- MistralException,
6
- )
7
- from mistralai.models.jobs import DetailedJob, IntegrationIn, Job, JobMetadata, JobQueryFilter, Jobs, TrainingParameters
8
-
9
-
10
- class JobsClient:
11
- def __init__(self, client: Any):
12
- self.client = client
13
-
14
- def create(
15
- self,
16
- model: str,
17
- training_files: Union[list[str], None] = None,
18
- validation_files: Union[list[str], None] = None,
19
- hyperparameters: TrainingParameters = TrainingParameters(
20
- training_steps=1800,
21
- learning_rate=1.0e-4,
22
- ),
23
- suffix: Union[str, None] = None,
24
- integrations: Union[set[IntegrationIn], None] = None,
25
- training_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
26
- validation_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
27
- dry_run: bool = False,
28
- ) -> Union[Job, JobMetadata]:
29
- # Handle deprecated arguments
30
- if not training_files and training_file:
31
- training_files = [training_file]
32
- if not validation_files and validation_file:
33
- validation_files = [validation_file]
34
- single_response = self.client._request(
35
- method="post",
36
- json={
37
- "model": model,
38
- "training_files": training_files,
39
- "validation_files": validation_files,
40
- "hyperparameters": hyperparameters.dict(),
41
- "suffix": suffix,
42
- "integrations": integrations,
43
- },
44
- path="v1/fine_tuning/jobs",
45
- params={"dry_run": dry_run},
46
- )
47
- for response in single_response:
48
- return Job(**response) if not dry_run else JobMetadata(**response)
49
- raise MistralException("No response received")
50
-
51
- def retrieve(self, job_id: str) -> DetailedJob:
52
- single_response = self.client._request(method="get", path=f"v1/fine_tuning/jobs/{job_id}", json={})
53
- for response in single_response:
54
- return DetailedJob(**response)
55
- raise MistralException("No response received")
56
-
57
- def list(
58
- self,
59
- page: int = 0,
60
- page_size: int = 10,
61
- model: Optional[str] = None,
62
- created_after: Optional[datetime] = None,
63
- created_by_me: Optional[bool] = None,
64
- status: Optional[str] = None,
65
- wandb_project: Optional[str] = None,
66
- wandb_name: Optional[str] = None,
67
- suffix: Optional[str] = None,
68
- ) -> Jobs:
69
- query_params = JobQueryFilter(
70
- page=page,
71
- page_size=page_size,
72
- model=model,
73
- created_after=created_after,
74
- created_by_me=created_by_me,
75
- status=status,
76
- wandb_project=wandb_project,
77
- wandb_name=wandb_name,
78
- suffix=suffix,
79
- ).model_dump(exclude_none=True)
80
- single_response = self.client._request(method="get", params=query_params, path="v1/fine_tuning/jobs", json={})
81
- for response in single_response:
82
- return Jobs(**response)
83
- raise MistralException("No response received")
84
-
85
- def cancel(self, job_id: str) -> DetailedJob:
86
- single_response = self.client._request(method="post", path=f"v1/fine_tuning/jobs/{job_id}/cancel", json={})
87
- for response in single_response:
88
- return DetailedJob(**response)
89
- raise MistralException("No response received")
90
-
91
-
92
- class JobsAsyncClient:
93
- def __init__(self, client: Any):
94
- self.client = client
95
-
96
- async def create(
97
- self,
98
- model: str,
99
- training_files: Union[list[str], None] = None,
100
- validation_files: Union[list[str], None] = None,
101
- hyperparameters: TrainingParameters = TrainingParameters(
102
- training_steps=1800,
103
- learning_rate=1.0e-4,
104
- ),
105
- suffix: Union[str, None] = None,
106
- integrations: Union[set[IntegrationIn], None] = None,
107
- training_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
108
- validation_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
109
- dry_run: bool = False,
110
- ) -> Union[Job, JobMetadata]:
111
- # Handle deprecated arguments
112
- if not training_files and training_file:
113
- training_files = [training_file]
114
- if not validation_files and validation_file:
115
- validation_files = [validation_file]
116
-
117
- single_response = self.client._request(
118
- method="post",
119
- json={
120
- "model": model,
121
- "training_files": training_files,
122
- "validation_files": validation_files,
123
- "hyperparameters": hyperparameters.dict(),
124
- "suffix": suffix,
125
- "integrations": integrations,
126
- },
127
- path="v1/fine_tuning/jobs",
128
- params={"dry_run": dry_run},
129
- )
130
- async for response in single_response:
131
- return Job(**response) if not dry_run else JobMetadata(**response)
132
- raise MistralException("No response received")
133
-
134
- async def retrieve(self, job_id: str) -> DetailedJob:
135
- single_response = self.client._request(method="get", path=f"v1/fine_tuning/jobs/{job_id}", json={})
136
- async for response in single_response:
137
- return DetailedJob(**response)
138
- raise MistralException("No response received")
139
-
140
- async def list(
141
- self,
142
- page: int = 0,
143
- page_size: int = 10,
144
- model: Optional[str] = None,
145
- created_after: Optional[datetime] = None,
146
- created_by_me: Optional[bool] = None,
147
- status: Optional[str] = None,
148
- wandb_project: Optional[str] = None,
149
- wandb_name: Optional[str] = None,
150
- suffix: Optional[str] = None,
151
- ) -> Jobs:
152
- query_params = JobQueryFilter(
153
- page=page,
154
- page_size=page_size,
155
- model=model,
156
- created_after=created_after,
157
- created_by_me=created_by_me,
158
- status=status,
159
- wandb_project=wandb_project,
160
- wandb_name=wandb_name,
161
- suffix=suffix,
162
- ).model_dump(exclude_none=True)
163
- single_response = self.client._request(method="get", path="v1/fine_tuning/jobs", params=query_params, json={})
164
- async for response in single_response:
165
- return Jobs(**response)
166
- raise MistralException("No response received")
167
-
168
- async def cancel(self, job_id: str) -> DetailedJob:
169
- single_response = self.client._request(method="post", path=f"v1/fine_tuning/jobs/{job_id}/cancel", json={})
170
- async for response in single_response:
171
- return DetailedJob(**response)
172
- raise MistralException("No response received")
@@ -1,93 +0,0 @@
1
- from enum import Enum
2
- from typing import List, Optional
3
-
4
- from pydantic import BaseModel
5
-
6
- from mistralai.models.common import UsageInfo
7
-
8
-
9
- class Function(BaseModel):
10
- name: str
11
- description: str
12
- parameters: dict
13
-
14
-
15
- class ToolType(str, Enum):
16
- function = "function"
17
-
18
-
19
- class FunctionCall(BaseModel):
20
- name: str
21
- arguments: str
22
-
23
-
24
- class ToolCall(BaseModel):
25
- id: str = "null"
26
- type: ToolType = ToolType.function
27
- function: FunctionCall
28
-
29
-
30
- class ResponseFormats(str, Enum):
31
- text: str = "text"
32
- json_object: str = "json_object"
33
-
34
-
35
- class ToolChoice(str, Enum):
36
- auto: str = "auto"
37
- any: str = "any"
38
- none: str = "none"
39
-
40
-
41
- class ResponseFormat(BaseModel):
42
- type: ResponseFormats = ResponseFormats.text
43
-
44
-
45
- class ChatMessage(BaseModel):
46
- role: str
47
- content: str
48
- name: Optional[str] = None
49
- tool_calls: Optional[List[ToolCall]] = None
50
- tool_call_id: Optional[str] = None
51
-
52
-
53
- class DeltaMessage(BaseModel):
54
- role: Optional[str] = None
55
- content: Optional[str] = None
56
- tool_calls: Optional[List[ToolCall]] = None
57
-
58
-
59
- class FinishReason(str, Enum):
60
- stop = "stop"
61
- length = "length"
62
- error = "error"
63
- tool_calls = "tool_calls"
64
-
65
-
66
- class ChatCompletionResponseStreamChoice(BaseModel):
67
- index: int
68
- delta: DeltaMessage
69
- finish_reason: Optional[FinishReason]
70
-
71
-
72
- class ChatCompletionStreamResponse(BaseModel):
73
- id: str
74
- model: str
75
- choices: List[ChatCompletionResponseStreamChoice]
76
- created: Optional[int] = None
77
- object: Optional[str] = None
78
- usage: Optional[UsageInfo] = None
79
-
80
-
81
- class ChatCompletionResponseChoice(BaseModel):
82
- index: int
83
- message: ChatMessage
84
- finish_reason: Optional[FinishReason]
85
-
86
-
87
- class ChatCompletionResponse(BaseModel):
88
- id: str
89
- object: str
90
- created: int
91
- model: str
92
- choices: List[ChatCompletionResponseChoice]
93
- usage: UsageInfo
@@ -1,9 +0,0 @@
1
- from typing import Optional
2
-
3
- from pydantic import BaseModel
4
-
5
-
6
- class UsageInfo(BaseModel):
7
- prompt_tokens: int
8
- total_tokens: int
9
- completion_tokens: Optional[int]
@@ -1,19 +0,0 @@
1
- from typing import List
2
-
3
- from pydantic import BaseModel
4
-
5
- from mistralai.models.common import UsageInfo
6
-
7
-
8
- class EmbeddingObject(BaseModel):
9
- object: str
10
- embedding: List[float]
11
- index: int
12
-
13
-
14
- class EmbeddingResponse(BaseModel):
15
- id: str
16
- object: str
17
- data: List[EmbeddingObject]
18
- model: str
19
- usage: UsageInfo
mistralai/models/files.py DELETED
@@ -1,23 +0,0 @@
1
- from typing import Literal, Optional
2
-
3
- from pydantic import BaseModel
4
-
5
-
6
- class FileObject(BaseModel):
7
- id: str
8
- object: str
9
- bytes: int
10
- created_at: int
11
- filename: str
12
- purpose: Optional[Literal["fine-tune"]] = "fine-tune"
13
-
14
-
15
- class FileDeleted(BaseModel):
16
- id: str
17
- object: str
18
- deleted: bool
19
-
20
-
21
- class Files(BaseModel):
22
- data: list[FileObject]
23
- object: Literal["list"]
mistralai/models/jobs.py DELETED
@@ -1,98 +0,0 @@
1
- from datetime import datetime
2
- from typing import Annotated, List, Literal, Optional, Union
3
-
4
- from pydantic import BaseModel, Field
5
-
6
-
7
- class TrainingParameters(BaseModel):
8
- training_steps: int = Field(1800, le=10000, ge=1)
9
- learning_rate: float = Field(1.0e-4, le=1, ge=1.0e-8)
10
-
11
-
12
- class WandbIntegration(BaseModel):
13
- type: Literal["wandb"] = "wandb"
14
- project: str
15
- name: Union[str, None] = None
16
- run_name: Union[str, None] = None
17
-
18
-
19
- class WandbIntegrationIn(WandbIntegration):
20
- api_key: str
21
-
22
-
23
- Integration = Annotated[Union[WandbIntegration], Field(discriminator="type")]
24
- IntegrationIn = Annotated[Union[WandbIntegrationIn], Field(discriminator="type")]
25
-
26
-
27
- class JobMetadata(BaseModel):
28
- object: Literal["job.metadata"] = "job.metadata"
29
- training_steps: int
30
- train_tokens_per_step: int
31
- data_tokens: int
32
- train_tokens: int
33
- epochs: float
34
- expected_duration_seconds: Optional[int]
35
-
36
-
37
- class Job(BaseModel):
38
- id: str
39
- hyperparameters: TrainingParameters
40
- fine_tuned_model: Union[str, None]
41
- model: str
42
- status: Literal[
43
- "QUEUED",
44
- "STARTED",
45
- "RUNNING",
46
- "FAILED",
47
- "SUCCESS",
48
- "CANCELLED",
49
- "CANCELLATION_REQUESTED",
50
- ]
51
- job_type: str
52
- created_at: int
53
- modified_at: int
54
- training_files: list[str]
55
- validation_files: Union[list[str], None] = []
56
- object: Literal["job"]
57
- integrations: List[Integration] = []
58
-
59
-
60
- class Event(BaseModel):
61
- name: str
62
- data: Union[dict, None] = None
63
- created_at: int
64
-
65
-
66
- class Metric(BaseModel):
67
- train_loss: Union[float, None] = None
68
- valid_loss: Union[float, None] = None
69
- valid_mean_token_accuracy: Union[float, None] = None
70
-
71
-
72
- class Checkpoint(BaseModel):
73
- metrics: Metric
74
- step_number: int
75
- created_at: int
76
-
77
-
78
- class JobQueryFilter(BaseModel):
79
- page: int = 0
80
- page_size: int = 100
81
- model: Optional[str] = None
82
- created_after: Optional[datetime] = None
83
- created_by_me: Optional[bool] = None
84
- status: Optional[str] = None
85
- wandb_project: Optional[str] = None
86
- wandb_name: Optional[str] = None
87
- suffix: Optional[str] = None
88
-
89
-
90
- class DetailedJob(Job):
91
- events: list[Event] = []
92
- checkpoints: list[Checkpoint] = []
93
- estimated_start_time: Optional[int] = None
94
-
95
-
96
- class Jobs(BaseModel):
97
- data: list[Job] = []
98
- object: Literal["list"]
@@ -1,39 +0,0 @@
1
- from typing import List, Optional
2
-
3
- from pydantic import BaseModel
4
-
5
-
6
- class ModelPermission(BaseModel):
7
- id: str
8
- object: str
9
- created: int
10
- allow_create_engine: Optional[bool] = False
11
- allow_sampling: bool = True
12
- allow_logprobs: bool = True
13
- allow_search_indices: Optional[bool] = False
14
- allow_view: bool = True
15
- allow_fine_tuning: bool = False
16
- organization: str = "*"
17
- group: Optional[str] = None
18
- is_blocking: Optional[bool] = False
19
-
20
-
21
- class ModelCard(BaseModel):
22
- id: str
23
- object: str
24
- created: int
25
- owned_by: str
26
- root: Optional[str] = None
27
- parent: Optional[str] = None
28
- permission: List[ModelPermission] = []
29
-
30
-
31
- class ModelList(BaseModel):
32
- object: str
33
- data: List[ModelCard]
34
-
35
-
36
- class ModelDeleted(BaseModel):
37
- id: str
38
- object: str
39
- deleted: bool