mistralai 1.5.0__py3-none-any.whl → 1.5.2rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/_hooks/types.py +15 -3
- mistralai/_version.py +3 -3
- mistralai/agents.py +32 -12
- mistralai/basesdk.py +8 -0
- mistralai/chat.py +37 -17
- mistralai/classifiers.py +59 -37
- mistralai/embeddings.py +22 -18
- mistralai/extra/utils/response_format.py +3 -3
- mistralai/files.py +36 -0
- mistralai/fim.py +37 -17
- mistralai/httpclient.py +4 -2
- mistralai/jobs.py +30 -0
- mistralai/mistral_jobs.py +24 -0
- mistralai/models/__init__.py +43 -16
- mistralai/models/assistantmessage.py +2 -0
- mistralai/models/chatcompletionrequest.py +3 -10
- mistralai/models/chatcompletionstreamrequest.py +3 -10
- mistralai/models/chatmoderationrequest.py +86 -0
- mistralai/models/classificationrequest.py +7 -36
- mistralai/models/contentchunk.py +8 -1
- mistralai/models/documenturlchunk.py +56 -0
- mistralai/models/embeddingrequest.py +8 -44
- mistralai/models/filepurpose.py +1 -1
- mistralai/models/fimcompletionrequest.py +2 -3
- mistralai/models/fimcompletionstreamrequest.py +2 -3
- mistralai/models/ocrimageobject.py +77 -0
- mistralai/models/ocrpagedimensions.py +25 -0
- mistralai/models/ocrpageobject.py +64 -0
- mistralai/models/ocrrequest.py +97 -0
- mistralai/models/ocrresponse.py +26 -0
- mistralai/models/ocrusageinfo.py +51 -0
- mistralai/models/prediction.py +4 -5
- mistralai/models_.py +66 -18
- mistralai/ocr.py +248 -0
- mistralai/sdk.py +23 -3
- mistralai/sdkconfiguration.py +4 -2
- mistralai/utils/__init__.py +2 -0
- mistralai/utils/serializers.py +10 -6
- mistralai/utils/values.py +4 -1
- {mistralai-1.5.0.dist-info → mistralai-1.5.2rc1.dist-info}/METADATA +70 -19
- {mistralai-1.5.0.dist-info → mistralai-1.5.2rc1.dist-info}/RECORD +88 -76
- {mistralai-1.5.0.dist-info → mistralai-1.5.2rc1.dist-info}/WHEEL +1 -1
- mistralai_azure/__init__.py +10 -1
- mistralai_azure/_hooks/types.py +15 -3
- mistralai_azure/_version.py +3 -0
- mistralai_azure/basesdk.py +8 -0
- mistralai_azure/chat.py +88 -20
- mistralai_azure/httpclient.py +52 -0
- mistralai_azure/models/__init__.py +7 -0
- mistralai_azure/models/assistantmessage.py +2 -0
- mistralai_azure/models/chatcompletionrequest.py +8 -10
- mistralai_azure/models/chatcompletionstreamrequest.py +8 -10
- mistralai_azure/models/function.py +3 -0
- mistralai_azure/models/jsonschema.py +61 -0
- mistralai_azure/models/prediction.py +25 -0
- mistralai_azure/models/responseformat.py +42 -1
- mistralai_azure/models/responseformats.py +1 -1
- mistralai_azure/models/toolcall.py +3 -0
- mistralai_azure/sdk.py +56 -14
- mistralai_azure/sdkconfiguration.py +14 -6
- mistralai_azure/utils/__init__.py +2 -0
- mistralai_azure/utils/serializers.py +10 -6
- mistralai_azure/utils/values.py +4 -1
- mistralai_gcp/__init__.py +10 -1
- mistralai_gcp/_hooks/types.py +15 -3
- mistralai_gcp/_version.py +3 -0
- mistralai_gcp/basesdk.py +8 -0
- mistralai_gcp/chat.py +89 -21
- mistralai_gcp/fim.py +61 -21
- mistralai_gcp/httpclient.py +52 -0
- mistralai_gcp/models/__init__.py +7 -0
- mistralai_gcp/models/assistantmessage.py +2 -0
- mistralai_gcp/models/chatcompletionrequest.py +8 -10
- mistralai_gcp/models/chatcompletionstreamrequest.py +8 -10
- mistralai_gcp/models/fimcompletionrequest.py +2 -3
- mistralai_gcp/models/fimcompletionstreamrequest.py +2 -3
- mistralai_gcp/models/function.py +3 -0
- mistralai_gcp/models/jsonschema.py +61 -0
- mistralai_gcp/models/prediction.py +25 -0
- mistralai_gcp/models/responseformat.py +42 -1
- mistralai_gcp/models/responseformats.py +1 -1
- mistralai_gcp/models/toolcall.py +3 -0
- mistralai_gcp/sdk.py +63 -19
- mistralai_gcp/sdkconfiguration.py +14 -6
- mistralai_gcp/utils/__init__.py +2 -0
- mistralai_gcp/utils/serializers.py +10 -6
- mistralai_gcp/utils/values.py +4 -1
- mistralai/models/chatclassificationrequest.py +0 -113
- {mistralai-1.5.0.dist-info → mistralai-1.5.2rc1.dist-info}/LICENSE +0 -0
mistralai/classifiers.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from .basesdk import BaseSDK
|
|
4
4
|
from mistralai import models, utils
|
|
5
5
|
from mistralai._hooks import HookContext
|
|
6
|
-
from mistralai.types import
|
|
6
|
+
from mistralai.types import OptionalNullable, UNSET
|
|
7
7
|
from mistralai.utils import get_security_from_env
|
|
8
8
|
from typing import Any, Mapping, Optional, Union
|
|
9
9
|
|
|
@@ -14,11 +14,11 @@ class Classifiers(BaseSDK):
|
|
|
14
14
|
def moderate(
|
|
15
15
|
self,
|
|
16
16
|
*,
|
|
17
|
+
model: str,
|
|
17
18
|
inputs: Union[
|
|
18
19
|
models.ClassificationRequestInputs,
|
|
19
20
|
models.ClassificationRequestInputsTypedDict,
|
|
20
21
|
],
|
|
21
|
-
model: OptionalNullable[str] = UNSET,
|
|
22
22
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
23
23
|
server_url: Optional[str] = None,
|
|
24
24
|
timeout_ms: Optional[int] = None,
|
|
@@ -26,8 +26,8 @@ class Classifiers(BaseSDK):
|
|
|
26
26
|
) -> models.ClassificationResponse:
|
|
27
27
|
r"""Moderations
|
|
28
28
|
|
|
29
|
+
:param model: ID of the model to use.
|
|
29
30
|
:param inputs: Text to classify.
|
|
30
|
-
:param model:
|
|
31
31
|
:param retries: Override the default retry configuration for this method
|
|
32
32
|
:param server_url: Override the default server URL for this method
|
|
33
33
|
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
@@ -40,10 +40,12 @@ class Classifiers(BaseSDK):
|
|
|
40
40
|
|
|
41
41
|
if server_url is not None:
|
|
42
42
|
base_url = server_url
|
|
43
|
+
else:
|
|
44
|
+
base_url = self._get_url(base_url, url_variables)
|
|
43
45
|
|
|
44
46
|
request = models.ClassificationRequest(
|
|
45
|
-
inputs=inputs,
|
|
46
47
|
model=model,
|
|
48
|
+
inputs=inputs,
|
|
47
49
|
)
|
|
48
50
|
|
|
49
51
|
req = self._build_request(
|
|
@@ -75,6 +77,7 @@ class Classifiers(BaseSDK):
|
|
|
75
77
|
|
|
76
78
|
http_res = self.do_request(
|
|
77
79
|
hook_ctx=HookContext(
|
|
80
|
+
base_url=base_url or "",
|
|
78
81
|
operation_id="moderations_v1_moderations_post",
|
|
79
82
|
oauth2_scopes=[],
|
|
80
83
|
security_source=get_security_from_env(
|
|
@@ -86,12 +89,14 @@ class Classifiers(BaseSDK):
|
|
|
86
89
|
retry_config=retry_config,
|
|
87
90
|
)
|
|
88
91
|
|
|
89
|
-
|
|
92
|
+
response_data: Any = None
|
|
90
93
|
if utils.match_response(http_res, "200", "application/json"):
|
|
91
94
|
return utils.unmarshal_json(http_res.text, models.ClassificationResponse)
|
|
92
95
|
if utils.match_response(http_res, "422", "application/json"):
|
|
93
|
-
|
|
94
|
-
|
|
96
|
+
response_data = utils.unmarshal_json(
|
|
97
|
+
http_res.text, models.HTTPValidationErrorData
|
|
98
|
+
)
|
|
99
|
+
raise models.HTTPValidationError(data=response_data)
|
|
95
100
|
if utils.match_response(http_res, "4XX", "*"):
|
|
96
101
|
http_res_text = utils.stream_to_text(http_res)
|
|
97
102
|
raise models.SDKError(
|
|
@@ -115,11 +120,11 @@ class Classifiers(BaseSDK):
|
|
|
115
120
|
async def moderate_async(
|
|
116
121
|
self,
|
|
117
122
|
*,
|
|
123
|
+
model: str,
|
|
118
124
|
inputs: Union[
|
|
119
125
|
models.ClassificationRequestInputs,
|
|
120
126
|
models.ClassificationRequestInputsTypedDict,
|
|
121
127
|
],
|
|
122
|
-
model: OptionalNullable[str] = UNSET,
|
|
123
128
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
124
129
|
server_url: Optional[str] = None,
|
|
125
130
|
timeout_ms: Optional[int] = None,
|
|
@@ -127,8 +132,8 @@ class Classifiers(BaseSDK):
|
|
|
127
132
|
) -> models.ClassificationResponse:
|
|
128
133
|
r"""Moderations
|
|
129
134
|
|
|
135
|
+
:param model: ID of the model to use.
|
|
130
136
|
:param inputs: Text to classify.
|
|
131
|
-
:param model:
|
|
132
137
|
:param retries: Override the default retry configuration for this method
|
|
133
138
|
:param server_url: Override the default server URL for this method
|
|
134
139
|
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
@@ -141,10 +146,12 @@ class Classifiers(BaseSDK):
|
|
|
141
146
|
|
|
142
147
|
if server_url is not None:
|
|
143
148
|
base_url = server_url
|
|
149
|
+
else:
|
|
150
|
+
base_url = self._get_url(base_url, url_variables)
|
|
144
151
|
|
|
145
152
|
request = models.ClassificationRequest(
|
|
146
|
-
inputs=inputs,
|
|
147
153
|
model=model,
|
|
154
|
+
inputs=inputs,
|
|
148
155
|
)
|
|
149
156
|
|
|
150
157
|
req = self._build_request_async(
|
|
@@ -176,6 +183,7 @@ class Classifiers(BaseSDK):
|
|
|
176
183
|
|
|
177
184
|
http_res = await self.do_request_async(
|
|
178
185
|
hook_ctx=HookContext(
|
|
186
|
+
base_url=base_url or "",
|
|
179
187
|
operation_id="moderations_v1_moderations_post",
|
|
180
188
|
oauth2_scopes=[],
|
|
181
189
|
security_source=get_security_from_env(
|
|
@@ -187,12 +195,14 @@ class Classifiers(BaseSDK):
|
|
|
187
195
|
retry_config=retry_config,
|
|
188
196
|
)
|
|
189
197
|
|
|
190
|
-
|
|
198
|
+
response_data: Any = None
|
|
191
199
|
if utils.match_response(http_res, "200", "application/json"):
|
|
192
200
|
return utils.unmarshal_json(http_res.text, models.ClassificationResponse)
|
|
193
201
|
if utils.match_response(http_res, "422", "application/json"):
|
|
194
|
-
|
|
195
|
-
|
|
202
|
+
response_data = utils.unmarshal_json(
|
|
203
|
+
http_res.text, models.HTTPValidationErrorData
|
|
204
|
+
)
|
|
205
|
+
raise models.HTTPValidationError(data=response_data)
|
|
196
206
|
if utils.match_response(http_res, "4XX", "*"):
|
|
197
207
|
http_res_text = await utils.stream_to_text_async(http_res)
|
|
198
208
|
raise models.SDKError(
|
|
@@ -216,11 +226,12 @@ class Classifiers(BaseSDK):
|
|
|
216
226
|
def moderate_chat(
|
|
217
227
|
self,
|
|
218
228
|
*,
|
|
229
|
+
model: str,
|
|
219
230
|
inputs: Union[
|
|
220
|
-
models.
|
|
221
|
-
models.
|
|
231
|
+
models.ChatModerationRequestInputs,
|
|
232
|
+
models.ChatModerationRequestInputsTypedDict,
|
|
222
233
|
],
|
|
223
|
-
|
|
234
|
+
truncate_for_context_length: Optional[bool] = False,
|
|
224
235
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
225
236
|
server_url: Optional[str] = None,
|
|
226
237
|
timeout_ms: Optional[int] = None,
|
|
@@ -228,8 +239,9 @@ class Classifiers(BaseSDK):
|
|
|
228
239
|
) -> models.ClassificationResponse:
|
|
229
240
|
r"""Moderations Chat
|
|
230
241
|
|
|
231
|
-
:param inputs: Chat to classify
|
|
232
242
|
:param model:
|
|
243
|
+
:param inputs: Chat to classify
|
|
244
|
+
:param truncate_for_context_length:
|
|
233
245
|
:param retries: Override the default retry configuration for this method
|
|
234
246
|
:param server_url: Override the default server URL for this method
|
|
235
247
|
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
@@ -242,12 +254,13 @@ class Classifiers(BaseSDK):
|
|
|
242
254
|
|
|
243
255
|
if server_url is not None:
|
|
244
256
|
base_url = server_url
|
|
257
|
+
else:
|
|
258
|
+
base_url = self._get_url(base_url, url_variables)
|
|
245
259
|
|
|
246
|
-
request = models.
|
|
247
|
-
inputs=utils.get_pydantic_model(
|
|
248
|
-
inputs, models.ChatClassificationRequestInputs
|
|
249
|
-
),
|
|
260
|
+
request = models.ChatModerationRequest(
|
|
250
261
|
model=model,
|
|
262
|
+
inputs=utils.get_pydantic_model(inputs, models.ChatModerationRequestInputs),
|
|
263
|
+
truncate_for_context_length=truncate_for_context_length,
|
|
251
264
|
)
|
|
252
265
|
|
|
253
266
|
req = self._build_request(
|
|
@@ -264,7 +277,7 @@ class Classifiers(BaseSDK):
|
|
|
264
277
|
http_headers=http_headers,
|
|
265
278
|
security=self.sdk_configuration.security,
|
|
266
279
|
get_serialized_body=lambda: utils.serialize_request_body(
|
|
267
|
-
request, False, False, "json", models.
|
|
280
|
+
request, False, False, "json", models.ChatModerationRequest
|
|
268
281
|
),
|
|
269
282
|
timeout_ms=timeout_ms,
|
|
270
283
|
)
|
|
@@ -279,6 +292,7 @@ class Classifiers(BaseSDK):
|
|
|
279
292
|
|
|
280
293
|
http_res = self.do_request(
|
|
281
294
|
hook_ctx=HookContext(
|
|
295
|
+
base_url=base_url or "",
|
|
282
296
|
operation_id="moderations_chat_v1_chat_moderations_post",
|
|
283
297
|
oauth2_scopes=[],
|
|
284
298
|
security_source=get_security_from_env(
|
|
@@ -290,12 +304,14 @@ class Classifiers(BaseSDK):
|
|
|
290
304
|
retry_config=retry_config,
|
|
291
305
|
)
|
|
292
306
|
|
|
293
|
-
|
|
307
|
+
response_data: Any = None
|
|
294
308
|
if utils.match_response(http_res, "200", "application/json"):
|
|
295
309
|
return utils.unmarshal_json(http_res.text, models.ClassificationResponse)
|
|
296
310
|
if utils.match_response(http_res, "422", "application/json"):
|
|
297
|
-
|
|
298
|
-
|
|
311
|
+
response_data = utils.unmarshal_json(
|
|
312
|
+
http_res.text, models.HTTPValidationErrorData
|
|
313
|
+
)
|
|
314
|
+
raise models.HTTPValidationError(data=response_data)
|
|
299
315
|
if utils.match_response(http_res, "4XX", "*"):
|
|
300
316
|
http_res_text = utils.stream_to_text(http_res)
|
|
301
317
|
raise models.SDKError(
|
|
@@ -319,11 +335,12 @@ class Classifiers(BaseSDK):
|
|
|
319
335
|
async def moderate_chat_async(
|
|
320
336
|
self,
|
|
321
337
|
*,
|
|
338
|
+
model: str,
|
|
322
339
|
inputs: Union[
|
|
323
|
-
models.
|
|
324
|
-
models.
|
|
340
|
+
models.ChatModerationRequestInputs,
|
|
341
|
+
models.ChatModerationRequestInputsTypedDict,
|
|
325
342
|
],
|
|
326
|
-
|
|
343
|
+
truncate_for_context_length: Optional[bool] = False,
|
|
327
344
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
328
345
|
server_url: Optional[str] = None,
|
|
329
346
|
timeout_ms: Optional[int] = None,
|
|
@@ -331,8 +348,9 @@ class Classifiers(BaseSDK):
|
|
|
331
348
|
) -> models.ClassificationResponse:
|
|
332
349
|
r"""Moderations Chat
|
|
333
350
|
|
|
334
|
-
:param inputs: Chat to classify
|
|
335
351
|
:param model:
|
|
352
|
+
:param inputs: Chat to classify
|
|
353
|
+
:param truncate_for_context_length:
|
|
336
354
|
:param retries: Override the default retry configuration for this method
|
|
337
355
|
:param server_url: Override the default server URL for this method
|
|
338
356
|
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
@@ -345,12 +363,13 @@ class Classifiers(BaseSDK):
|
|
|
345
363
|
|
|
346
364
|
if server_url is not None:
|
|
347
365
|
base_url = server_url
|
|
366
|
+
else:
|
|
367
|
+
base_url = self._get_url(base_url, url_variables)
|
|
348
368
|
|
|
349
|
-
request = models.
|
|
350
|
-
inputs=utils.get_pydantic_model(
|
|
351
|
-
inputs, models.ChatClassificationRequestInputs
|
|
352
|
-
),
|
|
369
|
+
request = models.ChatModerationRequest(
|
|
353
370
|
model=model,
|
|
371
|
+
inputs=utils.get_pydantic_model(inputs, models.ChatModerationRequestInputs),
|
|
372
|
+
truncate_for_context_length=truncate_for_context_length,
|
|
354
373
|
)
|
|
355
374
|
|
|
356
375
|
req = self._build_request_async(
|
|
@@ -367,7 +386,7 @@ class Classifiers(BaseSDK):
|
|
|
367
386
|
http_headers=http_headers,
|
|
368
387
|
security=self.sdk_configuration.security,
|
|
369
388
|
get_serialized_body=lambda: utils.serialize_request_body(
|
|
370
|
-
request, False, False, "json", models.
|
|
389
|
+
request, False, False, "json", models.ChatModerationRequest
|
|
371
390
|
),
|
|
372
391
|
timeout_ms=timeout_ms,
|
|
373
392
|
)
|
|
@@ -382,6 +401,7 @@ class Classifiers(BaseSDK):
|
|
|
382
401
|
|
|
383
402
|
http_res = await self.do_request_async(
|
|
384
403
|
hook_ctx=HookContext(
|
|
404
|
+
base_url=base_url or "",
|
|
385
405
|
operation_id="moderations_chat_v1_chat_moderations_post",
|
|
386
406
|
oauth2_scopes=[],
|
|
387
407
|
security_source=get_security_from_env(
|
|
@@ -393,12 +413,14 @@ class Classifiers(BaseSDK):
|
|
|
393
413
|
retry_config=retry_config,
|
|
394
414
|
)
|
|
395
415
|
|
|
396
|
-
|
|
416
|
+
response_data: Any = None
|
|
397
417
|
if utils.match_response(http_res, "200", "application/json"):
|
|
398
418
|
return utils.unmarshal_json(http_res.text, models.ClassificationResponse)
|
|
399
419
|
if utils.match_response(http_res, "422", "application/json"):
|
|
400
|
-
|
|
401
|
-
|
|
420
|
+
response_data = utils.unmarshal_json(
|
|
421
|
+
http_res.text, models.HTTPValidationErrorData
|
|
422
|
+
)
|
|
423
|
+
raise models.HTTPValidationError(data=response_data)
|
|
402
424
|
if utils.match_response(http_res, "4XX", "*"):
|
|
403
425
|
http_res_text = await utils.stream_to_text_async(http_res)
|
|
404
426
|
raise models.SDKError(
|
mistralai/embeddings.py
CHANGED
|
@@ -14,9 +14,8 @@ class Embeddings(BaseSDK):
|
|
|
14
14
|
def create(
|
|
15
15
|
self,
|
|
16
16
|
*,
|
|
17
|
+
model: str,
|
|
17
18
|
inputs: Union[models.Inputs, models.InputsTypedDict],
|
|
18
|
-
model: Optional[str] = "mistral-embed",
|
|
19
|
-
encoding_format: OptionalNullable[str] = UNSET,
|
|
20
19
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
21
20
|
server_url: Optional[str] = None,
|
|
22
21
|
timeout_ms: Optional[int] = None,
|
|
@@ -26,9 +25,8 @@ class Embeddings(BaseSDK):
|
|
|
26
25
|
|
|
27
26
|
Embeddings
|
|
28
27
|
|
|
29
|
-
:param inputs: Text to embed.
|
|
30
28
|
:param model: ID of the model to use.
|
|
31
|
-
:param
|
|
29
|
+
:param inputs: Text to embed.
|
|
32
30
|
:param retries: Override the default retry configuration for this method
|
|
33
31
|
:param server_url: Override the default server URL for this method
|
|
34
32
|
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
@@ -41,11 +39,12 @@ class Embeddings(BaseSDK):
|
|
|
41
39
|
|
|
42
40
|
if server_url is not None:
|
|
43
41
|
base_url = server_url
|
|
42
|
+
else:
|
|
43
|
+
base_url = self._get_url(base_url, url_variables)
|
|
44
44
|
|
|
45
45
|
request = models.EmbeddingRequest(
|
|
46
|
-
inputs=inputs,
|
|
47
46
|
model=model,
|
|
48
|
-
|
|
47
|
+
inputs=inputs,
|
|
49
48
|
)
|
|
50
49
|
|
|
51
50
|
req = self._build_request(
|
|
@@ -77,6 +76,7 @@ class Embeddings(BaseSDK):
|
|
|
77
76
|
|
|
78
77
|
http_res = self.do_request(
|
|
79
78
|
hook_ctx=HookContext(
|
|
79
|
+
base_url=base_url or "",
|
|
80
80
|
operation_id="embeddings_v1_embeddings_post",
|
|
81
81
|
oauth2_scopes=[],
|
|
82
82
|
security_source=get_security_from_env(
|
|
@@ -88,12 +88,14 @@ class Embeddings(BaseSDK):
|
|
|
88
88
|
retry_config=retry_config,
|
|
89
89
|
)
|
|
90
90
|
|
|
91
|
-
|
|
91
|
+
response_data: Any = None
|
|
92
92
|
if utils.match_response(http_res, "200", "application/json"):
|
|
93
93
|
return utils.unmarshal_json(http_res.text, models.EmbeddingResponse)
|
|
94
94
|
if utils.match_response(http_res, "422", "application/json"):
|
|
95
|
-
|
|
96
|
-
|
|
95
|
+
response_data = utils.unmarshal_json(
|
|
96
|
+
http_res.text, models.HTTPValidationErrorData
|
|
97
|
+
)
|
|
98
|
+
raise models.HTTPValidationError(data=response_data)
|
|
97
99
|
if utils.match_response(http_res, "4XX", "*"):
|
|
98
100
|
http_res_text = utils.stream_to_text(http_res)
|
|
99
101
|
raise models.SDKError(
|
|
@@ -117,9 +119,8 @@ class Embeddings(BaseSDK):
|
|
|
117
119
|
async def create_async(
|
|
118
120
|
self,
|
|
119
121
|
*,
|
|
122
|
+
model: str,
|
|
120
123
|
inputs: Union[models.Inputs, models.InputsTypedDict],
|
|
121
|
-
model: Optional[str] = "mistral-embed",
|
|
122
|
-
encoding_format: OptionalNullable[str] = UNSET,
|
|
123
124
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
124
125
|
server_url: Optional[str] = None,
|
|
125
126
|
timeout_ms: Optional[int] = None,
|
|
@@ -129,9 +130,8 @@ class Embeddings(BaseSDK):
|
|
|
129
130
|
|
|
130
131
|
Embeddings
|
|
131
132
|
|
|
132
|
-
:param inputs: Text to embed.
|
|
133
133
|
:param model: ID of the model to use.
|
|
134
|
-
:param
|
|
134
|
+
:param inputs: Text to embed.
|
|
135
135
|
:param retries: Override the default retry configuration for this method
|
|
136
136
|
:param server_url: Override the default server URL for this method
|
|
137
137
|
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
@@ -144,11 +144,12 @@ class Embeddings(BaseSDK):
|
|
|
144
144
|
|
|
145
145
|
if server_url is not None:
|
|
146
146
|
base_url = server_url
|
|
147
|
+
else:
|
|
148
|
+
base_url = self._get_url(base_url, url_variables)
|
|
147
149
|
|
|
148
150
|
request = models.EmbeddingRequest(
|
|
149
|
-
inputs=inputs,
|
|
150
151
|
model=model,
|
|
151
|
-
|
|
152
|
+
inputs=inputs,
|
|
152
153
|
)
|
|
153
154
|
|
|
154
155
|
req = self._build_request_async(
|
|
@@ -180,6 +181,7 @@ class Embeddings(BaseSDK):
|
|
|
180
181
|
|
|
181
182
|
http_res = await self.do_request_async(
|
|
182
183
|
hook_ctx=HookContext(
|
|
184
|
+
base_url=base_url or "",
|
|
183
185
|
operation_id="embeddings_v1_embeddings_post",
|
|
184
186
|
oauth2_scopes=[],
|
|
185
187
|
security_source=get_security_from_env(
|
|
@@ -191,12 +193,14 @@ class Embeddings(BaseSDK):
|
|
|
191
193
|
retry_config=retry_config,
|
|
192
194
|
)
|
|
193
195
|
|
|
194
|
-
|
|
196
|
+
response_data: Any = None
|
|
195
197
|
if utils.match_response(http_res, "200", "application/json"):
|
|
196
198
|
return utils.unmarshal_json(http_res.text, models.EmbeddingResponse)
|
|
197
199
|
if utils.match_response(http_res, "422", "application/json"):
|
|
198
|
-
|
|
199
|
-
|
|
200
|
+
response_data = utils.unmarshal_json(
|
|
201
|
+
http_res.text, models.HTTPValidationErrorData
|
|
202
|
+
)
|
|
203
|
+
raise models.HTTPValidationError(data=response_data)
|
|
200
204
|
if utils.match_response(http_res, "4XX", "*"):
|
|
201
205
|
http_res_text = await utils.stream_to_text_async(http_res)
|
|
202
206
|
raise models.SDKError(
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from pydantic import BaseModel
|
|
2
|
-
from typing import TypeVar, Any, Type
|
|
2
|
+
from typing import TypeVar, Any, Type, Dict
|
|
3
3
|
from ...models import JSONSchema, ResponseFormat
|
|
4
4
|
from ._pydantic_helper import rec_strict_json_schema
|
|
5
5
|
|
|
@@ -7,7 +7,7 @@ CustomPydanticModel = TypeVar("CustomPydanticModel", bound=BaseModel)
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
def response_format_from_pydantic_model(
|
|
10
|
-
model:
|
|
10
|
+
model: Type[CustomPydanticModel],
|
|
11
11
|
) -> ResponseFormat:
|
|
12
12
|
"""Generate a strict JSON schema from a pydantic model."""
|
|
13
13
|
model_schema = rec_strict_json_schema(model.model_json_schema())
|
|
@@ -18,7 +18,7 @@ def response_format_from_pydantic_model(
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
def pydantic_model_from_json(
|
|
21
|
-
json_data:
|
|
21
|
+
json_data: Dict[str, Any], pydantic_model: Type[CustomPydanticModel]
|
|
22
22
|
) -> CustomPydanticModel:
|
|
23
23
|
"""Parse a JSON schema into a pydantic model."""
|
|
24
24
|
return pydantic_model.model_validate(json_data)
|
mistralai/files.py
CHANGED
|
@@ -44,6 +44,8 @@ class Files(BaseSDK):
|
|
|
44
44
|
|
|
45
45
|
if server_url is not None:
|
|
46
46
|
base_url = server_url
|
|
47
|
+
else:
|
|
48
|
+
base_url = self._get_url(base_url, url_variables)
|
|
47
49
|
|
|
48
50
|
request = models.FilesAPIRoutesUploadFileMultiPartBodyParams(
|
|
49
51
|
file=utils.get_pydantic_model(file, models.File),
|
|
@@ -83,6 +85,7 @@ class Files(BaseSDK):
|
|
|
83
85
|
|
|
84
86
|
http_res = self.do_request(
|
|
85
87
|
hook_ctx=HookContext(
|
|
88
|
+
base_url=base_url or "",
|
|
86
89
|
operation_id="files_api_routes_upload_file",
|
|
87
90
|
oauth2_scopes=[],
|
|
88
91
|
security_source=get_security_from_env(
|
|
@@ -148,6 +151,8 @@ class Files(BaseSDK):
|
|
|
148
151
|
|
|
149
152
|
if server_url is not None:
|
|
150
153
|
base_url = server_url
|
|
154
|
+
else:
|
|
155
|
+
base_url = self._get_url(base_url, url_variables)
|
|
151
156
|
|
|
152
157
|
request = models.FilesAPIRoutesUploadFileMultiPartBodyParams(
|
|
153
158
|
file=utils.get_pydantic_model(file, models.File),
|
|
@@ -187,6 +192,7 @@ class Files(BaseSDK):
|
|
|
187
192
|
|
|
188
193
|
http_res = await self.do_request_async(
|
|
189
194
|
hook_ctx=HookContext(
|
|
195
|
+
base_url=base_url or "",
|
|
190
196
|
operation_id="files_api_routes_upload_file",
|
|
191
197
|
oauth2_scopes=[],
|
|
192
198
|
security_source=get_security_from_env(
|
|
@@ -256,6 +262,8 @@ class Files(BaseSDK):
|
|
|
256
262
|
|
|
257
263
|
if server_url is not None:
|
|
258
264
|
base_url = server_url
|
|
265
|
+
else:
|
|
266
|
+
base_url = self._get_url(base_url, url_variables)
|
|
259
267
|
|
|
260
268
|
request = models.FilesAPIRoutesListFilesRequest(
|
|
261
269
|
page=page,
|
|
@@ -292,6 +300,7 @@ class Files(BaseSDK):
|
|
|
292
300
|
|
|
293
301
|
http_res = self.do_request(
|
|
294
302
|
hook_ctx=HookContext(
|
|
303
|
+
base_url=base_url or "",
|
|
295
304
|
operation_id="files_api_routes_list_files",
|
|
296
305
|
oauth2_scopes=[],
|
|
297
306
|
security_source=get_security_from_env(
|
|
@@ -361,6 +370,8 @@ class Files(BaseSDK):
|
|
|
361
370
|
|
|
362
371
|
if server_url is not None:
|
|
363
372
|
base_url = server_url
|
|
373
|
+
else:
|
|
374
|
+
base_url = self._get_url(base_url, url_variables)
|
|
364
375
|
|
|
365
376
|
request = models.FilesAPIRoutesListFilesRequest(
|
|
366
377
|
page=page,
|
|
@@ -397,6 +408,7 @@ class Files(BaseSDK):
|
|
|
397
408
|
|
|
398
409
|
http_res = await self.do_request_async(
|
|
399
410
|
hook_ctx=HookContext(
|
|
411
|
+
base_url=base_url or "",
|
|
400
412
|
operation_id="files_api_routes_list_files",
|
|
401
413
|
oauth2_scopes=[],
|
|
402
414
|
security_source=get_security_from_env(
|
|
@@ -456,6 +468,8 @@ class Files(BaseSDK):
|
|
|
456
468
|
|
|
457
469
|
if server_url is not None:
|
|
458
470
|
base_url = server_url
|
|
471
|
+
else:
|
|
472
|
+
base_url = self._get_url(base_url, url_variables)
|
|
459
473
|
|
|
460
474
|
request = models.FilesAPIRoutesRetrieveFileRequest(
|
|
461
475
|
file_id=file_id,
|
|
@@ -487,6 +501,7 @@ class Files(BaseSDK):
|
|
|
487
501
|
|
|
488
502
|
http_res = self.do_request(
|
|
489
503
|
hook_ctx=HookContext(
|
|
504
|
+
base_url=base_url or "",
|
|
490
505
|
operation_id="files_api_routes_retrieve_file",
|
|
491
506
|
oauth2_scopes=[],
|
|
492
507
|
security_source=get_security_from_env(
|
|
@@ -546,6 +561,8 @@ class Files(BaseSDK):
|
|
|
546
561
|
|
|
547
562
|
if server_url is not None:
|
|
548
563
|
base_url = server_url
|
|
564
|
+
else:
|
|
565
|
+
base_url = self._get_url(base_url, url_variables)
|
|
549
566
|
|
|
550
567
|
request = models.FilesAPIRoutesRetrieveFileRequest(
|
|
551
568
|
file_id=file_id,
|
|
@@ -577,6 +594,7 @@ class Files(BaseSDK):
|
|
|
577
594
|
|
|
578
595
|
http_res = await self.do_request_async(
|
|
579
596
|
hook_ctx=HookContext(
|
|
597
|
+
base_url=base_url or "",
|
|
580
598
|
operation_id="files_api_routes_retrieve_file",
|
|
581
599
|
oauth2_scopes=[],
|
|
582
600
|
security_source=get_security_from_env(
|
|
@@ -636,6 +654,8 @@ class Files(BaseSDK):
|
|
|
636
654
|
|
|
637
655
|
if server_url is not None:
|
|
638
656
|
base_url = server_url
|
|
657
|
+
else:
|
|
658
|
+
base_url = self._get_url(base_url, url_variables)
|
|
639
659
|
|
|
640
660
|
request = models.FilesAPIRoutesDeleteFileRequest(
|
|
641
661
|
file_id=file_id,
|
|
@@ -667,6 +687,7 @@ class Files(BaseSDK):
|
|
|
667
687
|
|
|
668
688
|
http_res = self.do_request(
|
|
669
689
|
hook_ctx=HookContext(
|
|
690
|
+
base_url=base_url or "",
|
|
670
691
|
operation_id="files_api_routes_delete_file",
|
|
671
692
|
oauth2_scopes=[],
|
|
672
693
|
security_source=get_security_from_env(
|
|
@@ -726,6 +747,8 @@ class Files(BaseSDK):
|
|
|
726
747
|
|
|
727
748
|
if server_url is not None:
|
|
728
749
|
base_url = server_url
|
|
750
|
+
else:
|
|
751
|
+
base_url = self._get_url(base_url, url_variables)
|
|
729
752
|
|
|
730
753
|
request = models.FilesAPIRoutesDeleteFileRequest(
|
|
731
754
|
file_id=file_id,
|
|
@@ -757,6 +780,7 @@ class Files(BaseSDK):
|
|
|
757
780
|
|
|
758
781
|
http_res = await self.do_request_async(
|
|
759
782
|
hook_ctx=HookContext(
|
|
783
|
+
base_url=base_url or "",
|
|
760
784
|
operation_id="files_api_routes_delete_file",
|
|
761
785
|
oauth2_scopes=[],
|
|
762
786
|
security_source=get_security_from_env(
|
|
@@ -816,6 +840,8 @@ class Files(BaseSDK):
|
|
|
816
840
|
|
|
817
841
|
if server_url is not None:
|
|
818
842
|
base_url = server_url
|
|
843
|
+
else:
|
|
844
|
+
base_url = self._get_url(base_url, url_variables)
|
|
819
845
|
|
|
820
846
|
request = models.FilesAPIRoutesDownloadFileRequest(
|
|
821
847
|
file_id=file_id,
|
|
@@ -847,6 +873,7 @@ class Files(BaseSDK):
|
|
|
847
873
|
|
|
848
874
|
http_res = self.do_request(
|
|
849
875
|
hook_ctx=HookContext(
|
|
876
|
+
base_url=base_url or "",
|
|
850
877
|
operation_id="files_api_routes_download_file",
|
|
851
878
|
oauth2_scopes=[],
|
|
852
879
|
security_source=get_security_from_env(
|
|
@@ -907,6 +934,8 @@ class Files(BaseSDK):
|
|
|
907
934
|
|
|
908
935
|
if server_url is not None:
|
|
909
936
|
base_url = server_url
|
|
937
|
+
else:
|
|
938
|
+
base_url = self._get_url(base_url, url_variables)
|
|
910
939
|
|
|
911
940
|
request = models.FilesAPIRoutesDownloadFileRequest(
|
|
912
941
|
file_id=file_id,
|
|
@@ -938,6 +967,7 @@ class Files(BaseSDK):
|
|
|
938
967
|
|
|
939
968
|
http_res = await self.do_request_async(
|
|
940
969
|
hook_ctx=HookContext(
|
|
970
|
+
base_url=base_url or "",
|
|
941
971
|
operation_id="files_api_routes_download_file",
|
|
942
972
|
oauth2_scopes=[],
|
|
943
973
|
security_source=get_security_from_env(
|
|
@@ -998,6 +1028,8 @@ class Files(BaseSDK):
|
|
|
998
1028
|
|
|
999
1029
|
if server_url is not None:
|
|
1000
1030
|
base_url = server_url
|
|
1031
|
+
else:
|
|
1032
|
+
base_url = self._get_url(base_url, url_variables)
|
|
1001
1033
|
|
|
1002
1034
|
request = models.FilesAPIRoutesGetSignedURLRequest(
|
|
1003
1035
|
file_id=file_id,
|
|
@@ -1030,6 +1062,7 @@ class Files(BaseSDK):
|
|
|
1030
1062
|
|
|
1031
1063
|
http_res = self.do_request(
|
|
1032
1064
|
hook_ctx=HookContext(
|
|
1065
|
+
base_url=base_url or "",
|
|
1033
1066
|
operation_id="files_api_routes_get_signed_url",
|
|
1034
1067
|
oauth2_scopes=[],
|
|
1035
1068
|
security_source=get_security_from_env(
|
|
@@ -1089,6 +1122,8 @@ class Files(BaseSDK):
|
|
|
1089
1122
|
|
|
1090
1123
|
if server_url is not None:
|
|
1091
1124
|
base_url = server_url
|
|
1125
|
+
else:
|
|
1126
|
+
base_url = self._get_url(base_url, url_variables)
|
|
1092
1127
|
|
|
1093
1128
|
request = models.FilesAPIRoutesGetSignedURLRequest(
|
|
1094
1129
|
file_id=file_id,
|
|
@@ -1121,6 +1156,7 @@ class Files(BaseSDK):
|
|
|
1121
1156
|
|
|
1122
1157
|
http_res = await self.do_request_async(
|
|
1123
1158
|
hook_ctx=HookContext(
|
|
1159
|
+
base_url=base_url or "",
|
|
1124
1160
|
operation_id="files_api_routes_get_signed_url",
|
|
1125
1161
|
oauth2_scopes=[],
|
|
1126
1162
|
security_source=get_security_from_env(
|