mistralai 1.0.3__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (230) hide show
  1. mistralai/__init__.py +4 -0
  2. mistralai/_hooks/sdkhooks.py +23 -4
  3. mistralai/_hooks/types.py +27 -9
  4. mistralai/_version.py +12 -0
  5. mistralai/agents.py +334 -164
  6. mistralai/basesdk.py +90 -5
  7. mistralai/batch.py +17 -0
  8. mistralai/chat.py +316 -166
  9. mistralai/classifiers.py +396 -0
  10. mistralai/embeddings.py +79 -55
  11. mistralai/files.py +487 -194
  12. mistralai/fim.py +206 -132
  13. mistralai/fine_tuning.py +3 -2
  14. mistralai/jobs.py +392 -263
  15. mistralai/mistral_jobs.py +733 -0
  16. mistralai/models/__init__.py +593 -50
  17. mistralai/models/agentscompletionrequest.py +70 -17
  18. mistralai/models/agentscompletionstreamrequest.py +72 -17
  19. mistralai/models/apiendpoint.py +9 -0
  20. mistralai/models/archiveftmodelout.py +15 -5
  21. mistralai/models/assistantmessage.py +22 -10
  22. mistralai/models/{modelcard.py → basemodelcard.py} +53 -14
  23. mistralai/models/batcherror.py +17 -0
  24. mistralai/models/batchjobin.py +58 -0
  25. mistralai/models/batchjobout.py +117 -0
  26. mistralai/models/batchjobsout.py +30 -0
  27. mistralai/models/batchjobstatus.py +15 -0
  28. mistralai/models/chatclassificationrequest.py +104 -0
  29. mistralai/models/chatcompletionchoice.py +13 -6
  30. mistralai/models/chatcompletionrequest.py +86 -21
  31. mistralai/models/chatcompletionresponse.py +8 -4
  32. mistralai/models/chatcompletionstreamrequest.py +88 -21
  33. mistralai/models/checkpointout.py +4 -3
  34. mistralai/models/classificationobject.py +21 -0
  35. mistralai/models/classificationrequest.py +59 -0
  36. mistralai/models/classificationresponse.py +21 -0
  37. mistralai/models/completionchunk.py +12 -5
  38. mistralai/models/completionevent.py +2 -3
  39. mistralai/models/completionresponsestreamchoice.py +22 -8
  40. mistralai/models/contentchunk.py +13 -10
  41. mistralai/models/delete_model_v1_models_model_id_deleteop.py +5 -5
  42. mistralai/models/deletefileout.py +4 -3
  43. mistralai/models/deletemodelout.py +5 -4
  44. mistralai/models/deltamessage.py +23 -11
  45. mistralai/models/detailedjobout.py +70 -12
  46. mistralai/models/embeddingrequest.py +14 -9
  47. mistralai/models/embeddingresponse.py +7 -3
  48. mistralai/models/embeddingresponsedata.py +5 -4
  49. mistralai/models/eventout.py +11 -6
  50. mistralai/models/filepurpose.py +8 -0
  51. mistralai/models/files_api_routes_delete_fileop.py +5 -5
  52. mistralai/models/files_api_routes_download_fileop.py +16 -0
  53. mistralai/models/files_api_routes_list_filesop.py +96 -0
  54. mistralai/models/files_api_routes_retrieve_fileop.py +5 -5
  55. mistralai/models/files_api_routes_upload_fileop.py +33 -14
  56. mistralai/models/fileschema.py +22 -15
  57. mistralai/models/fimcompletionrequest.py +44 -16
  58. mistralai/models/fimcompletionresponse.py +8 -4
  59. mistralai/models/fimcompletionstreamrequest.py +44 -16
  60. mistralai/models/finetuneablemodel.py +7 -1
  61. mistralai/models/ftmodelcapabilitiesout.py +6 -4
  62. mistralai/models/ftmodelcard.py +121 -0
  63. mistralai/models/ftmodelout.py +39 -9
  64. mistralai/models/function.py +5 -4
  65. mistralai/models/functioncall.py +4 -3
  66. mistralai/models/functionname.py +17 -0
  67. mistralai/models/githubrepositoryin.py +24 -7
  68. mistralai/models/githubrepositoryout.py +24 -7
  69. mistralai/models/httpvalidationerror.py +1 -3
  70. mistralai/models/imageurl.py +47 -0
  71. mistralai/models/imageurlchunk.py +38 -0
  72. mistralai/models/jobin.py +24 -7
  73. mistralai/models/jobmetadataout.py +32 -8
  74. mistralai/models/jobout.py +65 -12
  75. mistralai/models/jobs_api_routes_batch_cancel_batch_jobop.py +16 -0
  76. mistralai/models/jobs_api_routes_batch_get_batch_jobop.py +16 -0
  77. mistralai/models/jobs_api_routes_batch_get_batch_jobsop.py +95 -0
  78. mistralai/models/jobs_api_routes_fine_tuning_archive_fine_tuned_modelop.py +5 -5
  79. mistralai/models/jobs_api_routes_fine_tuning_cancel_fine_tuning_jobop.py +5 -5
  80. mistralai/models/jobs_api_routes_fine_tuning_create_fine_tuning_jobop.py +3 -2
  81. mistralai/models/jobs_api_routes_fine_tuning_get_fine_tuning_jobop.py +5 -5
  82. mistralai/models/jobs_api_routes_fine_tuning_get_fine_tuning_jobsop.py +85 -18
  83. mistralai/models/jobs_api_routes_fine_tuning_start_fine_tuning_jobop.py +5 -5
  84. mistralai/models/jobs_api_routes_fine_tuning_unarchive_fine_tuned_modelop.py +5 -5
  85. mistralai/models/jobs_api_routes_fine_tuning_update_fine_tuned_modelop.py +10 -6
  86. mistralai/models/jobsout.py +13 -5
  87. mistralai/models/legacyjobmetadataout.py +55 -9
  88. mistralai/models/listfilesout.py +7 -3
  89. mistralai/models/metricout.py +12 -8
  90. mistralai/models/modelcapabilities.py +9 -4
  91. mistralai/models/modellist.py +21 -7
  92. mistralai/models/responseformat.py +7 -8
  93. mistralai/models/responseformats.py +8 -0
  94. mistralai/models/retrieve_model_v1_models_model_id_getop.py +25 -6
  95. mistralai/models/retrievefileout.py +25 -15
  96. mistralai/models/sampletype.py +6 -2
  97. mistralai/models/security.py +14 -5
  98. mistralai/models/source.py +3 -2
  99. mistralai/models/systemmessage.py +10 -9
  100. mistralai/models/textchunk.py +14 -5
  101. mistralai/models/tool.py +10 -9
  102. mistralai/models/toolcall.py +10 -8
  103. mistralai/models/toolchoice.py +29 -0
  104. mistralai/models/toolchoiceenum.py +7 -0
  105. mistralai/models/toolmessage.py +13 -6
  106. mistralai/models/tooltypes.py +8 -0
  107. mistralai/models/trainingfile.py +4 -4
  108. mistralai/models/trainingparameters.py +34 -8
  109. mistralai/models/trainingparametersin.py +36 -10
  110. mistralai/models/unarchiveftmodelout.py +15 -5
  111. mistralai/models/updateftmodelin.py +9 -6
  112. mistralai/models/uploadfileout.py +22 -15
  113. mistralai/models/usageinfo.py +4 -3
  114. mistralai/models/usermessage.py +42 -10
  115. mistralai/models/validationerror.py +5 -3
  116. mistralai/models/wandbintegration.py +23 -7
  117. mistralai/models/wandbintegrationout.py +23 -8
  118. mistralai/models_.py +416 -294
  119. mistralai/sdk.py +31 -19
  120. mistralai/sdkconfiguration.py +9 -11
  121. mistralai/utils/__init__.py +14 -1
  122. mistralai/utils/annotations.py +13 -2
  123. mistralai/utils/logger.py +4 -1
  124. mistralai/utils/retries.py +2 -1
  125. mistralai/utils/security.py +13 -6
  126. mistralai/utils/serializers.py +25 -0
  127. {mistralai-1.0.3.dist-info → mistralai-1.2.0.dist-info}/METADATA +171 -66
  128. mistralai-1.2.0.dist-info/RECORD +276 -0
  129. {mistralai-1.0.3.dist-info → mistralai-1.2.0.dist-info}/WHEEL +1 -1
  130. mistralai_azure/__init__.py +4 -0
  131. mistralai_azure/_hooks/sdkhooks.py +23 -4
  132. mistralai_azure/_hooks/types.py +27 -9
  133. mistralai_azure/_version.py +12 -0
  134. mistralai_azure/basesdk.py +91 -6
  135. mistralai_azure/chat.py +308 -166
  136. mistralai_azure/models/__init__.py +164 -16
  137. mistralai_azure/models/assistantmessage.py +29 -11
  138. mistralai_azure/models/chatcompletionchoice.py +15 -6
  139. mistralai_azure/models/chatcompletionrequest.py +94 -22
  140. mistralai_azure/models/chatcompletionresponse.py +8 -4
  141. mistralai_azure/models/chatcompletionstreamrequest.py +96 -22
  142. mistralai_azure/models/completionchunk.py +12 -5
  143. mistralai_azure/models/completionevent.py +2 -3
  144. mistralai_azure/models/completionresponsestreamchoice.py +19 -8
  145. mistralai_azure/models/contentchunk.py +4 -11
  146. mistralai_azure/models/deltamessage.py +30 -12
  147. mistralai_azure/models/function.py +5 -4
  148. mistralai_azure/models/functioncall.py +4 -3
  149. mistralai_azure/models/functionname.py +17 -0
  150. mistralai_azure/models/httpvalidationerror.py +1 -3
  151. mistralai_azure/models/responseformat.py +7 -8
  152. mistralai_azure/models/responseformats.py +8 -0
  153. mistralai_azure/models/security.py +13 -5
  154. mistralai_azure/models/systemmessage.py +10 -9
  155. mistralai_azure/models/textchunk.py +14 -5
  156. mistralai_azure/models/tool.py +10 -9
  157. mistralai_azure/models/toolcall.py +10 -8
  158. mistralai_azure/models/toolchoice.py +29 -0
  159. mistralai_azure/models/toolchoiceenum.py +7 -0
  160. mistralai_azure/models/toolmessage.py +20 -7
  161. mistralai_azure/models/tooltypes.py +8 -0
  162. mistralai_azure/models/usageinfo.py +4 -3
  163. mistralai_azure/models/usermessage.py +42 -10
  164. mistralai_azure/models/validationerror.py +5 -3
  165. mistralai_azure/sdkconfiguration.py +9 -11
  166. mistralai_azure/utils/__init__.py +16 -3
  167. mistralai_azure/utils/annotations.py +13 -2
  168. mistralai_azure/utils/forms.py +10 -9
  169. mistralai_azure/utils/headers.py +8 -8
  170. mistralai_azure/utils/logger.py +6 -0
  171. mistralai_azure/utils/queryparams.py +16 -14
  172. mistralai_azure/utils/retries.py +2 -1
  173. mistralai_azure/utils/security.py +12 -6
  174. mistralai_azure/utils/serializers.py +42 -8
  175. mistralai_azure/utils/url.py +13 -8
  176. mistralai_azure/utils/values.py +6 -0
  177. mistralai_gcp/__init__.py +4 -0
  178. mistralai_gcp/_hooks/sdkhooks.py +23 -4
  179. mistralai_gcp/_hooks/types.py +27 -9
  180. mistralai_gcp/_version.py +12 -0
  181. mistralai_gcp/basesdk.py +91 -6
  182. mistralai_gcp/chat.py +308 -166
  183. mistralai_gcp/fim.py +198 -132
  184. mistralai_gcp/models/__init__.py +186 -18
  185. mistralai_gcp/models/assistantmessage.py +29 -11
  186. mistralai_gcp/models/chatcompletionchoice.py +15 -6
  187. mistralai_gcp/models/chatcompletionrequest.py +91 -22
  188. mistralai_gcp/models/chatcompletionresponse.py +8 -4
  189. mistralai_gcp/models/chatcompletionstreamrequest.py +93 -22
  190. mistralai_gcp/models/completionchunk.py +12 -5
  191. mistralai_gcp/models/completionevent.py +2 -3
  192. mistralai_gcp/models/completionresponsestreamchoice.py +19 -8
  193. mistralai_gcp/models/contentchunk.py +4 -11
  194. mistralai_gcp/models/deltamessage.py +30 -12
  195. mistralai_gcp/models/fimcompletionrequest.py +51 -17
  196. mistralai_gcp/models/fimcompletionresponse.py +8 -4
  197. mistralai_gcp/models/fimcompletionstreamrequest.py +51 -17
  198. mistralai_gcp/models/function.py +5 -4
  199. mistralai_gcp/models/functioncall.py +4 -3
  200. mistralai_gcp/models/functionname.py +17 -0
  201. mistralai_gcp/models/httpvalidationerror.py +1 -3
  202. mistralai_gcp/models/responseformat.py +7 -8
  203. mistralai_gcp/models/responseformats.py +8 -0
  204. mistralai_gcp/models/security.py +13 -5
  205. mistralai_gcp/models/systemmessage.py +10 -9
  206. mistralai_gcp/models/textchunk.py +14 -5
  207. mistralai_gcp/models/tool.py +10 -9
  208. mistralai_gcp/models/toolcall.py +10 -8
  209. mistralai_gcp/models/toolchoice.py +29 -0
  210. mistralai_gcp/models/toolchoiceenum.py +7 -0
  211. mistralai_gcp/models/toolmessage.py +20 -7
  212. mistralai_gcp/models/tooltypes.py +8 -0
  213. mistralai_gcp/models/usageinfo.py +4 -3
  214. mistralai_gcp/models/usermessage.py +42 -10
  215. mistralai_gcp/models/validationerror.py +5 -3
  216. mistralai_gcp/sdk.py +6 -7
  217. mistralai_gcp/sdkconfiguration.py +9 -11
  218. mistralai_gcp/utils/__init__.py +16 -3
  219. mistralai_gcp/utils/annotations.py +13 -2
  220. mistralai_gcp/utils/forms.py +10 -9
  221. mistralai_gcp/utils/headers.py +8 -8
  222. mistralai_gcp/utils/logger.py +6 -0
  223. mistralai_gcp/utils/queryparams.py +16 -14
  224. mistralai_gcp/utils/retries.py +2 -1
  225. mistralai_gcp/utils/security.py +12 -6
  226. mistralai_gcp/utils/serializers.py +42 -8
  227. mistralai_gcp/utils/url.py +13 -8
  228. mistralai_gcp/utils/values.py +6 -0
  229. mistralai-1.0.3.dist-info/RECORD +0 -236
  230. {mistralai-1.0.3.dist-info → mistralai-1.2.0.dist-info}/LICENSE +0 -0
