model-library 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +139 -62
- model_library/base/delegate_only.py +77 -10
- model_library/base/output.py +43 -0
- model_library/base/utils.py +35 -0
- model_library/config/alibaba_models.yaml +44 -57
- model_library/config/all_models.json +253 -126
- model_library/config/kimi_models.yaml +30 -3
- model_library/config/openai_models.yaml +15 -23
- model_library/config/zai_models.yaml +24 -3
- model_library/exceptions.py +3 -77
- model_library/providers/ai21labs.py +12 -8
- model_library/providers/alibaba.py +17 -8
- model_library/providers/amazon.py +49 -16
- model_library/providers/anthropic.py +93 -40
- model_library/providers/azure.py +22 -10
- model_library/providers/cohere.py +7 -7
- model_library/providers/deepseek.py +8 -8
- model_library/providers/fireworks.py +7 -8
- model_library/providers/google/batch.py +14 -10
- model_library/providers/google/google.py +48 -29
- model_library/providers/inception.py +7 -7
- model_library/providers/kimi.py +18 -8
- model_library/providers/minimax.py +15 -17
- model_library/providers/mistral.py +20 -8
- model_library/providers/openai.py +99 -22
- model_library/providers/openrouter.py +34 -0
- model_library/providers/perplexity.py +7 -7
- model_library/providers/together.py +7 -8
- model_library/providers/vals.py +12 -6
- model_library/providers/xai.py +47 -42
- model_library/providers/zai.py +38 -8
- model_library/registry_utils.py +39 -15
- model_library/retriers/__init__.py +0 -0
- model_library/retriers/backoff.py +73 -0
- model_library/retriers/base.py +225 -0
- model_library/retriers/token.py +427 -0
- model_library/retriers/utils.py +11 -0
- model_library/settings.py +1 -1
- model_library/utils.py +13 -0
- {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/METADATA +2 -1
- model_library-0.1.8.dist-info/RECORD +70 -0
- {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
- model_library-0.1.7.dist-info/RECORD +0 -64
- {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
|
+
from pydantic import SecretStr
|
|
3
4
|
from typing_extensions import override
|
|
4
5
|
|
|
5
6
|
from model_library import model_library_settings
|
|
6
7
|
from model_library.base import (
|
|
8
|
+
DelegateConfig,
|
|
9
|
+
DelegateOnly,
|
|
7
10
|
LLMConfig,
|
|
8
11
|
ProviderConfig,
|
|
9
12
|
QueryResultCost,
|
|
10
13
|
QueryResultMetadata,
|
|
11
14
|
)
|
|
12
|
-
from model_library.base.delegate_only import DelegateOnly
|
|
13
|
-
from model_library.providers.openai import OpenAIModel
|
|
14
15
|
from model_library.register_models import register_provider
|
|
15
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class FireworksConfig(ProviderConfig):
|
|
@@ -38,15 +38,14 @@ class FireworksModel(DelegateOnly):
|
|
|
38
38
|
self.model_name = "accounts/rayan-936e28/deployedModels/" + self.model_name
|
|
39
39
|
|
|
40
40
|
# https://docs.fireworks.ai/tools-sdks/openai-compatibility
|
|
41
|
-
self.
|
|
42
|
-
model_name=self.model_name,
|
|
43
|
-
provider=self.provider,
|
|
41
|
+
self.init_delegate(
|
|
44
42
|
config=config,
|
|
45
|
-
|
|
46
|
-
api_key=model_library_settings.FIREWORKS_API_KEY,
|
|
43
|
+
delegate_config=DelegateConfig(
|
|
47
44
|
base_url="https://api.fireworks.ai/inference/v1",
|
|
45
|
+
api_key=SecretStr(model_library_settings.FIREWORKS_API_KEY),
|
|
48
46
|
),
|
|
49
47
|
use_completions=True,
|
|
48
|
+
delegate_provider="openai",
|
|
50
49
|
)
|
|
51
50
|
|
|
52
51
|
@override
|
|
@@ -24,16 +24,19 @@ from google.genai.types import (
|
|
|
24
24
|
)
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
def extract_text_from_json_response(response: dict[str, Any]) -> str:
|
|
27
|
+
def extract_text_from_json_response(response: dict[str, Any]) -> tuple[str, str]:
|
|
28
28
|
"""Extract concatenated non-thought text from a JSON response structure."""
|
|
29
29
|
# TODO: fix the typing we always ignore
|
|
30
30
|
text = ""
|
|
31
|
+
reasoning = ""
|
|
31
32
|
for candidate in response.get("candidates", []) or []: # type: ignore
|
|
32
33
|
content = (candidate or {}).get("content") or {} # type: ignore
|
|
33
34
|
for part in content.get("parts", []) or []: # type: ignore
|
|
34
|
-
if
|
|
35
|
+
if part.get("thought", False): # type: ignore
|
|
36
|
+
reasoning += part.get("text", "") # type: ignore
|
|
37
|
+
else:
|
|
35
38
|
text += part.get("text", "") # type: ignore
|
|
36
|
-
return text # type: ignore
|
|
39
|
+
return text, reasoning # type: ignore
|
|
37
40
|
|
|
38
41
|
|
|
39
42
|
def parse_predictions_jsonl(jsonl: str) -> list[BatchResult]:
|
|
@@ -48,9 +51,10 @@ def parse_predictions_jsonl(jsonl: str) -> list[BatchResult]:
|
|
|
48
51
|
custom_id = data.get("key", "unknown")
|
|
49
52
|
if "response" in data:
|
|
50
53
|
response = data["response"]
|
|
51
|
-
text = extract_text_from_json_response(response)
|
|
54
|
+
text, reasoning = extract_text_from_json_response(response)
|
|
52
55
|
output = QueryResult()
|
|
53
56
|
output.output_text = text
|
|
57
|
+
output.reasoning = reasoning
|
|
54
58
|
if "usageMetadata" in response:
|
|
55
59
|
output.metadata.in_tokens = response["usageMetadata"].get(
|
|
56
60
|
"promptTokenCount", 0
|
|
@@ -196,7 +200,7 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
196
200
|
custom_id = labels.get("qa_pair_id", f"request-{i}")
|
|
197
201
|
jsonl_lines.append(json.dumps({"key": custom_id, "request": request_data}))
|
|
198
202
|
|
|
199
|
-
batch_request_file = self._root.
|
|
203
|
+
batch_request_file = self._root.get_client().files.upload(
|
|
200
204
|
file=io.StringIO("\n".join(jsonl_lines)),
|
|
201
205
|
config=UploadFileConfig(mime_type="application/jsonl"),
|
|
202
206
|
)
|
|
@@ -205,7 +209,7 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
205
209
|
raise Exception("Failed to upload batch jsonl")
|
|
206
210
|
|
|
207
211
|
try:
|
|
208
|
-
job: BatchJob = await self._root.
|
|
212
|
+
job: BatchJob = await self._root.get_client().aio.batches.create(
|
|
209
213
|
model=self._root.model_name,
|
|
210
214
|
src=batch_request_file.name,
|
|
211
215
|
config={"display_name": batch_name},
|
|
@@ -224,14 +228,14 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
224
228
|
async def get_batch_results(self, batch_id: str) -> list[BatchResult]:
|
|
225
229
|
self._root.logger.info(f"Retrieving batch results for {batch_id}")
|
|
226
230
|
|
|
227
|
-
job = await self._root.
|
|
231
|
+
job = await self._root.get_client().aio.batches.get(name=batch_id)
|
|
228
232
|
|
|
229
233
|
results: list[BatchResult] = []
|
|
230
234
|
|
|
231
235
|
if job.state == JobState.JOB_STATE_SUCCEEDED:
|
|
232
236
|
if job.dest and job.dest.file_name:
|
|
233
237
|
results_file_name = job.dest.file_name
|
|
234
|
-
file_content = await self._root.
|
|
238
|
+
file_content = await self._root.get_client().aio.files.download(
|
|
235
239
|
file=results_file_name
|
|
236
240
|
)
|
|
237
241
|
decoded = file_content.decode("utf-8")
|
|
@@ -250,7 +254,7 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
250
254
|
@override
|
|
251
255
|
async def cancel_batch_request(self, batch_id: str):
|
|
252
256
|
self._root.logger.info(f"Cancelling batch {batch_id}")
|
|
253
|
-
await self._root.
|
|
257
|
+
await self._root.get_client().aio.batches.cancel(name=batch_id)
|
|
254
258
|
|
|
255
259
|
@override
|
|
256
260
|
async def get_batch_progress(self, batch_id: str) -> int:
|
|
@@ -262,7 +266,7 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
262
266
|
|
|
263
267
|
try:
|
|
264
268
|
self._root.logger.debug(f"Checking batch status for {batch_id}")
|
|
265
|
-
job: BatchJob = await self._root.
|
|
269
|
+
job: BatchJob = await self._root.get_client().aio.batches.get(name=batch_id)
|
|
266
270
|
state = job.state
|
|
267
271
|
|
|
268
272
|
if not state:
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import io
|
|
3
|
+
import json
|
|
3
4
|
import logging
|
|
4
5
|
import uuid
|
|
5
6
|
from typing import Any, Literal, Sequence, cast
|
|
@@ -25,6 +26,7 @@ from google.genai.types import (
|
|
|
25
26
|
ToolListUnion,
|
|
26
27
|
UploadFileConfig,
|
|
27
28
|
)
|
|
29
|
+
from google.oauth2 import service_account
|
|
28
30
|
from typing_extensions import override
|
|
29
31
|
|
|
30
32
|
from model_library import model_library_settings
|
|
@@ -95,31 +97,50 @@ class GoogleModel(LLM):
|
|
|
95
97
|
),
|
|
96
98
|
]
|
|
97
99
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
"gemini-2.5-flash-preview-09-2025": "global",
|
|
108
|
-
"gemini-2.5-flash-lite-preview-09-2025": "global",
|
|
100
|
+
def _get_default_api_key(self) -> str:
|
|
101
|
+
if not self.provider_config.use_vertex:
|
|
102
|
+
return model_library_settings.GOOGLE_API_KEY
|
|
103
|
+
|
|
104
|
+
return json.dumps(
|
|
105
|
+
{
|
|
106
|
+
"GCP_REGION": model_library_settings.GCP_REGION,
|
|
107
|
+
"GCP_PROJECT_ID": model_library_settings.GCP_PROJECT_ID,
|
|
108
|
+
"GCP_CREDS": model_library_settings.GCP_CREDS,
|
|
109
109
|
}
|
|
110
|
-
|
|
111
|
-
if self.model_name in MODEL_REGION_OVERRIDES:
|
|
112
|
-
region = MODEL_REGION_OVERRIDES[self.model_name]
|
|
113
|
-
|
|
114
|
-
return Client(
|
|
115
|
-
vertexai=True,
|
|
116
|
-
project=model_library_settings.GCP_PROJECT_ID,
|
|
117
|
-
location=region,
|
|
118
|
-
# Credentials object is not typed, so we have to ignore the error
|
|
119
|
-
credentials=model_library_settings.GCP_CREDS,
|
|
120
|
-
)
|
|
110
|
+
)
|
|
121
111
|
|
|
122
|
-
|
|
112
|
+
@override
|
|
113
|
+
def get_client(self, api_key: str | None = None) -> Client:
|
|
114
|
+
if not self.has_client():
|
|
115
|
+
assert api_key
|
|
116
|
+
if self.provider_config.use_vertex:
|
|
117
|
+
# Gemini preview releases are only server from the global Vertex region after September 2025.
|
|
118
|
+
MODEL_REGION_OVERRIDES: dict[str, str] = {
|
|
119
|
+
"gemini-2.5-flash-preview-09-2025": "global",
|
|
120
|
+
"gemini-2.5-flash-lite-preview-09-2025": "global",
|
|
121
|
+
"gemini-3-flash-preview": "global",
|
|
122
|
+
"gemini-3-pro-preview": "global",
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
creds = json.loads(api_key)
|
|
126
|
+
|
|
127
|
+
region = creds["GCP_REGION"]
|
|
128
|
+
if self.model_name in MODEL_REGION_OVERRIDES:
|
|
129
|
+
region = MODEL_REGION_OVERRIDES[self.model_name]
|
|
130
|
+
|
|
131
|
+
client = Client(
|
|
132
|
+
vertexai=True,
|
|
133
|
+
project=creds["GCP_PROJECT_ID"],
|
|
134
|
+
location=region,
|
|
135
|
+
credentials=service_account.Credentials.from_service_account_info( # type: ignore
|
|
136
|
+
json.loads(creds["GCP_CREDS"]),
|
|
137
|
+
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
138
|
+
),
|
|
139
|
+
)
|
|
140
|
+
else:
|
|
141
|
+
client = Client(api_key=api_key)
|
|
142
|
+
self.assign_client(client)
|
|
143
|
+
return super().get_client()
|
|
123
144
|
|
|
124
145
|
def __init__(
|
|
125
146
|
self,
|
|
@@ -141,8 +162,6 @@ class GoogleModel(LLM):
|
|
|
141
162
|
GoogleBatchMixin(self) if self.supports_batch else None
|
|
142
163
|
)
|
|
143
164
|
|
|
144
|
-
self.client = self.get_client()
|
|
145
|
-
|
|
146
165
|
@override
|
|
147
166
|
async def parse_input(
|
|
148
167
|
self,
|
|
@@ -260,7 +279,7 @@ class GoogleModel(LLM):
|
|
|
260
279
|
)
|
|
261
280
|
|
|
262
281
|
mime = f"image/{mime}" if type == "image" else mime # TODO:
|
|
263
|
-
response: File = self.
|
|
282
|
+
response: File = self.get_client().files.upload(
|
|
264
283
|
file=bytes, config=UploadFileConfig(mime_type=mime)
|
|
265
284
|
)
|
|
266
285
|
if not response.name:
|
|
@@ -338,7 +357,7 @@ class GoogleModel(LLM):
|
|
|
338
357
|
|
|
339
358
|
metadata: GenerateContentResponseUsageMetadata | None = None
|
|
340
359
|
|
|
341
|
-
stream = await self.
|
|
360
|
+
stream = await self.get_client().aio.models.generate_content_stream(**body)
|
|
342
361
|
contents: list[Content | None] = []
|
|
343
362
|
finish_reason: FinishReason | None = None
|
|
344
363
|
|
|
@@ -437,7 +456,7 @@ class GoogleModel(LLM):
|
|
|
437
456
|
tools=parsed_tools,
|
|
438
457
|
)
|
|
439
458
|
|
|
440
|
-
response = await self.
|
|
459
|
+
response = await self.get_client().aio.models.count_tokens(
|
|
441
460
|
model=self.model_name,
|
|
442
461
|
contents=cast(Any, contents),
|
|
443
462
|
config=config,
|
|
@@ -503,7 +522,7 @@ class GoogleModel(LLM):
|
|
|
503
522
|
# Make the request with retry wrapper
|
|
504
523
|
async def _query():
|
|
505
524
|
try:
|
|
506
|
-
return await self.
|
|
525
|
+
return await self.get_client().aio.models.generate_content(**body)
|
|
507
526
|
except (genai_errors.ServerError, genai_errors.UnknownApiResponseError):
|
|
508
527
|
raise ImmediateRetryException("Failed to connect to Google API")
|
|
509
528
|
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
|
|
3
5
|
from model_library import model_library_settings
|
|
4
6
|
from model_library.base import (
|
|
7
|
+
DelegateConfig,
|
|
5
8
|
DelegateOnly,
|
|
6
9
|
LLMConfig,
|
|
7
10
|
)
|
|
8
|
-
from model_library.providers.openai import OpenAIModel
|
|
9
11
|
from model_library.register_models import register_provider
|
|
10
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
@register_provider("inception")
|
|
@@ -22,13 +23,12 @@ class MercuryModel(DelegateOnly):
|
|
|
22
23
|
super().__init__(model_name, provider, config=config)
|
|
23
24
|
|
|
24
25
|
# https://docs.inceptionlabs.ai/get-started/get-started#external-libraries-compatibility
|
|
25
|
-
self.
|
|
26
|
-
model_name=self.model_name,
|
|
27
|
-
provider=self.provider,
|
|
26
|
+
self.init_delegate(
|
|
28
27
|
config=config,
|
|
29
|
-
|
|
30
|
-
api_key=model_library_settings.MERCURY_API_KEY,
|
|
28
|
+
delegate_config=DelegateConfig(
|
|
31
29
|
base_url="https://api.inceptionlabs.ai/v1/",
|
|
30
|
+
api_key=SecretStr(model_library_settings.MERCURY_API_KEY),
|
|
32
31
|
),
|
|
33
32
|
use_completions=True,
|
|
33
|
+
delegate_provider="openai",
|
|
34
34
|
)
|
model_library/providers/kimi.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
|
1
|
-
from typing import Literal
|
|
1
|
+
from typing import Any, Literal
|
|
2
|
+
|
|
3
|
+
from typing_extensions import override
|
|
4
|
+
|
|
5
|
+
from pydantic import SecretStr
|
|
2
6
|
|
|
3
7
|
from model_library import model_library_settings
|
|
4
8
|
from model_library.base import (
|
|
9
|
+
DelegateConfig,
|
|
5
10
|
DelegateOnly,
|
|
6
11
|
LLMConfig,
|
|
7
12
|
)
|
|
8
|
-
from model_library.providers.openai import OpenAIModel
|
|
9
13
|
from model_library.register_models import register_provider
|
|
10
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
11
14
|
|
|
12
15
|
|
|
13
16
|
@register_provider("kimi")
|
|
@@ -22,13 +25,20 @@ class KimiModel(DelegateOnly):
|
|
|
22
25
|
super().__init__(model_name, provider, config=config)
|
|
23
26
|
|
|
24
27
|
# https://platform.moonshot.ai/docs/guide/migrating-from-openai-to-kimi#about-api-compatibility
|
|
25
|
-
self.
|
|
26
|
-
model_name=self.model_name,
|
|
27
|
-
provider=self.provider,
|
|
28
|
+
self.init_delegate(
|
|
28
29
|
config=config,
|
|
29
|
-
|
|
30
|
-
api_key=model_library_settings.KIMI_API_KEY,
|
|
30
|
+
delegate_config=DelegateConfig(
|
|
31
31
|
base_url="https://api.moonshot.ai/v1/",
|
|
32
|
+
api_key=SecretStr(model_library_settings.KIMI_API_KEY),
|
|
32
33
|
),
|
|
33
34
|
use_completions=True,
|
|
35
|
+
delegate_provider="openai",
|
|
34
36
|
)
|
|
37
|
+
|
|
38
|
+
@override
|
|
39
|
+
def _get_extra_body(self) -> dict[str, Any]:
|
|
40
|
+
"""
|
|
41
|
+
Build extra body parameters for Kimi-specific features.
|
|
42
|
+
see https://platform.moonshot.ai/docs/guide/kimi-k2-5-quickstart#parameters-differences-in-request-body
|
|
43
|
+
"""
|
|
44
|
+
return {"thinking": {"type": "enabled" if self.reasoning else "disabled"}}
|
|
@@ -1,16 +1,17 @@
|
|
|
1
|
-
from typing import Literal
|
|
1
|
+
from typing import Literal, Sequence
|
|
2
|
+
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
from typing_extensions import override
|
|
2
5
|
|
|
3
6
|
from model_library import model_library_settings
|
|
4
|
-
from model_library.base import
|
|
5
|
-
|
|
6
|
-
|
|
7
|
+
from model_library.base import (
|
|
8
|
+
DelegateConfig,
|
|
9
|
+
DelegateOnly,
|
|
10
|
+
InputItem,
|
|
11
|
+
LLMConfig,
|
|
12
|
+
ToolDefinition,
|
|
13
|
+
)
|
|
7
14
|
from model_library.register_models import register_provider
|
|
8
|
-
from model_library.utils import default_httpx_client
|
|
9
|
-
|
|
10
|
-
from anthropic import AsyncAnthropic
|
|
11
|
-
|
|
12
|
-
from typing import Sequence
|
|
13
|
-
from typing_extensions import override
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
@register_provider("minimax")
|
|
@@ -24,16 +25,13 @@ class MinimaxModel(DelegateOnly):
|
|
|
24
25
|
):
|
|
25
26
|
super().__init__(model_name, provider, config=config)
|
|
26
27
|
|
|
27
|
-
self.
|
|
28
|
-
model_name=self.model_name,
|
|
29
|
-
provider=self.provider,
|
|
28
|
+
self.init_delegate(
|
|
30
29
|
config=config,
|
|
31
|
-
|
|
32
|
-
api_key=model_library_settings.MINIMAX_API_KEY,
|
|
30
|
+
delegate_config=DelegateConfig(
|
|
33
31
|
base_url="https://api.minimax.io/anthropic",
|
|
34
|
-
|
|
35
|
-
max_retries=1,
|
|
32
|
+
api_key=SecretStr(model_library_settings.MINIMAX_API_KEY),
|
|
36
33
|
),
|
|
34
|
+
delegate_provider="anthropic",
|
|
37
35
|
)
|
|
38
36
|
|
|
39
37
|
# minimax client shares anthropic's syntax
|
|
@@ -3,7 +3,13 @@ import logging
|
|
|
3
3
|
from collections.abc import Sequence
|
|
4
4
|
from typing import Any, Literal
|
|
5
5
|
|
|
6
|
-
from mistralai import
|
|
6
|
+
from mistralai import (
|
|
7
|
+
AssistantMessage,
|
|
8
|
+
ContentChunk,
|
|
9
|
+
Mistral,
|
|
10
|
+
TextChunk,
|
|
11
|
+
ThinkChunk,
|
|
12
|
+
)
|
|
7
13
|
from mistralai.models.completionevent import CompletionEvent
|
|
8
14
|
from mistralai.models.toolcall import ToolCall as MistralToolCall
|
|
9
15
|
from mistralai.utils.eventstreaming import EventStreamAsync
|
|
@@ -40,16 +46,20 @@ from model_library.utils import default_httpx_client
|
|
|
40
46
|
|
|
41
47
|
@register_provider("mistralai")
|
|
42
48
|
class MistralModel(LLM):
|
|
43
|
-
|
|
49
|
+
@override
|
|
50
|
+
def _get_default_api_key(self) -> str:
|
|
51
|
+
return model_library_settings.MISTRAL_API_KEY
|
|
44
52
|
|
|
45
53
|
@override
|
|
46
|
-
def get_client(self) -> Mistral:
|
|
47
|
-
if not
|
|
48
|
-
|
|
49
|
-
|
|
54
|
+
def get_client(self, api_key: str | None = None) -> Mistral:
|
|
55
|
+
if not self.has_client():
|
|
56
|
+
assert api_key
|
|
57
|
+
client = Mistral(
|
|
58
|
+
api_key=api_key,
|
|
50
59
|
async_client=default_httpx_client(),
|
|
51
60
|
)
|
|
52
|
-
|
|
61
|
+
self.assign_client(client)
|
|
62
|
+
return super().get_client()
|
|
53
63
|
|
|
54
64
|
def __init__(
|
|
55
65
|
self,
|
|
@@ -198,12 +208,14 @@ class MistralModel(LLM):
|
|
|
198
208
|
|
|
199
209
|
body: dict[str, Any] = {
|
|
200
210
|
"model": self.model_name,
|
|
201
|
-
"max_tokens": self.max_tokens,
|
|
202
211
|
"messages": messages,
|
|
203
212
|
"prompt_mode": "reasoning" if self.reasoning else None,
|
|
204
213
|
"tools": tools,
|
|
205
214
|
}
|
|
206
215
|
|
|
216
|
+
if self.max_tokens:
|
|
217
|
+
body["max_tokens"] = self.max_tokens
|
|
218
|
+
|
|
207
219
|
if self.supports_temperature:
|
|
208
220
|
if self.temperature is not None:
|
|
209
221
|
body["temperature"] = self.temperature
|