model-library 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +237 -62
- model_library/base/delegate_only.py +86 -9
- model_library/base/input.py +10 -7
- model_library/base/output.py +48 -0
- model_library/base/utils.py +56 -7
- model_library/config/alibaba_models.yaml +44 -57
- model_library/config/all_models.json +253 -126
- model_library/config/kimi_models.yaml +30 -3
- model_library/config/openai_models.yaml +15 -23
- model_library/config/zai_models.yaml +24 -3
- model_library/exceptions.py +14 -77
- model_library/logging.py +6 -2
- model_library/providers/ai21labs.py +30 -14
- model_library/providers/alibaba.py +17 -8
- model_library/providers/amazon.py +119 -64
- model_library/providers/anthropic.py +184 -104
- model_library/providers/azure.py +22 -10
- model_library/providers/cohere.py +7 -7
- model_library/providers/deepseek.py +8 -8
- model_library/providers/fireworks.py +7 -8
- model_library/providers/google/batch.py +17 -13
- model_library/providers/google/google.py +130 -73
- model_library/providers/inception.py +7 -7
- model_library/providers/kimi.py +18 -8
- model_library/providers/minimax.py +30 -13
- model_library/providers/mistral.py +61 -35
- model_library/providers/openai.py +219 -93
- model_library/providers/openrouter.py +34 -0
- model_library/providers/perplexity.py +7 -7
- model_library/providers/together.py +7 -8
- model_library/providers/vals.py +16 -9
- model_library/providers/xai.py +157 -144
- model_library/providers/zai.py +38 -8
- model_library/register_models.py +4 -2
- model_library/registry_utils.py +39 -15
- model_library/retriers/__init__.py +0 -0
- model_library/retriers/backoff.py +73 -0
- model_library/retriers/base.py +225 -0
- model_library/retriers/token.py +427 -0
- model_library/retriers/utils.py +11 -0
- model_library/settings.py +1 -1
- model_library/utils.py +13 -35
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/METADATA +4 -3
- model_library-0.1.8.dist-info/RECORD +70 -0
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
- model_library-0.1.6.dist-info/RECORD +0 -64
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
model_library/base/base.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
1
|
+
import hashlib
|
|
1
2
|
import io
|
|
2
3
|
import logging
|
|
4
|
+
import threading
|
|
3
5
|
import time
|
|
4
6
|
import uuid
|
|
5
7
|
from abc import ABC, abstractmethod
|
|
6
8
|
from collections.abc import Awaitable
|
|
9
|
+
from math import ceil
|
|
7
10
|
from pprint import pformat
|
|
8
11
|
from typing import (
|
|
9
12
|
Any,
|
|
@@ -13,8 +16,10 @@ from typing import (
|
|
|
13
16
|
TypeVar,
|
|
14
17
|
)
|
|
15
18
|
|
|
16
|
-
|
|
19
|
+
import tiktoken
|
|
20
|
+
from pydantic import SecretStr, model_serializer
|
|
17
21
|
from pydantic.main import BaseModel
|
|
22
|
+
from tiktoken.core import Encoding
|
|
18
23
|
from typing_extensions import override
|
|
19
24
|
|
|
20
25
|
from model_library.base.batch import (
|
|
@@ -32,14 +37,15 @@ from model_library.base.output import (
|
|
|
32
37
|
QueryResult,
|
|
33
38
|
QueryResultCost,
|
|
34
39
|
QueryResultMetadata,
|
|
40
|
+
RateLimit,
|
|
35
41
|
)
|
|
36
42
|
from model_library.base.utils import (
|
|
37
43
|
get_pretty_input_types,
|
|
44
|
+
serialize_for_tokenizing,
|
|
38
45
|
)
|
|
39
|
-
from model_library.
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
)
|
|
46
|
+
from model_library.retriers.backoff import ExponentialBackoffRetrier
|
|
47
|
+
from model_library.retriers.base import BaseRetrier, R, RetrierType, retry_decorator
|
|
48
|
+
from model_library.retriers.token import TokenRetrier
|
|
43
49
|
from model_library.utils import truncate_str
|
|
44
50
|
|
|
45
51
|
PydanticT = TypeVar("PydanticT", bound=BaseModel)
|
|
@@ -53,11 +59,18 @@ class ProviderConfig(BaseModel):
|
|
|
53
59
|
return self.__dict__
|
|
54
60
|
|
|
55
61
|
|
|
56
|
-
|
|
62
|
+
class TokenRetryParams(BaseModel):
|
|
63
|
+
input_modifier: float
|
|
64
|
+
output_modifier: float
|
|
65
|
+
|
|
66
|
+
use_dynamic_estimate: bool = True
|
|
67
|
+
|
|
68
|
+
limit: int
|
|
69
|
+
limit_refresh_seconds: Literal[60] = 60
|
|
57
70
|
|
|
58
71
|
|
|
59
72
|
class LLMConfig(BaseModel):
|
|
60
|
-
max_tokens: int =
|
|
73
|
+
max_tokens: int | None = None
|
|
61
74
|
temperature: float | None = None
|
|
62
75
|
top_p: float | None = None
|
|
63
76
|
top_k: int | None = None
|
|
@@ -72,11 +85,18 @@ class LLMConfig(BaseModel):
|
|
|
72
85
|
native: bool = True
|
|
73
86
|
provider_config: ProviderConfig | None = None
|
|
74
87
|
registry_key: str | None = None
|
|
88
|
+
custom_api_key: SecretStr | None = None
|
|
75
89
|
|
|
76
90
|
|
|
77
|
-
|
|
91
|
+
class DelegateConfig(BaseModel):
|
|
92
|
+
base_url: str
|
|
93
|
+
api_key: SecretStr
|
|
78
94
|
|
|
79
|
-
|
|
95
|
+
|
|
96
|
+
# shared across all subclasses and instances
|
|
97
|
+
# hash(provider + api_key) -> client
|
|
98
|
+
client_registry_lock = threading.Lock()
|
|
99
|
+
client_registry: dict[tuple[str, str], Any] = {}
|
|
80
100
|
|
|
81
101
|
|
|
82
102
|
class LLM(ABC):
|
|
@@ -85,6 +105,34 @@ class LLM(ABC):
|
|
|
85
105
|
LLM call errors should be raised as exceptions
|
|
86
106
|
"""
|
|
87
107
|
|
|
108
|
+
@abstractmethod
|
|
109
|
+
def get_client(self, api_key: str | None = None) -> Any:
|
|
110
|
+
"""
|
|
111
|
+
Returns the cached instance of the appropriate SDK client.
|
|
112
|
+
Sublasses should implement this method and:
|
|
113
|
+
- if api_key is provided, initialize their client and call assing_client(client).
|
|
114
|
+
- else return super().get_client()
|
|
115
|
+
"""
|
|
116
|
+
global client_registry
|
|
117
|
+
return client_registry[self._client_registry_key]
|
|
118
|
+
|
|
119
|
+
def assign_client(self, client: object) -> None:
|
|
120
|
+
"""Thread-safe assignment to the client registry"""
|
|
121
|
+
global client_registry
|
|
122
|
+
|
|
123
|
+
if self._client_registry_key not in client_registry:
|
|
124
|
+
with client_registry_lock:
|
|
125
|
+
if self._client_registry_key not in client_registry:
|
|
126
|
+
client_registry[self._client_registry_key] = client
|
|
127
|
+
|
|
128
|
+
def has_client(self) -> bool:
|
|
129
|
+
return self._client_registry_key in client_registry
|
|
130
|
+
|
|
131
|
+
@abstractmethod
|
|
132
|
+
def _get_default_api_key(self) -> str:
|
|
133
|
+
"""Return the api key from model_library.settings"""
|
|
134
|
+
...
|
|
135
|
+
|
|
88
136
|
def __init__(
|
|
89
137
|
self,
|
|
90
138
|
model_name: str,
|
|
@@ -100,7 +148,7 @@ class LLM(ABC):
|
|
|
100
148
|
config = config or LLMConfig()
|
|
101
149
|
self._registry_key = config.registry_key
|
|
102
150
|
|
|
103
|
-
self.max_tokens: int = config.max_tokens
|
|
151
|
+
self.max_tokens: int | None = config.max_tokens
|
|
104
152
|
self.temperature: float | None = config.temperature
|
|
105
153
|
self.top_p: float | None = config.top_p
|
|
106
154
|
self.top_k: int | None = config.top_k
|
|
@@ -128,21 +176,33 @@ class LLM(ABC):
|
|
|
128
176
|
self.logger: logging.Logger = logging.getLogger(
|
|
129
177
|
f"llm.{provider}.{model_name}<instance={self.instance_id}>"
|
|
130
178
|
)
|
|
131
|
-
self.custom_retrier:
|
|
179
|
+
self.custom_retrier: RetrierType | None = None
|
|
180
|
+
|
|
181
|
+
self.token_retry_params = None
|
|
182
|
+
# set _client_registry_key after initializing delegate
|
|
183
|
+
if not self.native:
|
|
184
|
+
return
|
|
185
|
+
|
|
186
|
+
if config.custom_api_key:
|
|
187
|
+
raw_key = config.custom_api_key.get_secret_value()
|
|
188
|
+
else:
|
|
189
|
+
raw_key = self._get_default_api_key()
|
|
190
|
+
|
|
191
|
+
key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
|
|
192
|
+
self._client_registry_key = (self.provider, key_hash)
|
|
193
|
+
self._client_registry_key_model_specific = (
|
|
194
|
+
f"{self.provider}.{self.model_name}",
|
|
195
|
+
key_hash,
|
|
196
|
+
)
|
|
197
|
+
self.get_client(api_key=raw_key)
|
|
132
198
|
|
|
133
199
|
@override
|
|
134
200
|
def __repr__(self):
|
|
135
201
|
attrs = vars(self).copy()
|
|
136
202
|
attrs.pop("logger", None)
|
|
137
203
|
attrs.pop("custom_retrier", None)
|
|
138
|
-
attrs.pop("_key", None)
|
|
139
204
|
return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2, sort_dicts=False)}\n)"
|
|
140
205
|
|
|
141
|
-
@abstractmethod
|
|
142
|
-
def get_client(self) -> object:
|
|
143
|
-
"""Return the instance of the appropriate SDK client."""
|
|
144
|
-
...
|
|
145
|
-
|
|
146
206
|
@staticmethod
|
|
147
207
|
async def timer_wrapper(func: Callable[[], Awaitable[R]]) -> tuple[R, float]:
|
|
148
208
|
"""
|
|
@@ -152,43 +212,6 @@ class LLM(ABC):
|
|
|
152
212
|
result = await func()
|
|
153
213
|
return result, round(time.perf_counter() - start, 4)
|
|
154
214
|
|
|
155
|
-
@staticmethod
|
|
156
|
-
async def immediate_retry_wrapper(
|
|
157
|
-
func: Callable[[], Awaitable[R]],
|
|
158
|
-
logger: logging.Logger,
|
|
159
|
-
) -> R:
|
|
160
|
-
"""
|
|
161
|
-
Retry the query immediately
|
|
162
|
-
"""
|
|
163
|
-
MAX_IMMEDIATE_RETRIES = 10
|
|
164
|
-
retries = 0
|
|
165
|
-
while True:
|
|
166
|
-
try:
|
|
167
|
-
return await func()
|
|
168
|
-
except ImmediateRetryException as e:
|
|
169
|
-
if retries >= MAX_IMMEDIATE_RETRIES:
|
|
170
|
-
logger.error(f"Query reached max immediate retries {retries}: {e}")
|
|
171
|
-
raise Exception(
|
|
172
|
-
f"Query reached max immediate retries {retries}: {e}"
|
|
173
|
-
) from e
|
|
174
|
-
retries += 1
|
|
175
|
-
|
|
176
|
-
logger.warning(
|
|
177
|
-
f"Query retried immediately {retries}/{MAX_IMMEDIATE_RETRIES}: {e}"
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
@staticmethod
|
|
181
|
-
async def backoff_retry_wrapper(
|
|
182
|
-
func: Callable[..., Awaitable[R]],
|
|
183
|
-
backoff_retrier: RetrierType | None,
|
|
184
|
-
) -> R:
|
|
185
|
-
"""
|
|
186
|
-
Retry the query with backoff
|
|
187
|
-
"""
|
|
188
|
-
if not backoff_retrier:
|
|
189
|
-
return await func()
|
|
190
|
-
return await backoff_retrier(func)()
|
|
191
|
-
|
|
192
215
|
async def delegate_query(
|
|
193
216
|
self,
|
|
194
217
|
input: Sequence[InputItem],
|
|
@@ -273,15 +296,38 @@ class LLM(ABC):
|
|
|
273
296
|
return await LLM.timer_wrapper(query_func)
|
|
274
297
|
|
|
275
298
|
async def immediate_retry() -> tuple[QueryResult, float]:
|
|
276
|
-
return await
|
|
277
|
-
|
|
278
|
-
async def
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
299
|
+
return await BaseRetrier.immediate_retry_wrapper(timed_query, query_logger)
|
|
300
|
+
|
|
301
|
+
async def default_retry() -> tuple[QueryResult, float]:
|
|
302
|
+
if self.token_retry_params:
|
|
303
|
+
(
|
|
304
|
+
estimate_input_tokens,
|
|
305
|
+
estimate_output_tokens,
|
|
306
|
+
) = await self.estimate_query_tokens(
|
|
307
|
+
input,
|
|
308
|
+
tools=tools,
|
|
309
|
+
**kwargs,
|
|
310
|
+
)
|
|
311
|
+
retrier = TokenRetrier(
|
|
312
|
+
logger=query_logger,
|
|
313
|
+
client_registry_key=self._client_registry_key_model_specific,
|
|
314
|
+
estimate_input_tokens=estimate_input_tokens,
|
|
315
|
+
estimate_output_tokens=estimate_output_tokens,
|
|
316
|
+
dynamic_estimate_instance_id=self.instance_id
|
|
317
|
+
if self.token_retry_params.use_dynamic_estimate
|
|
318
|
+
else None,
|
|
319
|
+
)
|
|
320
|
+
else:
|
|
321
|
+
retrier = ExponentialBackoffRetrier(logger=query_logger)
|
|
322
|
+
return await retry_decorator(retrier)(immediate_retry)()
|
|
323
|
+
|
|
324
|
+
run_with_retry = (
|
|
325
|
+
default_retry
|
|
326
|
+
if not self.custom_retrier
|
|
327
|
+
else self.custom_retrier(immediate_retry)
|
|
328
|
+
)
|
|
283
329
|
|
|
284
|
-
output, duration = await
|
|
330
|
+
output, duration = await run_with_retry()
|
|
285
331
|
output.metadata.duration_seconds = duration
|
|
286
332
|
output.metadata.cost = await self._calculate_cost(output.metadata)
|
|
287
333
|
|
|
@@ -290,6 +336,16 @@ class LLM(ABC):
|
|
|
290
336
|
|
|
291
337
|
return output
|
|
292
338
|
|
|
339
|
+
async def init_token_retry(self, token_retry_params: TokenRetryParams) -> None:
|
|
340
|
+
self.token_retry_params = token_retry_params
|
|
341
|
+
await TokenRetrier.init_remaining_tokens(
|
|
342
|
+
client_registry_key=self._client_registry_key_model_specific,
|
|
343
|
+
limit=self.token_retry_params.limit,
|
|
344
|
+
limit_refresh_seconds=self.token_retry_params.limit_refresh_seconds,
|
|
345
|
+
get_rate_limit_func=self.get_rate_limit,
|
|
346
|
+
logger=self.logger,
|
|
347
|
+
)
|
|
348
|
+
|
|
293
349
|
async def _calculate_cost(
|
|
294
350
|
self,
|
|
295
351
|
metadata: QueryResultMetadata,
|
|
@@ -379,6 +435,20 @@ class LLM(ABC):
|
|
|
379
435
|
"""
|
|
380
436
|
...
|
|
381
437
|
|
|
438
|
+
@abstractmethod
|
|
439
|
+
async def build_body(
|
|
440
|
+
self,
|
|
441
|
+
input: Sequence[InputItem],
|
|
442
|
+
*,
|
|
443
|
+
tools: list[ToolDefinition],
|
|
444
|
+
**kwargs: Any,
|
|
445
|
+
) -> dict[str, Any]:
|
|
446
|
+
"""
|
|
447
|
+
Builds the body of the request to the model provider
|
|
448
|
+
Calls parse_input
|
|
449
|
+
"""
|
|
450
|
+
...
|
|
451
|
+
|
|
382
452
|
@abstractmethod
|
|
383
453
|
async def parse_input(
|
|
384
454
|
self,
|
|
@@ -421,6 +491,111 @@ class LLM(ABC):
|
|
|
421
491
|
"""Upload a file to the model provider"""
|
|
422
492
|
...
|
|
423
493
|
|
|
494
|
+
async def get_rate_limit(self) -> RateLimit | None:
|
|
495
|
+
"""Get the rate limit for the model provider"""
|
|
496
|
+
return None
|
|
497
|
+
|
|
498
|
+
async def estimate_query_tokens(
|
|
499
|
+
self,
|
|
500
|
+
input: Sequence[InputItem],
|
|
501
|
+
*,
|
|
502
|
+
tools: list[ToolDefinition] = [],
|
|
503
|
+
**kwargs: object,
|
|
504
|
+
) -> tuple[int, int]:
|
|
505
|
+
"""Pessimistically estimate the number of tokens required for a query"""
|
|
506
|
+
assert self.token_retry_params
|
|
507
|
+
|
|
508
|
+
# TODO: when passing in images and files, we really need to take that into account when calculating the output tokens!!
|
|
509
|
+
|
|
510
|
+
input_tokens = (
|
|
511
|
+
await self.count_tokens(input, history=[], tools=tools, **kwargs)
|
|
512
|
+
* self.token_retry_params.input_modifier
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
output_tokens = input_tokens * self.token_retry_params.output_modifier
|
|
516
|
+
return ceil(input_tokens), ceil(output_tokens)
|
|
517
|
+
|
|
518
|
+
async def get_encoding(self) -> Encoding:
|
|
519
|
+
"""Get the appropriate tokenizer"""
|
|
520
|
+
|
|
521
|
+
model = self.model_name.lower()
|
|
522
|
+
|
|
523
|
+
if any(x in model for x in ["gpt-4o", "o1", "o3", "gpt-4.1", "gpt-5"]):
|
|
524
|
+
return tiktoken.get_encoding("o200k_base")
|
|
525
|
+
elif "gpt-4" in model or "gpt-3.5" in model:
|
|
526
|
+
try:
|
|
527
|
+
return tiktoken.encoding_for_model(self.model_name)
|
|
528
|
+
except KeyError:
|
|
529
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
530
|
+
elif "claude" in model:
|
|
531
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
532
|
+
elif "gemini" in model:
|
|
533
|
+
return tiktoken.get_encoding("o200k_base")
|
|
534
|
+
elif "llama" in model or "mistral" in model:
|
|
535
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
536
|
+
else:
|
|
537
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
538
|
+
|
|
539
|
+
async def stringify_input(
|
|
540
|
+
self,
|
|
541
|
+
input: Sequence[InputItem],
|
|
542
|
+
*,
|
|
543
|
+
history: Sequence[InputItem] = [],
|
|
544
|
+
tools: list[ToolDefinition] = [],
|
|
545
|
+
**kwargs: object,
|
|
546
|
+
) -> str:
|
|
547
|
+
input = [*history, *input]
|
|
548
|
+
|
|
549
|
+
system_prompt = kwargs.pop(
|
|
550
|
+
"system_prompt", ""
|
|
551
|
+
) # TODO: refactor along with system prompt arg change
|
|
552
|
+
|
|
553
|
+
# special case if using a delegate
|
|
554
|
+
# don't inherit method override by default
|
|
555
|
+
if self.delegate:
|
|
556
|
+
parsed_input = await self.delegate.parse_input(input, **kwargs)
|
|
557
|
+
parsed_tools = await self.delegate.parse_tools(tools)
|
|
558
|
+
else:
|
|
559
|
+
parsed_input = await self.parse_input(input, **kwargs)
|
|
560
|
+
parsed_tools = await self.parse_tools(tools)
|
|
561
|
+
|
|
562
|
+
serialized_input = serialize_for_tokenizing(parsed_input)
|
|
563
|
+
serialized_tools = serialize_for_tokenizing(parsed_tools)
|
|
564
|
+
|
|
565
|
+
combined = f"{system_prompt}\n{serialized_input}\n{serialized_tools}"
|
|
566
|
+
|
|
567
|
+
return combined
|
|
568
|
+
|
|
569
|
+
async def count_tokens(
|
|
570
|
+
self,
|
|
571
|
+
input: Sequence[InputItem],
|
|
572
|
+
*,
|
|
573
|
+
history: Sequence[InputItem] = [],
|
|
574
|
+
tools: list[ToolDefinition] = [],
|
|
575
|
+
**kwargs: object,
|
|
576
|
+
) -> int:
|
|
577
|
+
"""
|
|
578
|
+
Count the number of tokens for a query.
|
|
579
|
+
Combines parsed input and tools, then tokenizes the result.
|
|
580
|
+
"""
|
|
581
|
+
|
|
582
|
+
if not input and not history:
|
|
583
|
+
return 0
|
|
584
|
+
|
|
585
|
+
if self.delegate:
|
|
586
|
+
encoding = await self.delegate.get_encoding()
|
|
587
|
+
else:
|
|
588
|
+
encoding = await self.get_encoding()
|
|
589
|
+
self.logger.debug(f"Token Count Encoding: {encoding}")
|
|
590
|
+
|
|
591
|
+
string_input = await self.stringify_input(
|
|
592
|
+
input, history=history, tools=tools, **kwargs
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
count = len(encoding.encode(string_input, disallowed_special=()))
|
|
596
|
+
self.logger.debug(f"Combined Token Count Input: {count}")
|
|
597
|
+
return count
|
|
598
|
+
|
|
424
599
|
async def query_json(
|
|
425
600
|
self,
|
|
426
601
|
input: Sequence[InputItem],
|
|
@@ -13,6 +13,7 @@ from model_library.base import (
|
|
|
13
13
|
QueryResult,
|
|
14
14
|
ToolDefinition,
|
|
15
15
|
)
|
|
16
|
+
from model_library.base.base import DelegateConfig
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
class DelegateOnlyException(Exception):
|
|
@@ -21,17 +22,51 @@ class DelegateOnlyException(Exception):
|
|
|
21
22
|
delegate-only model.
|
|
22
23
|
"""
|
|
23
24
|
|
|
24
|
-
DEFAULT_MESSAGE: str = "This model
|
|
25
|
+
DEFAULT_MESSAGE: str = "This model is running in delegate-only mode, certain functionality is not supported."
|
|
25
26
|
|
|
26
27
|
def __init__(self, message: str | None = None):
|
|
27
28
|
super().__init__(message or DelegateOnlyException.DEFAULT_MESSAGE)
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
class DelegateOnly(LLM):
|
|
31
|
-
|
|
32
|
-
def get_client(self) -> None:
|
|
32
|
+
def _get_default_api_key(self) -> str:
|
|
33
33
|
raise DelegateOnlyException()
|
|
34
34
|
|
|
35
|
+
@override
|
|
36
|
+
def get_client(self, api_key: str | None = None) -> None:
|
|
37
|
+
assert self.delegate
|
|
38
|
+
return self.delegate.get_client()
|
|
39
|
+
|
|
40
|
+
def init_delegate(
|
|
41
|
+
self,
|
|
42
|
+
config: LLMConfig | None,
|
|
43
|
+
delegate_config: DelegateConfig,
|
|
44
|
+
delegate_provider: Literal["openai", "anthropic"],
|
|
45
|
+
use_completions: bool = True,
|
|
46
|
+
) -> None:
|
|
47
|
+
from model_library.providers.anthropic import AnthropicModel
|
|
48
|
+
from model_library.providers.openai import OpenAIModel
|
|
49
|
+
|
|
50
|
+
match delegate_provider:
|
|
51
|
+
case "openai":
|
|
52
|
+
self.delegate = OpenAIModel(
|
|
53
|
+
model_name=self.model_name,
|
|
54
|
+
provider=self.provider,
|
|
55
|
+
config=config,
|
|
56
|
+
use_completions=use_completions,
|
|
57
|
+
delegate_config=delegate_config,
|
|
58
|
+
)
|
|
59
|
+
case "anthropic":
|
|
60
|
+
self.delegate = AnthropicModel(
|
|
61
|
+
model_name=self.model_name,
|
|
62
|
+
provider=self.provider,
|
|
63
|
+
config=config,
|
|
64
|
+
delegate_config=delegate_config,
|
|
65
|
+
)
|
|
66
|
+
self._client_registry_key_model_specific = (
|
|
67
|
+
self.delegate._client_registry_key_model_specific
|
|
68
|
+
)
|
|
69
|
+
|
|
35
70
|
def __init__(
|
|
36
71
|
self,
|
|
37
72
|
model_name: str,
|
|
@@ -42,6 +77,11 @@ class DelegateOnly(LLM):
|
|
|
42
77
|
config = config or LLMConfig()
|
|
43
78
|
config.native = False
|
|
44
79
|
super().__init__(model_name, provider, config=config)
|
|
80
|
+
config.native = True
|
|
81
|
+
|
|
82
|
+
def _get_extra_body(self) -> dict[str, Any]:
|
|
83
|
+
"""Build extra body parameters for delegate-specific features."""
|
|
84
|
+
return {}
|
|
45
85
|
|
|
46
86
|
@override
|
|
47
87
|
async def _query_impl(
|
|
@@ -53,39 +93,57 @@ class DelegateOnly(LLM):
|
|
|
53
93
|
**kwargs: object,
|
|
54
94
|
) -> QueryResult:
|
|
55
95
|
assert self.delegate
|
|
56
|
-
|
|
57
96
|
return await self.delegate_query(
|
|
58
|
-
input,
|
|
97
|
+
input,
|
|
98
|
+
tools=tools,
|
|
99
|
+
query_logger=query_logger,
|
|
100
|
+
extra_body=self._get_extra_body(),
|
|
101
|
+
**kwargs,
|
|
59
102
|
)
|
|
60
103
|
|
|
104
|
+
@override
|
|
105
|
+
async def build_body(
|
|
106
|
+
self,
|
|
107
|
+
input: Sequence[InputItem],
|
|
108
|
+
*,
|
|
109
|
+
tools: list[ToolDefinition],
|
|
110
|
+
**kwargs: object,
|
|
111
|
+
) -> dict[str, Any]:
|
|
112
|
+
assert self.delegate
|
|
113
|
+
return await self.delegate.build_body(input, tools=tools, **kwargs)
|
|
114
|
+
|
|
61
115
|
@override
|
|
62
116
|
async def parse_input(
|
|
63
117
|
self,
|
|
64
118
|
input: Sequence[InputItem],
|
|
65
119
|
**kwargs: Any,
|
|
66
120
|
) -> Any:
|
|
67
|
-
|
|
121
|
+
assert self.delegate
|
|
122
|
+
return await self.delegate.parse_input(input, **kwargs)
|
|
68
123
|
|
|
69
124
|
@override
|
|
70
125
|
async def parse_image(
|
|
71
126
|
self,
|
|
72
127
|
image: FileInput,
|
|
73
128
|
) -> Any:
|
|
74
|
-
|
|
129
|
+
assert self.delegate
|
|
130
|
+
return await self.delegate.parse_image(image)
|
|
75
131
|
|
|
76
132
|
@override
|
|
77
133
|
async def parse_file(
|
|
78
134
|
self,
|
|
79
135
|
file: FileInput,
|
|
80
136
|
) -> Any:
|
|
81
|
-
|
|
137
|
+
assert self.delegate
|
|
138
|
+
return await self.delegate.parse_file(file)
|
|
82
139
|
|
|
83
140
|
@override
|
|
84
141
|
async def parse_tools(
|
|
85
142
|
self,
|
|
86
143
|
tools: list[ToolDefinition],
|
|
87
144
|
) -> Any:
|
|
88
|
-
|
|
145
|
+
assert self.delegate
|
|
146
|
+
return await self.delegate.parse_tools(tools)
|
|
89
147
|
|
|
90
148
|
@override
|
|
91
149
|
async def upload_file(
|
|
@@ -96,3 +154,22 @@ class DelegateOnly(LLM):
|
|
|
96
154
|
type: Literal["image", "file"] = "file",
|
|
97
155
|
) -> FileWithId:
|
|
98
156
|
raise DelegateOnlyException()
|
|
157
|
+
|
|
158
|
+
@override
|
|
159
|
+
async def get_rate_limit(self) -> Any:
|
|
160
|
+
assert self.delegate
|
|
161
|
+
return await self.delegate.get_rate_limit()
|
|
162
|
+
|
|
163
|
+
@override
|
|
164
|
+
async def count_tokens(
|
|
165
|
+
self,
|
|
166
|
+
input: Sequence[InputItem],
|
|
167
|
+
*,
|
|
168
|
+
history: Sequence[InputItem] = [],
|
|
169
|
+
tools: list[ToolDefinition] = [],
|
|
170
|
+
**kwargs: object,
|
|
171
|
+
) -> int:
|
|
172
|
+
assert self.delegate
|
|
173
|
+
return await self.delegate.count_tokens(
|
|
174
|
+
input, history=history, tools=tools, **kwargs
|
|
175
|
+
)
|
model_library/base/input.py
CHANGED
|
@@ -74,8 +74,6 @@ class ToolCall(BaseModel):
|
|
|
74
74
|
--- INPUT ---
|
|
75
75
|
"""
|
|
76
76
|
|
|
77
|
-
RawResponse = Any
|
|
78
|
-
|
|
79
77
|
|
|
80
78
|
class ToolInput(BaseModel):
|
|
81
79
|
tools: list[ToolDefinition] = []
|
|
@@ -90,11 +88,16 @@ class TextInput(BaseModel):
|
|
|
90
88
|
text: str
|
|
91
89
|
|
|
92
90
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
91
|
+
class RawResponse(BaseModel):
|
|
92
|
+
# used to store a received response
|
|
93
|
+
response: Any
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class RawInput(BaseModel):
|
|
97
|
+
# used to pass in anything provider specific (e.g. a mock conversation)
|
|
98
|
+
input: Any
|
|
96
99
|
|
|
97
100
|
|
|
98
101
|
InputItem = (
|
|
99
|
-
TextInput | FileInput | ToolResult |
|
|
100
|
-
) # input item can either be a prompt, a file (image or file), a tool call result,
|
|
102
|
+
TextInput | FileInput | ToolResult | RawInput | RawResponse
|
|
103
|
+
) # input item can either be a prompt, a file (image or file), a tool call result, a previous response, or raw input
|
model_library/base/output.py
CHANGED
|
@@ -24,6 +24,11 @@ class Citation(BaseModel):
|
|
|
24
24
|
index: int | None = None
|
|
25
25
|
container_id: str | None = None
|
|
26
26
|
|
|
27
|
+
@override
|
|
28
|
+
def __repr__(self):
|
|
29
|
+
attrs = vars(self).copy()
|
|
30
|
+
return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2)}\n)"
|
|
31
|
+
|
|
27
32
|
|
|
28
33
|
class QueryResultExtras(BaseModel):
|
|
29
34
|
citations: list[Citation] = Field(default_factory=list)
|
|
@@ -113,6 +118,48 @@ class QueryResultCost(BaseModel):
|
|
|
113
118
|
)
|
|
114
119
|
|
|
115
120
|
|
|
121
|
+
class RateLimit(BaseModel):
|
|
122
|
+
"""Rate limit information"""
|
|
123
|
+
|
|
124
|
+
request_limit: int | None = None
|
|
125
|
+
request_remaining: int | None = None
|
|
126
|
+
|
|
127
|
+
token_limit: int | None = None
|
|
128
|
+
token_limit_input: int | None = None
|
|
129
|
+
token_limit_output: int | None = None
|
|
130
|
+
|
|
131
|
+
token_remaining: int | None = None
|
|
132
|
+
token_remaining_input: int | None = None
|
|
133
|
+
token_remaining_output: int | None = None
|
|
134
|
+
|
|
135
|
+
unix_timestamp: float
|
|
136
|
+
raw: Any
|
|
137
|
+
|
|
138
|
+
@computed_field
|
|
139
|
+
@property
|
|
140
|
+
def token_limit_total(self) -> int:
|
|
141
|
+
if self.token_limit:
|
|
142
|
+
return self.token_limit
|
|
143
|
+
else:
|
|
144
|
+
return (self.token_limit_input or 0) + (self.token_limit_output or 0)
|
|
145
|
+
|
|
146
|
+
@computed_field
|
|
147
|
+
@property
|
|
148
|
+
def token_remaining_total(self) -> int:
|
|
149
|
+
if self.token_remaining:
|
|
150
|
+
return self.token_remaining
|
|
151
|
+
else:
|
|
152
|
+
return (self.token_remaining_input or 0) + (
|
|
153
|
+
self.token_remaining_output or 0
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
@override
|
|
157
|
+
def __repr__(self):
|
|
158
|
+
attrs = vars(self).copy()
|
|
159
|
+
attrs.pop("raw", None)
|
|
160
|
+
return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2, sort_dicts=False)}\n)"
|
|
161
|
+
|
|
162
|
+
|
|
116
163
|
class QueryResultMetadata(BaseModel):
|
|
117
164
|
"""
|
|
118
165
|
Metadata for a query: token usage and timing.
|
|
@@ -126,6 +173,7 @@ class QueryResultMetadata(BaseModel):
|
|
|
126
173
|
reasoning_tokens: int | None = None
|
|
127
174
|
cache_read_tokens: int | None = None
|
|
128
175
|
cache_write_tokens: int | None = None
|
|
176
|
+
extra: dict[str, Any] = {}
|
|
129
177
|
|
|
130
178
|
@property
|
|
131
179
|
def default_duration_seconds(self) -> float:
|