mistralai/sdk.py CHANGED
@@ -9,7 +9,9 @@ import httpx
9
9
  from mistralai import models, utils
10
10
  from mistralai._hooks import SDKHooks
11
11
  from mistralai.agents import Agents
12
+ from mistralai.batch import Batch
12
13
  from mistralai.chat import Chat
14
+ from mistralai.classifiers import Classifiers
13
15
  from mistralai.embeddings import Embeddings
14
16
  from mistralai.files import Files
15
17
  from mistralai.fim import Fim
@@ -18,13 +20,16 @@ from mistralai.models_ import Models
18
20
  from mistralai.types import OptionalNullable, UNSET
19
21
  from typing import Any, Callable, Dict, Optional, Union
20
22
 
23
+
21
24
  class Mistral(BaseSDK):
22
25
  r"""Mistral AI API: Our Chat Completion and Embeddings APIs specification. Create your account on [La Plateforme](https://console.mistral.ai) to get access and read the [docs](https://docs.mistral.ai) to learn how to use it."""
26
+
23
27
  models: Models
24
28
  r"""Model Management API"""
25
29
  files: Files
26
30
  r"""Files API"""
27
31
  fine_tuning: FineTuning
32
+ batch: Batch
28
33
  chat: Chat
29
34
  r"""Chat Completion API."""
