model-library 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/__init__.py +7 -3
- model_library/base/__init__.py +7 -0
- model_library/{base.py → base/base.py} +47 -423
- model_library/base/batch.py +121 -0
- model_library/base/delegate_only.py +94 -0
- model_library/base/input.py +100 -0
- model_library/base/output.py +175 -0
- model_library/base/utils.py +42 -0
- model_library/config/all_models.json +164 -2
- model_library/config/anthropic_models.yaml +4 -0
- model_library/config/deepseek_models.yaml +3 -1
- model_library/config/openai_models.yaml +48 -0
- model_library/exceptions.py +2 -0
- model_library/logging.py +30 -0
- model_library/providers/__init__.py +0 -0
- model_library/providers/ai21labs.py +2 -0
- model_library/providers/alibaba.py +16 -78
- model_library/providers/amazon.py +3 -0
- model_library/providers/anthropic.py +213 -2
- model_library/providers/azure.py +2 -0
- model_library/providers/cohere.py +14 -80
- model_library/providers/deepseek.py +14 -90
- model_library/providers/fireworks.py +17 -81
- model_library/providers/google/google.py +22 -20
- model_library/providers/inception.py +15 -83
- model_library/providers/kimi.py +15 -83
- model_library/providers/mistral.py +2 -0
- model_library/providers/openai.py +2 -0
- model_library/providers/perplexity.py +12 -79
- model_library/providers/together.py +2 -0
- model_library/providers/vals.py +2 -0
- model_library/providers/xai.py +2 -0
- model_library/providers/zai.py +15 -83
- model_library/register_models.py +75 -55
- model_library/registry_utils.py +5 -5
- model_library/utils.py +3 -28
- {model_library-0.1.0.dist-info → model_library-0.1.2.dist-info}/METADATA +36 -7
- model_library-0.1.2.dist-info/RECORD +61 -0
- model_library-0.1.0.dist-info/RECORD +0 -53
- {model_library-0.1.0.dist-info → model_library-0.1.2.dist-info}/WHEEL +0 -0
- {model_library-0.1.0.dist-info → model_library-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.0.dist-info → model_library-0.1.2.dist-info}/top_level.txt +0 -0
model_library/__init__.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
|
-
from model_library.settings import ModelLibrarySettings
|
|
2
1
|
from model_library.base import LLM, LLMConfig
|
|
2
|
+
from model_library.logging import set_logging
|
|
3
|
+
from model_library.settings import ModelLibrarySettings
|
|
3
4
|
|
|
4
5
|
model_library_settings: ModelLibrarySettings = ModelLibrarySettings()
|
|
5
6
|
|
|
7
|
+
set_logging()
|
|
8
|
+
|
|
6
9
|
|
|
7
10
|
def model(model_str: str, override_config: LLMConfig | None = None) -> LLM:
|
|
8
11
|
from model_library.registry_utils import get_registry_model
|
|
@@ -10,14 +13,15 @@ def model(model_str: str, override_config: LLMConfig | None = None) -> LLM:
|
|
|
10
13
|
return get_registry_model(model_str, override_config)
|
|
11
14
|
|
|
12
15
|
|
|
13
|
-
def raw_model(model_str: str,
|
|
16
|
+
def raw_model(model_str: str, config: LLMConfig | None = None) -> LLM:
|
|
14
17
|
from model_library.registry_utils import get_raw_model
|
|
15
18
|
|
|
16
|
-
return get_raw_model(model_str, config=
|
|
19
|
+
return get_raw_model(model_str, config=config)
|
|
17
20
|
|
|
18
21
|
|
|
19
22
|
__all__ = [
|
|
20
23
|
"model_library_settings",
|
|
21
24
|
"model",
|
|
22
25
|
"raw_model",
|
|
26
|
+
"set_logging",
|
|
23
27
|
]
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
# ruff: noqa: F403,F401
|
|
2
|
+
from model_library.base.base import *
|
|
3
|
+
from model_library.base.batch import *
|
|
4
|
+
from model_library.base.delegate_only import *
|
|
5
|
+
from model_library.base.input import *
|
|
6
|
+
from model_library.base.output import *
|
|
7
|
+
from model_library.base.utils import *
|
|
@@ -7,293 +7,46 @@ from collections.abc import Awaitable
|
|
|
7
7
|
from pprint import pformat
|
|
8
8
|
from typing import (
|
|
9
9
|
TYPE_CHECKING,
|
|
10
|
-
Annotated,
|
|
11
10
|
Any,
|
|
12
11
|
Callable,
|
|
13
12
|
Literal,
|
|
14
|
-
Mapping,
|
|
15
13
|
Sequence,
|
|
16
14
|
TypeVar,
|
|
17
|
-
cast,
|
|
18
15
|
)
|
|
19
16
|
|
|
20
|
-
from pydantic import
|
|
21
|
-
from pydantic.fields import Field
|
|
17
|
+
from pydantic import model_serializer
|
|
22
18
|
from pydantic.main import BaseModel
|
|
23
19
|
from typing_extensions import override
|
|
24
20
|
|
|
21
|
+
from model_library.base.batch import (
|
|
22
|
+
LLMBatchMixin,
|
|
23
|
+
)
|
|
24
|
+
from model_library.base.input import (
|
|
25
|
+
FileInput,
|
|
26
|
+
FileWithId,
|
|
27
|
+
InputItem,
|
|
28
|
+
TextInput,
|
|
29
|
+
ToolDefinition,
|
|
30
|
+
ToolResult,
|
|
31
|
+
)
|
|
32
|
+
from model_library.base.output import (
|
|
33
|
+
QueryResult,
|
|
34
|
+
QueryResultCost,
|
|
35
|
+
QueryResultMetadata,
|
|
36
|
+
)
|
|
37
|
+
from model_library.base.utils import (
|
|
38
|
+
get_pretty_input_types,
|
|
39
|
+
)
|
|
25
40
|
from model_library.exceptions import (
|
|
26
41
|
ImmediateRetryException,
|
|
27
42
|
retry_llm_call,
|
|
28
43
|
)
|
|
29
|
-
from model_library.utils import
|
|
30
|
-
|
|
31
|
-
PydanticT = TypeVar("PydanticT", bound=BaseModel)
|
|
32
|
-
|
|
33
|
-
DEFAULT_MAX_TOKENS = 2048
|
|
34
|
-
DEFAULT_TEMPERATURE = 0.7
|
|
35
|
-
DEFAULT_TOP_P = 1
|
|
44
|
+
from model_library.utils import truncate_str
|
|
36
45
|
|
|
37
46
|
if TYPE_CHECKING:
|
|
38
47
|
from model_library.providers.openai import OpenAIModel
|
|
39
48
|
|
|
40
|
-
""
|
|
41
|
-
--- FILES ---
|
|
42
|
-
"""
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
class FileBase(BaseModel):
|
|
46
|
-
type: Literal["image", "file"]
|
|
47
|
-
name: str
|
|
48
|
-
mime: str
|
|
49
|
-
|
|
50
|
-
@override
|
|
51
|
-
def __repr__(self):
|
|
52
|
-
attrs = vars(self).copy()
|
|
53
|
-
if "base64" in attrs:
|
|
54
|
-
attrs["base64"] = truncate_str(attrs["base64"])
|
|
55
|
-
return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2)}\n)"
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
class FileWithBase64(FileBase):
|
|
59
|
-
append_type: Literal["base64"] = "base64"
|
|
60
|
-
base64: str
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
class FileWithUrl(FileBase):
|
|
64
|
-
append_type: Literal["url"] = "url"
|
|
65
|
-
url: str
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
class FileWithId(FileBase):
|
|
69
|
-
append_type: Literal["file_id"] = "file_id"
|
|
70
|
-
file_id: str
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
FileInput = Annotated[
|
|
74
|
-
FileWithBase64 | FileWithUrl | FileWithId,
|
|
75
|
-
Field(discriminator="append_type"),
|
|
76
|
-
]
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
"""
|
|
80
|
-
--- TOOLS ---
|
|
81
|
-
"""
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
class ToolBody(BaseModel):
|
|
85
|
-
name: str
|
|
86
|
-
description: str
|
|
87
|
-
properties: dict[str, Any]
|
|
88
|
-
required: list[str]
|
|
89
|
-
kwargs: dict[str, Any] = {}
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
class ToolDefinition(BaseModel):
|
|
93
|
-
name: str # acts as a key
|
|
94
|
-
body: ToolBody | Any
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
class ToolCall(BaseModel):
|
|
98
|
-
id: str
|
|
99
|
-
call_id: str | None = None
|
|
100
|
-
name: str
|
|
101
|
-
args: dict[str, Any] | str
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
"""
|
|
105
|
-
--- INPUT ---
|
|
106
|
-
"""
|
|
107
|
-
|
|
108
|
-
RawResponse = Any
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
class ToolInput(BaseModel):
|
|
112
|
-
tools: list[ToolDefinition] = []
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
class ToolResult(BaseModel):
|
|
116
|
-
tool_call: ToolCall
|
|
117
|
-
result: Any
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
class TextInput(BaseModel):
|
|
121
|
-
text: str
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
RawInputItem = dict[
|
|
125
|
-
str, Any
|
|
126
|
-
] # to pass in, for example, a mock convertsation with {"role": "user", "content": "Hello"}
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
InputItem = (
|
|
130
|
-
TextInput | FileInput | ToolResult | RawInputItem | RawResponse
|
|
131
|
-
) # input item can either be a prompt, a file (image or file), a tool call result, raw input, or a previous response
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
"""
|
|
135
|
-
--- OUTPUT ---
|
|
136
|
-
"""
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
class Citation(BaseModel):
|
|
140
|
-
type: str | None = None
|
|
141
|
-
title: str | None = None
|
|
142
|
-
url: str | None = None
|
|
143
|
-
start_index: int | None = None
|
|
144
|
-
end_index: int | None = None
|
|
145
|
-
file_id: str | None = None
|
|
146
|
-
filename: str | None = None
|
|
147
|
-
index: int | None = None
|
|
148
|
-
container_id: str | None = None
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
class QueryResultExtras(BaseModel):
|
|
152
|
-
citations: list[Citation] = Field(default_factory=list)
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
class QueryResultCost(BaseModel):
|
|
156
|
-
"""
|
|
157
|
-
Cost information for a query
|
|
158
|
-
Includes total cost and a structured breakdown.
|
|
159
|
-
"""
|
|
160
|
-
|
|
161
|
-
input: float
|
|
162
|
-
output: float
|
|
163
|
-
reasoning: float | None = None
|
|
164
|
-
cache_read: float | None = None
|
|
165
|
-
cache_write: float | None = None
|
|
166
|
-
|
|
167
|
-
@computed_field
|
|
168
|
-
@property
|
|
169
|
-
def total(self) -> float:
|
|
170
|
-
return sum(
|
|
171
|
-
filter(
|
|
172
|
-
None,
|
|
173
|
-
[
|
|
174
|
-
self.input,
|
|
175
|
-
self.output,
|
|
176
|
-
self.reasoning,
|
|
177
|
-
self.cache_read,
|
|
178
|
-
self.cache_write,
|
|
179
|
-
],
|
|
180
|
-
)
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
@override
|
|
184
|
-
def __repr__(self):
|
|
185
|
-
use_cents = self.total < 1
|
|
186
|
-
|
|
187
|
-
def format_cost(value: float | None):
|
|
188
|
-
if value is None:
|
|
189
|
-
return None
|
|
190
|
-
return f"{value * 100:.3f} cents" if use_cents else f"${value:.2f}"
|
|
191
|
-
|
|
192
|
-
return (
|
|
193
|
-
f"{format_cost(self.total)} "
|
|
194
|
-
+ f"(uncached input: {format_cost(self.input)} | output: {format_cost(self.output)} | reasoning: {format_cost(self.reasoning)} | cache_read: {format_cost(self.cache_read)} | cache_write: {format_cost(self.cache_write)})"
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
class QueryResultMetadata(BaseModel):
|
|
199
|
-
"""
|
|
200
|
-
Metadata for a query: token usage and timing.
|
|
201
|
-
|
|
202
|
-
"""
|
|
203
|
-
|
|
204
|
-
cost: QueryResultCost | None = None # set post query
|
|
205
|
-
duration_seconds: float | None = None # set post query
|
|
206
|
-
in_tokens: int = 0
|
|
207
|
-
out_tokens: int = 0
|
|
208
|
-
reasoning_tokens: int | None = None
|
|
209
|
-
cache_read_tokens: int | None = None
|
|
210
|
-
cache_write_tokens: int | None = None
|
|
211
|
-
|
|
212
|
-
@property
|
|
213
|
-
def default_duration_seconds(self) -> float:
|
|
214
|
-
return self.duration_seconds or 0
|
|
215
|
-
|
|
216
|
-
def __add__(self, other: "QueryResultMetadata") -> "QueryResultMetadata":
|
|
217
|
-
return QueryResultMetadata(
|
|
218
|
-
in_tokens=self.in_tokens + other.in_tokens,
|
|
219
|
-
out_tokens=self.out_tokens + other.out_tokens,
|
|
220
|
-
reasoning_tokens=sum_optional(
|
|
221
|
-
self.reasoning_tokens, other.reasoning_tokens
|
|
222
|
-
),
|
|
223
|
-
cache_read_tokens=sum_optional(
|
|
224
|
-
self.cache_read_tokens, other.cache_read_tokens
|
|
225
|
-
),
|
|
226
|
-
cache_write_tokens=sum_optional(
|
|
227
|
-
self.cache_write_tokens, other.cache_write_tokens
|
|
228
|
-
),
|
|
229
|
-
duration_seconds=self.default_duration_seconds
|
|
230
|
-
+ other.default_duration_seconds,
|
|
231
|
-
)
|
|
232
|
-
|
|
233
|
-
@override
|
|
234
|
-
def __repr__(self):
|
|
235
|
-
attrs = vars(self).copy()
|
|
236
|
-
return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2, sort_dicts=False)}\n)"
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
class QueryResult(BaseModel):
|
|
240
|
-
"""
|
|
241
|
-
Result of a query
|
|
242
|
-
Contains the text, reasoning, metadata, tool calls, and history
|
|
243
|
-
"""
|
|
244
|
-
|
|
245
|
-
output_text: str | None = None
|
|
246
|
-
reasoning: str | None = None
|
|
247
|
-
metadata: QueryResultMetadata = Field(default_factory=QueryResultMetadata)
|
|
248
|
-
tool_calls: list[ToolCall] = Field(default_factory=list)
|
|
249
|
-
history: list[InputItem] = Field(default_factory=list)
|
|
250
|
-
extras: QueryResultExtras = Field(default_factory=QueryResultExtras)
|
|
251
|
-
raw: dict[str, Any] = Field(default_factory=dict)
|
|
252
|
-
|
|
253
|
-
@property
|
|
254
|
-
def output_text_str(self) -> str:
|
|
255
|
-
return self.output_text or ""
|
|
256
|
-
|
|
257
|
-
@field_validator("reasoning", mode="before")
|
|
258
|
-
def default_reasoning(cls, v: str | None):
|
|
259
|
-
return None if not v else v # make reasoning None if empty
|
|
260
|
-
|
|
261
|
-
@property
|
|
262
|
-
def search_results(self) -> Any | None:
|
|
263
|
-
"""Expose provider-supplied search metadata without additional processing."""
|
|
264
|
-
raw_dict = cast(dict[str, Any], getattr(self, "raw", {}))
|
|
265
|
-
raw_candidate = raw_dict.get("search_results")
|
|
266
|
-
if raw_candidate is not None:
|
|
267
|
-
return raw_candidate
|
|
268
|
-
|
|
269
|
-
return _get_from_history(self.history, "search_results")
|
|
270
|
-
|
|
271
|
-
@override
|
|
272
|
-
def __repr__(self):
|
|
273
|
-
attrs = vars(self).copy()
|
|
274
|
-
ordered_attrs = {
|
|
275
|
-
"output_text": truncate_str(attrs.pop("output_text", None), 400),
|
|
276
|
-
"reasoning": truncate_str(attrs.pop("reasoning", None), 400),
|
|
277
|
-
"metadata": attrs.pop("metadata", None),
|
|
278
|
-
}
|
|
279
|
-
if self.tool_calls:
|
|
280
|
-
ordered_attrs["tool_calls"] = self.tool_calls
|
|
281
|
-
return f"{self.__class__.__name__}(\n{pformat(ordered_attrs, indent=2, sort_dicts=False)}\n)"
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
def _get_from_history(history: Sequence[InputItem], key: str) -> Any | None:
|
|
285
|
-
for item in reversed(history):
|
|
286
|
-
value = getattr(item, key, None)
|
|
287
|
-
if value is not None:
|
|
288
|
-
return value
|
|
289
|
-
|
|
290
|
-
extra = getattr(item, "model_extra", None)
|
|
291
|
-
if isinstance(extra, Mapping):
|
|
292
|
-
value = cast(Mapping[str, Any], extra).get(key)
|
|
293
|
-
if value is not None:
|
|
294
|
-
return value
|
|
295
|
-
|
|
296
|
-
return None
|
|
49
|
+
PydanticT = TypeVar("PydanticT", bound=BaseModel)
|
|
297
50
|
|
|
298
51
|
|
|
299
52
|
class ProviderConfig(BaseModel):
|
|
@@ -304,6 +57,9 @@ class ProviderConfig(BaseModel):
|
|
|
304
57
|
return self.__dict__
|
|
305
58
|
|
|
306
59
|
|
|
60
|
+
DEFAULT_MAX_TOKENS = 2048
|
|
61
|
+
|
|
62
|
+
|
|
307
63
|
class LLMConfig(BaseModel):
|
|
308
64
|
max_tokens: int = DEFAULT_MAX_TOKENS
|
|
309
65
|
temperature: float | None = None
|
|
@@ -347,9 +103,6 @@ class LLM(ABC):
|
|
|
347
103
|
config = config or LLMConfig()
|
|
348
104
|
self._registry_key = config.registry_key
|
|
349
105
|
|
|
350
|
-
if config.provider_config:
|
|
351
|
-
self.provider_config = config.provider_config
|
|
352
|
-
|
|
353
106
|
self.max_tokens: int = config.max_tokens
|
|
354
107
|
self.temperature: float | None = config.temperature
|
|
355
108
|
self.top_p: float | None = config.top_p
|
|
@@ -368,6 +121,12 @@ class LLM(ABC):
|
|
|
368
121
|
self.delegate: "OpenAIModel | None" = None
|
|
369
122
|
self.batch: LLMBatchMixin | None = None
|
|
370
123
|
|
|
124
|
+
if config.provider_config:
|
|
125
|
+
if isinstance(
|
|
126
|
+
config.provider_config, type(getattr(self, "provider_config"))
|
|
127
|
+
):
|
|
128
|
+
self.provider_config = config.provider_config
|
|
129
|
+
|
|
371
130
|
self.logger: logging.Logger = logging.getLogger(
|
|
372
131
|
f"llm.{provider}.{model_name}<instance={self.instance_id}>"
|
|
373
132
|
)
|
|
@@ -521,33 +280,6 @@ class LLM(ABC):
|
|
|
521
280
|
|
|
522
281
|
return output
|
|
523
282
|
|
|
524
|
-
async def query_json(
|
|
525
|
-
self,
|
|
526
|
-
input: Sequence[InputItem],
|
|
527
|
-
pydantic_model: type[PydanticT],
|
|
528
|
-
**kwargs: object,
|
|
529
|
-
) -> PydanticT:
|
|
530
|
-
"""Query the model with JSON response format using Pydantic model.
|
|
531
|
-
|
|
532
|
-
This is a convenience method that is not implemented for all providers.
|
|
533
|
-
Only OpenAI and Google providers currently support this method.
|
|
534
|
-
|
|
535
|
-
Args:
|
|
536
|
-
input: Input items (text, files, etc.)
|
|
537
|
-
pydantic_model: Pydantic model class defining the expected response structure
|
|
538
|
-
**kwargs: Additional arguments passed to the query method
|
|
539
|
-
|
|
540
|
-
Returns:
|
|
541
|
-
Instance of the pydantic_model with the model's response
|
|
542
|
-
|
|
543
|
-
Raises:
|
|
544
|
-
NotImplementedError: If the provider does not support structured JSON output
|
|
545
|
-
"""
|
|
546
|
-
raise NotImplementedError(
|
|
547
|
-
f"query_json is not implemented for {self.__class__.__name__}. "
|
|
548
|
-
f"Only OpenAI and Google providers currently support this method."
|
|
549
|
-
)
|
|
550
|
-
|
|
551
283
|
async def _calculate_cost(
|
|
552
284
|
self,
|
|
553
285
|
metadata: QueryResultMetadata,
|
|
@@ -678,137 +410,29 @@ class LLM(ABC):
|
|
|
678
410
|
"""Upload a file to the model provider"""
|
|
679
411
|
...
|
|
680
412
|
|
|
681
|
-
|
|
682
|
-
class BatchResult(BaseModel):
|
|
683
|
-
custom_id: str
|
|
684
|
-
output: QueryResult
|
|
685
|
-
error_message: str | None = None
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
class LLMBatchMixin(ABC):
|
|
689
|
-
@abstractmethod
|
|
690
|
-
async def create_batch_query_request(
|
|
413
|
+
async def query_json(
|
|
691
414
|
self,
|
|
692
|
-
custom_id: str,
|
|
693
415
|
input: Sequence[InputItem],
|
|
416
|
+
pydantic_model: type[PydanticT],
|
|
694
417
|
**kwargs: object,
|
|
695
|
-
) ->
|
|
696
|
-
"""
|
|
697
|
-
|
|
698
|
-
The batch api sends out a batch of query requests to various endpoints.
|
|
418
|
+
) -> PydanticT:
|
|
419
|
+
"""Query the model with JSON response format using Pydantic model.
|
|
699
420
|
|
|
700
|
-
|
|
421
|
+
This is a convenience method that is not implemented for all providers.
|
|
422
|
+
Only OpenAI and Google providers currently support this method.
|
|
701
423
|
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
424
|
+
Args:
|
|
425
|
+
input: Input items (text, files, etc.)
|
|
426
|
+
pydantic_model: Pydantic model class defining the expected response structure
|
|
427
|
+
**kwargs: Additional arguments passed to the query method
|
|
705
428
|
|
|
706
|
-
@abstractmethod
|
|
707
|
-
async def batch_query(
|
|
708
|
-
self,
|
|
709
|
-
batch_name: str,
|
|
710
|
-
requests: list[dict[str, Any]],
|
|
711
|
-
) -> str:
|
|
712
|
-
"""
|
|
713
|
-
Batch query the model
|
|
714
429
|
Returns:
|
|
715
|
-
|
|
716
|
-
Raises:
|
|
717
|
-
Exception: If failed to batch query
|
|
718
|
-
"""
|
|
719
|
-
...
|
|
720
|
-
|
|
721
|
-
@abstractmethod
|
|
722
|
-
async def get_batch_results(self, batch_id: str) -> list[BatchResult]:
|
|
723
|
-
"""
|
|
724
|
-
Returns results for batch
|
|
725
|
-
Raises:
|
|
726
|
-
Exception: If failed to get results
|
|
727
|
-
"""
|
|
728
|
-
...
|
|
729
|
-
|
|
730
|
-
@abstractmethod
|
|
731
|
-
async def get_batch_progress(self, batch_id: str) -> int:
|
|
732
|
-
"""
|
|
733
|
-
Returns number of completed requests for batch
|
|
734
|
-
Raises:
|
|
735
|
-
Exception: If failed to get progress
|
|
736
|
-
"""
|
|
737
|
-
...
|
|
738
|
-
|
|
739
|
-
@abstractmethod
|
|
740
|
-
async def cancel_batch_request(self, batch_id: str) -> None:
|
|
741
|
-
"""
|
|
742
|
-
Cancels batch
|
|
743
|
-
Raises:
|
|
744
|
-
Exception: If failed to cancel
|
|
745
|
-
"""
|
|
746
|
-
...
|
|
430
|
+
Instance of the pydantic_model with the model's response
|
|
747
431
|
|
|
748
|
-
@abstractmethod
|
|
749
|
-
async def get_batch_status(
|
|
750
|
-
self,
|
|
751
|
-
batch_id: str,
|
|
752
|
-
) -> str:
|
|
753
|
-
"""
|
|
754
|
-
Returns batch status
|
|
755
432
|
Raises:
|
|
756
|
-
|
|
757
|
-
"""
|
|
758
|
-
...
|
|
759
|
-
|
|
760
|
-
@classmethod
|
|
761
|
-
@abstractmethod
|
|
762
|
-
def is_batch_status_completed(
|
|
763
|
-
cls,
|
|
764
|
-
batch_status: str,
|
|
765
|
-
) -> bool:
|
|
766
|
-
"""
|
|
767
|
-
Returns if batch status is completed
|
|
768
|
-
|
|
769
|
-
A completed state is any state that is final and not in-progress
|
|
770
|
-
Example: failed | cancelled | expired | completed
|
|
771
|
-
|
|
772
|
-
An incompleted state is any state that is not completed
|
|
773
|
-
Example: in_progress | pending | running
|
|
433
|
+
NotImplementedError: If the provider does not support structured JSON output
|
|
774
434
|
"""
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
def is_batch_status_failed(
|
|
780
|
-
cls,
|
|
781
|
-
batch_status: str,
|
|
782
|
-
) -> bool:
|
|
783
|
-
"""Returns if batch status is failed"""
|
|
784
|
-
...
|
|
785
|
-
|
|
786
|
-
@classmethod
|
|
787
|
-
@abstractmethod
|
|
788
|
-
def is_batch_status_cancelled(
|
|
789
|
-
cls,
|
|
790
|
-
batch_status: str,
|
|
791
|
-
) -> bool:
|
|
792
|
-
"""Returns if batch status is cancelled"""
|
|
793
|
-
...
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
def get_pretty_input_types(input: Sequence["InputItem"]) -> str:
|
|
797
|
-
# for logging
|
|
798
|
-
def process_item(item: "InputItem"):
|
|
799
|
-
match item:
|
|
800
|
-
case TextInput():
|
|
801
|
-
return truncate_str(repr(item))
|
|
802
|
-
case FileBase(): # FileInput
|
|
803
|
-
return repr(item)
|
|
804
|
-
case ToolResult():
|
|
805
|
-
return repr(item)
|
|
806
|
-
case dict():
|
|
807
|
-
item = cast(RawInputItem, item)
|
|
808
|
-
return repr(item)
|
|
809
|
-
case _:
|
|
810
|
-
# RawResponse
|
|
811
|
-
return repr(item)
|
|
812
|
-
|
|
813
|
-
processed_items = [f" {process_item(item)}" for item in input]
|
|
814
|
-
return "\n" + "\n".join(processed_items) if processed_items else ""
|
|
435
|
+
raise NotImplementedError(
|
|
436
|
+
f"query_json is not implemented for {self.__class__.__name__}. "
|
|
437
|
+
f"Only OpenAI and Google providers currently support this method."
|
|
438
|
+
)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Sequence
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from model_library.base.input import InputItem
|
|
7
|
+
from model_library.base.output import QueryResult
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BatchResult(BaseModel):
|
|
11
|
+
custom_id: str
|
|
12
|
+
output: QueryResult
|
|
13
|
+
error_message: str | None = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LLMBatchMixin(ABC):
|
|
17
|
+
@abstractmethod
|
|
18
|
+
async def create_batch_query_request(
|
|
19
|
+
self,
|
|
20
|
+
custom_id: str,
|
|
21
|
+
input: Sequence[InputItem],
|
|
22
|
+
**kwargs: object,
|
|
23
|
+
) -> dict[str, Any]:
|
|
24
|
+
"""Return a single query request
|
|
25
|
+
|
|
26
|
+
The batch api sends out a batch of query requests to various endpoints.
|
|
27
|
+
|
|
28
|
+
For example OpenAI sends can send requests to /v1/responses or /v1/chat/completions endpoints.
|
|
29
|
+
|
|
30
|
+
This method creates a query request for methods such methods
|
|
31
|
+
"""
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
async def batch_query(
|
|
36
|
+
self,
|
|
37
|
+
batch_name: str,
|
|
38
|
+
requests: list[dict[str, Any]],
|
|
39
|
+
) -> str:
|
|
40
|
+
"""
|
|
41
|
+
Batch query the model
|
|
42
|
+
Returns:
|
|
43
|
+
str: batch_id
|
|
44
|
+
Raises:
|
|
45
|
+
Exception: If failed to batch query
|
|
46
|
+
"""
|
|
47
|
+
...
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
async def get_batch_results(self, batch_id: str) -> list[BatchResult]:
|
|
51
|
+
"""
|
|
52
|
+
Returns results for batch
|
|
53
|
+
Raises:
|
|
54
|
+
Exception: If failed to get results
|
|
55
|
+
"""
|
|
56
|
+
...
|
|
57
|
+
|
|
58
|
+
@abstractmethod
|
|
59
|
+
async def get_batch_progress(self, batch_id: str) -> int:
|
|
60
|
+
"""
|
|
61
|
+
Returns number of completed requests for batch
|
|
62
|
+
Raises:
|
|
63
|
+
Exception: If failed to get progress
|
|
64
|
+
"""
|
|
65
|
+
...
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
async def cancel_batch_request(self, batch_id: str) -> None:
|
|
69
|
+
"""
|
|
70
|
+
Cancels batch
|
|
71
|
+
Raises:
|
|
72
|
+
Exception: If failed to cancel
|
|
73
|
+
"""
|
|
74
|
+
...
|
|
75
|
+
|
|
76
|
+
@abstractmethod
|
|
77
|
+
async def get_batch_status(
|
|
78
|
+
self,
|
|
79
|
+
batch_id: str,
|
|
80
|
+
) -> str:
|
|
81
|
+
"""
|
|
82
|
+
Returns batch status
|
|
83
|
+
Raises:
|
|
84
|
+
Exception: If failed to get status
|
|
85
|
+
"""
|
|
86
|
+
...
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
@abstractmethod
|
|
90
|
+
def is_batch_status_completed(
|
|
91
|
+
cls,
|
|
92
|
+
batch_status: str,
|
|
93
|
+
) -> bool:
|
|
94
|
+
"""
|
|
95
|
+
Returns if batch status is completed
|
|
96
|
+
|
|
97
|
+
A completed state is any state that is final and not in-progress
|
|
98
|
+
Example: failed | cancelled | expired | completed
|
|
99
|
+
|
|
100
|
+
An incompleted state is any state that is not completed
|
|
101
|
+
Example: in_progress | pending | running
|
|
102
|
+
"""
|
|
103
|
+
...
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
@abstractmethod
|
|
107
|
+
def is_batch_status_failed(
|
|
108
|
+
cls,
|
|
109
|
+
batch_status: str,
|
|
110
|
+
) -> bool:
|
|
111
|
+
"""Returns if batch status is failed"""
|
|
112
|
+
...
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
@abstractmethod
|
|
116
|
+
def is_batch_status_cancelled(
|
|
117
|
+
cls,
|
|
118
|
+
batch_status: str,
|
|
119
|
+
) -> bool:
|
|
120
|
+
"""Returns if batch status is cancelled"""
|
|
121
|
+
...
|