mistralai 0.4.0__tar.gz → 0.4.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mistralai-0.4.0 → mistralai-0.4.2}/PKG-INFO +5 -3
- {mistralai-0.4.0 → mistralai-0.4.2}/pyproject.toml +6 -7
- {mistralai-0.4.0 → mistralai-0.4.2}/src/mistralai/async_client.py +56 -9
- {mistralai-0.4.0 → mistralai-0.4.2}/src/mistralai/client.py +60 -10
- {mistralai-0.4.0 → mistralai-0.4.2}/src/mistralai/client_base.py +45 -20
- {mistralai-0.4.0 → mistralai-0.4.2}/src/mistralai/constants.py +2 -0
- {mistralai-0.4.0 → mistralai-0.4.2}/src/mistralai/models/jobs.py +2 -0
- {mistralai-0.4.0 → mistralai-0.4.2}/src/mistralai/models/models.py +6 -0
- {mistralai-0.4.0 → mistralai-0.4.2}/LICENSE +0 -0
- {mistralai-0.4.0 → mistralai-0.4.2}/README.md +0 -0
- {mistralai-0.4.0 → mistralai-0.4.2}/src/mistralai/__init__.py +0 -0
- {mistralai-0.4.0 → mistralai-0.4.2}/src/mistralai/exceptions.py +0 -0
- {mistralai-0.4.0 → mistralai-0.4.2}/src/mistralai/files.py +0 -0
- {mistralai-0.4.0 → mistralai-0.4.2}/src/mistralai/jobs.py +0 -0
- {mistralai-0.4.0 → mistralai-0.4.2}/src/mistralai/models/__init__.py +0 -0
- {mistralai-0.4.0 → mistralai-0.4.2}/src/mistralai/models/chat_completion.py +0 -0
- {mistralai-0.4.0 → mistralai-0.4.2}/src/mistralai/models/common.py +0 -0
- {mistralai-0.4.0 → mistralai-0.4.2}/src/mistralai/models/embeddings.py +0 -0
- {mistralai-0.4.0 → mistralai-0.4.2}/src/mistralai/models/files.py +0 -0
- {mistralai-0.4.0 → mistralai-0.4.2}/src/mistralai/py.typed +0 -0
|
@@ -1,18 +1,20 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: mistralai
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.2
|
|
4
4
|
Summary:
|
|
5
|
+
License: Apache 2.0 License
|
|
5
6
|
Author: Bam4d
|
|
6
7
|
Author-email: bam4d@mistral.ai
|
|
7
8
|
Requires-Python: >=3.9,<4.0
|
|
9
|
+
Classifier: License :: Other/Proprietary License
|
|
8
10
|
Classifier: Programming Language :: Python :: 3
|
|
9
11
|
Classifier: Programming Language :: Python :: 3.9
|
|
10
12
|
Classifier: Programming Language :: Python :: 3.10
|
|
11
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
12
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
-
Requires-Dist: httpx (>=0.25,<
|
|
15
|
+
Requires-Dist: httpx (>=0.25,<1)
|
|
14
16
|
Requires-Dist: orjson (>=3.9.10,<3.11)
|
|
15
|
-
Requires-Dist: pydantic (>=2.5.2,<3
|
|
17
|
+
Requires-Dist: pydantic (>=2.5.2,<3)
|
|
16
18
|
Description-Content-Type: text/markdown
|
|
17
19
|
|
|
18
20
|
# Mistral Python Client
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "mistralai"
|
|
3
|
-
version = "0.4.
|
|
3
|
+
version = "0.4.2"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = ["Bam4d <bam4d@mistral.ai>"]
|
|
6
6
|
readme = "README.md"
|
|
7
|
+
license = "Apache 2.0 License"
|
|
7
8
|
|
|
8
9
|
[tool.ruff]
|
|
9
10
|
select = ["E", "F", "W", "Q", "I"]
|
|
@@ -23,10 +24,10 @@ exclude = ["docs", "tests", "examples", "tools", "build"]
|
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
[tool.poetry.dependencies]
|
|
26
|
-
python = "
|
|
27
|
-
orjson = "
|
|
28
|
-
pydantic = "
|
|
29
|
-
httpx = "
|
|
27
|
+
python = ">=3.9,<4.0"
|
|
28
|
+
orjson = ">=3.9.10,<3.11"
|
|
29
|
+
pydantic = ">=2.5.2,<3"
|
|
30
|
+
httpx = ">=0.25,<1"
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
[tool.poetry.group.dev.dependencies]
|
|
@@ -39,5 +40,3 @@ pytest-asyncio = "^0.23.2"
|
|
|
39
40
|
[build-system]
|
|
40
41
|
requires = ["poetry-core"]
|
|
41
42
|
build-backend = "poetry.core.masonry.api"
|
|
42
|
-
|
|
43
|
-
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import posixpath
|
|
3
3
|
from json import JSONDecodeError
|
|
4
|
-
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
4
|
+
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
|
|
5
5
|
|
|
6
6
|
from httpx import (
|
|
7
7
|
AsyncClient,
|
|
@@ -29,7 +29,7 @@ from mistralai.models.chat_completion import (
|
|
|
29
29
|
ToolChoice,
|
|
30
30
|
)
|
|
31
31
|
from mistralai.models.embeddings import EmbeddingResponse
|
|
32
|
-
from mistralai.models.models import ModelList
|
|
32
|
+
from mistralai.models.models import ModelDeleted, ModelList
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
class MistralAsyncClient(ClientBase):
|
|
@@ -101,6 +101,7 @@ class MistralAsyncClient(ClientBase):
|
|
|
101
101
|
stream: bool = False,
|
|
102
102
|
attempt: int = 1,
|
|
103
103
|
data: Optional[Dict[str, Any]] = None,
|
|
104
|
+
check_model_deprecation_headers_callback: Optional[Callable] = None,
|
|
104
105
|
**kwargs: Any,
|
|
105
106
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
106
107
|
accept_header = "text/event-stream" if stream else "application/json"
|
|
@@ -129,6 +130,8 @@ class MistralAsyncClient(ClientBase):
|
|
|
129
130
|
data=data,
|
|
130
131
|
**kwargs,
|
|
131
132
|
) as response:
|
|
133
|
+
if check_model_deprecation_headers_callback:
|
|
134
|
+
check_model_deprecation_headers_callback(response.headers)
|
|
132
135
|
await self._check_streaming_response(response)
|
|
133
136
|
|
|
134
137
|
async for line in response.aiter_lines():
|
|
@@ -145,7 +148,8 @@ class MistralAsyncClient(ClientBase):
|
|
|
145
148
|
data=data,
|
|
146
149
|
**kwargs,
|
|
147
150
|
)
|
|
148
|
-
|
|
151
|
+
if check_model_deprecation_headers_callback:
|
|
152
|
+
check_model_deprecation_headers_callback(response.headers)
|
|
149
153
|
yield await self._check_response(response)
|
|
150
154
|
|
|
151
155
|
except ConnectError as e:
|
|
@@ -213,7 +217,12 @@ class MistralAsyncClient(ClientBase):
|
|
|
213
217
|
response_format=response_format,
|
|
214
218
|
)
|
|
215
219
|
|
|
216
|
-
single_response = self._request(
|
|
220
|
+
single_response = self._request(
|
|
221
|
+
"post",
|
|
222
|
+
request,
|
|
223
|
+
"v1/chat/completions",
|
|
224
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
225
|
+
)
|
|
217
226
|
|
|
218
227
|
async for response in single_response:
|
|
219
228
|
return ChatCompletionResponse(**response)
|
|
@@ -267,7 +276,13 @@ class MistralAsyncClient(ClientBase):
|
|
|
267
276
|
tool_choice=tool_choice,
|
|
268
277
|
response_format=response_format,
|
|
269
278
|
)
|
|
270
|
-
async_response = self._request(
|
|
279
|
+
async_response = self._request(
|
|
280
|
+
"post",
|
|
281
|
+
request,
|
|
282
|
+
"v1/chat/completions",
|
|
283
|
+
stream=True,
|
|
284
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
285
|
+
)
|
|
271
286
|
|
|
272
287
|
async for json_response in async_response:
|
|
273
288
|
yield ChatCompletionStreamResponse(**json_response)
|
|
@@ -284,7 +299,12 @@ class MistralAsyncClient(ClientBase):
|
|
|
284
299
|
EmbeddingResponse: A response object containing the embeddings.
|
|
285
300
|
"""
|
|
286
301
|
request = {"model": model, "input": input}
|
|
287
|
-
single_response = self._request(
|
|
302
|
+
single_response = self._request(
|
|
303
|
+
"post",
|
|
304
|
+
request,
|
|
305
|
+
"v1/embeddings",
|
|
306
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
307
|
+
)
|
|
288
308
|
|
|
289
309
|
async for response in single_response:
|
|
290
310
|
return EmbeddingResponse(**response)
|
|
@@ -304,6 +324,14 @@ class MistralAsyncClient(ClientBase):
|
|
|
304
324
|
|
|
305
325
|
raise MistralException("No response received")
|
|
306
326
|
|
|
327
|
+
async def delete_model(self, model_id: str) -> ModelDeleted:
|
|
328
|
+
single_response = self._request("delete", {}, f"v1/models/{model_id}")
|
|
329
|
+
|
|
330
|
+
async for response in single_response:
|
|
331
|
+
return ModelDeleted(**response)
|
|
332
|
+
|
|
333
|
+
raise MistralException("No response received")
|
|
334
|
+
|
|
307
335
|
async def completion(
|
|
308
336
|
self,
|
|
309
337
|
model: str,
|
|
@@ -333,7 +361,12 @@ class MistralAsyncClient(ClientBase):
|
|
|
333
361
|
request = self._make_completion_request(
|
|
334
362
|
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
|
|
335
363
|
)
|
|
336
|
-
single_response = self._request(
|
|
364
|
+
single_response = self._request(
|
|
365
|
+
"post",
|
|
366
|
+
request,
|
|
367
|
+
"v1/fim/completions",
|
|
368
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
369
|
+
)
|
|
337
370
|
|
|
338
371
|
async for response in single_response:
|
|
339
372
|
return ChatCompletionResponse(**response)
|
|
@@ -368,9 +401,23 @@ class MistralAsyncClient(ClientBase):
|
|
|
368
401
|
Dict[str, Any]: a response object containing the generated text.
|
|
369
402
|
"""
|
|
370
403
|
request = self._make_completion_request(
|
|
371
|
-
prompt,
|
|
404
|
+
prompt,
|
|
405
|
+
model,
|
|
406
|
+
suffix,
|
|
407
|
+
temperature,
|
|
408
|
+
max_tokens,
|
|
409
|
+
top_p,
|
|
410
|
+
random_seed,
|
|
411
|
+
stop,
|
|
412
|
+
stream=True,
|
|
413
|
+
)
|
|
414
|
+
async_response = self._request(
|
|
415
|
+
"post",
|
|
416
|
+
request,
|
|
417
|
+
"v1/fim/completions",
|
|
418
|
+
stream=True,
|
|
419
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
372
420
|
)
|
|
373
|
-
async_response = self._request("post", request, "v1/fim/completions", stream=True)
|
|
374
421
|
|
|
375
422
|
async for json_response in async_response:
|
|
376
423
|
yield ChatCompletionStreamResponse(**json_response)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import posixpath
|
|
2
2
|
import time
|
|
3
3
|
from json import JSONDecodeError
|
|
4
|
-
from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
|
|
4
|
+
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Union
|
|
5
5
|
|
|
6
6
|
from httpx import Client, ConnectError, HTTPTransport, RequestError, Response
|
|
7
7
|
|
|
@@ -22,7 +22,7 @@ from mistralai.models.chat_completion import (
|
|
|
22
22
|
ToolChoice,
|
|
23
23
|
)
|
|
24
24
|
from mistralai.models.embeddings import EmbeddingResponse
|
|
25
|
-
from mistralai.models.models import ModelList
|
|
25
|
+
from mistralai.models.models import ModelDeleted, ModelList
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class MistralClient(ClientBase):
|
|
@@ -40,7 +40,9 @@ class MistralClient(ClientBase):
|
|
|
40
40
|
super().__init__(endpoint, api_key, max_retries, timeout)
|
|
41
41
|
|
|
42
42
|
self._client = Client(
|
|
43
|
-
follow_redirects=True,
|
|
43
|
+
follow_redirects=True,
|
|
44
|
+
timeout=self._timeout,
|
|
45
|
+
transport=HTTPTransport(retries=self._max_retries),
|
|
44
46
|
)
|
|
45
47
|
self.files = FilesClient(self)
|
|
46
48
|
self.jobs = JobsClient(self)
|
|
@@ -94,6 +96,7 @@ class MistralClient(ClientBase):
|
|
|
94
96
|
stream: bool = False,
|
|
95
97
|
attempt: int = 1,
|
|
96
98
|
data: Optional[Dict[str, Any]] = None,
|
|
99
|
+
check_model_deprecation_headers_callback: Optional[Callable] = None,
|
|
97
100
|
**kwargs: Any,
|
|
98
101
|
) -> Iterator[Dict[str, Any]]:
|
|
99
102
|
accept_header = "text/event-stream" if stream else "application/json"
|
|
@@ -122,6 +125,8 @@ class MistralClient(ClientBase):
|
|
|
122
125
|
data=data,
|
|
123
126
|
**kwargs,
|
|
124
127
|
) as response:
|
|
128
|
+
if check_model_deprecation_headers_callback:
|
|
129
|
+
check_model_deprecation_headers_callback(response.headers)
|
|
125
130
|
self._check_streaming_response(response)
|
|
126
131
|
|
|
127
132
|
for line in response.iter_lines():
|
|
@@ -138,7 +143,8 @@ class MistralClient(ClientBase):
|
|
|
138
143
|
data=data,
|
|
139
144
|
**kwargs,
|
|
140
145
|
)
|
|
141
|
-
|
|
146
|
+
if check_model_deprecation_headers_callback:
|
|
147
|
+
check_model_deprecation_headers_callback(response.headers)
|
|
142
148
|
yield self._check_response(response)
|
|
143
149
|
|
|
144
150
|
except ConnectError as e:
|
|
@@ -207,7 +213,12 @@ class MistralClient(ClientBase):
|
|
|
207
213
|
response_format=response_format,
|
|
208
214
|
)
|
|
209
215
|
|
|
210
|
-
single_response = self._request(
|
|
216
|
+
single_response = self._request(
|
|
217
|
+
"post",
|
|
218
|
+
request,
|
|
219
|
+
"v1/chat/completions",
|
|
220
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
221
|
+
)
|
|
211
222
|
|
|
212
223
|
for response in single_response:
|
|
213
224
|
return ChatCompletionResponse(**response)
|
|
@@ -261,7 +272,13 @@ class MistralClient(ClientBase):
|
|
|
261
272
|
response_format=response_format,
|
|
262
273
|
)
|
|
263
274
|
|
|
264
|
-
response = self._request(
|
|
275
|
+
response = self._request(
|
|
276
|
+
"post",
|
|
277
|
+
request,
|
|
278
|
+
"v1/chat/completions",
|
|
279
|
+
stream=True,
|
|
280
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
281
|
+
)
|
|
265
282
|
|
|
266
283
|
for json_streamed_response in response:
|
|
267
284
|
yield ChatCompletionStreamResponse(**json_streamed_response)
|
|
@@ -278,7 +295,12 @@ class MistralClient(ClientBase):
|
|
|
278
295
|
EmbeddingResponse: A response object containing the embeddings.
|
|
279
296
|
"""
|
|
280
297
|
request = {"model": model, "input": input}
|
|
281
|
-
singleton_response = self._request(
|
|
298
|
+
singleton_response = self._request(
|
|
299
|
+
"post",
|
|
300
|
+
request,
|
|
301
|
+
"v1/embeddings",
|
|
302
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
303
|
+
)
|
|
282
304
|
|
|
283
305
|
for response in singleton_response:
|
|
284
306
|
return EmbeddingResponse(**response)
|
|
@@ -298,6 +320,14 @@ class MistralClient(ClientBase):
|
|
|
298
320
|
|
|
299
321
|
raise MistralException("No response received")
|
|
300
322
|
|
|
323
|
+
def delete_model(self, model_id: str) -> ModelDeleted:
|
|
324
|
+
single_response = self._request("delete", {}, f"v1/models/{model_id}")
|
|
325
|
+
|
|
326
|
+
for response in single_response:
|
|
327
|
+
return ModelDeleted(**response)
|
|
328
|
+
|
|
329
|
+
raise MistralException("No response received")
|
|
330
|
+
|
|
301
331
|
def completion(
|
|
302
332
|
self,
|
|
303
333
|
model: str,
|
|
@@ -329,7 +359,13 @@ class MistralClient(ClientBase):
|
|
|
329
359
|
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
|
|
330
360
|
)
|
|
331
361
|
|
|
332
|
-
single_response = self._request(
|
|
362
|
+
single_response = self._request(
|
|
363
|
+
"post",
|
|
364
|
+
request,
|
|
365
|
+
"v1/fim/completions",
|
|
366
|
+
stream=False,
|
|
367
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
368
|
+
)
|
|
333
369
|
|
|
334
370
|
for response in single_response:
|
|
335
371
|
return ChatCompletionResponse(**response)
|
|
@@ -364,10 +400,24 @@ class MistralClient(ClientBase):
|
|
|
364
400
|
Iterable[Dict[str, Any]]: a generator that yields response objects containing the generated text.
|
|
365
401
|
"""
|
|
366
402
|
request = self._make_completion_request(
|
|
367
|
-
prompt,
|
|
403
|
+
prompt,
|
|
404
|
+
model,
|
|
405
|
+
suffix,
|
|
406
|
+
temperature,
|
|
407
|
+
max_tokens,
|
|
408
|
+
top_p,
|
|
409
|
+
random_seed,
|
|
410
|
+
stop,
|
|
411
|
+
stream=True,
|
|
368
412
|
)
|
|
369
413
|
|
|
370
|
-
response = self._request(
|
|
414
|
+
response = self._request(
|
|
415
|
+
"post",
|
|
416
|
+
request,
|
|
417
|
+
"v1/fim/completions",
|
|
418
|
+
stream=True,
|
|
419
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
420
|
+
)
|
|
371
421
|
|
|
372
422
|
for json_streamed_response in response:
|
|
373
423
|
yield ChatCompletionStreamResponse(**json_streamed_response)
|
|
@@ -1,16 +1,21 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
3
|
from abc import ABC
|
|
4
|
-
from typing import Any, Dict, List, Optional, Union
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
|
5
5
|
|
|
6
6
|
import orjson
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
7
|
+
from httpx import Headers
|
|
8
|
+
|
|
9
|
+
from mistralai.constants import HEADER_MODEL_DEPRECATION_TIMESTAMP
|
|
10
|
+
from mistralai.exceptions import MistralException
|
|
11
|
+
from mistralai.models.chat_completion import (
|
|
12
|
+
ChatMessage,
|
|
13
|
+
Function,
|
|
14
|
+
ResponseFormat,
|
|
15
|
+
ToolChoice,
|
|
10
16
|
)
|
|
11
|
-
from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice
|
|
12
17
|
|
|
13
|
-
CLIENT_VERSION = "0.4.
|
|
18
|
+
CLIENT_VERSION = "0.4.2"
|
|
14
19
|
|
|
15
20
|
|
|
16
21
|
class ClientBase(ABC):
|
|
@@ -38,6 +43,14 @@ class ClientBase(ABC):
|
|
|
38
43
|
|
|
39
44
|
self._version = CLIENT_VERSION
|
|
40
45
|
|
|
46
|
+
def _get_model(self, model: Optional[str] = None) -> str:
|
|
47
|
+
if model is not None:
|
|
48
|
+
return model
|
|
49
|
+
else:
|
|
50
|
+
if self._default_model is None:
|
|
51
|
+
raise MistralException(message="model must be provided")
|
|
52
|
+
return self._default_model
|
|
53
|
+
|
|
41
54
|
def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
42
55
|
parsed_tools: List[Dict[str, Any]] = []
|
|
43
56
|
for tool in tools:
|
|
@@ -73,6 +86,22 @@ class ClientBase(ABC):
|
|
|
73
86
|
|
|
74
87
|
return parsed_messages
|
|
75
88
|
|
|
89
|
+
def _check_model_deprecation_header_callback_factory(self, model: Optional[str] = None) -> Callable:
|
|
90
|
+
model = self._get_model(model)
|
|
91
|
+
|
|
92
|
+
def _check_model_deprecation_header_callback(
|
|
93
|
+
headers: Headers,
|
|
94
|
+
) -> None:
|
|
95
|
+
if HEADER_MODEL_DEPRECATION_TIMESTAMP in headers:
|
|
96
|
+
self._logger.warning(
|
|
97
|
+
f"WARNING: The model {model} is deprecated "
|
|
98
|
+
f"and will be removed on {headers[HEADER_MODEL_DEPRECATION_TIMESTAMP]}. "
|
|
99
|
+
"Please refer to https://docs.mistral.ai/getting-started/models/#api-versioning "
|
|
100
|
+
"for more information."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return _check_model_deprecation_header_callback
|
|
104
|
+
|
|
76
105
|
def _make_completion_request(
|
|
77
106
|
self,
|
|
78
107
|
prompt: str,
|
|
@@ -95,16 +124,14 @@ class ClientBase(ABC):
|
|
|
95
124
|
if stop is not None:
|
|
96
125
|
request_data["stop"] = stop
|
|
97
126
|
|
|
98
|
-
|
|
99
|
-
request_data["model"] = model
|
|
100
|
-
else:
|
|
101
|
-
if self._default_model is None:
|
|
102
|
-
raise MistralException(message="model must be provided")
|
|
103
|
-
request_data["model"] = self._default_model
|
|
127
|
+
request_data["model"] = self._get_model(model)
|
|
104
128
|
|
|
105
129
|
request_data.update(
|
|
106
130
|
self._build_sampling_params(
|
|
107
|
-
temperature=temperature,
|
|
131
|
+
temperature=temperature,
|
|
132
|
+
max_tokens=max_tokens,
|
|
133
|
+
top_p=top_p,
|
|
134
|
+
random_seed=random_seed,
|
|
108
135
|
)
|
|
109
136
|
)
|
|
110
137
|
|
|
@@ -148,16 +175,14 @@ class ClientBase(ABC):
|
|
|
148
175
|
"messages": self._parse_messages(messages),
|
|
149
176
|
}
|
|
150
177
|
|
|
151
|
-
|
|
152
|
-
request_data["model"] = model
|
|
153
|
-
else:
|
|
154
|
-
if self._default_model is None:
|
|
155
|
-
raise MistralException(message="model must be provided")
|
|
156
|
-
request_data["model"] = self._default_model
|
|
178
|
+
request_data["model"] = self._get_model(model)
|
|
157
179
|
|
|
158
180
|
request_data.update(
|
|
159
181
|
self._build_sampling_params(
|
|
160
|
-
temperature=temperature,
|
|
182
|
+
temperature=temperature,
|
|
183
|
+
max_tokens=max_tokens,
|
|
184
|
+
top_p=top_p,
|
|
185
|
+
random_seed=random_seed,
|
|
161
186
|
)
|
|
162
187
|
)
|
|
163
188
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|