30
35
  fim: Fim
@@ -33,6 +38,9 @@ class Mistral(BaseSDK):
33
38
  r"""Agents API."""
34
39
  embeddings: Embeddings
35
40
  r"""Embeddings API."""
41
+ classifiers: Classifiers
42
+ r"""Classifiers API."""
43
+
36
44
  def __init__(
37
45
  self,
38
46
  api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
@@ -43,7 +51,7 @@ class Mistral(BaseSDK):
43
51
  async_client: Optional[AsyncHttpClient] = None,
44
52
  retry_config: OptionalNullable[RetryConfig] = UNSET,
45
53
  timeout_ms: Optional[int] = None,
46
- debug_logger: Optional[Logger] = None
54
+ debug_logger: Optional[Logger] = None,
47
55
  ) -> None:
48
56
  r"""Instantiates the SDK configuring it with the provided parameters.
49
57
 
@@ -72,33 +80,37 @@ class Mistral(BaseSDK):
72
80
  assert issubclass(
73
81
  type(async_client), AsyncHttpClient
74
82
  ), "The provided async_client must implement the AsyncHttpClient protocol."
75
-
83
+
76
84
  security: Any = None
77
85
  if callable(api_key):
78
- security = lambda: models.Security(api_key = api_key()) # pylint: disable=unnecessary-lambda-assignment
86
+ security = lambda: models.Security(api_key=api_key()) # pylint: disable=unnecessary-lambda-assignment
79
87
  else:
80
- security = models.Security(api_key = api_key)
88
+ security = models.Security(api_key=api_key)
81
89
 
82
90
  if server_url is not None:
83
91
  if url_params is not None:
84
92
  server_url = utils.template_url(server_url, url_params)
85
-
86
-
87
- BaseSDK.__init__(self, SDKConfiguration(
88
- client=client,
89
- async_client=async_client,
90
- security=security,
91
- server_url=server_url,
92
- server=server,
93
- retry_config=retry_config,
94
- timeout_ms=timeout_ms,
95
- debug_logger=debug_logger
96
- ))
93
+
94
+ BaseSDK.__init__(
95
+ self,
96
+ SDKConfiguration(
97
+ client=client,
98
+ async_client=async_client,
99
+ security=security,
100
+ server_url=server_url,
101
+ server=server,
102
+ retry_config=retry_config,
103
+ timeout_ms=timeout_ms,
104
+ debug_logger=debug_logger,
105
+ ),
106
+ )
97
107
 
98
108
  hooks = SDKHooks()
99
109
 
100
110
  current_server_url, *_ = self.sdk_configuration.get_server_details()
101
- server_url, self.sdk_configuration.client = hooks.sdk_init(current_server_url, self.sdk_configuration.client)
111
+ server_url, self.sdk_configuration.client = hooks.sdk_init(
112
+ current_server_url, self.sdk_configuration.client
113
+ )
102
114
  if current_server_url != server_url:
103
115
  self.sdk_configuration.server_url = server_url
104
116
 
@@ -107,13 +119,13 @@ class Mistral(BaseSDK):
107
119
 
108
120
  self._init_sdks()
109
121
 
110
-
111
122
  def _init_sdks(self):
112
123
  self.models = Models(self.sdk_configuration)
113
124
  self.files = Files(self.sdk_configuration)
114
125
  self.fine_tuning = FineTuning(self.sdk_configuration)
126
+ self.batch = Batch(self.sdk_configuration)
115
127
  self.chat = Chat(self.sdk_configuration)
116
128
  self.fim = Fim(self.sdk_configuration)
117
129
  self.agents = Agents(self.sdk_configuration)
118
130
  self.embeddings = Embeddings(self.sdk_configuration)
119
-
131
+ self.classifiers = Classifiers(self.sdk_configuration)
@@ -1,6 +1,5 @@
1
1
  """Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
