mistralai 1.3.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. mistralai/__init__.py +10 -1
  2. mistralai/_version.py +4 -1
  3. mistralai/agents.py +58 -14
  4. mistralai/chat.py +140 -14
  5. mistralai/classifiers.py +32 -20
  6. mistralai/embeddings.py +16 -10
  7. mistralai/extra/README.md +56 -0
  8. mistralai/extra/__init__.py +5 -0
  9. mistralai/extra/struct_chat.py +41 -0
  10. mistralai/extra/tests/__init__.py +0 -0
  11. mistralai/extra/tests/test_struct_chat.py +103 -0
  12. mistralai/extra/tests/test_utils.py +162 -0
  13. mistralai/extra/utils/__init__.py +3 -0
  14. mistralai/extra/utils/_pydantic_helper.py +20 -0
  15. mistralai/extra/utils/response_format.py +24 -0
  16. mistralai/files.py +94 -34
  17. mistralai/fim.py +30 -14
  18. mistralai/httpclient.py +50 -0
  19. mistralai/jobs.py +80 -32
  20. mistralai/mistral_jobs.py +64 -24
  21. mistralai/models/__init__.py +8 -0
  22. mistralai/models/agentscompletionrequest.py +5 -0
  23. mistralai/models/agentscompletionstreamrequest.py +5 -0
  24. mistralai/models/chatcompletionrequest.py +5 -0
  25. mistralai/models/chatcompletionstreamrequest.py +5 -0
  26. mistralai/models/fileschema.py +3 -2
  27. mistralai/models/function.py +3 -0
  28. mistralai/models/jsonschema.py +55 -0
  29. mistralai/models/prediction.py +26 -0
  30. mistralai/models/responseformat.py +36 -1
  31. mistralai/models/responseformats.py +1 -1
  32. mistralai/models/retrievefileout.py +3 -2
  33. mistralai/models/toolcall.py +3 -0
  34. mistralai/models/uploadfileout.py +3 -2
  35. mistralai/models_.py +92 -48
  36. mistralai/sdk.py +13 -3
  37. mistralai/sdkconfiguration.py +10 -4
  38. {mistralai-1.3.1.dist-info → mistralai-1.5.0.dist-info}/METADATA +41 -42
  39. {mistralai-1.3.1.dist-info → mistralai-1.5.0.dist-info}/RECORD +43 -33
  40. {mistralai-1.3.1.dist-info → mistralai-1.5.0.dist-info}/WHEEL +1 -1
  41. mistralai_azure/_hooks/custom_user_agent.py +1 -1
  42. mistralai_gcp/sdk.py +1 -2
  43. py.typed +0 -1
  44. {mistralai-1.3.1.dist-info → mistralai-1.5.0.dist-info}/LICENSE +0 -0
mistralai/__init__.py CHANGED
@@ -1,9 +1,18 @@
1
1
  """Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
2
2
 
3
- from ._version import __title__, __version__
3
+ from ._version import (
4
+ __title__,
5
+ __version__,
6
+ __openapi_doc_version__,
7
+ __gen_version__,
8
+ __user_agent__,
9
+ )
4
10
  from .sdk import *
5
11
  from .sdkconfiguration import *
6
12
  from .models import *
7
13
 
8
14
 
9
15
  VERSION: str = __version__
16
+ OPENAPI_DOC_VERSION = __openapi_doc_version__
17
+ SPEAKEASY_GENERATOR_VERSION = __gen_version__
18
+ USER_AGENT = __user_agent__
mistralai/_version.py CHANGED
@@ -3,7 +3,10 @@
3
3
  import importlib.metadata
4
4
 
5
5
  __title__: str = "mistralai"
6
- __version__: str = "1.3.1"
6
+ __version__: str = "1.5.0"
7
+ __openapi_doc_version__: str = "0.0.2"
8
+ __gen_version__: str = "2.497.0"
9
+ __user_agent__: str = "speakeasy-sdk/python 1.5.0 2.497.0 0.0.2 mistralai"
7
10
 
8
11
  try:
9
12
  if __package__ is not None:
mistralai/agents.py CHANGED
@@ -43,11 +43,14 @@ class Agents(BaseSDK):
43
43
  presence_penalty: Optional[float] = None,
44
44
  frequency_penalty: Optional[float] = None,
45
45
  n: OptionalNullable[int] = UNSET,
46
+ prediction: Optional[
47
+ Union[models.Prediction, models.PredictionTypedDict]
48
+ ] = None,
46
49
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
47
50
  server_url: Optional[str] = None,
48
51
  timeout_ms: Optional[int] = None,
49
52
  http_headers: Optional[Mapping[str, str]] = None,
50
- ) -> Optional[models.ChatCompletionResponse]:
53
+ ) -> models.ChatCompletionResponse:
51
54
  r"""Agents Completion
