model-library 0.1.6__tar.gz → 0.1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {model_library-0.1.6 → model_library-0.1.8}/Makefile +3 -5
- {model_library-0.1.6 → model_library-0.1.8}/PKG-INFO +4 -3
- {model_library-0.1.6 → model_library-0.1.8}/examples/advanced/custom_retrier.py +26 -33
- {model_library-0.1.6 → model_library-0.1.8}/examples/advanced/web_search.py +3 -27
- model_library-0.1.8/examples/count_tokens.py +98 -0
- model_library-0.1.8/examples/token_retry.py +85 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/base/base.py +237 -62
- model_library-0.1.8/model_library/base/delegate_only.py +175 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/base/input.py +10 -7
- {model_library-0.1.6 → model_library-0.1.8}/model_library/base/output.py +48 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/base/utils.py +56 -7
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/alibaba_models.yaml +44 -57
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/all_models.json +253 -126
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/kimi_models.yaml +30 -3
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/openai_models.yaml +15 -23
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/zai_models.yaml +24 -3
- {model_library-0.1.6 → model_library-0.1.8}/model_library/exceptions.py +14 -77
- {model_library-0.1.6 → model_library-0.1.8}/model_library/logging.py +6 -2
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/ai21labs.py +30 -14
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/alibaba.py +17 -8
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/amazon.py +119 -64
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/anthropic.py +184 -104
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/azure.py +22 -10
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/cohere.py +7 -7
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/deepseek.py +8 -8
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/fireworks.py +7 -8
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/google/batch.py +17 -13
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/google/google.py +130 -73
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/inception.py +7 -7
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/kimi.py +18 -8
- model_library-0.1.8/model_library/providers/minimax.py +50 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/mistral.py +61 -35
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/openai.py +219 -93
- model_library-0.1.8/model_library/providers/openrouter.py +34 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/perplexity.py +7 -7
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/together.py +7 -8
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/vals.py +16 -9
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/xai.py +157 -144
- model_library-0.1.8/model_library/providers/zai.py +64 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/register_models.py +4 -2
- {model_library-0.1.6 → model_library-0.1.8}/model_library/registry_utils.py +39 -15
- model_library-0.1.8/model_library/retriers/backoff.py +73 -0
- model_library-0.1.8/model_library/retriers/base.py +225 -0
- model_library-0.1.8/model_library/retriers/token.py +427 -0
- model_library-0.1.8/model_library/retriers/utils.py +11 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/settings.py +1 -1
- {model_library-0.1.6 → model_library-0.1.8}/model_library/utils.py +13 -35
- {model_library-0.1.6 → model_library-0.1.8}/model_library.egg-info/PKG-INFO +4 -3
- {model_library-0.1.6 → model_library-0.1.8}/model_library.egg-info/SOURCES.txt +15 -8
- {model_library-0.1.6 → model_library-0.1.8}/model_library.egg-info/requires.txt +3 -2
- {model_library-0.1.6 → model_library-0.1.8}/pyproject.toml +4 -2
- {model_library-0.1.6 → model_library-0.1.8}/scripts/browse_models.py +2 -2
- {model_library-0.1.6 → model_library-0.1.8}/scripts/run_models.py +14 -17
- model_library-0.1.8/tests/README.md +75 -0
- model_library-0.1.8/tests/conftest.py +85 -0
- model_library-0.1.8/tests/integration/conftest.py +8 -0
- model_library-0.1.8/tests/integration/test_basic.py +15 -0
- {model_library-0.1.6 → model_library-0.1.8}/tests/integration/test_batch.py +4 -12
- model_library-0.1.8/tests/integration/test_files.py +38 -0
- model_library-0.1.8/tests/integration/test_long_problem.py +24 -0
- {model_library-0.1.6 → model_library-0.1.8}/tests/integration/test_reasoning.py +3 -12
- model_library-0.1.8/tests/integration/test_retry.py +28 -0
- {model_library-0.1.6 → model_library-0.1.8}/tests/integration/test_streaming.py +13 -30
- {model_library-0.1.6 → model_library-0.1.8}/tests/integration/test_structured_output.py +0 -6
- {model_library-0.1.6 → model_library-0.1.8}/tests/integration/test_tools.py +1 -18
- model_library-0.1.8/tests/test_helpers.py +183 -0
- model_library-0.1.8/tests/unit/conftest.py +102 -0
- model_library-0.1.8/tests/unit/test_batch.py +353 -0
- model_library-0.1.8/tests/unit/test_count_tokens.py +42 -0
- model_library-0.1.8/tests/unit/test_get_client.py +188 -0
- model_library-0.1.8/tests/unit/test_openai_config.py +32 -0
- {model_library-0.1.6 → model_library-0.1.8}/tests/unit/test_prompt_caching.py +31 -81
- {model_library-0.1.6 → model_library-0.1.8}/tests/unit/test_query_logger.py +18 -0
- {model_library-0.1.6 → model_library-0.1.8}/tests/unit/test_registry.py +7 -7
- {model_library-0.1.6 → model_library-0.1.8}/tests/unit/test_result_metadata.py +28 -18
- {model_library-0.1.6 → model_library-0.1.8}/tests/unit/test_retry.py +44 -41
- model_library-0.1.8/tests/unit/test_token_retry.py +405 -0
- {model_library-0.1.6 → model_library-0.1.8}/tests/unit/test_tools.py +9 -15
- {model_library-0.1.6 → model_library-0.1.8}/uv.lock +72 -24
- model_library-0.1.6/model_library/base/delegate_only.py +0 -98
- model_library-0.1.6/model_library/providers/minimax.py +0 -33
- model_library-0.1.6/model_library/providers/zai.py +0 -34
- model_library-0.1.6/tests/README.md +0 -87
- model_library-0.1.6/tests/conftest.py +0 -275
- model_library-0.1.6/tests/integration/conftest.py +0 -8
- model_library-0.1.6/tests/integration/test_completion.py +0 -41
- model_library-0.1.6/tests/integration/test_files.py +0 -279
- model_library-0.1.6/tests/integration/test_retry.py +0 -95
- model_library-0.1.6/tests/test_helpers.py +0 -89
- model_library-0.1.6/tests/unit/conftest.py +0 -52
- model_library-0.1.6/tests/unit/providers/test_fireworks_provider.py +0 -48
- model_library-0.1.6/tests/unit/providers/test_google_provider.py +0 -58
- model_library-0.1.6/tests/unit/test_batch.py +0 -236
- model_library-0.1.6/tests/unit/test_context_window.py +0 -45
- model_library-0.1.6/tests/unit/test_perplexity_provider.py +0 -73
- model_library-0.1.6/tests/unit/test_streaming.py +0 -83
- {model_library-0.1.6 → model_library-0.1.8}/.gitattributes +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/.github/workflows/publish.yml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/.github/workflows/style.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/.github/workflows/test.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/.github/workflows/typecheck.yml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/.gitignore +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/LICENSE +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/README.md +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/examples/README.md +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/examples/advanced/batch.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/examples/advanced/deep_research.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/examples/advanced/stress.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/examples/advanced/structured_output.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/examples/basics.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/examples/data/files.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/examples/data/images.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/examples/embeddings.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/examples/files.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/examples/images.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/examples/prompt_caching.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/examples/setup.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/examples/tool_calls.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/__init__.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/base/__init__.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/base/batch.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/README.md +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/ai21labs_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/amazon_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/anthropic_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/cohere_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/deepseek_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/dummy_model.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/fireworks_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/google_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/inception_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/minimax_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/mistral_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/perplexity_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/together_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/config/xai_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/file_utils.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/model_utils.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/__init__.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/providers/google/__init__.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library/py.typed +0 -0
- {model_library-0.1.6/tests → model_library-0.1.8/model_library/retriers}/__init__.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library.egg-info/dependency_links.txt +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/model_library.egg-info/top_level.txt +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/scripts/config.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/scripts/publish.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/setup.cfg +0 -0
- {model_library-0.1.6/tests/integration → model_library-0.1.8/tests}/__init__.py +0 -0
- {model_library-0.1.6/tests/unit → model_library-0.1.8/tests/integration}/__init__.py +0 -0
- {model_library-0.1.6/tests/unit/providers → model_library-0.1.8/tests/unit}/__init__.py +0 -0
- {model_library-0.1.6 → model_library-0.1.8}/tests/unit/test_deep_research.py +0 -0
|
@@ -29,15 +29,13 @@ venv_check:
|
|
|
29
29
|
exit 1; \
|
|
30
30
|
fi
|
|
31
31
|
|
|
32
|
+
DIR ?= tests/
|
|
32
33
|
test: venv_check
|
|
33
34
|
@echo "Running unit tests..."
|
|
34
|
-
@uv run pytest
|
|
35
|
+
@uv run pytest $(DIR) -m unit -v -n 4 --dist loadscope --model=$(MODEL)
|
|
35
36
|
test-integration: venv_check
|
|
36
37
|
@echo "Running integration tests (requires API keys)..."
|
|
37
|
-
@uv run pytest
|
|
38
|
-
test-all: venv_check
|
|
39
|
-
@echo "Running all tests..."
|
|
40
|
-
@uv run pytest
|
|
38
|
+
@uv run pytest $(DIR) -m integration -v -n 4 --dist loadscope --model=$(MODEL)
|
|
41
39
|
|
|
42
40
|
format: venv_check
|
|
43
41
|
@uv run ruff format .
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: model-library
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.8
|
|
4
4
|
Summary: Model Library for vals.ai
|
|
5
5
|
Author-email: "Vals AI, Inc." <contact@vals.ai>
|
|
6
6
|
License: MIT
|
|
@@ -13,16 +13,17 @@ Requires-Dist: pyyaml>=6.0.2
|
|
|
13
13
|
Requires-Dist: rich
|
|
14
14
|
Requires-Dist: backoff<3.0,>=2.2.1
|
|
15
15
|
Requires-Dist: redis<7.0,>=6.2.0
|
|
16
|
-
Requires-Dist: tiktoken
|
|
16
|
+
Requires-Dist: tiktoken>=0.12.0
|
|
17
17
|
Requires-Dist: pillow
|
|
18
18
|
Requires-Dist: openai<3.0,>=2.0
|
|
19
19
|
Requires-Dist: anthropic<1.0,>=0.57.1
|
|
20
20
|
Requires-Dist: mistralai<2.0,>=1.9.10
|
|
21
21
|
Requires-Dist: xai-sdk<2.0,>=1.0.0
|
|
22
|
-
Requires-Dist: ai21<5.0,>=4.0
|
|
22
|
+
Requires-Dist: ai21<5.0,>=4.3.0
|
|
23
23
|
Requires-Dist: boto3<2.0,>=1.38.27
|
|
24
24
|
Requires-Dist: google-genai[aiohttp]>=1.51.0
|
|
25
25
|
Requires-Dist: google-cloud-storage>=1.26.0
|
|
26
|
+
Requires-Dist: pytest-xdist>=3.8.0
|
|
26
27
|
Dynamic: license-file
|
|
27
28
|
|
|
28
29
|
# Model Library
|
|
@@ -4,15 +4,15 @@ from typing import Any, Awaitable, Callable
|
|
|
4
4
|
|
|
5
5
|
from model_library.base import (
|
|
6
6
|
LLM,
|
|
7
|
-
RetrierType,
|
|
8
7
|
TextInput,
|
|
9
8
|
)
|
|
10
9
|
from model_library.exceptions import (
|
|
11
10
|
BackoffRetryException,
|
|
12
11
|
RetryException,
|
|
13
|
-
retry_llm_call,
|
|
14
12
|
)
|
|
15
13
|
from model_library.registry_utils import get_registry_model
|
|
14
|
+
from model_library.retriers.backoff import ExponentialBackoffRetrier
|
|
15
|
+
from model_library.retriers.base import retry_decorator
|
|
16
16
|
|
|
17
17
|
from ..setup import console_log, setup
|
|
18
18
|
|
|
@@ -36,34 +36,28 @@ def is_context_length_error(error_str: str) -> bool:
|
|
|
36
36
|
)
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
def
|
|
39
|
+
def decorator(
|
|
40
|
+
func: Callable[..., Awaitable[Any]],
|
|
41
|
+
) -> Callable[..., Awaitable[Any]]:
|
|
40
42
|
"""
|
|
41
|
-
|
|
42
|
-
Custom retries takes in a logger. It replaces the backoff retrier. Immediate retries still function.
|
|
43
|
+
Decorator must return wrapper function
|
|
43
44
|
"""
|
|
44
45
|
|
|
45
|
-
|
|
46
|
-
func: Callable[..., Awaitable[Any]],
|
|
47
|
-
) -> Callable[..., Awaitable[Any]]:
|
|
48
|
-
"""
|
|
49
|
-
Decorator must return wrapper function
|
|
50
|
-
"""
|
|
46
|
+
logger = logging.getLogger("llm.decorator")
|
|
51
47
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
48
|
+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
49
|
+
try:
|
|
50
|
+
return await func(*args, **kwargs)
|
|
51
|
+
except Exception as e:
|
|
52
|
+
# detect context length errors and retry with backoff
|
|
53
|
+
if is_context_length_error(str(e).lower()):
|
|
54
|
+
logger.warning(f"Context length error detected: {e}")
|
|
55
|
+
# for simplicty, we don't actually retry
|
|
56
|
+
raise BackoffRetryException(f"Context length error: {e}")
|
|
61
57
|
|
|
62
|
-
|
|
58
|
+
raise
|
|
63
59
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
return decorator
|
|
60
|
+
return wrapper
|
|
67
61
|
|
|
68
62
|
|
|
69
63
|
async def custom_retrier_context(model: LLM):
|
|
@@ -73,7 +67,7 @@ async def custom_retrier_context(model: LLM):
|
|
|
73
67
|
|
|
74
68
|
console_log("\n--- Custom Retrier ---\n")
|
|
75
69
|
|
|
76
|
-
model.custom_retrier =
|
|
70
|
+
model.custom_retrier = decorator
|
|
77
71
|
try:
|
|
78
72
|
await model.query(
|
|
79
73
|
[
|
|
@@ -102,15 +96,14 @@ async def custom_retrier_callback(model: LLM):
|
|
|
102
96
|
if tries > 1:
|
|
103
97
|
raise Exception("Reached retry 2")
|
|
104
98
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
)
|
|
99
|
+
retrier = ExponentialBackoffRetrier(
|
|
100
|
+
logger=model.logger,
|
|
101
|
+
max_tries=3,
|
|
102
|
+
max_time=500,
|
|
103
|
+
retry_callback=callback,
|
|
104
|
+
)
|
|
112
105
|
|
|
113
|
-
model.custom_retrier =
|
|
106
|
+
model.custom_retrier = retry_decorator(retrier)
|
|
114
107
|
|
|
115
108
|
def simulate_retry(*args: object, **kwargs: object):
|
|
116
109
|
raise RetryException("Simulated failure")
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from typing import Any, cast
|
|
3
3
|
|
|
4
|
-
from model_library.base import LLM, ToolDefinition
|
|
4
|
+
from model_library.base import LLM, QueryResult, ToolDefinition
|
|
5
5
|
from model_library.registry_utils import get_registry_model
|
|
6
6
|
|
|
7
7
|
from ..setup import console_log, setup
|
|
@@ -41,31 +41,7 @@ def print_search_details(tool_call: Any) -> None:
|
|
|
41
41
|
console_log(f" - {source}")
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
def
|
|
45
|
-
"""Extract and print citations from response history."""
|
|
46
|
-
if not response.history:
|
|
47
|
-
return
|
|
48
|
-
|
|
49
|
-
for item in response.history:
|
|
50
|
-
if not (hasattr(item, "content") and isinstance(item.content, list)):
|
|
51
|
-
continue
|
|
52
|
-
|
|
53
|
-
content_list = cast(list[Any], item.content)
|
|
54
|
-
for content_item in content_list:
|
|
55
|
-
if not (hasattr(content_item, "annotations") and content_item.annotations):
|
|
56
|
-
continue
|
|
57
|
-
|
|
58
|
-
console_log("\nCitations:")
|
|
59
|
-
annotations = cast(list[Any], content_item.annotations)
|
|
60
|
-
for annotation in annotations:
|
|
61
|
-
if hasattr(annotation, "url") and annotation.url:
|
|
62
|
-
title = getattr(annotation, "title", "Untitled")
|
|
63
|
-
url = annotation.url
|
|
64
|
-
location = getattr(annotation, "location", "Unknown")
|
|
65
|
-
console_log(f"- {title}: {url} (Location: {location})")
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
def print_web_search_results(response: Any) -> None:
|
|
44
|
+
def print_web_search_results(response: QueryResult) -> None:
|
|
69
45
|
"""Print comprehensive web search results."""
|
|
70
46
|
console_log(f"Response: {response.output_text}")
|
|
71
47
|
|
|
@@ -74,7 +50,7 @@ def print_web_search_results(response: Any) -> None:
|
|
|
74
50
|
for tool_call in response.tool_calls:
|
|
75
51
|
print_search_details(tool_call)
|
|
76
52
|
|
|
77
|
-
|
|
53
|
+
print(response.extras.citations)
|
|
78
54
|
|
|
79
55
|
|
|
80
56
|
async def web_search_domain_filtered(model: LLM) -> None:
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
from model_library import set_logging
|
|
5
|
+
from model_library.base import (
|
|
6
|
+
LLM,
|
|
7
|
+
QueryResult,
|
|
8
|
+
TextInput,
|
|
9
|
+
ToolBody,
|
|
10
|
+
ToolDefinition,
|
|
11
|
+
)
|
|
12
|
+
from model_library.registry_utils import get_registry_model
|
|
13
|
+
|
|
14
|
+
from .setup import console_log, setup
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
async def count_tokens(model: LLM):
|
|
18
|
+
console_log("\n--- Count Tokens ---\n")
|
|
19
|
+
|
|
20
|
+
tools = [
|
|
21
|
+
ToolDefinition(
|
|
22
|
+
name="get_weather",
|
|
23
|
+
body=ToolBody(
|
|
24
|
+
name="get_weather",
|
|
25
|
+
description="Get current temperature in a given location",
|
|
26
|
+
properties={
|
|
27
|
+
"location": {
|
|
28
|
+
"type": "string",
|
|
29
|
+
"description": "City and country e.g. Bogotá, Colombia",
|
|
30
|
+
},
|
|
31
|
+
},
|
|
32
|
+
required=["location"],
|
|
33
|
+
),
|
|
34
|
+
),
|
|
35
|
+
ToolDefinition(
|
|
36
|
+
name="get_danger",
|
|
37
|
+
body=ToolBody(
|
|
38
|
+
name="get_danger",
|
|
39
|
+
description="Get current danger in a given location",
|
|
40
|
+
properties={
|
|
41
|
+
"location": {
|
|
42
|
+
"type": "string",
|
|
43
|
+
"description": "City and country e.g. Bogotá, Colombia",
|
|
44
|
+
},
|
|
45
|
+
},
|
|
46
|
+
required=["location"],
|
|
47
|
+
),
|
|
48
|
+
),
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
system_prompt = "You must make exactly 0 or 1 tool calls per answer. You must not make more than 1 tool call per answer."
|
|
52
|
+
user_prompt = "What is the weather in San Francisco right now?"
|
|
53
|
+
|
|
54
|
+
input = [TextInput(text=user_prompt)]
|
|
55
|
+
|
|
56
|
+
predicted_tokens = await model.count_tokens(
|
|
57
|
+
input,
|
|
58
|
+
tools=tools,
|
|
59
|
+
system_prompt=system_prompt,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
response: QueryResult = await model.query(
|
|
63
|
+
input,
|
|
64
|
+
tools=tools,
|
|
65
|
+
system_prompt=system_prompt,
|
|
66
|
+
)
|
|
67
|
+
metadata = response.metadata
|
|
68
|
+
|
|
69
|
+
actual_tokens = metadata.total_input_tokens
|
|
70
|
+
|
|
71
|
+
console_log(f"Predicted Token Count: {predicted_tokens}")
|
|
72
|
+
console_log(f"Actual Token Count: {actual_tokens}")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
async def main():
|
|
76
|
+
import argparse
|
|
77
|
+
|
|
78
|
+
parser = argparse.ArgumentParser(description="Example of counting tokens")
|
|
79
|
+
parser.add_argument(
|
|
80
|
+
"model",
|
|
81
|
+
nargs="?",
|
|
82
|
+
default="google/gemini-2.5-flash",
|
|
83
|
+
type=str,
|
|
84
|
+
help="Model endpoint (default: google/gemini-2.5-flash)",
|
|
85
|
+
)
|
|
86
|
+
args = parser.parse_args()
|
|
87
|
+
|
|
88
|
+
model = get_registry_model(args.model)
|
|
89
|
+
model.logger.info(model)
|
|
90
|
+
|
|
91
|
+
set_logging(enable=True, level=logging.INFO)
|
|
92
|
+
|
|
93
|
+
await count_tokens(model)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
if __name__ == "__main__":
|
|
97
|
+
setup()
|
|
98
|
+
asyncio.run(main())
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import time
|
|
3
|
+
from logging import DEBUG
|
|
4
|
+
from typing import Any, Coroutine
|
|
5
|
+
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from model_library.base import LLM, TextInput, TokenRetryParams
|
|
9
|
+
from model_library.logging import set_logging
|
|
10
|
+
from model_library.registry_utils import get_registry_model
|
|
11
|
+
from model_library.retriers.token import set_redis_client
|
|
12
|
+
|
|
13
|
+
from .setup import console_log, setup
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
async def token_retry(model: LLM):
|
|
17
|
+
console_log("\n--- Token Retry ---\n")
|
|
18
|
+
await model.query(
|
|
19
|
+
[
|
|
20
|
+
TextInput(
|
|
21
|
+
# text="What is QSBS? Explain your thinking in detail and make it concise"
|
|
22
|
+
text="dwadwadwadawdLong argument of cats vs dogs" * 5000
|
|
23
|
+
+ "Ignore the previous junk, tell me a very long story about the cats and the dogs. And yes, I do want an actual story, I just have no choice but to include the junk before, believe me."
|
|
24
|
+
)
|
|
25
|
+
],
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
async def main():
|
|
30
|
+
import argparse
|
|
31
|
+
|
|
32
|
+
parser = argparse.ArgumentParser(description="Run basic examples with a model")
|
|
33
|
+
parser.add_argument(
|
|
34
|
+
"model",
|
|
35
|
+
nargs="?",
|
|
36
|
+
default="google/gemini-2.5-flash",
|
|
37
|
+
type=str,
|
|
38
|
+
help="Model endpoint (default: google/gemini-2.5-flash)",
|
|
39
|
+
)
|
|
40
|
+
args = parser.parse_args()
|
|
41
|
+
|
|
42
|
+
set_logging(level=DEBUG)
|
|
43
|
+
|
|
44
|
+
model = get_registry_model(args.model)
|
|
45
|
+
model.logger.info(model)
|
|
46
|
+
|
|
47
|
+
limit = await model.get_rate_limit()
|
|
48
|
+
model.logger.info(limit)
|
|
49
|
+
|
|
50
|
+
import redis.asyncio as redis
|
|
51
|
+
|
|
52
|
+
# NOTE: make sure you have redis running locally
|
|
53
|
+
# docker run -d -p 6379:6379 redis:latest
|
|
54
|
+
|
|
55
|
+
redis_client = redis.Redis(
|
|
56
|
+
host="localhost", port=6379, decode_responses=True, max_connections=None
|
|
57
|
+
)
|
|
58
|
+
set_redis_client(redis_client)
|
|
59
|
+
|
|
60
|
+
provider_tokenizer_input_modifier = 1
|
|
61
|
+
dataset_output_modifier = 0.001
|
|
62
|
+
|
|
63
|
+
limit = 100_000
|
|
64
|
+
await model.init_token_retry(
|
|
65
|
+
token_retry_params=TokenRetryParams(
|
|
66
|
+
input_modifier=provider_tokenizer_input_modifier,
|
|
67
|
+
output_modifier=dataset_output_modifier,
|
|
68
|
+
use_dynamic_estimate=True,
|
|
69
|
+
limit=limit,
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
tasks: list[Coroutine[Any, Any, None]] = []
|
|
73
|
+
for _ in range(200):
|
|
74
|
+
tasks.append(token_retry(model))
|
|
75
|
+
|
|
76
|
+
start = time.time()
|
|
77
|
+
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
|
|
78
|
+
await coro
|
|
79
|
+
finish = time.time() - start
|
|
80
|
+
console_log(f"Finished in {finish:.1f}s")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
if __name__ == "__main__":
|
|
84
|
+
setup()
|
|
85
|
+
asyncio.run(main())
|