mistralai 1.5.2__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. mistralai/_hooks/types.py +15 -3
  2. mistralai/_version.py +3 -3
  3. mistralai/agents.py +44 -12
  4. mistralai/basesdk.py +8 -0
  5. mistralai/chat.py +44 -12
  6. mistralai/classifiers.py +36 -16
  7. mistralai/embeddings.py +16 -6
  8. mistralai/files.py +36 -0
  9. mistralai/fim.py +32 -12
  10. mistralai/httpclient.py +4 -2
  11. mistralai/jobs.py +30 -0
  12. mistralai/mistral_jobs.py +24 -0
  13. mistralai/models/agentscompletionrequest.py +4 -0
  14. mistralai/models/agentscompletionstreamrequest.py +4 -0
  15. mistralai/models/chatcompletionrequest.py +4 -0
  16. mistralai/models/chatcompletionstreamrequest.py +4 -0
  17. mistralai/models/function.py +2 -2
  18. mistralai/models/jsonschema.py +1 -1
  19. mistralai/models_.py +66 -18
  20. mistralai/ocr.py +16 -6
  21. mistralai/sdk.py +19 -3
  22. mistralai/sdkconfiguration.py +4 -2
  23. mistralai/utils/__init__.py +2 -0
  24. mistralai/utils/serializers.py +10 -6
  25. mistralai/utils/values.py +4 -1
  26. {mistralai-1.5.2.dist-info → mistralai-1.6.0.dist-info}/METADATA +63 -16
  27. {mistralai-1.5.2.dist-info → mistralai-1.6.0.dist-info}/RECORD +80 -72
  28. mistralai_azure/__init__.py +10 -1
  29. mistralai_azure/_hooks/types.py +15 -3
  30. mistralai_azure/_version.py +4 -1
  31. mistralai_azure/basesdk.py +8 -0
  32. mistralai_azure/chat.py +100 -20
  33. mistralai_azure/httpclient.py +52 -0
  34. mistralai_azure/models/__init__.py +22 -0
  35. mistralai_azure/models/assistantmessage.py +2 -0
  36. mistralai_azure/models/chatcompletionrequest.py +12 -10
  37. mistralai_azure/models/chatcompletionstreamrequest.py +12 -10
  38. mistralai_azure/models/contentchunk.py +6 -2
  39. mistralai_azure/models/function.py +4 -1
  40. mistralai_azure/models/imageurl.py +53 -0
  41. mistralai_azure/models/imageurlchunk.py +33 -0
  42. mistralai_azure/models/jsonschema.py +61 -0
  43. mistralai_azure/models/prediction.py +25 -0
  44. mistralai_azure/models/responseformat.py +42 -1
  45. mistralai_azure/models/responseformats.py +1 -1
  46. mistralai_azure/models/toolcall.py +3 -0
  47. mistralai_azure/sdk.py +56 -14
  48. mistralai_azure/sdkconfiguration.py +14 -6
  49. mistralai_azure/utils/__init__.py +2 -0
  50. mistralai_azure/utils/serializers.py +10 -6
  51. mistralai_azure/utils/values.py +4 -1
  52. mistralai_gcp/__init__.py +10 -1
  53. mistralai_gcp/_hooks/types.py +15 -3
  54. mistralai_gcp/_version.py +4 -1
  55. mistralai_gcp/basesdk.py +8 -0
  56. mistralai_gcp/chat.py +101 -21
  57. mistralai_gcp/fim.py +61 -21
  58. mistralai_gcp/httpclient.py +52 -0
  59. mistralai_gcp/models/__init__.py +22 -0
  60. mistralai_gcp/models/assistantmessage.py +2 -0
  61. mistralai_gcp/models/chatcompletionrequest.py +12 -10
  62. mistralai_gcp/models/chatcompletionstreamrequest.py +12 -10
  63. mistralai_gcp/models/contentchunk.py +6 -2
  64. mistralai_gcp/models/fimcompletionrequest.py +2 -3
  65. mistralai_gcp/models/fimcompletionstreamrequest.py +2 -3
  66. mistralai_gcp/models/function.py +4 -1
  67. mistralai_gcp/models/imageurl.py +53 -0
  68. mistralai_gcp/models/imageurlchunk.py +33 -0
  69. mistralai_gcp/models/jsonschema.py +61 -0
  70. mistralai_gcp/models/prediction.py +25 -0
  71. mistralai_gcp/models/responseformat.py +42 -1
  72. mistralai_gcp/models/responseformats.py +1 -1
  73. mistralai_gcp/models/toolcall.py +3 -0
  74. mistralai_gcp/sdk.py +63 -19
  75. mistralai_gcp/sdkconfiguration.py +14 -6
  76. mistralai_gcp/utils/__init__.py +2 -0
  77. mistralai_gcp/utils/serializers.py +10 -6
  78. mistralai_gcp/utils/values.py +4 -1
  79. {mistralai-1.5.2.dist-info → mistralai-1.6.0.dist-info}/LICENSE +0 -0
  80. {mistralai-1.5.2.dist-info → mistralai-1.6.0.dist-info}/WHEEL +0 -0
