model-library 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +114 -12
- model_library/base/delegate_only.py +15 -1
- model_library/base/input.py +10 -7
- model_library/base/output.py +5 -0
- model_library/base/utils.py +21 -7
- model_library/config/all_models.json +92 -1
- model_library/config/fireworks_models.yaml +2 -0
- model_library/config/minimax_models.yaml +18 -0
- model_library/config/zai_models.yaml +14 -0
- model_library/exceptions.py +11 -0
- model_library/logging.py +6 -2
- model_library/providers/ai21labs.py +20 -6
- model_library/providers/amazon.py +72 -48
- model_library/providers/anthropic.py +138 -85
- model_library/providers/google/batch.py +3 -3
- model_library/providers/google/google.py +92 -46
- model_library/providers/minimax.py +29 -10
- model_library/providers/mistral.py +42 -26
- model_library/providers/openai.py +131 -77
- model_library/providers/vals.py +6 -3
- model_library/providers/xai.py +125 -113
- model_library/register_models.py +5 -3
- model_library/utils.py +0 -35
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/METADATA +3 -3
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/RECORD +28 -28
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/WHEEL +0 -0
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/top_level.txt +0 -0
model_library/base/base.py
CHANGED
|
@@ -6,7 +6,6 @@ from abc import ABC, abstractmethod
|
|
|
6
6
|
from collections.abc import Awaitable
|
|
7
7
|
from pprint import pformat
|
|
8
8
|
from typing import (
|
|
9
|
-
TYPE_CHECKING,
|
|
10
9
|
Any,
|
|
11
10
|
Callable,
|
|
12
11
|
Literal,
|
|
@@ -14,8 +13,10 @@ from typing import (
|
|
|
14
13
|
TypeVar,
|
|
15
14
|
)
|
|
16
15
|
|
|
16
|
+
import tiktoken
|
|
17
17
|
from pydantic import model_serializer
|
|
18
18
|
from pydantic.main import BaseModel
|
|
19
|
+
from tiktoken.core import Encoding
|
|
19
20
|
from typing_extensions import override
|
|
20
21
|
|
|
21
22
|
from model_library.base.batch import (
|
|
@@ -36,6 +37,7 @@ from model_library.base.output import (
|
|
|
36
37
|
)
|
|
37
38
|
from model_library.base.utils import (
|
|
38
39
|
get_pretty_input_types,
|
|
40
|
+
serialize_for_tokenizing,
|
|
39
41
|
)
|
|
40
42
|
from model_library.exceptions import (
|
|
41
43
|
ImmediateRetryException,
|
|
@@ -43,9 +45,6 @@ from model_library.exceptions import (
|
|
|
43
45
|
)
|
|
44
46
|
from model_library.utils import truncate_str
|
|
45
47
|
|
|
46
|
-
if TYPE_CHECKING:
|
|
47
|
-
from model_library.providers.openai import OpenAIModel
|
|
48
|
-
|
|
49
48
|
PydanticT = TypeVar("PydanticT", bound=BaseModel)
|
|
50
49
|
|
|
51
50
|
|
|
@@ -66,7 +65,7 @@ class LLMConfig(BaseModel):
|
|
|
66
65
|
top_p: float | None = None
|
|
67
66
|
top_k: int | None = None
|
|
68
67
|
reasoning: bool = False
|
|
69
|
-
reasoning_effort: str | None = None
|
|
68
|
+
reasoning_effort: str | bool | None = None
|
|
70
69
|
supports_images: bool = False
|
|
71
70
|
supports_files: bool = False
|
|
72
71
|
supports_videos: bool = False
|
|
@@ -110,7 +109,7 @@ class LLM(ABC):
|
|
|
110
109
|
self.top_k: int | None = config.top_k
|
|
111
110
|
|
|
112
111
|
self.reasoning: bool = config.reasoning
|
|
113
|
-
self.reasoning_effort: str | None = config.reasoning_effort
|
|
112
|
+
self.reasoning_effort: str | bool | None = config.reasoning_effort
|
|
114
113
|
|
|
115
114
|
self.supports_files: bool = config.supports_files
|
|
116
115
|
self.supports_videos: bool = config.supports_videos
|
|
@@ -120,7 +119,7 @@ class LLM(ABC):
|
|
|
120
119
|
self.supports_tools: bool = config.supports_tools
|
|
121
120
|
|
|
122
121
|
self.native: bool = config.native
|
|
123
|
-
self.delegate: "
|
|
122
|
+
self.delegate: "LLM | None" = None
|
|
124
123
|
self.batch: LLMBatchMixin | None = None
|
|
125
124
|
|
|
126
125
|
if config.provider_config:
|
|
@@ -198,11 +197,14 @@ class LLM(ABC):
|
|
|
198
197
|
input: Sequence[InputItem],
|
|
199
198
|
*,
|
|
200
199
|
tools: list[ToolDefinition] = [],
|
|
200
|
+
query_logger: logging.Logger,
|
|
201
201
|
**kwargs: object,
|
|
202
202
|
) -> QueryResult:
|
|
203
203
|
if not self.delegate:
|
|
204
204
|
raise Exception("Delegate not set")
|
|
205
|
-
return await self.delegate._query_impl(
|
|
205
|
+
return await self.delegate._query_impl( # pyright: ignore[reportPrivateUsage]
|
|
206
|
+
input, tools=tools, query_logger=query_logger, **kwargs
|
|
207
|
+
)
|
|
206
208
|
|
|
207
209
|
async def query(
|
|
208
210
|
self,
|
|
@@ -213,6 +215,7 @@ class LLM(ABC):
|
|
|
213
215
|
# for backwards compatibility
|
|
214
216
|
files: list[FileInput] = [],
|
|
215
217
|
images: list[FileInput] = [],
|
|
218
|
+
query_logger: logging.Logger | None = None,
|
|
216
219
|
**kwargs: object,
|
|
217
220
|
) -> QueryResult:
|
|
218
221
|
"""
|
|
@@ -256,15 +259,18 @@ class LLM(ABC):
|
|
|
256
259
|
input = [*history, *input]
|
|
257
260
|
|
|
258
261
|
# unique logger for the query
|
|
259
|
-
|
|
260
|
-
|
|
262
|
+
if not query_logger:
|
|
263
|
+
query_id = uuid.uuid4().hex[:14]
|
|
264
|
+
query_logger = self.logger.getChild(f"query={query_id}")
|
|
261
265
|
|
|
262
266
|
query_logger.info(
|
|
263
267
|
"Query started:\n" + item_info + tool_info + f"--- kwargs: {short_kwargs}\n"
|
|
264
268
|
)
|
|
265
269
|
|
|
266
270
|
async def query_func() -> QueryResult:
|
|
267
|
-
return await self._query_impl(
|
|
271
|
+
return await self._query_impl(
|
|
272
|
+
input, tools=tools, query_logger=query_logger, **kwargs
|
|
273
|
+
)
|
|
268
274
|
|
|
269
275
|
async def timed_query() -> tuple[QueryResult, float]:
|
|
270
276
|
return await LLM.timer_wrapper(query_func)
|
|
@@ -361,7 +367,8 @@ class LLM(ABC):
|
|
|
361
367
|
input: Sequence[InputItem],
|
|
362
368
|
*,
|
|
363
369
|
tools: list[ToolDefinition],
|
|
364
|
-
|
|
370
|
+
query_logger: logging.Logger,
|
|
371
|
+
**kwargs: object,
|
|
365
372
|
) -> QueryResult:
|
|
366
373
|
"""
|
|
367
374
|
Query the model with input
|
|
@@ -375,6 +382,20 @@ class LLM(ABC):
|
|
|
375
382
|
"""
|
|
376
383
|
...
|
|
377
384
|
|
|
385
|
+
@abstractmethod
|
|
386
|
+
async def build_body(
|
|
387
|
+
self,
|
|
388
|
+
input: Sequence[InputItem],
|
|
389
|
+
*,
|
|
390
|
+
tools: list[ToolDefinition],
|
|
391
|
+
**kwargs: Any,
|
|
392
|
+
) -> dict[str, Any]:
|
|
393
|
+
"""
|
|
394
|
+
Builds the body of the request to the model provider
|
|
395
|
+
Calls parse_input
|
|
396
|
+
"""
|
|
397
|
+
...
|
|
398
|
+
|
|
378
399
|
@abstractmethod
|
|
379
400
|
async def parse_input(
|
|
380
401
|
self,
|
|
@@ -417,6 +438,87 @@ class LLM(ABC):
|
|
|
417
438
|
"""Upload a file to the model provider"""
|
|
418
439
|
...
|
|
419
440
|
|
|
441
|
+
async def get_encoding(self) -> Encoding:
|
|
442
|
+
"""Get the appropriate tokenizer"""
|
|
443
|
+
|
|
444
|
+
model = self.model_name.lower()
|
|
445
|
+
|
|
446
|
+
if any(x in model for x in ["gpt-4o", "o1", "o3", "gpt-4.1", "gpt-5"]):
|
|
447
|
+
return tiktoken.get_encoding("o200k_base")
|
|
448
|
+
elif "gpt-4" in model or "gpt-3.5" in model:
|
|
449
|
+
try:
|
|
450
|
+
return tiktoken.encoding_for_model(self.model_name)
|
|
451
|
+
except KeyError:
|
|
452
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
453
|
+
elif "claude" in model:
|
|
454
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
455
|
+
elif "gemini" in model:
|
|
456
|
+
return tiktoken.get_encoding("o200k_base")
|
|
457
|
+
elif "llama" in model or "mistral" in model:
|
|
458
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
459
|
+
else:
|
|
460
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
461
|
+
|
|
462
|
+
async def stringify_input(
|
|
463
|
+
self,
|
|
464
|
+
input: Sequence[InputItem],
|
|
465
|
+
*,
|
|
466
|
+
history: Sequence[InputItem] = [],
|
|
467
|
+
tools: list[ToolDefinition] = [],
|
|
468
|
+
**kwargs: object,
|
|
469
|
+
) -> str:
|
|
470
|
+
input = [*history, *input]
|
|
471
|
+
|
|
472
|
+
system_prompt = kwargs.pop(
|
|
473
|
+
"system_prompt", ""
|
|
474
|
+
) # TODO: refactor along with system prompt arg change
|
|
475
|
+
|
|
476
|
+
# special case if using a delegate
|
|
477
|
+
# don't inherit method override by default
|
|
478
|
+
if self.delegate:
|
|
479
|
+
parsed_input = await self.delegate.parse_input(input, **kwargs)
|
|
480
|
+
parsed_tools = await self.delegate.parse_tools(tools)
|
|
481
|
+
else:
|
|
482
|
+
parsed_input = await self.parse_input(input, **kwargs)
|
|
483
|
+
parsed_tools = await self.parse_tools(tools)
|
|
484
|
+
|
|
485
|
+
serialized_input = serialize_for_tokenizing(parsed_input)
|
|
486
|
+
serialized_tools = serialize_for_tokenizing(parsed_tools)
|
|
487
|
+
|
|
488
|
+
combined = f"{system_prompt}\n{serialized_input}\n{serialized_tools}"
|
|
489
|
+
|
|
490
|
+
return combined
|
|
491
|
+
|
|
492
|
+
async def count_tokens(
|
|
493
|
+
self,
|
|
494
|
+
input: Sequence[InputItem],
|
|
495
|
+
*,
|
|
496
|
+
history: Sequence[InputItem] = [],
|
|
497
|
+
tools: list[ToolDefinition] = [],
|
|
498
|
+
**kwargs: object,
|
|
499
|
+
) -> int:
|
|
500
|
+
"""
|
|
501
|
+
Count the number of tokens for a query.
|
|
502
|
+
Combines parsed input and tools, then tokenizes the result.
|
|
503
|
+
"""
|
|
504
|
+
|
|
505
|
+
if not input and not history:
|
|
506
|
+
return 0
|
|
507
|
+
|
|
508
|
+
if self.delegate:
|
|
509
|
+
encoding = await self.delegate.get_encoding()
|
|
510
|
+
else:
|
|
511
|
+
encoding = await self.get_encoding()
|
|
512
|
+
self.logger.debug(f"Token Count Encoding: {encoding}")
|
|
513
|
+
|
|
514
|
+
string_input = await self.stringify_input(
|
|
515
|
+
input, history=history, tools=tools, **kwargs
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
count = len(encoding.encode(string_input, disallowed_special=()))
|
|
519
|
+
self.logger.debug(f"Combined Token Count Input: {count}")
|
|
520
|
+
return count
|
|
521
|
+
|
|
420
522
|
async def query_json(
|
|
421
523
|
self,
|
|
422
524
|
input: Sequence[InputItem],
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import io
|
|
2
|
+
import logging
|
|
2
3
|
from typing import Any, Literal, Sequence
|
|
3
4
|
|
|
4
5
|
from typing_extensions import override
|
|
@@ -48,11 +49,24 @@ class DelegateOnly(LLM):
|
|
|
48
49
|
input: Sequence[InputItem],
|
|
49
50
|
*,
|
|
50
51
|
tools: list[ToolDefinition],
|
|
52
|
+
query_logger: logging.Logger,
|
|
51
53
|
**kwargs: object,
|
|
52
54
|
) -> QueryResult:
|
|
53
55
|
assert self.delegate
|
|
54
56
|
|
|
55
|
-
return await self.delegate_query(
|
|
57
|
+
return await self.delegate_query(
|
|
58
|
+
input, tools=tools, query_logger=query_logger, **kwargs
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
@override
|
|
62
|
+
async def build_body(
|
|
63
|
+
self,
|
|
64
|
+
input: Sequence[InputItem],
|
|
65
|
+
*,
|
|
66
|
+
tools: list[ToolDefinition],
|
|
67
|
+
**kwargs: object,
|
|
68
|
+
) -> dict[str, Any]:
|
|
69
|
+
raise DelegateOnlyException()
|
|
56
70
|
|
|
57
71
|
@override
|
|
58
72
|
async def parse_input(
|
model_library/base/input.py
CHANGED
|
@@ -74,8 +74,6 @@ class ToolCall(BaseModel):
|
|
|
74
74
|
--- INPUT ---
|
|
75
75
|
"""
|
|
76
76
|
|
|
77
|
-
RawResponse = Any
|
|
78
|
-
|
|
79
77
|
|
|
80
78
|
class ToolInput(BaseModel):
|
|
81
79
|
tools: list[ToolDefinition] = []
|
|
@@ -90,11 +88,16 @@ class TextInput(BaseModel):
|
|
|
90
88
|
text: str
|
|
91
89
|
|
|
92
90
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
91
|
+
class RawResponse(BaseModel):
|
|
92
|
+
# used to store a received response
|
|
93
|
+
response: Any
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class RawInput(BaseModel):
|
|
97
|
+
# used to pass in anything provider specific (e.g. a mock conversation)
|
|
98
|
+
input: Any
|
|
96
99
|
|
|
97
100
|
|
|
98
101
|
InputItem = (
|
|
99
|
-
TextInput | FileInput | ToolResult |
|
|
100
|
-
) # input item can either be a prompt, a file (image or file), a tool call result,
|
|
102
|
+
TextInput | FileInput | ToolResult | RawInput | RawResponse
|
|
103
|
+
) # input item can either be a prompt, a file (image or file), a tool call result, a previous response, or raw input
|
model_library/base/output.py
CHANGED
|
@@ -24,6 +24,11 @@ class Citation(BaseModel):
|
|
|
24
24
|
index: int | None = None
|
|
25
25
|
container_id: str | None = None
|
|
26
26
|
|
|
27
|
+
@override
|
|
28
|
+
def __repr__(self):
|
|
29
|
+
attrs = vars(self).copy()
|
|
30
|
+
return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2)}\n)"
|
|
31
|
+
|
|
27
32
|
|
|
28
33
|
class QueryResultExtras(BaseModel):
|
|
29
34
|
citations: list[Citation] = Field(default_factory=list)
|
model_library/base/utils.py
CHANGED
|
@@ -1,18 +1,34 @@
|
|
|
1
|
-
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any, Sequence, TypeVar
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
2
5
|
|
|
3
6
|
from model_library.base.input import (
|
|
4
7
|
FileBase,
|
|
5
8
|
InputItem,
|
|
6
|
-
|
|
9
|
+
RawInput,
|
|
10
|
+
RawResponse,
|
|
7
11
|
TextInput,
|
|
8
12
|
ToolResult,
|
|
9
13
|
)
|
|
10
14
|
from model_library.utils import truncate_str
|
|
11
|
-
from pydantic import BaseModel
|
|
12
15
|
|
|
13
16
|
T = TypeVar("T", bound=BaseModel)
|
|
14
17
|
|
|
15
18
|
|
|
19
|
+
def serialize_for_tokenizing(content: Any) -> str:
|
|
20
|
+
"""
|
|
21
|
+
Serialize parsed content into a string for tokenization
|
|
22
|
+
"""
|
|
23
|
+
parts: list[str] = []
|
|
24
|
+
if content:
|
|
25
|
+
if isinstance(content, str):
|
|
26
|
+
parts.append(content)
|
|
27
|
+
else:
|
|
28
|
+
parts.append(json.dumps(content, default=str))
|
|
29
|
+
return "\n".join(parts)
|
|
30
|
+
|
|
31
|
+
|
|
16
32
|
def add_optional(
|
|
17
33
|
a: int | float | T | None, b: int | float | T | None
|
|
18
34
|
) -> int | float | T | None:
|
|
@@ -54,11 +70,9 @@ def get_pretty_input_types(input: Sequence["InputItem"], verbose: bool = False)
|
|
|
54
70
|
return repr(item)
|
|
55
71
|
case ToolResult():
|
|
56
72
|
return repr(item)
|
|
57
|
-
case
|
|
58
|
-
item = cast(RawInputItem, item)
|
|
73
|
+
case RawInput():
|
|
59
74
|
return repr(item)
|
|
60
|
-
case
|
|
61
|
-
# RawResponse
|
|
75
|
+
case RawResponse():
|
|
62
76
|
return repr(item)
|
|
63
77
|
|
|
64
78
|
processed_items = [f" {process_item(item)}" for item in input]
|
|
@@ -1,4 +1,94 @@
|
|
|
1
1
|
{
|
|
2
|
+
"minimax/MiniMax-M2.1": {
|
|
3
|
+
"company": "MiniMax",
|
|
4
|
+
"label": "MiniMax-M2.1",
|
|
5
|
+
"description": null,
|
|
6
|
+
"release_date": "2025-12-23",
|
|
7
|
+
"open_source": true,
|
|
8
|
+
"documentation_url": "https://platform.minimax.io/docs",
|
|
9
|
+
"properties": {
|
|
10
|
+
"context_window": 204800,
|
|
11
|
+
"max_tokens": 131000,
|
|
12
|
+
"training_cutoff": null,
|
|
13
|
+
"reasoning_model": true
|
|
14
|
+
},
|
|
15
|
+
"supports": {
|
|
16
|
+
"images": false,
|
|
17
|
+
"files": false,
|
|
18
|
+
"temperature": true,
|
|
19
|
+
"tools": true
|
|
20
|
+
},
|
|
21
|
+
"metadata": {
|
|
22
|
+
"deprecated": false,
|
|
23
|
+
"available_for_everyone": true,
|
|
24
|
+
"available_as_evaluator": false,
|
|
25
|
+
"ignored_for_cost": false
|
|
26
|
+
},
|
|
27
|
+
"provider_properties": {},
|
|
28
|
+
"costs_per_million_token": {
|
|
29
|
+
"input": 0.3,
|
|
30
|
+
"output": 1.2,
|
|
31
|
+
"cache": {
|
|
32
|
+
"read": 0.03,
|
|
33
|
+
"write": 0.375,
|
|
34
|
+
"write_markup": 1.0
|
|
35
|
+
}
|
|
36
|
+
},
|
|
37
|
+
"alternative_keys": [],
|
|
38
|
+
"default_parameters": {
|
|
39
|
+
"temperature": 1.0,
|
|
40
|
+
"top_p": 0.95
|
|
41
|
+
},
|
|
42
|
+
"provider_endpoint": "MiniMax-M2.1",
|
|
43
|
+
"provider_name": "minimax",
|
|
44
|
+
"full_key": "minimax/MiniMax-M2.1",
|
|
45
|
+
"slug": "minimax_MiniMax-M2.1"
|
|
46
|
+
},
|
|
47
|
+
"zai/glm-4.7": {
|
|
48
|
+
"company": "zAI",
|
|
49
|
+
"label": "GLM 4.7",
|
|
50
|
+
"description": "Latest model from ZAI",
|
|
51
|
+
"release_date": "2025-12-22",
|
|
52
|
+
"open_source": true,
|
|
53
|
+
"documentation_url": "https://docs.z.ai/",
|
|
54
|
+
"properties": {
|
|
55
|
+
"context_window": 200000,
|
|
56
|
+
"max_tokens": 128000,
|
|
57
|
+
"training_cutoff": null,
|
|
58
|
+
"reasoning_model": true
|
|
59
|
+
},
|
|
60
|
+
"supports": {
|
|
61
|
+
"images": false,
|
|
62
|
+
"files": false,
|
|
63
|
+
"temperature": true,
|
|
64
|
+
"tools": true
|
|
65
|
+
},
|
|
66
|
+
"metadata": {
|
|
67
|
+
"deprecated": false,
|
|
68
|
+
"available_for_everyone": true,
|
|
69
|
+
"available_as_evaluator": false,
|
|
70
|
+
"ignored_for_cost": false
|
|
71
|
+
},
|
|
72
|
+
"provider_properties": {},
|
|
73
|
+
"costs_per_million_token": {
|
|
74
|
+
"input": 0.6,
|
|
75
|
+
"output": 2.2,
|
|
76
|
+
"cache": {
|
|
77
|
+
"read": 0.11,
|
|
78
|
+
"read_discount": 1.0,
|
|
79
|
+
"write_markup": 1.0
|
|
80
|
+
}
|
|
81
|
+
},
|
|
82
|
+
"alternative_keys": [],
|
|
83
|
+
"default_parameters": {
|
|
84
|
+
"temperature": 1.0,
|
|
85
|
+
"top_p": 1.0
|
|
86
|
+
},
|
|
87
|
+
"provider_endpoint": "glm-4.7",
|
|
88
|
+
"provider_name": "zai",
|
|
89
|
+
"full_key": "zai/glm-4.7",
|
|
90
|
+
"slug": "zai_glm-4.7"
|
|
91
|
+
},
|
|
2
92
|
"google/gemini-3-flash-preview": {
|
|
3
93
|
"company": "Google",
|
|
4
94
|
"label": "Gemini 3 Flash (12/25)",
|
|
@@ -504,7 +594,8 @@
|
|
|
504
594
|
}
|
|
505
595
|
],
|
|
506
596
|
"default_parameters": {
|
|
507
|
-
"temperature": 1.0
|
|
597
|
+
"temperature": 1.0,
|
|
598
|
+
"reasoning_effort": "none"
|
|
508
599
|
},
|
|
509
600
|
"provider_endpoint": "deepseek-v3p2",
|
|
510
601
|
"provider_name": "fireworks",
|
|
@@ -16,6 +16,24 @@ base-config:
|
|
|
16
16
|
|
|
17
17
|
minimax-m2-models:
|
|
18
18
|
|
|
19
|
+
minimax/MiniMax-M2.1:
|
|
20
|
+
label: MiniMax-M2.1
|
|
21
|
+
release_date: 2025-12-23
|
|
22
|
+
properties:
|
|
23
|
+
context_window: 204_800
|
|
24
|
+
max_tokens: 131_000
|
|
25
|
+
reasoning_model: true
|
|
26
|
+
training_cutoff: null
|
|
27
|
+
default_parameters:
|
|
28
|
+
temperature: 1.0
|
|
29
|
+
top_p: 0.95
|
|
30
|
+
costs_per_million_token:
|
|
31
|
+
input: 0.30
|
|
32
|
+
output: 1.20
|
|
33
|
+
cache:
|
|
34
|
+
read: 0.03
|
|
35
|
+
write: 0.375
|
|
36
|
+
|
|
19
37
|
minimax/MiniMax-M2:
|
|
20
38
|
label: MiniMax-M2
|
|
21
39
|
description: MiniMax-M2 is a cost-efficient open-source model optimized for agentic applications and coding in particular.
|
|
@@ -18,6 +18,20 @@ base-config:
|
|
|
18
18
|
write_markup: 1
|
|
19
19
|
|
|
20
20
|
zai-models:
|
|
21
|
+
zai/glm-4.7:
|
|
22
|
+
label: GLM 4.7
|
|
23
|
+
description: "Latest model from ZAI"
|
|
24
|
+
release_date: 2025-12-22
|
|
25
|
+
properties:
|
|
26
|
+
context_window: 200_000
|
|
27
|
+
max_tokens: 128_000
|
|
28
|
+
costs_per_million_token:
|
|
29
|
+
input: 0.6
|
|
30
|
+
output: 2.2
|
|
31
|
+
cache:
|
|
32
|
+
read: 0.11
|
|
33
|
+
default_parameters:
|
|
34
|
+
temperature: 1
|
|
21
35
|
zai/glm-4.5:
|
|
22
36
|
label: GLM 4.5
|
|
23
37
|
description: "z.AI old model"
|
model_library/exceptions.py
CHANGED
|
@@ -146,6 +146,17 @@ class BadInputError(Exception):
|
|
|
146
146
|
super().__init__(message or BadInputError.DEFAULT_MESSAGE)
|
|
147
147
|
|
|
148
148
|
|
|
149
|
+
class NoMatchingToolCallError(Exception):
|
|
150
|
+
"""
|
|
151
|
+
Raised when a tool call result is provided with no matching tool call
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
DEFAULT_MESSAGE: str = "Tool call result provided with no matching tool call"
|
|
155
|
+
|
|
156
|
+
def __init__(self, message: str | None = None):
|
|
157
|
+
super().__init__(message or NoMatchingToolCallError.DEFAULT_MESSAGE)
|
|
158
|
+
|
|
159
|
+
|
|
149
160
|
# Add more retriable exceptions as needed
|
|
150
161
|
# Providers that don't have an explicit rate limit error are handled manually
|
|
151
162
|
# by wrapping errored Http/gRPC requests with a BackoffRetryException
|
model_library/logging.py
CHANGED
|
@@ -6,7 +6,11 @@ from rich.logging import RichHandler
|
|
|
6
6
|
_llm_logger = logging.getLogger("llm")
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def set_logging(
|
|
9
|
+
def set_logging(
|
|
10
|
+
enable: bool = True,
|
|
11
|
+
level: int = logging.INFO,
|
|
12
|
+
handler: logging.Handler | None = None,
|
|
13
|
+
):
|
|
10
14
|
"""
|
|
11
15
|
Sets up logging for the model library
|
|
12
16
|
|
|
@@ -15,7 +19,7 @@ def set_logging(enable: bool = True, handler: logging.Handler | None = None):
|
|
|
15
19
|
handler (logging.Handler, optional): A custom logging handler. Defaults to RichHandler.
|
|
16
20
|
"""
|
|
17
21
|
if enable:
|
|
18
|
-
_llm_logger.setLevel(
|
|
22
|
+
_llm_logger.setLevel(level)
|
|
19
23
|
else:
|
|
20
24
|
_llm_logger.setLevel(logging.CRITICAL)
|
|
21
25
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import io
|
|
2
|
+
import logging
|
|
2
3
|
from typing import Any, Literal, Sequence
|
|
3
4
|
|
|
4
5
|
from ai21 import AsyncAI21Client
|
|
@@ -21,6 +22,7 @@ from model_library.base import (
|
|
|
21
22
|
ToolDefinition,
|
|
22
23
|
ToolResult,
|
|
23
24
|
)
|
|
25
|
+
from model_library.base.input import RawResponse
|
|
24
26
|
from model_library.exceptions import (
|
|
25
27
|
BadInputError,
|
|
26
28
|
MaxOutputTokensExceededError,
|
|
@@ -64,8 +66,6 @@ class AI21LabsModel(LLM):
|
|
|
64
66
|
match item:
|
|
65
67
|
case TextInput():
|
|
66
68
|
new_input.append(ChatMessage(role="user", content=item.text))
|
|
67
|
-
case AssistantMessage():
|
|
68
|
-
new_input.append(item)
|
|
69
69
|
case ToolResult():
|
|
70
70
|
new_input.append(
|
|
71
71
|
ToolMessage(
|
|
@@ -73,7 +73,9 @@ class AI21LabsModel(LLM):
|
|
|
73
73
|
content=item.result,
|
|
74
74
|
tool_call_id=item.tool_call.id,
|
|
75
75
|
)
|
|
76
|
-
)
|
|
76
|
+
) # TODO: tool calling metadata and test
|
|
77
|
+
case RawResponse():
|
|
78
|
+
new_input.append(item.response)
|
|
77
79
|
case _:
|
|
78
80
|
raise BadInputError("Unsupported input type")
|
|
79
81
|
return new_input
|
|
@@ -132,13 +134,13 @@ class AI21LabsModel(LLM):
|
|
|
132
134
|
raise NotImplementedError()
|
|
133
135
|
|
|
134
136
|
@override
|
|
135
|
-
async def
|
|
137
|
+
async def build_body(
|
|
136
138
|
self,
|
|
137
139
|
input: Sequence[InputItem],
|
|
138
140
|
*,
|
|
139
141
|
tools: list[ToolDefinition],
|
|
140
142
|
**kwargs: object,
|
|
141
|
-
) ->
|
|
143
|
+
) -> dict[str, Any]:
|
|
142
144
|
messages: list[ChatMessage] = []
|
|
143
145
|
if "system_prompt" in kwargs:
|
|
144
146
|
messages.append(
|
|
@@ -160,6 +162,18 @@ class AI21LabsModel(LLM):
|
|
|
160
162
|
body["top_p"] = self.top_p
|
|
161
163
|
|
|
162
164
|
body.update(kwargs)
|
|
165
|
+
return body
|
|
166
|
+
|
|
167
|
+
@override
|
|
168
|
+
async def _query_impl(
|
|
169
|
+
self,
|
|
170
|
+
input: Sequence[InputItem],
|
|
171
|
+
*,
|
|
172
|
+
tools: list[ToolDefinition],
|
|
173
|
+
query_logger: logging.Logger,
|
|
174
|
+
**kwargs: object,
|
|
175
|
+
) -> QueryResult:
|
|
176
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
163
177
|
|
|
164
178
|
response: ChatCompletionResponse = (
|
|
165
179
|
await self.get_client().chat.completions.create(**body, stream=False) # pyright: ignore[reportAny, reportUnknownMemberType]
|
|
@@ -184,7 +198,7 @@ class AI21LabsModel(LLM):
|
|
|
184
198
|
|
|
185
199
|
output = QueryResult(
|
|
186
200
|
output_text=choice.message.content,
|
|
187
|
-
history=[*input, choice.message],
|
|
201
|
+
history=[*input, RawResponse(response=choice.message)],
|
|
188
202
|
metadata=QueryResultMetadata(
|
|
189
203
|
in_tokens=response.usage.prompt_tokens,
|
|
190
204
|
out_tokens=response.usage.completion_tokens,
|