2
2
 
3
-
4
3
  from ._hooks import SDKHooks
5
4
  from .httpclient import AsyncHttpClient, HttpClient
6
5
  from .utils import Logger, RetryConfig, remove_suffix
@@ -11,10 +10,10 @@ from pydantic import Field
11
10
  from typing import Callable, Dict, Optional, Tuple, Union
12
11
 
13
12
 
14
- SERVER_PROD = "prod"
15
- r"""Production server"""
13
+ SERVER_EU = "eu"
14
+ r"""EU Production server"""
16
15
  SERVERS = {
17
- SERVER_PROD: "https://api.mistral.ai",
16
+ SERVER_EU: "https://api.mistral.ai",
18
17
  }
19
18
  """Contains the list of servers available to the SDK"""
20
19
 
@@ -24,14 +23,14 @@ class SDKConfiguration:
24
23
  client: HttpClient
25
24
  async_client: AsyncHttpClient
26
25
  debug_logger: Logger
27
- security: Optional[Union[models.Security,Callable[[], models.Security]]] = None
26
+ security: Optional[Union[models.Security, Callable[[], models.Security]]] = None
28
27
  server_url: Optional[str] = ""
29
28
  server: Optional[str] = ""
30
29
  language: str = "python"
31
30
  openapi_doc_version: str = "0.0.2"