mistralai/models_.py CHANGED
@@ -35,6 +35,8 @@ class Models(BaseSDK):
35
35
 
36
36
  if server_url is not None:
37
37
  base_url = server_url
38
+ else:
39
+ base_url = self._get_url(base_url, url_variables)
38
40
  req = self._build_request(
39
41
  method="GET",
40
42
  path="/v1/models",
@@ -61,6 +63,7 @@ class Models(BaseSDK):
61
63
 
62
64
  http_res = self.do_request(
63
65
  hook_ctx=HookContext(
66
+ base_url=base_url or "",
64
67
  operation_id="list_models_v1_models_get",
65
68
  oauth2_scopes=[],
66
69
  security_source=get_security_from_env(
@@ -72,12 +75,14 @@ class Models(BaseSDK):
72
75
  retry_config=retry_config,
73
76
  )
74
77
 
75
- data: Any = None
78
+ response_data: Any = None
76
79
  if utils.match_response(http_res, "200", "application/json"):
77
80
  return utils.unmarshal_json(http_res.text, models.ModelList)
78
81
  if utils.match_response(http_res, "422", "application/json"):
79
- data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
80
- raise models.HTTPValidationError(data=data)
82
+ response_data = utils.unmarshal_json(
83
+ http_res.text, models.HTTPValidationErrorData
84
+ )
85
+ raise models.HTTPValidationError(data=response_data)
81
86
  if utils.match_response(http_res, "4XX", "*"):
82
87
  http_res_text = utils.stream_to_text(http_res)
83
88
  raise models.SDKError(
@@ -122,6 +127,8 @@ class Models(BaseSDK):
122
127
 
123
128
  if server_url is not None:
124
129
  base_url = server_url
130
+ else:
131
+ base_url = self._get_url(base_url, url_variables)
125
132
  req = self._build_request_async(
126
133
  method="GET",
127
134
  path="/v1/models",
@@ -148,6 +155,7 @@ class Models(BaseSDK):
148
155
 
149
156
  http_res = await self.do_request_async(
150
157
  hook_ctx=HookContext(
158
+ base_url=base_url or "",
151
159
  operation_id="list_models_v1_models_get",
152
160
  oauth2_scopes=[],
153
161
  security_source=get_security_from_env(
@@ -159,12 +167,14 @@ class Models(BaseSDK):
159
167
  retry_config=retry_config,
160
168
  )
161
169
 
162
- data: Any = None
170
+ response_data: Any = None
163
171
  if utils.match_response(http_res, "200", "application/json"):
164
172
  return utils.unmarshal_json(http_res.text, models.ModelList)
165
173
  if utils.match_response(http_res, "422", "application/json"):
166
- data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
167
- raise models.HTTPValidationError(data=data)
174
+ response_data = utils.unmarshal_json(
175
+ http_res.text, models.HTTPValidationErrorData
176
+ )
177
+ raise models.HTTPValidationError(data=response_data)
168
178
  if utils.match_response(http_res, "4XX", "*"):
169
179
  http_res_text = await utils.stream_to_text_async(http_res)
170
180
  raise models.SDKError(
@@ -211,6 +221,8 @@ class Models(BaseSDK):
211
221
 
212
222
  if server_url is not None:
213
223
  base_url = server_url
224
+ else:
225
+ base_url = self._get_url(base_url, url_variables)
214
226
 
215
227
  request = models.RetrieveModelV1ModelsModelIDGetRequest(
216
228
  model_id=model_id,
@@ -242,6 +254,7 @@ class Models(BaseSDK):
242
254
 
243
255
  http_res = self.do_request(
244
256
  hook_ctx=HookContext(
257
+ base_url=base_url or "",
245
258
  operation_id="retrieve_model_v1_models__model_id__get",
246
259
  oauth2_scopes=[],
247
260
  security_source=get_security_from_env(
@@ -253,15 +266,17 @@ class Models(BaseSDK):
253
266
  retry_config=retry_config,
254
267
  )
255
268
 
256
- data: Any = None
269
+ response_data: Any = None
257
270
  if utils.match_response(http_res, "200", "application/json"):
258
271
  return utils.unmarshal_json(
259
272
  http_res.text,
260
273
  models.RetrieveModelV1ModelsModelIDGetResponseRetrieveModelV1ModelsModelIDGet,
261
274
  )
262
275
  if utils.match_response(http_res, "422", "application/json"):
263
- data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
264
- raise models.HTTPValidationError(data=data)
276
+ response_data = utils.unmarshal_json(
277
+ http_res.text, models.HTTPValidationErrorData
278
+ )
279
+ raise models.HTTPValidationError(data=response_data)
265
280
  if utils.match_response(http_res, "4XX", "*"):
266
281
  http_res_text = utils.stream_to_text(http_res)
267
282
  raise models.SDKError(
@@ -308,6 +323,8 @@ class Models(BaseSDK):
308
323
 
309
324
  if server_url is not None:
310
325
  base_url = server_url
326
+ else:
327
+ base_url = self._get_url(base_url, url_variables)
311
328
 
312
329
  request = models.RetrieveModelV1ModelsModelIDGetRequest(
313
330
  model_id=model_id,
@@ -339,6 +356,7 @@ class Models(BaseSDK):
339
356
 
340
357
  http_res = await self.do_request_async(
341
358
  hook_ctx=HookContext(
359
+ base_url=base_url or "",
342
360
  operation_id="retrieve_model_v1_models__model_id__get",
343
361
  oauth2_scopes=[],
344
362
  security_source=get_security_from_env(
@@ -350,15 +368,17 @@ class Models(BaseSDK):
350
368
  retry_config=retry_config,
351
369
  )
352
370
 
353
- data: Any = None
371
+ response_data: Any = None
354
372
  if utils.match_response(http_res, "200", "application/json"):
355
373
  return utils.unmarshal_json(
356
374
  http_res.text,
357
375
  models.RetrieveModelV1ModelsModelIDGetResponseRetrieveModelV1ModelsModelIDGet,
358
376
  )
359
377
  if utils.match_response(http_res, "422", "application/json"):
360
- data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
361
- raise models.HTTPValidationError(data=data)
378
+ response_data = utils.unmarshal_json(
379
+ http_res.text, models.HTTPValidationErrorData
380
+ )
381
+ raise models.HTTPValidationError(data=response_data)
362
382
  if utils.match_response(http_res, "4XX", "*"):
363
383
  http_res_text = await utils.stream_to_text_async(http_res)
364
384
  raise models.SDKError(
@@ -405,6 +425,8 @@ class Models(BaseSDK):
405
425
 
406
426
  if server_url is not None:
407
427
  base_url = server_url
428
+ else:
429
+ base_url = self._get_url(base_url, url_variables)
408
430
 
409
431
  request = models.DeleteModelV1ModelsModelIDDeleteRequest(
410
432
  model_id=model_id,
@@ -436,6 +458,7 @@ class Models(BaseSDK):
436
458
 
437
459
  http_res = self.do_request(
438
460
  hook_ctx=HookContext(
461
+ base_url=base_url or "",
439
462
  operation_id="delete_model_v1_models__model_id__delete",
440
463
  oauth2_scopes=[],
441
464
  security_source=get_security_from_env(
@@ -447,12 +470,14 @@ class Models(BaseSDK):
447
470
  retry_config=retry_config,
448
471
  )
449
472
 
450
- data: Any = None
473
+ response_data: Any = None
451
474
  if utils.match_response(http_res, "200", "application/json"):
452
475
  return utils.unmarshal_json(http_res.text, models.DeleteModelOut)
453
476
  if utils.match_response(http_res, "422", "application/json"):
454
- data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
455
- raise models.HTTPValidationError(data=data)
477
+ response_data = utils.unmarshal_json(
478
+ http_res.text, models.HTTPValidationErrorData
479
+ )
480
+ raise models.HTTPValidationError(data=response_data)
456
481
  if utils.match_response(http_res, "4XX", "*"):
457
482
  http_res_text = utils.stream_to_text(http_res)
458
483
  raise models.SDKError(
@@ -499,6 +524,8 @@ class Models(BaseSDK):
499
524
 
500
525
  if server_url is not None:
501
526
  base_url = server_url
527
+ else:
528
+ base_url = self._get_url(base_url, url_variables)
502
529
 
503
530
  request = models.DeleteModelV1ModelsModelIDDeleteRequest(
504
531
  model_id=model_id,
@@ -530,6 +557,7 @@ class Models(BaseSDK):
530
557
 
531
558
  http_res = await self.do_request_async(
532
559
  hook_ctx=HookContext(
560
+ base_url=base_url or "",
533
561
  operation_id="delete_model_v1_models__model_id__delete",
534
562
  oauth2_scopes=[],
535
563
  security_source=get_security_from_env(
@@ -541,12 +569,14 @@ class Models(BaseSDK):
541
569
  retry_config=retry_config,
542
570
  )
543
571
 
544
- data: Any = None
572
+ response_data: Any = None
545
573
  if utils.match_response(http_res, "200", "application/json"):
546
574
  return utils.unmarshal_json(http_res.text, models.DeleteModelOut)
547
575
  if utils.match_response(http_res, "422", "application/json"):
548
- data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
549
- raise models.HTTPValidationError(data=data)
576
+ response_data = utils.unmarshal_json(
577
+ http_res.text, models.HTTPValidationErrorData
578
+ )
579
+ raise models.HTTPValidationError(data=response_data)
550
580
  if utils.match_response(http_res, "4XX", "*"):
551
581
  http_res_text = await utils.stream_to_text_async(http_res)
552
582
  raise models.SDKError(
@@ -597,6 +627,8 @@ class Models(BaseSDK):
597
627
 
598
628
  if server_url is not None:
599
629
  base_url = server_url
630
+ else:
631
+ base_url = self._get_url(base_url, url_variables)
600
632
 
601
633
  request = models.JobsAPIRoutesFineTuningUpdateFineTunedModelRequest(
602
634
  model_id=model_id,
@@ -635,6 +667,7 @@ class Models(BaseSDK):
635
667
 
636
668
  http_res = self.do_request(
637
669
  hook_ctx=HookContext(
670
+ base_url=base_url or "",
638
671
  operation_id="jobs_api_routes_fine_tuning_update_fine_tuned_model",
639
672
  oauth2_scopes=[],
640
673
  security_source=get_security_from_env(
@@ -698,6 +731,8 @@ class Models(BaseSDK):
698
731
 
699
732
  if server_url is not None:
700
733
  base_url = server_url
734
+ else:
735
+ base_url = self._get_url(base_url, url_variables)
701
736
 
702
737
  request = models.JobsAPIRoutesFineTuningUpdateFineTunedModelRequest(
703
738
  model_id=model_id,
@@ -736,6 +771,7 @@ class Models(BaseSDK):
736
771
 
737
772
  http_res = await self.do_request_async(
738
773
  hook_ctx=HookContext(
774
+ base_url=base_url or "",
739
775
  operation_id="jobs_api_routes_fine_tuning_update_fine_tuned_model",
740
776
  oauth2_scopes=[],
741
777
  security_source=get_security_from_env(
@@ -795,6 +831,8 @@ class Models(BaseSDK):
795
831
 
796
832
  if server_url is not None:
797
833
  base_url = server_url
834
+ else:
835
+ base_url = self._get_url(base_url, url_variables)
798
836
 
799
837
  request = models.JobsAPIRoutesFineTuningArchiveFineTunedModelRequest(
800
838
  model_id=model_id,
@@ -826,6 +864,7 @@ class Models(BaseSDK):
826
864
 
827
865
  http_res = self.do_request(
828
866
  hook_ctx=HookContext(
867
+ base_url=base_url or "",
829
868
  operation_id="jobs_api_routes_fine_tuning_archive_fine_tuned_model",
830
869
  oauth2_scopes=[],
831
870
  security_source=get_security_from_env(
@@ -885,6 +924,8 @@ class Models(BaseSDK):
885
924
 
886
925
  if server_url is not None:
887
926
  base_url = server_url
927
+ else:
928
+ base_url = self._get_url(base_url, url_variables)
888
929
 
889
930
  request = models.JobsAPIRoutesFineTuningArchiveFineTunedModelRequest(
890
931
  model_id=model_id,
@@ -916,6 +957,7 @@ class Models(BaseSDK):
916
957
 
917
958
  http_res = await self.do_request_async(
918
959
  hook_ctx=HookContext(
960
+ base_url=base_url or "",
919
961
  operation_id="jobs_api_routes_fine_tuning_archive_fine_tuned_model",
920
962
  oauth2_scopes=[],
921
963
  security_source=get_security_from_env(
@@ -975,6 +1017,8 @@ class Models(BaseSDK):
975
1017
 
976
1018
  if server_url is not None:
977
1019
  base_url = server_url
1020
+ else:
1021
+ base_url = self._get_url(base_url, url_variables)
978
1022
 
979
1023
  request = models.JobsAPIRoutesFineTuningUnarchiveFineTunedModelRequest(
980
1024
  model_id=model_id,
@@ -1006,6 +1050,7 @@ class Models(BaseSDK):
1006
1050
 
1007
1051
  http_res = self.do_request(
1008
1052
  hook_ctx=HookContext(
1053
+ base_url=base_url or "",
1009
1054
  operation_id="jobs_api_routes_fine_tuning_unarchive_fine_tuned_model",
1010
1055
  oauth2_scopes=[],
1011
1056
  security_source=get_security_from_env(
@@ -1065,6 +1110,8 @@ class Models(BaseSDK):
1065
1110
 
1066
1111
  if server_url is not None:
1067
1112
  base_url = server_url
1113
+ else:
1114
+ base_url = self._get_url(base_url, url_variables)
1068
1115
 
1069
1116
  request = models.JobsAPIRoutesFineTuningUnarchiveFineTunedModelRequest(
1070
1117
  model_id=model_id,
@@ -1096,6 +1143,7 @@ class Models(BaseSDK):
1096
1143
 
1097
1144
  http_res = await self.do_request_async(
1098
1145
  hook_ctx=HookContext(
1146
+ base_url=base_url or "",
1099
1147
  operation_id="jobs_api_routes_fine_tuning_unarchive_fine_tuned_model",
1100
1148
  oauth2_scopes=[],
1101
1149
  security_source=get_security_from_env(
mistralai/ocr.py CHANGED
@@ -47,6 +47,8 @@ class Ocr(BaseSDK):
47
47
 
48
48
  if server_url is not None:
49
49
  base_url = server_url
50
+ else:
51
+ base_url = self._get_url(base_url, url_variables)
50
52
 
51
53
  request = models.OCRRequest(
52
54
  model=model,
@@ -87,6 +89,7 @@ class Ocr(BaseSDK):
87
89
 
88
90
  http_res = self.do_request(
89
91
  hook_ctx=HookContext(
92
+ base_url=base_url or "",
90
93
  operation_id="ocr_v1_ocr_post",
91
94
  oauth2_scopes=[],
92
95
  security_source=get_security_from_env(
@@ -98,12 +101,14 @@ class Ocr(BaseSDK):
98
101
  retry_config=retry_config,
99
102
  )
100
103
 
101
- data: Any = None
104
+ response_data: Any = None
102
105
  if utils.match_response(http_res, "200", "application/json"):
103
106
  return utils.unmarshal_json(http_res.text, models.OCRResponse)
104
107
  if utils.match_response(http_res, "422", "application/json"):
105
- data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
106
- raise models.HTTPValidationError(data=data)
108
+ response_data = utils.unmarshal_json(
109
+ http_res.text, models.HTTPValidationErrorData
110
+ )
111
+ raise models.HTTPValidationError(data=response_data)
107
112
  if utils.match_response(http_res, "4XX", "*"):
108
113
  http_res_text = utils.stream_to_text(http_res)
109
114
  raise models.SDKError(
@@ -160,6 +165,8 @@ class Ocr(BaseSDK):
160
165
 
161
166
  if server_url is not None:
162
167
  base_url = server_url
168
+ else:
169
+ base_url = self._get_url(base_url, url_variables)
163
170
 
164
171
  request = models.OCRRequest(
165
172
  model=model,
@@ -200,6 +207,7 @@ class Ocr(BaseSDK):
200
207
 
201
208
  http_res = await self.do_request_async(
202
209
  hook_ctx=HookContext(
210
+ base_url=base_url or "",
203
211
  operation_id="ocr_v1_ocr_post",
204
212
  oauth2_scopes=[],
205
213
  security_source=get_security_from_env(
@@ -211,12 +219,14 @@ class Ocr(BaseSDK):
211
219
  retry_config=retry_config,
212
220
  )
213
221
 
214
- data: Any = None
222
+ response_data: Any = None
215
223
  if utils.match_response(http_res, "200", "application/json"):
216
224
  return utils.unmarshal_json(http_res.text, models.OCRResponse)
217
225
  if utils.match_response(http_res, "422", "application/json"):
218
- data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
219
- raise models.HTTPValidationError(data=data)
226
+ response_data = utils.unmarshal_json(
227
+ http_res.text, models.HTTPValidationErrorData
228
+ )
229
+ raise models.HTTPValidationError(data=response_data)
220
230
  if utils.match_response(http_res, "4XX", "*"):
221
231
  http_res_text = await utils.stream_to_text_async(http_res)
222
232
  raise models.SDKError(
mistralai/sdk.py CHANGED
@@ -68,15 +68,19 @@ class Mistral(BaseSDK):
68
68
  :param retry_config: The retry configuration to use for all supported methods
69
69
  :param timeout_ms: Optional request timeout applied to each operation in milliseconds
70
70
  """
71
+ client_supplied = True
71
72
  if client is None:
72
73
  client = httpx.Client()
74
+ client_supplied = False
73
75
 
74
76
  assert issubclass(
75
77
  type(client), HttpClient
76
78
  ), "The provided client must implement the HttpClient protocol."
77
79
 
80
+ async_client_supplied = True
78
81
  if async_client is None:
79
82
  async_client = httpx.AsyncClient()
83
+ async_client_supplied = False
80
84
 
81
85
  if debug_logger is None:
82
86
  debug_logger = get_default_logger()
@@ -100,7 +104,9 @@ class Mistral(BaseSDK):
100
104
  self,
101
105
  SDKConfiguration(
102
106
  client=client,
107
+ client_supplied=client_supplied,
103
108
  async_client=async_client,
109
+ async_client_supplied=async_client_supplied,
104
110
  security=security,
105
111
  server_url=server_url,
106
112
  server=server,
@@ -114,7 +120,7 @@ class Mistral(BaseSDK):
114
120
 
115
121
  current_server_url, *_ = self.sdk_configuration.get_server_details()
116
122
  server_url, self.sdk_configuration.client = hooks.sdk_init(
117
- current_server_url, self.sdk_configuration.client
123
+ current_server_url, client
118
124
  )
119
125
  if current_server_url != server_url:
120
126
  self.sdk_configuration.server_url = server_url
@@ -127,7 +133,9 @@ class Mistral(BaseSDK):
127
133
  close_clients,
128
134
  cast(ClientOwner, self.sdk_configuration),
129
135
  self.sdk_configuration.client,
136
+ self.sdk_configuration.client_supplied,
130
137
  self.sdk_configuration.async_client,
138
+ self.sdk_configuration.async_client_supplied,
131
139
  )
132
140
 
133
141
  self._init_sdks()
@@ -151,9 +159,17 @@ class Mistral(BaseSDK):
151
159
  return self
152
160
 
153
161
  def __exit__(self, exc_type, exc_val, exc_tb):
154
- if self.sdk_configuration.client is not None:
162
+ if (
163
+ self.sdk_configuration.client is not None
164
+ and not self.sdk_configuration.client_supplied
165
+ ):
155
166
  self.sdk_configuration.client.close()
167
+ self.sdk_configuration.client = None
156
168
 
157
169
  async def __aexit__(self, exc_type, exc_val, exc_tb):
158
- if self.sdk_configuration.async_client is not None:
170
+ if (
171
+ self.sdk_configuration.async_client is not None
172
+ and not self.sdk_configuration.async_client_supplied
173
+ ):
159
174
  await self.sdk_configuration.async_client.aclose()
175
+ self.sdk_configuration.async_client = None
@@ -26,8 +26,10 @@ SERVERS = {
26
26
 
27
27
  @dataclass
28
28
  class SDKConfiguration:
29
- client: HttpClient
30
- async_client: AsyncHttpClient
29
+ client: Union[HttpClient, None]
30
+ client_supplied: bool
31
+ async_client: Union[AsyncHttpClient, None]
32
+ async_client_supplied: bool
31
33
  debug_logger: Logger
32
34
  security: Optional[Union[models.Security, Callable[[], models.Security]]] = None
33
35
  server_url: Optional[str] = ""
@@ -43,6 +43,7 @@ from .values import (
43
43
  match_content_type,
44
44
  match_status_codes,
45
45
  match_response,
46
+ cast_partial,
46
47
  )
47
48
  from .logger import Logger, get_body_content, get_default_logger
48
49
 
@@ -96,4 +97,5 @@ __all__ = [
96
97
  "validate_float",
97
98
  "validate_int",
98
99
  "validate_open_enum",
100
+ "cast_partial",
99
101
  ]
@@ -7,14 +7,15 @@ import httpx
7
7
  from typing_extensions import get_origin
8
8
  from pydantic import ConfigDict, create_model
9
9
  from pydantic_core import from_json
10
- from typing_inspect import is_optional_type
10
+ from typing_inspection.typing_objects import is_union
11
11
 
12
12
  from ..types.basemodel import BaseModel, Nullable, OptionalNullable, Unset
13
13
 
14
14
 
15
15
  def serialize_decimal(as_str: bool):
16
16
  def serialize(d):
17
- if is_optional_type(type(d)) and d is None:
17
+ # Optional[T] is a Union[T, None]
18
+ if is_union(type(d)) and type(None) in get_args(type(d)) and d is None:
18
19
  return None
19
20
  if isinstance(d, Unset):
20
21
  return d
@@ -42,7 +43,8 @@ def validate_decimal(d):
42
43
 
43
44
  def serialize_float(as_str: bool):
44
45
  def serialize(f):
45
- if is_optional_type(type(f)) and f is None:
46
+ # Optional[T] is a Union[T, None]
47
+ if is_union(type(f)) and type(None) in get_args(type(f)) and f is None:
46
48
  return None
47
49
  if isinstance(f, Unset):
48
50
  return f
@@ -70,7 +72,8 @@ def validate_float(f):
70
72
 
71
73
  def serialize_int(as_str: bool):
72
74
  def serialize(i):
73
- if is_optional_type(type(i)) and i is None:
75
+ # Optional[T] is a Union[T, None]
76
+ if is_union(type(i)) and type(None) in get_args(type(i)) and i is None:
74
77
  return None
75
78
  if isinstance(i, Unset):
76
79
  return i
@@ -118,7 +121,8 @@ def validate_open_enum(is_int: bool):
118
121
 
119
122
  def validate_const(v):
120
123
  def validate(c):
121
- if is_optional_type(type(c)) and c is None:
124
+ # Optional[T] is a Union[T, None]
125
+ if is_union(type(c)) and type(None) in get_args(type(c)) and c is None:
122
126
  return None
123
127
 
124
128
  if v != c:
@@ -163,7 +167,7 @@ def marshal_json(val, typ):
163
167
  if len(d) == 0:
164
168
  return ""
165
169
 
166
- return json.dumps(d[next(iter(d))], separators=(",", ":"), sort_keys=True)
170
+ return json.dumps(d[next(iter(d))], separators=(",", ":"))
167
171
 
168
172
 
169
173
  def is_nullable(field):
mistralai/utils/values.py CHANGED
@@ -3,8 +3,9 @@
3
3
  from datetime import datetime
4
4
  from enum import Enum
5
5
  from email.message import Message
6
+ from functools import partial
6
7
  import os
7
- from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union, cast
8
9
 
9
10
  from httpx import Response
10
11
  from pydantic import BaseModel
@@ -51,6 +52,8 @@ def match_status_codes(status_codes: List[str], status_code: int) -> bool:
51
52
 
52
53
  T = TypeVar("T")
53
54
 
55
+ def cast_partial(typ):
56
+ return partial(cast, typ)
54
57
 
55
58
  def get_global_from_env(
56
59
  value: Optional[T], env_key: str, type_cast: Callable[[str], T]