mistralai 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. mistralai/_version.py +3 -3
  2. mistralai/chat.py +87 -5
  3. mistralai/classifiers.py +27 -25
  4. mistralai/embeddings.py +2 -8
  5. mistralai/extra/README.md +56 -0
  6. mistralai/extra/__init__.py +5 -0
  7. mistralai/extra/struct_chat.py +41 -0
  8. mistralai/extra/tests/__init__.py +0 -0
  9. mistralai/extra/tests/test_struct_chat.py +103 -0
  10. mistralai/extra/tests/test_utils.py +162 -0
  11. mistralai/extra/utils/__init__.py +3 -0
  12. mistralai/extra/utils/_pydantic_helper.py +20 -0
  13. mistralai/extra/utils/response_format.py +24 -0
  14. mistralai/fim.py +5 -5
  15. mistralai/httpclient.py +50 -0
  16. mistralai/models/__init__.py +41 -16
  17. mistralai/models/assistantmessage.py +2 -0
  18. mistralai/models/chatcompletionrequest.py +3 -10
  19. mistralai/models/chatcompletionstreamrequest.py +3 -10
  20. mistralai/models/chatmoderationrequest.py +86 -0
  21. mistralai/models/classificationrequest.py +7 -36
  22. mistralai/models/contentchunk.py +8 -1
  23. mistralai/models/documenturlchunk.py +62 -0
  24. mistralai/models/embeddingrequest.py +1 -37
  25. mistralai/models/fimcompletionrequest.py +2 -3
  26. mistralai/models/fimcompletionstreamrequest.py +2 -3
  27. mistralai/models/jsonschema.py +55 -0
  28. mistralai/models/ocrimageobject.py +77 -0
  29. mistralai/models/ocrpagedimensions.py +25 -0
  30. mistralai/models/ocrpageobject.py +64 -0
  31. mistralai/models/ocrrequest.py +97 -0
  32. mistralai/models/ocrresponse.py +26 -0
  33. mistralai/models/ocrusageinfo.py +51 -0
  34. mistralai/models/prediction.py +4 -5
  35. mistralai/models/responseformat.py +36 -1
  36. mistralai/models/responseformats.py +1 -1
  37. mistralai/ocr.py +238 -0
  38. mistralai/sdk.py +15 -2
  39. {mistralai-1.4.0.dist-info → mistralai-1.5.1.dist-info}/METADATA +37 -1
  40. {mistralai-1.4.0.dist-info → mistralai-1.5.1.dist-info}/RECORD +42 -24
  41. {mistralai-1.4.0.dist-info → mistralai-1.5.1.dist-info}/WHEEL +1 -1
  42. mistralai/models/chatclassificationrequest.py +0 -113
  43. {mistralai-1.4.0.dist-info → mistralai-1.5.1.dist-info}/LICENSE +0 -0
mistralai/_version.py CHANGED
@@ -3,10 +3,10 @@
3
3
  import importlib.metadata
4
4
 
5
5
  __title__: str = "mistralai"
6
- __version__: str = "1.4.0"
6
+ __version__: str = "1.5.1"
7
7
  __openapi_doc_version__: str = "0.0.2"
8
- __gen_version__: str = "2.493.32"
9
- __user_agent__: str = "speakeasy-sdk/python 1.4.0 2.493.32 0.0.2 mistralai"
8
+ __gen_version__: str = "2.497.0"
9
+ __user_agent__: str = "speakeasy-sdk/python 1.5.1 2.497.0 0.0.2 mistralai"
10
10
 
11
11
  try:
12
12
  if __package__ is not None:
mistralai/chat.py CHANGED
@@ -3,18 +3,100 @@
3
3
  from .basesdk import BaseSDK
4
4
  from mistralai import models, utils
5
5
  from mistralai._hooks import HookContext
6
- from mistralai.types import Nullable, OptionalNullable, UNSET
6
+ from mistralai.types import OptionalNullable, UNSET
7
7
  from mistralai.utils import eventstreaming, get_security_from_env
8
8
  from typing import Any, List, Mapping, Optional, Union
9
9
 
