model-library 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +237 -62
- model_library/base/delegate_only.py +86 -9
- model_library/base/input.py +10 -7
- model_library/base/output.py +48 -0
- model_library/base/utils.py +56 -7
- model_library/config/alibaba_models.yaml +44 -57
- model_library/config/all_models.json +253 -126
- model_library/config/kimi_models.yaml +30 -3
- model_library/config/openai_models.yaml +15 -23
- model_library/config/zai_models.yaml +24 -3
- model_library/exceptions.py +14 -77
- model_library/logging.py +6 -2
- model_library/providers/ai21labs.py +30 -14
- model_library/providers/alibaba.py +17 -8
- model_library/providers/amazon.py +119 -64
- model_library/providers/anthropic.py +184 -104
- model_library/providers/azure.py +22 -10
- model_library/providers/cohere.py +7 -7
- model_library/providers/deepseek.py +8 -8
- model_library/providers/fireworks.py +7 -8
- model_library/providers/google/batch.py +17 -13
- model_library/providers/google/google.py +130 -73
- model_library/providers/inception.py +7 -7
- model_library/providers/kimi.py +18 -8
- model_library/providers/minimax.py +30 -13
- model_library/providers/mistral.py +61 -35
- model_library/providers/openai.py +219 -93
- model_library/providers/openrouter.py +34 -0
- model_library/providers/perplexity.py +7 -7
- model_library/providers/together.py +7 -8
- model_library/providers/vals.py +16 -9
- model_library/providers/xai.py +157 -144
- model_library/providers/zai.py +38 -8
- model_library/register_models.py +4 -2
- model_library/registry_utils.py +39 -15
- model_library/retriers/__init__.py +0 -0
- model_library/retriers/backoff.py +73 -0
- model_library/retriers/base.py +225 -0
- model_library/retriers/token.py +427 -0
- model_library/retriers/utils.py +11 -0
- model_library/settings.py +1 -1
- model_library/utils.py +13 -35
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/METADATA +4 -3
- model_library-0.1.8.dist-info/RECORD +70 -0
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
- model_library-0.1.6.dist-info/RECORD +0 -64
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
|
@@ -2,8 +2,6 @@ import io
|
|
|
2
2
|
import json
|
|
3
3
|
from typing import TYPE_CHECKING, Any, Final, Sequence, cast
|
|
4
4
|
|
|
5
|
-
from typing_extensions import override
|
|
6
|
-
|
|
7
5
|
from google.genai.types import (
|
|
8
6
|
BatchJob,
|
|
9
7
|
Content,
|
|
@@ -11,6 +9,8 @@ from google.genai.types import (
|
|
|
11
9
|
JobState,
|
|
12
10
|
UploadFileConfig,
|
|
13
11
|
)
|
|
12
|
+
from typing_extensions import override
|
|
13
|
+
|
|
14
14
|
from model_library.base import BatchResult, InputItem, LLMBatchMixin
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
@@ -24,16 +24,19 @@ from google.genai.types import (
|
|
|
24
24
|
)
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
def extract_text_from_json_response(response: dict[str, Any]) -> str:
|
|
27
|
+
def extract_text_from_json_response(response: dict[str, Any]) -> tuple[str, str]:
|
|
28
28
|
"""Extract concatenated non-thought text from a JSON response structure."""
|
|
29
29
|
# TODO: fix the typing we always ignore
|
|
30
30
|
text = ""
|
|
31
|
+
reasoning = ""
|
|
31
32
|
for candidate in response.get("candidates", []) or []: # type: ignore
|
|
32
33
|
content = (candidate or {}).get("content") or {} # type: ignore
|
|
33
34
|
for part in content.get("parts", []) or []: # type: ignore
|
|
34
|
-
if
|
|
35
|
+
if part.get("thought", False): # type: ignore
|
|
36
|
+
reasoning += part.get("text", "") # type: ignore
|
|
37
|
+
else:
|
|
35
38
|
text += part.get("text", "") # type: ignore
|
|
36
|
-
return text # type: ignore
|
|
39
|
+
return text, reasoning # type: ignore
|
|
37
40
|
|
|
38
41
|
|
|
39
42
|
def parse_predictions_jsonl(jsonl: str) -> list[BatchResult]:
|
|
@@ -48,9 +51,10 @@ def parse_predictions_jsonl(jsonl: str) -> list[BatchResult]:
|
|
|
48
51
|
custom_id = data.get("key", "unknown")
|
|
49
52
|
if "response" in data:
|
|
50
53
|
response = data["response"]
|
|
51
|
-
text = extract_text_from_json_response(response)
|
|
54
|
+
text, reasoning = extract_text_from_json_response(response)
|
|
52
55
|
output = QueryResult()
|
|
53
56
|
output.output_text = text
|
|
57
|
+
output.reasoning = reasoning
|
|
54
58
|
if "usageMetadata" in response:
|
|
55
59
|
output.metadata.in_tokens = response["usageMetadata"].get(
|
|
56
60
|
"promptTokenCount", 0
|
|
@@ -144,7 +148,7 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
144
148
|
**kwargs: object,
|
|
145
149
|
) -> dict[str, Any]:
|
|
146
150
|
self._root.logger.debug(f"Creating batch request for custom_id: {custom_id}")
|
|
147
|
-
body = await self._root.
|
|
151
|
+
body = await self._root.build_body(input, tools=[], **kwargs)
|
|
148
152
|
|
|
149
153
|
contents_any = body["contents"]
|
|
150
154
|
serialized_contents: list[dict[str, Any]] = [
|
|
@@ -196,7 +200,7 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
196
200
|
custom_id = labels.get("qa_pair_id", f"request-{i}")
|
|
197
201
|
jsonl_lines.append(json.dumps({"key": custom_id, "request": request_data}))
|
|
198
202
|
|
|
199
|
-
batch_request_file = self._root.
|
|
203
|
+
batch_request_file = self._root.get_client().files.upload(
|
|
200
204
|
file=io.StringIO("\n".join(jsonl_lines)),
|
|
201
205
|
config=UploadFileConfig(mime_type="application/jsonl"),
|
|
202
206
|
)
|
|
@@ -205,7 +209,7 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
205
209
|
raise Exception("Failed to upload batch jsonl")
|
|
206
210
|
|
|
207
211
|
try:
|
|
208
|
-
job: BatchJob = await self._root.
|
|
212
|
+
job: BatchJob = await self._root.get_client().aio.batches.create(
|
|
209
213
|
model=self._root.model_name,
|
|
210
214
|
src=batch_request_file.name,
|
|
211
215
|
config={"display_name": batch_name},
|
|
@@ -224,14 +228,14 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
224
228
|
async def get_batch_results(self, batch_id: str) -> list[BatchResult]:
|
|
225
229
|
self._root.logger.info(f"Retrieving batch results for {batch_id}")
|
|
226
230
|
|
|
227
|
-
job = await self._root.
|
|
231
|
+
job = await self._root.get_client().aio.batches.get(name=batch_id)
|
|
228
232
|
|
|
229
233
|
results: list[BatchResult] = []
|
|
230
234
|
|
|
231
235
|
if job.state == JobState.JOB_STATE_SUCCEEDED:
|
|
232
236
|
if job.dest and job.dest.file_name:
|
|
233
237
|
results_file_name = job.dest.file_name
|
|
234
|
-
file_content = await self._root.
|
|
238
|
+
file_content = await self._root.get_client().aio.files.download(
|
|
235
239
|
file=results_file_name
|
|
236
240
|
)
|
|
237
241
|
decoded = file_content.decode("utf-8")
|
|
@@ -250,7 +254,7 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
250
254
|
@override
|
|
251
255
|
async def cancel_batch_request(self, batch_id: str):
|
|
252
256
|
self._root.logger.info(f"Cancelling batch {batch_id}")
|
|
253
|
-
await self._root.
|
|
257
|
+
await self._root.get_client().aio.batches.cancel(name=batch_id)
|
|
254
258
|
|
|
255
259
|
@override
|
|
256
260
|
async def get_batch_progress(self, batch_id: str) -> int:
|
|
@@ -262,7 +266,7 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
262
266
|
|
|
263
267
|
try:
|
|
264
268
|
self._root.logger.debug(f"Checking batch status for {batch_id}")
|
|
265
|
-
job: BatchJob = await self._root.
|
|
269
|
+
job: BatchJob = await self._root.get_client().aio.batches.get(name=batch_id)
|
|
266
270
|
state = job.state
|
|
267
271
|
|
|
268
272
|
if not state:
|
|
@@ -1,13 +1,17 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import io
|
|
3
|
+
import json
|
|
3
4
|
import logging
|
|
5
|
+
import uuid
|
|
4
6
|
from typing import Any, Literal, Sequence, cast
|
|
5
7
|
|
|
6
8
|
from google.genai import Client
|
|
7
9
|
from google.genai import errors as genai_errors
|
|
8
10
|
from google.genai.types import (
|
|
9
11
|
Content,
|
|
12
|
+
CountTokensConfig,
|
|
10
13
|
File,
|
|
14
|
+
FinishReason,
|
|
11
15
|
FunctionDeclaration,
|
|
12
16
|
GenerateContentConfig,
|
|
13
17
|
GenerateContentResponse,
|
|
@@ -21,13 +25,14 @@ from google.genai.types import (
|
|
|
21
25
|
Tool,
|
|
22
26
|
ToolListUnion,
|
|
23
27
|
UploadFileConfig,
|
|
24
|
-
FinishReason,
|
|
25
28
|
)
|
|
29
|
+
from google.oauth2 import service_account
|
|
26
30
|
from typing_extensions import override
|
|
27
31
|
|
|
28
32
|
from model_library import model_library_settings
|
|
29
33
|
from model_library.base import (
|
|
30
34
|
LLM,
|
|
35
|
+
FileBase,
|
|
31
36
|
FileInput,
|
|
32
37
|
FileWithBase64,
|
|
33
38
|
FileWithId,
|
|
@@ -40,6 +45,8 @@ from model_library.base import (
|
|
|
40
45
|
QueryResult,
|
|
41
46
|
QueryResultCost,
|
|
42
47
|
QueryResultMetadata,
|
|
48
|
+
RawInput,
|
|
49
|
+
RawResponse,
|
|
43
50
|
TextInput,
|
|
44
51
|
ToolBody,
|
|
45
52
|
ToolCall,
|
|
@@ -54,8 +61,6 @@ from model_library.exceptions import (
|
|
|
54
61
|
)
|
|
55
62
|
from model_library.providers.google.batch import GoogleBatchMixin
|
|
56
63
|
from model_library.register_models import register_provider
|
|
57
|
-
from model_library.utils import normalize_tool_result
|
|
58
|
-
import uuid
|
|
59
64
|
|
|
60
65
|
|
|
61
66
|
def generate_tool_call_id(tool_name: str) -> str:
|
|
@@ -92,31 +97,50 @@ class GoogleModel(LLM):
|
|
|
92
97
|
),
|
|
93
98
|
]
|
|
94
99
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
"gemini-2.5-flash-preview-09-2025": "global",
|
|
105
|
-
"gemini-2.5-flash-lite-preview-09-2025": "global",
|
|
100
|
+
def _get_default_api_key(self) -> str:
|
|
101
|
+
if not self.provider_config.use_vertex:
|
|
102
|
+
return model_library_settings.GOOGLE_API_KEY
|
|
103
|
+
|
|
104
|
+
return json.dumps(
|
|
105
|
+
{
|
|
106
|
+
"GCP_REGION": model_library_settings.GCP_REGION,
|
|
107
|
+
"GCP_PROJECT_ID": model_library_settings.GCP_PROJECT_ID,
|
|
108
|
+
"GCP_CREDS": model_library_settings.GCP_CREDS,
|
|
106
109
|
}
|
|
107
|
-
|
|
108
|
-
if self.model_name in MODEL_REGION_OVERRIDES:
|
|
109
|
-
region = MODEL_REGION_OVERRIDES[self.model_name]
|
|
110
|
-
|
|
111
|
-
return Client(
|
|
112
|
-
vertexai=True,
|
|
113
|
-
project=model_library_settings.GCP_PROJECT_ID,
|
|
114
|
-
location=region,
|
|
115
|
-
# Credentials object is not typed, so we have to ignore the error
|
|
116
|
-
credentials=model_library_settings.GCP_CREDS,
|
|
117
|
-
)
|
|
110
|
+
)
|
|
118
111
|
|
|
119
|
-
|
|
112
|
+
@override
|
|
113
|
+
def get_client(self, api_key: str | None = None) -> Client:
|
|
114
|
+
if not self.has_client():
|
|
115
|
+
assert api_key
|
|
116
|
+
if self.provider_config.use_vertex:
|
|
117
|
+
# Gemini preview releases are only server from the global Vertex region after September 2025.
|
|
118
|
+
MODEL_REGION_OVERRIDES: dict[str, str] = {
|
|
119
|
+
"gemini-2.5-flash-preview-09-2025": "global",
|
|
120
|
+
"gemini-2.5-flash-lite-preview-09-2025": "global",
|
|
121
|
+
"gemini-3-flash-preview": "global",
|
|
122
|
+
"gemini-3-pro-preview": "global",
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
creds = json.loads(api_key)
|
|
126
|
+
|
|
127
|
+
region = creds["GCP_REGION"]
|
|
128
|
+
if self.model_name in MODEL_REGION_OVERRIDES:
|
|
129
|
+
region = MODEL_REGION_OVERRIDES[self.model_name]
|
|
130
|
+
|
|
131
|
+
client = Client(
|
|
132
|
+
vertexai=True,
|
|
133
|
+
project=creds["GCP_PROJECT_ID"],
|
|
134
|
+
location=region,
|
|
135
|
+
credentials=service_account.Credentials.from_service_account_info( # type: ignore
|
|
136
|
+
json.loads(creds["GCP_CREDS"]),
|
|
137
|
+
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
138
|
+
),
|
|
139
|
+
)
|
|
140
|
+
else:
|
|
141
|
+
client = Client(api_key=api_key)
|
|
142
|
+
self.assign_client(client)
|
|
143
|
+
return super().get_client()
|
|
120
144
|
|
|
121
145
|
def __init__(
|
|
122
146
|
self,
|
|
@@ -138,71 +162,58 @@ class GoogleModel(LLM):
|
|
|
138
162
|
GoogleBatchMixin(self) if self.supports_batch else None
|
|
139
163
|
)
|
|
140
164
|
|
|
141
|
-
self.client = self.get_client()
|
|
142
|
-
|
|
143
165
|
@override
|
|
144
166
|
async def parse_input(
|
|
145
167
|
self,
|
|
146
168
|
input: Sequence[InputItem],
|
|
147
169
|
**kwargs: Any,
|
|
148
170
|
) -> list[Content]:
|
|
149
|
-
|
|
150
|
-
parts: list[Part] = []
|
|
171
|
+
new_input: list[Content] = []
|
|
151
172
|
|
|
152
|
-
|
|
153
|
-
nonlocal parts
|
|
173
|
+
content_user: list[Part] = []
|
|
154
174
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
parts =
|
|
175
|
+
def flush_content_user():
|
|
176
|
+
if content_user:
|
|
177
|
+
new_input.append(Content(parts=content_user, role="user"))
|
|
178
|
+
content_user.clear()
|
|
158
179
|
|
|
159
180
|
for item in input:
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
181
|
+
if isinstance(item, TextInput):
|
|
182
|
+
content_user.append(Part.from_text(text=item.text))
|
|
183
|
+
continue
|
|
184
|
+
|
|
185
|
+
if isinstance(item, FileBase):
|
|
186
|
+
parsed = await self.parse_file(item)
|
|
187
|
+
content_user.append(parsed)
|
|
188
|
+
continue
|
|
164
189
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
parts.append(part)
|
|
190
|
+
# non content user item
|
|
191
|
+
flush_content_user()
|
|
168
192
|
|
|
193
|
+
match item:
|
|
169
194
|
case ToolResult():
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
parsed_input.append(
|
|
195
|
+
# id check
|
|
196
|
+
new_input.append(
|
|
173
197
|
Content(
|
|
174
198
|
role="function",
|
|
175
199
|
parts=[
|
|
176
200
|
Part.from_function_response(
|
|
177
201
|
name=item.tool_call.name,
|
|
178
|
-
response={"result":
|
|
202
|
+
response={"result": item.result},
|
|
179
203
|
)
|
|
180
204
|
],
|
|
181
205
|
)
|
|
182
206
|
)
|
|
183
207
|
|
|
184
|
-
case
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
content0 = candidates[0].content
|
|
189
|
-
if content0 is not None:
|
|
190
|
-
parsed_input.append(content0)
|
|
191
|
-
else:
|
|
192
|
-
self.logger.debug(
|
|
193
|
-
"GenerateContentResponse missing candidates; skipping"
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
case Content():
|
|
197
|
-
flush_parts()
|
|
198
|
-
parsed_input.append(item)
|
|
208
|
+
case RawResponse():
|
|
209
|
+
new_input.extend(item.response)
|
|
210
|
+
case RawInput():
|
|
211
|
+
new_input.append(item.input)
|
|
199
212
|
|
|
200
|
-
|
|
201
|
-
|
|
213
|
+
# in case content user item is the last item
|
|
214
|
+
flush_content_user()
|
|
202
215
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
return parsed_input
|
|
216
|
+
return new_input
|
|
206
217
|
|
|
207
218
|
@override
|
|
208
219
|
async def parse_file(self, file: FileInput) -> Part:
|
|
@@ -268,7 +279,7 @@ class GoogleModel(LLM):
|
|
|
268
279
|
)
|
|
269
280
|
|
|
270
281
|
mime = f"image/{mime}" if type == "image" else mime # TODO:
|
|
271
|
-
response: File = self.
|
|
282
|
+
response: File = self.get_client().files.upload(
|
|
272
283
|
file=bytes, config=UploadFileConfig(mime_type=mime)
|
|
273
284
|
)
|
|
274
285
|
if not response.name:
|
|
@@ -284,7 +295,8 @@ class GoogleModel(LLM):
|
|
|
284
295
|
mime=mime,
|
|
285
296
|
)
|
|
286
297
|
|
|
287
|
-
|
|
298
|
+
@override
|
|
299
|
+
async def build_body(
|
|
288
300
|
self,
|
|
289
301
|
input: Sequence[InputItem],
|
|
290
302
|
*,
|
|
@@ -337,7 +349,7 @@ class GoogleModel(LLM):
|
|
|
337
349
|
query_logger: logging.Logger,
|
|
338
350
|
**kwargs: object,
|
|
339
351
|
) -> QueryResult:
|
|
340
|
-
body: dict[str, Any] = await self.
|
|
352
|
+
body: dict[str, Any] = await self.build_body(input, tools=tools, **kwargs)
|
|
341
353
|
|
|
342
354
|
text: str = ""
|
|
343
355
|
reasoning: str = ""
|
|
@@ -345,7 +357,7 @@ class GoogleModel(LLM):
|
|
|
345
357
|
|
|
346
358
|
metadata: GenerateContentResponseUsageMetadata | None = None
|
|
347
359
|
|
|
348
|
-
stream = await self.
|
|
360
|
+
stream = await self.get_client().aio.models.generate_content_stream(**body)
|
|
349
361
|
contents: list[Content | None] = []
|
|
350
362
|
finish_reason: FinishReason | None = None
|
|
351
363
|
|
|
@@ -395,7 +407,7 @@ class GoogleModel(LLM):
|
|
|
395
407
|
result = QueryResult(
|
|
396
408
|
output_text=text,
|
|
397
409
|
reasoning=reasoning,
|
|
398
|
-
history=[*input,
|
|
410
|
+
history=[*input, RawResponse(response=contents)],
|
|
399
411
|
tool_calls=tool_calls,
|
|
400
412
|
)
|
|
401
413
|
|
|
@@ -410,6 +422,51 @@ class GoogleModel(LLM):
|
|
|
410
422
|
)
|
|
411
423
|
return result
|
|
412
424
|
|
|
425
|
+
@override
|
|
426
|
+
async def count_tokens(
|
|
427
|
+
self,
|
|
428
|
+
input: Sequence[InputItem],
|
|
429
|
+
*,
|
|
430
|
+
history: Sequence[InputItem] = [],
|
|
431
|
+
tools: list[ToolDefinition] = [],
|
|
432
|
+
**kwargs: object,
|
|
433
|
+
) -> int:
|
|
434
|
+
"""
|
|
435
|
+
Count the number of tokens using Google's native token counting API.
|
|
436
|
+
https://ai.google.dev/gemini-api/docs/tokens
|
|
437
|
+
|
|
438
|
+
Only Vertex AI supports system_instruction and tools in count_tokens.
|
|
439
|
+
For Gemini API, fall back to the base implementation.
|
|
440
|
+
TODO: implement token counting for non-Vertex models.
|
|
441
|
+
"""
|
|
442
|
+
if not self.provider_config.use_vertex:
|
|
443
|
+
return await super().count_tokens(
|
|
444
|
+
input, history=history, tools=tools, **kwargs
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
input = [*history, *input]
|
|
448
|
+
if not input:
|
|
449
|
+
return 0
|
|
450
|
+
|
|
451
|
+
system_prompt = kwargs.pop("system_prompt", None)
|
|
452
|
+
contents = await self.parse_input(input, **kwargs)
|
|
453
|
+
parsed_tools = await self.parse_tools(tools) if tools else None
|
|
454
|
+
config = CountTokensConfig(
|
|
455
|
+
system_instruction=str(system_prompt) if system_prompt else None,
|
|
456
|
+
tools=parsed_tools,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
response = await self.get_client().aio.models.count_tokens(
|
|
460
|
+
model=self.model_name,
|
|
461
|
+
contents=cast(Any, contents),
|
|
462
|
+
config=config,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
if response.total_tokens is None:
|
|
466
|
+
raise ValueError("count_tokens returned None")
|
|
467
|
+
|
|
468
|
+
return response.total_tokens
|
|
469
|
+
|
|
413
470
|
@override
|
|
414
471
|
async def _calculate_cost(
|
|
415
472
|
self,
|
|
@@ -446,7 +503,7 @@ class GoogleModel(LLM):
|
|
|
446
503
|
**kwargs: object,
|
|
447
504
|
) -> PydanticT:
|
|
448
505
|
# Create the request body with JSON schema
|
|
449
|
-
body: dict[str, Any] = await self.
|
|
506
|
+
body: dict[str, Any] = await self.build_body(input, tools=[], **kwargs)
|
|
450
507
|
|
|
451
508
|
# Get the JSON schema from the Pydantic model
|
|
452
509
|
json_schema = pydantic_model.model_json_schema()
|
|
@@ -465,7 +522,7 @@ class GoogleModel(LLM):
|
|
|
465
522
|
# Make the request with retry wrapper
|
|
466
523
|
async def _query():
|
|
467
524
|
try:
|
|
468
|
-
return await self.
|
|
525
|
+
return await self.get_client().aio.models.generate_content(**body)
|
|
469
526
|
except (genai_errors.ServerError, genai_errors.UnknownApiResponseError):
|
|
470
527
|
raise ImmediateRetryException("Failed to connect to Google API")
|
|
471
528
|
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
|
|
3
5
|
from model_library import model_library_settings
|
|
4
6
|
from model_library.base import (
|
|
7
|
+
DelegateConfig,
|
|
5
8
|
DelegateOnly,
|
|
6
9
|
LLMConfig,
|
|
7
10
|
)
|
|
8
|
-
from model_library.providers.openai import OpenAIModel
|
|
9
11
|
from model_library.register_models import register_provider
|
|
10
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
@register_provider("inception")
|
|
@@ -22,13 +23,12 @@ class MercuryModel(DelegateOnly):
|
|
|
22
23
|
super().__init__(model_name, provider, config=config)
|
|
23
24
|
|
|
24
25
|
# https://docs.inceptionlabs.ai/get-started/get-started#external-libraries-compatibility
|
|
25
|
-
self.
|
|
26
|
-
model_name=self.model_name,
|
|
27
|
-
provider=self.provider,
|
|
26
|
+
self.init_delegate(
|
|
28
27
|
config=config,
|
|
29
|
-
|
|
30
|
-
api_key=model_library_settings.MERCURY_API_KEY,
|
|
28
|
+
delegate_config=DelegateConfig(
|
|
31
29
|
base_url="https://api.inceptionlabs.ai/v1/",
|
|
30
|
+
api_key=SecretStr(model_library_settings.MERCURY_API_KEY),
|
|
32
31
|
),
|
|
33
32
|
use_completions=True,
|
|
33
|
+
delegate_provider="openai",
|
|
34
34
|
)
|
model_library/providers/kimi.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
|
1
|
-
from typing import Literal
|
|
1
|
+
from typing import Any, Literal
|
|
2
|
+
|
|
3
|
+
from typing_extensions import override
|
|
4
|
+
|
|
5
|
+
from pydantic import SecretStr
|
|
2
6
|
|
|
3
7
|
from model_library import model_library_settings
|
|
4
8
|
from model_library.base import (
|
|
9
|
+
DelegateConfig,
|
|
5
10
|
DelegateOnly,
|
|
6
11
|
LLMConfig,
|
|
7
12
|
)
|
|
8
|
-
from model_library.providers.openai import OpenAIModel
|
|
9
13
|
from model_library.register_models import register_provider
|
|
10
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
11
14
|
|
|
12
15
|
|
|
13
16
|
@register_provider("kimi")
|
|
@@ -22,13 +25,20 @@ class KimiModel(DelegateOnly):
|
|
|
22
25
|
super().__init__(model_name, provider, config=config)
|
|
23
26
|
|
|
24
27
|
# https://platform.moonshot.ai/docs/guide/migrating-from-openai-to-kimi#about-api-compatibility
|
|
25
|
-
self.
|
|
26
|
-
model_name=self.model_name,
|
|
27
|
-
provider=self.provider,
|
|
28
|
+
self.init_delegate(
|
|
28
29
|
config=config,
|
|
29
|
-
|
|
30
|
-
api_key=model_library_settings.KIMI_API_KEY,
|
|
30
|
+
delegate_config=DelegateConfig(
|
|
31
31
|
base_url="https://api.moonshot.ai/v1/",
|
|
32
|
+
api_key=SecretStr(model_library_settings.KIMI_API_KEY),
|
|
32
33
|
),
|
|
33
34
|
use_completions=True,
|
|
35
|
+
delegate_provider="openai",
|
|
34
36
|
)
|
|
37
|
+
|
|
38
|
+
@override
|
|
39
|
+
def _get_extra_body(self) -> dict[str, Any]:
|
|
40
|
+
"""
|
|
41
|
+
Build extra body parameters for Kimi-specific features.
|
|
42
|
+
see https://platform.moonshot.ai/docs/guide/kimi-k2-5-quickstart#parameters-differences-in-request-body
|
|
43
|
+
"""
|
|
44
|
+
return {"thinking": {"type": "enabled" if self.reasoning else "disabled"}}
|
|
@@ -1,12 +1,17 @@
|
|
|
1
|
-
from typing import Literal
|
|
1
|
+
from typing import Literal, Sequence
|
|
2
|
+
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
from typing_extensions import override
|
|
2
5
|
|
|
3
6
|
from model_library import model_library_settings
|
|
4
|
-
from model_library.base import
|
|
5
|
-
|
|
7
|
+
from model_library.base import (
|
|
8
|
+
DelegateConfig,
|
|
9
|
+
DelegateOnly,
|
|
10
|
+
InputItem,
|
|
11
|
+
LLMConfig,
|
|
12
|
+
ToolDefinition,
|
|
13
|
+
)
|
|
6
14
|
from model_library.register_models import register_provider
|
|
7
|
-
from model_library.utils import default_httpx_client
|
|
8
|
-
|
|
9
|
-
from anthropic import AsyncAnthropic
|
|
10
15
|
|
|
11
16
|
|
|
12
17
|
@register_provider("minimax")
|
|
@@ -20,14 +25,26 @@ class MinimaxModel(DelegateOnly):
|
|
|
20
25
|
):
|
|
21
26
|
super().__init__(model_name, provider, config=config)
|
|
22
27
|
|
|
23
|
-
self.
|
|
24
|
-
model_name=self.model_name,
|
|
25
|
-
provider=self.provider,
|
|
28
|
+
self.init_delegate(
|
|
26
29
|
config=config,
|
|
27
|
-
|
|
28
|
-
api_key=model_library_settings.MINIMAX_API_KEY,
|
|
30
|
+
delegate_config=DelegateConfig(
|
|
29
31
|
base_url="https://api.minimax.io/anthropic",
|
|
30
|
-
|
|
31
|
-
max_retries=1,
|
|
32
|
+
api_key=SecretStr(model_library_settings.MINIMAX_API_KEY),
|
|
32
33
|
),
|
|
34
|
+
delegate_provider="anthropic",
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# minimax client shares anthropic's syntax
|
|
38
|
+
@override
|
|
39
|
+
async def count_tokens(
|
|
40
|
+
self,
|
|
41
|
+
input: Sequence[InputItem],
|
|
42
|
+
*,
|
|
43
|
+
history: Sequence[InputItem] = [],
|
|
44
|
+
tools: list[ToolDefinition] = [],
|
|
45
|
+
**kwargs: object,
|
|
46
|
+
) -> int:
|
|
47
|
+
assert self.delegate
|
|
48
|
+
return await self.delegate.count_tokens(
|
|
49
|
+
input, history=history, tools=tools, **kwargs
|
|
33
50
|
)
|