52
55
 
53
56
  :param messages: The prompt(s) to generate completions for, encoded as a list of dict with role and content.
@@ -62,6 +65,7 @@ class Agents(BaseSDK):
62
65
  :param presence_penalty: presence_penalty determines how much the model penalizes the repetition of words or phrases. A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative.
63
66
  :param frequency_penalty: frequency_penalty penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition.
64
67
  :param n: Number of completions to return for each request, input tokens are only billed once.
68
+ :param prediction:
65
69
  :param retries: Override the default retry configuration for this method
66
70
  :param server_url: Override the default server URL for this method
67
71
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -93,6 +97,9 @@ class Agents(BaseSDK):
93
97
  presence_penalty=presence_penalty,
94
98
  frequency_penalty=frequency_penalty,
95
99
  n=n,
100
+ prediction=utils.get_pydantic_model(
101
+ prediction, Optional[models.Prediction]
102
+ ),
96
103
  agent_id=agent_id,
97
104
  )
98
105
 
@@ -138,13 +145,16 @@ class Agents(BaseSDK):
138
145
 
139
146
  data: Any = None
140
147
  if utils.match_response(http_res, "200", "application/json"):
141
- return utils.unmarshal_json(
142
- http_res.text, Optional[models.ChatCompletionResponse]
143
- )
148
+ return utils.unmarshal_json(http_res.text, models.ChatCompletionResponse)
144
149
  if utils.match_response(http_res, "422", "application/json"):
145
150
  data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
146
151
  raise models.HTTPValidationError(data=data)
147
- if utils.match_response(http_res, ["4XX", "5XX"], "*"):
152
+ if utils.match_response(http_res, "4XX", "*"):
153
+ http_res_text = utils.stream_to_text(http_res)
154
+ raise models.SDKError(
155
+ "API error occurred", http_res.status_code, http_res_text, http_res
156
+ )
157
+ if utils.match_response(http_res, "5XX", "*"):
148
158
  http_res_text = utils.stream_to_text(http_res)
149
159
  raise models.SDKError(
150
160
  "API error occurred", http_res.status_code, http_res_text, http_res
@@ -191,11 +201,14 @@ class Agents(BaseSDK):
191
201
  presence_penalty: Optional[float] = None,
192
202
  frequency_penalty: Optional[float] = None,
193
203
  n: OptionalNullable[int] = UNSET,
204
+ prediction: Optional[
205
+ Union[models.Prediction, models.PredictionTypedDict]
206
+ ] = None,
194
207
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
195
208
  server_url: Optional[str] = None,
196
209
  timeout_ms: Optional[int] = None,
197
210
  http_headers: Optional[Mapping[str, str]] = None,
198
- ) -> Optional[models.ChatCompletionResponse]:
211
+ ) -> models.ChatCompletionResponse:
199
212
  r"""Agents Completion
200
213
 
201
214
  :param messages: The prompt(s) to generate completions for, encoded as a list of dict with role and content.
@@ -210,6 +223,7 @@ class Agents(BaseSDK):
210
223
  :param presence_penalty: presence_penalty determines how much the model penalizes the repetition of words or phrases. A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative.
211
224
  :param frequency_penalty: frequency_penalty penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition.
212
225
  :param n: Number of completions to return for each request, input tokens are only billed once.
226
+ :param prediction:
213
227
  :param retries: Override the default retry configuration for this method
214
228
  :param server_url: Override the default server URL for this method
215
229
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -241,6 +255,9 @@ class Agents(BaseSDK):
241
255
  presence_penalty=presence_penalty,
242
256
  frequency_penalty=frequency_penalty,
243
257
  n=n,
258
+ prediction=utils.get_pydantic_model(
259
+ prediction, Optional[models.Prediction]
260
+ ),
244
261
  agent_id=agent_id,
245
262
  )
246
263
 
@@ -286,13 +303,16 @@ class Agents(BaseSDK):
286
303
 
287
304
  data: Any = None
288
305
  if utils.match_response(http_res, "200", "application/json"):
289
- return utils.unmarshal_json(
290
- http_res.text, Optional[models.ChatCompletionResponse]
291
- )
306
+ return utils.unmarshal_json(http_res.text, models.ChatCompletionResponse)
292
307
  if utils.match_response(http_res, "422", "application/json"):
293
308
  data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
294
309
  raise models.HTTPValidationError(data=data)
295
- if utils.match_response(http_res, ["4XX", "5XX"], "*"):
310
+ if utils.match_response(http_res, "4XX", "*"):
311
+ http_res_text = await utils.stream_to_text_async(http_res)
312
+ raise models.SDKError(
313
+ "API error occurred", http_res.status_code, http_res_text, http_res
314
+ )
315
+ if utils.match_response(http_res, "5XX", "*"):
296
316
  http_res_text = await utils.stream_to_text_async(http_res)
297
317
  raise models.SDKError(
298
318
  "API error occurred", http_res.status_code, http_res_text, http_res
@@ -339,11 +359,14 @@ class Agents(BaseSDK):
339
359
  presence_penalty: Optional[float] = None,
340
360
  frequency_penalty: Optional[float] = None,
341
361
  n: OptionalNullable[int] = UNSET,
362
+ prediction: Optional[
363
+ Union[models.Prediction, models.PredictionTypedDict]
364
+ ] = None,
342
365
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
343
366
  server_url: Optional[str] = None,
344
367
  timeout_ms: Optional[int] = None,
345
368
  http_headers: Optional[Mapping[str, str]] = None,
346
- ) -> Optional[eventstreaming.EventStream[models.CompletionEvent]]:
369
+ ) -> eventstreaming.EventStream[models.CompletionEvent]:
347
370
  r"""Stream Agents completion
