model-library 0.1.7__tar.gz → 0.1.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {model_library-0.1.7 → model_library-0.1.9}/Makefile +3 -5
- {model_library-0.1.7 → model_library-0.1.9}/PKG-INFO +2 -1
- {model_library-0.1.7 → model_library-0.1.9}/examples/advanced/custom_retrier.py +26 -33
- {model_library-0.1.7 → model_library-0.1.9}/examples/advanced/web_search.py +1 -2
- {model_library-0.1.7 → model_library-0.1.9}/examples/count_tokens.py +7 -4
- model_library-0.1.9/examples/token_retry.py +85 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/base/base.py +141 -62
- model_library-0.1.9/model_library/base/delegate_only.py +175 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/base/output.py +43 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/base/utils.py +35 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/alibaba_models.yaml +49 -57
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/all_models.json +353 -120
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/anthropic_models.yaml +2 -1
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/kimi_models.yaml +30 -3
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/mistral_models.yaml +2 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/openai_models.yaml +15 -23
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/together_models.yaml +2 -0
- model_library-0.1.9/model_library/config/xiaomi_models.yaml +43 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/zai_models.yaml +27 -3
- {model_library-0.1.7 → model_library-0.1.9}/model_library/exceptions.py +3 -77
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/ai21labs.py +12 -8
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/alibaba.py +17 -8
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/amazon.py +49 -16
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/anthropic.py +128 -48
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/azure.py +22 -10
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/cohere.py +7 -7
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/deepseek.py +8 -8
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/fireworks.py +7 -8
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/google/batch.py +14 -10
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/google/google.py +57 -30
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/inception.py +7 -7
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/kimi.py +18 -8
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/minimax.py +15 -17
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/mistral.py +20 -8
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/openai.py +99 -22
- model_library-0.1.9/model_library/providers/openrouter.py +34 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/perplexity.py +7 -7
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/together.py +7 -8
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/vals.py +12 -6
- model_library-0.1.9/model_library/providers/vercel.py +34 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/xai.py +47 -42
- model_library-0.1.9/model_library/providers/xiaomi.py +34 -0
- model_library-0.1.9/model_library/providers/zai.py +64 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/register_models.py +5 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/registry_utils.py +48 -17
- model_library-0.1.9/model_library/retriers/backoff.py +73 -0
- model_library-0.1.9/model_library/retriers/base.py +225 -0
- model_library-0.1.9/model_library/retriers/token.py +427 -0
- model_library-0.1.9/model_library/retriers/utils.py +11 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/settings.py +1 -1
- {model_library-0.1.7 → model_library-0.1.9}/model_library/utils.py +17 -7
- {model_library-0.1.7 → model_library-0.1.9}/model_library.egg-info/PKG-INFO +2 -1
- {model_library-0.1.7 → model_library-0.1.9}/model_library.egg-info/SOURCES.txt +16 -7
- {model_library-0.1.7 → model_library-0.1.9}/model_library.egg-info/requires.txt +1 -0
- {model_library-0.1.7 → model_library-0.1.9}/pyproject.toml +2 -0
- {model_library-0.1.7 → model_library-0.1.9}/scripts/browse_models.py +2 -2
- {model_library-0.1.7 → model_library-0.1.9}/scripts/run_models.py +13 -13
- model_library-0.1.9/tests/README.md +75 -0
- model_library-0.1.9/tests/conftest.py +85 -0
- model_library-0.1.9/tests/integration/conftest.py +8 -0
- model_library-0.1.9/tests/integration/test_basic.py +15 -0
- {model_library-0.1.7 → model_library-0.1.9}/tests/integration/test_batch.py +4 -12
- model_library-0.1.9/tests/integration/test_files.py +38 -0
- model_library-0.1.9/tests/integration/test_long_problem.py +24 -0
- {model_library-0.1.7 → model_library-0.1.9}/tests/integration/test_reasoning.py +3 -12
- model_library-0.1.9/tests/integration/test_retry.py +28 -0
- {model_library-0.1.7 → model_library-0.1.9}/tests/integration/test_streaming.py +13 -30
- {model_library-0.1.7 → model_library-0.1.9}/tests/integration/test_structured_output.py +0 -6
- {model_library-0.1.7 → model_library-0.1.9}/tests/integration/test_tools.py +1 -18
- model_library-0.1.9/tests/test_helpers.py +183 -0
- model_library-0.1.9/tests/unit/conftest.py +102 -0
- model_library-0.1.9/tests/unit/test_batch.py +353 -0
- model_library-0.1.9/tests/unit/test_count_tokens.py +42 -0
- model_library-0.1.9/tests/unit/test_get_client.py +188 -0
- model_library-0.1.9/tests/unit/test_openai_config.py +32 -0
- {model_library-0.1.7 → model_library-0.1.9}/tests/unit/test_prompt_caching.py +31 -81
- {model_library-0.1.7 → model_library-0.1.9}/tests/unit/test_query_logger.py +18 -0
- {model_library-0.1.7 → model_library-0.1.9}/tests/unit/test_registry.py +7 -7
- {model_library-0.1.7 → model_library-0.1.9}/tests/unit/test_result_metadata.py +28 -18
- {model_library-0.1.7 → model_library-0.1.9}/tests/unit/test_retry.py +44 -41
- model_library-0.1.9/tests/unit/test_token_retry.py +405 -0
- {model_library-0.1.7 → model_library-0.1.9}/tests/unit/test_tools.py +4 -6
- model_library-0.1.9/tests/unit/test_utils.py +15 -0
- {model_library-0.1.7 → model_library-0.1.9}/uv.lock +28 -4
- model_library-0.1.7/model_library/base/delegate_only.py +0 -108
- model_library-0.1.7/model_library/providers/zai.py +0 -34
- model_library-0.1.7/tests/README.md +0 -87
- model_library-0.1.7/tests/conftest.py +0 -275
- model_library-0.1.7/tests/integration/conftest.py +0 -8
- model_library-0.1.7/tests/integration/test_completion.py +0 -41
- model_library-0.1.7/tests/integration/test_files.py +0 -279
- model_library-0.1.7/tests/integration/test_retry.py +0 -95
- model_library-0.1.7/tests/test_helpers.py +0 -89
- model_library-0.1.7/tests/unit/conftest.py +0 -53
- model_library-0.1.7/tests/unit/providers/test_fireworks_provider.py +0 -48
- model_library-0.1.7/tests/unit/providers/test_google_provider.py +0 -58
- model_library-0.1.7/tests/unit/test_batch.py +0 -236
- model_library-0.1.7/tests/unit/test_context_window.py +0 -45
- model_library-0.1.7/tests/unit/test_count_tokens.py +0 -67
- model_library-0.1.7/tests/unit/test_perplexity_provider.py +0 -73
- model_library-0.1.7/tests/unit/test_streaming.py +0 -83
- {model_library-0.1.7 → model_library-0.1.9}/.gitattributes +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/.github/workflows/publish.yml +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/.github/workflows/style.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/.github/workflows/test.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/.github/workflows/typecheck.yml +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/.gitignore +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/LICENSE +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/README.md +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/examples/README.md +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/examples/advanced/batch.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/examples/advanced/deep_research.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/examples/advanced/stress.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/examples/advanced/structured_output.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/examples/basics.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/examples/data/files.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/examples/data/images.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/examples/embeddings.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/examples/files.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/examples/images.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/examples/prompt_caching.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/examples/setup.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/examples/tool_calls.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/__init__.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/base/__init__.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/base/batch.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/base/input.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/README.md +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/ai21labs_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/amazon_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/cohere_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/deepseek_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/dummy_model.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/fireworks_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/google_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/inception_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/minimax_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/perplexity_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/config/xai_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/file_utils.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/logging.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/model_utils.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/__init__.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/providers/google/__init__.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library/py.typed +0 -0
- {model_library-0.1.7/tests → model_library-0.1.9/model_library/retriers}/__init__.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library.egg-info/dependency_links.txt +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/model_library.egg-info/top_level.txt +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/scripts/config.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/scripts/publish.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/setup.cfg +0 -0
- {model_library-0.1.7/tests/integration → model_library-0.1.9/tests}/__init__.py +0 -0
- {model_library-0.1.7/tests/unit → model_library-0.1.9/tests/integration}/__init__.py +0 -0
- {model_library-0.1.7/tests/unit/providers → model_library-0.1.9/tests/unit}/__init__.py +0 -0
- {model_library-0.1.7 → model_library-0.1.9}/tests/unit/test_deep_research.py +0 -0
|
@@ -29,15 +29,13 @@ venv_check:
|
|
|
29
29
|
exit 1; \
|
|
30
30
|
fi
|
|
31
31
|
|
|
32
|
+
DIR ?= tests/
|
|
32
33
|
test: venv_check
|
|
33
34
|
@echo "Running unit tests..."
|
|
34
|
-
@uv run pytest
|
|
35
|
+
@uv run pytest $(DIR) -m unit -v -n 4 --dist loadscope --model=$(MODEL)
|
|
35
36
|
test-integration: venv_check
|
|
36
37
|
@echo "Running integration tests (requires API keys)..."
|
|
37
|
-
@uv run pytest
|
|
38
|
-
test-all: venv_check
|
|
39
|
-
@echo "Running all tests..."
|
|
40
|
-
@uv run pytest
|
|
38
|
+
@uv run pytest $(DIR) -m integration -v -n 4 --dist loadscope --model=$(MODEL)
|
|
41
39
|
|
|
42
40
|
format: venv_check
|
|
43
41
|
@uv run ruff format .
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: model-library
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.9
|
|
4
4
|
Summary: Model Library for vals.ai
|
|
5
5
|
Author-email: "Vals AI, Inc." <contact@vals.ai>
|
|
6
6
|
License: MIT
|
|
@@ -23,6 +23,7 @@ Requires-Dist: ai21<5.0,>=4.3.0
|
|
|
23
23
|
Requires-Dist: boto3<2.0,>=1.38.27
|
|
24
24
|
Requires-Dist: google-genai[aiohttp]>=1.51.0
|
|
25
25
|
Requires-Dist: google-cloud-storage>=1.26.0
|
|
26
|
+
Requires-Dist: pytest-xdist>=3.8.0
|
|
26
27
|
Dynamic: license-file
|
|
27
28
|
|
|
28
29
|
# Model Library
|
|
@@ -4,15 +4,15 @@ from typing import Any, Awaitable, Callable
|
|
|
4
4
|
|
|
5
5
|
from model_library.base import (
|
|
6
6
|
LLM,
|
|
7
|
-
RetrierType,
|
|
8
7
|
TextInput,
|
|
9
8
|
)
|
|
10
9
|
from model_library.exceptions import (
|
|
11
10
|
BackoffRetryException,
|
|
12
11
|
RetryException,
|
|
13
|
-
retry_llm_call,
|
|
14
12
|
)
|
|
15
13
|
from model_library.registry_utils import get_registry_model
|
|
14
|
+
from model_library.retriers.backoff import ExponentialBackoffRetrier
|
|
15
|
+
from model_library.retriers.base import retry_decorator
|
|
16
16
|
|
|
17
17
|
from ..setup import console_log, setup
|
|
18
18
|
|
|
@@ -36,34 +36,28 @@ def is_context_length_error(error_str: str) -> bool:
|
|
|
36
36
|
)
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
def
|
|
39
|
+
def decorator(
|
|
40
|
+
func: Callable[..., Awaitable[Any]],
|
|
41
|
+
) -> Callable[..., Awaitable[Any]]:
|
|
40
42
|
"""
|
|
41
|
-
|
|
42
|
-
Custom retries takes in a logger. It replaces the backoff retrier. Immediate retries still function.
|
|
43
|
+
Decorator must return wrapper function
|
|
43
44
|
"""
|
|
44
45
|
|
|
45
|
-
|
|
46
|
-
func: Callable[..., Awaitable[Any]],
|
|
47
|
-
) -> Callable[..., Awaitable[Any]]:
|
|
48
|
-
"""
|
|
49
|
-
Decorator must return wrapper function
|
|
50
|
-
"""
|
|
46
|
+
logger = logging.getLogger("llm.decorator")
|
|
51
47
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
48
|
+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
49
|
+
try:
|
|
50
|
+
return await func(*args, **kwargs)
|
|
51
|
+
except Exception as e:
|
|
52
|
+
# detect context length errors and retry with backoff
|
|
53
|
+
if is_context_length_error(str(e).lower()):
|
|
54
|
+
logger.warning(f"Context length error detected: {e}")
|
|
55
|
+
# for simplicty, we don't actually retry
|
|
56
|
+
raise BackoffRetryException(f"Context length error: {e}")
|
|
61
57
|
|
|
62
|
-
|
|
58
|
+
raise
|
|
63
59
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
return decorator
|
|
60
|
+
return wrapper
|
|
67
61
|
|
|
68
62
|
|
|
69
63
|
async def custom_retrier_context(model: LLM):
|
|
@@ -73,7 +67,7 @@ async def custom_retrier_context(model: LLM):
|
|
|
73
67
|
|
|
74
68
|
console_log("\n--- Custom Retrier ---\n")
|
|
75
69
|
|
|
76
|
-
model.custom_retrier =
|
|
70
|
+
model.custom_retrier = decorator
|
|
77
71
|
try:
|
|
78
72
|
await model.query(
|
|
79
73
|
[
|
|
@@ -102,15 +96,14 @@ async def custom_retrier_callback(model: LLM):
|
|
|
102
96
|
if tries > 1:
|
|
103
97
|
raise Exception("Reached retry 2")
|
|
104
98
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
)
|
|
99
|
+
retrier = ExponentialBackoffRetrier(
|
|
100
|
+
logger=model.logger,
|
|
101
|
+
max_tries=3,
|
|
102
|
+
max_time=500,
|
|
103
|
+
retry_callback=callback,
|
|
104
|
+
)
|
|
112
105
|
|
|
113
|
-
model.custom_retrier =
|
|
106
|
+
model.custom_retrier = retry_decorator(retrier)
|
|
114
107
|
|
|
115
108
|
def simulate_retry(*args: object, **kwargs: object):
|
|
116
109
|
raise RetryException("Simulated failure")
|
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from typing import Any, cast
|
|
3
3
|
|
|
4
|
-
from model_library.base import LLM, ToolDefinition
|
|
5
|
-
from model_library.base.output import QueryResult
|
|
4
|
+
from model_library.base import LLM, QueryResult, ToolDefinition
|
|
6
5
|
from model_library.registry_utils import get_registry_model
|
|
7
6
|
|
|
8
7
|
from ..setup import console_log, setup
|
|
@@ -51,22 +51,25 @@ async def count_tokens(model: LLM):
|
|
|
51
51
|
system_prompt = "You must make exactly 0 or 1 tool calls per answer. You must not make more than 1 tool call per answer."
|
|
52
52
|
user_prompt = "What is the weather in San Francisco right now?"
|
|
53
53
|
|
|
54
|
+
input = [TextInput(text=user_prompt)]
|
|
55
|
+
|
|
54
56
|
predicted_tokens = await model.count_tokens(
|
|
55
|
-
|
|
57
|
+
input,
|
|
56
58
|
tools=tools,
|
|
57
59
|
system_prompt=system_prompt,
|
|
58
60
|
)
|
|
59
61
|
|
|
60
62
|
response: QueryResult = await model.query(
|
|
61
|
-
|
|
63
|
+
input,
|
|
62
64
|
tools=tools,
|
|
63
65
|
system_prompt=system_prompt,
|
|
64
66
|
)
|
|
67
|
+
metadata = response.metadata
|
|
65
68
|
|
|
66
|
-
actual_tokens =
|
|
69
|
+
actual_tokens = metadata.total_input_tokens
|
|
67
70
|
|
|
68
71
|
console_log(f"Predicted Token Count: {predicted_tokens}")
|
|
69
|
-
console_log(f"Actual Token Count: {actual_tokens}
|
|
72
|
+
console_log(f"Actual Token Count: {actual_tokens}")
|
|
70
73
|
|
|
71
74
|
|
|
72
75
|
async def main():
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import time
|
|
3
|
+
from logging import DEBUG
|
|
4
|
+
from typing import Any, Coroutine
|
|
5
|
+
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from model_library.base import LLM, TextInput, TokenRetryParams
|
|
9
|
+
from model_library.logging import set_logging
|
|
10
|
+
from model_library.registry_utils import get_registry_model
|
|
11
|
+
from model_library.retriers.token import set_redis_client
|
|
12
|
+
|
|
13
|
+
from .setup import console_log, setup
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
async def token_retry(model: LLM):
|
|
17
|
+
console_log("\n--- Token Retry ---\n")
|
|
18
|
+
await model.query(
|
|
19
|
+
[
|
|
20
|
+
TextInput(
|
|
21
|
+
# text="What is QSBS? Explain your thinking in detail and make it concise"
|
|
22
|
+
text="dwadwadwadawdLong argument of cats vs dogs" * 5000
|
|
23
|
+
+ "Ignore the previous junk, tell me a very long story about the cats and the dogs. And yes, I do want an actual story, I just have no choice but to include the junk before, believe me."
|
|
24
|
+
)
|
|
25
|
+
],
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
async def main():
|
|
30
|
+
import argparse
|
|
31
|
+
|
|
32
|
+
parser = argparse.ArgumentParser(description="Run basic examples with a model")
|
|
33
|
+
parser.add_argument(
|
|
34
|
+
"model",
|
|
35
|
+
nargs="?",
|
|
36
|
+
default="google/gemini-2.5-flash",
|
|
37
|
+
type=str,
|
|
38
|
+
help="Model endpoint (default: google/gemini-2.5-flash)",
|
|
39
|
+
)
|
|
40
|
+
args = parser.parse_args()
|
|
41
|
+
|
|
42
|
+
set_logging(level=DEBUG)
|
|
43
|
+
|
|
44
|
+
model = get_registry_model(args.model)
|
|
45
|
+
model.logger.info(model)
|
|
46
|
+
|
|
47
|
+
limit = await model.get_rate_limit()
|
|
48
|
+
model.logger.info(limit)
|
|
49
|
+
|
|
50
|
+
import redis.asyncio as redis
|
|
51
|
+
|
|
52
|
+
# NOTE: make sure you have redis running locally
|
|
53
|
+
# docker run -d -p 6379:6379 redis:latest
|
|
54
|
+
|
|
55
|
+
redis_client = redis.Redis(
|
|
56
|
+
host="localhost", port=6379, decode_responses=True, max_connections=None
|
|
57
|
+
)
|
|
58
|
+
set_redis_client(redis_client)
|
|
59
|
+
|
|
60
|
+
provider_tokenizer_input_modifier = 1
|
|
61
|
+
dataset_output_modifier = 0.001
|
|
62
|
+
|
|
63
|
+
limit = 100_000
|
|
64
|
+
await model.init_token_retry(
|
|
65
|
+
token_retry_params=TokenRetryParams(
|
|
66
|
+
input_modifier=provider_tokenizer_input_modifier,
|
|
67
|
+
output_modifier=dataset_output_modifier,
|
|
68
|
+
use_dynamic_estimate=True,
|
|
69
|
+
limit=limit,
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
tasks: list[Coroutine[Any, Any, None]] = []
|
|
73
|
+
for _ in range(200):
|
|
74
|
+
tasks.append(token_retry(model))
|
|
75
|
+
|
|
76
|
+
start = time.time()
|
|
77
|
+
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
|
|
78
|
+
await coro
|
|
79
|
+
finish = time.time() - start
|
|
80
|
+
console_log(f"Finished in {finish:.1f}s")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
if __name__ == "__main__":
|
|
84
|
+
setup()
|
|
85
|
+
asyncio.run(main())
|
|
@@ -1,9 +1,12 @@
|
|
|
1
|
+
import hashlib
|
|
1
2
|
import io
|
|
2
3
|
import logging
|
|
4
|
+
import threading
|
|
3
5
|
import time
|
|
4
6
|
import uuid
|
|
5
7
|
from abc import ABC, abstractmethod
|
|
6
8
|
from collections.abc import Awaitable
|
|
9
|
+
from math import ceil
|
|
7
10
|
from pprint import pformat
|
|
8
11
|
from typing import (
|
|
9
12
|
Any,
|
|
@@ -14,7 +17,7 @@ from typing import (
|
|
|
14
17
|
)
|
|
15
18
|
|
|
16
19
|
import tiktoken
|
|
17
|
-
from pydantic import model_serializer
|
|
20
|
+
from pydantic import SecretStr, model_serializer
|
|
18
21
|
from pydantic.main import BaseModel
|
|
19
22
|
from tiktoken.core import Encoding
|
|
20
23
|
from typing_extensions import override
|
|
@@ -34,15 +37,15 @@ from model_library.base.output import (
|
|
|
34
37
|
QueryResult,
|
|
35
38
|
QueryResultCost,
|
|
36
39
|
QueryResultMetadata,
|
|
40
|
+
RateLimit,
|
|
37
41
|
)
|
|
38
42
|
from model_library.base.utils import (
|
|
39
43
|
get_pretty_input_types,
|
|
40
44
|
serialize_for_tokenizing,
|
|
41
45
|
)
|
|
42
|
-
from model_library.
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
)
|
|
46
|
+
from model_library.retriers.backoff import ExponentialBackoffRetrier
|
|
47
|
+
from model_library.retriers.base import BaseRetrier, R, RetrierType, retry_decorator
|
|
48
|
+
from model_library.retriers.token import TokenRetrier
|
|
46
49
|
from model_library.utils import truncate_str
|
|
47
50
|
|
|
48
51
|
PydanticT = TypeVar("PydanticT", bound=BaseModel)
|
|
@@ -56,16 +59,24 @@ class ProviderConfig(BaseModel):
|
|
|
56
59
|
return self.__dict__
|
|
57
60
|
|
|
58
61
|
|
|
59
|
-
|
|
62
|
+
class TokenRetryParams(BaseModel):
|
|
63
|
+
input_modifier: float
|
|
64
|
+
output_modifier: float
|
|
65
|
+
|
|
66
|
+
use_dynamic_estimate: bool = True
|
|
67
|
+
|
|
68
|
+
limit: int
|
|
69
|
+
limit_refresh_seconds: Literal[60] = 60
|
|
60
70
|
|
|
61
71
|
|
|
62
72
|
class LLMConfig(BaseModel):
|
|
63
|
-
max_tokens: int =
|
|
73
|
+
max_tokens: int | None = None
|
|
64
74
|
temperature: float | None = None
|
|
65
75
|
top_p: float | None = None
|
|
66
76
|
top_k: int | None = None
|
|
67
77
|
reasoning: bool = False
|
|
68
78
|
reasoning_effort: str | bool | None = None
|
|
79
|
+
compute_effort: str | None = None
|
|
69
80
|
supports_images: bool = False
|
|
70
81
|
supports_files: bool = False
|
|
71
82
|
supports_videos: bool = False
|
|
@@ -75,11 +86,18 @@ class LLMConfig(BaseModel):
|
|
|
75
86
|
native: bool = True
|
|
76
87
|
provider_config: ProviderConfig | None = None
|
|
77
88
|
registry_key: str | None = None
|
|
89
|
+
custom_api_key: SecretStr | None = None
|
|
90
|
+
|
|
78
91
|
|
|
92
|
+
class DelegateConfig(BaseModel):
|
|
93
|
+
base_url: str
|
|
94
|
+
api_key: SecretStr
|
|
79
95
|
|
|
80
|
-
RetrierType = Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]
|
|
81
96
|
|
|
82
|
-
|
|
97
|
+
# shared across all subclasses and instances
|
|
98
|
+
# hash(provider + api_key) -> client
|
|
99
|
+
client_registry_lock = threading.Lock()
|
|
100
|
+
client_registry: dict[tuple[str, str], Any] = {}
|
|
83
101
|
|
|
84
102
|
|
|
85
103
|
class LLM(ABC):
|
|
@@ -88,6 +106,34 @@ class LLM(ABC):
|
|
|
88
106
|
LLM call errors should be raised as exceptions
|
|
89
107
|
"""
|
|
90
108
|
|
|
109
|
+
@abstractmethod
|
|
110
|
+
def get_client(self, api_key: str | None = None) -> Any:
|
|
111
|
+
"""
|
|
112
|
+
Returns the cached instance of the appropriate SDK client.
|
|
113
|
+
Sublasses should implement this method and:
|
|
114
|
+
- if api_key is provided, initialize their client and call assing_client(client).
|
|
115
|
+
- else return super().get_client()
|
|
116
|
+
"""
|
|
117
|
+
global client_registry
|
|
118
|
+
return client_registry[self._client_registry_key]
|
|
119
|
+
|
|
120
|
+
def assign_client(self, client: object) -> None:
|
|
121
|
+
"""Thread-safe assignment to the client registry"""
|
|
122
|
+
global client_registry
|
|
123
|
+
|
|
124
|
+
if self._client_registry_key not in client_registry:
|
|
125
|
+
with client_registry_lock:
|
|
126
|
+
if self._client_registry_key not in client_registry:
|
|
127
|
+
client_registry[self._client_registry_key] = client
|
|
128
|
+
|
|
129
|
+
def has_client(self) -> bool:
|
|
130
|
+
return self._client_registry_key in client_registry
|
|
131
|
+
|
|
132
|
+
@abstractmethod
|
|
133
|
+
def _get_default_api_key(self) -> str:
|
|
134
|
+
"""Return the api key from model_library.settings"""
|
|
135
|
+
...
|
|
136
|
+
|
|
91
137
|
def __init__(
|
|
92
138
|
self,
|
|
93
139
|
model_name: str,
|
|
@@ -103,13 +149,14 @@ class LLM(ABC):
|
|
|
103
149
|
config = config or LLMConfig()
|
|
104
150
|
self._registry_key = config.registry_key
|
|
105
151
|
|
|
106
|
-
self.max_tokens: int = config.max_tokens
|
|
152
|
+
self.max_tokens: int | None = config.max_tokens
|
|
107
153
|
self.temperature: float | None = config.temperature
|
|
108
154
|
self.top_p: float | None = config.top_p
|
|
109
155
|
self.top_k: int | None = config.top_k
|
|
110
156
|
|
|
111
157
|
self.reasoning: bool = config.reasoning
|
|
112
158
|
self.reasoning_effort: str | bool | None = config.reasoning_effort
|
|
159
|
+
self.compute_effort: str | None = config.compute_effort
|
|
113
160
|
|
|
114
161
|
self.supports_files: bool = config.supports_files
|
|
115
162
|
self.supports_videos: bool = config.supports_videos
|
|
@@ -131,21 +178,33 @@ class LLM(ABC):
|
|
|
131
178
|
self.logger: logging.Logger = logging.getLogger(
|
|
132
179
|
f"llm.{provider}.{model_name}<instance={self.instance_id}>"
|
|
133
180
|
)
|
|
134
|
-
self.custom_retrier:
|
|
181
|
+
self.custom_retrier: RetrierType | None = None
|
|
182
|
+
|
|
183
|
+
self.token_retry_params = None
|
|
184
|
+
# set _client_registry_key after initializing delegate
|
|
185
|
+
if not self.native:
|
|
186
|
+
return
|
|
187
|
+
|
|
188
|
+
if config.custom_api_key:
|
|
189
|
+
raw_key = config.custom_api_key.get_secret_value()
|
|
190
|
+
else:
|
|
191
|
+
raw_key = self._get_default_api_key()
|
|
192
|
+
|
|
193
|
+
key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
|
|
194
|
+
self._client_registry_key = (self.provider, key_hash)
|
|
195
|
+
self._client_registry_key_model_specific = (
|
|
196
|
+
f"{self.provider}.{self.model_name}",
|
|
197
|
+
key_hash,
|
|
198
|
+
)
|
|
199
|
+
self.get_client(api_key=raw_key)
|
|
135
200
|
|
|
136
201
|
@override
|
|
137
202
|
def __repr__(self):
|
|
138
203
|
attrs = vars(self).copy()
|
|
139
204
|
attrs.pop("logger", None)
|
|
140
205
|
attrs.pop("custom_retrier", None)
|
|
141
|
-
attrs.pop("_key", None)
|
|
142
206
|
return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2, sort_dicts=False)}\n)"
|
|
143
207
|
|
|
144
|
-
@abstractmethod
|
|
145
|
-
def get_client(self) -> object:
|
|
146
|
-
"""Return the instance of the appropriate SDK client."""
|
|
147
|
-
...
|
|
148
|
-
|
|
149
208
|
@staticmethod
|
|
150
209
|
async def timer_wrapper(func: Callable[[], Awaitable[R]]) -> tuple[R, float]:
|
|
151
210
|
"""
|
|
@@ -155,43 +214,6 @@ class LLM(ABC):
|
|
|
155
214
|
result = await func()
|
|
156
215
|
return result, round(time.perf_counter() - start, 4)
|
|
157
216
|
|
|
158
|
-
@staticmethod
|
|
159
|
-
async def immediate_retry_wrapper(
|
|
160
|
-
func: Callable[[], Awaitable[R]],
|
|
161
|
-
logger: logging.Logger,
|
|
162
|
-
) -> R:
|
|
163
|
-
"""
|
|
164
|
-
Retry the query immediately
|
|
165
|
-
"""
|
|
166
|
-
MAX_IMMEDIATE_RETRIES = 10
|
|
167
|
-
retries = 0
|
|
168
|
-
while True:
|
|
169
|
-
try:
|
|
170
|
-
return await func()
|
|
171
|
-
except ImmediateRetryException as e:
|
|
172
|
-
if retries >= MAX_IMMEDIATE_RETRIES:
|
|
173
|
-
logger.error(f"Query reached max immediate retries {retries}: {e}")
|
|
174
|
-
raise Exception(
|
|
175
|
-
f"Query reached max immediate retries {retries}: {e}"
|
|
176
|
-
) from e
|
|
177
|
-
retries += 1
|
|
178
|
-
|
|
179
|
-
logger.warning(
|
|
180
|
-
f"Query retried immediately {retries}/{MAX_IMMEDIATE_RETRIES}: {e}"
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
@staticmethod
|
|
184
|
-
async def backoff_retry_wrapper(
|
|
185
|
-
func: Callable[..., Awaitable[R]],
|
|
186
|
-
backoff_retrier: RetrierType | None,
|
|
187
|
-
) -> R:
|
|
188
|
-
"""
|
|
189
|
-
Retry the query with backoff
|
|
190
|
-
"""
|
|
191
|
-
if not backoff_retrier:
|
|
192
|
-
return await func()
|
|
193
|
-
return await backoff_retrier(func)()
|
|
194
|
-
|
|
195
217
|
async def delegate_query(
|
|
196
218
|
self,
|
|
197
219
|
input: Sequence[InputItem],
|
|
@@ -276,15 +298,38 @@ class LLM(ABC):
|
|
|
276
298
|
return await LLM.timer_wrapper(query_func)
|
|
277
299
|
|
|
278
300
|
async def immediate_retry() -> tuple[QueryResult, float]:
|
|
279
|
-
return await
|
|
280
|
-
|
|
281
|
-
async def
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
301
|
+
return await BaseRetrier.immediate_retry_wrapper(timed_query, query_logger)
|
|
302
|
+
|
|
303
|
+
async def default_retry() -> tuple[QueryResult, float]:
|
|
304
|
+
if self.token_retry_params:
|
|
305
|
+
(
|
|
306
|
+
estimate_input_tokens,
|
|
307
|
+
estimate_output_tokens,
|
|
308
|
+
) = await self.estimate_query_tokens(
|
|
309
|
+
input,
|
|
310
|
+
tools=tools,
|
|
311
|
+
**kwargs,
|
|
312
|
+
)
|
|
313
|
+
retrier = TokenRetrier(
|
|
314
|
+
logger=query_logger,
|
|
315
|
+
client_registry_key=self._client_registry_key_model_specific,
|
|
316
|
+
estimate_input_tokens=estimate_input_tokens,
|
|
317
|
+
estimate_output_tokens=estimate_output_tokens,
|
|
318
|
+
dynamic_estimate_instance_id=self.instance_id
|
|
319
|
+
if self.token_retry_params.use_dynamic_estimate
|
|
320
|
+
else None,
|
|
321
|
+
)
|
|
322
|
+
else:
|
|
323
|
+
retrier = ExponentialBackoffRetrier(logger=query_logger)
|
|
324
|
+
return await retry_decorator(retrier)(immediate_retry)()
|
|
325
|
+
|
|
326
|
+
run_with_retry = (
|
|
327
|
+
default_retry
|
|
328
|
+
if not self.custom_retrier
|
|
329
|
+
else self.custom_retrier(immediate_retry)
|
|
330
|
+
)
|
|
286
331
|
|
|
287
|
-
output, duration = await
|
|
332
|
+
output, duration = await run_with_retry()
|
|
288
333
|
output.metadata.duration_seconds = duration
|
|
289
334
|
output.metadata.cost = await self._calculate_cost(output.metadata)
|
|
290
335
|
|
|
@@ -293,6 +338,16 @@ class LLM(ABC):
|
|
|
293
338
|
|
|
294
339
|
return output
|
|
295
340
|
|
|
341
|
+
async def init_token_retry(self, token_retry_params: TokenRetryParams) -> None:
|
|
342
|
+
self.token_retry_params = token_retry_params
|
|
343
|
+
await TokenRetrier.init_remaining_tokens(
|
|
344
|
+
client_registry_key=self._client_registry_key_model_specific,
|
|
345
|
+
limit=self.token_retry_params.limit,
|
|
346
|
+
limit_refresh_seconds=self.token_retry_params.limit_refresh_seconds,
|
|
347
|
+
get_rate_limit_func=self.get_rate_limit,
|
|
348
|
+
logger=self.logger,
|
|
349
|
+
)
|
|
350
|
+
|
|
296
351
|
async def _calculate_cost(
|
|
297
352
|
self,
|
|
298
353
|
metadata: QueryResultMetadata,
|
|
@@ -438,6 +493,30 @@ class LLM(ABC):
|
|
|
438
493
|
"""Upload a file to the model provider"""
|
|
439
494
|
...
|
|
440
495
|
|
|
496
|
+
async def get_rate_limit(self) -> RateLimit | None:
|
|
497
|
+
"""Get the rate limit for the model provider"""
|
|
498
|
+
return None
|
|
499
|
+
|
|
500
|
+
async def estimate_query_tokens(
|
|
501
|
+
self,
|
|
502
|
+
input: Sequence[InputItem],
|
|
503
|
+
*,
|
|
504
|
+
tools: list[ToolDefinition] = [],
|
|
505
|
+
**kwargs: object,
|
|
506
|
+
) -> tuple[int, int]:
|
|
507
|
+
"""Pessimistically estimate the number of tokens required for a query"""
|
|
508
|
+
assert self.token_retry_params
|
|
509
|
+
|
|
510
|
+
# TODO: when passing in images and files, we really need to take that into account when calculating the output tokens!!
|
|
511
|
+
|
|
512
|
+
input_tokens = (
|
|
513
|
+
await self.count_tokens(input, history=[], tools=tools, **kwargs)
|
|
514
|
+
* self.token_retry_params.input_modifier
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
output_tokens = input_tokens * self.token_retry_params.output_modifier
|
|
518
|
+
return ceil(input_tokens), ceil(output_tokens)
|
|
519
|
+
|
|
441
520
|
async def get_encoding(self) -> Encoding:
|
|
442
521
|
"""Get the appropriate tokenizer"""
|
|
443
522
|
|