model-library 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/__init__.py +7 -0
- model_library/{base.py → base/base.py} +58 -429
- model_library/base/batch.py +121 -0
- model_library/base/delegate_only.py +94 -0
- model_library/base/input.py +100 -0
- model_library/base/output.py +229 -0
- model_library/base/utils.py +43 -0
- model_library/config/ai21labs_models.yaml +1 -0
- model_library/config/all_models.json +461 -36
- model_library/config/anthropic_models.yaml +30 -3
- model_library/config/deepseek_models.yaml +3 -1
- model_library/config/google_models.yaml +49 -0
- model_library/config/openai_models.yaml +43 -4
- model_library/config/together_models.yaml +1 -0
- model_library/config/xai_models.yaml +63 -3
- model_library/exceptions.py +8 -2
- model_library/file_utils.py +1 -1
- model_library/providers/__init__.py +0 -0
- model_library/providers/ai21labs.py +2 -0
- model_library/providers/alibaba.py +16 -78
- model_library/providers/amazon.py +3 -0
- model_library/providers/anthropic.py +215 -8
- model_library/providers/azure.py +2 -0
- model_library/providers/cohere.py +14 -80
- model_library/providers/deepseek.py +14 -90
- model_library/providers/fireworks.py +17 -81
- model_library/providers/google/google.py +55 -47
- model_library/providers/inception.py +15 -83
- model_library/providers/kimi.py +15 -83
- model_library/providers/mistral.py +2 -0
- model_library/providers/openai.py +10 -2
- model_library/providers/perplexity.py +12 -79
- model_library/providers/together.py +19 -210
- model_library/providers/vals.py +2 -0
- model_library/providers/xai.py +2 -0
- model_library/providers/zai.py +15 -83
- model_library/register_models.py +75 -57
- model_library/registry_utils.py +5 -5
- model_library/utils.py +3 -28
- {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/METADATA +2 -3
- model_library-0.1.3.dist-info/RECORD +61 -0
- model_library-0.1.1.dist-info/RECORD +0 -54
- {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/WHEEL +0 -0
- {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import io
|
|
2
|
+
from typing import Any, Literal, Sequence
|
|
3
|
+
|
|
4
|
+
from typing_extensions import override
|
|
5
|
+
|
|
6
|
+
from model_library.base import (
|
|
7
|
+
LLM,
|
|
8
|
+
FileInput,
|
|
9
|
+
FileWithId,
|
|
10
|
+
InputItem,
|
|
11
|
+
LLMConfig,
|
|
12
|
+
QueryResult,
|
|
13
|
+
ToolDefinition,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DelegateOnlyException(Exception):
|
|
18
|
+
"""
|
|
19
|
+
Raised when native model functionality is performed on a
|
|
20
|
+
delegate-only model.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
DEFAULT_MESSAGE: str = "This model supports only delegate-only functionality. Only the query() method should be used."
|
|
24
|
+
|
|
25
|
+
def __init__(self, message: str | None = None):
|
|
26
|
+
super().__init__(message or DelegateOnlyException.DEFAULT_MESSAGE)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class DelegateOnly(LLM):
|
|
30
|
+
@override
|
|
31
|
+
def get_client(self) -> None:
|
|
32
|
+
raise DelegateOnlyException()
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
model_name: str,
|
|
37
|
+
provider: str,
|
|
38
|
+
*,
|
|
39
|
+
config: LLMConfig | None = None,
|
|
40
|
+
):
|
|
41
|
+
config = config or LLMConfig()
|
|
42
|
+
config.native = False
|
|
43
|
+
super().__init__(model_name, provider, config=config)
|
|
44
|
+
|
|
45
|
+
@override
|
|
46
|
+
async def _query_impl(
|
|
47
|
+
self,
|
|
48
|
+
input: Sequence[InputItem],
|
|
49
|
+
*,
|
|
50
|
+
tools: list[ToolDefinition],
|
|
51
|
+
**kwargs: object,
|
|
52
|
+
) -> QueryResult:
|
|
53
|
+
assert self.delegate
|
|
54
|
+
|
|
55
|
+
return await self.delegate_query(input, tools=tools, **kwargs)
|
|
56
|
+
|
|
57
|
+
@override
|
|
58
|
+
async def parse_input(
|
|
59
|
+
self,
|
|
60
|
+
input: Sequence[InputItem],
|
|
61
|
+
**kwargs: Any,
|
|
62
|
+
) -> Any:
|
|
63
|
+
raise DelegateOnlyException()
|
|
64
|
+
|
|
65
|
+
@override
|
|
66
|
+
async def parse_image(
|
|
67
|
+
self,
|
|
68
|
+
image: FileInput,
|
|
69
|
+
) -> Any:
|
|
70
|
+
raise DelegateOnlyException()
|
|
71
|
+
|
|
72
|
+
@override
|
|
73
|
+
async def parse_file(
|
|
74
|
+
self,
|
|
75
|
+
file: FileInput,
|
|
76
|
+
) -> Any:
|
|
77
|
+
raise DelegateOnlyException()
|
|
78
|
+
|
|
79
|
+
@override
|
|
80
|
+
async def parse_tools(
|
|
81
|
+
self,
|
|
82
|
+
tools: list[ToolDefinition],
|
|
83
|
+
) -> Any:
|
|
84
|
+
raise DelegateOnlyException()
|
|
85
|
+
|
|
86
|
+
@override
|
|
87
|
+
async def upload_file(
|
|
88
|
+
self,
|
|
89
|
+
name: str,
|
|
90
|
+
mime: str,
|
|
91
|
+
bytes: io.BytesIO,
|
|
92
|
+
type: Literal["image", "file"] = "file",
|
|
93
|
+
) -> FileWithId:
|
|
94
|
+
raise DelegateOnlyException()
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from pprint import pformat
|
|
2
|
+
from typing import Annotated, Any, Literal
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
|
|
7
|
+
from model_library.utils import truncate_str
|
|
8
|
+
|
|
9
|
+
"""
|
|
10
|
+
--- FILES ---
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class FileBase(BaseModel):
|
|
15
|
+
type: Literal["image", "file"]
|
|
16
|
+
name: str
|
|
17
|
+
mime: str
|
|
18
|
+
|
|
19
|
+
@override
|
|
20
|
+
def __repr__(self):
|
|
21
|
+
attrs = vars(self).copy()
|
|
22
|
+
if "base64" in attrs:
|
|
23
|
+
attrs["base64"] = truncate_str(attrs["base64"])
|
|
24
|
+
return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2)}\n)"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class FileWithBase64(FileBase):
|
|
28
|
+
append_type: Literal["base64"] = "base64"
|
|
29
|
+
base64: str
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class FileWithUrl(FileBase):
|
|
33
|
+
append_type: Literal["url"] = "url"
|
|
34
|
+
url: str
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class FileWithId(FileBase):
|
|
38
|
+
append_type: Literal["file_id"] = "file_id"
|
|
39
|
+
file_id: str
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
FileInput = Annotated[
|
|
43
|
+
FileWithBase64 | FileWithUrl | FileWithId,
|
|
44
|
+
Field(discriminator="append_type"),
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
--- TOOLS ---
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ToolBody(BaseModel):
|
|
54
|
+
name: str
|
|
55
|
+
description: str
|
|
56
|
+
properties: dict[str, Any]
|
|
57
|
+
required: list[str]
|
|
58
|
+
kwargs: dict[str, Any] = {}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ToolDefinition(BaseModel):
|
|
62
|
+
name: str # acts as a key
|
|
63
|
+
body: ToolBody | Any
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ToolCall(BaseModel):
|
|
67
|
+
id: str
|
|
68
|
+
call_id: str | None = None
|
|
69
|
+
name: str
|
|
70
|
+
args: dict[str, Any] | str
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
"""
|
|
74
|
+
--- INPUT ---
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
RawResponse = Any
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class ToolInput(BaseModel):
|
|
81
|
+
tools: list[ToolDefinition] = []
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class ToolResult(BaseModel):
|
|
85
|
+
tool_call: ToolCall
|
|
86
|
+
result: Any
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class TextInput(BaseModel):
|
|
90
|
+
text: str
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
RawInputItem = dict[
|
|
94
|
+
str, Any
|
|
95
|
+
] # to pass in, for example, a mock convertsation with {"role": "user", "content": "Hello"}
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
InputItem = (
|
|
99
|
+
TextInput | FileInput | ToolResult | RawInputItem | RawResponse
|
|
100
|
+
) # input item can either be a prompt, a file (image or file), a tool call result, raw input, or a previous response
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""
|
|
2
|
+
--- OUTPUT ---
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from pprint import pformat
|
|
6
|
+
from typing import Any, Mapping, Sequence, cast
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field, computed_field, field_validator
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from model_library.base.input import InputItem, ToolCall
|
|
12
|
+
from model_library.base.utils import (
|
|
13
|
+
sum_optional,
|
|
14
|
+
)
|
|
15
|
+
from model_library.utils import truncate_str
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Citation(BaseModel):
|
|
19
|
+
type: str | None = None
|
|
20
|
+
title: str | None = None
|
|
21
|
+
url: str | None = None
|
|
22
|
+
start_index: int | None = None
|
|
23
|
+
end_index: int | None = None
|
|
24
|
+
file_id: str | None = None
|
|
25
|
+
filename: str | None = None
|
|
26
|
+
index: int | None = None
|
|
27
|
+
container_id: str | None = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class QueryResultExtras(BaseModel):
|
|
31
|
+
citations: list[Citation] = Field(default_factory=list)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class QueryResultCost(BaseModel):
|
|
35
|
+
"""
|
|
36
|
+
Cost information for a query
|
|
37
|
+
Includes total cost and a structured breakdown.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
input: float
|
|
41
|
+
output: float
|
|
42
|
+
reasoning: float | None = None
|
|
43
|
+
cache_read: float | None = None
|
|
44
|
+
cache_write: float | None = None
|
|
45
|
+
|
|
46
|
+
@computed_field
|
|
47
|
+
@property
|
|
48
|
+
def total(self) -> float:
|
|
49
|
+
return sum(
|
|
50
|
+
filter(
|
|
51
|
+
None,
|
|
52
|
+
[
|
|
53
|
+
self.input,
|
|
54
|
+
self.output,
|
|
55
|
+
self.reasoning,
|
|
56
|
+
self.cache_read,
|
|
57
|
+
self.cache_write,
|
|
58
|
+
],
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
@computed_field
|
|
63
|
+
@property
|
|
64
|
+
def total_input(self) -> float:
|
|
65
|
+
return sum(
|
|
66
|
+
filter(
|
|
67
|
+
None,
|
|
68
|
+
[
|
|
69
|
+
self.input,
|
|
70
|
+
self.cache_read,
|
|
71
|
+
self.cache_write,
|
|
72
|
+
],
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
@computed_field
|
|
77
|
+
@property
|
|
78
|
+
def total_output(self) -> float:
|
|
79
|
+
return sum(
|
|
80
|
+
filter(
|
|
81
|
+
None,
|
|
82
|
+
[
|
|
83
|
+
self.output,
|
|
84
|
+
self.reasoning,
|
|
85
|
+
],
|
|
86
|
+
)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
@override
|
|
90
|
+
def __repr__(self):
|
|
91
|
+
use_cents = self.total < 1
|
|
92
|
+
|
|
93
|
+
def format_cost(value: float | None):
|
|
94
|
+
if value is None:
|
|
95
|
+
return None
|
|
96
|
+
return f"{value * 100:.3f} cents" if use_cents else f"${value:.2f}"
|
|
97
|
+
|
|
98
|
+
return (
|
|
99
|
+
f"{format_cost(self.total)} "
|
|
100
|
+
+ f"(uncached input: {format_cost(self.input)} | output: {format_cost(self.output)} | reasoning: {format_cost(self.reasoning)} | cache_read: {format_cost(self.cache_read)} | cache_write: {format_cost(self.cache_write)})"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class QueryResultMetadata(BaseModel):
|
|
105
|
+
"""
|
|
106
|
+
Metadata for a query: token usage and timing.
|
|
107
|
+
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
cost: QueryResultCost | None = None # set post query
|
|
111
|
+
duration_seconds: float | None = None # set post query
|
|
112
|
+
in_tokens: int = 0
|
|
113
|
+
out_tokens: int = 0
|
|
114
|
+
reasoning_tokens: int | None = None
|
|
115
|
+
cache_read_tokens: int | None = None
|
|
116
|
+
cache_write_tokens: int | None = None
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def default_duration_seconds(self) -> float:
|
|
120
|
+
return self.duration_seconds or 0
|
|
121
|
+
|
|
122
|
+
@computed_field
|
|
123
|
+
@property
|
|
124
|
+
def total_input_tokens(self) -> int:
|
|
125
|
+
return sum(
|
|
126
|
+
filter(
|
|
127
|
+
None,
|
|
128
|
+
[
|
|
129
|
+
self.in_tokens,
|
|
130
|
+
self.cache_read_tokens,
|
|
131
|
+
self.cache_write_tokens,
|
|
132
|
+
],
|
|
133
|
+
)
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
@computed_field
|
|
137
|
+
@property
|
|
138
|
+
def total_output_tokens(self) -> int:
|
|
139
|
+
return sum(
|
|
140
|
+
filter(
|
|
141
|
+
None,
|
|
142
|
+
[
|
|
143
|
+
self.out_tokens,
|
|
144
|
+
self.reasoning_tokens,
|
|
145
|
+
],
|
|
146
|
+
)
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def __add__(self, other: "QueryResultMetadata") -> "QueryResultMetadata":
|
|
150
|
+
return QueryResultMetadata(
|
|
151
|
+
in_tokens=self.in_tokens + other.in_tokens,
|
|
152
|
+
out_tokens=self.out_tokens + other.out_tokens,
|
|
153
|
+
reasoning_tokens=sum_optional(
|
|
154
|
+
self.reasoning_tokens, other.reasoning_tokens
|
|
155
|
+
),
|
|
156
|
+
cache_read_tokens=sum_optional(
|
|
157
|
+
self.cache_read_tokens, other.cache_read_tokens
|
|
158
|
+
),
|
|
159
|
+
cache_write_tokens=sum_optional(
|
|
160
|
+
self.cache_write_tokens, other.cache_write_tokens
|
|
161
|
+
),
|
|
162
|
+
duration_seconds=self.default_duration_seconds
|
|
163
|
+
+ other.default_duration_seconds,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
@override
|
|
167
|
+
def __repr__(self):
|
|
168
|
+
attrs = vars(self).copy()
|
|
169
|
+
return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2, sort_dicts=False)}\n)"
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class QueryResult(BaseModel):
|
|
173
|
+
"""
|
|
174
|
+
Result of a query
|
|
175
|
+
Contains the text, reasoning, metadata, tool calls, and history
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
output_text: str | None = None
|
|
179
|
+
reasoning: str | None = None
|
|
180
|
+
metadata: QueryResultMetadata = Field(default_factory=QueryResultMetadata)
|
|
181
|
+
tool_calls: list[ToolCall] = Field(default_factory=list)
|
|
182
|
+
history: list[InputItem] = Field(default_factory=list)
|
|
183
|
+
extras: QueryResultExtras = Field(default_factory=QueryResultExtras)
|
|
184
|
+
raw: dict[str, Any] = Field(default_factory=dict)
|
|
185
|
+
|
|
186
|
+
@property
|
|
187
|
+
def output_text_str(self) -> str:
|
|
188
|
+
return self.output_text or ""
|
|
189
|
+
|
|
190
|
+
@field_validator("reasoning", mode="before")
|
|
191
|
+
def default_reasoning(cls, v: str | None):
|
|
192
|
+
return None if not v else v # make reasoning None if empty
|
|
193
|
+
|
|
194
|
+
@property
|
|
195
|
+
def search_results(self) -> Any | None:
|
|
196
|
+
"""Expose provider-supplied search metadata without additional processing."""
|
|
197
|
+
raw_dict = cast(dict[str, Any], getattr(self, "raw", {}))
|
|
198
|
+
raw_candidate = raw_dict.get("search_results")
|
|
199
|
+
if raw_candidate is not None:
|
|
200
|
+
return raw_candidate
|
|
201
|
+
|
|
202
|
+
return _get_from_history(self.history, "search_results")
|
|
203
|
+
|
|
204
|
+
@override
|
|
205
|
+
def __repr__(self):
|
|
206
|
+
attrs = vars(self).copy()
|
|
207
|
+
ordered_attrs = {
|
|
208
|
+
"output_text": truncate_str(attrs.pop("output_text", None), 400),
|
|
209
|
+
"reasoning": truncate_str(attrs.pop("reasoning", None), 400),
|
|
210
|
+
"metadata": attrs.pop("metadata", None),
|
|
211
|
+
}
|
|
212
|
+
if self.tool_calls:
|
|
213
|
+
ordered_attrs["tool_calls"] = self.tool_calls
|
|
214
|
+
return f"{self.__class__.__name__}(\n{pformat(ordered_attrs, indent=2, sort_dicts=False)}\n)"
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _get_from_history(history: Sequence[InputItem], key: str) -> Any | None:
|
|
218
|
+
for item in reversed(history):
|
|
219
|
+
value = getattr(item, key, None)
|
|
220
|
+
if value is not None:
|
|
221
|
+
return value
|
|
222
|
+
|
|
223
|
+
extra = getattr(item, "model_extra", None)
|
|
224
|
+
if isinstance(extra, Mapping):
|
|
225
|
+
value = cast(Mapping[str, Any], extra).get(key)
|
|
226
|
+
if value is not None:
|
|
227
|
+
return value
|
|
228
|
+
|
|
229
|
+
return None
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from typing import Sequence, cast
|
|
2
|
+
|
|
3
|
+
from model_library.base.input import (
|
|
4
|
+
FileBase,
|
|
5
|
+
InputItem,
|
|
6
|
+
RawInputItem,
|
|
7
|
+
TextInput,
|
|
8
|
+
ToolResult,
|
|
9
|
+
)
|
|
10
|
+
from model_library.utils import truncate_str
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def sum_optional(a: int | None, b: int | None) -> int | None:
|
|
14
|
+
"""Sum two optional integers, returning None if both are None.
|
|
15
|
+
|
|
16
|
+
Preserves None to indicate "unknown/not provided" when both inputs are None,
|
|
17
|
+
otherwise treats None as 0 for summation.
|
|
18
|
+
"""
|
|
19
|
+
if a is None and b is None:
|
|
20
|
+
return None
|
|
21
|
+
return (a or 0) + (b or 0)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_pretty_input_types(input: Sequence["InputItem"], verbose: bool = False) -> str:
|
|
25
|
+
# for logging
|
|
26
|
+
def process_item(item: "InputItem"):
|
|
27
|
+
match item:
|
|
28
|
+
case TextInput():
|
|
29
|
+
item_str = repr(item)
|
|
30
|
+
return item_str if verbose else truncate_str(item_str)
|
|
31
|
+
case FileBase(): # FileInput
|
|
32
|
+
return repr(item)
|
|
33
|
+
case ToolResult():
|
|
34
|
+
return repr(item)
|
|
35
|
+
case dict():
|
|
36
|
+
item = cast(RawInputItem, item)
|
|
37
|
+
return repr(item)
|
|
38
|
+
case _:
|
|
39
|
+
# RawResponse
|
|
40
|
+
return repr(item)
|
|
41
|
+
|
|
42
|
+
processed_items = [f" {process_item(item)}" for item in input]
|
|
43
|
+
return "\n" + "\n".join(processed_items) if processed_items else ""
|