348
371
 
349
372
  Mistral AI provides the ability to stream responses back to a client in order to allow partial results for certain requests. Tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message. Otherwise, the server will hold the request open until the timeout or until completion, with the response containing the full result as JSON.
@@ -360,6 +383,7 @@ class Agents(BaseSDK):
360
383
  :param presence_penalty: presence_penalty determines how much the model penalizes the repetition of words or phrases. A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative.
361
384
  :param frequency_penalty: frequency_penalty penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition.
362
385
  :param n: Number of completions to return for each request, input tokens are only billed once.
386
+ :param prediction:
363
387
  :param retries: Override the default retry configuration for this method
364
388
  :param server_url: Override the default server URL for this method
365
389
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -391,6 +415,9 @@ class Agents(BaseSDK):
391
415
  presence_penalty=presence_penalty,
392
416
  frequency_penalty=frequency_penalty,
393
417
  n=n,
418
+ prediction=utils.get_pydantic_model(
419
+ prediction, Optional[models.Prediction]
420
+ ),
394
421
  agent_id=agent_id,
395
422
  )
396
423
 
@@ -446,7 +473,12 @@ class Agents(BaseSDK):
446
473
  http_res_text = utils.stream_to_text(http_res)
447
474
  data = utils.unmarshal_json(http_res_text, models.HTTPValidationErrorData)
448
475
  raise models.HTTPValidationError(data=data)
449
- if utils.match_response(http_res, ["4XX", "5XX"], "*"):
476
+ if utils.match_response(http_res, "4XX", "*"):
477
+ http_res_text = utils.stream_to_text(http_res)
478
+ raise models.SDKError(
479
+ "API error occurred", http_res.status_code, http_res_text, http_res
480
+ )
481
+ if utils.match_response(http_res, "5XX", "*"):
450
482
  http_res_text = utils.stream_to_text(http_res)
