mistralai 1.5.2__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/_hooks/types.py +15 -3
- mistralai/_version.py +3 -3
- mistralai/agents.py +44 -12
- mistralai/basesdk.py +8 -0
- mistralai/chat.py +44 -12
- mistralai/classifiers.py +36 -16
- mistralai/embeddings.py +16 -6
- mistralai/files.py +36 -0
- mistralai/fim.py +32 -12
- mistralai/httpclient.py +4 -2
- mistralai/jobs.py +30 -0
- mistralai/mistral_jobs.py +24 -0
- mistralai/models/agentscompletionrequest.py +4 -0
- mistralai/models/agentscompletionstreamrequest.py +4 -0
- mistralai/models/chatcompletionrequest.py +4 -0
- mistralai/models/chatcompletionstreamrequest.py +4 -0
- mistralai/models/function.py +2 -2
- mistralai/models/jsonschema.py +1 -1
- mistralai/models_.py +66 -18
- mistralai/ocr.py +16 -6
- mistralai/sdk.py +19 -3
- mistralai/sdkconfiguration.py +4 -2
- mistralai/utils/__init__.py +2 -0
- mistralai/utils/serializers.py +10 -6
- mistralai/utils/values.py +4 -1
- {mistralai-1.5.2.dist-info → mistralai-1.6.0.dist-info}/METADATA +63 -16
- {mistralai-1.5.2.dist-info → mistralai-1.6.0.dist-info}/RECORD +80 -72
- mistralai_azure/__init__.py +10 -1
- mistralai_azure/_hooks/types.py +15 -3
- mistralai_azure/_version.py +4 -1
- mistralai_azure/basesdk.py +8 -0
- mistralai_azure/chat.py +100 -20
- mistralai_azure/httpclient.py +52 -0
- mistralai_azure/models/__init__.py +22 -0
- mistralai_azure/models/assistantmessage.py +2 -0
- mistralai_azure/models/chatcompletionrequest.py +12 -10
- mistralai_azure/models/chatcompletionstreamrequest.py +12 -10
- mistralai_azure/models/contentchunk.py +6 -2
- mistralai_azure/models/function.py +4 -1
- mistralai_azure/models/imageurl.py +53 -0
- mistralai_azure/models/imageurlchunk.py +33 -0
- mistralai_azure/models/jsonschema.py +61 -0
- mistralai_azure/models/prediction.py +25 -0
- mistralai_azure/models/responseformat.py +42 -1
- mistralai_azure/models/responseformats.py +1 -1
- mistralai_azure/models/toolcall.py +3 -0
- mistralai_azure/sdk.py +56 -14
- mistralai_azure/sdkconfiguration.py +14 -6
- mistralai_azure/utils/__init__.py +2 -0
- mistralai_azure/utils/serializers.py +10 -6
- mistralai_azure/utils/values.py +4 -1
- mistralai_gcp/__init__.py +10 -1
- mistralai_gcp/_hooks/types.py +15 -3
- mistralai_gcp/_version.py +4 -1
- mistralai_gcp/basesdk.py +8 -0
- mistralai_gcp/chat.py +101 -21
- mistralai_gcp/fim.py +61 -21
- mistralai_gcp/httpclient.py +52 -0
- mistralai_gcp/models/__init__.py +22 -0
- mistralai_gcp/models/assistantmessage.py +2 -0
- mistralai_gcp/models/chatcompletionrequest.py +12 -10
- mistralai_gcp/models/chatcompletionstreamrequest.py +12 -10
- mistralai_gcp/models/contentchunk.py +6 -2
- mistralai_gcp/models/fimcompletionrequest.py +2 -3
- mistralai_gcp/models/fimcompletionstreamrequest.py +2 -3
- mistralai_gcp/models/function.py +4 -1
- mistralai_gcp/models/imageurl.py +53 -0
- mistralai_gcp/models/imageurlchunk.py +33 -0
- mistralai_gcp/models/jsonschema.py +61 -0
- mistralai_gcp/models/prediction.py +25 -0
- mistralai_gcp/models/responseformat.py +42 -1
- mistralai_gcp/models/responseformats.py +1 -1
- mistralai_gcp/models/toolcall.py +3 -0
- mistralai_gcp/sdk.py +63 -19
- mistralai_gcp/sdkconfiguration.py +14 -6
- mistralai_gcp/utils/__init__.py +2 -0
- mistralai_gcp/utils/serializers.py +10 -6
- mistralai_gcp/utils/values.py +4 -1
- {mistralai-1.5.2.dist-info → mistralai-1.6.0.dist-info}/LICENSE +0 -0
- {mistralai-1.5.2.dist-info → mistralai-1.6.0.dist-info}/WHEEL +0 -0
mistralai/models_.py
CHANGED
|
@@ -35,6 +35,8 @@ class Models(BaseSDK):
|
|
|
35
35
|
|
|
36
36
|
if server_url is not None:
|
|
37
37
|
base_url = server_url
|
|
38
|
+
else:
|
|
39
|
+
base_url = self._get_url(base_url, url_variables)
|
|
38
40
|
req = self._build_request(
|
|
39
41
|
method="GET",
|
|
40
42
|
path="/v1/models",
|
|
@@ -61,6 +63,7 @@ class Models(BaseSDK):
|
|
|
61
63
|
|
|
62
64
|
http_res = self.do_request(
|
|
63
65
|
hook_ctx=HookContext(
|
|
66
|
+
base_url=base_url or "",
|
|
64
67
|
operation_id="list_models_v1_models_get",
|
|
65
68
|
oauth2_scopes=[],
|
|
66
69
|
security_source=get_security_from_env(
|
|
@@ -72,12 +75,14 @@ class Models(BaseSDK):
|
|
|
72
75
|
retry_config=retry_config,
|
|
73
76
|
)
|
|
74
77
|
|
|
75
|
-
|
|
78
|
+
response_data: Any = None
|
|
76
79
|
if utils.match_response(http_res, "200", "application/json"):
|
|
77
80
|
return utils.unmarshal_json(http_res.text, models.ModelList)
|
|
78
81
|
if utils.match_response(http_res, "422", "application/json"):
|
|
79
|
-
|
|
80
|
-
|
|
82
|
+
response_data = utils.unmarshal_json(
|
|
83
|
+
http_res.text, models.HTTPValidationErrorData
|
|
84
|
+
)
|
|
85
|
+
raise models.HTTPValidationError(data=response_data)
|
|
81
86
|
if utils.match_response(http_res, "4XX", "*"):
|
|
82
87
|
http_res_text = utils.stream_to_text(http_res)
|
|
83
88
|
raise models.SDKError(
|
|
@@ -122,6 +127,8 @@ class Models(BaseSDK):
|
|
|
122
127
|
|
|
123
128
|
if server_url is not None:
|
|
124
129
|
base_url = server_url
|
|
130
|
+
else:
|
|
131
|
+
base_url = self._get_url(base_url, url_variables)
|
|
125
132
|
req = self._build_request_async(
|
|
126
133
|
method="GET",
|
|
127
134
|
path="/v1/models",
|
|
@@ -148,6 +155,7 @@ class Models(BaseSDK):
|
|
|
148
155
|
|
|
149
156
|
http_res = await self.do_request_async(
|
|
150
157
|
hook_ctx=HookContext(
|
|
158
|
+
base_url=base_url or "",
|
|
151
159
|
operation_id="list_models_v1_models_get",
|
|
152
160
|
oauth2_scopes=[],
|
|
153
161
|
security_source=get_security_from_env(
|
|
@@ -159,12 +167,14 @@ class Models(BaseSDK):
|
|
|
159
167
|
retry_config=retry_config,
|
|
160
168
|
)
|
|
161
169
|
|
|
162
|
-
|
|
170
|
+
response_data: Any = None
|
|
163
171
|
if utils.match_response(http_res, "200", "application/json"):
|
|
164
172
|
return utils.unmarshal_json(http_res.text, models.ModelList)
|
|
165
173
|
if utils.match_response(http_res, "422", "application/json"):
|
|
166
|
-
|
|
167
|
-
|
|
174
|
+
response_data = utils.unmarshal_json(
|
|
175
|
+
http_res.text, models.HTTPValidationErrorData
|
|
176
|
+
)
|
|
177
|
+
raise models.HTTPValidationError(data=response_data)
|
|
168
178
|
if utils.match_response(http_res, "4XX", "*"):
|
|
169
179
|
http_res_text = await utils.stream_to_text_async(http_res)
|
|
170
180
|
raise models.SDKError(
|
|
@@ -211,6 +221,8 @@ class Models(BaseSDK):
|
|
|
211
221
|
|
|
212
222
|
if server_url is not None:
|
|
213
223
|
base_url = server_url
|
|
224
|
+
else:
|
|
225
|
+
base_url = self._get_url(base_url, url_variables)
|
|
214
226
|
|
|
215
227
|
request = models.RetrieveModelV1ModelsModelIDGetRequest(
|
|
216
228
|
model_id=model_id,
|
|
@@ -242,6 +254,7 @@ class Models(BaseSDK):
|
|
|
242
254
|
|
|
243
255
|
http_res = self.do_request(
|
|
244
256
|
hook_ctx=HookContext(
|
|
257
|
+
base_url=base_url or "",
|
|
245
258
|
operation_id="retrieve_model_v1_models__model_id__get",
|
|
246
259
|
oauth2_scopes=[],
|
|
247
260
|
security_source=get_security_from_env(
|
|
@@ -253,15 +266,17 @@ class Models(BaseSDK):
|
|
|
253
266
|
retry_config=retry_config,
|
|
254
267
|
)
|
|
255
268
|
|
|
256
|
-
|
|
269
|
+
response_data: Any = None
|
|
257
270
|
if utils.match_response(http_res, "200", "application/json"):
|
|
258
271
|
return utils.unmarshal_json(
|
|
259
272
|
http_res.text,
|
|
260
273
|
models.RetrieveModelV1ModelsModelIDGetResponseRetrieveModelV1ModelsModelIDGet,
|
|
261
274
|
)
|
|
262
275
|
if utils.match_response(http_res, "422", "application/json"):
|
|
263
|
-
|
|
264
|
-
|
|
276
|
+
response_data = utils.unmarshal_json(
|
|
277
|
+
http_res.text, models.HTTPValidationErrorData
|
|
278
|
+
)
|
|
279
|
+
raise models.HTTPValidationError(data=response_data)
|
|
265
280
|
if utils.match_response(http_res, "4XX", "*"):
|
|
266
281
|
http_res_text = utils.stream_to_text(http_res)
|
|
267
282
|
raise models.SDKError(
|
|
@@ -308,6 +323,8 @@ class Models(BaseSDK):
|
|
|
308
323
|
|
|
309
324
|
if server_url is not None:
|
|
310
325
|
base_url = server_url
|
|
326
|
+
else:
|
|
327
|
+
base_url = self._get_url(base_url, url_variables)
|
|
311
328
|
|
|
312
329
|
request = models.RetrieveModelV1ModelsModelIDGetRequest(
|
|
313
330
|
model_id=model_id,
|
|
@@ -339,6 +356,7 @@ class Models(BaseSDK):
|
|
|
339
356
|
|
|
340
357
|
http_res = await self.do_request_async(
|
|
341
358
|
hook_ctx=HookContext(
|
|
359
|
+
base_url=base_url or "",
|
|
342
360
|
operation_id="retrieve_model_v1_models__model_id__get",
|
|
343
361
|
oauth2_scopes=[],
|
|
344
362
|
security_source=get_security_from_env(
|
|
@@ -350,15 +368,17 @@ class Models(BaseSDK):
|
|
|
350
368
|
retry_config=retry_config,
|
|
351
369
|
)
|
|
352
370
|
|
|
353
|
-
|
|
371
|
+
response_data: Any = None
|
|
354
372
|
if utils.match_response(http_res, "200", "application/json"):
|
|
355
373
|
return utils.unmarshal_json(
|
|
356
374
|
http_res.text,
|
|
357
375
|
models.RetrieveModelV1ModelsModelIDGetResponseRetrieveModelV1ModelsModelIDGet,
|
|
358
376
|
)
|
|
359
377
|
if utils.match_response(http_res, "422", "application/json"):
|
|
360
|
-
|
|
361
|
-
|
|
378
|
+
response_data = utils.unmarshal_json(
|
|
379
|
+
http_res.text, models.HTTPValidationErrorData
|
|
380
|
+
)
|
|
381
|
+
raise models.HTTPValidationError(data=response_data)
|
|
362
382
|
if utils.match_response(http_res, "4XX", "*"):
|
|
363
383
|
http_res_text = await utils.stream_to_text_async(http_res)
|
|
364
384
|
raise models.SDKError(
|
|
@@ -405,6 +425,8 @@ class Models(BaseSDK):
|
|
|
405
425
|
|
|
406
426
|
if server_url is not None:
|
|
407
427
|
base_url = server_url
|
|
428
|
+
else:
|
|
429
|
+
base_url = self._get_url(base_url, url_variables)
|
|
408
430
|
|
|
409
431
|
request = models.DeleteModelV1ModelsModelIDDeleteRequest(
|
|
410
432
|
model_id=model_id,
|
|
@@ -436,6 +458,7 @@ class Models(BaseSDK):
|
|
|
436
458
|
|
|
437
459
|
http_res = self.do_request(
|
|
438
460
|
hook_ctx=HookContext(
|
|
461
|
+
base_url=base_url or "",
|
|
439
462
|
operation_id="delete_model_v1_models__model_id__delete",
|
|
440
463
|
oauth2_scopes=[],
|
|
441
464
|
security_source=get_security_from_env(
|
|
@@ -447,12 +470,14 @@ class Models(BaseSDK):
|
|
|
447
470
|
retry_config=retry_config,
|
|
448
471
|
)
|
|
449
472
|
|
|
450
|
-
|
|
473
|
+
response_data: Any = None
|
|
451
474
|
if utils.match_response(http_res, "200", "application/json"):
|
|
452
475
|
return utils.unmarshal_json(http_res.text, models.DeleteModelOut)
|
|
453
476
|
if utils.match_response(http_res, "422", "application/json"):
|
|
454
|
-
|
|
455
|
-
|
|
477
|
+
response_data = utils.unmarshal_json(
|
|
478
|
+
http_res.text, models.HTTPValidationErrorData
|
|
479
|
+
)
|
|
480
|
+
raise models.HTTPValidationError(data=response_data)
|
|
456
481
|
if utils.match_response(http_res, "4XX", "*"):
|
|
457
482
|
http_res_text = utils.stream_to_text(http_res)
|
|
458
483
|
raise models.SDKError(
|
|
@@ -499,6 +524,8 @@ class Models(BaseSDK):
|
|
|
499
524
|
|
|
500
525
|
if server_url is not None:
|
|
501
526
|
base_url = server_url
|
|
527
|
+
else:
|
|
528
|
+
base_url = self._get_url(base_url, url_variables)
|
|
502
529
|
|
|
503
530
|
request = models.DeleteModelV1ModelsModelIDDeleteRequest(
|
|
504
531
|
model_id=model_id,
|
|
@@ -530,6 +557,7 @@ class Models(BaseSDK):
|
|
|
530
557
|
|
|
531
558
|
http_res = await self.do_request_async(
|
|
532
559
|
hook_ctx=HookContext(
|
|
560
|
+
base_url=base_url or "",
|
|
533
561
|
operation_id="delete_model_v1_models__model_id__delete",
|
|
534
562
|
oauth2_scopes=[],
|
|
535
563
|
security_source=get_security_from_env(
|
|
@@ -541,12 +569,14 @@ class Models(BaseSDK):
|
|
|
541
569
|
retry_config=retry_config,
|
|
542
570
|
)
|
|
543
571
|
|
|
544
|
-
|
|
572
|
+
response_data: Any = None
|
|
545
573
|
if utils.match_response(http_res, "200", "application/json"):
|
|
546
574
|
return utils.unmarshal_json(http_res.text, models.DeleteModelOut)
|
|
547
575
|
if utils.match_response(http_res, "422", "application/json"):
|
|
548
|
-
|
|
549
|
-
|
|
576
|
+
response_data = utils.unmarshal_json(
|
|
577
|
+
http_res.text, models.HTTPValidationErrorData
|
|
578
|
+
)
|
|
579
|
+
raise models.HTTPValidationError(data=response_data)
|
|
550
580
|
if utils.match_response(http_res, "4XX", "*"):
|
|
551
581
|
http_res_text = await utils.stream_to_text_async(http_res)
|
|
552
582
|
raise models.SDKError(
|
|
@@ -597,6 +627,8 @@ class Models(BaseSDK):
|
|
|
597
627
|
|
|
598
628
|
if server_url is not None:
|
|
599
629
|
base_url = server_url
|
|
630
|
+
else:
|
|
631
|
+
base_url = self._get_url(base_url, url_variables)
|
|
600
632
|
|
|
601
633
|
request = models.JobsAPIRoutesFineTuningUpdateFineTunedModelRequest(
|
|
602
634
|
model_id=model_id,
|
|
@@ -635,6 +667,7 @@ class Models(BaseSDK):
|
|
|
635
667
|
|
|
636
668
|
http_res = self.do_request(
|
|
637
669
|
hook_ctx=HookContext(
|
|
670
|
+
base_url=base_url or "",
|
|
638
671
|
operation_id="jobs_api_routes_fine_tuning_update_fine_tuned_model",
|
|
639
672
|
oauth2_scopes=[],
|
|
640
673
|
security_source=get_security_from_env(
|
|
@@ -698,6 +731,8 @@ class Models(BaseSDK):
|
|
|
698
731
|
|
|
699
732
|
if server_url is not None:
|
|
700
733
|
base_url = server_url
|
|
734
|
+
else:
|
|
735
|
+
base_url = self._get_url(base_url, url_variables)
|
|
701
736
|
|
|
702
737
|
request = models.JobsAPIRoutesFineTuningUpdateFineTunedModelRequest(
|
|
703
738
|
model_id=model_id,
|
|
@@ -736,6 +771,7 @@ class Models(BaseSDK):
|
|
|
736
771
|
|
|
737
772
|
http_res = await self.do_request_async(
|
|
738
773
|
hook_ctx=HookContext(
|
|
774
|
+
base_url=base_url or "",
|
|
739
775
|
operation_id="jobs_api_routes_fine_tuning_update_fine_tuned_model",
|
|
740
776
|
oauth2_scopes=[],
|
|
741
777
|
security_source=get_security_from_env(
|
|
@@ -795,6 +831,8 @@ class Models(BaseSDK):
|
|
|
795
831
|
|
|
796
832
|
if server_url is not None:
|
|
797
833
|
base_url = server_url
|
|
834
|
+
else:
|
|
835
|
+
base_url = self._get_url(base_url, url_variables)
|
|
798
836
|
|
|
799
837
|
request = models.JobsAPIRoutesFineTuningArchiveFineTunedModelRequest(
|
|
800
838
|
model_id=model_id,
|
|
@@ -826,6 +864,7 @@ class Models(BaseSDK):
|
|
|
826
864
|
|
|
827
865
|
http_res = self.do_request(
|
|
828
866
|
hook_ctx=HookContext(
|
|
867
|
+
base_url=base_url or "",
|
|
829
868
|
operation_id="jobs_api_routes_fine_tuning_archive_fine_tuned_model",
|
|
830
869
|
oauth2_scopes=[],
|
|
831
870
|
security_source=get_security_from_env(
|
|
@@ -885,6 +924,8 @@ class Models(BaseSDK):
|
|
|
885
924
|
|
|
886
925
|
if server_url is not None:
|
|
887
926
|
base_url = server_url
|
|
927
|
+
else:
|
|
928
|
+
base_url = self._get_url(base_url, url_variables)
|
|
888
929
|
|
|
889
930
|
request = models.JobsAPIRoutesFineTuningArchiveFineTunedModelRequest(
|
|
890
931
|
model_id=model_id,
|
|
@@ -916,6 +957,7 @@ class Models(BaseSDK):
|
|
|
916
957
|
|
|
917
958
|
http_res = await self.do_request_async(
|
|
918
959
|
hook_ctx=HookContext(
|
|
960
|
+
base_url=base_url or "",
|
|
919
961
|
operation_id="jobs_api_routes_fine_tuning_archive_fine_tuned_model",
|
|
920
962
|
oauth2_scopes=[],
|
|
921
963
|
security_source=get_security_from_env(
|
|
@@ -975,6 +1017,8 @@ class Models(BaseSDK):
|
|
|
975
1017
|
|
|
976
1018
|
if server_url is not None:
|
|
977
1019
|
base_url = server_url
|
|
1020
|
+
else:
|
|
1021
|
+
base_url = self._get_url(base_url, url_variables)
|
|
978
1022
|
|
|
979
1023
|
request = models.JobsAPIRoutesFineTuningUnarchiveFineTunedModelRequest(
|
|
980
1024
|
model_id=model_id,
|
|
@@ -1006,6 +1050,7 @@ class Models(BaseSDK):
|
|
|
1006
1050
|
|
|
1007
1051
|
http_res = self.do_request(
|
|
1008
1052
|
hook_ctx=HookContext(
|
|
1053
|
+
base_url=base_url or "",
|
|
1009
1054
|
operation_id="jobs_api_routes_fine_tuning_unarchive_fine_tuned_model",
|
|
1010
1055
|
oauth2_scopes=[],
|
|
1011
1056
|
security_source=get_security_from_env(
|
|
@@ -1065,6 +1110,8 @@ class Models(BaseSDK):
|
|
|
1065
1110
|
|
|
1066
1111
|
if server_url is not None:
|
|
1067
1112
|
base_url = server_url
|
|
1113
|
+
else:
|
|
1114
|
+
base_url = self._get_url(base_url, url_variables)
|
|
1068
1115
|
|
|
1069
1116
|
request = models.JobsAPIRoutesFineTuningUnarchiveFineTunedModelRequest(
|
|
1070
1117
|
model_id=model_id,
|
|
@@ -1096,6 +1143,7 @@ class Models(BaseSDK):
|
|
|
1096
1143
|
|
|
1097
1144
|
http_res = await self.do_request_async(
|
|
1098
1145
|
hook_ctx=HookContext(
|
|
1146
|
+
base_url=base_url or "",
|
|
1099
1147
|
operation_id="jobs_api_routes_fine_tuning_unarchive_fine_tuned_model",
|
|
1100
1148
|
oauth2_scopes=[],
|
|
1101
1149
|
security_source=get_security_from_env(
|
mistralai/ocr.py
CHANGED
|
@@ -47,6 +47,8 @@ class Ocr(BaseSDK):
|
|
|
47
47
|
|
|
48
48
|
if server_url is not None:
|
|
49
49
|
base_url = server_url
|
|
50
|
+
else:
|
|
51
|
+
base_url = self._get_url(base_url, url_variables)
|
|
50
52
|
|
|
51
53
|
request = models.OCRRequest(
|
|
52
54
|
model=model,
|
|
@@ -87,6 +89,7 @@ class Ocr(BaseSDK):
|
|
|
87
89
|
|
|
88
90
|
http_res = self.do_request(
|
|
89
91
|
hook_ctx=HookContext(
|
|
92
|
+
base_url=base_url or "",
|
|
90
93
|
operation_id="ocr_v1_ocr_post",
|
|
91
94
|
oauth2_scopes=[],
|
|
92
95
|
security_source=get_security_from_env(
|
|
@@ -98,12 +101,14 @@ class Ocr(BaseSDK):
|
|
|
98
101
|
retry_config=retry_config,
|
|
99
102
|
)
|
|
100
103
|
|
|
101
|
-
|
|
104
|
+
response_data: Any = None
|
|
102
105
|
if utils.match_response(http_res, "200", "application/json"):
|
|
103
106
|
return utils.unmarshal_json(http_res.text, models.OCRResponse)
|
|
104
107
|
if utils.match_response(http_res, "422", "application/json"):
|
|
105
|
-
|
|
106
|
-
|
|
108
|
+
response_data = utils.unmarshal_json(
|
|
109
|
+
http_res.text, models.HTTPValidationErrorData
|
|
110
|
+
)
|
|
111
|
+
raise models.HTTPValidationError(data=response_data)
|
|
107
112
|
if utils.match_response(http_res, "4XX", "*"):
|
|
108
113
|
http_res_text = utils.stream_to_text(http_res)
|
|
109
114
|
raise models.SDKError(
|
|
@@ -160,6 +165,8 @@ class Ocr(BaseSDK):
|
|
|
160
165
|
|
|
161
166
|
if server_url is not None:
|
|
162
167
|
base_url = server_url
|
|
168
|
+
else:
|
|
169
|
+
base_url = self._get_url(base_url, url_variables)
|
|
163
170
|
|
|
164
171
|
request = models.OCRRequest(
|
|
165
172
|
model=model,
|
|
@@ -200,6 +207,7 @@ class Ocr(BaseSDK):
|
|
|
200
207
|
|
|
201
208
|
http_res = await self.do_request_async(
|
|
202
209
|
hook_ctx=HookContext(
|
|
210
|
+
base_url=base_url or "",
|
|
203
211
|
operation_id="ocr_v1_ocr_post",
|
|
204
212
|
oauth2_scopes=[],
|
|
205
213
|
security_source=get_security_from_env(
|
|
@@ -211,12 +219,14 @@ class Ocr(BaseSDK):
|
|
|
211
219
|
retry_config=retry_config,
|
|
212
220
|
)
|
|
213
221
|
|
|
214
|
-
|
|
222
|
+
response_data: Any = None
|
|
215
223
|
if utils.match_response(http_res, "200", "application/json"):
|
|
216
224
|
return utils.unmarshal_json(http_res.text, models.OCRResponse)
|
|
217
225
|
if utils.match_response(http_res, "422", "application/json"):
|
|
218
|
-
|
|
219
|
-
|
|
226
|
+
response_data = utils.unmarshal_json(
|
|
227
|
+
http_res.text, models.HTTPValidationErrorData
|
|
228
|
+
)
|
|
229
|
+
raise models.HTTPValidationError(data=response_data)
|
|
220
230
|
if utils.match_response(http_res, "4XX", "*"):
|
|
221
231
|
http_res_text = await utils.stream_to_text_async(http_res)
|
|
222
232
|
raise models.SDKError(
|
mistralai/sdk.py
CHANGED
|
@@ -68,15 +68,19 @@ class Mistral(BaseSDK):
|
|
|
68
68
|
:param retry_config: The retry configuration to use for all supported methods
|
|
69
69
|
:param timeout_ms: Optional request timeout applied to each operation in milliseconds
|
|
70
70
|
"""
|
|
71
|
+
client_supplied = True
|
|
71
72
|
if client is None:
|
|
72
73
|
client = httpx.Client()
|
|
74
|
+
client_supplied = False
|
|
73
75
|
|
|
74
76
|
assert issubclass(
|
|
75
77
|
type(client), HttpClient
|
|
76
78
|
), "The provided client must implement the HttpClient protocol."
|
|
77
79
|
|
|
80
|
+
async_client_supplied = True
|
|
78
81
|
if async_client is None:
|
|
79
82
|
async_client = httpx.AsyncClient()
|
|
83
|
+
async_client_supplied = False
|
|
80
84
|
|
|
81
85
|
if debug_logger is None:
|
|
82
86
|
debug_logger = get_default_logger()
|
|
@@ -100,7 +104,9 @@ class Mistral(BaseSDK):
|
|
|
100
104
|
self,
|
|
101
105
|
SDKConfiguration(
|
|
102
106
|
client=client,
|
|
107
|
+
client_supplied=client_supplied,
|
|
103
108
|
async_client=async_client,
|
|
109
|
+
async_client_supplied=async_client_supplied,
|
|
104
110
|
security=security,
|
|
105
111
|
server_url=server_url,
|
|
106
112
|
server=server,
|
|
@@ -114,7 +120,7 @@ class Mistral(BaseSDK):
|
|
|
114
120
|
|
|
115
121
|
current_server_url, *_ = self.sdk_configuration.get_server_details()
|
|
116
122
|
server_url, self.sdk_configuration.client = hooks.sdk_init(
|
|
117
|
-
current_server_url,
|
|
123
|
+
current_server_url, client
|
|
118
124
|
)
|
|
119
125
|
if current_server_url != server_url:
|
|
120
126
|
self.sdk_configuration.server_url = server_url
|
|
@@ -127,7 +133,9 @@ class Mistral(BaseSDK):
|
|
|
127
133
|
close_clients,
|
|
128
134
|
cast(ClientOwner, self.sdk_configuration),
|
|
129
135
|
self.sdk_configuration.client,
|
|
136
|
+
self.sdk_configuration.client_supplied,
|
|
130
137
|
self.sdk_configuration.async_client,
|
|
138
|
+
self.sdk_configuration.async_client_supplied,
|
|
131
139
|
)
|
|
132
140
|
|
|
133
141
|
self._init_sdks()
|
|
@@ -151,9 +159,17 @@ class Mistral(BaseSDK):
|
|
|
151
159
|
return self
|
|
152
160
|
|
|
153
161
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
154
|
-
if
|
|
162
|
+
if (
|
|
163
|
+
self.sdk_configuration.client is not None
|
|
164
|
+
and not self.sdk_configuration.client_supplied
|
|
165
|
+
):
|
|
155
166
|
self.sdk_configuration.client.close()
|
|
167
|
+
self.sdk_configuration.client = None
|
|
156
168
|
|
|
157
169
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
158
|
-
if
|
|
170
|
+
if (
|
|
171
|
+
self.sdk_configuration.async_client is not None
|
|
172
|
+
and not self.sdk_configuration.async_client_supplied
|
|
173
|
+
):
|
|
159
174
|
await self.sdk_configuration.async_client.aclose()
|
|
175
|
+
self.sdk_configuration.async_client = None
|
mistralai/sdkconfiguration.py
CHANGED
|
@@ -26,8 +26,10 @@ SERVERS = {
|
|
|
26
26
|
|
|
27
27
|
@dataclass
|
|
28
28
|
class SDKConfiguration:
|
|
29
|
-
client: HttpClient
|
|
30
|
-
|
|
29
|
+
client: Union[HttpClient, None]
|
|
30
|
+
client_supplied: bool
|
|
31
|
+
async_client: Union[AsyncHttpClient, None]
|
|
32
|
+
async_client_supplied: bool
|
|
31
33
|
debug_logger: Logger
|
|
32
34
|
security: Optional[Union[models.Security, Callable[[], models.Security]]] = None
|
|
33
35
|
server_url: Optional[str] = ""
|
mistralai/utils/__init__.py
CHANGED
|
@@ -43,6 +43,7 @@ from .values import (
|
|
|
43
43
|
match_content_type,
|
|
44
44
|
match_status_codes,
|
|
45
45
|
match_response,
|
|
46
|
+
cast_partial,
|
|
46
47
|
)
|
|
47
48
|
from .logger import Logger, get_body_content, get_default_logger
|
|
48
49
|
|
|
@@ -96,4 +97,5 @@ __all__ = [
|
|
|
96
97
|
"validate_float",
|
|
97
98
|
"validate_int",
|
|
98
99
|
"validate_open_enum",
|
|
100
|
+
"cast_partial",
|
|
99
101
|
]
|
mistralai/utils/serializers.py
CHANGED
|
@@ -7,14 +7,15 @@ import httpx
|
|
|
7
7
|
from typing_extensions import get_origin
|
|
8
8
|
from pydantic import ConfigDict, create_model
|
|
9
9
|
from pydantic_core import from_json
|
|
10
|
-
from
|
|
10
|
+
from typing_inspection.typing_objects import is_union
|
|
11
11
|
|
|
12
12
|
from ..types.basemodel import BaseModel, Nullable, OptionalNullable, Unset
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def serialize_decimal(as_str: bool):
|
|
16
16
|
def serialize(d):
|
|
17
|
-
|
|
17
|
+
# Optional[T] is a Union[T, None]
|
|
18
|
+
if is_union(type(d)) and type(None) in get_args(type(d)) and d is None:
|
|
18
19
|
return None
|
|
19
20
|
if isinstance(d, Unset):
|
|
20
21
|
return d
|
|
@@ -42,7 +43,8 @@ def validate_decimal(d):
|
|
|
42
43
|
|
|
43
44
|
def serialize_float(as_str: bool):
|
|
44
45
|
def serialize(f):
|
|
45
|
-
|
|
46
|
+
# Optional[T] is a Union[T, None]
|
|
47
|
+
if is_union(type(f)) and type(None) in get_args(type(f)) and f is None:
|
|
46
48
|
return None
|
|
47
49
|
if isinstance(f, Unset):
|
|
48
50
|
return f
|
|
@@ -70,7 +72,8 @@ def validate_float(f):
|
|
|
70
72
|
|
|
71
73
|
def serialize_int(as_str: bool):
|
|
72
74
|
def serialize(i):
|
|
73
|
-
|
|
75
|
+
# Optional[T] is a Union[T, None]
|
|
76
|
+
if is_union(type(i)) and type(None) in get_args(type(i)) and i is None:
|
|
74
77
|
return None
|
|
75
78
|
if isinstance(i, Unset):
|
|
76
79
|
return i
|
|
@@ -118,7 +121,8 @@ def validate_open_enum(is_int: bool):
|
|
|
118
121
|
|
|
119
122
|
def validate_const(v):
|
|
120
123
|
def validate(c):
|
|
121
|
-
|
|
124
|
+
# Optional[T] is a Union[T, None]
|
|
125
|
+
if is_union(type(c)) and type(None) in get_args(type(c)) and c is None:
|
|
122
126
|
return None
|
|
123
127
|
|
|
124
128
|
if v != c:
|
|
@@ -163,7 +167,7 @@ def marshal_json(val, typ):
|
|
|
163
167
|
if len(d) == 0:
|
|
164
168
|
return ""
|
|
165
169
|
|
|
166
|
-
return json.dumps(d[next(iter(d))], separators=(",", ":")
|
|
170
|
+
return json.dumps(d[next(iter(d))], separators=(",", ":"))
|
|
167
171
|
|
|
168
172
|
|
|
169
173
|
def is_nullable(field):
|
mistralai/utils/values.py
CHANGED
|
@@ -3,8 +3,9 @@
|
|
|
3
3
|
from datetime import datetime
|
|
4
4
|
from enum import Enum
|
|
5
5
|
from email.message import Message
|
|
6
|
+
from functools import partial
|
|
6
7
|
import os
|
|
7
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
|
|
8
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union, cast
|
|
8
9
|
|
|
9
10
|
from httpx import Response
|
|
10
11
|
from pydantic import BaseModel
|
|
@@ -51,6 +52,8 @@ def match_status_codes(status_codes: List[str], status_code: int) -> bool:
|
|
|
51
52
|
|
|
52
53
|
T = TypeVar("T")
|
|
53
54
|
|
|
55
|
+
def cast_partial(typ):
|
|
56
|
+
return partial(cast, typ)
|
|
54
57
|
|
|
55
58
|
def get_global_from_env(
|
|
56
59
|
value: Optional[T], env_key: str, type_cast: Callable[[str], T]
|