32
- sdk_version: str = "1.0.3"
33
- gen_version: str = "2.404.11"
34
- user_agent: str = "speakeasy-sdk/python 1.0.3 2.404.11 0.0.2 mistralai"
31
+ sdk_version: str = "1.2.0"
32
+ gen_version: str = "2.452.0"
33
+ user_agent: str = "speakeasy-sdk/python 1.2.0 2.452.0 0.0.2 mistralai"
35
34
  retry_config: OptionalNullable[RetryConfig] = Field(default_factory=lambda: UNSET)
36
35
  timeout_ms: Optional[int] = None
37
36
 
@@ -42,13 +41,12 @@ class SDKConfiguration:
42
41
  if self.server_url is not None and self.server_url:
43
42
  return remove_suffix(self.server_url, "/"), {}
44
43
  if not self.server:
45
- self.server = SERVER_PROD
44
+ self.server = SERVER_EU
46
45
 
47
46
  if self.server not in SERVERS:
48
- raise ValueError(f"Invalid server \"{self.server}\"")
47
+ raise ValueError(f'Invalid server "{self.server}"')
49
48
 
50
49
  return SERVERS[self.server], {}
51
50
 
52
-
53
51
  def get_hooks(self) -> SDKHooks:
54
52
  return self._hooks
@@ -28,13 +28,22 @@ from .serializers import (
28
28
  serialize_float,
29
29
  serialize_int,
30
30
  stream_to_text,
31
+ stream_to_text_async,
32
+ stream_to_bytes,
33
+ stream_to_bytes_async,
34
+ validate_const,
31
35
  validate_decimal,
32
36
  validate_float,
33
37
  validate_int,
34
38
  validate_open_enum,
35
39
  )