451
483
  raise models.SDKError(
452
484
  "API error occurred", http_res.status_code, http_res_text, http_res
@@ -493,11 +525,14 @@ class Agents(BaseSDK):
493
525
  presence_penalty: Optional[float] = None,
494
526
  frequency_penalty: Optional[float] = None,
495
527
  n: OptionalNullable[int] = UNSET,
528
+ prediction: Optional[
529
+ Union[models.Prediction, models.PredictionTypedDict]
530
+ ] = None,
496
531
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
497
532
  server_url: Optional[str] = None,
498
533
  timeout_ms: Optional[int] = None,
499
534
  http_headers: Optional[Mapping[str, str]] = None,
500
- ) -> Optional[eventstreaming.EventStreamAsync[models.CompletionEvent]]:
535
+ ) -> eventstreaming.EventStreamAsync[models.CompletionEvent]:
501
536
  r"""Stream Agents completion
502
537
 
503
538
  Mistral AI provides the ability to stream responses back to a client in order to allow partial results for certain requests. Tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message. Otherwise, the server will hold the request open until the timeout or until completion, with the response containing the full result as JSON.
@@ -514,6 +549,7 @@ class Agents(BaseSDK):
514
549
  :param presence_penalty: presence_penalty determines how much the model penalizes the repetition of words or phrases. A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative.
515
550
  :param frequency_penalty: frequency_penalty penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition.
516
551
  :param n: Number of completions to return for each request, input tokens are only billed once.
552
+ :param prediction:
517
553
  :param retries: Override the default retry configuration for this method
518
554
  :param server_url: Override the default server URL for this method
519
555
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -545,6 +581,9 @@ class Agents(BaseSDK):
545
581
  presence_penalty=presence_penalty,
546
582
  frequency_penalty=frequency_penalty,
547
583
  n=n,
584
+ prediction=utils.get_pydantic_model(
585
+ prediction, Optional[models.Prediction]
586
+ ),
548
587
  agent_id=agent_id,
549
588
  )
550
589
 
@@ -600,7 +639,12 @@ class Agents(BaseSDK):
600
639
  http_res_text = await utils.stream_to_text_async(http_res)
601
640
  data = utils.unmarshal_json(http_res_text, models.HTTPValidationErrorData)
602
641
  raise models.HTTPValidationError(data=data)
603
- if utils.match_response(http_res, ["4XX", "5XX"], "*"):
642
+ if utils.match_response(http_res, "4XX", "*"):
643
+ http_res_text = await utils.stream_to_text_async(http_res)
644
+ raise models.SDKError(
645
+ "API error occurred", http_res.status_code, http_res_text, http_res
646
+ )
647
+ if utils.match_response(http_res, "5XX", "*"):
604
648
  http_res_text = await utils.stream_to_text_async(http_res)
605
649
  raise models.SDKError(
606
650
  "API error occurred", http_res.status_code, http_res_text, http_res
mistralai/chat.py CHANGED
@@ -7,10 +7,92 @@ from mistralai.types import Nullable, OptionalNullable, UNSET
7
7
  from mistralai.utils import eventstreaming, get_security_from_env
8
8
  from typing import Any, List, Mapping, Optional, Union
9
9
 
10
+ # region imports
11
+ from typing import Type
12
+ from mistralai.extra import (
13
+ convert_to_parsed_chat_completion_response,
14
+ response_format_from_pydantic_model,
15
+ CustomPydanticModel,
16
+ ParsedChatCompletionResponse,
17
+ )
18
+ # endregion imports
19
+
10
20
 
11
21
  class Chat(BaseSDK):
12
22
  r"""Chat Completion API."""
13
23
 
24
+ # region sdk-class-body
25
+ # Custom .parse methods for the Structure Outputs Feature.
26
+
27
+ def parse(
28
+ self, response_format: Type[CustomPydanticModel], **kwargs: Any
29
+ ) -> ParsedChatCompletionResponse[CustomPydanticModel]:
30
+ """
31
+ Parse the response using the provided response format.
32
+ :param Type[CustomPydanticModel] response_format: The Pydantic model to parse the response into
33
+ :param Any **kwargs Additional keyword arguments to pass to the .complete method
34
+ :return: The parsed response
35
+ """
36
+ # Convert the input Pydantic Model to a strict JSON ready to be passed to chat.complete
37
+ json_response_format = response_format_from_pydantic_model(response_format)
38
+ # Run the inference
39
+ response = self.complete(**kwargs, response_format=json_response_format)
40
+ # Parse response back to the input pydantic model
41
+ parsed_response = convert_to_parsed_chat_completion_response(
42
+ response, response_format
43
+ )
44
+ return parsed_response
45
+
46
+ async def parse_async(
47
+ self, response_format: Type[CustomPydanticModel], **kwargs
48
+ ) -> ParsedChatCompletionResponse[CustomPydanticModel]:
49
+ """
50
+ Asynchronously parse the response using the provided response format.
51
+ :param Type[CustomPydanticModel] response_format: The Pydantic model to parse the response into
52
+ :param Any **kwargs Additional keyword arguments to pass to the .complete method
53
+ :return: The parsed response
54
+ """
55
+ json_response_format = response_format_from_pydantic_model(response_format)
56
+ response = await self.complete_async( # pylint: disable=E1125
57
+ **kwargs, response_format=json_response_format
58
+ )
59
+ parsed_response = convert_to_parsed_chat_completion_response(
60
+ response, response_format
61
+ )
62
+ return parsed_response
63
+
64
+ def parse_stream(
65
+ self, response_format: Type[CustomPydanticModel], **kwargs
66
+ ) -> eventstreaming.EventStream[models.CompletionEvent]:
67
+ """
68
+ Parse the response using the provided response format.
69
+ For now the response will be in JSON format not in the input Pydantic model.
70
+ :param Type[CustomPydanticModel] response_format: The Pydantic model to parse the response into
71
+ :param Any **kwargs Additional keyword arguments to pass to the .stream method
72
+ :return: The JSON parsed response
73
+ """
74
+ json_response_format = response_format_from_pydantic_model(response_format)
75
+ response = self.stream(**kwargs, response_format=json_response_format)
76
+ return response
77
+
78
+ async def parse_stream_async(
79
+ self, response_format: Type[CustomPydanticModel], **kwargs
80
+ ) -> eventstreaming.EventStreamAsync[models.CompletionEvent]:
81
+ """
82
+ Asynchronously parse the response using the provided response format.
83
+ For now the response will be in JSON format not in the input Pydantic model.
84
+ :param Type[CustomPydanticModel] response_format: The Pydantic model to parse the response into
85
+ :param Any **kwargs Additional keyword arguments to pass to the .stream method
86
+ :return: The JSON parsed response
87
+ """
88
+ json_response_format = response_format_from_pydantic_model(response_format)
89
+ response = await self.stream_async( # pylint: disable=E1125
90
+ **kwargs, response_format=json_response_format
91
+ )
92
+ return response
93
+
94
+ # endregion sdk-class-body
95
+
14
96
  def complete(
15
97
  self,
16
98
  *,
@@ -37,12 +119,15 @@ class Chat(BaseSDK):
37
119
  presence_penalty: Optional[float] = None,
38
120
  frequency_penalty: Optional[float] = None,
39
121
  n: OptionalNullable[int] = UNSET,
122
+ prediction: Optional[
123
+ Union[models.Prediction, models.PredictionTypedDict]
124
+ ] = None,
40
125
  safe_prompt: Optional[bool] = None,
41
126
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
42
127
  server_url: Optional[str] = None,
43
128
  timeout_ms: Optional[int] = None,
44
129
  http_headers: Optional[Mapping[str, str]] = None,
45
- ) -> Optional[models.ChatCompletionResponse]:
130
+ ) -> models.ChatCompletionResponse:
46
131
  r"""Chat Completion
