model-library 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. model_library/base/__init__.py +7 -0
  2. model_library/{base.py → base/base.py} +58 -429
  3. model_library/base/batch.py +121 -0
  4. model_library/base/delegate_only.py +94 -0
  5. model_library/base/input.py +100 -0
  6. model_library/base/output.py +229 -0
  7. model_library/base/utils.py +43 -0
  8. model_library/config/ai21labs_models.yaml +1 -0
  9. model_library/config/all_models.json +461 -36
  10. model_library/config/anthropic_models.yaml +30 -3
  11. model_library/config/deepseek_models.yaml +3 -1
  12. model_library/config/google_models.yaml +49 -0
  13. model_library/config/openai_models.yaml +43 -4
  14. model_library/config/together_models.yaml +1 -0
  15. model_library/config/xai_models.yaml +63 -3
  16. model_library/exceptions.py +8 -2
  17. model_library/file_utils.py +1 -1
  18. model_library/providers/__init__.py +0 -0
  19. model_library/providers/ai21labs.py +2 -0
  20. model_library/providers/alibaba.py +16 -78
  21. model_library/providers/amazon.py +3 -0
  22. model_library/providers/anthropic.py +215 -8
  23. model_library/providers/azure.py +2 -0
  24. model_library/providers/cohere.py +14 -80
  25. model_library/providers/deepseek.py +14 -90
  26. model_library/providers/fireworks.py +17 -81
  27. model_library/providers/google/google.py +55 -47
  28. model_library/providers/inception.py +15 -83
  29. model_library/providers/kimi.py +15 -83
  30. model_library/providers/mistral.py +2 -0
  31. model_library/providers/openai.py +10 -2
  32. model_library/providers/perplexity.py +12 -79
  33. model_library/providers/together.py +19 -210
  34. model_library/providers/vals.py +2 -0
  35. model_library/providers/xai.py +2 -0
  36. model_library/providers/zai.py +15 -83
  37. model_library/register_models.py +75 -57
  38. model_library/registry_utils.py +5 -5
  39. model_library/utils.py +3 -28
  40. {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/METADATA +2 -3
  41. model_library-0.1.3.dist-info/RECORD +61 -0
  42. model_library-0.1.1.dist-info/RECORD +0 -54
  43. {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/WHEEL +0 -0
  44. {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/licenses/LICENSE +0 -0
  45. {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,94 @@
1
+ import io
2
+ from typing import Any, Literal, Sequence
3
+
4
+ from typing_extensions import override
5
+
6
+ from model_library.base import (
7
+ LLM,
8
+ FileInput,
9
+ FileWithId,
10
+ InputItem,
11
+ LLMConfig,
12
+ QueryResult,
13
+ ToolDefinition,
14
+ )
15
+
16
+
17
+ class DelegateOnlyException(Exception):
18
+ """
19
+ Raised when native model functionality is performed on a
20
+ delegate-only model.
21
+ """
22
+
23
+ DEFAULT_MESSAGE: str = "This model supports only delegate-only functionality. Only the query() method should be used."
24
+
25
+ def __init__(self, message: str | None = None):
26
+ super().__init__(message or DelegateOnlyException.DEFAULT_MESSAGE)
27
+
28
+
29
+ class DelegateOnly(LLM):
30
+ @override
31
+ def get_client(self) -> None:
32
+ raise DelegateOnlyException()
33
+
34
+ def __init__(
35
+ self,
36
+ model_name: str,
37
+ provider: str,
38
+ *,
39
+ config: LLMConfig | None = None,
40
+ ):
41
+ config = config or LLMConfig()
42
+ config.native = False
43
+ super().__init__(model_name, provider, config=config)
44
+
45
+ @override
46
+ async def _query_impl(
47
+ self,
48
+ input: Sequence[InputItem],
49
+ *,
50
+ tools: list[ToolDefinition],
51
+ **kwargs: object,
52
+ ) -> QueryResult:
53
+ assert self.delegate
54
+
55
+ return await self.delegate_query(input, tools=tools, **kwargs)
56
+
57
+ @override
58
+ async def parse_input(
59
+ self,
60
+ input: Sequence[InputItem],
61
+ **kwargs: Any,
62
+ ) -> Any:
63
+ raise DelegateOnlyException()
64
+
65
+ @override
66
+ async def parse_image(
67
+ self,
68
+ image: FileInput,
69
+ ) -> Any:
70
+ raise DelegateOnlyException()
71
+
72
+ @override
73
+ async def parse_file(
74
+ self,
75
+ file: FileInput,
76
+ ) -> Any:
77
+ raise DelegateOnlyException()
78
+
79
+ @override
80
+ async def parse_tools(
81
+ self,
82
+ tools: list[ToolDefinition],
83
+ ) -> Any:
84
+ raise DelegateOnlyException()
85
+
86
+ @override
87
+ async def upload_file(
88
+ self,
89
+ name: str,
90
+ mime: str,
91
+ bytes: io.BytesIO,
92
+ type: Literal["image", "file"] = "file",
93
+ ) -> FileWithId:
94
+ raise DelegateOnlyException()
@@ -0,0 +1,100 @@
1
+ from pprint import pformat
2
+ from typing import Annotated, Any, Literal
3
+
4
+ from pydantic import BaseModel, Field
5
+ from typing_extensions import override
6
+
7
+ from model_library.utils import truncate_str
8
+
9
+ """
10
+ --- FILES ---
11
+ """
12
+
13
+
14
+ class FileBase(BaseModel):
15
+ type: Literal["image", "file"]
16
+ name: str
17
+ mime: str
18
+
19
+ @override
20
+ def __repr__(self):
21
+ attrs = vars(self).copy()
22
+ if "base64" in attrs:
23
+ attrs["base64"] = truncate_str(attrs["base64"])
24
+ return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2)}\n)"
25
+
26
+
27
+ class FileWithBase64(FileBase):
28
+ append_type: Literal["base64"] = "base64"
29
+ base64: str
30
+
31
+
32
+ class FileWithUrl(FileBase):
33
+ append_type: Literal["url"] = "url"
34
+ url: str
35
+
36
+
37
+ class FileWithId(FileBase):
38
+ append_type: Literal["file_id"] = "file_id"
39
+ file_id: str
40
+
41
+
42
+ FileInput = Annotated[
43
+ FileWithBase64 | FileWithUrl | FileWithId,
44
+ Field(discriminator="append_type"),
45
+ ]
46
+
47
+
48
+ """
49
+ --- TOOLS ---
50
+ """
51
+
52
+
53
+ class ToolBody(BaseModel):
54
+ name: str
55
+ description: str
56
+ properties: dict[str, Any]
57
+ required: list[str]
58
+ kwargs: dict[str, Any] = {}
59
+
60
+
61
+ class ToolDefinition(BaseModel):
62
+ name: str # acts as a key
63
+ body: ToolBody | Any
64
+
65
+
66
+ class ToolCall(BaseModel):
67
+ id: str
68
+ call_id: str | None = None
69
+ name: str
70
+ args: dict[str, Any] | str
71
+
72
+
73
+ """
74
+ --- INPUT ---
75
+ """
76
+
77
+ RawResponse = Any
78
+
79
+
80
+ class ToolInput(BaseModel):
81
+ tools: list[ToolDefinition] = []
82
+
83
+
84
+ class ToolResult(BaseModel):
85
+ tool_call: ToolCall
86
+ result: Any
87
+
88
+
89
+ class TextInput(BaseModel):
90
+ text: str
91
+
92
+
93
+ RawInputItem = dict[
94
+ str, Any
95
+ ] # to pass in, for example, a mock convertsation with {"role": "user", "content": "Hello"}
96
+
97
+
98
+ InputItem = (
99
+ TextInput | FileInput | ToolResult | RawInputItem | RawResponse
100
+ ) # input item can either be a prompt, a file (image or file), a tool call result, raw input, or a previous response
@@ -0,0 +1,229 @@
1
+ """
2
+ --- OUTPUT ---
3
+ """
4
+
5
+ from pprint import pformat
6
+ from typing import Any, Mapping, Sequence, cast
7
+
8
+ from pydantic import BaseModel, Field, computed_field, field_validator
9
+ from typing_extensions import override
10
+
11
+ from model_library.base.input import InputItem, ToolCall
12
+ from model_library.base.utils import (
13
+ sum_optional,
14
+ )
15
+ from model_library.utils import truncate_str
16
+
17
+
18
+ class Citation(BaseModel):
19
+ type: str | None = None
20
+ title: str | None = None
21
+ url: str | None = None
22
+ start_index: int | None = None
23
+ end_index: int | None = None
24
+ file_id: str | None = None
25
+ filename: str | None = None
26
+ index: int | None = None
27
+ container_id: str | None = None
28
+
29
+
30
+ class QueryResultExtras(BaseModel):
31
+ citations: list[Citation] = Field(default_factory=list)
32
+
33
+
34
+ class QueryResultCost(BaseModel):
35
+ """
36
+ Cost information for a query
37
+ Includes total cost and a structured breakdown.
38
+ """
39
+
40
+ input: float
41
+ output: float
42
+ reasoning: float | None = None
43
+ cache_read: float | None = None
44
+ cache_write: float | None = None
45
+
46
+ @computed_field
47
+ @property
48
+ def total(self) -> float:
49
+ return sum(
50
+ filter(
51
+ None,
52
+ [
53
+ self.input,
54
+ self.output,
55
+ self.reasoning,
56
+ self.cache_read,
57
+ self.cache_write,
58
+ ],
59
+ )
60
+ )
61
+
62
+ @computed_field
63
+ @property
64
+ def total_input(self) -> float:
65
+ return sum(
66
+ filter(
67
+ None,
68
+ [
69
+ self.input,
70
+ self.cache_read,
71
+ self.cache_write,
72
+ ],
73
+ )
74
+ )
75
+
76
+ @computed_field
77
+ @property
78
+ def total_output(self) -> float:
79
+ return sum(
80
+ filter(
81
+ None,
82
+ [
83
+ self.output,
84
+ self.reasoning,
85
+ ],
86
+ )
87
+ )
88
+
89
+ @override
90
+ def __repr__(self):
91
+ use_cents = self.total < 1
92
+
93
+ def format_cost(value: float | None):
94
+ if value is None:
95
+ return None
96
+ return f"{value * 100:.3f} cents" if use_cents else f"${value:.2f}"
97
+
98
+ return (
99
+ f"{format_cost(self.total)} "
100
+ + f"(uncached input: {format_cost(self.input)} | output: {format_cost(self.output)} | reasoning: {format_cost(self.reasoning)} | cache_read: {format_cost(self.cache_read)} | cache_write: {format_cost(self.cache_write)})"
101
+ )
102
+
103
+
104
+ class QueryResultMetadata(BaseModel):
105
+ """
106
+ Metadata for a query: token usage and timing.
107
+
108
+ """
109
+
110
+ cost: QueryResultCost | None = None # set post query
111
+ duration_seconds: float | None = None # set post query
112
+ in_tokens: int = 0
113
+ out_tokens: int = 0
114
+ reasoning_tokens: int | None = None
115
+ cache_read_tokens: int | None = None
116
+ cache_write_tokens: int | None = None
117
+
118
+ @property
119
+ def default_duration_seconds(self) -> float:
120
+ return self.duration_seconds or 0
121
+
122
+ @computed_field
123
+ @property
124
+ def total_input_tokens(self) -> int:
125
+ return sum(
126
+ filter(
127
+ None,
128
+ [
129
+ self.in_tokens,
130
+ self.cache_read_tokens,
131
+ self.cache_write_tokens,
132
+ ],
133
+ )
134
+ )
135
+
136
+ @computed_field
137
+ @property
138
+ def total_output_tokens(self) -> int:
139
+ return sum(
140
+ filter(
141
+ None,
142
+ [
143
+ self.out_tokens,
144
+ self.reasoning_tokens,
145
+ ],
146
+ )
147
+ )
148
+
149
+ def __add__(self, other: "QueryResultMetadata") -> "QueryResultMetadata":
150
+ return QueryResultMetadata(
151
+ in_tokens=self.in_tokens + other.in_tokens,
152
+ out_tokens=self.out_tokens + other.out_tokens,
153
+ reasoning_tokens=sum_optional(
154
+ self.reasoning_tokens, other.reasoning_tokens
155
+ ),
156
+ cache_read_tokens=sum_optional(
157
+ self.cache_read_tokens, other.cache_read_tokens
158
+ ),
159
+ cache_write_tokens=sum_optional(
160
+ self.cache_write_tokens, other.cache_write_tokens
161
+ ),
162
+ duration_seconds=self.default_duration_seconds
163
+ + other.default_duration_seconds,
164
+ )
165
+
166
+ @override
167
+ def __repr__(self):
168
+ attrs = vars(self).copy()
169
+ return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2, sort_dicts=False)}\n)"
170
+
171
+
172
+ class QueryResult(BaseModel):
173
+ """
174
+ Result of a query
175
+ Contains the text, reasoning, metadata, tool calls, and history
176
+ """
177
+
178
+ output_text: str | None = None
179
+ reasoning: str | None = None
180
+ metadata: QueryResultMetadata = Field(default_factory=QueryResultMetadata)
181
+ tool_calls: list[ToolCall] = Field(default_factory=list)
182
+ history: list[InputItem] = Field(default_factory=list)
183
+ extras: QueryResultExtras = Field(default_factory=QueryResultExtras)
184
+ raw: dict[str, Any] = Field(default_factory=dict)
185
+
186
+ @property
187
+ def output_text_str(self) -> str:
188
+ return self.output_text or ""
189
+
190
+ @field_validator("reasoning", mode="before")
191
+ def default_reasoning(cls, v: str | None):
192
+ return None if not v else v # make reasoning None if empty
193
+
194
+ @property
195
+ def search_results(self) -> Any | None:
196
+ """Expose provider-supplied search metadata without additional processing."""
197
+ raw_dict = cast(dict[str, Any], getattr(self, "raw", {}))
198
+ raw_candidate = raw_dict.get("search_results")
199
+ if raw_candidate is not None:
200
+ return raw_candidate
201
+
202
+ return _get_from_history(self.history, "search_results")
203
+
204
+ @override
205
+ def __repr__(self):
206
+ attrs = vars(self).copy()
207
+ ordered_attrs = {
208
+ "output_text": truncate_str(attrs.pop("output_text", None), 400),
209
+ "reasoning": truncate_str(attrs.pop("reasoning", None), 400),
210
+ "metadata": attrs.pop("metadata", None),
211
+ }
212
+ if self.tool_calls:
213
+ ordered_attrs["tool_calls"] = self.tool_calls
214
+ return f"{self.__class__.__name__}(\n{pformat(ordered_attrs, indent=2, sort_dicts=False)}\n)"
215
+
216
+
217
+ def _get_from_history(history: Sequence[InputItem], key: str) -> Any | None:
218
+ for item in reversed(history):
219
+ value = getattr(item, key, None)
220
+ if value is not None:
221
+ return value
222
+
223
+ extra = getattr(item, "model_extra", None)
224
+ if isinstance(extra, Mapping):
225
+ value = cast(Mapping[str, Any], extra).get(key)
226
+ if value is not None:
227
+ return value
228
+
229
+ return None
@@ -0,0 +1,43 @@
1
+ from typing import Sequence, cast
2
+
3
+ from model_library.base.input import (
4
+ FileBase,
5
+ InputItem,
6
+ RawInputItem,
7
+ TextInput,
8
+ ToolResult,
9
+ )
10
+ from model_library.utils import truncate_str
11
+
12
+
13
+ def sum_optional(a: int | None, b: int | None) -> int | None:
14
+ """Sum two optional integers, returning None if both are None.
15
+
16
+ Preserves None to indicate "unknown/not provided" when both inputs are None,
17
+ otherwise treats None as 0 for summation.
18
+ """
19
+ if a is None and b is None:
20
+ return None
21
+ return (a or 0) + (b or 0)
22
+
23
+
24
+ def get_pretty_input_types(input: Sequence["InputItem"], verbose: bool = False) -> str:
25
+ # for logging
26
+ def process_item(item: "InputItem"):
27
+ match item:
28
+ case TextInput():
29
+ item_str = repr(item)
30
+ return item_str if verbose else truncate_str(item_str)
31
+ case FileBase(): # FileInput
32
+ return repr(item)
33
+ case ToolResult():
34
+ return repr(item)
35
+ case dict():
36
+ item = cast(RawInputItem, item)
37
+ return repr(item)
38
+ case _:
39
+ # RawResponse
40
+ return repr(item)
41
+
42
+ processed_items = [f" {process_item(item)}" for item in input]
43
+ return "\n" + "\n".join(processed_items) if processed_items else ""
@@ -19,6 +19,7 @@ ai21labs-models:
19
19
  supports_temperature: true
20
20
  default_parameters:
21
21
  temperature: 0.4
22
+ max_output_tokens: 4096
22
23
 
23
24
  ai21labs/jamba-large-1.7:
24
25
  label: Jamba 1.7 Large