mistralai 0.4.2__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/__init__.py +5 -0
- mistralai/_hooks/__init__.py +5 -0
- mistralai/_hooks/custom_user_agent.py +16 -0
- mistralai/_hooks/deprecation_warning.py +26 -0
- mistralai/_hooks/registration.py +17 -0
- mistralai/_hooks/sdkhooks.py +57 -0
- mistralai/_hooks/types.py +76 -0
- mistralai/agents.py +434 -0
- mistralai/async_client.py +5 -413
- mistralai/basesdk.py +253 -0
- mistralai/chat.py +470 -0
- mistralai/client.py +5 -414
- mistralai/embeddings.py +182 -0
- mistralai/files.py +600 -84
- mistralai/fim.py +438 -0
- mistralai/fine_tuning.py +16 -0
- mistralai/httpclient.py +78 -0
- mistralai/jobs.py +822 -150
- mistralai/models/__init__.py +82 -0
- mistralai/models/agentscompletionrequest.py +96 -0
- mistralai/models/agentscompletionstreamrequest.py +92 -0
- mistralai/models/archiveftmodelout.py +19 -0
- mistralai/models/assistantmessage.py +53 -0
- mistralai/models/chatcompletionchoice.py +22 -0
- mistralai/models/chatcompletionrequest.py +109 -0
- mistralai/models/chatcompletionresponse.py +27 -0
- mistralai/models/chatcompletionstreamrequest.py +107 -0
- mistralai/models/checkpointout.py +25 -0
- mistralai/models/completionchunk.py +27 -0
- mistralai/models/completionevent.py +15 -0
- mistralai/models/completionresponsestreamchoice.py +48 -0
- mistralai/models/contentchunk.py +17 -0
- mistralai/models/delete_model_v1_models_model_id_deleteop.py +18 -0
- mistralai/models/deletefileout.py +24 -0
- mistralai/models/deletemodelout.py +25 -0
- mistralai/models/deltamessage.py +47 -0
- mistralai/models/detailedjobout.py +91 -0
- mistralai/models/embeddingrequest.py +61 -0
- mistralai/models/embeddingresponse.py +24 -0
- mistralai/models/embeddingresponsedata.py +19 -0
- mistralai/models/eventout.py +50 -0
- mistralai/models/files_api_routes_delete_fileop.py +16 -0
- mistralai/models/files_api_routes_retrieve_fileop.py +16 -0
- mistralai/models/files_api_routes_upload_fileop.py +51 -0
- mistralai/models/fileschema.py +71 -0
- mistralai/models/fimcompletionrequest.py +94 -0
- mistralai/models/fimcompletionresponse.py +27 -0
- mistralai/models/fimcompletionstreamrequest.py +92 -0
- mistralai/models/finetuneablemodel.py +8 -0
- mistralai/models/ftmodelcapabilitiesout.py +21 -0
- mistralai/models/ftmodelout.py +65 -0
- mistralai/models/function.py +19 -0
- mistralai/models/functioncall.py +22 -0
- mistralai/models/githubrepositoryin.py +52 -0
- mistralai/models/githubrepositoryout.py +52 -0
- mistralai/models/httpvalidationerror.py +23 -0
- mistralai/models/jobin.py +73 -0
- mistralai/models/jobmetadataout.py +54 -0
- mistralai/models/jobout.py +107 -0
- mistralai/models/jobs_api_routes_fine_tuning_archive_fine_tuned_modelop.py +18 -0
- mistralai/models/jobs_api_routes_fine_tuning_cancel_fine_tuning_jobop.py +18 -0
- mistralai/models/jobs_api_routes_fine_tuning_create_fine_tuning_jobop.py +15 -0
- mistralai/models/jobs_api_routes_fine_tuning_get_fine_tuning_jobop.py +18 -0
- mistralai/models/jobs_api_routes_fine_tuning_get_fine_tuning_jobsop.py +81 -0
- mistralai/models/jobs_api_routes_fine_tuning_start_fine_tuning_jobop.py +16 -0
- mistralai/models/jobs_api_routes_fine_tuning_unarchive_fine_tuned_modelop.py +18 -0
- mistralai/models/jobs_api_routes_fine_tuning_update_fine_tuned_modelop.py +21 -0
- mistralai/models/jobsout.py +20 -0
- mistralai/models/legacyjobmetadataout.py +80 -0
- mistralai/models/listfilesout.py +17 -0
- mistralai/models/metricout.py +50 -0
- mistralai/models/modelcapabilities.py +21 -0
- mistralai/models/modelcard.py +66 -0
- mistralai/models/modellist.py +18 -0
- mistralai/models/responseformat.py +18 -0
- mistralai/models/retrieve_model_v1_models_model_id_getop.py +18 -0
- mistralai/models/retrievefileout.py +71 -0
- mistralai/models/sampletype.py +7 -0
- mistralai/models/sdkerror.py +22 -0
- mistralai/models/security.py +16 -0
- mistralai/models/source.py +7 -0
- mistralai/models/systemmessage.py +26 -0
- mistralai/models/textchunk.py +17 -0
- mistralai/models/tool.py +18 -0
- mistralai/models/toolcall.py +20 -0
- mistralai/models/toolmessage.py +50 -0
- mistralai/models/trainingfile.py +17 -0
- mistralai/models/trainingparameters.py +48 -0
- mistralai/models/trainingparametersin.py +56 -0
- mistralai/models/unarchiveftmodelout.py +19 -0
- mistralai/models/updateftmodelin.py +44 -0
- mistralai/models/uploadfileout.py +71 -0
- mistralai/models/usageinfo.py +18 -0
- mistralai/models/usermessage.py +26 -0
- mistralai/models/validationerror.py +24 -0
- mistralai/models/wandbintegration.py +56 -0
- mistralai/models/wandbintegrationout.py +52 -0
- mistralai/models_.py +928 -0
- mistralai/py.typed +1 -0
- mistralai/sdk.py +119 -0
- mistralai/sdkconfiguration.py +54 -0
- mistralai/types/__init__.py +21 -0
- mistralai/types/basemodel.py +39 -0
- mistralai/utils/__init__.py +86 -0
- mistralai/utils/annotations.py +19 -0
- mistralai/utils/enums.py +34 -0
- mistralai/utils/eventstreaming.py +178 -0
- mistralai/utils/forms.py +207 -0
- mistralai/utils/headers.py +136 -0
- mistralai/utils/logger.py +16 -0
- mistralai/utils/metadata.py +118 -0
- mistralai/utils/queryparams.py +203 -0
- mistralai/utils/requestbodies.py +66 -0
- mistralai/utils/retries.py +216 -0
- mistralai/utils/security.py +185 -0
- mistralai/utils/serializers.py +181 -0
- mistralai/utils/url.py +150 -0
- mistralai/utils/values.py +128 -0
- {mistralai-0.4.2.dist-info → mistralai-1.0.0.dist-info}/LICENSE +1 -1
- mistralai-1.0.0.dist-info/METADATA +695 -0
- mistralai-1.0.0.dist-info/RECORD +235 -0
- mistralai_azure/__init__.py +5 -0
- mistralai_azure/_hooks/__init__.py +5 -0
- mistralai_azure/_hooks/custom_user_agent.py +16 -0
- mistralai_azure/_hooks/registration.py +15 -0
- mistralai_azure/_hooks/sdkhooks.py +57 -0
- mistralai_azure/_hooks/types.py +76 -0
- mistralai_azure/basesdk.py +253 -0
- mistralai_azure/chat.py +470 -0
- mistralai_azure/httpclient.py +78 -0
- mistralai_azure/models/__init__.py +28 -0
- mistralai_azure/models/assistantmessage.py +53 -0
- mistralai_azure/models/chatcompletionchoice.py +22 -0
- mistralai_azure/models/chatcompletionrequest.py +109 -0
- mistralai_azure/models/chatcompletionresponse.py +27 -0
- mistralai_azure/models/chatcompletionstreamrequest.py +107 -0
- mistralai_azure/models/completionchunk.py +27 -0
- mistralai_azure/models/completionevent.py +15 -0
- mistralai_azure/models/completionresponsestreamchoice.py +48 -0
- mistralai_azure/models/contentchunk.py +17 -0
- mistralai_azure/models/deltamessage.py +47 -0
- mistralai_azure/models/function.py +19 -0
- mistralai_azure/models/functioncall.py +22 -0
- mistralai_azure/models/httpvalidationerror.py +23 -0
- mistralai_azure/models/responseformat.py +18 -0
- mistralai_azure/models/sdkerror.py +22 -0
- mistralai_azure/models/security.py +16 -0
- mistralai_azure/models/systemmessage.py +26 -0
- mistralai_azure/models/textchunk.py +17 -0
- mistralai_azure/models/tool.py +18 -0
- mistralai_azure/models/toolcall.py +20 -0
- mistralai_azure/models/toolmessage.py +50 -0
- mistralai_azure/models/usageinfo.py +18 -0
- mistralai_azure/models/usermessage.py +26 -0
- mistralai_azure/models/validationerror.py +24 -0
- mistralai_azure/py.typed +1 -0
- mistralai_azure/sdk.py +107 -0
- mistralai_azure/sdkconfiguration.py +54 -0
- mistralai_azure/types/__init__.py +21 -0
- mistralai_azure/types/basemodel.py +39 -0
- mistralai_azure/utils/__init__.py +84 -0
- mistralai_azure/utils/annotations.py +19 -0
- mistralai_azure/utils/enums.py +34 -0
- mistralai_azure/utils/eventstreaming.py +178 -0
- mistralai_azure/utils/forms.py +207 -0
- mistralai_azure/utils/headers.py +136 -0
- mistralai_azure/utils/logger.py +16 -0
- mistralai_azure/utils/metadata.py +118 -0
- mistralai_azure/utils/queryparams.py +203 -0
- mistralai_azure/utils/requestbodies.py +66 -0
- mistralai_azure/utils/retries.py +216 -0
- mistralai_azure/utils/security.py +168 -0
- mistralai_azure/utils/serializers.py +181 -0
- mistralai_azure/utils/url.py +150 -0
- mistralai_azure/utils/values.py +128 -0
- mistralai_gcp/__init__.py +5 -0
- mistralai_gcp/_hooks/__init__.py +5 -0
- mistralai_gcp/_hooks/custom_user_agent.py +16 -0
- mistralai_gcp/_hooks/registration.py +15 -0
- mistralai_gcp/_hooks/sdkhooks.py +57 -0
- mistralai_gcp/_hooks/types.py +76 -0
- mistralai_gcp/basesdk.py +253 -0
- mistralai_gcp/chat.py +458 -0
- mistralai_gcp/fim.py +438 -0
- mistralai_gcp/httpclient.py +78 -0
- mistralai_gcp/models/__init__.py +31 -0
- mistralai_gcp/models/assistantmessage.py +53 -0
- mistralai_gcp/models/chatcompletionchoice.py +22 -0
- mistralai_gcp/models/chatcompletionrequest.py +105 -0
- mistralai_gcp/models/chatcompletionresponse.py +27 -0
- mistralai_gcp/models/chatcompletionstreamrequest.py +103 -0
- mistralai_gcp/models/completionchunk.py +27 -0
- mistralai_gcp/models/completionevent.py +15 -0
- mistralai_gcp/models/completionresponsestreamchoice.py +48 -0
- mistralai_gcp/models/contentchunk.py +17 -0
- mistralai_gcp/models/deltamessage.py +47 -0
- mistralai_gcp/models/fimcompletionrequest.py +94 -0
- mistralai_gcp/models/fimcompletionresponse.py +27 -0
- mistralai_gcp/models/fimcompletionstreamrequest.py +92 -0
- mistralai_gcp/models/function.py +19 -0
- mistralai_gcp/models/functioncall.py +22 -0
- mistralai_gcp/models/httpvalidationerror.py +23 -0
- mistralai_gcp/models/responseformat.py +18 -0
- mistralai_gcp/models/sdkerror.py +22 -0
- mistralai_gcp/models/security.py +16 -0
- mistralai_gcp/models/systemmessage.py +26 -0
- mistralai_gcp/models/textchunk.py +17 -0
- mistralai_gcp/models/tool.py +18 -0
- mistralai_gcp/models/toolcall.py +20 -0
- mistralai_gcp/models/toolmessage.py +50 -0
- mistralai_gcp/models/usageinfo.py +18 -0
- mistralai_gcp/models/usermessage.py +26 -0
- mistralai_gcp/models/validationerror.py +24 -0
- mistralai_gcp/py.typed +1 -0
- mistralai_gcp/sdk.py +174 -0
- mistralai_gcp/sdkconfiguration.py +54 -0
- mistralai_gcp/types/__init__.py +21 -0
- mistralai_gcp/types/basemodel.py +39 -0
- mistralai_gcp/utils/__init__.py +84 -0
- mistralai_gcp/utils/annotations.py +19 -0
- mistralai_gcp/utils/enums.py +34 -0
- mistralai_gcp/utils/eventstreaming.py +178 -0
- mistralai_gcp/utils/forms.py +207 -0
- mistralai_gcp/utils/headers.py +136 -0
- mistralai_gcp/utils/logger.py +16 -0
- mistralai_gcp/utils/metadata.py +118 -0
- mistralai_gcp/utils/queryparams.py +203 -0
- mistralai_gcp/utils/requestbodies.py +66 -0
- mistralai_gcp/utils/retries.py +216 -0
- mistralai_gcp/utils/security.py +168 -0
- mistralai_gcp/utils/serializers.py +181 -0
- mistralai_gcp/utils/url.py +150 -0
- mistralai_gcp/utils/values.py +128 -0
- py.typed +1 -0
- mistralai/client_base.py +0 -211
- mistralai/constants.py +0 -5
- mistralai/exceptions.py +0 -54
- mistralai/models/chat_completion.py +0 -93
- mistralai/models/common.py +0 -9
- mistralai/models/embeddings.py +0 -19
- mistralai/models/files.py +0 -23
- mistralai/models/jobs.py +0 -100
- mistralai/models/models.py +0 -39
- mistralai-0.4.2.dist-info/METADATA +0 -82
- mistralai-0.4.2.dist-info/RECORD +0 -20
- {mistralai-0.4.2.dist-info → mistralai-1.0.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
|
|
2
|
+
|
|
3
|
+
from decimal import Decimal
|
|
4
|
+
import json
|
|
5
|
+
from typing import Any, Dict, List, Union, get_args
|
|
6
|
+
import httpx
|
|
7
|
+
from typing_extensions import get_origin
|
|
8
|
+
from pydantic import ConfigDict, create_model
|
|
9
|
+
from pydantic_core import from_json
|
|
10
|
+
from typing_inspect import is_optional_type
|
|
11
|
+
|
|
12
|
+
from ..types.basemodel import BaseModel, Nullable, OptionalNullable
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def serialize_decimal(as_str: bool):
|
|
16
|
+
def serialize(d):
|
|
17
|
+
if is_optional_type(type(d)) and d is None:
|
|
18
|
+
return None
|
|
19
|
+
|
|
20
|
+
if not isinstance(d, Decimal):
|
|
21
|
+
raise ValueError("Expected Decimal object")
|
|
22
|
+
|
|
23
|
+
return str(d) if as_str else float(d)
|
|
24
|
+
|
|
25
|
+
return serialize
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def validate_decimal(d):
|
|
29
|
+
if d is None:
|
|
30
|
+
return None
|
|
31
|
+
|
|
32
|
+
if isinstance(d, Decimal):
|
|
33
|
+
return d
|
|
34
|
+
|
|
35
|
+
if not isinstance(d, (str, int, float)):
|
|
36
|
+
raise ValueError("Expected string, int or float")
|
|
37
|
+
|
|
38
|
+
return Decimal(str(d))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def serialize_float(as_str: bool):
|
|
42
|
+
def serialize(f):
|
|
43
|
+
if is_optional_type(type(f)) and f is None:
|
|
44
|
+
return None
|
|
45
|
+
|
|
46
|
+
if not isinstance(f, float):
|
|
47
|
+
raise ValueError("Expected float")
|
|
48
|
+
|
|
49
|
+
return str(f) if as_str else f
|
|
50
|
+
|
|
51
|
+
return serialize
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def validate_float(f):
|
|
55
|
+
if f is None:
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
if isinstance(f, float):
|
|
59
|
+
return f
|
|
60
|
+
|
|
61
|
+
if not isinstance(f, str):
|
|
62
|
+
raise ValueError("Expected string")
|
|
63
|
+
|
|
64
|
+
return float(f)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def serialize_int(as_str: bool):
|
|
68
|
+
def serialize(b):
|
|
69
|
+
if is_optional_type(type(b)) and b is None:
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
if not isinstance(b, int):
|
|
73
|
+
raise ValueError("Expected int")
|
|
74
|
+
|
|
75
|
+
return str(b) if as_str else b
|
|
76
|
+
|
|
77
|
+
return serialize
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def validate_int(b):
|
|
81
|
+
if b is None:
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
if isinstance(b, int):
|
|
85
|
+
return b
|
|
86
|
+
|
|
87
|
+
if not isinstance(b, str):
|
|
88
|
+
raise ValueError("Expected string")
|
|
89
|
+
|
|
90
|
+
return int(b)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def validate_open_enum(is_int: bool):
|
|
94
|
+
def validate(e):
|
|
95
|
+
if e is None:
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
if is_int:
|
|
99
|
+
if not isinstance(e, int):
|
|
100
|
+
raise ValueError("Expected int")
|
|
101
|
+
else:
|
|
102
|
+
if not isinstance(e, str):
|
|
103
|
+
raise ValueError("Expected string")
|
|
104
|
+
|
|
105
|
+
return e
|
|
106
|
+
|
|
107
|
+
return validate
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def unmarshal_json(raw, typ: Any) -> Any:
|
|
111
|
+
return unmarshal(from_json(raw), typ)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def unmarshal(val, typ: Any) -> Any:
|
|
115
|
+
unmarshaller = create_model(
|
|
116
|
+
"Unmarshaller",
|
|
117
|
+
body=(typ, ...),
|
|
118
|
+
__config__=ConfigDict(populate_by_name=True, arbitrary_types_allowed=True),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
m = unmarshaller(body=val)
|
|
122
|
+
|
|
123
|
+
# pyright: ignore[reportAttributeAccessIssue]
|
|
124
|
+
return m.body # type: ignore
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def marshal_json(val, typ):
|
|
128
|
+
if is_nullable(typ) and val is None:
|
|
129
|
+
return "null"
|
|
130
|
+
|
|
131
|
+
marshaller = create_model(
|
|
132
|
+
"Marshaller",
|
|
133
|
+
body=(typ, ...),
|
|
134
|
+
__config__=ConfigDict(populate_by_name=True, arbitrary_types_allowed=True),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
m = marshaller(body=val)
|
|
138
|
+
|
|
139
|
+
d = m.model_dump(by_alias=True, mode="json", exclude_none=True)
|
|
140
|
+
|
|
141
|
+
if len(d) == 0:
|
|
142
|
+
return ""
|
|
143
|
+
|
|
144
|
+
return json.dumps(d[next(iter(d))], separators=(",", ":"), sort_keys=True)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def is_nullable(field):
|
|
148
|
+
origin = get_origin(field)
|
|
149
|
+
if origin is Nullable or origin is OptionalNullable:
|
|
150
|
+
return True
|
|
151
|
+
|
|
152
|
+
if not origin is Union or type(None) not in get_args(field):
|
|
153
|
+
return False
|
|
154
|
+
|
|
155
|
+
for arg in get_args(field):
|
|
156
|
+
if get_origin(arg) is Nullable or get_origin(arg) is OptionalNullable:
|
|
157
|
+
return True
|
|
158
|
+
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def stream_to_text(stream: httpx.Response) -> str:
|
|
163
|
+
return "".join(stream.iter_text())
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def get_pydantic_model(data: Any, typ: Any) -> Any:
|
|
167
|
+
if not _contains_pydantic_model(data):
|
|
168
|
+
return unmarshal(data, typ)
|
|
169
|
+
|
|
170
|
+
return data
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _contains_pydantic_model(data: Any) -> bool:
|
|
174
|
+
if isinstance(data, BaseModel):
|
|
175
|
+
return True
|
|
176
|
+
if isinstance(data, List):
|
|
177
|
+
return any(_contains_pydantic_model(item) for item in data)
|
|
178
|
+
if isinstance(data, Dict):
|
|
179
|
+
return any(_contains_pydantic_model(value) for value in data.values())
|
|
180
|
+
|
|
181
|
+
return False
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
|
|
2
|
+
|
|
3
|
+
from decimal import Decimal
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
Dict,
|
|
7
|
+
get_type_hints,
|
|
8
|
+
List,
|
|
9
|
+
Optional,
|
|
10
|
+
Union,
|
|
11
|
+
get_args,
|
|
12
|
+
get_origin,
|
|
13
|
+
)
|
|
14
|
+
from pydantic import BaseModel
|
|
15
|
+
from pydantic.fields import FieldInfo
|
|
16
|
+
|
|
17
|
+
from .metadata import (
|
|
18
|
+
PathParamMetadata,
|
|
19
|
+
find_field_metadata,
|
|
20
|
+
)
|
|
21
|
+
from .values import _get_serialized_params, _populate_from_globals, _val_to_string
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def generate_url(
|
|
25
|
+
server_url: str,
|
|
26
|
+
path: str,
|
|
27
|
+
path_params: Any,
|
|
28
|
+
gbls: Optional[Any] = None,
|
|
29
|
+
) -> str:
|
|
30
|
+
path_param_values: Dict[str, str] = {}
|
|
31
|
+
|
|
32
|
+
globals_already_populated = _populate_path_params(
|
|
33
|
+
path_params, gbls, path_param_values, []
|
|
34
|
+
)
|
|
35
|
+
if gbls is not None:
|
|
36
|
+
_populate_path_params(gbls, None, path_param_values, globals_already_populated)
|
|
37
|
+
|
|
38
|
+
for key, value in path_param_values.items():
|
|
39
|
+
path = path.replace("{" + key + "}", value, 1)
|
|
40
|
+
|
|
41
|
+
return remove_suffix(server_url, "/") + path
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _populate_path_params(
|
|
45
|
+
path_params: Any,
|
|
46
|
+
gbls: Any,
|
|
47
|
+
path_param_values: Dict[str, str],
|
|
48
|
+
skip_fields: List[str],
|
|
49
|
+
) -> List[str]:
|
|
50
|
+
globals_already_populated: List[str] = []
|
|
51
|
+
|
|
52
|
+
if not isinstance(path_params, BaseModel):
|
|
53
|
+
return globals_already_populated
|
|
54
|
+
|
|
55
|
+
path_param_fields: Dict[str, FieldInfo] = path_params.__class__.model_fields
|
|
56
|
+
path_param_field_types = get_type_hints(path_params.__class__)
|
|
57
|
+
for name in path_param_fields:
|
|
58
|
+
if name in skip_fields:
|
|
59
|
+
continue
|
|
60
|
+
|
|
61
|
+
field = path_param_fields[name]
|
|
62
|
+
|
|
63
|
+
param_metadata = find_field_metadata(field, PathParamMetadata)
|
|
64
|
+
if param_metadata is None:
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
param = getattr(path_params, name) if path_params is not None else None
|
|
68
|
+
param, global_found = _populate_from_globals(
|
|
69
|
+
name, param, PathParamMetadata, gbls
|
|
70
|
+
)
|
|
71
|
+
if global_found:
|
|
72
|
+
globals_already_populated.append(name)
|
|
73
|
+
|
|
74
|
+
if param is None:
|
|
75
|
+
continue
|
|
76
|
+
|
|
77
|
+
f_name = field.alias if field.alias is not None else name
|
|
78
|
+
serialization = param_metadata.serialization
|
|
79
|
+
if serialization is not None:
|
|
80
|
+
serialized_params = _get_serialized_params(
|
|
81
|
+
param_metadata, f_name, param, path_param_field_types[name]
|
|
82
|
+
)
|
|
83
|
+
for key, value in serialized_params.items():
|
|
84
|
+
path_param_values[key] = value
|
|
85
|
+
else:
|
|
86
|
+
pp_vals: List[str] = []
|
|
87
|
+
if param_metadata.style == "simple":
|
|
88
|
+
if isinstance(param, List):
|
|
89
|
+
for pp_val in param:
|
|
90
|
+
if pp_val is None:
|
|
91
|
+
continue
|
|
92
|
+
pp_vals.append(_val_to_string(pp_val))
|
|
93
|
+
path_param_values[f_name] = ",".join(pp_vals)
|
|
94
|
+
elif isinstance(param, Dict):
|
|
95
|
+
for pp_key in param:
|
|
96
|
+
if param[pp_key] is None:
|
|
97
|
+
continue
|
|
98
|
+
if param_metadata.explode:
|
|
99
|
+
pp_vals.append(f"{pp_key}={_val_to_string(param[pp_key])}")
|
|
100
|
+
else:
|
|
101
|
+
pp_vals.append(f"{pp_key},{_val_to_string(param[pp_key])}")
|
|
102
|
+
path_param_values[f_name] = ",".join(pp_vals)
|
|
103
|
+
elif not isinstance(param, (str, int, float, complex, bool, Decimal)):
|
|
104
|
+
param_fields: Dict[str, FieldInfo] = param.__class__.model_fields
|
|
105
|
+
for name in param_fields:
|
|
106
|
+
param_field = param_fields[name]
|
|
107
|
+
|
|
108
|
+
param_value_metadata = find_field_metadata(
|
|
109
|
+
param_field, PathParamMetadata
|
|
110
|
+
)
|
|
111
|
+
if param_value_metadata is None:
|
|
112
|
+
continue
|
|
113
|
+
|
|
114
|
+
param_name = (
|
|
115
|
+
param_field.alias if param_field.alias is not None else name
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
param_field_val = getattr(param, name)
|
|
119
|
+
if param_field_val is None:
|
|
120
|
+
continue
|
|
121
|
+
if param_metadata.explode:
|
|
122
|
+
pp_vals.append(
|
|
123
|
+
f"{param_name}={_val_to_string(param_field_val)}"
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
pp_vals.append(
|
|
127
|
+
f"{param_name},{_val_to_string(param_field_val)}"
|
|
128
|
+
)
|
|
129
|
+
path_param_values[f_name] = ",".join(pp_vals)
|
|
130
|
+
else:
|
|
131
|
+
path_param_values[f_name] = _val_to_string(param)
|
|
132
|
+
|
|
133
|
+
return globals_already_populated
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def is_optional(field):
|
|
137
|
+
return get_origin(field) is Union and type(None) in get_args(field)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def template_url(url_with_params: str, params: Dict[str, str]) -> str:
|
|
141
|
+
for key, value in params.items():
|
|
142
|
+
url_with_params = url_with_params.replace("{" + key + "}", value)
|
|
143
|
+
|
|
144
|
+
return url_with_params
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def remove_suffix(input_string, suffix):
|
|
148
|
+
if suffix and input_string.endswith(suffix):
|
|
149
|
+
return input_string[: -len(suffix)]
|
|
150
|
+
return input_string
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from email.message import Message
|
|
6
|
+
import os
|
|
7
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
|
|
8
|
+
|
|
9
|
+
from httpx import Response
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
from pydantic.fields import FieldInfo
|
|
12
|
+
|
|
13
|
+
from .serializers import marshal_json
|
|
14
|
+
|
|
15
|
+
from .metadata import ParamMetadata, find_field_metadata
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def match_content_type(content_type: str, pattern: str) -> bool:
|
|
19
|
+
if pattern in (content_type, "*", "*/*"):
|
|
20
|
+
return True
|
|
21
|
+
|
|
22
|
+
msg = Message()
|
|
23
|
+
msg["content-type"] = content_type
|
|
24
|
+
media_type = msg.get_content_type()
|
|
25
|
+
|
|
26
|
+
if media_type == pattern:
|
|
27
|
+
return True
|
|
28
|
+
|
|
29
|
+
parts = media_type.split("/")
|
|
30
|
+
if len(parts) == 2:
|
|
31
|
+
if pattern in (f"{parts[0]}/*", f"*/{parts[1]}"):
|
|
32
|
+
return True
|
|
33
|
+
|
|
34
|
+
return False
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def match_status_codes(status_codes: List[str], status_code: int) -> bool:
|
|
38
|
+
if "default" in status_codes:
|
|
39
|
+
return True
|
|
40
|
+
|
|
41
|
+
for code in status_codes:
|
|
42
|
+
if code == str(status_code):
|
|
43
|
+
return True
|
|
44
|
+
|
|
45
|
+
if code.endswith("XX") and code.startswith(str(status_code)[:1]):
|
|
46
|
+
return True
|
|
47
|
+
return False
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
T = TypeVar("T")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_global_from_env(
|
|
54
|
+
value: Optional[T], env_key: str, type_cast: Callable[[str], T]
|
|
55
|
+
) -> Optional[T]:
|
|
56
|
+
if value is not None:
|
|
57
|
+
return value
|
|
58
|
+
env_value = os.getenv(env_key)
|
|
59
|
+
if env_value is not None:
|
|
60
|
+
try:
|
|
61
|
+
return type_cast(env_value)
|
|
62
|
+
except ValueError:
|
|
63
|
+
pass
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def match_response(
|
|
68
|
+
response: Response, code: Union[str, List[str]], content_type: str
|
|
69
|
+
) -> bool:
|
|
70
|
+
codes = code if isinstance(code, list) else [code]
|
|
71
|
+
return match_status_codes(codes, response.status_code) and match_content_type(
|
|
72
|
+
response.headers.get("content-type", "application/octet-stream"), content_type
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _populate_from_globals(
|
|
77
|
+
param_name: str, value: Any, param_metadata_type: type, gbls: Any
|
|
78
|
+
) -> Tuple[Any, bool]:
|
|
79
|
+
if gbls is None:
|
|
80
|
+
return value, False
|
|
81
|
+
|
|
82
|
+
if not isinstance(gbls, BaseModel):
|
|
83
|
+
raise TypeError("globals must be a pydantic model")
|
|
84
|
+
|
|
85
|
+
global_fields: Dict[str, FieldInfo] = gbls.__class__.model_fields
|
|
86
|
+
found = False
|
|
87
|
+
for name in global_fields:
|
|
88
|
+
field = global_fields[name]
|
|
89
|
+
if name is not param_name:
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
found = True
|
|
93
|
+
|
|
94
|
+
if value is not None:
|
|
95
|
+
return value, True
|
|
96
|
+
|
|
97
|
+
global_value = getattr(gbls, name)
|
|
98
|
+
|
|
99
|
+
param_metadata = find_field_metadata(field, param_metadata_type)
|
|
100
|
+
if param_metadata is None:
|
|
101
|
+
return value, True
|
|
102
|
+
|
|
103
|
+
return global_value, True
|
|
104
|
+
|
|
105
|
+
return value, found
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _val_to_string(val) -> str:
|
|
109
|
+
if isinstance(val, bool):
|
|
110
|
+
return str(val).lower()
|
|
111
|
+
if isinstance(val, datetime):
|
|
112
|
+
return str(val.isoformat().replace("+00:00", "Z"))
|
|
113
|
+
if isinstance(val, Enum):
|
|
114
|
+
return str(val.value)
|
|
115
|
+
|
|
116
|
+
return str(val)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _get_serialized_params(
|
|
120
|
+
metadata: ParamMetadata, field_name: str, obj: Any, typ: type
|
|
121
|
+
) -> Dict[str, str]:
|
|
122
|
+
params: Dict[str, str] = {}
|
|
123
|
+
|
|
124
|
+
serialization = metadata.serialization
|
|
125
|
+
if serialization == "json":
|
|
126
|
+
params[field_name] = marshal_json(obj, typ)
|
|
127
|
+
|
|
128
|
+
return params
|
mistralai/client_base.py
DELETED
|
@@ -1,211 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import os
|
|
3
|
-
from abc import ABC
|
|
4
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
|
5
|
-
|
|
6
|
-
import orjson
|
|
7
|
-
from httpx import Headers
|
|
8
|
-
|
|
9
|
-
from mistralai.constants import HEADER_MODEL_DEPRECATION_TIMESTAMP
|
|
10
|
-
from mistralai.exceptions import MistralException
|
|
11
|
-
from mistralai.models.chat_completion import (
|
|
12
|
-
ChatMessage,
|
|
13
|
-
Function,
|
|
14
|
-
ResponseFormat,
|
|
15
|
-
ToolChoice,
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
CLIENT_VERSION = "0.4.2"
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class ClientBase(ABC):
|
|
22
|
-
def __init__(
|
|
23
|
-
self,
|
|
24
|
-
endpoint: str,
|
|
25
|
-
api_key: Optional[str] = None,
|
|
26
|
-
max_retries: int = 5,
|
|
27
|
-
timeout: int = 120,
|
|
28
|
-
):
|
|
29
|
-
self._max_retries = max_retries
|
|
30
|
-
self._timeout = timeout
|
|
31
|
-
|
|
32
|
-
if api_key is None:
|
|
33
|
-
api_key = os.environ.get("MISTRAL_API_KEY")
|
|
34
|
-
if api_key is None:
|
|
35
|
-
raise MistralException(message="API key not provided. Please set MISTRAL_API_KEY environment variable.")
|
|
36
|
-
self._api_key = api_key
|
|
37
|
-
self._endpoint = endpoint
|
|
38
|
-
self._logger = logging.getLogger(__name__)
|
|
39
|
-
|
|
40
|
-
# For azure endpoints, we default to the mistral model
|
|
41
|
-
if "inference.azure.com" in self._endpoint:
|
|
42
|
-
self._default_model = "mistral"
|
|
43
|
-
|
|
44
|
-
self._version = CLIENT_VERSION
|
|
45
|
-
|
|
46
|
-
def _get_model(self, model: Optional[str] = None) -> str:
|
|
47
|
-
if model is not None:
|
|
48
|
-
return model
|
|
49
|
-
else:
|
|
50
|
-
if self._default_model is None:
|
|
51
|
-
raise MistralException(message="model must be provided")
|
|
52
|
-
return self._default_model
|
|
53
|
-
|
|
54
|
-
def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
55
|
-
parsed_tools: List[Dict[str, Any]] = []
|
|
56
|
-
for tool in tools:
|
|
57
|
-
if tool["type"] == "function":
|
|
58
|
-
parsed_function = {}
|
|
59
|
-
parsed_function["type"] = tool["type"]
|
|
60
|
-
if isinstance(tool["function"], Function):
|
|
61
|
-
parsed_function["function"] = tool["function"].model_dump(exclude_none=True)
|
|
62
|
-
else:
|
|
63
|
-
parsed_function["function"] = tool["function"]
|
|
64
|
-
|
|
65
|
-
parsed_tools.append(parsed_function)
|
|
66
|
-
|
|
67
|
-
return parsed_tools
|
|
68
|
-
|
|
69
|
-
def _parse_tool_choice(self, tool_choice: Union[str, ToolChoice]) -> str:
|
|
70
|
-
if isinstance(tool_choice, ToolChoice):
|
|
71
|
-
return tool_choice.value
|
|
72
|
-
return tool_choice
|
|
73
|
-
|
|
74
|
-
def _parse_response_format(self, response_format: Union[Dict[str, Any], ResponseFormat]) -> Dict[str, Any]:
|
|
75
|
-
if isinstance(response_format, ResponseFormat):
|
|
76
|
-
return response_format.model_dump(exclude_none=True)
|
|
77
|
-
return response_format
|
|
78
|
-
|
|
79
|
-
def _parse_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
|
|
80
|
-
parsed_messages: List[Dict[str, Any]] = []
|
|
81
|
-
for message in messages:
|
|
82
|
-
if isinstance(message, ChatMessage):
|
|
83
|
-
parsed_messages.append(message.model_dump(exclude_none=True))
|
|
84
|
-
else:
|
|
85
|
-
parsed_messages.append(message)
|
|
86
|
-
|
|
87
|
-
return parsed_messages
|
|
88
|
-
|
|
89
|
-
def _check_model_deprecation_header_callback_factory(self, model: Optional[str] = None) -> Callable:
|
|
90
|
-
model = self._get_model(model)
|
|
91
|
-
|
|
92
|
-
def _check_model_deprecation_header_callback(
|
|
93
|
-
headers: Headers,
|
|
94
|
-
) -> None:
|
|
95
|
-
if HEADER_MODEL_DEPRECATION_TIMESTAMP in headers:
|
|
96
|
-
self._logger.warning(
|
|
97
|
-
f"WARNING: The model {model} is deprecated "
|
|
98
|
-
f"and will be removed on {headers[HEADER_MODEL_DEPRECATION_TIMESTAMP]}. "
|
|
99
|
-
"Please refer to https://docs.mistral.ai/getting-started/models/#api-versioning "
|
|
100
|
-
"for more information."
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
return _check_model_deprecation_header_callback
|
|
104
|
-
|
|
105
|
-
def _make_completion_request(
|
|
106
|
-
self,
|
|
107
|
-
prompt: str,
|
|
108
|
-
model: Optional[str] = None,
|
|
109
|
-
suffix: Optional[str] = None,
|
|
110
|
-
temperature: Optional[float] = None,
|
|
111
|
-
max_tokens: Optional[int] = None,
|
|
112
|
-
top_p: Optional[float] = None,
|
|
113
|
-
random_seed: Optional[int] = None,
|
|
114
|
-
stop: Optional[List[str]] = None,
|
|
115
|
-
stream: Optional[bool] = False,
|
|
116
|
-
) -> Dict[str, Any]:
|
|
117
|
-
request_data: Dict[str, Any] = {
|
|
118
|
-
"prompt": prompt,
|
|
119
|
-
"suffix": suffix,
|
|
120
|
-
"model": model,
|
|
121
|
-
"stream": stream,
|
|
122
|
-
}
|
|
123
|
-
|
|
124
|
-
if stop is not None:
|
|
125
|
-
request_data["stop"] = stop
|
|
126
|
-
|
|
127
|
-
request_data["model"] = self._get_model(model)
|
|
128
|
-
|
|
129
|
-
request_data.update(
|
|
130
|
-
self._build_sampling_params(
|
|
131
|
-
temperature=temperature,
|
|
132
|
-
max_tokens=max_tokens,
|
|
133
|
-
top_p=top_p,
|
|
134
|
-
random_seed=random_seed,
|
|
135
|
-
)
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
self._logger.debug(f"Completion request: {request_data}")
|
|
139
|
-
|
|
140
|
-
return request_data
|
|
141
|
-
|
|
142
|
-
def _build_sampling_params(
|
|
143
|
-
self,
|
|
144
|
-
max_tokens: Optional[int],
|
|
145
|
-
random_seed: Optional[int],
|
|
146
|
-
temperature: Optional[float],
|
|
147
|
-
top_p: Optional[float],
|
|
148
|
-
) -> Dict[str, Any]:
|
|
149
|
-
params = {}
|
|
150
|
-
if temperature is not None:
|
|
151
|
-
params["temperature"] = temperature
|
|
152
|
-
if max_tokens is not None:
|
|
153
|
-
params["max_tokens"] = max_tokens
|
|
154
|
-
if top_p is not None:
|
|
155
|
-
params["top_p"] = top_p
|
|
156
|
-
if random_seed is not None:
|
|
157
|
-
params["random_seed"] = random_seed
|
|
158
|
-
return params
|
|
159
|
-
|
|
160
|
-
def _make_chat_request(
|
|
161
|
-
self,
|
|
162
|
-
messages: List[Any],
|
|
163
|
-
model: Optional[str] = None,
|
|
164
|
-
tools: Optional[List[Dict[str, Any]]] = None,
|
|
165
|
-
temperature: Optional[float] = None,
|
|
166
|
-
max_tokens: Optional[int] = None,
|
|
167
|
-
top_p: Optional[float] = None,
|
|
168
|
-
random_seed: Optional[int] = None,
|
|
169
|
-
stream: Optional[bool] = None,
|
|
170
|
-
safe_prompt: Optional[bool] = False,
|
|
171
|
-
tool_choice: Optional[Union[str, ToolChoice]] = None,
|
|
172
|
-
response_format: Optional[Union[Dict[str, str], ResponseFormat]] = None,
|
|
173
|
-
) -> Dict[str, Any]:
|
|
174
|
-
request_data: Dict[str, Any] = {
|
|
175
|
-
"messages": self._parse_messages(messages),
|
|
176
|
-
}
|
|
177
|
-
|
|
178
|
-
request_data["model"] = self._get_model(model)
|
|
179
|
-
|
|
180
|
-
request_data.update(
|
|
181
|
-
self._build_sampling_params(
|
|
182
|
-
temperature=temperature,
|
|
183
|
-
max_tokens=max_tokens,
|
|
184
|
-
top_p=top_p,
|
|
185
|
-
random_seed=random_seed,
|
|
186
|
-
)
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
if safe_prompt:
|
|
190
|
-
request_data["safe_prompt"] = safe_prompt
|
|
191
|
-
if tools is not None:
|
|
192
|
-
request_data["tools"] = self._parse_tools(tools)
|
|
193
|
-
if stream is not None:
|
|
194
|
-
request_data["stream"] = stream
|
|
195
|
-
|
|
196
|
-
if tool_choice is not None:
|
|
197
|
-
request_data["tool_choice"] = self._parse_tool_choice(tool_choice)
|
|
198
|
-
if response_format is not None:
|
|
199
|
-
request_data["response_format"] = self._parse_response_format(response_format)
|
|
200
|
-
|
|
201
|
-
self._logger.debug(f"Chat request: {request_data}")
|
|
202
|
-
|
|
203
|
-
return request_data
|
|
204
|
-
|
|
205
|
-
def _process_line(self, line: str) -> Optional[Dict[str, Any]]:
|
|
206
|
-
if line.startswith("data: "):
|
|
207
|
-
line = line[6:].strip()
|
|
208
|
-
if line != "[DONE]":
|
|
209
|
-
json_streamed_response: Dict[str, Any] = orjson.loads(line)
|
|
210
|
-
return json_streamed_response
|
|
211
|
-
return None
|