47
132
 
48
133
  :param model: ID of the model to use. You can use the [List Available Models](/api/#tag/models/operation/list_models_v1_models_get) API to see all of your available models, or see our [Model overview](/models) for model descriptions.
@@ -59,6 +144,7 @@ class Chat(BaseSDK):
59
144
  :param presence_penalty: presence_penalty determines how much the model penalizes the repetition of words or phrases. A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative.
60
145
  :param frequency_penalty: frequency_penalty penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition.
61
146
  :param n: Number of completions to return for each request, input tokens are only billed once.
147
+ :param prediction:
62
148
  :param safe_prompt: Whether to inject a safety prompt before all conversations.
63
149
  :param retries: Override the default retry configuration for this method
64
150
  :param server_url: Override the default server URL for this method
@@ -92,6 +178,9 @@ class Chat(BaseSDK):
92
178
  presence_penalty=presence_penalty,
93
179
  frequency_penalty=frequency_penalty,
94
180
  n=n,
181
+ prediction=utils.get_pydantic_model(
182
+ prediction, Optional[models.Prediction]
183
+ ),
95
184
  safe_prompt=safe_prompt,
96
185
  )
97
186
 
@@ -137,13 +226,16 @@ class Chat(BaseSDK):
137
226
 
138
227
  data: Any = None
139
228
  if utils.match_response(http_res, "200", "application/json"):
140
- return utils.unmarshal_json(
141
- http_res.text, Optional[models.ChatCompletionResponse]
142
- )
229
+ return utils.unmarshal_json(http_res.text, models.ChatCompletionResponse)
143
230
  if utils.match_response(http_res, "422", "application/json"):
144
231
  data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
145
232
  raise models.HTTPValidationError(data=data)
146
- if utils.match_response(http_res, ["4XX", "5XX"], "*"):
233
+ if utils.match_response(http_res, "4XX", "*"):
234
+ http_res_text = utils.stream_to_text(http_res)
235
+ raise models.SDKError(
236
+ "API error occurred", http_res.status_code, http_res_text, http_res
237
+ )
238
+ if utils.match_response(http_res, "5XX", "*"):
147
239
  http_res_text = utils.stream_to_text(http_res)
