mistralai 1.3.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/__init__.py +10 -1
- mistralai/_version.py +4 -1
- mistralai/agents.py +58 -14
- mistralai/chat.py +140 -14
- mistralai/classifiers.py +32 -20
- mistralai/embeddings.py +16 -10
- mistralai/extra/README.md +56 -0
- mistralai/extra/__init__.py +5 -0
- mistralai/extra/struct_chat.py +41 -0
- mistralai/extra/tests/__init__.py +0 -0
- mistralai/extra/tests/test_struct_chat.py +103 -0
- mistralai/extra/tests/test_utils.py +162 -0
- mistralai/extra/utils/__init__.py +3 -0
- mistralai/extra/utils/_pydantic_helper.py +20 -0
- mistralai/extra/utils/response_format.py +24 -0
- mistralai/files.py +94 -34
- mistralai/fim.py +30 -14
- mistralai/httpclient.py +50 -0
- mistralai/jobs.py +80 -32
- mistralai/mistral_jobs.py +64 -24
- mistralai/models/__init__.py +8 -0
- mistralai/models/agentscompletionrequest.py +5 -0
- mistralai/models/agentscompletionstreamrequest.py +5 -0
- mistralai/models/chatcompletionrequest.py +5 -0
- mistralai/models/chatcompletionstreamrequest.py +5 -0
- mistralai/models/fileschema.py +3 -2
- mistralai/models/function.py +3 -0
- mistralai/models/jsonschema.py +55 -0
- mistralai/models/prediction.py +26 -0
- mistralai/models/responseformat.py +36 -1
- mistralai/models/responseformats.py +1 -1
- mistralai/models/retrievefileout.py +3 -2
- mistralai/models/toolcall.py +3 -0
- mistralai/models/uploadfileout.py +3 -2
- mistralai/models_.py +92 -48
- mistralai/sdk.py +13 -3
- mistralai/sdkconfiguration.py +10 -4
- {mistralai-1.3.1.dist-info → mistralai-1.5.0.dist-info}/METADATA +41 -42
- {mistralai-1.3.1.dist-info → mistralai-1.5.0.dist-info}/RECORD +43 -33
- {mistralai-1.3.1.dist-info → mistralai-1.5.0.dist-info}/WHEEL +1 -1
- mistralai_azure/_hooks/custom_user_agent.py +1 -1
- mistralai_gcp/sdk.py +1 -2
- py.typed +0 -1
- {mistralai-1.3.1.dist-info → mistralai-1.5.0.dist-info}/LICENSE +0 -0
mistralai/embeddings.py
CHANGED
|
@@ -21,7 +21,7 @@ class Embeddings(BaseSDK):
|
|
|
21
21
|
server_url: Optional[str] = None,
|
|
22
22
|
timeout_ms: Optional[int] = None,
|
|
23
23
|
http_headers: Optional[Mapping[str, str]] = None,
|
|
24
|
-
) ->
|
|
24
|
+
) -> models.EmbeddingResponse:
|
|
25
25
|
r"""Embeddings
|
|
26
26
|
|
|
27
27
|
Embeddings
|
|
@@ -90,13 +90,16 @@ class Embeddings(BaseSDK):
|
|
|
90
90
|
|
|
91
91
|
data: Any = None
|
|
92
92
|
if utils.match_response(http_res, "200", "application/json"):
|
|
93
|
-
return utils.unmarshal_json(
|
|
94
|
-
http_res.text, Optional[models.EmbeddingResponse]
|
|
95
|
-
)
|
|
93
|
+
return utils.unmarshal_json(http_res.text, models.EmbeddingResponse)
|
|
96
94
|
if utils.match_response(http_res, "422", "application/json"):
|
|
97
95
|
data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
|
|
98
96
|
raise models.HTTPValidationError(data=data)
|
|
99
|
-
if utils.match_response(http_res,
|
|
97
|
+
if utils.match_response(http_res, "4XX", "*"):
|
|
98
|
+
http_res_text = utils.stream_to_text(http_res)
|
|
99
|
+
raise models.SDKError(
|
|
100
|
+
"API error occurred", http_res.status_code, http_res_text, http_res
|
|
101
|
+
)
|
|
102
|
+
if utils.match_response(http_res, "5XX", "*"):
|
|
100
103
|
http_res_text = utils.stream_to_text(http_res)
|
|
101
104
|
raise models.SDKError(
|
|
102
105
|
"API error occurred", http_res.status_code, http_res_text, http_res
|
|
@@ -121,7 +124,7 @@ class Embeddings(BaseSDK):
|
|
|
121
124
|
server_url: Optional[str] = None,
|
|
122
125
|
timeout_ms: Optional[int] = None,
|
|
123
126
|
http_headers: Optional[Mapping[str, str]] = None,
|
|
124
|
-
) ->
|
|
127
|
+
) -> models.EmbeddingResponse:
|
|
125
128
|
r"""Embeddings
|
|
126
129
|
|
|
127
130
|
Embeddings
|
|
@@ -190,13 +193,16 @@ class Embeddings(BaseSDK):
|
|
|
190
193
|
|
|
191
194
|
data: Any = None
|
|
192
195
|
if utils.match_response(http_res, "200", "application/json"):
|
|
193
|
-
return utils.unmarshal_json(
|
|
194
|
-
http_res.text, Optional[models.EmbeddingResponse]
|
|
195
|
-
)
|
|
196
|
+
return utils.unmarshal_json(http_res.text, models.EmbeddingResponse)
|
|
196
197
|
if utils.match_response(http_res, "422", "application/json"):
|
|
197
198
|
data = utils.unmarshal_json(http_res.text, models.HTTPValidationErrorData)
|
|
198
199
|
raise models.HTTPValidationError(data=data)
|
|
199
|
-
if utils.match_response(http_res,
|
|
200
|
+
if utils.match_response(http_res, "4XX", "*"):
|
|
201
|
+
http_res_text = await utils.stream_to_text_async(http_res)
|
|
202
|
+
raise models.SDKError(
|
|
203
|
+
"API error occurred", http_res.status_code, http_res_text, http_res
|
|
204
|
+
)
|
|
205
|
+
if utils.match_response(http_res, "5XX", "*"):
|
|
200
206
|
http_res_text = await utils.stream_to_text_async(http_res)
|
|
201
207
|
raise models.SDKError(
|
|
202
208
|
"API error occurred", http_res.status_code, http_res_text, http_res
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
## Context
|
|
2
|
+
|
|
3
|
+
The extra package contains the custom logic which is too complex to be generated by Speakeasy from the OpenAPI specs. It was introduced to add the Structured Outputs feature.
|
|
4
|
+
|
|
5
|
+
## Development / Contributing
|
|
6
|
+
|
|
7
|
+
To add custom code in the SDK, you need to use [Speakeasy custom code regions](https://www.speakeasy.com/docs/customize/code/code-regions/overview) as below.
|
|
8
|
+
|
|
9
|
+
### Runbook of SDK customization
|
|
10
|
+
|
|
11
|
+
1. Add the code you want to import in the `src/mistralai/extra/` package. To have it importable from the SDK, you need to add it in the `__init__.py` file:
|
|
12
|
+
```python
|
|
13
|
+
from .my_custom_file import my_custom_function
|
|
14
|
+
|
|
15
|
+
__all__ = ["my_custom_function"]
|
|
16
|
+
```
|
|
17
|
+
|
|
18
|
+
2. Add a new custom code region in the SDK files, e.g in `src/mistralai/chat.py`:
|
|
19
|
+
```python
|
|
20
|
+
# region imports
|
|
21
|
+
from typing import Type
|
|
22
|
+
from mistralai.extra import my_custom_function
|
|
23
|
+
# endregion imports
|
|
24
|
+
|
|
25
|
+
class Chat(BaseSDK):
|
|
26
|
+
r"""Chat Completion API."""
|
|
27
|
+
|
|
28
|
+
# region sdk-class-body
|
|
29
|
+
def my_custom_method(self, param: str) -> Type[some_type]:
|
|
30
|
+
output = my_custom_function(param1)
|
|
31
|
+
return output
|
|
32
|
+
# endregion sdk-class-body
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
3. Now build the SDK with the custom code:
|
|
36
|
+
```bash
|
|
37
|
+
rm -rf dist; poetry build; python3 -m pip install ~/client-python/dist/mistralai-1.4.1-py3-none-any.whl --force-reinstall
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
4. And now you should be able to call the custom method:
|
|
41
|
+
```python
|
|
42
|
+
import os
|
|
43
|
+
from mistralai import Mistral
|
|
44
|
+
|
|
45
|
+
api_key = os.environ["MISTRAL_API_KEY"]
|
|
46
|
+
client = Mistral(api_key=api_key)
|
|
47
|
+
|
|
48
|
+
client.chat.my_custom_method(param="test")
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
### Run the unit tests
|
|
52
|
+
|
|
53
|
+
To run the unit tests for the `extra` package, you can run the following command from the root of the repository:
|
|
54
|
+
```bash
|
|
55
|
+
python3.12 -m unittest discover -s src/mistralai/extra/tests -t src
|
|
56
|
+
```
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
from .struct_chat import ParsedChatCompletionResponse, convert_to_parsed_chat_completion_response
|
|
2
|
+
from .utils import response_format_from_pydantic_model
|
|
3
|
+
from .utils.response_format import CustomPydanticModel
|
|
4
|
+
|
|
5
|
+
__all__ = ["convert_to_parsed_chat_completion_response", "response_format_from_pydantic_model", "CustomPydanticModel", "ParsedChatCompletionResponse"]
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from ..models import ChatCompletionResponse, ChatCompletionChoice, AssistantMessage
|
|
2
|
+
from .utils.response_format import CustomPydanticModel, pydantic_model_from_json
|
|
3
|
+
from typing import List, Optional, Type, Generic
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
import json
|
|
6
|
+
|
|
7
|
+
class ParsedAssistantMessage(AssistantMessage, Generic[CustomPydanticModel]):
|
|
8
|
+
parsed: Optional[CustomPydanticModel]
|
|
9
|
+
|
|
10
|
+
class ParsedChatCompletionChoice(ChatCompletionChoice, Generic[CustomPydanticModel]):
|
|
11
|
+
message: Optional[ParsedAssistantMessage[CustomPydanticModel]] # type: ignore
|
|
12
|
+
|
|
13
|
+
class ParsedChatCompletionResponse(ChatCompletionResponse, Generic[CustomPydanticModel]):
|
|
14
|
+
choices: Optional[List[ParsedChatCompletionChoice[CustomPydanticModel]]] # type: ignore
|
|
15
|
+
|
|
16
|
+
def convert_to_parsed_chat_completion_response(response: ChatCompletionResponse, response_format: Type[BaseModel]) -> ParsedChatCompletionResponse:
|
|
17
|
+
parsed_choices = []
|
|
18
|
+
|
|
19
|
+
if response.choices:
|
|
20
|
+
for choice in response.choices:
|
|
21
|
+
if choice.message:
|
|
22
|
+
parsed_message: ParsedAssistantMessage = ParsedAssistantMessage(
|
|
23
|
+
**choice.message.model_dump(),
|
|
24
|
+
parsed=None
|
|
25
|
+
)
|
|
26
|
+
if isinstance(parsed_message.content, str):
|
|
27
|
+
parsed_message.parsed = pydantic_model_from_json(json.loads(parsed_message.content), response_format)
|
|
28
|
+
elif parsed_message.content is None:
|
|
29
|
+
parsed_message.parsed = None
|
|
30
|
+
else:
|
|
31
|
+
raise TypeError(f"Unexpected type for message.content: {type(parsed_message.content)}")
|
|
32
|
+
choice_dict = choice.model_dump()
|
|
33
|
+
choice_dict["message"] = parsed_message
|
|
34
|
+
parsed_choice: ParsedChatCompletionChoice = ParsedChatCompletionChoice(**choice_dict)
|
|
35
|
+
parsed_choices.append(parsed_choice)
|
|
36
|
+
else:
|
|
37
|
+
parsed_choice = ParsedChatCompletionChoice(**choice.model_dump())
|
|
38
|
+
parsed_choices.append(parsed_choice)
|
|
39
|
+
response_dict = response.model_dump()
|
|
40
|
+
response_dict["choices"] = parsed_choices
|
|
41
|
+
return ParsedChatCompletionResponse(**response_dict)
|
|
File without changes
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
from ..struct_chat import (
|
|
3
|
+
convert_to_parsed_chat_completion_response,
|
|
4
|
+
ParsedChatCompletionResponse,
|
|
5
|
+
ParsedChatCompletionChoice,
|
|
6
|
+
ParsedAssistantMessage,
|
|
7
|
+
)
|
|
8
|
+
from ...models import (
|
|
9
|
+
ChatCompletionResponse,
|
|
10
|
+
UsageInfo,
|
|
11
|
+
ChatCompletionChoice,
|
|
12
|
+
AssistantMessage,
|
|
13
|
+
)
|
|
14
|
+
from pydantic import BaseModel
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Explanation(BaseModel):
|
|
18
|
+
explanation: str
|
|
19
|
+
output: str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MathDemonstration(BaseModel):
|
|
23
|
+
steps: list[Explanation]
|
|
24
|
+
final_answer: str
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
mock_cc_response = ChatCompletionResponse(
|
|
28
|
+
id="c0271b2098954c6094231703875ca0bc",
|
|
29
|
+
object="chat.completion",
|
|
30
|
+
model="mistral-large-latest",
|
|
31
|
+
usage=UsageInfo(prompt_tokens=75, completion_tokens=220, total_tokens=295),
|
|
32
|
+
created=1737727558,
|
|
33
|
+
choices=[
|
|
34
|
+
ChatCompletionChoice(
|
|
35
|
+
index=0,
|
|
36
|
+
message=AssistantMessage(
|
|
37
|
+
content='{\n "final_answer": "x = -4",\n "steps": [\n {\n "explanation": "Start with the given equation.",\n "output": "8x + 7 = -23"\n },\n {\n "explanation": "Subtract 7 from both sides to isolate the term with x.",\n "output": "8x = -23 - 7"\n },\n {\n "explanation": "Simplify the right side of the equation.",\n "output": "8x = -30"\n },\n {\n "explanation": "Divide both sides by 8 to solve for x.",\n "output": "x = -30 / 8"\n },\n {\n "explanation": "Simplify the fraction to get the final answer.",\n "output": "x = -4"\n }\n ]\n}',
|
|
38
|
+
tool_calls=None,
|
|
39
|
+
prefix=False,
|
|
40
|
+
role="assistant",
|
|
41
|
+
),
|
|
42
|
+
finish_reason="stop",
|
|
43
|
+
)
|
|
44
|
+
],
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
expected_response = ParsedChatCompletionResponse(
|
|
49
|
+
choices=[
|
|
50
|
+
ParsedChatCompletionChoice(
|
|
51
|
+
index=0,
|
|
52
|
+
message=ParsedAssistantMessage(
|
|
53
|
+
content='{\n "final_answer": "x = -4",\n "steps": [\n {\n "explanation": "Start with the given equation.",\n "output": "8x + 7 = -23"\n },\n {\n "explanation": "Subtract 7 from both sides to isolate the term with x.",\n "output": "8x = -23 - 7"\n },\n {\n "explanation": "Simplify the right side of the equation.",\n "output": "8x = -30"\n },\n {\n "explanation": "Divide both sides by 8 to solve for x.",\n "output": "x = -30 / 8"\n },\n {\n "explanation": "Simplify the fraction to get the final answer.",\n "output": "x = -4"\n }\n ]\n}',
|
|
54
|
+
tool_calls=None,
|
|
55
|
+
prefix=False,
|
|
56
|
+
role="assistant",
|
|
57
|
+
parsed=MathDemonstration(
|
|
58
|
+
steps=[
|
|
59
|
+
Explanation(
|
|
60
|
+
explanation="Start with the given equation.",
|
|
61
|
+
output="8x + 7 = -23",
|
|
62
|
+
),
|
|
63
|
+
Explanation(
|
|
64
|
+
explanation="Subtract 7 from both sides to isolate the term with x.",
|
|
65
|
+
output="8x = -23 - 7",
|
|
66
|
+
),
|
|
67
|
+
Explanation(
|
|
68
|
+
explanation="Simplify the right side of the equation.",
|
|
69
|
+
output="8x = -30",
|
|
70
|
+
),
|
|
71
|
+
Explanation(
|
|
72
|
+
explanation="Divide both sides by 8 to solve for x.",
|
|
73
|
+
output="x = -30 / 8",
|
|
74
|
+
),
|
|
75
|
+
Explanation(
|
|
76
|
+
explanation="Simplify the fraction to get the final answer.",
|
|
77
|
+
output="x = -4",
|
|
78
|
+
),
|
|
79
|
+
],
|
|
80
|
+
final_answer="x = -4",
|
|
81
|
+
),
|
|
82
|
+
),
|
|
83
|
+
finish_reason="stop",
|
|
84
|
+
)
|
|
85
|
+
],
|
|
86
|
+
created=1737727558,
|
|
87
|
+
id="c0271b2098954c6094231703875ca0bc",
|
|
88
|
+
model="mistral-large-latest",
|
|
89
|
+
object="chat.completion",
|
|
90
|
+
usage=UsageInfo(prompt_tokens=75, completion_tokens=220, total_tokens=295),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class TestConvertToParsedChatCompletionResponse(unittest.TestCase):
|
|
95
|
+
def test_convert_to_parsed_chat_completion_response(self):
|
|
96
|
+
output = convert_to_parsed_chat_completion_response(
|
|
97
|
+
mock_cc_response, MathDemonstration
|
|
98
|
+
)
|
|
99
|
+
self.assertEqual(output, expected_response)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
if __name__ == "__main__":
|
|
103
|
+
unittest.main()
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
from ..utils.response_format import (
|
|
2
|
+
pydantic_model_from_json,
|
|
3
|
+
response_format_from_pydantic_model,
|
|
4
|
+
rec_strict_json_schema,
|
|
5
|
+
)
|
|
6
|
+
from pydantic import BaseModel, ValidationError
|
|
7
|
+
|
|
8
|
+
from ...models import ResponseFormat, JSONSchema
|
|
9
|
+
from ...types.basemodel import Unset
|
|
10
|
+
|
|
11
|
+
import unittest
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Student(BaseModel):
|
|
15
|
+
name: str
|
|
16
|
+
age: int
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Explanation(BaseModel):
|
|
20
|
+
explanation: str
|
|
21
|
+
output: str
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MathDemonstration(BaseModel):
|
|
25
|
+
steps: list[Explanation]
|
|
26
|
+
final_answer: str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
mathdemo_schema = {
|
|
30
|
+
"$defs": {
|
|
31
|
+
"Explanation": {
|
|
32
|
+
"properties": {
|
|
33
|
+
"explanation": {"title": "Explanation", "type": "string"},
|
|
34
|
+
"output": {"title": "Output", "type": "string"},
|
|
35
|
+
},
|
|
36
|
+
"required": ["explanation", "output"],
|
|
37
|
+
"title": "Explanation",
|
|
38
|
+
"type": "object",
|
|
39
|
+
}
|
|
40
|
+
},
|
|
41
|
+
"properties": {
|
|
42
|
+
"steps": {
|
|
43
|
+
"items": {"$ref": "#/$defs/Explanation"},
|
|
44
|
+
"title": "Steps",
|
|
45
|
+
"type": "array",
|
|
46
|
+
},
|
|
47
|
+
"final_answer": {"title": "Final Answer", "type": "string"},
|
|
48
|
+
},
|
|
49
|
+
"required": ["steps", "final_answer"],
|
|
50
|
+
"title": "MathDemonstration",
|
|
51
|
+
"type": "object",
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
mathdemo_strict_schema = mathdemo_schema.copy()
|
|
55
|
+
mathdemo_strict_schema["$defs"]["Explanation"]["additionalProperties"] = False # type: ignore
|
|
56
|
+
mathdemo_strict_schema["additionalProperties"] = False
|
|
57
|
+
|
|
58
|
+
mathdemo_response_format = ResponseFormat(
|
|
59
|
+
type="json_schema",
|
|
60
|
+
json_schema=JSONSchema(
|
|
61
|
+
name="MathDemonstration",
|
|
62
|
+
schema_definition=mathdemo_strict_schema,
|
|
63
|
+
description=Unset(),
|
|
64
|
+
strict=True,
|
|
65
|
+
),
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class TestResponseFormat(unittest.TestCase):
|
|
70
|
+
def test_pydantic_model_from_json(self):
|
|
71
|
+
missing_json_data = {"name": "Jean Dupont"}
|
|
72
|
+
good_json_data = {"name": "Jean Dupont", "age": 25}
|
|
73
|
+
extra_json_data = {
|
|
74
|
+
"name": "Jean Dupont",
|
|
75
|
+
"age": 25,
|
|
76
|
+
"extra_field": "extra_value",
|
|
77
|
+
}
|
|
78
|
+
complex_json_data = {
|
|
79
|
+
"final_answer": "x = -4",
|
|
80
|
+
"steps": [
|
|
81
|
+
{
|
|
82
|
+
"explanation": "Start with the given equation.",
|
|
83
|
+
"output": "8x + 7 = -23",
|
|
84
|
+
},
|
|
85
|
+
{
|
|
86
|
+
"explanation": "Subtract 7 from both sides to isolate the term with x.",
|
|
87
|
+
"output": "8x = -23 - 7",
|
|
88
|
+
},
|
|
89
|
+
{
|
|
90
|
+
"explanation": "Simplify the right side of the equation.",
|
|
91
|
+
"output": "8x = -30",
|
|
92
|
+
},
|
|
93
|
+
{
|
|
94
|
+
"explanation": "Divide both sides by 8 to solve for x.",
|
|
95
|
+
"output": "x = -30 / 8",
|
|
96
|
+
},
|
|
97
|
+
{
|
|
98
|
+
"explanation": "Simplify the fraction to get the final answer.",
|
|
99
|
+
"output": "x = -4",
|
|
100
|
+
},
|
|
101
|
+
],
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
self.assertEqual(
|
|
105
|
+
pydantic_model_from_json(good_json_data, Student),
|
|
106
|
+
Student(name="Jean Dupont", age=25),
|
|
107
|
+
)
|
|
108
|
+
self.assertEqual(
|
|
109
|
+
pydantic_model_from_json(extra_json_data, Student),
|
|
110
|
+
Student(name="Jean Dupont", age=25),
|
|
111
|
+
)
|
|
112
|
+
self.assertEqual(
|
|
113
|
+
pydantic_model_from_json(complex_json_data, MathDemonstration),
|
|
114
|
+
MathDemonstration(
|
|
115
|
+
steps=[
|
|
116
|
+
Explanation(
|
|
117
|
+
explanation="Start with the given equation.",
|
|
118
|
+
output="8x + 7 = -23",
|
|
119
|
+
),
|
|
120
|
+
Explanation(
|
|
121
|
+
explanation="Subtract 7 from both sides to isolate the term with x.",
|
|
122
|
+
output="8x = -23 - 7",
|
|
123
|
+
),
|
|
124
|
+
Explanation(
|
|
125
|
+
explanation="Simplify the right side of the equation.",
|
|
126
|
+
output="8x = -30",
|
|
127
|
+
),
|
|
128
|
+
Explanation(
|
|
129
|
+
explanation="Divide both sides by 8 to solve for x.",
|
|
130
|
+
output="x = -30 / 8",
|
|
131
|
+
),
|
|
132
|
+
Explanation(
|
|
133
|
+
explanation="Simplify the fraction to get the final answer.",
|
|
134
|
+
output="x = -4",
|
|
135
|
+
),
|
|
136
|
+
],
|
|
137
|
+
final_answer="x = -4",
|
|
138
|
+
),
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Check it raises a validation error
|
|
142
|
+
with self.assertRaises(ValidationError):
|
|
143
|
+
pydantic_model_from_json(missing_json_data, Student) # type: ignore
|
|
144
|
+
|
|
145
|
+
def test_response_format_from_pydantic_model(self):
|
|
146
|
+
self.assertEqual(
|
|
147
|
+
response_format_from_pydantic_model(MathDemonstration),
|
|
148
|
+
mathdemo_response_format,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
def test_rec_strict_json_schema(self):
|
|
152
|
+
invalid_schema = mathdemo_schema | {"wrong_value": 1}
|
|
153
|
+
self.assertEqual(
|
|
154
|
+
rec_strict_json_schema(mathdemo_schema), mathdemo_strict_schema
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
with self.assertRaises(ValueError):
|
|
158
|
+
rec_strict_json_schema(invalid_schema)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
if __name__ == "__main__":
|
|
162
|
+
unittest.main()
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
def rec_strict_json_schema(schema_node: Any) -> Any:
|
|
4
|
+
"""
|
|
5
|
+
Recursively set the additionalProperties property to False for all objects in the JSON Schema.
|
|
6
|
+
This makes the JSON Schema strict (i.e. no additional properties are allowed).
|
|
7
|
+
"""
|
|
8
|
+
if isinstance(schema_node, (str, bool)):
|
|
9
|
+
return schema_node
|
|
10
|
+
if isinstance(schema_node, dict):
|
|
11
|
+
if "type" in schema_node and schema_node["type"] == "object":
|
|
12
|
+
schema_node["additionalProperties"] = False
|
|
13
|
+
for key, value in schema_node.items():
|
|
14
|
+
schema_node[key] = rec_strict_json_schema(value)
|
|
15
|
+
elif isinstance(schema_node, list):
|
|
16
|
+
for i, value in enumerate(schema_node):
|
|
17
|
+
schema_node[i] = rec_strict_json_schema(value)
|
|
18
|
+
else:
|
|
19
|
+
raise ValueError(f"Unexpected type: {schema_node}")
|
|
20
|
+
return schema_node
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from pydantic import BaseModel
|
|
2
|
+
from typing import TypeVar, Any, Type
|
|
3
|
+
from ...models import JSONSchema, ResponseFormat
|
|
4
|
+
from ._pydantic_helper import rec_strict_json_schema
|
|
5
|
+
|
|
6
|
+
CustomPydanticModel = TypeVar("CustomPydanticModel", bound=BaseModel)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def response_format_from_pydantic_model(
|
|
10
|
+
model: type[CustomPydanticModel],
|
|
11
|
+
) -> ResponseFormat:
|
|
12
|
+
"""Generate a strict JSON schema from a pydantic model."""
|
|
13
|
+
model_schema = rec_strict_json_schema(model.model_json_schema())
|
|
14
|
+
json_schema = JSONSchema.model_validate(
|
|
15
|
+
{"name": model.__name__, "schema": model_schema, "strict": True}
|
|
16
|
+
)
|
|
17
|
+
return ResponseFormat(type="json_schema", json_schema=json_schema)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def pydantic_model_from_json(
|
|
21
|
+
json_data: dict[str, Any], pydantic_model: Type[CustomPydanticModel]
|
|
22
|
+
) -> CustomPydanticModel:
|
|
23
|
+
"""Parse a JSON schema into a pydantic model."""
|
|
24
|
+
return pydantic_model.model_validate(json_data)
|