model-library 0.1.7__tar.gz → 0.1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {model_library-0.1.7 → model_library-0.1.8}/Makefile +3 -5
- {model_library-0.1.7 → model_library-0.1.8}/PKG-INFO +2 -1
- {model_library-0.1.7 → model_library-0.1.8}/examples/advanced/custom_retrier.py +26 -33
- {model_library-0.1.7 → model_library-0.1.8}/examples/advanced/web_search.py +1 -2
- {model_library-0.1.7 → model_library-0.1.8}/examples/count_tokens.py +7 -4
- model_library-0.1.8/examples/token_retry.py +85 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/base/base.py +139 -62
- model_library-0.1.8/model_library/base/delegate_only.py +175 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/base/output.py +43 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/base/utils.py +35 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/alibaba_models.yaml +44 -57
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/all_models.json +253 -126
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/kimi_models.yaml +30 -3
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/openai_models.yaml +15 -23
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/zai_models.yaml +24 -3
- {model_library-0.1.7 → model_library-0.1.8}/model_library/exceptions.py +3 -77
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/ai21labs.py +12 -8
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/alibaba.py +17 -8
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/amazon.py +49 -16
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/anthropic.py +93 -40
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/azure.py +22 -10
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/cohere.py +7 -7
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/deepseek.py +8 -8
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/fireworks.py +7 -8
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/google/batch.py +14 -10
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/google/google.py +48 -29
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/inception.py +7 -7
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/kimi.py +18 -8
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/minimax.py +15 -17
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/mistral.py +20 -8
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/openai.py +99 -22
- model_library-0.1.8/model_library/providers/openrouter.py +34 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/perplexity.py +7 -7
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/together.py +7 -8
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/vals.py +12 -6
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/xai.py +47 -42
- model_library-0.1.8/model_library/providers/zai.py +64 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/registry_utils.py +39 -15
- model_library-0.1.8/model_library/retriers/backoff.py +73 -0
- model_library-0.1.8/model_library/retriers/base.py +225 -0
- model_library-0.1.8/model_library/retriers/token.py +427 -0
- model_library-0.1.8/model_library/retriers/utils.py +11 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/settings.py +1 -1
- {model_library-0.1.7 → model_library-0.1.8}/model_library/utils.py +13 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library.egg-info/PKG-INFO +2 -1
- {model_library-0.1.7 → model_library-0.1.8}/model_library.egg-info/SOURCES.txt +13 -8
- {model_library-0.1.7 → model_library-0.1.8}/model_library.egg-info/requires.txt +1 -0
- {model_library-0.1.7 → model_library-0.1.8}/pyproject.toml +2 -0
- {model_library-0.1.7 → model_library-0.1.8}/scripts/browse_models.py +2 -2
- {model_library-0.1.7 → model_library-0.1.8}/scripts/run_models.py +13 -13
- model_library-0.1.8/tests/README.md +75 -0
- model_library-0.1.8/tests/conftest.py +85 -0
- model_library-0.1.8/tests/integration/conftest.py +8 -0
- model_library-0.1.8/tests/integration/test_basic.py +15 -0
- {model_library-0.1.7 → model_library-0.1.8}/tests/integration/test_batch.py +4 -12
- model_library-0.1.8/tests/integration/test_files.py +38 -0
- model_library-0.1.8/tests/integration/test_long_problem.py +24 -0
- {model_library-0.1.7 → model_library-0.1.8}/tests/integration/test_reasoning.py +3 -12
- model_library-0.1.8/tests/integration/test_retry.py +28 -0
- {model_library-0.1.7 → model_library-0.1.8}/tests/integration/test_streaming.py +13 -30
- {model_library-0.1.7 → model_library-0.1.8}/tests/integration/test_structured_output.py +0 -6
- {model_library-0.1.7 → model_library-0.1.8}/tests/integration/test_tools.py +1 -18
- model_library-0.1.8/tests/test_helpers.py +183 -0
- model_library-0.1.8/tests/unit/conftest.py +102 -0
- model_library-0.1.8/tests/unit/test_batch.py +353 -0
- model_library-0.1.8/tests/unit/test_count_tokens.py +42 -0
- model_library-0.1.8/tests/unit/test_get_client.py +188 -0
- model_library-0.1.8/tests/unit/test_openai_config.py +32 -0
- {model_library-0.1.7 → model_library-0.1.8}/tests/unit/test_prompt_caching.py +31 -81
- {model_library-0.1.7 → model_library-0.1.8}/tests/unit/test_query_logger.py +18 -0
- {model_library-0.1.7 → model_library-0.1.8}/tests/unit/test_registry.py +7 -7
- {model_library-0.1.7 → model_library-0.1.8}/tests/unit/test_result_metadata.py +28 -18
- {model_library-0.1.7 → model_library-0.1.8}/tests/unit/test_retry.py +44 -41
- model_library-0.1.8/tests/unit/test_token_retry.py +405 -0
- {model_library-0.1.7 → model_library-0.1.8}/tests/unit/test_tools.py +4 -6
- {model_library-0.1.7 → model_library-0.1.8}/uv.lock +25 -1
- model_library-0.1.7/model_library/base/delegate_only.py +0 -108
- model_library-0.1.7/model_library/providers/zai.py +0 -34
- model_library-0.1.7/tests/README.md +0 -87
- model_library-0.1.7/tests/conftest.py +0 -275
- model_library-0.1.7/tests/integration/conftest.py +0 -8
- model_library-0.1.7/tests/integration/test_completion.py +0 -41
- model_library-0.1.7/tests/integration/test_files.py +0 -279
- model_library-0.1.7/tests/integration/test_retry.py +0 -95
- model_library-0.1.7/tests/test_helpers.py +0 -89
- model_library-0.1.7/tests/unit/conftest.py +0 -53
- model_library-0.1.7/tests/unit/providers/test_fireworks_provider.py +0 -48
- model_library-0.1.7/tests/unit/providers/test_google_provider.py +0 -58
- model_library-0.1.7/tests/unit/test_batch.py +0 -236
- model_library-0.1.7/tests/unit/test_context_window.py +0 -45
- model_library-0.1.7/tests/unit/test_count_tokens.py +0 -67
- model_library-0.1.7/tests/unit/test_perplexity_provider.py +0 -73
- model_library-0.1.7/tests/unit/test_streaming.py +0 -83
- {model_library-0.1.7 → model_library-0.1.8}/.gitattributes +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/.github/workflows/publish.yml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/.github/workflows/style.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/.github/workflows/test.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/.github/workflows/typecheck.yml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/.gitignore +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/LICENSE +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/README.md +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/examples/README.md +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/examples/advanced/batch.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/examples/advanced/deep_research.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/examples/advanced/stress.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/examples/advanced/structured_output.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/examples/basics.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/examples/data/files.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/examples/data/images.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/examples/embeddings.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/examples/files.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/examples/images.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/examples/prompt_caching.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/examples/setup.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/examples/tool_calls.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/__init__.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/base/__init__.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/base/batch.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/base/input.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/README.md +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/ai21labs_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/amazon_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/anthropic_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/cohere_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/deepseek_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/dummy_model.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/fireworks_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/google_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/inception_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/minimax_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/mistral_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/perplexity_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/together_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/config/xai_models.yaml +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/file_utils.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/logging.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/model_utils.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/__init__.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/google/__init__.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/py.typed +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library/register_models.py +0 -0
- {model_library-0.1.7/tests → model_library-0.1.8/model_library/retriers}/__init__.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library.egg-info/dependency_links.txt +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/model_library.egg-info/top_level.txt +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/scripts/config.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/scripts/publish.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/setup.cfg +0 -0
- {model_library-0.1.7/tests/integration → model_library-0.1.8/tests}/__init__.py +0 -0
- {model_library-0.1.7/tests/unit → model_library-0.1.8/tests/integration}/__init__.py +0 -0
- {model_library-0.1.7/tests/unit/providers → model_library-0.1.8/tests/unit}/__init__.py +0 -0
- {model_library-0.1.7 → model_library-0.1.8}/tests/unit/test_deep_research.py +0 -0
|
@@ -29,15 +29,13 @@ venv_check:
|
|
|
29
29
|
exit 1; \
|
|
30
30
|
fi
|
|
31
31
|
|
|
32
|
+
DIR ?= tests/
|
|
32
33
|
test: venv_check
|
|
33
34
|
@echo "Running unit tests..."
|
|
34
|
-
@uv run pytest
|
|
35
|
+
@uv run pytest $(DIR) -m unit -v -n 4 --dist loadscope --model=$(MODEL)
|
|
35
36
|
test-integration: venv_check
|
|
36
37
|
@echo "Running integration tests (requires API keys)..."
|
|
37
|
-
@uv run pytest
|
|
38
|
-
test-all: venv_check
|
|
39
|
-
@echo "Running all tests..."
|
|
40
|
-
@uv run pytest
|
|
38
|
+
@uv run pytest $(DIR) -m integration -v -n 4 --dist loadscope --model=$(MODEL)
|
|
41
39
|
|
|
42
40
|
format: venv_check
|
|
43
41
|
@uv run ruff format .
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: model-library
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.8
|
|
4
4
|
Summary: Model Library for vals.ai
|
|
5
5
|
Author-email: "Vals AI, Inc." <contact@vals.ai>
|
|
6
6
|
License: MIT
|
|
@@ -23,6 +23,7 @@ Requires-Dist: ai21<5.0,>=4.3.0
|
|
|
23
23
|
Requires-Dist: boto3<2.0,>=1.38.27
|
|
24
24
|
Requires-Dist: google-genai[aiohttp]>=1.51.0
|
|
25
25
|
Requires-Dist: google-cloud-storage>=1.26.0
|
|
26
|
+
Requires-Dist: pytest-xdist>=3.8.0
|
|
26
27
|
Dynamic: license-file
|
|
27
28
|
|
|
28
29
|
# Model Library
|
|
@@ -4,15 +4,15 @@ from typing import Any, Awaitable, Callable
|
|
|
4
4
|
|
|
5
5
|
from model_library.base import (
|
|
6
6
|
LLM,
|
|
7
|
-
RetrierType,
|
|
8
7
|
TextInput,
|
|
9
8
|
)
|
|
10
9
|
from model_library.exceptions import (
|
|
11
10
|
BackoffRetryException,
|
|
12
11
|
RetryException,
|
|
13
|
-
retry_llm_call,
|
|
14
12
|
)
|
|
15
13
|
from model_library.registry_utils import get_registry_model
|
|
14
|
+
from model_library.retriers.backoff import ExponentialBackoffRetrier
|
|
15
|
+
from model_library.retriers.base import retry_decorator
|
|
16
16
|
|
|
17
17
|
from ..setup import console_log, setup
|
|
18
18
|
|
|
@@ -36,34 +36,28 @@ def is_context_length_error(error_str: str) -> bool:
|
|
|
36
36
|
)
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
def
|
|
39
|
+
def decorator(
|
|
40
|
+
func: Callable[..., Awaitable[Any]],
|
|
41
|
+
) -> Callable[..., Awaitable[Any]]:
|
|
40
42
|
"""
|
|
41
|
-
|
|
42
|
-
Custom retries takes in a logger. It replaces the backoff retrier. Immediate retries still function.
|
|
43
|
+
Decorator must return wrapper function
|
|
43
44
|
"""
|
|
44
45
|
|
|
45
|
-
|
|
46
|
-
func: Callable[..., Awaitable[Any]],
|
|
47
|
-
) -> Callable[..., Awaitable[Any]]:
|
|
48
|
-
"""
|
|
49
|
-
Decorator must return wrapper function
|
|
50
|
-
"""
|
|
46
|
+
logger = logging.getLogger("llm.decorator")
|
|
51
47
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
48
|
+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
49
|
+
try:
|
|
50
|
+
return await func(*args, **kwargs)
|
|
51
|
+
except Exception as e:
|
|
52
|
+
# detect context length errors and retry with backoff
|
|
53
|
+
if is_context_length_error(str(e).lower()):
|
|
54
|
+
logger.warning(f"Context length error detected: {e}")
|
|
55
|
+
# for simplicty, we don't actually retry
|
|
56
|
+
raise BackoffRetryException(f"Context length error: {e}")
|
|
61
57
|
|
|
62
|
-
|
|
58
|
+
raise
|
|
63
59
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
return decorator
|
|
60
|
+
return wrapper
|
|
67
61
|
|
|
68
62
|
|
|
69
63
|
async def custom_retrier_context(model: LLM):
|
|
@@ -73,7 +67,7 @@ async def custom_retrier_context(model: LLM):
|
|
|
73
67
|
|
|
74
68
|
console_log("\n--- Custom Retrier ---\n")
|
|
75
69
|
|
|
76
|
-
model.custom_retrier =
|
|
70
|
+
model.custom_retrier = decorator
|
|
77
71
|
try:
|
|
78
72
|
await model.query(
|
|
79
73
|
[
|
|
@@ -102,15 +96,14 @@ async def custom_retrier_callback(model: LLM):
|
|
|
102
96
|
if tries > 1:
|
|
103
97
|
raise Exception("Reached retry 2")
|
|
104
98
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
)
|
|
99
|
+
retrier = ExponentialBackoffRetrier(
|
|
100
|
+
logger=model.logger,
|
|
101
|
+
max_tries=3,
|
|
102
|
+
max_time=500,
|
|
103
|
+
retry_callback=callback,
|
|
104
|
+
)
|
|
112
105
|
|
|
113
|
-
model.custom_retrier =
|
|
106
|
+
model.custom_retrier = retry_decorator(retrier)
|
|
114
107
|
|
|
115
108
|
def simulate_retry(*args: object, **kwargs: object):
|
|
116
109
|
raise RetryException("Simulated failure")
|
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from typing import Any, cast
|
|
3
3
|
|
|
4
|
-
from model_library.base import LLM, ToolDefinition
|
|
5
|
-
from model_library.base.output import QueryResult
|
|
4
|
+
from model_library.base import LLM, QueryResult, ToolDefinition
|
|
6
5
|
from model_library.registry_utils import get_registry_model
|
|
7
6
|
|
|
8
7
|
from ..setup import console_log, setup
|
|
@@ -51,22 +51,25 @@ async def count_tokens(model: LLM):
|
|
|
51
51
|
system_prompt = "You must make exactly 0 or 1 tool calls per answer. You must not make more than 1 tool call per answer."
|
|
52
52
|
user_prompt = "What is the weather in San Francisco right now?"
|
|
53
53
|
|
|
54
|
+
input = [TextInput(text=user_prompt)]
|
|
55
|
+
|
|
54
56
|
predicted_tokens = await model.count_tokens(
|
|
55
|
-
|
|
57
|
+
input,
|
|
56
58
|
tools=tools,
|
|
57
59
|
system_prompt=system_prompt,
|
|
58
60
|
)
|
|
59
61
|
|
|
60
62
|
response: QueryResult = await model.query(
|
|
61
|
-
|
|
63
|
+
input,
|
|
62
64
|
tools=tools,
|
|
63
65
|
system_prompt=system_prompt,
|
|
64
66
|
)
|
|
67
|
+
metadata = response.metadata
|
|
65
68
|
|
|
66
|
-
actual_tokens =
|
|
69
|
+
actual_tokens = metadata.total_input_tokens
|
|
67
70
|
|
|
68
71
|
console_log(f"Predicted Token Count: {predicted_tokens}")
|
|
69
|
-
console_log(f"Actual Token Count: {actual_tokens}
|
|
72
|
+
console_log(f"Actual Token Count: {actual_tokens}")
|
|
70
73
|
|
|
71
74
|
|
|
72
75
|
async def main():
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import time
|
|
3
|
+
from logging import DEBUG
|
|
4
|
+
from typing import Any, Coroutine
|
|
5
|
+
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from model_library.base import LLM, TextInput, TokenRetryParams
|
|
9
|
+
from model_library.logging import set_logging
|
|
10
|
+
from model_library.registry_utils import get_registry_model
|
|
11
|
+
from model_library.retriers.token import set_redis_client
|
|
12
|
+
|
|
13
|
+
from .setup import console_log, setup
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
async def token_retry(model: LLM):
|
|
17
|
+
console_log("\n--- Token Retry ---\n")
|
|
18
|
+
await model.query(
|
|
19
|
+
[
|
|
20
|
+
TextInput(
|
|
21
|
+
# text="What is QSBS? Explain your thinking in detail and make it concise"
|
|
22
|
+
text="dwadwadwadawdLong argument of cats vs dogs" * 5000
|
|
23
|
+
+ "Ignore the previous junk, tell me a very long story about the cats and the dogs. And yes, I do want an actual story, I just have no choice but to include the junk before, believe me."
|
|
24
|
+
)
|
|
25
|
+
],
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
async def main():
|
|
30
|
+
import argparse
|
|
31
|
+
|
|
32
|
+
parser = argparse.ArgumentParser(description="Run basic examples with a model")
|
|
33
|
+
parser.add_argument(
|
|
34
|
+
"model",
|
|
35
|
+
nargs="?",
|
|
36
|
+
default="google/gemini-2.5-flash",
|
|
37
|
+
type=str,
|
|
38
|
+
help="Model endpoint (default: google/gemini-2.5-flash)",
|
|
39
|
+
)
|
|
40
|
+
args = parser.parse_args()
|
|
41
|
+
|
|
42
|
+
set_logging(level=DEBUG)
|
|
43
|
+
|
|
44
|
+
model = get_registry_model(args.model)
|
|
45
|
+
model.logger.info(model)
|
|
46
|
+
|
|
47
|
+
limit = await model.get_rate_limit()
|
|
48
|
+
model.logger.info(limit)
|
|
49
|
+
|
|
50
|
+
import redis.asyncio as redis
|
|
51
|
+
|
|
52
|
+
# NOTE: make sure you have redis running locally
|
|
53
|
+
# docker run -d -p 6379:6379 redis:latest
|
|
54
|
+
|
|
55
|
+
redis_client = redis.Redis(
|
|
56
|
+
host="localhost", port=6379, decode_responses=True, max_connections=None
|
|
57
|
+
)
|
|
58
|
+
set_redis_client(redis_client)
|
|
59
|
+
|
|
60
|
+
provider_tokenizer_input_modifier = 1
|
|
61
|
+
dataset_output_modifier = 0.001
|
|
62
|
+
|
|
63
|
+
limit = 100_000
|
|
64
|
+
await model.init_token_retry(
|
|
65
|
+
token_retry_params=TokenRetryParams(
|
|
66
|
+
input_modifier=provider_tokenizer_input_modifier,
|
|
67
|
+
output_modifier=dataset_output_modifier,
|
|
68
|
+
use_dynamic_estimate=True,
|
|
69
|
+
limit=limit,
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
tasks: list[Coroutine[Any, Any, None]] = []
|
|
73
|
+
for _ in range(200):
|
|
74
|
+
tasks.append(token_retry(model))
|
|
75
|
+
|
|
76
|
+
start = time.time()
|
|
77
|
+
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
|
|
78
|
+
await coro
|
|
79
|
+
finish = time.time() - start
|
|
80
|
+
console_log(f"Finished in {finish:.1f}s")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
if __name__ == "__main__":
|
|
84
|
+
setup()
|
|
85
|
+
asyncio.run(main())
|
|
@@ -1,9 +1,12 @@
|
|
|
1
|
+
import hashlib
|
|
1
2
|
import io
|
|
2
3
|
import logging
|
|
4
|
+
import threading
|
|
3
5
|
import time
|
|
4
6
|
import uuid
|
|
5
7
|
from abc import ABC, abstractmethod
|
|
6
8
|
from collections.abc import Awaitable
|
|
9
|
+
from math import ceil
|
|
7
10
|
from pprint import pformat
|
|
8
11
|
from typing import (
|
|
9
12
|
Any,
|
|
@@ -14,7 +17,7 @@ from typing import (
|
|
|
14
17
|
)
|
|
15
18
|
|
|
16
19
|
import tiktoken
|
|
17
|
-
from pydantic import model_serializer
|
|
20
|
+
from pydantic import SecretStr, model_serializer
|
|
18
21
|
from pydantic.main import BaseModel
|
|
19
22
|
from tiktoken.core import Encoding
|
|
20
23
|
from typing_extensions import override
|
|
@@ -34,15 +37,15 @@ from model_library.base.output import (
|
|
|
34
37
|
QueryResult,
|
|
35
38
|
QueryResultCost,
|
|
36
39
|
QueryResultMetadata,
|
|
40
|
+
RateLimit,
|
|
37
41
|
)
|
|
38
42
|
from model_library.base.utils import (
|
|
39
43
|
get_pretty_input_types,
|
|
40
44
|
serialize_for_tokenizing,
|
|
41
45
|
)
|
|
42
|
-
from model_library.
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
)
|
|
46
|
+
from model_library.retriers.backoff import ExponentialBackoffRetrier
|
|
47
|
+
from model_library.retriers.base import BaseRetrier, R, RetrierType, retry_decorator
|
|
48
|
+
from model_library.retriers.token import TokenRetrier
|
|
46
49
|
from model_library.utils import truncate_str
|
|
47
50
|
|
|
48
51
|
PydanticT = TypeVar("PydanticT", bound=BaseModel)
|
|
@@ -56,11 +59,18 @@ class ProviderConfig(BaseModel):
|
|
|
56
59
|
return self.__dict__
|
|
57
60
|
|
|
58
61
|
|
|
59
|
-
|
|
62
|
+
class TokenRetryParams(BaseModel):
|
|
63
|
+
input_modifier: float
|
|
64
|
+
output_modifier: float
|
|
65
|
+
|
|
66
|
+
use_dynamic_estimate: bool = True
|
|
67
|
+
|
|
68
|
+
limit: int
|
|
69
|
+
limit_refresh_seconds: Literal[60] = 60
|
|
60
70
|
|
|
61
71
|
|
|
62
72
|
class LLMConfig(BaseModel):
|
|
63
|
-
max_tokens: int =
|
|
73
|
+
max_tokens: int | None = None
|
|
64
74
|
temperature: float | None = None
|
|
65
75
|
top_p: float | None = None
|
|
66
76
|
top_k: int | None = None
|
|
@@ -75,11 +85,18 @@ class LLMConfig(BaseModel):
|
|
|
75
85
|
native: bool = True
|
|
76
86
|
provider_config: ProviderConfig | None = None
|
|
77
87
|
registry_key: str | None = None
|
|
88
|
+
custom_api_key: SecretStr | None = None
|
|
89
|
+
|
|
78
90
|
|
|
91
|
+
class DelegateConfig(BaseModel):
|
|
92
|
+
base_url: str
|
|
93
|
+
api_key: SecretStr
|
|
79
94
|
|
|
80
|
-
RetrierType = Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]
|
|
81
95
|
|
|
82
|
-
|
|
96
|
+
# shared across all subclasses and instances
|
|
97
|
+
# hash(provider + api_key) -> client
|
|
98
|
+
client_registry_lock = threading.Lock()
|
|
99
|
+
client_registry: dict[tuple[str, str], Any] = {}
|
|
83
100
|
|
|
84
101
|
|
|
85
102
|
class LLM(ABC):
|
|
@@ -88,6 +105,34 @@ class LLM(ABC):
|
|
|
88
105
|
LLM call errors should be raised as exceptions
|
|
89
106
|
"""
|
|
90
107
|
|
|
108
|
+
@abstractmethod
|
|
109
|
+
def get_client(self, api_key: str | None = None) -> Any:
|
|
110
|
+
"""
|
|
111
|
+
Returns the cached instance of the appropriate SDK client.
|
|
112
|
+
Sublasses should implement this method and:
|
|
113
|
+
- if api_key is provided, initialize their client and call assing_client(client).
|
|
114
|
+
- else return super().get_client()
|
|
115
|
+
"""
|
|
116
|
+
global client_registry
|
|
117
|
+
return client_registry[self._client_registry_key]
|
|
118
|
+
|
|
119
|
+
def assign_client(self, client: object) -> None:
|
|
120
|
+
"""Thread-safe assignment to the client registry"""
|
|
121
|
+
global client_registry
|
|
122
|
+
|
|
123
|
+
if self._client_registry_key not in client_registry:
|
|
124
|
+
with client_registry_lock:
|
|
125
|
+
if self._client_registry_key not in client_registry:
|
|
126
|
+
client_registry[self._client_registry_key] = client
|
|
127
|
+
|
|
128
|
+
def has_client(self) -> bool:
|
|
129
|
+
return self._client_registry_key in client_registry
|
|
130
|
+
|
|
131
|
+
@abstractmethod
|
|
132
|
+
def _get_default_api_key(self) -> str:
|
|
133
|
+
"""Return the api key from model_library.settings"""
|
|
134
|
+
...
|
|
135
|
+
|
|
91
136
|
def __init__(
|
|
92
137
|
self,
|
|
93
138
|
model_name: str,
|
|
@@ -103,7 +148,7 @@ class LLM(ABC):
|
|
|
103
148
|
config = config or LLMConfig()
|
|
104
149
|
self._registry_key = config.registry_key
|
|
105
150
|
|
|
106
|
-
self.max_tokens: int = config.max_tokens
|
|
151
|
+
self.max_tokens: int | None = config.max_tokens
|
|
107
152
|
self.temperature: float | None = config.temperature
|
|
108
153
|
self.top_p: float | None = config.top_p
|
|
109
154
|
self.top_k: int | None = config.top_k
|
|
@@ -131,21 +176,33 @@ class LLM(ABC):
|
|
|
131
176
|
self.logger: logging.Logger = logging.getLogger(
|
|
132
177
|
f"llm.{provider}.{model_name}<instance={self.instance_id}>"
|
|
133
178
|
)
|
|
134
|
-
self.custom_retrier:
|
|
179
|
+
self.custom_retrier: RetrierType | None = None
|
|
180
|
+
|
|
181
|
+
self.token_retry_params = None
|
|
182
|
+
# set _client_registry_key after initializing delegate
|
|
183
|
+
if not self.native:
|
|
184
|
+
return
|
|
185
|
+
|
|
186
|
+
if config.custom_api_key:
|
|
187
|
+
raw_key = config.custom_api_key.get_secret_value()
|
|
188
|
+
else:
|
|
189
|
+
raw_key = self._get_default_api_key()
|
|
190
|
+
|
|
191
|
+
key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
|
|
192
|
+
self._client_registry_key = (self.provider, key_hash)
|
|
193
|
+
self._client_registry_key_model_specific = (
|
|
194
|
+
f"{self.provider}.{self.model_name}",
|
|
195
|
+
key_hash,
|
|
196
|
+
)
|
|
197
|
+
self.get_client(api_key=raw_key)
|
|
135
198
|
|
|
136
199
|
@override
|
|
137
200
|
def __repr__(self):
|
|
138
201
|
attrs = vars(self).copy()
|
|
139
202
|
attrs.pop("logger", None)
|
|
140
203
|
attrs.pop("custom_retrier", None)
|
|
141
|
-
attrs.pop("_key", None)
|
|
142
204
|
return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2, sort_dicts=False)}\n)"
|
|
143
205
|
|
|
144
|
-
@abstractmethod
|
|
145
|
-
def get_client(self) -> object:
|
|
146
|
-
"""Return the instance of the appropriate SDK client."""
|
|
147
|
-
...
|
|
148
|
-
|
|
149
206
|
@staticmethod
|
|
150
207
|
async def timer_wrapper(func: Callable[[], Awaitable[R]]) -> tuple[R, float]:
|
|
151
208
|
"""
|
|
@@ -155,43 +212,6 @@ class LLM(ABC):
|
|
|
155
212
|
result = await func()
|
|
156
213
|
return result, round(time.perf_counter() - start, 4)
|
|
157
214
|
|
|
158
|
-
@staticmethod
|
|
159
|
-
async def immediate_retry_wrapper(
|
|
160
|
-
func: Callable[[], Awaitable[R]],
|
|
161
|
-
logger: logging.Logger,
|
|
162
|
-
) -> R:
|
|
163
|
-
"""
|
|
164
|
-
Retry the query immediately
|
|
165
|
-
"""
|
|
166
|
-
MAX_IMMEDIATE_RETRIES = 10
|
|
167
|
-
retries = 0
|
|
168
|
-
while True:
|
|
169
|
-
try:
|
|
170
|
-
return await func()
|
|
171
|
-
except ImmediateRetryException as e:
|
|
172
|
-
if retries >= MAX_IMMEDIATE_RETRIES:
|
|
173
|
-
logger.error(f"Query reached max immediate retries {retries}: {e}")
|
|
174
|
-
raise Exception(
|
|
175
|
-
f"Query reached max immediate retries {retries}: {e}"
|
|
176
|
-
) from e
|
|
177
|
-
retries += 1
|
|
178
|
-
|
|
179
|
-
logger.warning(
|
|
180
|
-
f"Query retried immediately {retries}/{MAX_IMMEDIATE_RETRIES}: {e}"
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
@staticmethod
|
|
184
|
-
async def backoff_retry_wrapper(
|
|
185
|
-
func: Callable[..., Awaitable[R]],
|
|
186
|
-
backoff_retrier: RetrierType | None,
|
|
187
|
-
) -> R:
|
|
188
|
-
"""
|
|
189
|
-
Retry the query with backoff
|
|
190
|
-
"""
|
|
191
|
-
if not backoff_retrier:
|
|
192
|
-
return await func()
|
|
193
|
-
return await backoff_retrier(func)()
|
|
194
|
-
|
|
195
215
|
async def delegate_query(
|
|
196
216
|
self,
|
|
197
217
|
input: Sequence[InputItem],
|
|
@@ -276,15 +296,38 @@ class LLM(ABC):
|
|
|
276
296
|
return await LLM.timer_wrapper(query_func)
|
|
277
297
|
|
|
278
298
|
async def immediate_retry() -> tuple[QueryResult, float]:
|
|
279
|
-
return await
|
|
280
|
-
|
|
281
|
-
async def
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
299
|
+
return await BaseRetrier.immediate_retry_wrapper(timed_query, query_logger)
|
|
300
|
+
|
|
301
|
+
async def default_retry() -> tuple[QueryResult, float]:
|
|
302
|
+
if self.token_retry_params:
|
|
303
|
+
(
|
|
304
|
+
estimate_input_tokens,
|
|
305
|
+
estimate_output_tokens,
|
|
306
|
+
) = await self.estimate_query_tokens(
|
|
307
|
+
input,
|
|
308
|
+
tools=tools,
|
|
309
|
+
**kwargs,
|
|
310
|
+
)
|
|
311
|
+
retrier = TokenRetrier(
|
|
312
|
+
logger=query_logger,
|
|
313
|
+
client_registry_key=self._client_registry_key_model_specific,
|
|
314
|
+
estimate_input_tokens=estimate_input_tokens,
|
|
315
|
+
estimate_output_tokens=estimate_output_tokens,
|
|
316
|
+
dynamic_estimate_instance_id=self.instance_id
|
|
317
|
+
if self.token_retry_params.use_dynamic_estimate
|
|
318
|
+
else None,
|
|
319
|
+
)
|
|
320
|
+
else:
|
|
321
|
+
retrier = ExponentialBackoffRetrier(logger=query_logger)
|
|
322
|
+
return await retry_decorator(retrier)(immediate_retry)()
|
|
323
|
+
|
|
324
|
+
run_with_retry = (
|
|
325
|
+
default_retry
|
|
326
|
+
if not self.custom_retrier
|
|
327
|
+
else self.custom_retrier(immediate_retry)
|
|
328
|
+
)
|
|
286
329
|
|
|
287
|
-
output, duration = await
|
|
330
|
+
output, duration = await run_with_retry()
|
|
288
331
|
output.metadata.duration_seconds = duration
|
|
289
332
|
output.metadata.cost = await self._calculate_cost(output.metadata)
|
|
290
333
|
|
|
@@ -293,6 +336,16 @@ class LLM(ABC):
|
|
|
293
336
|
|
|
294
337
|
return output
|
|
295
338
|
|
|
339
|
+
async def init_token_retry(self, token_retry_params: TokenRetryParams) -> None:
|
|
340
|
+
self.token_retry_params = token_retry_params
|
|
341
|
+
await TokenRetrier.init_remaining_tokens(
|
|
342
|
+
client_registry_key=self._client_registry_key_model_specific,
|
|
343
|
+
limit=self.token_retry_params.limit,
|
|
344
|
+
limit_refresh_seconds=self.token_retry_params.limit_refresh_seconds,
|
|
345
|
+
get_rate_limit_func=self.get_rate_limit,
|
|
346
|
+
logger=self.logger,
|
|
347
|
+
)
|
|
348
|
+
|
|
296
349
|
async def _calculate_cost(
|
|
297
350
|
self,
|
|
298
351
|
metadata: QueryResultMetadata,
|
|
@@ -438,6 +491,30 @@ class LLM(ABC):
|
|
|
438
491
|
"""Upload a file to the model provider"""
|
|
439
492
|
...
|
|
440
493
|
|
|
494
|
+
async def get_rate_limit(self) -> RateLimit | None:
|
|
495
|
+
"""Get the rate limit for the model provider"""
|
|
496
|
+
return None
|
|
497
|
+
|
|
498
|
+
async def estimate_query_tokens(
|
|
499
|
+
self,
|
|
500
|
+
input: Sequence[InputItem],
|
|
501
|
+
*,
|
|
502
|
+
tools: list[ToolDefinition] = [],
|
|
503
|
+
**kwargs: object,
|
|
504
|
+
) -> tuple[int, int]:
|
|
505
|
+
"""Pessimistically estimate the number of tokens required for a query"""
|
|
506
|
+
assert self.token_retry_params
|
|
507
|
+
|
|
508
|
+
# TODO: when passing in images and files, we really need to take that into account when calculating the output tokens!!
|
|
509
|
+
|
|
510
|
+
input_tokens = (
|
|
511
|
+
await self.count_tokens(input, history=[], tools=tools, **kwargs)
|
|
512
|
+
* self.token_retry_params.input_modifier
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
output_tokens = input_tokens * self.token_retry_params.output_modifier
|
|
516
|
+
return ceil(input_tokens), ceil(output_tokens)
|
|
517
|
+
|
|
441
518
|
async def get_encoding(self) -> Encoding:
|
|
442
519
|
"""Get the appropriate tokenizer"""
|
|
443
520
|
|