148
240
  raise models.SDKError(
149
241
  "API error occurred", http_res.status_code, http_res_text, http_res
@@ -184,12 +276,15 @@ class Chat(BaseSDK):
184
276
  presence_penalty: Optional[float] = None,
185
277
  frequency_penalty: Optional[float] = None,
186
278
  n: OptionalNullable[int] = UNSET,
279
+ prediction: Optional[
280
+ Union[models.Prediction, models.PredictionTypedDict]
281
+ ] = None,
187
282
  safe_prompt: Optional[bool] = None,
188
283
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
189
284
  server_url: Optional[str] = None,
190
285
  timeout_ms: Optional[int] = None,
191
286
  http_headers: Optional[Mapping[str, str]] = None,
192
- ) -> Optional[models.ChatCompletionResponse]:
287
+ ) -> models.ChatCompletionResponse:
193
288
  r"""Chat Completion
194
289
 
195
290
  :param model: ID of the model to use. You can use the [List Available Models](/api/#tag/models/operation/list_models_v1_models_get) API to see all of your available models, or see our [Model overview](/models) for model descriptions.
@@ -206,6 +301,7 @@ class Chat(BaseSDK):
206
301
  :param presence_penalty: presence_penalty determines how much the model penalizes the repetition of words or phrases. A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative.
207
302
  :param frequency_penalty: frequency_penalty penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition.
208
303
  :param n: Number of completions to return for each request, input tokens are only billed once.
304
+ :param prediction:
209
305
  :param safe_prompt: Whether to inject a safety prompt before all conversations.
210
306
  :param retries: Override the default retry configuration for this method
211
307
  :param server_url: Override the default server URL for this method
@@ -239,6 +335,9 @@ class Chat(BaseSDK):
239
335
  presence_penalty=presence_penalty,
240
336
  frequency_penalty=frequency_penalty,
241
337
  n=n,
338
+ prediction=utils.get_pydantic_model(
339
+ prediction, Optional[models.Prediction]
340
+ ),
242
341
  safe_prompt=safe_prompt,
243
342
  )
244
343
 
@@ -284,13 +383,16 @@ class Chat(BaseSDK):
284
383
 
285
384
  data: Any = None
286
385
  if utils.match_response(http_res, "200", "application/json"):
287
- return utils.unmarshal_json(
288
- http_res.text, Optional[models.ChatCompletionResponse]
289
- )
386
+ return utils.unmarshal_json(http_res.text, models.ChatCompletionResponse)
290
387
  if utils.match_response(http_res, "422", "application/json"):
291
388
  data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
292
389
  raise models.HTTPValidationError(data=data)
293
- if utils.match_response(http_res, ["4XX", "5XX"], "*"):
390
+ if utils.match_response(http_res, "4XX", "*"):
391
+ http_res_text = await utils.stream_to_text_async(http_res)
392
+ raise models.SDKError(
393
+ "API error occurred", http_res.status_code, http_res_text, http_res
394
+ )
395
+ if utils.match_response(http_res, "5XX", "*"):
294
396
  http_res_text = await utils.stream_to_text_async(http_res)
295
397
  raise models.SDKError(
296
398
  "API error occurred", http_res.status_code, http_res_text, http_res
@@ -339,12 +441,15 @@ class Chat(BaseSDK):
339
441
  presence_penalty: Optional[float] = None,
340
442
  frequency_penalty: Optional[float] = None,
341
443
  n: OptionalNullable[int] = UNSET,
444
+ prediction: Optional[
445
+ Union[models.Prediction, models.PredictionTypedDict]
446
+ ] = None,
342
447
  safe_prompt: Optional[bool] = None,
343
448
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
344
449
  server_url: Optional[str] = None,
345
450
  timeout_ms: Optional[int] = None,
346
451
  http_headers: Optional[Mapping[str, str]] = None,
347
- ) -> Optional[eventstreaming.EventStream[models.CompletionEvent]]:
452
+ ) -> eventstreaming.EventStream[models.CompletionEvent]:
348
453
  r"""Stream chat completion
349
454
 
350
455
  Mistral AI provides the ability to stream responses back to a client in order to allow partial results for certain requests. Tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message. Otherwise, the server will hold the request open until the timeout or until completion, with the response containing the full result as JSON.
@@ -363,6 +468,7 @@ class Chat(BaseSDK):
363
468
  :param presence_penalty: presence_penalty determines how much the model penalizes the repetition of words or phrases. A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative.
364
469
  :param frequency_penalty: frequency_penalty penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition.
365
470
  :param n: Number of completions to return for each request, input tokens are only billed once.
471
+ :param prediction:
366
472
  :param safe_prompt: Whether to inject a safety prompt before all conversations.
367
473
  :param retries: Override the default retry configuration for this method
368
474
  :param server_url: Override the default server URL for this method
@@ -398,6 +504,9 @@ class Chat(BaseSDK):
398
504
  presence_penalty=presence_penalty,
399
505
  frequency_penalty=frequency_penalty,
400
506
  n=n,
507
+ prediction=utils.get_pydantic_model(
508
+ prediction, Optional[models.Prediction]
509
+ ),
401
510
  safe_prompt=safe_prompt,
402
511
  )
403
512
 
@@ -453,7 +562,12 @@ class Chat(BaseSDK):
453
562
  http_res_text = utils.stream_to_text(http_res)
454
563
  data = utils.unmarshal_json(http_res_text, models.HTTPValidationErrorData)
455
564
  raise models.HTTPValidationError(data=data)
456
- if utils.match_response(http_res, ["4XX", "5XX"], "*"):
565
+ if utils.match_response(http_res, "4XX", "*"):
566
+ http_res_text = utils.stream_to_text(http_res)
567
+ raise models.SDKError(
568
+ "API error occurred", http_res.status_code, http_res_text, http_res
569
+ )
570
+ if utils.match_response(http_res, "5XX", "*"):
457
571
  http_res_text = utils.stream_to_text(http_res)