10
+ # region imports
11
+ from typing import Type
12
+ from mistralai.extra import (
13
+ convert_to_parsed_chat_completion_response,
14
+ response_format_from_pydantic_model,
15
+ CustomPydanticModel,
16
+ ParsedChatCompletionResponse,
17
+ )
18
+ # endregion imports
19
+
10
20
 
11
21
  class Chat(BaseSDK):
12
22
  r"""Chat Completion API."""
13
23
 
24
+ # region sdk-class-body
25
+ # Custom .parse methods for the Structure Outputs Feature.
26
+
27
+ def parse(
28
+ self, response_format: Type[CustomPydanticModel], **kwargs: Any
29
+ ) -> ParsedChatCompletionResponse[CustomPydanticModel]:
30
+ """
31
+ Parse the response using the provided response format.
32
+ :param Type[CustomPydanticModel] response_format: The Pydantic model to parse the response into
33
+ :param Any **kwargs Additional keyword arguments to pass to the .complete method
34
+ :return: The parsed response
35
+ """
36
+ # Convert the input Pydantic Model to a strict JSON ready to be passed to chat.complete
37
+ json_response_format = response_format_from_pydantic_model(response_format)
38
+ # Run the inference
39
+ response = self.complete(**kwargs, response_format=json_response_format)
40
+ # Parse response back to the input pydantic model
41
+ parsed_response = convert_to_parsed_chat_completion_response(
42
+ response, response_format
43
+ )
44
+ return parsed_response
45
+
46
+ async def parse_async(
47
+ self, response_format: Type[CustomPydanticModel], **kwargs
48
+ ) -> ParsedChatCompletionResponse[CustomPydanticModel]:
49
+ """
50
+ Asynchronously parse the response using the provided response format.
51
+ :param Type[CustomPydanticModel] response_format: The Pydantic model to parse the response into
52
+ :param Any **kwargs Additional keyword arguments to pass to the .complete method
53
+ :return: The parsed response
54
+ """
55
+ json_response_format = response_format_from_pydantic_model(response_format)
56
+ response = await self.complete_async( # pylint: disable=E1125
57
+ **kwargs, response_format=json_response_format
58
+ )
59
+ parsed_response = convert_to_parsed_chat_completion_response(
60
+ response, response_format
61
+ )
62
+ return parsed_response
63
+
64
+ def parse_stream(
65
+ self, response_format: Type[CustomPydanticModel], **kwargs
66
+ ) -> eventstreaming.EventStream[models.CompletionEvent]:
67
+ """
68
+ Parse the response using the provided response format.
69
+ For now the response will be in JSON format not in the input Pydantic model.
70
+ :param Type[CustomPydanticModel] response_format: The Pydantic model to parse the response into
71
+ :param Any **kwargs Additional keyword arguments to pass to the .stream method
72
+ :return: The JSON parsed response
73
+ """
74
+ json_response_format = response_format_from_pydantic_model(response_format)
75
+ response = self.stream(**kwargs, response_format=json_response_format)
76
+ return response
77
+
78
+ async def parse_stream_async(
79
+ self, response_format: Type[CustomPydanticModel], **kwargs
80
+ ) -> eventstreaming.EventStreamAsync[models.CompletionEvent]:
81
+ """
82
+ Asynchronously parse the response using the provided response format.
83
+ For now the response will be in JSON format not in the input Pydantic model.
84
+ :param Type[CustomPydanticModel] response_format: The Pydantic model to parse the response into
85
+ :param Any **kwargs Additional keyword arguments to pass to the .stream method
86
+ :return: The JSON parsed response
87
+ """
88
+ json_response_format = response_format_from_pydantic_model(response_format)
89
+ response = await self.stream_async( # pylint: disable=E1125
90
+ **kwargs, response_format=json_response_format
91
+ )
92
+ return response
93
+
94
+ # endregion sdk-class-body
95
+
14
96
  def complete(
15
97
  self,
16
98
  *,
17
- model: Nullable[str],
99
+ model: str,
18
100
  messages: Union[List[models.Messages], List[models.MessagesTypedDict]],
19
101
  temperature: OptionalNullable[float] = UNSET,
20
102
  top_p: Optional[float] = None,
@@ -171,7 +253,7 @@ class Chat(BaseSDK):
171
253
  async def complete_async(
172
254
  self,
173
255
  *,
174
- model: Nullable[str],
256
+ model: str,
175
257
  messages: Union[List[models.Messages], List[models.MessagesTypedDict]],
176
258
  temperature: OptionalNullable[float] = UNSET,
177
259
  top_p: Optional[float] = None,
@@ -328,7 +410,7 @@ class Chat(BaseSDK):
328
410
  def stream(
329
411
  self,
330
412
  *,
331
- model: Nullable[str],
413
+ model: str,
332
414
  messages: Union[
333
415
  List[models.ChatCompletionStreamRequestMessages],
334
416
  List[models.ChatCompletionStreamRequestMessagesTypedDict],
@@ -503,7 +585,7 @@ class Chat(BaseSDK):
503
585
  async def stream_async(
504
586
  self,
505
587
  *,
506
- model: Nullable[str],
588
+ model: str,
507
589
  messages: Union[
508
590
  List[models.ChatCompletionStreamRequestMessages],
509
591
  List[models.ChatCompletionStreamRequestMessagesTypedDict],
mistralai/classifiers.py CHANGED
@@ -3,7 +3,7 @@
3
3
  from .basesdk import BaseSDK
4
4
  from mistralai import models, utils
5
5
  from mistralai._hooks import HookContext
6
- from mistralai.types import Nullable, OptionalNullable, UNSET
6
+ from mistralai.types import OptionalNullable, UNSET
7
7
  from mistralai.utils import get_security_from_env
8
8
  from typing import Any, Mapping, Optional, Union
9
9
 
@@ -14,11 +14,11 @@ class Classifiers(BaseSDK):
14
14
  def moderate(
15
15
  self,
16
16
  *,
17
+ model: str,
17
18
  inputs: Union[
18
19
  models.ClassificationRequestInputs,
19
20
  models.ClassificationRequestInputsTypedDict,
20
21
  ],
21
- model: OptionalNullable[str] = UNSET,
22
22
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
23
23
  server_url: Optional[str] = None,
24
24
  timeout_ms: Optional[int] = None,
@@ -26,8 +26,8 @@ class Classifiers(BaseSDK):
26
26
  ) -> models.ClassificationResponse:
27
27
  r"""Moderations
28
28
 
29
+ :param model: ID of the model to use.
29
30
  :param inputs: Text to classify.
30
- :param model:
31
31
  :param retries: Override the default retry configuration for this method
32
32
  :param server_url: Override the default server URL for this method
33
33
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -42,8 +42,8 @@ class Classifiers(BaseSDK):
42
42
  base_url = server_url
43
43
 
44
44
  request = models.ClassificationRequest(
45
- inputs=inputs,
46
45
  model=model,
46
+ inputs=inputs,
47
47
  )
48
48
 
49
49
  req = self._build_request(
@@ -115,11 +115,11 @@ class Classifiers(BaseSDK):
115
115
  async def moderate_async(
116
116
  self,
117
117
  *,
118
+ model: str,
118
119
  inputs: Union[
119
120
  models.ClassificationRequestInputs,
120
121
  models.ClassificationRequestInputsTypedDict,
121
122
  ],
122
- model: OptionalNullable[str] = UNSET,
123
123
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
124
124
  server_url: Optional[str] = None,
125
125
  timeout_ms: Optional[int] = None,
@@ -127,8 +127,8 @@ class Classifiers(BaseSDK):
127
127
  ) -> models.ClassificationResponse:
128
128
  r"""Moderations
129
129
 
130
+ :param model: ID of the model to use.
130
131
  :param inputs: Text to classify.
131
- :param model:
132
132
  :param retries: Override the default retry configuration for this method
133
133
  :param server_url: Override the default server URL for this method
134
134
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -143,8 +143,8 @@ class Classifiers(BaseSDK):
143
143
  base_url = server_url
144
144
 
145
145
  request = models.ClassificationRequest(
146
- inputs=inputs,
147
146
  model=model,
147
+ inputs=inputs,
148
148
  )
149
149
 
150
150
  req = self._build_request_async(
@@ -216,11 +216,12 @@ class Classifiers(BaseSDK):
216
216
  def moderate_chat(
217
217
  self,
218
218
  *,
219
+ model: str,
219
220
  inputs: Union[
220
- models.ChatClassificationRequestInputs,
221
- models.ChatClassificationRequestInputsTypedDict,
221
+ models.ChatModerationRequestInputs,
222
+ models.ChatModerationRequestInputsTypedDict,
222
223
  ],
223
- model: Nullable[str],
224
+ truncate_for_context_length: Optional[bool] = False,
224
225
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
225
226
  server_url: Optional[str] = None,
226
227
  timeout_ms: Optional[int] = None,
@@ -228,8 +229,9 @@ class Classifiers(BaseSDK):
228
229
  ) -> models.ClassificationResponse:
229
230
  r"""Moderations Chat
230
231
 
231
- :param inputs: Chat to classify
232
232
  :param model:
233
+ :param inputs: Chat to classify
234
+ :param truncate_for_context_length:
233
235
  :param retries: Override the default retry configuration for this method
234
236
  :param server_url: Override the default server URL for this method
235
237
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -243,11 +245,10 @@ class Classifiers(BaseSDK):
243
245
  if server_url is not None:
244
246
  base_url = server_url
245
247
 
246
- request = models.ChatClassificationRequest(
247
- inputs=utils.get_pydantic_model(
248
- inputs, models.ChatClassificationRequestInputs
249
- ),
248
+ request = models.ChatModerationRequest(
250
249
  model=model,
250
+ inputs=utils.get_pydantic_model(inputs, models.ChatModerationRequestInputs),
251
+ truncate_for_context_length=truncate_for_context_length,
251
252
  )
252
253
 
253
254
  req = self._build_request(
@@ -264,7 +265,7 @@ class Classifiers(BaseSDK):
264
265
  http_headers=http_headers,
265
266
  security=self.sdk_configuration.security,
266
267
  get_serialized_body=lambda: utils.serialize_request_body(
267
- request, False, False, "json", models.ChatClassificationRequest
268
+ request, False, False, "json", models.ChatModerationRequest
268
269
  ),
269
270
  timeout_ms=timeout_ms,
270
271
  )
@@ -319,11 +320,12 @@ class Classifiers(BaseSDK):
319
320
  async def moderate_chat_async(
320
321
  self,
321
322
  *,
323
+ model: str,
322
324
  inputs: Union[
323
- models.ChatClassificationRequestInputs,
324
- models.ChatClassificationRequestInputsTypedDict,
325
+ models.ChatModerationRequestInputs,
326
+ models.ChatModerationRequestInputsTypedDict,
325
327
  ],
326
- model: Nullable[str],
328
+ truncate_for_context_length: Optional[bool] = False,
327
329
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
328
330
  server_url: Optional[str] = None,
329
331
  timeout_ms: Optional[int] = None,
@@ -331,8 +333,9 @@ class Classifiers(BaseSDK):
331
333
  ) -> models.ClassificationResponse:
332
334
  r"""Moderations Chat
333
335
 
334
- :param inputs: Chat to classify
335
336
  :param model:
337
+ :param inputs: Chat to classify
338
+ :param truncate_for_context_length:
336
339
  :param retries: Override the default retry configuration for this method
337
340
  :param server_url: Override the default server URL for this method
338
341
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -346,11 +349,10 @@ class Classifiers(BaseSDK):
346
349
  if server_url is not None:
347
350
  base_url = server_url
348
351
 
349
- request = models.ChatClassificationRequest(
350
- inputs=utils.get_pydantic_model(
351
- inputs, models.ChatClassificationRequestInputs
352
- ),
352
+ request = models.ChatModerationRequest(
353
353
  model=model,
354
+ inputs=utils.get_pydantic_model(inputs, models.ChatModerationRequestInputs),
355
+ truncate_for_context_length=truncate_for_context_length,
354
356
  )
355
357
 
356
358
  req = self._build_request_async(
@@ -367,7 +369,7 @@ class Classifiers(BaseSDK):
367
369
  http_headers=http_headers,
368
370
  security=self.sdk_configuration.security,
369
371
  get_serialized_body=lambda: utils.serialize_request_body(
370
- request, False, False, "json", models.ChatClassificationRequest
372
+ request, False, False, "json", models.ChatModerationRequest
371
373
  ),
372
374
  timeout_ms=timeout_ms,
373
375
  )
mistralai/embeddings.py CHANGED
@@ -16,7 +16,6 @@ class Embeddings(BaseSDK):
16
16
  *,
17
17
  inputs: Union[models.Inputs, models.InputsTypedDict],
18
18
  model: Optional[str] = "mistral-embed",
19
- encoding_format: OptionalNullable[str] = UNSET,
20
19
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
21
20
  server_url: Optional[str] = None,
22
21
  timeout_ms: Optional[int] = None,
@@ -28,7 +27,6 @@ class Embeddings(BaseSDK):
28
27
 
29
28
  :param inputs: Text to embed.
30
29
  :param model: ID of the model to use.
31
- :param encoding_format: The format to return the embeddings in.
32
30
  :param retries: Override the default retry configuration for this method
33
31
  :param server_url: Override the default server URL for this method
34
32
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -43,9 +41,8 @@ class Embeddings(BaseSDK):
43
41
  base_url = server_url
44
42
 
45
43
  request = models.EmbeddingRequest(
46
- inputs=inputs,
47
44
  model=model,
48
- encoding_format=encoding_format,
45
+ inputs=inputs,
49
46
  )
50
47
 
51
48
  req = self._build_request(
@@ -119,7 +116,6 @@ class Embeddings(BaseSDK):
119
116
  *,
120
117
  inputs: Union[models.Inputs, models.InputsTypedDict],
121
118
  model: Optional[str] = "mistral-embed",
122
- encoding_format: OptionalNullable[str] = UNSET,
123
119
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
124
120
  server_url: Optional[str] = None,
125
121
  timeout_ms: Optional[int] = None,
@@ -131,7 +127,6 @@ class Embeddings(BaseSDK):
131
127
 
132
128
  :param inputs: Text to embed.
133
129
  :param model: ID of the model to use.
134
- :param encoding_format: The format to return the embeddings in.
135
130
  :param retries: Override the default retry configuration for this method
136
131
  :param server_url: Override the default server URL for this method
137
132
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -146,9 +141,8 @@ class Embeddings(BaseSDK):
146
141
  base_url = server_url
147
142
 
148
143
  request = models.EmbeddingRequest(
149
- inputs=inputs,
150
144
  model=model,
151
- encoding_format=encoding_format,
145
+ inputs=inputs,
152
146
  )
153
147
 
154
148
  req = self._build_request_async(
@@ -0,0 +1,56 @@
1
+ ## Context
2
+
3
+ The extra package contains the custom logic which is too complex to be generated by Speakeasy from the OpenAPI specs. It was introduced to add the Structured Outputs feature.
4
+
5
+ ## Development / Contributing
6
+
7
+ To add custom code in the SDK, you need to use [Speakeasy custom code regions](https://www.speakeasy.com/docs/customize/code/code-regions/overview) as below.
8
+
9
+ ### Runbook of SDK customization
10
+
11
+ 1. Add the code you want to import in the `src/mistralai/extra/` package. To have it importable from the SDK, you need to add it in the `__init__.py` file:
12
+ ```python
13
+ from .my_custom_file import my_custom_function
14
+
15
+ __all__ = ["my_custom_function"]
16
+ ```
17
+
18
+ 2. Add a new custom code region in the SDK files, e.g in `src/mistralai/chat.py`:
19
+ ```python
20
+ # region imports
21
+ from typing import Type
22
+ from mistralai.extra import my_custom_function
23
+ # endregion imports
24
+
25
+ class Chat(BaseSDK):
26
+ r"""Chat Completion API."""
27
+
28
+ # region sdk-class-body
29
+ def my_custom_method(self, param: str) -> Type[some_type]:
30
+ output = my_custom_function(param1)
31
+ return output
32
+ # endregion sdk-class-body
33
+ ```
34
+
35
+ 3. Now build the SDK with the custom code:
36
+ ```bash
37
+ rm -rf dist; poetry build; python3 -m pip install ~/client-python/dist/mistralai-1.4.1-py3-none-any.whl --force-reinstall
38
+ ```
39
+
40
+ 4. And now you should be able to call the custom method:
41
+ ```python
42
+ import os
43
+ from mistralai import Mistral
44
+
45
+ api_key = os.environ["MISTRAL_API_KEY"]
46
+ client = Mistral(api_key=api_key)
47
+
48
+ client.chat.my_custom_method(param="test")
49
+ ```
50
+
51
+ ### Run the unit tests
52
+
53
+ To run the unit tests for the `extra` package, you can run the following command from the root of the repository:
54
+ ```bash
55
+ python3.12 -m unittest discover -s src/mistralai/extra/tests -t src
56
+ ```
@@ -0,0 +1,5 @@
1
+ from .struct_chat import ParsedChatCompletionResponse, convert_to_parsed_chat_completion_response
2
+ from .utils import response_format_from_pydantic_model
3
+ from .utils.response_format import CustomPydanticModel
4
+
5
+ __all__ = ["convert_to_parsed_chat_completion_response", "response_format_from_pydantic_model", "CustomPydanticModel", "ParsedChatCompletionResponse"]
@@ -0,0 +1,41 @@
1
+ from ..models import ChatCompletionResponse, ChatCompletionChoice, AssistantMessage
2
+ from .utils.response_format import CustomPydanticModel, pydantic_model_from_json
3
+ from typing import List, Optional, Type, Generic
4
+ from pydantic import BaseModel
5
+ import json
6
+
7
+ class ParsedAssistantMessage(AssistantMessage, Generic[CustomPydanticModel]):
8
+ parsed: Optional[CustomPydanticModel]
9
+
10
+ class ParsedChatCompletionChoice(ChatCompletionChoice, Generic[CustomPydanticModel]):
11
+ message: Optional[ParsedAssistantMessage[CustomPydanticModel]] # type: ignore
12
+
13
+ class ParsedChatCompletionResponse(ChatCompletionResponse, Generic[CustomPydanticModel]):
14
+ choices: Optional[List[ParsedChatCompletionChoice[CustomPydanticModel]]] # type: ignore
15
+
16
+ def convert_to_parsed_chat_completion_response(response: ChatCompletionResponse, response_format: Type[BaseModel]) -> ParsedChatCompletionResponse:
17
+ parsed_choices = []
18
+
19
+ if response.choices:
20
+ for choice in response.choices:
21
+ if choice.message:
22
+ parsed_message: ParsedAssistantMessage = ParsedAssistantMessage(
23
+ **choice.message.model_dump(),
24
+ parsed=None
25
+ )
26
+ if isinstance(parsed_message.content, str):
27
+ parsed_message.parsed = pydantic_model_from_json(json.loads(parsed_message.content), response_format)
28
+ elif parsed_message.content is None:
29
+ parsed_message.parsed = None
30
+ else:
31
+ raise TypeError(f"Unexpected type for message.content: {type(parsed_message.content)}")
32
+ choice_dict = choice.model_dump()
33
+ choice_dict["message"] = parsed_message
34
+ parsed_choice: ParsedChatCompletionChoice = ParsedChatCompletionChoice(**choice_dict)
35
+ parsed_choices.append(parsed_choice)
36
+ else:
37
+ parsed_choice = ParsedChatCompletionChoice(**choice.model_dump())
38
+ parsed_choices.append(parsed_choice)
39
+ response_dict = response.model_dump()
40
+ response_dict["choices"] = parsed_choices
41
+ return ParsedChatCompletionResponse(**response_dict)
File without changes
@@ -0,0 +1,103 @@
1
+ import unittest
2
+ from ..struct_chat import (
3
+ convert_to_parsed_chat_completion_response,
4
+ ParsedChatCompletionResponse,
5
+ ParsedChatCompletionChoice,
6
+ ParsedAssistantMessage,
7
+ )
8
+ from ...models import (
9
+ ChatCompletionResponse,
10
+ UsageInfo,
11
+ ChatCompletionChoice,
12
+ AssistantMessage,
13
+ )
14
+ from pydantic import BaseModel
15
+
16
+
17
+ class Explanation(BaseModel):
18
+ explanation: str
19
+ output: str
20
+
21
+
22
+ class MathDemonstration(BaseModel):
23
+ steps: list[Explanation]
24
+ final_answer: str
25
+
26
+
27
+ mock_cc_response = ChatCompletionResponse(
28
+ id="c0271b2098954c6094231703875ca0bc",
29
+ object="chat.completion",
30
+ model="mistral-large-latest",
31
+ usage=UsageInfo(prompt_tokens=75, completion_tokens=220, total_tokens=295),
32
+ created=1737727558,
33
+ choices=[
34
+ ChatCompletionChoice(
35
+ index=0,
36
+ message=AssistantMessage(
37
+ content='{\n "final_answer": "x = -4",\n "steps": [\n {\n "explanation": "Start with the given equation.",\n "output": "8x + 7 = -23"\n },\n {\n "explanation": "Subtract 7 from both sides to isolate the term with x.",\n "output": "8x = -23 - 7"\n },\n {\n "explanation": "Simplify the right side of the equation.",\n "output": "8x = -30"\n },\n {\n "explanation": "Divide both sides by 8 to solve for x.",\n "output": "x = -30 / 8"\n },\n {\n "explanation": "Simplify the fraction to get the final answer.",\n "output": "x = -4"\n }\n ]\n}',
38
+ tool_calls=None,
39
+ prefix=False,
40
+ role="assistant",
41
+ ),
42
+ finish_reason="stop",
43
+ )
44
+ ],
45
+ )
46
+
47
+
48
+ expected_response = ParsedChatCompletionResponse(
49
+ choices=[
50
+ ParsedChatCompletionChoice(
51
+ index=0,
52
+ message=ParsedAssistantMessage(
53
+ content='{\n "final_answer": "x = -4",\n "steps": [\n {\n "explanation": "Start with the given equation.",\n "output": "8x + 7 = -23"\n },\n {\n "explanation": "Subtract 7 from both sides to isolate the term with x.",\n "output": "8x = -23 - 7"\n },\n {\n "explanation": "Simplify the right side of the equation.",\n "output": "8x = -30"\n },\n {\n "explanation": "Divide both sides by 8 to solve for x.",\n "output": "x = -30 / 8"\n },\n {\n "explanation": "Simplify the fraction to get the final answer.",\n "output": "x = -4"\n }\n ]\n}',
54
+ tool_calls=None,
55
+ prefix=False,
56
+ role="assistant",
57
+ parsed=MathDemonstration(
58
+ steps=[
59
+ Explanation(
60
+ explanation="Start with the given equation.",
61
+ output="8x + 7 = -23",
62
+ ),
63
+ Explanation(
64
+ explanation="Subtract 7 from both sides to isolate the term with x.",
65
+ output="8x = -23 - 7",
66
+ ),
67
+ Explanation(
68
+ explanation="Simplify the right side of the equation.",
69
+ output="8x = -30",
70
+ ),
71
+ Explanation(
72
+ explanation="Divide both sides by 8 to solve for x.",
73
+ output="x = -30 / 8",
74
+ ),
75
+ Explanation(
76
+ explanation="Simplify the fraction to get the final answer.",
77
+ output="x = -4",
78
+ ),
79
+ ],
80
+ final_answer="x = -4",
81
+ ),
82
+ ),
83
+ finish_reason="stop",
84
+ )
85
+ ],
86
+ created=1737727558,
87
+ id="c0271b2098954c6094231703875ca0bc",
88
+ model="mistral-large-latest",
89
+ object="chat.completion",
90
+ usage=UsageInfo(prompt_tokens=75, completion_tokens=220, total_tokens=295),
91
+ )
92
+
93
+
94
+ class TestConvertToParsedChatCompletionResponse(unittest.TestCase):
95
+ def test_convert_to_parsed_chat_completion_response(self):
96
+ output = convert_to_parsed_chat_completion_response(
97
+ mock_cc_response, MathDemonstration
98
+ )
99
+ self.assertEqual(output, expected_response)
100
+
101
+
102
+ if __name__ == "__main__":
103
+ unittest.main()