36
40
  from .url import generate_url, template_url, remove_suffix
37
- from .values import get_global_from_env, match_content_type, match_status_codes, match_response
41
+ from .values import (
42
+ get_global_from_env,
43
+ match_content_type,
44
+ match_status_codes,
45
+ match_response,
46
+ )
38
47
  from .logger import Logger, get_body_content, get_default_logger
39
48
 
40
49
  __all__ = [
@@ -76,10 +85,14 @@ __all__ = [
76
85
  "serialize_request_body",
77
86
  "SerializedRequestBody",
78
87
  "stream_to_text",
88
+ "stream_to_text_async",
89
+ "stream_to_bytes",
90
+ "stream_to_bytes_async",
79
91
  "template_url",
80
92
  "unmarshal",
81
93
  "unmarshal_json",
82
94
  "validate_decimal",
95
+ "validate_const",
83
96
  "validate_float",
84
97
  "validate_int",
85
98
  "validate_open_enum",
@@ -1,5 +1,6 @@
1
1
  """Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
2
2
 
3
+ from enum import Enum
3
4
  from typing import Any
4
5
 
5
6
  def get_discriminator(model: Any, fieldname: str, key: str) -> str:
@@ -10,10 +11,20 @@ def get_discriminator(model: Any, fieldname: str, key: str) -> str:
10
11
  raise ValueError(f'Could not find discriminator key {key} in {model}') from e
11
12
 
12
13
  if hasattr(model, fieldname):
13
- return f'{getattr(model, fieldname)}'
14
+ attr = getattr(model, fieldname)
15
+
16
+ if isinstance(attr, Enum):
17
+ return f'{attr.value}'
18
+
19
+ return f'{attr}'
14
20
 
15
21
  fieldname = fieldname.upper()
16
22
  if hasattr(model, fieldname):
17
- return f'{getattr(model, fieldname)}'
23
+ attr = getattr(model, fieldname)
24
+
25
+ if isinstance(attr, Enum):
26
+ return f'{attr.value}'
27
+
28
+ return f'{attr}'
18
29
 
19
30
  raise ValueError(f'Could not find discriminator field {fieldname} in {model}')
mistralai/utils/logger.py CHANGED
@@ -5,20 +5,23 @@ import logging
5
5
  import os
6
6
  from typing import Any, Protocol
7
7
 
8
+
8
9
  class Logger(Protocol):
9
10
  def debug(self, msg: str, *args: Any, **kwargs: Any) -> None:
10
11
  pass
11
12
 
13
+
12
14
  class NoOpLogger:
13
15
  def debug(self, msg: str, *args: Any, **kwargs: Any) -> None:
14
16
  pass
15
17
 
18
+
16
19
  def get_body_content(req: httpx.Request) -> str:
17
20
  return "<streaming body>" if not hasattr(req, "_content") else str(req.content)
18
21
 
22
+
19
23
  def get_default_logger() -> Logger:
20
24
  if os.getenv("MISTRAL_DEBUG"):
21
25
  logging.basicConfig(level=logging.DEBUG)
22
26
  return logging.getLogger("mistralai")
23
27
  return NoOpLogger()
24
-
@@ -1,5 +1,6 @@
1
1
  """Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
2
2
 
3
+ import asyncio
3
4
  import random
4
5
  import time
5
6
  from typing import List
@@ -212,5 +213,5 @@ async def retry_with_backoff_async(
212
213
  raise
213
214
  sleep = (initial_interval / 1000) * exponent**retries + random.uniform(0, 1)
214
215
  sleep = min(sleep, max_interval / 1000)
215
- time.sleep(sleep)
216
+ await asyncio.sleep(sleep)
216
217
  retries += 1
@@ -44,8 +44,10 @@ def get_security(security: Any) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
44
44
  _parse_security_option(headers, query_params, value)
45
45
  return headers, query_params
46
46
  if metadata.scheme:
47
- # Special case for basic auth which could be a flattened model
48
- if metadata.sub_type == "basic" and not isinstance(value, BaseModel):
47
+ # Special case for basic auth or custom auth which could be a flattened model
48
+ if metadata.sub_type in ["basic", "custom"] and not isinstance(
49
+ value, BaseModel
50
+ ):
49
51
  _parse_security_scheme(headers, query_params, metadata, name, security)
50
52
  else:
51
53
  _parse_security_scheme(headers, query_params, metadata, name, value)
@@ -64,7 +66,7 @@ def get_security_from_env(security: Any, security_class: Any) -> Optional[BaseMo
64
66
 
65
67
  if os.getenv("MISTRAL_API_KEY"):
66
68
  security_dict["api_key"] = os.getenv("MISTRAL_API_KEY")
67
-
69
+
68
70
  return security_class(**security_dict) if security_dict else None
69
71
 
70
72
 
@@ -97,9 +99,12 @@ def _parse_security_scheme(
97
99
  sub_type = scheme_metadata.sub_type
98
100
 
99
101
  if isinstance(scheme, BaseModel):
100
- if scheme_type == "http" and sub_type == "basic":
101
- _parse_basic_auth_scheme(headers, scheme)
102
- return
102
+ if scheme_type == "http":
103
+ if sub_type == "basic":
104
+ _parse_basic_auth_scheme(headers, scheme)
105
+ return
106
+ if sub_type == "custom":
107
+ return
103
108
 
104
109
  scheme_fields: Dict[str, FieldInfo] = scheme.__class__.model_fields
105
110
  for name in scheme_fields:
@@ -148,6 +153,8 @@ def _parse_security_scheme_value(
148
153
  elif scheme_type == "http":
149
154
  if sub_type == "bearer":
150
155
  headers[header_name] = _apply_bearer(value)
156
+ elif sub_type == "custom":
157
+ return
151
158
  else:
152
159
  raise ValueError("sub type {sub_type} not supported")
153
160
  else:
@@ -116,6 +116,19 @@ def validate_open_enum(is_int: bool):
116
116
  return validate
117
117
 
118
118
 
119
+ def validate_const(v):
120
+ def validate(c):
121
+ if is_optional_type(type(c)) and c is None:
122
+ return None
123
+
124
+ if v != c:
125
+ raise ValueError(f"Expected {v}")
126
+
127
+ return c
128
+
129
+ return validate
130
+
131
+
119
132
  def unmarshal_json(raw, typ: Any) -> Any:
120
133
  return unmarshal(from_json(raw), typ)
121
134
 
@@ -172,6 +185,18 @@ def stream_to_text(stream: httpx.Response) -> str:
172
185
  return "".join(stream.iter_text())
173
186
 
174
187
 
188
+ async def stream_to_text_async(stream: httpx.Response) -> str:
189
+ return "".join([chunk async for chunk in stream.aiter_text()])
190
+
191
+
192
+ def stream_to_bytes(stream: httpx.Response) -> bytes:
193
+ return stream.content
194
+
195
+
196
+ async def stream_to_bytes_async(stream: httpx.Response) -> bytes:
197
+ return await stream.aread()
198
+
199
+
175
200
  def get_pydantic_model(data: Any, typ: Any) -> Any:
176
201
  if not _contains_pydantic_model(data):
177
202
  return unmarshal(data, typ)