458
572
  raise models.SDKError(
459
573
  "API error occurred", http_res.status_code, http_res_text, http_res
@@ -502,12 +616,15 @@ class Chat(BaseSDK):
502
616
  presence_penalty: Optional[float] = None,
503
617
  frequency_penalty: Optional[float] = None,
504
618
  n: OptionalNullable[int] = UNSET,
619
+ prediction: Optional[
620
+ Union[models.Prediction, models.PredictionTypedDict]
621
+ ] = None,
505
622
  safe_prompt: Optional[bool] = None,
506
623
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
507
624
  server_url: Optional[str] = None,
508
625
  timeout_ms: Optional[int] = None,
509
626
  http_headers: Optional[Mapping[str, str]] = None,
510
- ) -> Optional[eventstreaming.EventStreamAsync[models.CompletionEvent]]:
627
+ ) -> eventstreaming.EventStreamAsync[models.CompletionEvent]:
511
628
  r"""Stream chat completion
512
629
 
513
630
  Mistral AI provides the ability to stream responses back to a client in order to allow partial results for certain requests. Tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message. Otherwise, the server will hold the request open until the timeout or until completion, with the response containing the full result as JSON.
@@ -526,6 +643,7 @@ class Chat(BaseSDK):
526
643
  :param presence_penalty: presence_penalty determines how much the model penalizes the repetition of words or phrases. A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative.
527
644
  :param frequency_penalty: frequency_penalty penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition.
528
645
  :param n: Number of completions to return for each request, input tokens are only billed once.
646
+ :param prediction:
529
647
  :param safe_prompt: Whether to inject a safety prompt before all conversations.
530
648
  :param retries: Override the default retry configuration for this method
531
649
  :param server_url: Override the default server URL for this method
@@ -561,6 +679,9 @@ class Chat(BaseSDK):
561
679
  presence_penalty=presence_penalty,
562
680
  frequency_penalty=frequency_penalty,
563
681
  n=n,
682
+ prediction=utils.get_pydantic_model(
683
+ prediction, Optional[models.Prediction]
684
+ ),
564
685
  safe_prompt=safe_prompt,
565
686
  )
566
687
 
@@ -616,7 +737,12 @@ class Chat(BaseSDK):
616
737
  http_res_text = await utils.stream_to_text_async(http_res)
617
738
  data = utils.unmarshal_json(http_res_text, models.HTTPValidationErrorData)
618
739
  raise models.HTTPValidationError(data=data)
619
- if utils.match_response(http_res, ["4XX", "5XX"], "*"):
740
+ if utils.match_response(http_res, "4XX", "*"):
741
+ http_res_text = await utils.stream_to_text_async(http_res)
742
+ raise models.SDKError(
743
+ "API error occurred", http_res.status_code, http_res_text, http_res
744
+ )
745
+ if utils.match_response(http_res, "5XX", "*"):
620
746
  http_res_text = await utils.stream_to_text_async(http_res)
621
747
  raise models.SDKError(
622
748
  "API error occurred", http_res.status_code, http_res_text, http_res
mistralai/classifiers.py CHANGED
@@ -23,7 +23,7 @@ class Classifiers(BaseSDK):
23
23
  server_url: Optional[str] = None,
24
24
  timeout_ms: Optional[int] = None,
25
25
  http_headers: Optional[Mapping[str, str]] = None,
26
- ) -> Optional[models.ClassificationResponse]:
26
+ ) -> models.ClassificationResponse:
27
27
  r"""Moderations
28
28
 
29
29
  :param inputs: Text to classify.
@@ -88,13 +88,16 @@ class Classifiers(BaseSDK):
88
88
 
89
89
  data: Any = None
90
90
  if utils.match_response(http_res, "200", "application/json"):
91
- return utils.unmarshal_json(
92
- http_res.text, Optional[models.ClassificationResponse]
93
- )
91
+ return utils.unmarshal_json(http_res.text, models.ClassificationResponse)
94
92
  if utils.match_response(http_res, "422", "application/json"):
95
93
  data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
96
94
  raise models.HTTPValidationError(data=data)
97
- if utils.match_response(http_res, ["4XX", "5XX"], "*"):
95
+ if utils.match_response(http_res, "4XX", "*"):
96
+ http_res_text = utils.stream_to_text(http_res)
97
+ raise models.SDKError(
98
+ "API error occurred", http_res.status_code, http_res_text, http_res
99
+ )
100
+ if utils.match_response(http_res, "5XX", "*"):
98
101
  http_res_text = utils.stream_to_text(http_res)
