model-library 0.1.6__tar.gz → 0.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {model_library-0.1.6 → model_library-0.1.7}/PKG-INFO +3 -3
- {model_library-0.1.6 → model_library-0.1.7}/examples/advanced/web_search.py +3 -26
- model_library-0.1.7/examples/count_tokens.py +95 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/base/base.py +98 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/base/delegate_only.py +10 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/base/input.py +10 -7
- {model_library-0.1.6 → model_library-0.1.7}/model_library/base/output.py +5 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/base/utils.py +21 -7
- {model_library-0.1.6 → model_library-0.1.7}/model_library/exceptions.py +11 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/logging.py +6 -2
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/ai21labs.py +19 -7
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/amazon.py +70 -48
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/anthropic.py +101 -74
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/google/batch.py +3 -3
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/google/google.py +83 -45
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/minimax.py +19 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/mistral.py +41 -27
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/openai.py +122 -73
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/vals.py +4 -3
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/xai.py +123 -115
- {model_library-0.1.6 → model_library-0.1.7}/model_library/register_models.py +4 -2
- {model_library-0.1.6 → model_library-0.1.7}/model_library/utils.py +0 -35
- {model_library-0.1.6 → model_library-0.1.7}/model_library.egg-info/PKG-INFO +3 -3
- {model_library-0.1.6 → model_library-0.1.7}/model_library.egg-info/SOURCES.txt +2 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library.egg-info/requires.txt +2 -2
- {model_library-0.1.6 → model_library-0.1.7}/pyproject.toml +2 -2
- {model_library-0.1.6 → model_library-0.1.7}/scripts/run_models.py +1 -4
- {model_library-0.1.6 → model_library-0.1.7}/tests/unit/conftest.py +1 -0
- model_library-0.1.7/tests/unit/test_count_tokens.py +67 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_prompt_caching.py +5 -5
- {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_tools.py +5 -9
- {model_library-0.1.6 → model_library-0.1.7}/uv.lock +47 -23
- {model_library-0.1.6 → model_library-0.1.7}/.gitattributes +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/.github/workflows/publish.yml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/.github/workflows/style.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/.github/workflows/test.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/.github/workflows/typecheck.yml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/.gitignore +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/LICENSE +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/Makefile +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/README.md +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/examples/README.md +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/examples/advanced/batch.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/examples/advanced/custom_retrier.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/examples/advanced/deep_research.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/examples/advanced/stress.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/examples/advanced/structured_output.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/examples/basics.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/examples/data/files.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/examples/data/images.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/examples/embeddings.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/examples/files.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/examples/images.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/examples/prompt_caching.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/examples/setup.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/examples/tool_calls.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/__init__.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/base/__init__.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/base/batch.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/README.md +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/ai21labs_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/alibaba_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/all_models.json +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/amazon_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/anthropic_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/cohere_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/deepseek_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/dummy_model.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/fireworks_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/google_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/inception_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/kimi_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/minimax_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/mistral_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/openai_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/perplexity_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/together_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/xai_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/config/zai_models.yaml +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/file_utils.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/model_utils.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/__init__.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/alibaba.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/azure.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/cohere.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/deepseek.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/fireworks.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/google/__init__.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/inception.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/kimi.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/perplexity.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/together.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/zai.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/py.typed +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/registry_utils.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library/settings.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library.egg-info/dependency_links.txt +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/model_library.egg-info/top_level.txt +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/scripts/browse_models.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/scripts/config.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/scripts/publish.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/setup.cfg +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/README.md +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/__init__.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/conftest.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/integration/__init__.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/integration/conftest.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/integration/test_batch.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/integration/test_completion.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/integration/test_files.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/integration/test_reasoning.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/integration/test_retry.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/integration/test_streaming.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/integration/test_structured_output.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/integration/test_tools.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/test_helpers.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/unit/__init__.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/unit/providers/__init__.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/unit/providers/test_fireworks_provider.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/unit/providers/test_google_provider.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_batch.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_context_window.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_deep_research.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_perplexity_provider.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_query_logger.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_registry.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_result_metadata.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_retry.py +0 -0
- {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_streaming.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: model-library
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.7
|
|
4
4
|
Summary: Model Library for vals.ai
|
|
5
5
|
Author-email: "Vals AI, Inc." <contact@vals.ai>
|
|
6
6
|
License: MIT
|
|
@@ -13,13 +13,13 @@ Requires-Dist: pyyaml>=6.0.2
|
|
|
13
13
|
Requires-Dist: rich
|
|
14
14
|
Requires-Dist: backoff<3.0,>=2.2.1
|
|
15
15
|
Requires-Dist: redis<7.0,>=6.2.0
|
|
16
|
-
Requires-Dist: tiktoken
|
|
16
|
+
Requires-Dist: tiktoken>=0.12.0
|
|
17
17
|
Requires-Dist: pillow
|
|
18
18
|
Requires-Dist: openai<3.0,>=2.0
|
|
19
19
|
Requires-Dist: anthropic<1.0,>=0.57.1
|
|
20
20
|
Requires-Dist: mistralai<2.0,>=1.9.10
|
|
21
21
|
Requires-Dist: xai-sdk<2.0,>=1.0.0
|
|
22
|
-
Requires-Dist: ai21<5.0,>=4.0
|
|
22
|
+
Requires-Dist: ai21<5.0,>=4.3.0
|
|
23
23
|
Requires-Dist: boto3<2.0,>=1.38.27
|
|
24
24
|
Requires-Dist: google-genai[aiohttp]>=1.51.0
|
|
25
25
|
Requires-Dist: google-cloud-storage>=1.26.0
|
|
@@ -2,6 +2,7 @@ import asyncio
|
|
|
2
2
|
from typing import Any, cast
|
|
3
3
|
|
|
4
4
|
from model_library.base import LLM, ToolDefinition
|
|
5
|
+
from model_library.base.output import QueryResult
|
|
5
6
|
from model_library.registry_utils import get_registry_model
|
|
6
7
|
|
|
7
8
|
from ..setup import console_log, setup
|
|
@@ -41,31 +42,7 @@ def print_search_details(tool_call: Any) -> None:
|
|
|
41
42
|
console_log(f" - {source}")
|
|
42
43
|
|
|
43
44
|
|
|
44
|
-
def
|
|
45
|
-
"""Extract and print citations from response history."""
|
|
46
|
-
if not response.history:
|
|
47
|
-
return
|
|
48
|
-
|
|
49
|
-
for item in response.history:
|
|
50
|
-
if not (hasattr(item, "content") and isinstance(item.content, list)):
|
|
51
|
-
continue
|
|
52
|
-
|
|
53
|
-
content_list = cast(list[Any], item.content)
|
|
54
|
-
for content_item in content_list:
|
|
55
|
-
if not (hasattr(content_item, "annotations") and content_item.annotations):
|
|
56
|
-
continue
|
|
57
|
-
|
|
58
|
-
console_log("\nCitations:")
|
|
59
|
-
annotations = cast(list[Any], content_item.annotations)
|
|
60
|
-
for annotation in annotations:
|
|
61
|
-
if hasattr(annotation, "url") and annotation.url:
|
|
62
|
-
title = getattr(annotation, "title", "Untitled")
|
|
63
|
-
url = annotation.url
|
|
64
|
-
location = getattr(annotation, "location", "Unknown")
|
|
65
|
-
console_log(f"- {title}: {url} (Location: {location})")
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
def print_web_search_results(response: Any) -> None:
|
|
45
|
+
def print_web_search_results(response: QueryResult) -> None:
|
|
69
46
|
"""Print comprehensive web search results."""
|
|
70
47
|
console_log(f"Response: {response.output_text}")
|
|
71
48
|
|
|
@@ -74,7 +51,7 @@ def print_web_search_results(response: Any) -> None:
|
|
|
74
51
|
for tool_call in response.tool_calls:
|
|
75
52
|
print_search_details(tool_call)
|
|
76
53
|
|
|
77
|
-
|
|
54
|
+
print(response.extras.citations)
|
|
78
55
|
|
|
79
56
|
|
|
80
57
|
async def web_search_domain_filtered(model: LLM) -> None:
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
from model_library import set_logging
|
|
5
|
+
from model_library.base import (
|
|
6
|
+
LLM,
|
|
7
|
+
QueryResult,
|
|
8
|
+
TextInput,
|
|
9
|
+
ToolBody,
|
|
10
|
+
ToolDefinition,
|
|
11
|
+
)
|
|
12
|
+
from model_library.registry_utils import get_registry_model
|
|
13
|
+
|
|
14
|
+
from .setup import console_log, setup
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
async def count_tokens(model: LLM):
|
|
18
|
+
console_log("\n--- Count Tokens ---\n")
|
|
19
|
+
|
|
20
|
+
tools = [
|
|
21
|
+
ToolDefinition(
|
|
22
|
+
name="get_weather",
|
|
23
|
+
body=ToolBody(
|
|
24
|
+
name="get_weather",
|
|
25
|
+
description="Get current temperature in a given location",
|
|
26
|
+
properties={
|
|
27
|
+
"location": {
|
|
28
|
+
"type": "string",
|
|
29
|
+
"description": "City and country e.g. Bogotá, Colombia",
|
|
30
|
+
},
|
|
31
|
+
},
|
|
32
|
+
required=["location"],
|
|
33
|
+
),
|
|
34
|
+
),
|
|
35
|
+
ToolDefinition(
|
|
36
|
+
name="get_danger",
|
|
37
|
+
body=ToolBody(
|
|
38
|
+
name="get_danger",
|
|
39
|
+
description="Get current danger in a given location",
|
|
40
|
+
properties={
|
|
41
|
+
"location": {
|
|
42
|
+
"type": "string",
|
|
43
|
+
"description": "City and country e.g. Bogotá, Colombia",
|
|
44
|
+
},
|
|
45
|
+
},
|
|
46
|
+
required=["location"],
|
|
47
|
+
),
|
|
48
|
+
),
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
system_prompt = "You must make exactly 0 or 1 tool calls per answer. You must not make more than 1 tool call per answer."
|
|
52
|
+
user_prompt = "What is the weather in San Francisco right now?"
|
|
53
|
+
|
|
54
|
+
predicted_tokens = await model.count_tokens(
|
|
55
|
+
[TextInput(text=user_prompt)],
|
|
56
|
+
tools=tools,
|
|
57
|
+
system_prompt=system_prompt,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
response: QueryResult = await model.query(
|
|
61
|
+
[TextInput(text=user_prompt)],
|
|
62
|
+
tools=tools,
|
|
63
|
+
system_prompt=system_prompt,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
actual_tokens = response.metadata.total_input_tokens
|
|
67
|
+
|
|
68
|
+
console_log(f"Predicted Token Count: {predicted_tokens}")
|
|
69
|
+
console_log(f"Actual Token Count: {actual_tokens}\n")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
async def main():
|
|
73
|
+
import argparse
|
|
74
|
+
|
|
75
|
+
parser = argparse.ArgumentParser(description="Example of counting tokens")
|
|
76
|
+
parser.add_argument(
|
|
77
|
+
"model",
|
|
78
|
+
nargs="?",
|
|
79
|
+
default="google/gemini-2.5-flash",
|
|
80
|
+
type=str,
|
|
81
|
+
help="Model endpoint (default: google/gemini-2.5-flash)",
|
|
82
|
+
)
|
|
83
|
+
args = parser.parse_args()
|
|
84
|
+
|
|
85
|
+
model = get_registry_model(args.model)
|
|
86
|
+
model.logger.info(model)
|
|
87
|
+
|
|
88
|
+
set_logging(enable=True, level=logging.INFO)
|
|
89
|
+
|
|
90
|
+
await count_tokens(model)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
if __name__ == "__main__":
|
|
94
|
+
setup()
|
|
95
|
+
asyncio.run(main())
|
|
@@ -13,8 +13,10 @@ from typing import (
|
|
|
13
13
|
TypeVar,
|
|
14
14
|
)
|
|
15
15
|
|
|
16
|
+
import tiktoken
|
|
16
17
|
from pydantic import model_serializer
|
|
17
18
|
from pydantic.main import BaseModel
|
|
19
|
+
from tiktoken.core import Encoding
|
|
18
20
|
from typing_extensions import override
|
|
19
21
|
|
|
20
22
|
from model_library.base.batch import (
|
|
@@ -35,6 +37,7 @@ from model_library.base.output import (
|
|
|
35
37
|
)
|
|
36
38
|
from model_library.base.utils import (
|
|
37
39
|
get_pretty_input_types,
|
|
40
|
+
serialize_for_tokenizing,
|
|
38
41
|
)
|
|
39
42
|
from model_library.exceptions import (
|
|
40
43
|
ImmediateRetryException,
|
|
@@ -379,6 +382,20 @@ class LLM(ABC):
|
|
|
379
382
|
"""
|
|
380
383
|
...
|
|
381
384
|
|
|
385
|
+
@abstractmethod
|
|
386
|
+
async def build_body(
|
|
387
|
+
self,
|
|
388
|
+
input: Sequence[InputItem],
|
|
389
|
+
*,
|
|
390
|
+
tools: list[ToolDefinition],
|
|
391
|
+
**kwargs: Any,
|
|
392
|
+
) -> dict[str, Any]:
|
|
393
|
+
"""
|
|
394
|
+
Builds the body of the request to the model provider
|
|
395
|
+
Calls parse_input
|
|
396
|
+
"""
|
|
397
|
+
...
|
|
398
|
+
|
|
382
399
|
@abstractmethod
|
|
383
400
|
async def parse_input(
|
|
384
401
|
self,
|
|
@@ -421,6 +438,87 @@ class LLM(ABC):
|
|
|
421
438
|
"""Upload a file to the model provider"""
|
|
422
439
|
...
|
|
423
440
|
|
|
441
|
+
async def get_encoding(self) -> Encoding:
|
|
442
|
+
"""Get the appropriate tokenizer"""
|
|
443
|
+
|
|
444
|
+
model = self.model_name.lower()
|
|
445
|
+
|
|
446
|
+
if any(x in model for x in ["gpt-4o", "o1", "o3", "gpt-4.1", "gpt-5"]):
|
|
447
|
+
return tiktoken.get_encoding("o200k_base")
|
|
448
|
+
elif "gpt-4" in model or "gpt-3.5" in model:
|
|
449
|
+
try:
|
|
450
|
+
return tiktoken.encoding_for_model(self.model_name)
|
|
451
|
+
except KeyError:
|
|
452
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
453
|
+
elif "claude" in model:
|
|
454
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
455
|
+
elif "gemini" in model:
|
|
456
|
+
return tiktoken.get_encoding("o200k_base")
|
|
457
|
+
elif "llama" in model or "mistral" in model:
|
|
458
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
459
|
+
else:
|
|
460
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
461
|
+
|
|
462
|
+
async def stringify_input(
|
|
463
|
+
self,
|
|
464
|
+
input: Sequence[InputItem],
|
|
465
|
+
*,
|
|
466
|
+
history: Sequence[InputItem] = [],
|
|
467
|
+
tools: list[ToolDefinition] = [],
|
|
468
|
+
**kwargs: object,
|
|
469
|
+
) -> str:
|
|
470
|
+
input = [*history, *input]
|
|
471
|
+
|
|
472
|
+
system_prompt = kwargs.pop(
|
|
473
|
+
"system_prompt", ""
|
|
474
|
+
) # TODO: refactor along with system prompt arg change
|
|
475
|
+
|
|
476
|
+
# special case if using a delegate
|
|
477
|
+
# don't inherit method override by default
|
|
478
|
+
if self.delegate:
|
|
479
|
+
parsed_input = await self.delegate.parse_input(input, **kwargs)
|
|
480
|
+
parsed_tools = await self.delegate.parse_tools(tools)
|
|
481
|
+
else:
|
|
482
|
+
parsed_input = await self.parse_input(input, **kwargs)
|
|
483
|
+
parsed_tools = await self.parse_tools(tools)
|
|
484
|
+
|
|
485
|
+
serialized_input = serialize_for_tokenizing(parsed_input)
|
|
486
|
+
serialized_tools = serialize_for_tokenizing(parsed_tools)
|
|
487
|
+
|
|
488
|
+
combined = f"{system_prompt}\n{serialized_input}\n{serialized_tools}"
|
|
489
|
+
|
|
490
|
+
return combined
|
|
491
|
+
|
|
492
|
+
async def count_tokens(
|
|
493
|
+
self,
|
|
494
|
+
input: Sequence[InputItem],
|
|
495
|
+
*,
|
|
496
|
+
history: Sequence[InputItem] = [],
|
|
497
|
+
tools: list[ToolDefinition] = [],
|
|
498
|
+
**kwargs: object,
|
|
499
|
+
) -> int:
|
|
500
|
+
"""
|
|
501
|
+
Count the number of tokens for a query.
|
|
502
|
+
Combines parsed input and tools, then tokenizes the result.
|
|
503
|
+
"""
|
|
504
|
+
|
|
505
|
+
if not input and not history:
|
|
506
|
+
return 0
|
|
507
|
+
|
|
508
|
+
if self.delegate:
|
|
509
|
+
encoding = await self.delegate.get_encoding()
|
|
510
|
+
else:
|
|
511
|
+
encoding = await self.get_encoding()
|
|
512
|
+
self.logger.debug(f"Token Count Encoding: {encoding}")
|
|
513
|
+
|
|
514
|
+
string_input = await self.stringify_input(
|
|
515
|
+
input, history=history, tools=tools, **kwargs
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
count = len(encoding.encode(string_input, disallowed_special=()))
|
|
519
|
+
self.logger.debug(f"Combined Token Count Input: {count}")
|
|
520
|
+
return count
|
|
521
|
+
|
|
424
522
|
async def query_json(
|
|
425
523
|
self,
|
|
426
524
|
input: Sequence[InputItem],
|
|
@@ -58,6 +58,16 @@ class DelegateOnly(LLM):
|
|
|
58
58
|
input, tools=tools, query_logger=query_logger, **kwargs
|
|
59
59
|
)
|
|
60
60
|
|
|
61
|
+
@override
|
|
62
|
+
async def build_body(
|
|
63
|
+
self,
|
|
64
|
+
input: Sequence[InputItem],
|
|
65
|
+
*,
|
|
66
|
+
tools: list[ToolDefinition],
|
|
67
|
+
**kwargs: object,
|
|
68
|
+
) -> dict[str, Any]:
|
|
69
|
+
raise DelegateOnlyException()
|
|
70
|
+
|
|
61
71
|
@override
|
|
62
72
|
async def parse_input(
|
|
63
73
|
self,
|
|
@@ -74,8 +74,6 @@ class ToolCall(BaseModel):
|
|
|
74
74
|
--- INPUT ---
|
|
75
75
|
"""
|
|
76
76
|
|
|
77
|
-
RawResponse = Any
|
|
78
|
-
|
|
79
77
|
|
|
80
78
|
class ToolInput(BaseModel):
|
|
81
79
|
tools: list[ToolDefinition] = []
|
|
@@ -90,11 +88,16 @@ class TextInput(BaseModel):
|
|
|
90
88
|
text: str
|
|
91
89
|
|
|
92
90
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
91
|
+
class RawResponse(BaseModel):
|
|
92
|
+
# used to store a received response
|
|
93
|
+
response: Any
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class RawInput(BaseModel):
|
|
97
|
+
# used to pass in anything provider specific (e.g. a mock conversation)
|
|
98
|
+
input: Any
|
|
96
99
|
|
|
97
100
|
|
|
98
101
|
InputItem = (
|
|
99
|
-
TextInput | FileInput | ToolResult |
|
|
100
|
-
) # input item can either be a prompt, a file (image or file), a tool call result,
|
|
102
|
+
TextInput | FileInput | ToolResult | RawInput | RawResponse
|
|
103
|
+
) # input item can either be a prompt, a file (image or file), a tool call result, a previous response, or raw input
|
|
@@ -24,6 +24,11 @@ class Citation(BaseModel):
|
|
|
24
24
|
index: int | None = None
|
|
25
25
|
container_id: str | None = None
|
|
26
26
|
|
|
27
|
+
@override
|
|
28
|
+
def __repr__(self):
|
|
29
|
+
attrs = vars(self).copy()
|
|
30
|
+
return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2)}\n)"
|
|
31
|
+
|
|
27
32
|
|
|
28
33
|
class QueryResultExtras(BaseModel):
|
|
29
34
|
citations: list[Citation] = Field(default_factory=list)
|
|
@@ -1,18 +1,34 @@
|
|
|
1
|
-
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any, Sequence, TypeVar
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
2
5
|
|
|
3
6
|
from model_library.base.input import (
|
|
4
7
|
FileBase,
|
|
5
8
|
InputItem,
|
|
6
|
-
|
|
9
|
+
RawInput,
|
|
10
|
+
RawResponse,
|
|
7
11
|
TextInput,
|
|
8
12
|
ToolResult,
|
|
9
13
|
)
|
|
10
14
|
from model_library.utils import truncate_str
|
|
11
|
-
from pydantic import BaseModel
|
|
12
15
|
|
|
13
16
|
T = TypeVar("T", bound=BaseModel)
|
|
14
17
|
|
|
15
18
|
|
|
19
|
+
def serialize_for_tokenizing(content: Any) -> str:
|
|
20
|
+
"""
|
|
21
|
+
Serialize parsed content into a string for tokenization
|
|
22
|
+
"""
|
|
23
|
+
parts: list[str] = []
|
|
24
|
+
if content:
|
|
25
|
+
if isinstance(content, str):
|
|
26
|
+
parts.append(content)
|
|
27
|
+
else:
|
|
28
|
+
parts.append(json.dumps(content, default=str))
|
|
29
|
+
return "\n".join(parts)
|
|
30
|
+
|
|
31
|
+
|
|
16
32
|
def add_optional(
|
|
17
33
|
a: int | float | T | None, b: int | float | T | None
|
|
18
34
|
) -> int | float | T | None:
|
|
@@ -54,11 +70,9 @@ def get_pretty_input_types(input: Sequence["InputItem"], verbose: bool = False)
|
|
|
54
70
|
return repr(item)
|
|
55
71
|
case ToolResult():
|
|
56
72
|
return repr(item)
|
|
57
|
-
case
|
|
58
|
-
item = cast(RawInputItem, item)
|
|
73
|
+
case RawInput():
|
|
59
74
|
return repr(item)
|
|
60
|
-
case
|
|
61
|
-
# RawResponse
|
|
75
|
+
case RawResponse():
|
|
62
76
|
return repr(item)
|
|
63
77
|
|
|
64
78
|
processed_items = [f" {process_item(item)}" for item in input]
|
|
@@ -146,6 +146,17 @@ class BadInputError(Exception):
|
|
|
146
146
|
super().__init__(message or BadInputError.DEFAULT_MESSAGE)
|
|
147
147
|
|
|
148
148
|
|
|
149
|
+
class NoMatchingToolCallError(Exception):
|
|
150
|
+
"""
|
|
151
|
+
Raised when a tool call result is provided with no matching tool call
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
DEFAULT_MESSAGE: str = "Tool call result provided with no matching tool call"
|
|
155
|
+
|
|
156
|
+
def __init__(self, message: str | None = None):
|
|
157
|
+
super().__init__(message or NoMatchingToolCallError.DEFAULT_MESSAGE)
|
|
158
|
+
|
|
159
|
+
|
|
149
160
|
# Add more retriable exceptions as needed
|
|
150
161
|
# Providers that don't have an explicit rate limit error are handled manually
|
|
151
162
|
# by wrapping errored Http/gRPC requests with a BackoffRetryException
|
|
@@ -6,7 +6,11 @@ from rich.logging import RichHandler
|
|
|
6
6
|
_llm_logger = logging.getLogger("llm")
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def set_logging(
|
|
9
|
+
def set_logging(
|
|
10
|
+
enable: bool = True,
|
|
11
|
+
level: int = logging.INFO,
|
|
12
|
+
handler: logging.Handler | None = None,
|
|
13
|
+
):
|
|
10
14
|
"""
|
|
11
15
|
Sets up logging for the model library
|
|
12
16
|
|
|
@@ -15,7 +19,7 @@ def set_logging(enable: bool = True, handler: logging.Handler | None = None):
|
|
|
15
19
|
handler (logging.Handler, optional): A custom logging handler. Defaults to RichHandler.
|
|
16
20
|
"""
|
|
17
21
|
if enable:
|
|
18
|
-
_llm_logger.setLevel(
|
|
22
|
+
_llm_logger.setLevel(level)
|
|
19
23
|
else:
|
|
20
24
|
_llm_logger.setLevel(logging.CRITICAL)
|
|
21
25
|
|
|
@@ -22,6 +22,7 @@ from model_library.base import (
|
|
|
22
22
|
ToolDefinition,
|
|
23
23
|
ToolResult,
|
|
24
24
|
)
|
|
25
|
+
from model_library.base.input import RawResponse
|
|
25
26
|
from model_library.exceptions import (
|
|
26
27
|
BadInputError,
|
|
27
28
|
MaxOutputTokensExceededError,
|
|
@@ -65,8 +66,6 @@ class AI21LabsModel(LLM):
|
|
|
65
66
|
match item:
|
|
66
67
|
case TextInput():
|
|
67
68
|
new_input.append(ChatMessage(role="user", content=item.text))
|
|
68
|
-
case AssistantMessage():
|
|
69
|
-
new_input.append(item)
|
|
70
69
|
case ToolResult():
|
|
71
70
|
new_input.append(
|
|
72
71
|
ToolMessage(
|
|
@@ -74,7 +73,9 @@ class AI21LabsModel(LLM):
|
|
|
74
73
|
content=item.result,
|
|
75
74
|
tool_call_id=item.tool_call.id,
|
|
76
75
|
)
|
|
77
|
-
)
|
|
76
|
+
) # TODO: tool calling metadata and test
|
|
77
|
+
case RawResponse():
|
|
78
|
+
new_input.append(item.response)
|
|
78
79
|
case _:
|
|
79
80
|
raise BadInputError("Unsupported input type")
|
|
80
81
|
return new_input
|
|
@@ -133,14 +134,13 @@ class AI21LabsModel(LLM):
|
|
|
133
134
|
raise NotImplementedError()
|
|
134
135
|
|
|
135
136
|
@override
|
|
136
|
-
async def
|
|
137
|
+
async def build_body(
|
|
137
138
|
self,
|
|
138
139
|
input: Sequence[InputItem],
|
|
139
140
|
*,
|
|
140
141
|
tools: list[ToolDefinition],
|
|
141
|
-
query_logger: logging.Logger,
|
|
142
142
|
**kwargs: object,
|
|
143
|
-
) ->
|
|
143
|
+
) -> dict[str, Any]:
|
|
144
144
|
messages: list[ChatMessage] = []
|
|
145
145
|
if "system_prompt" in kwargs:
|
|
146
146
|
messages.append(
|
|
@@ -162,6 +162,18 @@ class AI21LabsModel(LLM):
|
|
|
162
162
|
body["top_p"] = self.top_p
|
|
163
163
|
|
|
164
164
|
body.update(kwargs)
|
|
165
|
+
return body
|
|
166
|
+
|
|
167
|
+
@override
|
|
168
|
+
async def _query_impl(
|
|
169
|
+
self,
|
|
170
|
+
input: Sequence[InputItem],
|
|
171
|
+
*,
|
|
172
|
+
tools: list[ToolDefinition],
|
|
173
|
+
query_logger: logging.Logger,
|
|
174
|
+
**kwargs: object,
|
|
175
|
+
) -> QueryResult:
|
|
176
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
165
177
|
|
|
166
178
|
response: ChatCompletionResponse = (
|
|
167
179
|
await self.get_client().chat.completions.create(**body, stream=False) # pyright: ignore[reportAny, reportUnknownMemberType]
|
|
@@ -186,7 +198,7 @@ class AI21LabsModel(LLM):
|
|
|
186
198
|
|
|
187
199
|
output = QueryResult(
|
|
188
200
|
output_text=choice.message.content,
|
|
189
|
-
history=[*input, choice.message],
|
|
201
|
+
history=[*input, RawResponse(response=choice.message)],
|
|
190
202
|
metadata=QueryResultMetadata(
|
|
191
203
|
in_tokens=response.usage.prompt_tokens,
|
|
192
204
|
out_tokens=response.usage.completion_tokens,
|