mistralai 1.5.1__py3-none-any.whl → 1.5.2rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/_hooks/types.py +15 -3
- mistralai/_version.py +3 -3
- mistralai/agents.py +32 -12
- mistralai/basesdk.py +8 -0
- mistralai/chat.py +32 -12
- mistralai/classifiers.py +32 -12
- mistralai/embeddings.py +20 -10
- mistralai/extra/utils/response_format.py +3 -3
- mistralai/files.py +36 -0
- mistralai/fim.py +32 -12
- mistralai/httpclient.py +4 -2
- mistralai/jobs.py +30 -0
- mistralai/mistral_jobs.py +24 -0
- mistralai/models/__init__.py +6 -1
- mistralai/models/documenturlchunk.py +8 -14
- mistralai/models/embeddingrequest.py +7 -7
- mistralai/models/filepurpose.py +1 -1
- mistralai/models_.py +66 -18
- mistralai/ocr.py +16 -6
- mistralai/sdk.py +19 -3
- mistralai/sdkconfiguration.py +4 -2
- mistralai/utils/__init__.py +2 -0
- mistralai/utils/serializers.py +10 -6
- mistralai/utils/values.py +4 -1
- {mistralai-1.5.1.dist-info → mistralai-1.5.2rc1.dist-info}/METADATA +66 -19
- {mistralai-1.5.1.dist-info → mistralai-1.5.2rc1.dist-info}/RECORD +73 -69
- mistralai_azure/__init__.py +10 -1
- mistralai_azure/_hooks/types.py +15 -3
- mistralai_azure/_version.py +3 -0
- mistralai_azure/basesdk.py +8 -0
- mistralai_azure/chat.py +88 -20
- mistralai_azure/httpclient.py +52 -0
- mistralai_azure/models/__init__.py +7 -0
- mistralai_azure/models/assistantmessage.py +2 -0
- mistralai_azure/models/chatcompletionrequest.py +8 -10
- mistralai_azure/models/chatcompletionstreamrequest.py +8 -10
- mistralai_azure/models/function.py +3 -0
- mistralai_azure/models/jsonschema.py +61 -0
- mistralai_azure/models/prediction.py +25 -0
- mistralai_azure/models/responseformat.py +42 -1
- mistralai_azure/models/responseformats.py +1 -1
- mistralai_azure/models/toolcall.py +3 -0
- mistralai_azure/sdk.py +56 -14
- mistralai_azure/sdkconfiguration.py +14 -6
- mistralai_azure/utils/__init__.py +2 -0
- mistralai_azure/utils/serializers.py +10 -6
- mistralai_azure/utils/values.py +4 -1
- mistralai_gcp/__init__.py +10 -1
- mistralai_gcp/_hooks/types.py +15 -3
- mistralai_gcp/_version.py +3 -0
- mistralai_gcp/basesdk.py +8 -0
- mistralai_gcp/chat.py +89 -21
- mistralai_gcp/fim.py +61 -21
- mistralai_gcp/httpclient.py +52 -0
- mistralai_gcp/models/__init__.py +7 -0
- mistralai_gcp/models/assistantmessage.py +2 -0
- mistralai_gcp/models/chatcompletionrequest.py +8 -10
- mistralai_gcp/models/chatcompletionstreamrequest.py +8 -10
- mistralai_gcp/models/fimcompletionrequest.py +2 -3
- mistralai_gcp/models/fimcompletionstreamrequest.py +2 -3
- mistralai_gcp/models/function.py +3 -0
- mistralai_gcp/models/jsonschema.py +61 -0
- mistralai_gcp/models/prediction.py +25 -0
- mistralai_gcp/models/responseformat.py +42 -1
- mistralai_gcp/models/responseformats.py +1 -1
- mistralai_gcp/models/toolcall.py +3 -0
- mistralai_gcp/sdk.py +63 -19
- mistralai_gcp/sdkconfiguration.py +14 -6
- mistralai_gcp/utils/__init__.py +2 -0
- mistralai_gcp/utils/serializers.py +10 -6
- mistralai_gcp/utils/values.py +4 -1
- {mistralai-1.5.1.dist-info → mistralai-1.5.2rc1.dist-info}/LICENSE +0 -0
- {mistralai-1.5.1.dist-info → mistralai-1.5.2rc1.dist-info}/WHEEL +0 -0
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
from .assistantmessage import AssistantMessage, AssistantMessageTypedDict
|
|
5
|
+
from .prediction import Prediction, PredictionTypedDict
|
|
5
6
|
from .responseformat import ResponseFormat, ResponseFormatTypedDict
|
|
6
7
|
from .systemmessage import SystemMessage, SystemMessageTypedDict
|
|
7
8
|
from .tool import Tool, ToolTypedDict
|
|
@@ -68,7 +69,7 @@ ChatCompletionRequestToolChoice = TypeAliasType(
|
|
|
68
69
|
|
|
69
70
|
|
|
70
71
|
class ChatCompletionRequestTypedDict(TypedDict):
|
|
71
|
-
model:
|
|
72
|
+
model: str
|
|
72
73
|
r"""ID of the model to use. You can use the [List Available Models](/api/#tag/models/operation/list_models_v1_models_get) API to see all of your available models, or see our [Model overview](/models) for model descriptions."""
|
|
73
74
|
messages: List[ChatCompletionRequestMessagesTypedDict]
|
|
74
75
|
r"""The prompt(s) to generate completions for, encoded as a list of dict with role and content."""
|
|
@@ -93,10 +94,11 @@ class ChatCompletionRequestTypedDict(TypedDict):
|
|
|
93
94
|
r"""frequency_penalty penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition."""
|
|
94
95
|
n: NotRequired[Nullable[int]]
|
|
95
96
|
r"""Number of completions to return for each request, input tokens are only billed once."""
|
|
97
|
+
prediction: NotRequired[PredictionTypedDict]
|
|
96
98
|
|
|
97
99
|
|
|
98
100
|
class ChatCompletionRequest(BaseModel):
|
|
99
|
-
model:
|
|
101
|
+
model: str
|
|
100
102
|
r"""ID of the model to use. You can use the [List Available Models](/api/#tag/models/operation/list_models_v1_models_get) API to see all of your available models, or see our [Model overview](/models) for model descriptions."""
|
|
101
103
|
|
|
102
104
|
messages: List[ChatCompletionRequestMessages]
|
|
@@ -135,6 +137,8 @@ class ChatCompletionRequest(BaseModel):
|
|
|
135
137
|
n: OptionalNullable[int] = UNSET
|
|
136
138
|
r"""Number of completions to return for each request, input tokens are only billed once."""
|
|
137
139
|
|
|
140
|
+
prediction: Optional[Prediction] = None
|
|
141
|
+
|
|
138
142
|
@model_serializer(mode="wrap")
|
|
139
143
|
def serialize_model(self, handler):
|
|
140
144
|
optional_fields = [
|
|
@@ -150,15 +154,9 @@ class ChatCompletionRequest(BaseModel):
|
|
|
150
154
|
"presence_penalty",
|
|
151
155
|
"frequency_penalty",
|
|
152
156
|
"n",
|
|
157
|
+
"prediction",
|
|
153
158
|
]
|
|
154
|
-
nullable_fields = [
|
|
155
|
-
"model",
|
|
156
|
-
"temperature",
|
|
157
|
-
"max_tokens",
|
|
158
|
-
"random_seed",
|
|
159
|
-
"tools",
|
|
160
|
-
"n",
|
|
161
|
-
]
|
|
159
|
+
nullable_fields = ["temperature", "max_tokens", "random_seed", "tools", "n"]
|
|
162
160
|
null_default_fields = []
|
|
163
161
|
|
|
164
162
|
serialized = handler(self)
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
from .assistantmessage import AssistantMessage, AssistantMessageTypedDict
|
|
5
|
+
from .prediction import Prediction, PredictionTypedDict
|
|
5
6
|
from .responseformat import ResponseFormat, ResponseFormatTypedDict
|
|
6
7
|
from .systemmessage import SystemMessage, SystemMessageTypedDict
|
|
7
8
|
from .tool import Tool, ToolTypedDict
|
|
@@ -64,7 +65,7 @@ ChatCompletionStreamRequestToolChoice = TypeAliasType(
|
|
|
64
65
|
|
|
65
66
|
|
|
66
67
|
class ChatCompletionStreamRequestTypedDict(TypedDict):
|
|
67
|
-
model:
|
|
68
|
+
model: str
|
|
68
69
|
r"""ID of the model to use. You can use the [List Available Models](/api/#tag/models/operation/list_models_v1_models_get) API to see all of your available models, or see our [Model overview](/models) for model descriptions."""
|
|
69
70
|
messages: List[MessagesTypedDict]
|
|
70
71
|
r"""The prompt(s) to generate completions for, encoded as a list of dict with role and content."""
|
|
@@ -88,10 +89,11 @@ class ChatCompletionStreamRequestTypedDict(TypedDict):
|
|
|
88
89
|
r"""frequency_penalty penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition."""
|
|
89
90
|
n: NotRequired[Nullable[int]]
|
|
90
91
|
r"""Number of completions to return for each request, input tokens are only billed once."""
|
|
92
|
+
prediction: NotRequired[PredictionTypedDict]
|
|
91
93
|
|
|
92
94
|
|
|
93
95
|
class ChatCompletionStreamRequest(BaseModel):
|
|
94
|
-
model:
|
|
96
|
+
model: str
|
|
95
97
|
r"""ID of the model to use. You can use the [List Available Models](/api/#tag/models/operation/list_models_v1_models_get) API to see all of your available models, or see our [Model overview](/models) for model descriptions."""
|
|
96
98
|
|
|
97
99
|
messages: List[Messages]
|
|
@@ -129,6 +131,8 @@ class ChatCompletionStreamRequest(BaseModel):
|
|
|
129
131
|
n: OptionalNullable[int] = UNSET
|
|
130
132
|
r"""Number of completions to return for each request, input tokens are only billed once."""
|
|
131
133
|
|
|
134
|
+
prediction: Optional[Prediction] = None
|
|
135
|
+
|
|
132
136
|
@model_serializer(mode="wrap")
|
|
133
137
|
def serialize_model(self, handler):
|
|
134
138
|
optional_fields = [
|
|
@@ -144,15 +148,9 @@ class ChatCompletionStreamRequest(BaseModel):
|
|
|
144
148
|
"presence_penalty",
|
|
145
149
|
"frequency_penalty",
|
|
146
150
|
"n",
|
|
151
|
+
"prediction",
|
|
147
152
|
]
|
|
148
|
-
nullable_fields = [
|
|
149
|
-
"model",
|
|
150
|
-
"temperature",
|
|
151
|
-
"max_tokens",
|
|
152
|
-
"random_seed",
|
|
153
|
-
"tools",
|
|
154
|
-
"n",
|
|
155
|
-
]
|
|
153
|
+
nullable_fields = ["temperature", "max_tokens", "random_seed", "tools", "n"]
|
|
156
154
|
null_default_fields = []
|
|
157
155
|
|
|
158
156
|
serialized = handler(self)
|
|
@@ -26,7 +26,7 @@ r"""Stop generation if this token is detected. Or if one of these tokens is dete
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class FIMCompletionRequestTypedDict(TypedDict):
|
|
29
|
-
model:
|
|
29
|
+
model: str
|
|
30
30
|
r"""ID of the model to use. Only compatible for now with:
|
|
31
31
|
- `codestral-2405`
|
|
32
32
|
- `codestral-latest`
|
|
@@ -52,7 +52,7 @@ class FIMCompletionRequestTypedDict(TypedDict):
|
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
class FIMCompletionRequest(BaseModel):
|
|
55
|
-
model:
|
|
55
|
+
model: str
|
|
56
56
|
r"""ID of the model to use. Only compatible for now with:
|
|
57
57
|
- `codestral-2405`
|
|
58
58
|
- `codestral-latest`
|
|
@@ -98,7 +98,6 @@ class FIMCompletionRequest(BaseModel):
|
|
|
98
98
|
"min_tokens",
|
|
99
99
|
]
|
|
100
100
|
nullable_fields = [
|
|
101
|
-
"model",
|
|
102
101
|
"temperature",
|
|
103
102
|
"max_tokens",
|
|
104
103
|
"random_seed",
|
|
@@ -26,7 +26,7 @@ r"""Stop generation if this token is detected. Or if one of these tokens is dete
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class FIMCompletionStreamRequestTypedDict(TypedDict):
|
|
29
|
-
model:
|
|
29
|
+
model: str
|
|
30
30
|
r"""ID of the model to use. Only compatible for now with:
|
|
31
31
|
- `codestral-2405`
|
|
32
32
|
- `codestral-latest`
|
|
@@ -51,7 +51,7 @@ class FIMCompletionStreamRequestTypedDict(TypedDict):
|
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
class FIMCompletionStreamRequest(BaseModel):
|
|
54
|
-
model:
|
|
54
|
+
model: str
|
|
55
55
|
r"""ID of the model to use. Only compatible for now with:
|
|
56
56
|
- `codestral-2405`
|
|
57
57
|
- `codestral-latest`
|
|
@@ -96,7 +96,6 @@ class FIMCompletionStreamRequest(BaseModel):
|
|
|
96
96
|
"min_tokens",
|
|
97
97
|
]
|
|
98
98
|
nullable_fields = [
|
|
99
|
-
"model",
|
|
100
99
|
"temperature",
|
|
101
100
|
"max_tokens",
|
|
102
101
|
"random_seed",
|
mistralai_gcp/models/function.py
CHANGED
|
@@ -10,6 +10,7 @@ class FunctionTypedDict(TypedDict):
|
|
|
10
10
|
name: str
|
|
11
11
|
parameters: Dict[str, Any]
|
|
12
12
|
description: NotRequired[str]
|
|
13
|
+
strict: NotRequired[bool]
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class Function(BaseModel):
|
|
@@ -18,3 +19,5 @@ class Function(BaseModel):
|
|
|
18
19
|
parameters: Dict[str, Any]
|
|
19
20
|
|
|
20
21
|
description: Optional[str] = ""
|
|
22
|
+
|
|
23
|
+
strict: Optional[bool] = False
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from mistralai_gcp.types import (
|
|
5
|
+
BaseModel,
|
|
6
|
+
Nullable,
|
|
7
|
+
OptionalNullable,
|
|
8
|
+
UNSET,
|
|
9
|
+
UNSET_SENTINEL,
|
|
10
|
+
)
|
|
11
|
+
import pydantic
|
|
12
|
+
from pydantic import model_serializer
|
|
13
|
+
from typing import Any, Dict, Optional
|
|
14
|
+
from typing_extensions import Annotated, NotRequired, TypedDict
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class JSONSchemaTypedDict(TypedDict):
|
|
18
|
+
name: str
|
|
19
|
+
schema_definition: Dict[str, Any]
|
|
20
|
+
description: NotRequired[Nullable[str]]
|
|
21
|
+
strict: NotRequired[bool]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class JSONSchema(BaseModel):
|
|
25
|
+
name: str
|
|
26
|
+
|
|
27
|
+
schema_definition: Annotated[Dict[str, Any], pydantic.Field(alias="schema")]
|
|
28
|
+
|
|
29
|
+
description: OptionalNullable[str] = UNSET
|
|
30
|
+
|
|
31
|
+
strict: Optional[bool] = False
|
|
32
|
+
|
|
33
|
+
@model_serializer(mode="wrap")
|
|
34
|
+
def serialize_model(self, handler):
|
|
35
|
+
optional_fields = ["description", "strict"]
|
|
36
|
+
nullable_fields = ["description"]
|
|
37
|
+
null_default_fields = []
|
|
38
|
+
|
|
39
|
+
serialized = handler(self)
|
|
40
|
+
|
|
41
|
+
m = {}
|
|
42
|
+
|
|
43
|
+
for n, f in self.model_fields.items():
|
|
44
|
+
k = f.alias or n
|
|
45
|
+
val = serialized.get(k)
|
|
46
|
+
serialized.pop(k, None)
|
|
47
|
+
|
|
48
|
+
optional_nullable = k in optional_fields and k in nullable_fields
|
|
49
|
+
is_set = (
|
|
50
|
+
self.__pydantic_fields_set__.intersection({n})
|
|
51
|
+
or k in null_default_fields
|
|
52
|
+
) # pylint: disable=no-member
|
|
53
|
+
|
|
54
|
+
if val is not None and val != UNSET_SENTINEL:
|
|
55
|
+
m[k] = val
|
|
56
|
+
elif val != UNSET_SENTINEL and (
|
|
57
|
+
not k in optional_fields or (optional_nullable and is_set)
|
|
58
|
+
):
|
|
59
|
+
m[k] = val
|
|
60
|
+
|
|
61
|
+
return m
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from mistralai_gcp.types import BaseModel
|
|
5
|
+
from mistralai_gcp.utils import validate_const
|
|
6
|
+
import pydantic
|
|
7
|
+
from pydantic.functional_validators import AfterValidator
|
|
8
|
+
from typing import Literal, Optional
|
|
9
|
+
from typing_extensions import Annotated, NotRequired, TypedDict
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PredictionTypedDict(TypedDict):
|
|
13
|
+
type: Literal["content"]
|
|
14
|
+
content: NotRequired[str]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Prediction(BaseModel):
|
|
18
|
+
TYPE: Annotated[
|
|
19
|
+
Annotated[
|
|
20
|
+
Optional[Literal["content"]], AfterValidator(validate_const("content"))
|
|
21
|
+
],
|
|
22
|
+
pydantic.Field(alias="type"),
|
|
23
|
+
] = "content"
|
|
24
|
+
|
|
25
|
+
content: Optional[str] = ""
|
|
@@ -1,8 +1,16 @@
|
|
|
1
1
|
"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
|
+
from .jsonschema import JSONSchema, JSONSchemaTypedDict
|
|
4
5
|
from .responseformats import ResponseFormats
|
|
5
|
-
from mistralai_gcp.types import
|
|
6
|
+
from mistralai_gcp.types import (
|
|
7
|
+
BaseModel,
|
|
8
|
+
Nullable,
|
|
9
|
+
OptionalNullable,
|
|
10
|
+
UNSET,
|
|
11
|
+
UNSET_SENTINEL,
|
|
12
|
+
)
|
|
13
|
+
from pydantic import model_serializer
|
|
6
14
|
from typing import Optional
|
|
7
15
|
from typing_extensions import NotRequired, TypedDict
|
|
8
16
|
|
|
@@ -10,8 +18,41 @@ from typing_extensions import NotRequired, TypedDict
|
|
|
10
18
|
class ResponseFormatTypedDict(TypedDict):
|
|
11
19
|
type: NotRequired[ResponseFormats]
|
|
12
20
|
r"""An object specifying the format that the model must output. Setting to `{ \"type\": \"json_object\" }` enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message."""
|
|
21
|
+
json_schema: NotRequired[Nullable[JSONSchemaTypedDict]]
|
|
13
22
|
|
|
14
23
|
|
|
15
24
|
class ResponseFormat(BaseModel):
|
|
16
25
|
type: Optional[ResponseFormats] = None
|
|
17
26
|
r"""An object specifying the format that the model must output. Setting to `{ \"type\": \"json_object\" }` enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message."""
|
|
27
|
+
|
|
28
|
+
json_schema: OptionalNullable[JSONSchema] = UNSET
|
|
29
|
+
|
|
30
|
+
@model_serializer(mode="wrap")
|
|
31
|
+
def serialize_model(self, handler):
|
|
32
|
+
optional_fields = ["type", "json_schema"]
|
|
33
|
+
nullable_fields = ["json_schema"]
|
|
34
|
+
null_default_fields = []
|
|
35
|
+
|
|
36
|
+
serialized = handler(self)
|
|
37
|
+
|
|
38
|
+
m = {}
|
|
39
|
+
|
|
40
|
+
for n, f in self.model_fields.items():
|
|
41
|
+
k = f.alias or n
|
|
42
|
+
val = serialized.get(k)
|
|
43
|
+
serialized.pop(k, None)
|
|
44
|
+
|
|
45
|
+
optional_nullable = k in optional_fields and k in nullable_fields
|
|
46
|
+
is_set = (
|
|
47
|
+
self.__pydantic_fields_set__.intersection({n})
|
|
48
|
+
or k in null_default_fields
|
|
49
|
+
) # pylint: disable=no-member
|
|
50
|
+
|
|
51
|
+
if val is not None and val != UNSET_SENTINEL:
|
|
52
|
+
m[k] = val
|
|
53
|
+
elif val != UNSET_SENTINEL and (
|
|
54
|
+
not k in optional_fields or (optional_nullable and is_set)
|
|
55
|
+
):
|
|
56
|
+
m[k] = val
|
|
57
|
+
|
|
58
|
+
return m
|
|
@@ -4,5 +4,5 @@ from __future__ import annotations
|
|
|
4
4
|
from typing import Literal
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
ResponseFormats = Literal["text", "json_object"]
|
|
7
|
+
ResponseFormats = Literal["text", "json_object", "json_schema"]
|
|
8
8
|
r"""An object specifying the format that the model must output. Setting to `{ \"type\": \"json_object\" }` enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message."""
|
mistralai_gcp/models/toolcall.py
CHANGED
|
@@ -14,6 +14,7 @@ class ToolCallTypedDict(TypedDict):
|
|
|
14
14
|
function: FunctionCallTypedDict
|
|
15
15
|
id: NotRequired[str]
|
|
16
16
|
type: NotRequired[ToolTypes]
|
|
17
|
+
index: NotRequired[int]
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class ToolCall(BaseModel):
|
|
@@ -24,3 +25,5 @@ class ToolCall(BaseModel):
|
|
|
24
25
|
type: Annotated[Optional[ToolTypes], PlainValidator(validate_open_enum(False))] = (
|
|
25
26
|
None
|
|
26
27
|
)
|
|
28
|
+
|
|
29
|
+
index: Optional[int] = 0
|
mistralai_gcp/sdk.py
CHANGED
|
@@ -1,23 +1,25 @@
|
|
|
1
|
-
"""Code generated by Speakeasy (https://
|
|
1
|
+
"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
|
|
4
|
+
import weakref
|
|
5
|
+
from typing import Any, Optional, cast
|
|
5
6
|
|
|
6
7
|
import google.auth
|
|
7
8
|
import google.auth.credentials
|
|
8
9
|
import google.auth.transport
|
|
9
10
|
import google.auth.transport.requests
|
|
10
11
|
import httpx
|
|
12
|
+
|
|
11
13
|
from mistralai_gcp import models
|
|
12
14
|
from mistralai_gcp._hooks import BeforeRequestHook, SDKHooks
|
|
13
15
|
from mistralai_gcp.chat import Chat
|
|
14
16
|
from mistralai_gcp.fim import Fim
|
|
15
|
-
from mistralai_gcp.types import
|
|
17
|
+
from mistralai_gcp.types import UNSET, OptionalNullable
|
|
16
18
|
|
|
17
19
|
from .basesdk import BaseSDK
|
|
18
|
-
from .httpclient import AsyncHttpClient, HttpClient
|
|
20
|
+
from .httpclient import AsyncHttpClient, ClientOwner, HttpClient, close_clients
|
|
19
21
|
from .sdkconfiguration import SDKConfiguration
|
|
20
|
-
from .utils.logger import Logger,
|
|
22
|
+
from .utils.logger import Logger, get_default_logger
|
|
21
23
|
from .utils.retries import RetryConfig
|
|
22
24
|
|
|
23
25
|
LEGACY_MODEL_ID_FORMAT = {
|
|
@@ -26,20 +28,21 @@ LEGACY_MODEL_ID_FORMAT = {
|
|
|
26
28
|
"mistral-nemo-2407": "mistral-nemo@2407",
|
|
27
29
|
}
|
|
28
30
|
|
|
29
|
-
|
|
31
|
+
|
|
32
|
+
def get_model_info(model: str) -> tuple[str, str]:
|
|
30
33
|
# if the model requiers the legacy fomat, use it, else do nothing.
|
|
31
34
|
if model in LEGACY_MODEL_ID_FORMAT:
|
|
32
35
|
return "-".join(model.split("-")[:-1]), LEGACY_MODEL_ID_FORMAT[model]
|
|
33
36
|
return model, model
|
|
34
37
|
|
|
35
38
|
|
|
36
|
-
|
|
37
39
|
class MistralGoogleCloud(BaseSDK):
|
|
38
40
|
r"""Mistral AI API: Our Chat Completion and Embeddings APIs specification. Create your account on [La Plateforme](https://console.mistral.ai) to get access and read the [docs](https://docs.mistral.ai) to learn how to use it."""
|
|
39
41
|
|
|
40
42
|
chat: Chat
|
|
43
|
+
r"""Chat Completion API."""
|
|
41
44
|
fim: Fim
|
|
42
|
-
r"""
|
|
45
|
+
r"""Fill-in-the-middle API."""
|
|
43
46
|
|
|
44
47
|
def __init__(
|
|
45
48
|
self,
|
|
@@ -48,16 +51,20 @@ class MistralGoogleCloud(BaseSDK):
|
|
|
48
51
|
access_token: Optional[str] = None,
|
|
49
52
|
client: Optional[HttpClient] = None,
|
|
50
53
|
async_client: Optional[AsyncHttpClient] = None,
|
|
51
|
-
retry_config:
|
|
54
|
+
retry_config: OptionalNullable[RetryConfig] = UNSET,
|
|
55
|
+
timeout_ms: Optional[int] = None,
|
|
52
56
|
debug_logger: Optional[Logger] = None,
|
|
53
57
|
) -> None:
|
|
54
58
|
r"""Instantiates the SDK configuring it with the provided parameters.
|
|
55
59
|
|
|
56
|
-
:param
|
|
57
|
-
:param
|
|
60
|
+
:param api_key: The api_key required for authentication
|
|
61
|
+
:param server: The server by name to use for all methods
|
|
62
|
+
:param server_url: The server URL to use for all methods
|
|
63
|
+
:param url_params: Parameters to optionally template the server URL with
|
|
58
64
|
:param client: The HTTP client to use for all synchronous methods
|
|
59
65
|
:param async_client: The Async HTTP client to use for all asynchronous methods
|
|
60
66
|
:param retry_config: The retry configuration to use for all supported methods
|
|
67
|
+
:param timeout_ms: Optional request timeout applied to each operation in milliseconds
|
|
61
68
|
"""
|
|
62
69
|
|
|
63
70
|
if not access_token:
|
|
@@ -72,36 +79,42 @@ class MistralGoogleCloud(BaseSDK):
|
|
|
72
79
|
)
|
|
73
80
|
|
|
74
81
|
project_id = project_id or loaded_project_id
|
|
82
|
+
|
|
75
83
|
if project_id is None:
|
|
76
84
|
raise models.SDKError("project_id must be provided")
|
|
77
85
|
|
|
78
86
|
def auth_token() -> str:
|
|
79
87
|
if access_token:
|
|
80
88
|
return access_token
|
|
89
|
+
|
|
81
90
|
credentials.refresh(google.auth.transport.requests.Request())
|
|
82
91
|
token = credentials.token
|
|
83
92
|
if not token:
|
|
84
93
|
raise models.SDKError("Failed to get token from credentials")
|
|
85
94
|
return token
|
|
86
95
|
|
|
96
|
+
client_supplied = True
|
|
87
97
|
if client is None:
|
|
88
98
|
client = httpx.Client()
|
|
99
|
+
client_supplied = False
|
|
89
100
|
|
|
90
101
|
assert issubclass(
|
|
91
102
|
type(client), HttpClient
|
|
92
103
|
), "The provided client must implement the HttpClient protocol."
|
|
93
104
|
|
|
105
|
+
async_client_supplied = True
|
|
94
106
|
if async_client is None:
|
|
95
107
|
async_client = httpx.AsyncClient()
|
|
108
|
+
async_client_supplied = False
|
|
96
109
|
|
|
97
110
|
if debug_logger is None:
|
|
98
|
-
debug_logger =
|
|
111
|
+
debug_logger = get_default_logger()
|
|
99
112
|
|
|
100
113
|
assert issubclass(
|
|
101
114
|
type(async_client), AsyncHttpClient
|
|
102
115
|
), "The provided async_client must implement the AsyncHttpClient protocol."
|
|
103
116
|
|
|
104
|
-
security = None
|
|
117
|
+
security: Any = None
|
|
105
118
|
if callable(auth_token):
|
|
106
119
|
security = lambda: models.Security( # pylint: disable=unnecessary-lambda-assignment
|
|
107
120
|
api_key=auth_token()
|
|
@@ -113,23 +126,24 @@ class MistralGoogleCloud(BaseSDK):
|
|
|
113
126
|
self,
|
|
114
127
|
SDKConfiguration(
|
|
115
128
|
client=client,
|
|
129
|
+
client_supplied=client_supplied,
|
|
116
130
|
async_client=async_client,
|
|
131
|
+
async_client_supplied=async_client_supplied,
|
|
117
132
|
security=security,
|
|
118
133
|
server_url=f"https://{region}-aiplatform.googleapis.com",
|
|
119
134
|
server=None,
|
|
120
135
|
retry_config=retry_config,
|
|
136
|
+
timeout_ms=timeout_ms,
|
|
121
137
|
debug_logger=debug_logger,
|
|
122
138
|
),
|
|
123
139
|
)
|
|
124
140
|
|
|
125
141
|
hooks = SDKHooks()
|
|
126
|
-
|
|
127
142
|
hook = GoogleCloudBeforeRequestHook(region, project_id)
|
|
128
143
|
hooks.register_before_request_hook(hook)
|
|
129
|
-
|
|
130
144
|
current_server_url, *_ = self.sdk_configuration.get_server_details()
|
|
131
145
|
server_url, self.sdk_configuration.client = hooks.sdk_init(
|
|
132
|
-
current_server_url,
|
|
146
|
+
current_server_url, client
|
|
133
147
|
)
|
|
134
148
|
if current_server_url != server_url:
|
|
135
149
|
self.sdk_configuration.server_url = server_url
|
|
@@ -137,22 +151,53 @@ class MistralGoogleCloud(BaseSDK):
|
|
|
137
151
|
# pylint: disable=protected-access
|
|
138
152
|
self.sdk_configuration.__dict__["_hooks"] = hooks
|
|
139
153
|
|
|
154
|
+
weakref.finalize(
|
|
155
|
+
self,
|
|
156
|
+
close_clients,
|
|
157
|
+
cast(ClientOwner, self.sdk_configuration),
|
|
158
|
+
self.sdk_configuration.client,
|
|
159
|
+
self.sdk_configuration.client_supplied,
|
|
160
|
+
self.sdk_configuration.async_client,
|
|
161
|
+
self.sdk_configuration.async_client_supplied,
|
|
162
|
+
)
|
|
163
|
+
|
|
140
164
|
self._init_sdks()
|
|
141
165
|
|
|
142
166
|
def _init_sdks(self):
|
|
143
167
|
self.chat = Chat(self.sdk_configuration)
|
|
144
168
|
self.fim = Fim(self.sdk_configuration)
|
|
145
169
|
|
|
170
|
+
def __enter__(self):
|
|
171
|
+
return self
|
|
146
172
|
|
|
147
|
-
|
|
173
|
+
async def __aenter__(self):
|
|
174
|
+
return self
|
|
175
|
+
|
|
176
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
177
|
+
if (
|
|
178
|
+
self.sdk_configuration.client is not None
|
|
179
|
+
and not self.sdk_configuration.client_supplied
|
|
180
|
+
):
|
|
181
|
+
self.sdk_configuration.client.close()
|
|
182
|
+
self.sdk_configuration.client = None
|
|
148
183
|
|
|
184
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
185
|
+
if (
|
|
186
|
+
self.sdk_configuration.async_client is not None
|
|
187
|
+
and not self.sdk_configuration.async_client_supplied
|
|
188
|
+
):
|
|
189
|
+
await self.sdk_configuration.async_client.aclose()
|
|
190
|
+
self.sdk_configuration.async_client = None
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class GoogleCloudBeforeRequestHook(BeforeRequestHook):
|
|
149
194
|
def __init__(self, region: str, project_id: str):
|
|
150
195
|
self.region = region
|
|
151
196
|
self.project_id = project_id
|
|
152
197
|
|
|
153
198
|
def before_request(
|
|
154
199
|
self, hook_ctx, request: httpx.Request
|
|
155
|
-
) ->
|
|
200
|
+
) -> httpx.Request | Exception:
|
|
156
201
|
# The goal of this function is to template in the region, project and model into the URL path
|
|
157
202
|
# We do this here so that the API remains more user-friendly
|
|
158
203
|
model_id = None
|
|
@@ -167,7 +212,6 @@ class GoogleCloudBeforeRequestHook(BeforeRequestHook):
|
|
|
167
212
|
if model_id == "":
|
|
168
213
|
raise models.SDKError("model must be provided")
|
|
169
214
|
|
|
170
|
-
|
|
171
215
|
stream = "streamRawPredict" in request.url.path
|
|
172
216
|
specifier = "streamRawPredict" if stream else "rawPredict"
|
|
173
217
|
url = f"/v1/projects/{self.project_id}/locations/{self.region}/publishers/mistralai/models/{model_id}:{specifier}"
|
|
@@ -1,6 +1,12 @@
|
|
|
1
1
|
"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
|
|
2
2
|
|
|
3
3
|
from ._hooks import SDKHooks
|
|
4
|
+
from ._version import (
|
|
5
|
+
__gen_version__,
|
|
6
|
+
__openapi_doc_version__,
|
|
7
|
+
__user_agent__,
|
|
8
|
+
__version__,
|
|
9
|
+
)
|
|
4
10
|
from .httpclient import AsyncHttpClient, HttpClient
|
|
5
11
|
from .utils import Logger, RetryConfig, remove_suffix
|
|
6
12
|
from dataclasses import dataclass
|
|
@@ -20,17 +26,19 @@ SERVERS = {
|
|
|
20
26
|
|
|
21
27
|
@dataclass
|
|
22
28
|
class SDKConfiguration:
|
|
23
|
-
client: HttpClient
|
|
24
|
-
|
|
29
|
+
client: Union[HttpClient, None]
|
|
30
|
+
client_supplied: bool
|
|
31
|
+
async_client: Union[AsyncHttpClient, None]
|
|
32
|
+
async_client_supplied: bool
|
|
25
33
|
debug_logger: Logger
|
|
26
34
|
security: Optional[Union[models.Security, Callable[[], models.Security]]] = None
|
|
27
35
|
server_url: Optional[str] = ""
|
|
28
36
|
server: Optional[str] = ""
|
|
29
37
|
language: str = "python"
|
|
30
|
-
openapi_doc_version: str =
|
|
31
|
-
sdk_version: str =
|
|
32
|
-
gen_version: str =
|
|
33
|
-
user_agent: str =
|
|
38
|
+
openapi_doc_version: str = __openapi_doc_version__
|
|
39
|
+
sdk_version: str = __version__
|
|
40
|
+
gen_version: str = __gen_version__
|
|
41
|
+
user_agent: str = __user_agent__
|
|
34
42
|
retry_config: OptionalNullable[RetryConfig] = Field(default_factory=lambda: UNSET)
|
|
35
43
|
timeout_ms: Optional[int] = None
|
|
36
44
|
|
mistralai_gcp/utils/__init__.py
CHANGED
|
@@ -42,6 +42,7 @@ from .values import (
|
|
|
42
42
|
match_content_type,
|
|
43
43
|
match_status_codes,
|
|
44
44
|
match_response,
|
|
45
|
+
cast_partial,
|
|
45
46
|
)
|
|
46
47
|
from .logger import Logger, get_body_content, get_default_logger
|
|
47
48
|
|
|
@@ -94,4 +95,5 @@ __all__ = [
|
|
|
94
95
|
"validate_float",
|
|
95
96
|
"validate_int",
|
|
96
97
|
"validate_open_enum",
|
|
98
|
+
"cast_partial",
|
|
97
99
|
]
|
|
@@ -7,14 +7,15 @@ import httpx
|
|
|
7
7
|
from typing_extensions import get_origin
|
|
8
8
|
from pydantic import ConfigDict, create_model
|
|
9
9
|
from pydantic_core import from_json
|
|
10
|
-
from
|
|
10
|
+
from typing_inspection.typing_objects import is_union
|
|
11
11
|
|
|
12
12
|
from ..types.basemodel import BaseModel, Nullable, OptionalNullable, Unset
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def serialize_decimal(as_str: bool):
|
|
16
16
|
def serialize(d):
|
|
17
|
-
|
|
17
|
+
# Optional[T] is a Union[T, None]
|
|
18
|
+
if is_union(type(d)) and type(None) in get_args(type(d)) and d is None:
|
|
18
19
|
return None
|
|
19
20
|
if isinstance(d, Unset):
|
|
20
21
|
return d
|
|
@@ -42,7 +43,8 @@ def validate_decimal(d):
|
|
|
42
43
|
|
|
43
44
|
def serialize_float(as_str: bool):
|
|
44
45
|
def serialize(f):
|
|
45
|
-
|
|
46
|
+
# Optional[T] is a Union[T, None]
|
|
47
|
+
if is_union(type(f)) and type(None) in get_args(type(f)) and f is None:
|
|
46
48
|
return None
|
|
47
49
|
if isinstance(f, Unset):
|
|
48
50
|
return f
|
|
@@ -70,7 +72,8 @@ def validate_float(f):
|
|
|
70
72
|
|
|
71
73
|
def serialize_int(as_str: bool):
|
|
72
74
|
def serialize(i):
|
|
73
|
-
|
|
75
|
+
# Optional[T] is a Union[T, None]
|
|
76
|
+
if is_union(type(i)) and type(None) in get_args(type(i)) and i is None:
|
|
74
77
|
return None
|
|
75
78
|
if isinstance(i, Unset):
|
|
76
79
|
return i
|
|
@@ -118,7 +121,8 @@ def validate_open_enum(is_int: bool):
|
|
|
118
121
|
|
|
119
122
|
def validate_const(v):
|
|
120
123
|
def validate(c):
|
|
121
|
-
|
|
124
|
+
# Optional[T] is a Union[T, None]
|
|
125
|
+
if is_union(type(c)) and type(None) in get_args(type(c)) and c is None:
|
|
122
126
|
return None
|
|
123
127
|
|
|
124
128
|
if v != c:
|
|
@@ -163,7 +167,7 @@ def marshal_json(val, typ):
|
|
|
163
167
|
if len(d) == 0:
|
|
164
168
|
return ""
|
|
165
169
|
|
|
166
|
-
return json.dumps(d[next(iter(d))], separators=(",", ":")
|
|
170
|
+
return json.dumps(d[next(iter(d))], separators=(",", ":"))
|
|
167
171
|
|
|
168
172
|
|
|
169
173
|
def is_nullable(field):
|