99
102
  raise models.SDKError(
100
103
  "API error occurred", http_res.status_code, http_res_text, http_res
@@ -121,7 +124,7 @@ class Classifiers(BaseSDK):
121
124
  server_url: Optional[str] = None,
122
125
  timeout_ms: Optional[int] = None,
123
126
  http_headers: Optional[Mapping[str, str]] = None,
124
- ) -> Optional[models.ClassificationResponse]:
127
+ ) -> models.ClassificationResponse:
125
128
  r"""Moderations
126
129
 
127
130
  :param inputs: Text to classify.
@@ -186,13 +189,16 @@ class Classifiers(BaseSDK):
186
189
 
187
190
  data: Any = None
188
191
  if utils.match_response(http_res, "200", "application/json"):
189
- return utils.unmarshal_json(
190
- http_res.text, Optional[models.ClassificationResponse]
191
- )
192
+ return utils.unmarshal_json(http_res.text, models.ClassificationResponse)
192
193
  if utils.match_response(http_res, "422", "application/json"):
193
194
  data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
194
195
  raise models.HTTPValidationError(data=data)
195
- if utils.match_response(http_res, ["4XX", "5XX"], "*"):
196
+ if utils.match_response(http_res, "4XX", "*"):
197
+ http_res_text = await utils.stream_to_text_async(http_res)
198
+ raise models.SDKError(
199
+ "API error occurred", http_res.status_code, http_res_text, http_res
200
+ )
201
+ if utils.match_response(http_res, "5XX", "*"):
196
202
  http_res_text = await utils.stream_to_text_async(http_res)
197
203
  raise models.SDKError(
198
204
  "API error occurred", http_res.status_code, http_res_text, http_res
@@ -219,7 +225,7 @@ class Classifiers(BaseSDK):
219
225
  server_url: Optional[str] = None,
220
226
  timeout_ms: Optional[int] = None,
221
227
  http_headers: Optional[Mapping[str, str]] = None,
222
- ) -> Optional[models.ClassificationResponse]:
228
+ ) -> models.ClassificationResponse:
223
229
  r"""Moderations Chat
224
230
 
225
231
  :param inputs: Chat to classify
@@ -286,13 +292,16 @@ class Classifiers(BaseSDK):
286
292
 
287
293
  data: Any = None
288
294
  if utils.match_response(http_res, "200", "application/json"):
289
- return utils.unmarshal_json(
290
- http_res.text, Optional[models.ClassificationResponse]
291
- )
295
+ return utils.unmarshal_json(http_res.text, models.ClassificationResponse)
292
296
  if utils.match_response(http_res, "422", "application/json"):
293
297
  data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
294
298
  raise models.HTTPValidationError(data=data)
295
- if utils.match_response(http_res, ["4XX", "5XX"], "*"):
299
+ if utils.match_response(http_res, "4XX", "*"):
300
+ http_res_text = utils.stream_to_text(http_res)
301
+ raise models.SDKError(
302
+ "API error occurred", http_res.status_code, http_res_text, http_res
303
+ )
304
+ if utils.match_response(http_res, "5XX", "*"):
296
305
  http_res_text = utils.stream_to_text(http_res)
297
306
  raise models.SDKError(
298
307
  "API error occurred", http_res.status_code, http_res_text, http_res
@@ -319,7 +328,7 @@ class Classifiers(BaseSDK):
319
328
  server_url: Optional[str] = None,
320
329
  timeout_ms: Optional[int] = None,
321
330
  http_headers: Optional[Mapping[str, str]] = None,
322
- ) -> Optional[models.ClassificationResponse]:
331
+ ) -> models.ClassificationResponse:
323
332
  r"""Moderations Chat
324
333
 
325
334
  :param inputs: Chat to classify
@@ -386,13 +395,16 @@ class Classifiers(BaseSDK):
386
395
 
387
396
  data: Any = None
388
397
  if utils.match_response(http_res, "200", "application/json"):
389
- return utils.unmarshal_json(
390
- http_res.text, Optional[models.ClassificationResponse]
391
- )
398
+ return utils.unmarshal_json(http_res.text, models.ClassificationResponse)
392
399
  if utils.match_response(http_res, "422", "application/json"):
393
400
  data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
394
401
  raise models.HTTPValidationError(data=data)
395
- if utils.match_response(http_res, ["4XX", "5XX"], "*"):
402
+ if utils.match_response(http_res, "4XX", "*"):
403
+ http_res_text = await utils.stream_to_text_async(http_res)
404
+ raise models.SDKError(
405
+ "API error occurred", http_res.status_code, http_res_text, http_res
406
+ )
407
+ if utils.match_response(http_res, "5XX", "*"):
396
408
  http_res_text = await utils.stream_to_text_async(http_res)
397
409
  raise models.SDKError(
398
410
  "API error occurred", http_res.status_code, http_res_text, http_res