model-library 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +141 -62
- model_library/base/delegate_only.py +77 -10
- model_library/base/output.py +43 -0
- model_library/base/utils.py +35 -0
- model_library/config/alibaba_models.yaml +49 -57
- model_library/config/all_models.json +353 -120
- model_library/config/anthropic_models.yaml +2 -1
- model_library/config/kimi_models.yaml +30 -3
- model_library/config/mistral_models.yaml +2 -0
- model_library/config/openai_models.yaml +15 -23
- model_library/config/together_models.yaml +2 -0
- model_library/config/xiaomi_models.yaml +43 -0
- model_library/config/zai_models.yaml +27 -3
- model_library/exceptions.py +3 -77
- model_library/providers/ai21labs.py +12 -8
- model_library/providers/alibaba.py +17 -8
- model_library/providers/amazon.py +49 -16
- model_library/providers/anthropic.py +128 -48
- model_library/providers/azure.py +22 -10
- model_library/providers/cohere.py +7 -7
- model_library/providers/deepseek.py +8 -8
- model_library/providers/fireworks.py +7 -8
- model_library/providers/google/batch.py +14 -10
- model_library/providers/google/google.py +57 -30
- model_library/providers/inception.py +7 -7
- model_library/providers/kimi.py +18 -8
- model_library/providers/minimax.py +15 -17
- model_library/providers/mistral.py +20 -8
- model_library/providers/openai.py +99 -22
- model_library/providers/openrouter.py +34 -0
- model_library/providers/perplexity.py +7 -7
- model_library/providers/together.py +7 -8
- model_library/providers/vals.py +12 -6
- model_library/providers/vercel.py +34 -0
- model_library/providers/xai.py +47 -42
- model_library/providers/xiaomi.py +34 -0
- model_library/providers/zai.py +38 -8
- model_library/register_models.py +5 -0
- model_library/registry_utils.py +48 -17
- model_library/retriers/__init__.py +0 -0
- model_library/retriers/backoff.py +73 -0
- model_library/retriers/base.py +225 -0
- model_library/retriers/token.py +427 -0
- model_library/retriers/utils.py +11 -0
- model_library/settings.py +1 -1
- model_library/utils.py +17 -7
- {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/METADATA +2 -1
- model_library-0.1.9.dist-info/RECORD +73 -0
- {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/WHEEL +1 -1
- model_library-0.1.7.dist-info/RECORD +0 -64
- {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/top_level.txt +0 -0
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
|
+
from pydantic import SecretStr
|
|
3
4
|
from typing_extensions import override
|
|
4
5
|
|
|
5
6
|
from model_library import model_library_settings
|
|
6
7
|
from model_library.base import (
|
|
8
|
+
DelegateConfig,
|
|
7
9
|
DelegateOnly,
|
|
8
10
|
LLMConfig,
|
|
9
11
|
ProviderConfig,
|
|
10
12
|
QueryResultCost,
|
|
11
13
|
QueryResultMetadata,
|
|
12
14
|
)
|
|
13
|
-
from model_library.providers.openai import OpenAIModel
|
|
14
15
|
from model_library.register_models import register_provider
|
|
15
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class TogetherConfig(ProviderConfig):
|
|
@@ -32,15 +32,14 @@ class TogetherModel(DelegateOnly):
|
|
|
32
32
|
):
|
|
33
33
|
super().__init__(model_name, provider, config=config)
|
|
34
34
|
# https://docs.together.ai/docs/openai-api-compatibility
|
|
35
|
-
self.
|
|
36
|
-
model_name=self.model_name,
|
|
37
|
-
provider=self.provider,
|
|
35
|
+
self.init_delegate(
|
|
38
36
|
config=config,
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
37
|
+
delegate_config=DelegateConfig(
|
|
38
|
+
base_url="https://api.together.xyz/v1/",
|
|
39
|
+
api_key=SecretStr(model_library_settings.TOGETHER_API_KEY),
|
|
42
40
|
),
|
|
43
41
|
use_completions=True,
|
|
42
|
+
delegate_provider="openai",
|
|
44
43
|
)
|
|
45
44
|
|
|
46
45
|
@override
|
model_library/providers/vals.py
CHANGED
|
@@ -151,13 +151,17 @@ class DummyAIBatchMixin(LLMBatchMixin):
|
|
|
151
151
|
class DummyAIModel(LLM):
|
|
152
152
|
_client: Redis | None = None
|
|
153
153
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
154
|
+
def _get_default_api_key(self) -> str:
|
|
155
|
+
return model_library_settings.REDIS_URL
|
|
156
|
+
|
|
157
|
+
def get_client(self, api_key: str | None = None) -> Redis:
|
|
158
|
+
if not self.has_client():
|
|
159
|
+
assert api_key
|
|
160
|
+
client = redis.from_url( # pyright: ignore[reportUnknownMemberType]
|
|
158
161
|
model_library_settings.REDIS_URL, decode_responses=True
|
|
159
162
|
)
|
|
160
|
-
|
|
163
|
+
self.assign_client(client)
|
|
164
|
+
return super().get_client()
|
|
161
165
|
|
|
162
166
|
def __init__(
|
|
163
167
|
self,
|
|
@@ -238,12 +242,14 @@ class DummyAIModel(LLM):
|
|
|
238
242
|
messages = await self.parse_input(input)
|
|
239
243
|
body: dict[str, Any] = {
|
|
240
244
|
"model": self.model_name,
|
|
241
|
-
"max_tokens": self.max_tokens,
|
|
242
245
|
"seed": 0,
|
|
243
246
|
"messages": messages,
|
|
244
247
|
"tools": await self.parse_tools(tools),
|
|
245
248
|
}
|
|
246
249
|
|
|
250
|
+
if self.max_tokens:
|
|
251
|
+
body["max_tokens"] = self.max_tokens
|
|
252
|
+
|
|
247
253
|
if self.supports_temperature:
|
|
248
254
|
if self.temperature is not None:
|
|
249
255
|
body["temperature"] = self.temperature
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
|
|
5
|
+
from model_library import model_library_settings
|
|
6
|
+
from model_library.base import (
|
|
7
|
+
DelegateConfig,
|
|
8
|
+
DelegateOnly,
|
|
9
|
+
LLMConfig,
|
|
10
|
+
)
|
|
11
|
+
from model_library.register_models import register_provider
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@register_provider("vercel")
|
|
15
|
+
class VercelModel(DelegateOnly):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
model_name: str,
|
|
19
|
+
provider: Literal["vercel"] = "vercel",
|
|
20
|
+
*,
|
|
21
|
+
config: LLMConfig | None = None,
|
|
22
|
+
):
|
|
23
|
+
super().__init__(model_name, provider, config=config)
|
|
24
|
+
|
|
25
|
+
# https://vercel.com/docs/ai-gateway/sdks-and-apis#quick-start
|
|
26
|
+
self.init_delegate(
|
|
27
|
+
config=config,
|
|
28
|
+
delegate_config=DelegateConfig(
|
|
29
|
+
base_url="https://ai-gateway.vercel.sh/v1",
|
|
30
|
+
api_key=SecretStr(model_library_settings.VERCEL_API_KEY),
|
|
31
|
+
),
|
|
32
|
+
use_completions=True,
|
|
33
|
+
delegate_provider="openai",
|
|
34
|
+
)
|
model_library/providers/xai.py
CHANGED
|
@@ -2,7 +2,7 @@ import io
|
|
|
2
2
|
import logging
|
|
3
3
|
from typing import Any, Literal, Sequence
|
|
4
4
|
|
|
5
|
-
import
|
|
5
|
+
from pydantic import SecretStr
|
|
6
6
|
from typing_extensions import override
|
|
7
7
|
from xai_sdk import AsyncClient
|
|
8
8
|
from xai_sdk.aio.chat import Chat
|
|
@@ -14,6 +14,7 @@ from xai_sdk.proto.v6.chat_pb2 import Message, Tool
|
|
|
14
14
|
from model_library import model_library_settings
|
|
15
15
|
from model_library.base import (
|
|
16
16
|
LLM,
|
|
17
|
+
DelegateConfig,
|
|
17
18
|
FileBase,
|
|
18
19
|
FileInput,
|
|
19
20
|
FileWithBase64,
|
|
@@ -36,24 +37,26 @@ from model_library.exceptions import (
|
|
|
36
37
|
MaxOutputTokensExceededError,
|
|
37
38
|
ModelNoOutputError,
|
|
38
39
|
NoMatchingToolCallError,
|
|
39
|
-
RateLimitException,
|
|
40
40
|
)
|
|
41
41
|
from model_library.providers.openai import OpenAIModel
|
|
42
42
|
from model_library.register_models import register_provider
|
|
43
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
44
43
|
|
|
45
44
|
|
|
46
45
|
@register_provider("grok")
|
|
47
46
|
class XAIModel(LLM):
|
|
48
|
-
|
|
47
|
+
@override
|
|
48
|
+
def _get_default_api_key(self) -> str:
|
|
49
|
+
return model_library_settings.XAI_API_KEY
|
|
49
50
|
|
|
50
51
|
@override
|
|
51
|
-
def get_client(self) -> AsyncClient:
|
|
52
|
-
if not
|
|
53
|
-
|
|
54
|
-
|
|
52
|
+
def get_client(self, api_key: str | None = None) -> AsyncClient:
|
|
53
|
+
if not self.has_client():
|
|
54
|
+
assert api_key
|
|
55
|
+
client = AsyncClient(
|
|
56
|
+
api_key=api_key,
|
|
55
57
|
)
|
|
56
|
-
|
|
58
|
+
self.assign_client(client)
|
|
59
|
+
return super().get_client()
|
|
57
60
|
|
|
58
61
|
@override
|
|
59
62
|
def __init__(
|
|
@@ -73,13 +76,13 @@ class XAIModel(LLM):
|
|
|
73
76
|
model_name=self.model_name,
|
|
74
77
|
provider=provider,
|
|
75
78
|
config=config,
|
|
76
|
-
|
|
77
|
-
api_key=model_library_settings.XAI_API_KEY,
|
|
79
|
+
delegate_config=DelegateConfig(
|
|
78
80
|
base_url=(
|
|
79
81
|
"https://us-west-1.api.x.ai/v1"
|
|
80
82
|
if "grok-3-mini-reasoning" in self.model_name
|
|
81
83
|
else "https://api.x.ai/v1"
|
|
82
84
|
),
|
|
85
|
+
api_key=SecretStr(model_library_settings.XAI_API_KEY),
|
|
83
86
|
),
|
|
84
87
|
use_completions=True,
|
|
85
88
|
)
|
|
@@ -210,12 +213,14 @@ class XAIModel(LLM):
|
|
|
210
213
|
messages.append(system(str(kwargs.pop("system_prompt"))))
|
|
211
214
|
|
|
212
215
|
body: dict[str, Any] = {
|
|
213
|
-
"max_tokens": self.max_tokens,
|
|
214
216
|
"model": self.model_name,
|
|
215
217
|
"tools": await self.parse_tools(tools),
|
|
216
218
|
"messages": messages,
|
|
217
219
|
}
|
|
218
220
|
|
|
221
|
+
if self.max_tokens:
|
|
222
|
+
body["max_tokens"] = self.max_tokens
|
|
223
|
+
|
|
219
224
|
if self.supports_temperature:
|
|
220
225
|
if self.temperature is not None:
|
|
221
226
|
body["temperature"] = self.temperature
|
|
@@ -253,38 +258,35 @@ class XAIModel(LLM):
|
|
|
253
258
|
|
|
254
259
|
body = await self.build_body(input, tools=tools, **kwargs)
|
|
255
260
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
args=tool_call.function.arguments,
|
|
277
|
-
)
|
|
261
|
+
chat: Chat = self.get_client().chat.create(**body)
|
|
262
|
+
|
|
263
|
+
latest_response: Response | None = None
|
|
264
|
+
async for response, _ in chat.stream():
|
|
265
|
+
latest_response = response
|
|
266
|
+
|
|
267
|
+
if not latest_response:
|
|
268
|
+
raise ModelNoOutputError("Model failed to produce a response")
|
|
269
|
+
|
|
270
|
+
tool_calls: list[ToolCall] = []
|
|
271
|
+
if (
|
|
272
|
+
latest_response.finish_reason == "REASON_TOOL_CALLS"
|
|
273
|
+
and latest_response.tool_calls
|
|
274
|
+
):
|
|
275
|
+
for tool_call in latest_response.tool_calls:
|
|
276
|
+
tool_calls.append(
|
|
277
|
+
ToolCall(
|
|
278
|
+
id=tool_call.id,
|
|
279
|
+
name=tool_call.function.name,
|
|
280
|
+
args=tool_call.function.arguments,
|
|
278
281
|
)
|
|
282
|
+
)
|
|
279
283
|
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
except grpc.RpcError as e:
|
|
287
|
-
raise RateLimitException(e.details())
|
|
284
|
+
if (
|
|
285
|
+
latest_response.finish_reason == "REASON_MAX_LEN"
|
|
286
|
+
and not latest_response.content
|
|
287
|
+
and not latest_response.reasoning_content
|
|
288
|
+
):
|
|
289
|
+
raise MaxOutputTokensExceededError()
|
|
288
290
|
|
|
289
291
|
return QueryResult(
|
|
290
292
|
output_text=latest_response.content,
|
|
@@ -310,6 +312,9 @@ class XAIModel(LLM):
|
|
|
310
312
|
tools: list[ToolDefinition] = [],
|
|
311
313
|
**kwargs: object,
|
|
312
314
|
) -> int:
|
|
315
|
+
if not input and not history:
|
|
316
|
+
return 0
|
|
317
|
+
|
|
313
318
|
string_input = await self.stringify_input(input, history=history, tools=tools)
|
|
314
319
|
self.logger.debug(string_input)
|
|
315
320
|
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
|
|
5
|
+
from model_library import model_library_settings
|
|
6
|
+
from model_library.base import (
|
|
7
|
+
DelegateConfig,
|
|
8
|
+
DelegateOnly,
|
|
9
|
+
LLMConfig,
|
|
10
|
+
)
|
|
11
|
+
from model_library.register_models import register_provider
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@register_provider("xiaomi")
|
|
15
|
+
class XiaomiModel(DelegateOnly):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
model_name: str,
|
|
19
|
+
provider: Literal["xiaomi"] = "xiaomi",
|
|
20
|
+
*,
|
|
21
|
+
config: LLMConfig | None = None,
|
|
22
|
+
):
|
|
23
|
+
super().__init__(model_name, provider, config=config)
|
|
24
|
+
|
|
25
|
+
# https://platform.xiaomimimo.com/#/docs/quick-start/first-api-call
|
|
26
|
+
self.init_delegate(
|
|
27
|
+
config=config,
|
|
28
|
+
delegate_config=DelegateConfig(
|
|
29
|
+
base_url="https://api.xiaomimimo.com/v1",
|
|
30
|
+
api_key=SecretStr(model_library_settings.XIAOMI_API_KEY),
|
|
31
|
+
),
|
|
32
|
+
use_completions=True,
|
|
33
|
+
delegate_provider="openai",
|
|
34
|
+
)
|
model_library/providers/zai.py
CHANGED
|
@@ -1,17 +1,36 @@
|
|
|
1
|
-
from typing import Literal
|
|
1
|
+
from typing import Any, Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
from typing_extensions import override
|
|
2
5
|
|
|
3
6
|
from model_library import model_library_settings
|
|
4
7
|
from model_library.base import (
|
|
8
|
+
DelegateConfig,
|
|
5
9
|
DelegateOnly,
|
|
6
10
|
LLMConfig,
|
|
11
|
+
ProviderConfig,
|
|
7
12
|
)
|
|
8
|
-
from model_library.providers.openai import OpenAIModel
|
|
9
13
|
from model_library.register_models import register_provider
|
|
10
|
-
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ZAIConfig(ProviderConfig):
|
|
17
|
+
"""Configuration for ZAI (GLM) models.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
clear_thinking: When disabled, reasoning content from previous turns is
|
|
21
|
+
preserved in context. This is useful for multi-turn conversations where
|
|
22
|
+
you want the model to maintain coherent reasoning across turns.
|
|
23
|
+
Enabled by default on the standard API endpoint.
|
|
24
|
+
See: https://docs.z.ai/guides/capabilities/thinking-mode
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
clear_thinking: bool = True
|
|
11
28
|
|
|
12
29
|
|
|
13
30
|
@register_provider("zai")
|
|
14
31
|
class ZAIModel(DelegateOnly):
|
|
32
|
+
provider_config = ZAIConfig()
|
|
33
|
+
|
|
15
34
|
def __init__(
|
|
16
35
|
self,
|
|
17
36
|
model_name: str,
|
|
@@ -21,14 +40,25 @@ class ZAIModel(DelegateOnly):
|
|
|
21
40
|
):
|
|
22
41
|
super().__init__(model_name, provider, config=config)
|
|
23
42
|
|
|
43
|
+
self.clear_thinking = self.provider_config.clear_thinking
|
|
44
|
+
|
|
24
45
|
# https://docs.z.ai/guides/develop/openai/python
|
|
25
|
-
self.
|
|
26
|
-
model_name=self.model_name,
|
|
27
|
-
provider=self.provider,
|
|
46
|
+
self.init_delegate(
|
|
28
47
|
config=config,
|
|
29
|
-
|
|
30
|
-
api_key=model_library_settings.ZAI_API_KEY,
|
|
48
|
+
delegate_config=DelegateConfig(
|
|
31
49
|
base_url="https://open.bigmodel.cn/api/paas/v4/",
|
|
50
|
+
api_key=SecretStr(model_library_settings.ZAI_API_KEY),
|
|
32
51
|
),
|
|
33
52
|
use_completions=True,
|
|
53
|
+
delegate_provider="openai",
|
|
34
54
|
)
|
|
55
|
+
|
|
56
|
+
@override
|
|
57
|
+
def _get_extra_body(self) -> dict[str, Any]:
|
|
58
|
+
"""Build extra body parameters for GLM-specific features."""
|
|
59
|
+
return {
|
|
60
|
+
"thinking": {
|
|
61
|
+
"type": "enabled" if self.reasoning else "disabled",
|
|
62
|
+
"clear_thinking": self.clear_thinking,
|
|
63
|
+
}
|
|
64
|
+
}
|
model_library/register_models.py
CHANGED
|
@@ -170,6 +170,7 @@ class DefaultParameters(BaseModel):
|
|
|
170
170
|
top_p: float | None = None
|
|
171
171
|
top_k: int | None = None
|
|
172
172
|
reasoning_effort: str | bool | None = None
|
|
173
|
+
compute_effort: str | bool | None = None
|
|
173
174
|
|
|
174
175
|
|
|
175
176
|
class RawModelConfig(BaseModel):
|
|
@@ -338,6 +339,10 @@ def _register_models() -> ModelRegistry:
|
|
|
338
339
|
copy.slug = key.replace("/", "_")
|
|
339
340
|
copy.full_key = key
|
|
340
341
|
copy.alternative_keys = []
|
|
342
|
+
copy.provider_properties = ProviderProperties.model_validate(
|
|
343
|
+
provider_properties
|
|
344
|
+
)
|
|
345
|
+
|
|
341
346
|
registry[key] = copy
|
|
342
347
|
|
|
343
348
|
return registry
|
model_library/registry_utils.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
|
+
import tiktoken
|
|
1
2
|
from functools import cache
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
from typing import TypedDict
|
|
4
5
|
|
|
5
|
-
import
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
6
|
+
from model_library.base import (
|
|
7
|
+
LLM,
|
|
8
|
+
LLMConfig,
|
|
9
|
+
ProviderConfig,
|
|
10
|
+
QueryResultCost,
|
|
11
|
+
QueryResultMetadata,
|
|
12
|
+
)
|
|
9
13
|
from model_library.register_models import (
|
|
10
14
|
CostProperties,
|
|
11
15
|
ModelConfig,
|
|
@@ -196,19 +200,46 @@ def get_provider_names() -> list[str]:
|
|
|
196
200
|
|
|
197
201
|
|
|
198
202
|
@cache
|
|
199
|
-
def get_model_names(
|
|
200
|
-
|
|
201
|
-
|
|
203
|
+
def get_model_names(
|
|
204
|
+
provider: str | None = None,
|
|
205
|
+
include_deprecated: bool = False,
|
|
206
|
+
include_alt_keys: bool = True,
|
|
207
|
+
) -> list[str]:
|
|
208
|
+
"""
|
|
209
|
+
Return model names in the registry
|
|
210
|
+
- provider: Filter by provider name
|
|
211
|
+
- include_deprecated: Include deprecated models
|
|
212
|
+
- include_alt_keys: Include alternative keys from the same provider
|
|
213
|
+
"""
|
|
214
|
+
registry = get_model_registry()
|
|
215
|
+
alternative_keys_set: set[str] = set()
|
|
216
|
+
|
|
217
|
+
if not include_alt_keys:
|
|
218
|
+
for model in registry.values():
|
|
219
|
+
for alt_item in model.alternative_keys:
|
|
220
|
+
alt_key = (
|
|
221
|
+
alt_item if isinstance(alt_item, str) else list(alt_item.keys())[0]
|
|
222
|
+
)
|
|
223
|
+
if alt_key.split("/")[0] == model.provider_name:
|
|
224
|
+
alternative_keys_set.add(alt_key)
|
|
225
|
+
|
|
226
|
+
return sorted(
|
|
227
|
+
[
|
|
228
|
+
model.full_key
|
|
229
|
+
for model in get_model_registry().values()
|
|
230
|
+
if (not provider or model.provider_name.lower() == provider.lower())
|
|
231
|
+
and (not model.metadata.deprecated or include_deprecated)
|
|
232
|
+
and model.full_key not in alternative_keys_set
|
|
233
|
+
]
|
|
234
|
+
)
|
|
202
235
|
|
|
203
236
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
if model.provider_name.lower() == provider_name.lower()
|
|
211
|
-
]
|
|
237
|
+
"""
|
|
238
|
+
everything below this comment is included for legacy support of caselaw/corpfin custom models.
|
|
239
|
+
@orestes please remove this as part of the migration to a standard CorpFin harness.
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
DEFAULT_CONTEXT_WINDOW = 128_000
|
|
212
243
|
|
|
213
244
|
|
|
214
245
|
@cache
|
|
@@ -233,7 +264,7 @@ def auto_trim_document(
|
|
|
233
264
|
Trimmed document, or original document if trimming isn't needed
|
|
234
265
|
"""
|
|
235
266
|
|
|
236
|
-
max_tokens = get_max_document_tokens(model_name)
|
|
267
|
+
max_tokens = get_max_document_tokens(model_name) or DEFAULT_CONTEXT_WINDOW
|
|
237
268
|
|
|
238
269
|
encoding = _get_tiktoken_encoder()
|
|
239
270
|
tokens = encoding.encode(document)
|
|
@@ -260,5 +291,5 @@ def get_max_document_tokens(model_name: str, output_buffer: int = 10000) -> int:
|
|
|
260
291
|
# Import here to avoid circular imports
|
|
261
292
|
from model_library.utils import get_context_window_for_model
|
|
262
293
|
|
|
263
|
-
context_window = get_context_window_for_model(model_name)
|
|
294
|
+
context_window = get_context_window_for_model(model_name) or DEFAULT_CONTEXT_WINDOW
|
|
264
295
|
return context_window - output_buffer
|
|
File without changes
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
from model_library.base.base import QueryResult
|
|
5
|
+
from model_library.exceptions import exception_message
|
|
6
|
+
from model_library.retriers.base import BaseRetrier
|
|
7
|
+
from model_library.retriers.utils import jitter
|
|
8
|
+
|
|
9
|
+
RETRY_MAX_TRIES: int = 20
|
|
10
|
+
RETRY_INITIAL: float = 10.0
|
|
11
|
+
RETRY_EXPO: float = 1.4
|
|
12
|
+
RETRY_MAX_BACKOFF_WAIT: float = 240.0
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ExponentialBackoffRetrier(BaseRetrier):
|
|
16
|
+
"""
|
|
17
|
+
Exponential backoff retry strategy.
|
|
18
|
+
Uses exponential backoff with jitter for wait times.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
logger: logging.Logger,
|
|
24
|
+
max_tries: int = RETRY_MAX_TRIES,
|
|
25
|
+
max_time: float | None = None,
|
|
26
|
+
retry_callback: Callable[[int, Exception | None, float, float], None]
|
|
27
|
+
| None = None,
|
|
28
|
+
*,
|
|
29
|
+
initial: float = RETRY_INITIAL,
|
|
30
|
+
expo: float = RETRY_EXPO,
|
|
31
|
+
max_backoff_wait: float = RETRY_MAX_BACKOFF_WAIT,
|
|
32
|
+
):
|
|
33
|
+
super().__init__(
|
|
34
|
+
strategy="backoff",
|
|
35
|
+
logger=logger,
|
|
36
|
+
max_tries=max_tries,
|
|
37
|
+
max_time=max_time,
|
|
38
|
+
retry_callback=retry_callback,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
self.initial = initial
|
|
42
|
+
self.expo = expo
|
|
43
|
+
self.max_backoff_wait = max_backoff_wait
|
|
44
|
+
|
|
45
|
+
async def _calculate_wait_time(
|
|
46
|
+
self, attempt: int, exception: Exception | None = None
|
|
47
|
+
) -> float:
|
|
48
|
+
"""Calculate exponential backoff wait time with jitter"""
|
|
49
|
+
|
|
50
|
+
exponential_wait = self.initial * (self.expo**attempt)
|
|
51
|
+
capped_wait = min(exponential_wait, self.max_backoff_wait)
|
|
52
|
+
return jitter(capped_wait)
|
|
53
|
+
|
|
54
|
+
async def _on_retry(
|
|
55
|
+
self, exception: Exception | None, elapsed: float, wait_time: float
|
|
56
|
+
) -> None:
|
|
57
|
+
"""Increment attempt counter and log retry attempt"""
|
|
58
|
+
|
|
59
|
+
logger_msg = f"[Retry] | {self.strategy} | Attempt: {self.attempts} | Elapsed: {elapsed:.1f}s | Next wait: {wait_time:.1f}s | Exception: {exception_message(exception)} "
|
|
60
|
+
|
|
61
|
+
self.logger.warning(logger_msg)
|
|
62
|
+
|
|
63
|
+
if self.retry_callback:
|
|
64
|
+
self.retry_callback(self.attempts, exception, elapsed, wait_time)
|
|
65
|
+
|
|
66
|
+
async def _pre_function(self) -> None:
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
async def _post_function(self, result: tuple[QueryResult, float]) -> None:
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
async def validate(self) -> None:
|
|
73
|
+
return
|