model-library 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/__init__.py +23 -0
- model_library/base.py +814 -0
- model_library/config/ai21labs_models.yaml +99 -0
- model_library/config/alibaba_models.yaml +91 -0
- model_library/config/all_models.json +13479 -0
- model_library/config/amazon_models.yaml +276 -0
- model_library/config/anthropic_models.yaml +370 -0
- model_library/config/cohere_models.yaml +177 -0
- model_library/config/deepseek_models.yaml +47 -0
- model_library/config/dummy_model.yaml +38 -0
- model_library/config/fireworks_models.yaml +228 -0
- model_library/config/google_models.yaml +516 -0
- model_library/config/inception_models.yaml +24 -0
- model_library/config/kimi_models.yaml +34 -0
- model_library/config/mistral_models.yaml +143 -0
- model_library/config/openai_models.yaml +783 -0
- model_library/config/perplexity_models.yaml +91 -0
- model_library/config/together_models.yaml +866 -0
- model_library/config/xai_models.yaml +266 -0
- model_library/config/zai_models.yaml +65 -0
- model_library/exceptions.py +288 -0
- model_library/file_utils.py +114 -0
- model_library/model_utils.py +26 -0
- model_library/providers/ai21labs.py +193 -0
- model_library/providers/alibaba.py +147 -0
- model_library/providers/amazon.py +367 -0
- model_library/providers/anthropic.py +419 -0
- model_library/providers/azure.py +43 -0
- model_library/providers/cohere.py +100 -0
- model_library/providers/deepseek.py +115 -0
- model_library/providers/fireworks.py +133 -0
- model_library/providers/google/__init__.py +4 -0
- model_library/providers/google/batch.py +299 -0
- model_library/providers/google/google.py +467 -0
- model_library/providers/inception.py +102 -0
- model_library/providers/kimi.py +102 -0
- model_library/providers/mistral.py +299 -0
- model_library/providers/openai.py +924 -0
- model_library/providers/perplexity.py +101 -0
- model_library/providers/together.py +249 -0
- model_library/providers/vals.py +307 -0
- model_library/providers/xai.py +332 -0
- model_library/providers/zai.py +102 -0
- model_library/py.typed +0 -0
- model_library/register_models.py +385 -0
- model_library/registry_utils.py +202 -0
- model_library/settings.py +34 -0
- model_library/utils.py +151 -0
- model_library-0.1.0.dist-info/METADATA +268 -0
- model_library-0.1.0.dist-info/RECORD +53 -0
- model_library-0.1.0.dist-info/WHEEL +5 -0
- model_library-0.1.0.dist-info/licenses/LICENSE +21 -0
- model_library-0.1.0.dist-info/top_level.txt +1 -0
model_library/base.py
ADDED
|
@@ -0,0 +1,814 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
import uuid
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from collections.abc import Awaitable
|
|
7
|
+
from pprint import pformat
|
|
8
|
+
from typing import (
|
|
9
|
+
TYPE_CHECKING,
|
|
10
|
+
Annotated,
|
|
11
|
+
Any,
|
|
12
|
+
Callable,
|
|
13
|
+
Literal,
|
|
14
|
+
Mapping,
|
|
15
|
+
Sequence,
|
|
16
|
+
TypeVar,
|
|
17
|
+
cast,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from pydantic import computed_field, field_validator, model_serializer
|
|
21
|
+
from pydantic.fields import Field
|
|
22
|
+
from pydantic.main import BaseModel
|
|
23
|
+
from typing_extensions import override
|
|
24
|
+
|
|
25
|
+
from model_library.exceptions import (
|
|
26
|
+
ImmediateRetryException,
|
|
27
|
+
retry_llm_call,
|
|
28
|
+
)
|
|
29
|
+
from model_library.utils import sum_optional, truncate_str
|
|
30
|
+
|
|
31
|
+
PydanticT = TypeVar("PydanticT", bound=BaseModel)
|
|
32
|
+
|
|
33
|
+
DEFAULT_MAX_TOKENS = 2048
|
|
34
|
+
DEFAULT_TEMPERATURE = 0.7
|
|
35
|
+
DEFAULT_TOP_P = 1
|
|
36
|
+
|
|
37
|
+
if TYPE_CHECKING:
|
|
38
|
+
from model_library.providers.openai import OpenAIModel
|
|
39
|
+
|
|
40
|
+
"""
|
|
41
|
+
--- FILES ---
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class FileBase(BaseModel):
|
|
46
|
+
type: Literal["image", "file"]
|
|
47
|
+
name: str
|
|
48
|
+
mime: str
|
|
49
|
+
|
|
50
|
+
@override
|
|
51
|
+
def __repr__(self):
|
|
52
|
+
attrs = vars(self).copy()
|
|
53
|
+
if "base64" in attrs:
|
|
54
|
+
attrs["base64"] = truncate_str(attrs["base64"])
|
|
55
|
+
return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2)}\n)"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class FileWithBase64(FileBase):
|
|
59
|
+
append_type: Literal["base64"] = "base64"
|
|
60
|
+
base64: str
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class FileWithUrl(FileBase):
|
|
64
|
+
append_type: Literal["url"] = "url"
|
|
65
|
+
url: str
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class FileWithId(FileBase):
|
|
69
|
+
append_type: Literal["file_id"] = "file_id"
|
|
70
|
+
file_id: str
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
FileInput = Annotated[
|
|
74
|
+
FileWithBase64 | FileWithUrl | FileWithId,
|
|
75
|
+
Field(discriminator="append_type"),
|
|
76
|
+
]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
"""
|
|
80
|
+
--- TOOLS ---
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class ToolBody(BaseModel):
|
|
85
|
+
name: str
|
|
86
|
+
description: str
|
|
87
|
+
properties: dict[str, Any]
|
|
88
|
+
required: list[str]
|
|
89
|
+
kwargs: dict[str, Any] = {}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class ToolDefinition(BaseModel):
|
|
93
|
+
name: str # acts as a key
|
|
94
|
+
body: ToolBody | Any
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class ToolCall(BaseModel):
|
|
98
|
+
id: str
|
|
99
|
+
call_id: str | None = None
|
|
100
|
+
name: str
|
|
101
|
+
args: dict[str, Any] | str
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
"""
|
|
105
|
+
--- INPUT ---
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
RawResponse = Any
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class ToolInput(BaseModel):
|
|
112
|
+
tools: list[ToolDefinition] = []
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class ToolResult(BaseModel):
|
|
116
|
+
tool_call: ToolCall
|
|
117
|
+
result: Any
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class TextInput(BaseModel):
|
|
121
|
+
text: str
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
RawInputItem = dict[
|
|
125
|
+
str, Any
|
|
126
|
+
] # to pass in, for example, a mock convertsation with {"role": "user", "content": "Hello"}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
InputItem = (
|
|
130
|
+
TextInput | FileInput | ToolResult | RawInputItem | RawResponse
|
|
131
|
+
) # input item can either be a prompt, a file (image or file), a tool call result, raw input, or a previous response
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
"""
|
|
135
|
+
--- OUTPUT ---
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class Citation(BaseModel):
|
|
140
|
+
type: str | None = None
|
|
141
|
+
title: str | None = None
|
|
142
|
+
url: str | None = None
|
|
143
|
+
start_index: int | None = None
|
|
144
|
+
end_index: int | None = None
|
|
145
|
+
file_id: str | None = None
|
|
146
|
+
filename: str | None = None
|
|
147
|
+
index: int | None = None
|
|
148
|
+
container_id: str | None = None
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class QueryResultExtras(BaseModel):
|
|
152
|
+
citations: list[Citation] = Field(default_factory=list)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class QueryResultCost(BaseModel):
|
|
156
|
+
"""
|
|
157
|
+
Cost information for a query
|
|
158
|
+
Includes total cost and a structured breakdown.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
input: float
|
|
162
|
+
output: float
|
|
163
|
+
reasoning: float | None = None
|
|
164
|
+
cache_read: float | None = None
|
|
165
|
+
cache_write: float | None = None
|
|
166
|
+
|
|
167
|
+
@computed_field
|
|
168
|
+
@property
|
|
169
|
+
def total(self) -> float:
|
|
170
|
+
return sum(
|
|
171
|
+
filter(
|
|
172
|
+
None,
|
|
173
|
+
[
|
|
174
|
+
self.input,
|
|
175
|
+
self.output,
|
|
176
|
+
self.reasoning,
|
|
177
|
+
self.cache_read,
|
|
178
|
+
self.cache_write,
|
|
179
|
+
],
|
|
180
|
+
)
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
@override
|
|
184
|
+
def __repr__(self):
|
|
185
|
+
use_cents = self.total < 1
|
|
186
|
+
|
|
187
|
+
def format_cost(value: float | None):
|
|
188
|
+
if value is None:
|
|
189
|
+
return None
|
|
190
|
+
return f"{value * 100:.3f} cents" if use_cents else f"${value:.2f}"
|
|
191
|
+
|
|
192
|
+
return (
|
|
193
|
+
f"{format_cost(self.total)} "
|
|
194
|
+
+ f"(uncached input: {format_cost(self.input)} | output: {format_cost(self.output)} | reasoning: {format_cost(self.reasoning)} | cache_read: {format_cost(self.cache_read)} | cache_write: {format_cost(self.cache_write)})"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class QueryResultMetadata(BaseModel):
|
|
199
|
+
"""
|
|
200
|
+
Metadata for a query: token usage and timing.
|
|
201
|
+
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
cost: QueryResultCost | None = None # set post query
|
|
205
|
+
duration_seconds: float | None = None # set post query
|
|
206
|
+
in_tokens: int = 0
|
|
207
|
+
out_tokens: int = 0
|
|
208
|
+
reasoning_tokens: int | None = None
|
|
209
|
+
cache_read_tokens: int | None = None
|
|
210
|
+
cache_write_tokens: int | None = None
|
|
211
|
+
|
|
212
|
+
@property
|
|
213
|
+
def default_duration_seconds(self) -> float:
|
|
214
|
+
return self.duration_seconds or 0
|
|
215
|
+
|
|
216
|
+
def __add__(self, other: "QueryResultMetadata") -> "QueryResultMetadata":
|
|
217
|
+
return QueryResultMetadata(
|
|
218
|
+
in_tokens=self.in_tokens + other.in_tokens,
|
|
219
|
+
out_tokens=self.out_tokens + other.out_tokens,
|
|
220
|
+
reasoning_tokens=sum_optional(
|
|
221
|
+
self.reasoning_tokens, other.reasoning_tokens
|
|
222
|
+
),
|
|
223
|
+
cache_read_tokens=sum_optional(
|
|
224
|
+
self.cache_read_tokens, other.cache_read_tokens
|
|
225
|
+
),
|
|
226
|
+
cache_write_tokens=sum_optional(
|
|
227
|
+
self.cache_write_tokens, other.cache_write_tokens
|
|
228
|
+
),
|
|
229
|
+
duration_seconds=self.default_duration_seconds
|
|
230
|
+
+ other.default_duration_seconds,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
@override
|
|
234
|
+
def __repr__(self):
|
|
235
|
+
attrs = vars(self).copy()
|
|
236
|
+
return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2, sort_dicts=False)}\n)"
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class QueryResult(BaseModel):
|
|
240
|
+
"""
|
|
241
|
+
Result of a query
|
|
242
|
+
Contains the text, reasoning, metadata, tool calls, and history
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
output_text: str | None = None
|
|
246
|
+
reasoning: str | None = None
|
|
247
|
+
metadata: QueryResultMetadata = Field(default_factory=QueryResultMetadata)
|
|
248
|
+
tool_calls: list[ToolCall] = Field(default_factory=list)
|
|
249
|
+
history: list[InputItem] = Field(default_factory=list)
|
|
250
|
+
extras: QueryResultExtras = Field(default_factory=QueryResultExtras)
|
|
251
|
+
raw: dict[str, Any] = Field(default_factory=dict)
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def output_text_str(self) -> str:
|
|
255
|
+
return self.output_text or ""
|
|
256
|
+
|
|
257
|
+
@field_validator("reasoning", mode="before")
|
|
258
|
+
def default_reasoning(cls, v: str | None):
|
|
259
|
+
return None if not v else v # make reasoning None if empty
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def search_results(self) -> Any | None:
|
|
263
|
+
"""Expose provider-supplied search metadata without additional processing."""
|
|
264
|
+
raw_dict = cast(dict[str, Any], getattr(self, "raw", {}))
|
|
265
|
+
raw_candidate = raw_dict.get("search_results")
|
|
266
|
+
if raw_candidate is not None:
|
|
267
|
+
return raw_candidate
|
|
268
|
+
|
|
269
|
+
return _get_from_history(self.history, "search_results")
|
|
270
|
+
|
|
271
|
+
@override
|
|
272
|
+
def __repr__(self):
|
|
273
|
+
attrs = vars(self).copy()
|
|
274
|
+
ordered_attrs = {
|
|
275
|
+
"output_text": truncate_str(attrs.pop("output_text", None), 400),
|
|
276
|
+
"reasoning": truncate_str(attrs.pop("reasoning", None), 400),
|
|
277
|
+
"metadata": attrs.pop("metadata", None),
|
|
278
|
+
}
|
|
279
|
+
if self.tool_calls:
|
|
280
|
+
ordered_attrs["tool_calls"] = self.tool_calls
|
|
281
|
+
return f"{self.__class__.__name__}(\n{pformat(ordered_attrs, indent=2, sort_dicts=False)}\n)"
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _get_from_history(history: Sequence[InputItem], key: str) -> Any | None:
|
|
285
|
+
for item in reversed(history):
|
|
286
|
+
value = getattr(item, key, None)
|
|
287
|
+
if value is not None:
|
|
288
|
+
return value
|
|
289
|
+
|
|
290
|
+
extra = getattr(item, "model_extra", None)
|
|
291
|
+
if isinstance(extra, Mapping):
|
|
292
|
+
value = cast(Mapping[str, Any], extra).get(key)
|
|
293
|
+
if value is not None:
|
|
294
|
+
return value
|
|
295
|
+
|
|
296
|
+
return None
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
class ProviderConfig(BaseModel):
|
|
300
|
+
"""Base class for provider-specific configs. Do not use directly."""
|
|
301
|
+
|
|
302
|
+
@model_serializer(mode="plain")
|
|
303
|
+
def serialize_actual(self):
|
|
304
|
+
return self.__dict__
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class LLMConfig(BaseModel):
|
|
308
|
+
max_tokens: int = DEFAULT_MAX_TOKENS
|
|
309
|
+
temperature: float | None = None
|
|
310
|
+
top_p: float | None = None
|
|
311
|
+
reasoning: bool = False
|
|
312
|
+
reasoning_effort: str | None = None
|
|
313
|
+
supports_images: bool = False
|
|
314
|
+
supports_files: bool = False
|
|
315
|
+
supports_videos: bool = False
|
|
316
|
+
supports_batch: bool = False
|
|
317
|
+
supports_temperature: bool = True
|
|
318
|
+
supports_tools: bool = False
|
|
319
|
+
native: bool = True
|
|
320
|
+
provider_config: ProviderConfig | None = None
|
|
321
|
+
registry_key: str | None = None
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
RetrierType = Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]
|
|
325
|
+
|
|
326
|
+
R = TypeVar("R") # return type
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class LLM(ABC):
|
|
330
|
+
"""
|
|
331
|
+
Base class for all LLMs
|
|
332
|
+
LLM call errors should be raised as exceptions
|
|
333
|
+
"""
|
|
334
|
+
|
|
335
|
+
def __init__(
|
|
336
|
+
self,
|
|
337
|
+
model_name: str,
|
|
338
|
+
provider: str,
|
|
339
|
+
*,
|
|
340
|
+
config: LLMConfig | None = None,
|
|
341
|
+
):
|
|
342
|
+
self.instance_id = uuid.uuid4().hex[:8]
|
|
343
|
+
|
|
344
|
+
self.provider: str = provider
|
|
345
|
+
self.model_name: str = model_name
|
|
346
|
+
|
|
347
|
+
config = config or LLMConfig()
|
|
348
|
+
self._registry_key = config.registry_key
|
|
349
|
+
|
|
350
|
+
if config.provider_config:
|
|
351
|
+
self.provider_config = config.provider_config
|
|
352
|
+
|
|
353
|
+
self.max_tokens: int = config.max_tokens
|
|
354
|
+
self.temperature: float | None = config.temperature
|
|
355
|
+
self.top_p: float | None = config.top_p
|
|
356
|
+
|
|
357
|
+
self.reasoning: bool = config.reasoning
|
|
358
|
+
self.reasoning_effort: str | None = config.reasoning_effort
|
|
359
|
+
|
|
360
|
+
self.supports_files: bool = config.supports_files
|
|
361
|
+
self.supports_videos: bool = config.supports_videos
|
|
362
|
+
self.supports_images: bool = config.supports_images
|
|
363
|
+
self.supports_batch: bool = config.supports_batch
|
|
364
|
+
self.supports_temperature: bool = config.supports_temperature
|
|
365
|
+
self.supports_tools: bool = config.supports_tools
|
|
366
|
+
|
|
367
|
+
self.native: bool = config.native
|
|
368
|
+
self.delegate: "OpenAIModel | None" = None
|
|
369
|
+
self.batch: LLMBatchMixin | None = None
|
|
370
|
+
|
|
371
|
+
self.logger: logging.Logger = logging.getLogger(
|
|
372
|
+
f"llm.{provider}.{model_name}<instance={self.instance_id}>"
|
|
373
|
+
)
|
|
374
|
+
self.custom_retrier: Callable[..., RetrierType] | None = retry_llm_call
|
|
375
|
+
|
|
376
|
+
@override
|
|
377
|
+
def __repr__(self):
|
|
378
|
+
attrs = vars(self).copy()
|
|
379
|
+
attrs.pop("logger", None)
|
|
380
|
+
attrs.pop("custom_retrier", None)
|
|
381
|
+
attrs.pop("_key", None)
|
|
382
|
+
return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2, sort_dicts=False)}\n)"
|
|
383
|
+
|
|
384
|
+
@abstractmethod
|
|
385
|
+
def get_client(self) -> object:
|
|
386
|
+
"""Return the instance of the appropriate SDK client."""
|
|
387
|
+
...
|
|
388
|
+
|
|
389
|
+
@staticmethod
|
|
390
|
+
async def timer_wrapper(func: Callable[[], Awaitable[R]]) -> tuple[R, float]:
|
|
391
|
+
"""
|
|
392
|
+
Time the query
|
|
393
|
+
"""
|
|
394
|
+
start = time.perf_counter()
|
|
395
|
+
result = await func()
|
|
396
|
+
return result, round(time.perf_counter() - start, 4)
|
|
397
|
+
|
|
398
|
+
@staticmethod
|
|
399
|
+
async def immediate_retry_wrapper(
|
|
400
|
+
func: Callable[[], Awaitable[R]],
|
|
401
|
+
logger: logging.Logger,
|
|
402
|
+
) -> R:
|
|
403
|
+
"""
|
|
404
|
+
Retry the query immediately
|
|
405
|
+
"""
|
|
406
|
+
MAX_IMMEDIATE_RETRIES = 10
|
|
407
|
+
retries = 0
|
|
408
|
+
while True:
|
|
409
|
+
try:
|
|
410
|
+
return await func()
|
|
411
|
+
except ImmediateRetryException as e:
|
|
412
|
+
if retries >= MAX_IMMEDIATE_RETRIES:
|
|
413
|
+
logger.error(f"Query reached max immediate retries {retries}: {e}")
|
|
414
|
+
raise Exception(
|
|
415
|
+
f"Query reached max immediate retries {retries}: {e}"
|
|
416
|
+
) from e
|
|
417
|
+
retries += 1
|
|
418
|
+
|
|
419
|
+
logger.warning(
|
|
420
|
+
f"Query retried immediately {retries}/{MAX_IMMEDIATE_RETRIES}: {e}"
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
@staticmethod
|
|
424
|
+
async def backoff_retry_wrapper(
|
|
425
|
+
func: Callable[..., Awaitable[R]],
|
|
426
|
+
backoff_retrier: RetrierType | None,
|
|
427
|
+
) -> R:
|
|
428
|
+
"""
|
|
429
|
+
Retry the query with backoff
|
|
430
|
+
"""
|
|
431
|
+
if not backoff_retrier:
|
|
432
|
+
return await func()
|
|
433
|
+
return await backoff_retrier(func)()
|
|
434
|
+
|
|
435
|
+
async def delegate_query(
|
|
436
|
+
self,
|
|
437
|
+
input: Sequence[InputItem],
|
|
438
|
+
*,
|
|
439
|
+
tools: list[ToolDefinition] = [],
|
|
440
|
+
**kwargs: object,
|
|
441
|
+
) -> QueryResult:
|
|
442
|
+
if not self.delegate:
|
|
443
|
+
raise Exception("Delegate not set")
|
|
444
|
+
return await self.delegate._query_impl(input, tools=tools, **kwargs) # pyright: ignore[reportPrivateUsage]
|
|
445
|
+
|
|
446
|
+
async def query(
|
|
447
|
+
self,
|
|
448
|
+
input: Sequence[InputItem] | str,
|
|
449
|
+
*,
|
|
450
|
+
history: Sequence[InputItem] = [],
|
|
451
|
+
tools: list[ToolDefinition] = [],
|
|
452
|
+
# for backwards compatibility
|
|
453
|
+
files: list[FileInput] = [],
|
|
454
|
+
images: list[FileInput] = [],
|
|
455
|
+
**kwargs: object,
|
|
456
|
+
) -> QueryResult:
|
|
457
|
+
"""
|
|
458
|
+
Query the model
|
|
459
|
+
Join input with history
|
|
460
|
+
Log, Time, and Retry
|
|
461
|
+
"""
|
|
462
|
+
# format str input
|
|
463
|
+
if isinstance(input, str):
|
|
464
|
+
input = [TextInput(text=input)]
|
|
465
|
+
|
|
466
|
+
# prepends files and images to input
|
|
467
|
+
input = [*files, *images, *input]
|
|
468
|
+
|
|
469
|
+
# format input info
|
|
470
|
+
item_info = f"--- input ({len(input)}): {get_pretty_input_types(input)}\n"
|
|
471
|
+
if history:
|
|
472
|
+
item_info += (
|
|
473
|
+
f"--- history({len(history)}): {get_pretty_input_types(history)}\n"
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
# format tool info
|
|
477
|
+
tool_results = [t for t in input if isinstance(t, ToolResult)]
|
|
478
|
+
tool_names = [tool.name for tool in tools or []]
|
|
479
|
+
|
|
480
|
+
tool_info = (
|
|
481
|
+
f"--- tools ({len(tools)}): {tool_names}\n"
|
|
482
|
+
+ f"--- tool results ({len(tool_results)}): "
|
|
483
|
+
+ f"{[{tool.tool_call.name: truncate_str(str(tool.result))} for tool in tool_results]}\n"
|
|
484
|
+
if tools
|
|
485
|
+
else ""
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
short_kwargs = {k: truncate_str(repr(v)) for k, v in kwargs.items()}
|
|
489
|
+
|
|
490
|
+
# join input with history
|
|
491
|
+
input = [*history, *input]
|
|
492
|
+
|
|
493
|
+
# unique logger for the query
|
|
494
|
+
query_id = uuid.uuid4().hex[:14]
|
|
495
|
+
query_logger = logging.getLogger(f"{self.logger.name}<query={query_id}>")
|
|
496
|
+
|
|
497
|
+
query_logger.info(
|
|
498
|
+
"Query started:\n" + item_info + tool_info + f"--- kwargs: {short_kwargs}\n"
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
async def query_func() -> QueryResult:
|
|
502
|
+
return await self._query_impl(input, tools=tools, **kwargs)
|
|
503
|
+
|
|
504
|
+
async def timed_query() -> tuple[QueryResult, float]:
|
|
505
|
+
return await LLM.timer_wrapper(query_func)
|
|
506
|
+
|
|
507
|
+
async def immediate_retry() -> tuple[QueryResult, float]:
|
|
508
|
+
return await LLM.immediate_retry_wrapper(timed_query, query_logger)
|
|
509
|
+
|
|
510
|
+
async def backoff_retry() -> tuple[QueryResult, float]:
|
|
511
|
+
backoff_retrier = (
|
|
512
|
+
self.custom_retrier(query_logger) if self.custom_retrier else None
|
|
513
|
+
)
|
|
514
|
+
return await LLM.backoff_retry_wrapper(immediate_retry, backoff_retrier)
|
|
515
|
+
|
|
516
|
+
output, duration = await backoff_retry()
|
|
517
|
+
output.metadata.duration_seconds = duration
|
|
518
|
+
output.metadata.cost = await self._calculate_cost(output.metadata)
|
|
519
|
+
|
|
520
|
+
query_logger.info(f"Query completed: {repr(output)}")
|
|
521
|
+
|
|
522
|
+
return output
|
|
523
|
+
|
|
524
|
+
async def query_json(
|
|
525
|
+
self,
|
|
526
|
+
input: Sequence[InputItem],
|
|
527
|
+
pydantic_model: type[PydanticT],
|
|
528
|
+
**kwargs: object,
|
|
529
|
+
) -> PydanticT:
|
|
530
|
+
"""Query the model with JSON response format using Pydantic model.
|
|
531
|
+
|
|
532
|
+
This is a convenience method that is not implemented for all providers.
|
|
533
|
+
Only OpenAI and Google providers currently support this method.
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
input: Input items (text, files, etc.)
|
|
537
|
+
pydantic_model: Pydantic model class defining the expected response structure
|
|
538
|
+
**kwargs: Additional arguments passed to the query method
|
|
539
|
+
|
|
540
|
+
Returns:
|
|
541
|
+
Instance of the pydantic_model with the model's response
|
|
542
|
+
|
|
543
|
+
Raises:
|
|
544
|
+
NotImplementedError: If the provider does not support structured JSON output
|
|
545
|
+
"""
|
|
546
|
+
raise NotImplementedError(
|
|
547
|
+
f"query_json is not implemented for {self.__class__.__name__}. "
|
|
548
|
+
f"Only OpenAI and Google providers currently support this method."
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
async def _calculate_cost(
|
|
552
|
+
self,
|
|
553
|
+
metadata: QueryResultMetadata,
|
|
554
|
+
batch: bool = False,
|
|
555
|
+
bill_reasoning: bool = True,
|
|
556
|
+
) -> QueryResultCost | None:
|
|
557
|
+
"""Calculate cost for a query"""
|
|
558
|
+
from model_library.registry_utils import get_model_cost
|
|
559
|
+
|
|
560
|
+
if not self._registry_key:
|
|
561
|
+
self.logger.warning("Model has no registry key, skipping cost calculation")
|
|
562
|
+
return None
|
|
563
|
+
|
|
564
|
+
costs = get_model_cost(self._registry_key)
|
|
565
|
+
if not costs:
|
|
566
|
+
return None
|
|
567
|
+
|
|
568
|
+
MILLION = 1_000_000
|
|
569
|
+
|
|
570
|
+
# base input and output
|
|
571
|
+
if costs.input is None or costs.output is None:
|
|
572
|
+
raise Exception("Base costs not set")
|
|
573
|
+
input_cost = costs.input
|
|
574
|
+
output_cost = costs.output
|
|
575
|
+
|
|
576
|
+
# apply fixed values or discounts/markup
|
|
577
|
+
# applied before other price changes
|
|
578
|
+
cache_read_cost, cache_write_cost = None, None
|
|
579
|
+
if metadata.cache_read_tokens or metadata.cache_write_tokens:
|
|
580
|
+
if not costs.cache:
|
|
581
|
+
raise Exception("Cache costs not set")
|
|
582
|
+
cache_read_cost, cache_write_cost = costs.cache.get_costs(
|
|
583
|
+
input_cost, output_cost
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
# costs for long context
|
|
587
|
+
total_in = metadata.in_tokens + (metadata.cache_read_tokens or 0)
|
|
588
|
+
if costs.context and total_in > costs.context.threshold:
|
|
589
|
+
input_cost, output_cost = costs.context.get_costs(
|
|
590
|
+
input_cost,
|
|
591
|
+
output_cost,
|
|
592
|
+
total_in,
|
|
593
|
+
)
|
|
594
|
+
if costs.context.cache:
|
|
595
|
+
cache_read_cost, cache_write_cost = costs.context.cache.get_costs(
|
|
596
|
+
input_cost, output_cost
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
# costs for batching
|
|
600
|
+
if batch:
|
|
601
|
+
if not costs.batch:
|
|
602
|
+
raise Exception("Batch costs not set")
|
|
603
|
+
input_cost, output_cost = costs.batch.get_costs(input_cost, output_cost)
|
|
604
|
+
|
|
605
|
+
return QueryResultCost(
|
|
606
|
+
input=input_cost * metadata.in_tokens / MILLION,
|
|
607
|
+
output=output_cost * metadata.out_tokens / MILLION,
|
|
608
|
+
reasoning=output_cost * metadata.reasoning_tokens / MILLION
|
|
609
|
+
if metadata.reasoning_tokens is not None and bill_reasoning
|
|
610
|
+
else None,
|
|
611
|
+
cache_read=cache_read_cost * metadata.cache_read_tokens / MILLION
|
|
612
|
+
if metadata.cache_read_tokens is not None and cache_read_cost
|
|
613
|
+
else None,
|
|
614
|
+
cache_write=cache_write_cost * metadata.cache_write_tokens / MILLION
|
|
615
|
+
if metadata.cache_write_tokens is not None and cache_write_cost
|
|
616
|
+
else None,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
@abstractmethod
|
|
620
|
+
async def _query_impl(
|
|
621
|
+
self,
|
|
622
|
+
input: Sequence[InputItem],
|
|
623
|
+
*,
|
|
624
|
+
tools: list[ToolDefinition],
|
|
625
|
+
**kwargs: object, # TODO: pass in query logger
|
|
626
|
+
) -> QueryResult:
|
|
627
|
+
"""
|
|
628
|
+
Query the model with input
|
|
629
|
+
Input can consist on text, images, files, or model specific raw responses
|
|
630
|
+
Optionally pass in tools
|
|
631
|
+
Kwargs will be passed to the model call (apart from exceptions like system_prompt)
|
|
632
|
+
Images and files should be preprocessed according to what the model supports:
|
|
633
|
+
- base64
|
|
634
|
+
- url
|
|
635
|
+
- file_id
|
|
636
|
+
"""
|
|
637
|
+
...
|
|
638
|
+
|
|
639
|
+
@abstractmethod
|
|
640
|
+
async def parse_input(
|
|
641
|
+
self,
|
|
642
|
+
input: Sequence[InputItem],
|
|
643
|
+
**kwargs: Any,
|
|
644
|
+
) -> Any:
|
|
645
|
+
"""
|
|
646
|
+
Parses input into the appropriate format for the model
|
|
647
|
+
Handles prompts, images, and files
|
|
648
|
+
Handles history and tool call results
|
|
649
|
+
Calls
|
|
650
|
+
- parse_image
|
|
651
|
+
- parse_file
|
|
652
|
+
"""
|
|
653
|
+
...
|
|
654
|
+
|
|
655
|
+
@abstractmethod
|
|
656
|
+
async def parse_image(self, image: FileInput) -> Any:
|
|
657
|
+
"""Parse an image into the appropriate format for the model"""
|
|
658
|
+
...
|
|
659
|
+
|
|
660
|
+
@abstractmethod
|
|
661
|
+
async def parse_file(self, file: FileInput) -> Any:
|
|
662
|
+
"""Parse a file into the appropriate format for the model"""
|
|
663
|
+
...
|
|
664
|
+
|
|
665
|
+
@abstractmethod
|
|
666
|
+
async def parse_tools(self, tools: list[ToolDefinition]) -> Any:
|
|
667
|
+
"""Parse tools into the appropriate format for the model"""
|
|
668
|
+
...
|
|
669
|
+
|
|
670
|
+
@abstractmethod
|
|
671
|
+
async def upload_file(
|
|
672
|
+
self,
|
|
673
|
+
name: str,
|
|
674
|
+
mime: str,
|
|
675
|
+
bytes: io.BytesIO,
|
|
676
|
+
type: Literal["image", "file"] = "file",
|
|
677
|
+
) -> FileWithId:
|
|
678
|
+
"""Upload a file to the model provider"""
|
|
679
|
+
...
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
class BatchResult(BaseModel):
|
|
683
|
+
custom_id: str
|
|
684
|
+
output: QueryResult
|
|
685
|
+
error_message: str | None = None
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
class LLMBatchMixin(ABC):
|
|
689
|
+
@abstractmethod
|
|
690
|
+
async def create_batch_query_request(
|
|
691
|
+
self,
|
|
692
|
+
custom_id: str,
|
|
693
|
+
input: Sequence[InputItem],
|
|
694
|
+
**kwargs: object,
|
|
695
|
+
) -> dict[str, Any]:
|
|
696
|
+
"""Return a single query request
|
|
697
|
+
|
|
698
|
+
The batch api sends out a batch of query requests to various endpoints.
|
|
699
|
+
|
|
700
|
+
For example OpenAI sends can send requests to /v1/responses or /v1/chat/completions endpoints.
|
|
701
|
+
|
|
702
|
+
This method creates a query request for methods such methods
|
|
703
|
+
"""
|
|
704
|
+
...
|
|
705
|
+
|
|
706
|
+
@abstractmethod
|
|
707
|
+
async def batch_query(
|
|
708
|
+
self,
|
|
709
|
+
batch_name: str,
|
|
710
|
+
requests: list[dict[str, Any]],
|
|
711
|
+
) -> str:
|
|
712
|
+
"""
|
|
713
|
+
Batch query the model
|
|
714
|
+
Returns:
|
|
715
|
+
str: batch_id
|
|
716
|
+
Raises:
|
|
717
|
+
Exception: If failed to batch query
|
|
718
|
+
"""
|
|
719
|
+
...
|
|
720
|
+
|
|
721
|
+
@abstractmethod
|
|
722
|
+
async def get_batch_results(self, batch_id: str) -> list[BatchResult]:
|
|
723
|
+
"""
|
|
724
|
+
Returns results for batch
|
|
725
|
+
Raises:
|
|
726
|
+
Exception: If failed to get results
|
|
727
|
+
"""
|
|
728
|
+
...
|
|
729
|
+
|
|
730
|
+
@abstractmethod
|
|
731
|
+
async def get_batch_progress(self, batch_id: str) -> int:
|
|
732
|
+
"""
|
|
733
|
+
Returns number of completed requests for batch
|
|
734
|
+
Raises:
|
|
735
|
+
Exception: If failed to get progress
|
|
736
|
+
"""
|
|
737
|
+
...
|
|
738
|
+
|
|
739
|
+
@abstractmethod
|
|
740
|
+
async def cancel_batch_request(self, batch_id: str) -> None:
|
|
741
|
+
"""
|
|
742
|
+
Cancels batch
|
|
743
|
+
Raises:
|
|
744
|
+
Exception: If failed to cancel
|
|
745
|
+
"""
|
|
746
|
+
...
|
|
747
|
+
|
|
748
|
+
@abstractmethod
|
|
749
|
+
async def get_batch_status(
|
|
750
|
+
self,
|
|
751
|
+
batch_id: str,
|
|
752
|
+
) -> str:
|
|
753
|
+
"""
|
|
754
|
+
Returns batch status
|
|
755
|
+
Raises:
|
|
756
|
+
Exception: If failed to get status
|
|
757
|
+
"""
|
|
758
|
+
...
|
|
759
|
+
|
|
760
|
+
@classmethod
|
|
761
|
+
@abstractmethod
|
|
762
|
+
def is_batch_status_completed(
|
|
763
|
+
cls,
|
|
764
|
+
batch_status: str,
|
|
765
|
+
) -> bool:
|
|
766
|
+
"""
|
|
767
|
+
Returns if batch status is completed
|
|
768
|
+
|
|
769
|
+
A completed state is any state that is final and not in-progress
|
|
770
|
+
Example: failed | cancelled | expired | completed
|
|
771
|
+
|
|
772
|
+
An incompleted state is any state that is not completed
|
|
773
|
+
Example: in_progress | pending | running
|
|
774
|
+
"""
|
|
775
|
+
...
|
|
776
|
+
|
|
777
|
+
@classmethod
|
|
778
|
+
@abstractmethod
|
|
779
|
+
def is_batch_status_failed(
|
|
780
|
+
cls,
|
|
781
|
+
batch_status: str,
|
|
782
|
+
) -> bool:
|
|
783
|
+
"""Returns if batch status is failed"""
|
|
784
|
+
...
|
|
785
|
+
|
|
786
|
+
@classmethod
|
|
787
|
+
@abstractmethod
|
|
788
|
+
def is_batch_status_cancelled(
|
|
789
|
+
cls,
|
|
790
|
+
batch_status: str,
|
|
791
|
+
) -> bool:
|
|
792
|
+
"""Returns if batch status is cancelled"""
|
|
793
|
+
...
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
def get_pretty_input_types(input: Sequence["InputItem"]) -> str:
|
|
797
|
+
# for logging
|
|
798
|
+
def process_item(item: "InputItem"):
|
|
799
|
+
match item:
|
|
800
|
+
case TextInput():
|
|
801
|
+
return truncate_str(repr(item))
|
|
802
|
+
case FileBase(): # FileInput
|
|
803
|
+
return repr(item)
|
|
804
|
+
case ToolResult():
|
|
805
|
+
return repr(item)
|
|
806
|
+
case dict():
|
|
807
|
+
item = cast(RawInputItem, item)
|
|
808
|
+
return repr(item)
|
|
809
|
+
case _:
|
|
810
|
+
# RawResponse
|
|
811
|
+
return repr(item)
|
|
812
|
+
|
|
813
|
+
processed_items = [f" {process_item(item)}" for item in input]
|
|
814
|
+
return "\n" + "\n".join(processed_items) if processed_items else ""
|