model-library 0.1.6__tar.gz → 0.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. {model_library-0.1.6 → model_library-0.1.7}/PKG-INFO +3 -3
  2. {model_library-0.1.6 → model_library-0.1.7}/examples/advanced/web_search.py +3 -26
  3. model_library-0.1.7/examples/count_tokens.py +95 -0
  4. {model_library-0.1.6 → model_library-0.1.7}/model_library/base/base.py +98 -0
  5. {model_library-0.1.6 → model_library-0.1.7}/model_library/base/delegate_only.py +10 -0
  6. {model_library-0.1.6 → model_library-0.1.7}/model_library/base/input.py +10 -7
  7. {model_library-0.1.6 → model_library-0.1.7}/model_library/base/output.py +5 -0
  8. {model_library-0.1.6 → model_library-0.1.7}/model_library/base/utils.py +21 -7
  9. {model_library-0.1.6 → model_library-0.1.7}/model_library/exceptions.py +11 -0
  10. {model_library-0.1.6 → model_library-0.1.7}/model_library/logging.py +6 -2
  11. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/ai21labs.py +19 -7
  12. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/amazon.py +70 -48
  13. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/anthropic.py +101 -74
  14. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/google/batch.py +3 -3
  15. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/google/google.py +83 -45
  16. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/minimax.py +19 -0
  17. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/mistral.py +41 -27
  18. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/openai.py +122 -73
  19. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/vals.py +4 -3
  20. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/xai.py +123 -115
  21. {model_library-0.1.6 → model_library-0.1.7}/model_library/register_models.py +4 -2
  22. {model_library-0.1.6 → model_library-0.1.7}/model_library/utils.py +0 -35
  23. {model_library-0.1.6 → model_library-0.1.7}/model_library.egg-info/PKG-INFO +3 -3
  24. {model_library-0.1.6 → model_library-0.1.7}/model_library.egg-info/SOURCES.txt +2 -0
  25. {model_library-0.1.6 → model_library-0.1.7}/model_library.egg-info/requires.txt +2 -2
  26. {model_library-0.1.6 → model_library-0.1.7}/pyproject.toml +2 -2
  27. {model_library-0.1.6 → model_library-0.1.7}/scripts/run_models.py +1 -4
  28. {model_library-0.1.6 → model_library-0.1.7}/tests/unit/conftest.py +1 -0
  29. model_library-0.1.7/tests/unit/test_count_tokens.py +67 -0
  30. {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_prompt_caching.py +5 -5
  31. {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_tools.py +5 -9
  32. {model_library-0.1.6 → model_library-0.1.7}/uv.lock +47 -23
  33. {model_library-0.1.6 → model_library-0.1.7}/.gitattributes +0 -0
  34. {model_library-0.1.6 → model_library-0.1.7}/.github/workflows/publish.yml +0 -0
  35. {model_library-0.1.6 → model_library-0.1.7}/.github/workflows/style.yaml +0 -0
  36. {model_library-0.1.6 → model_library-0.1.7}/.github/workflows/test.yaml +0 -0
  37. {model_library-0.1.6 → model_library-0.1.7}/.github/workflows/typecheck.yml +0 -0
  38. {model_library-0.1.6 → model_library-0.1.7}/.gitignore +0 -0
  39. {model_library-0.1.6 → model_library-0.1.7}/LICENSE +0 -0
  40. {model_library-0.1.6 → model_library-0.1.7}/Makefile +0 -0
  41. {model_library-0.1.6 → model_library-0.1.7}/README.md +0 -0
  42. {model_library-0.1.6 → model_library-0.1.7}/examples/README.md +0 -0
  43. {model_library-0.1.6 → model_library-0.1.7}/examples/advanced/batch.py +0 -0
  44. {model_library-0.1.6 → model_library-0.1.7}/examples/advanced/custom_retrier.py +0 -0
  45. {model_library-0.1.6 → model_library-0.1.7}/examples/advanced/deep_research.py +0 -0
  46. {model_library-0.1.6 → model_library-0.1.7}/examples/advanced/stress.py +0 -0
  47. {model_library-0.1.6 → model_library-0.1.7}/examples/advanced/structured_output.py +0 -0
  48. {model_library-0.1.6 → model_library-0.1.7}/examples/basics.py +0 -0
  49. {model_library-0.1.6 → model_library-0.1.7}/examples/data/files.py +0 -0
  50. {model_library-0.1.6 → model_library-0.1.7}/examples/data/images.py +0 -0
  51. {model_library-0.1.6 → model_library-0.1.7}/examples/embeddings.py +0 -0
  52. {model_library-0.1.6 → model_library-0.1.7}/examples/files.py +0 -0
  53. {model_library-0.1.6 → model_library-0.1.7}/examples/images.py +0 -0
  54. {model_library-0.1.6 → model_library-0.1.7}/examples/prompt_caching.py +0 -0
  55. {model_library-0.1.6 → model_library-0.1.7}/examples/setup.py +0 -0
  56. {model_library-0.1.6 → model_library-0.1.7}/examples/tool_calls.py +0 -0
  57. {model_library-0.1.6 → model_library-0.1.7}/model_library/__init__.py +0 -0
  58. {model_library-0.1.6 → model_library-0.1.7}/model_library/base/__init__.py +0 -0
  59. {model_library-0.1.6 → model_library-0.1.7}/model_library/base/batch.py +0 -0
  60. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/README.md +0 -0
  61. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/ai21labs_models.yaml +0 -0
  62. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/alibaba_models.yaml +0 -0
  63. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/all_models.json +0 -0
  64. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/amazon_models.yaml +0 -0
  65. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/anthropic_models.yaml +0 -0
  66. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/cohere_models.yaml +0 -0
  67. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/deepseek_models.yaml +0 -0
  68. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/dummy_model.yaml +0 -0
  69. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/fireworks_models.yaml +0 -0
  70. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/google_models.yaml +0 -0
  71. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/inception_models.yaml +0 -0
  72. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/kimi_models.yaml +0 -0
  73. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/minimax_models.yaml +0 -0
  74. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/mistral_models.yaml +0 -0
  75. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/openai_models.yaml +0 -0
  76. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/perplexity_models.yaml +0 -0
  77. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/together_models.yaml +0 -0
  78. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/xai_models.yaml +0 -0
  79. {model_library-0.1.6 → model_library-0.1.7}/model_library/config/zai_models.yaml +0 -0
  80. {model_library-0.1.6 → model_library-0.1.7}/model_library/file_utils.py +0 -0
  81. {model_library-0.1.6 → model_library-0.1.7}/model_library/model_utils.py +0 -0
  82. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/__init__.py +0 -0
  83. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/alibaba.py +0 -0
  84. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/azure.py +0 -0
  85. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/cohere.py +0 -0
  86. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/deepseek.py +0 -0
  87. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/fireworks.py +0 -0
  88. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/google/__init__.py +0 -0
  89. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/inception.py +0 -0
  90. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/kimi.py +0 -0
  91. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/perplexity.py +0 -0
  92. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/together.py +0 -0
  93. {model_library-0.1.6 → model_library-0.1.7}/model_library/providers/zai.py +0 -0
  94. {model_library-0.1.6 → model_library-0.1.7}/model_library/py.typed +0 -0
  95. {model_library-0.1.6 → model_library-0.1.7}/model_library/registry_utils.py +0 -0
  96. {model_library-0.1.6 → model_library-0.1.7}/model_library/settings.py +0 -0
  97. {model_library-0.1.6 → model_library-0.1.7}/model_library.egg-info/dependency_links.txt +0 -0
  98. {model_library-0.1.6 → model_library-0.1.7}/model_library.egg-info/top_level.txt +0 -0
  99. {model_library-0.1.6 → model_library-0.1.7}/scripts/browse_models.py +0 -0
  100. {model_library-0.1.6 → model_library-0.1.7}/scripts/config.py +0 -0
  101. {model_library-0.1.6 → model_library-0.1.7}/scripts/publish.py +0 -0
  102. {model_library-0.1.6 → model_library-0.1.7}/setup.cfg +0 -0
  103. {model_library-0.1.6 → model_library-0.1.7}/tests/README.md +0 -0
  104. {model_library-0.1.6 → model_library-0.1.7}/tests/__init__.py +0 -0
  105. {model_library-0.1.6 → model_library-0.1.7}/tests/conftest.py +0 -0
  106. {model_library-0.1.6 → model_library-0.1.7}/tests/integration/__init__.py +0 -0
  107. {model_library-0.1.6 → model_library-0.1.7}/tests/integration/conftest.py +0 -0
  108. {model_library-0.1.6 → model_library-0.1.7}/tests/integration/test_batch.py +0 -0
  109. {model_library-0.1.6 → model_library-0.1.7}/tests/integration/test_completion.py +0 -0
  110. {model_library-0.1.6 → model_library-0.1.7}/tests/integration/test_files.py +0 -0
  111. {model_library-0.1.6 → model_library-0.1.7}/tests/integration/test_reasoning.py +0 -0
  112. {model_library-0.1.6 → model_library-0.1.7}/tests/integration/test_retry.py +0 -0
  113. {model_library-0.1.6 → model_library-0.1.7}/tests/integration/test_streaming.py +0 -0
  114. {model_library-0.1.6 → model_library-0.1.7}/tests/integration/test_structured_output.py +0 -0
  115. {model_library-0.1.6 → model_library-0.1.7}/tests/integration/test_tools.py +0 -0
  116. {model_library-0.1.6 → model_library-0.1.7}/tests/test_helpers.py +0 -0
  117. {model_library-0.1.6 → model_library-0.1.7}/tests/unit/__init__.py +0 -0
  118. {model_library-0.1.6 → model_library-0.1.7}/tests/unit/providers/__init__.py +0 -0
  119. {model_library-0.1.6 → model_library-0.1.7}/tests/unit/providers/test_fireworks_provider.py +0 -0
  120. {model_library-0.1.6 → model_library-0.1.7}/tests/unit/providers/test_google_provider.py +0 -0
  121. {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_batch.py +0 -0
  122. {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_context_window.py +0 -0
  123. {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_deep_research.py +0 -0
  124. {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_perplexity_provider.py +0 -0
  125. {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_query_logger.py +0 -0
  126. {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_registry.py +0 -0
  127. {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_result_metadata.py +0 -0
  128. {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_retry.py +0 -0
  129. {model_library-0.1.6 → model_library-0.1.7}/tests/unit/test_streaming.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: model-library
3
- Version: 0.1.6
3
+ Version: 0.1.7
4
4
  Summary: Model Library for vals.ai
5
5
  Author-email: "Vals AI, Inc." <contact@vals.ai>
6
6
  License: MIT
@@ -13,13 +13,13 @@ Requires-Dist: pyyaml>=6.0.2
13
13
  Requires-Dist: rich
14
14
  Requires-Dist: backoff<3.0,>=2.2.1
15
15
  Requires-Dist: redis<7.0,>=6.2.0
16
- Requires-Dist: tiktoken==0.11.0
16
+ Requires-Dist: tiktoken>=0.12.0
17
17
  Requires-Dist: pillow
18
18
  Requires-Dist: openai<3.0,>=2.0
19
19
  Requires-Dist: anthropic<1.0,>=0.57.1
20
20
  Requires-Dist: mistralai<2.0,>=1.9.10
21
21
  Requires-Dist: xai-sdk<2.0,>=1.0.0
22
- Requires-Dist: ai21<5.0,>=4.0.3
22
+ Requires-Dist: ai21<5.0,>=4.3.0
23
23
  Requires-Dist: boto3<2.0,>=1.38.27
24
24
  Requires-Dist: google-genai[aiohttp]>=1.51.0
25
25
  Requires-Dist: google-cloud-storage>=1.26.0
@@ -2,6 +2,7 @@ import asyncio
2
2
  from typing import Any, cast
3
3
 
4
4
  from model_library.base import LLM, ToolDefinition
5
+ from model_library.base.output import QueryResult
5
6
  from model_library.registry_utils import get_registry_model
6
7
 
7
8
  from ..setup import console_log, setup
@@ -41,31 +42,7 @@ def print_search_details(tool_call: Any) -> None:
41
42
  console_log(f" - {source}")
42
43
 
43
44
 
44
- def print_citations(response: Any) -> None:
45
- """Extract and print citations from response history."""
46
- if not response.history:
47
- return
48
-
49
- for item in response.history:
50
- if not (hasattr(item, "content") and isinstance(item.content, list)):
51
- continue
52
-
53
- content_list = cast(list[Any], item.content)
54
- for content_item in content_list:
55
- if not (hasattr(content_item, "annotations") and content_item.annotations):
56
- continue
57
-
58
- console_log("\nCitations:")
59
- annotations = cast(list[Any], content_item.annotations)
60
- for annotation in annotations:
61
- if hasattr(annotation, "url") and annotation.url:
62
- title = getattr(annotation, "title", "Untitled")
63
- url = annotation.url
64
- location = getattr(annotation, "location", "Unknown")
65
- console_log(f"- {title}: {url} (Location: {location})")
66
-
67
-
68
- def print_web_search_results(response: Any) -> None:
45
+ def print_web_search_results(response: QueryResult) -> None:
69
46
  """Print comprehensive web search results."""
70
47
  console_log(f"Response: {response.output_text}")
71
48
 
@@ -74,7 +51,7 @@ def print_web_search_results(response: Any) -> None:
74
51
  for tool_call in response.tool_calls:
75
52
  print_search_details(tool_call)
76
53
 
77
- print_citations(response)
54
+ print(response.extras.citations)
78
55
 
79
56
 
80
57
  async def web_search_domain_filtered(model: LLM) -> None:
@@ -0,0 +1,95 @@
1
+ import asyncio
2
+ import logging
3
+
4
+ from model_library import set_logging
5
+ from model_library.base import (
6
+ LLM,
7
+ QueryResult,
8
+ TextInput,
9
+ ToolBody,
10
+ ToolDefinition,
11
+ )
12
+ from model_library.registry_utils import get_registry_model
13
+
14
+ from .setup import console_log, setup
15
+
16
+
17
+ async def count_tokens(model: LLM):
18
+ console_log("\n--- Count Tokens ---\n")
19
+
20
+ tools = [
21
+ ToolDefinition(
22
+ name="get_weather",
23
+ body=ToolBody(
24
+ name="get_weather",
25
+ description="Get current temperature in a given location",
26
+ properties={
27
+ "location": {
28
+ "type": "string",
29
+ "description": "City and country e.g. Bogotá, Colombia",
30
+ },
31
+ },
32
+ required=["location"],
33
+ ),
34
+ ),
35
+ ToolDefinition(
36
+ name="get_danger",
37
+ body=ToolBody(
38
+ name="get_danger",
39
+ description="Get current danger in a given location",
40
+ properties={
41
+ "location": {
42
+ "type": "string",
43
+ "description": "City and country e.g. Bogotá, Colombia",
44
+ },
45
+ },
46
+ required=["location"],
47
+ ),
48
+ ),
49
+ ]
50
+
51
+ system_prompt = "You must make exactly 0 or 1 tool calls per answer. You must not make more than 1 tool call per answer."
52
+ user_prompt = "What is the weather in San Francisco right now?"
53
+
54
+ predicted_tokens = await model.count_tokens(
55
+ [TextInput(text=user_prompt)],
56
+ tools=tools,
57
+ system_prompt=system_prompt,
58
+ )
59
+
60
+ response: QueryResult = await model.query(
61
+ [TextInput(text=user_prompt)],
62
+ tools=tools,
63
+ system_prompt=system_prompt,
64
+ )
65
+
66
+ actual_tokens = response.metadata.total_input_tokens
67
+
68
+ console_log(f"Predicted Token Count: {predicted_tokens}")
69
+ console_log(f"Actual Token Count: {actual_tokens}\n")
70
+
71
+
72
+ async def main():
73
+ import argparse
74
+
75
+ parser = argparse.ArgumentParser(description="Example of counting tokens")
76
+ parser.add_argument(
77
+ "model",
78
+ nargs="?",
79
+ default="google/gemini-2.5-flash",
80
+ type=str,
81
+ help="Model endpoint (default: google/gemini-2.5-flash)",
82
+ )
83
+ args = parser.parse_args()
84
+
85
+ model = get_registry_model(args.model)
86
+ model.logger.info(model)
87
+
88
+ set_logging(enable=True, level=logging.INFO)
89
+
90
+ await count_tokens(model)
91
+
92
+
93
+ if __name__ == "__main__":
94
+ setup()
95
+ asyncio.run(main())
@@ -13,8 +13,10 @@ from typing import (
13
13
  TypeVar,
14
14
  )
15
15
 
16
+ import tiktoken
16
17
  from pydantic import model_serializer
17
18
  from pydantic.main import BaseModel
19
+ from tiktoken.core import Encoding
18
20
  from typing_extensions import override
19
21
 
20
22
  from model_library.base.batch import (
@@ -35,6 +37,7 @@ from model_library.base.output import (
35
37
  )
36
38
  from model_library.base.utils import (
37
39
  get_pretty_input_types,
40
+ serialize_for_tokenizing,
38
41
  )
39
42
  from model_library.exceptions import (
40
43
  ImmediateRetryException,
@@ -379,6 +382,20 @@ class LLM(ABC):
379
382
  """
380
383
  ...
381
384
 
385
+ @abstractmethod
386
+ async def build_body(
387
+ self,
388
+ input: Sequence[InputItem],
389
+ *,
390
+ tools: list[ToolDefinition],
391
+ **kwargs: Any,
392
+ ) -> dict[str, Any]:
393
+ """
394
+ Builds the body of the request to the model provider
395
+ Calls parse_input
396
+ """
397
+ ...
398
+
382
399
  @abstractmethod
383
400
  async def parse_input(
384
401
  self,
@@ -421,6 +438,87 @@ class LLM(ABC):
421
438
  """Upload a file to the model provider"""
422
439
  ...
423
440
 
441
+ async def get_encoding(self) -> Encoding:
442
+ """Get the appropriate tokenizer"""
443
+
444
+ model = self.model_name.lower()
445
+
446
+ if any(x in model for x in ["gpt-4o", "o1", "o3", "gpt-4.1", "gpt-5"]):
447
+ return tiktoken.get_encoding("o200k_base")
448
+ elif "gpt-4" in model or "gpt-3.5" in model:
449
+ try:
450
+ return tiktoken.encoding_for_model(self.model_name)
451
+ except KeyError:
452
+ return tiktoken.get_encoding("cl100k_base")
453
+ elif "claude" in model:
454
+ return tiktoken.get_encoding("cl100k_base")
455
+ elif "gemini" in model:
456
+ return tiktoken.get_encoding("o200k_base")
457
+ elif "llama" in model or "mistral" in model:
458
+ return tiktoken.get_encoding("cl100k_base")
459
+ else:
460
+ return tiktoken.get_encoding("cl100k_base")
461
+
462
+ async def stringify_input(
463
+ self,
464
+ input: Sequence[InputItem],
465
+ *,
466
+ history: Sequence[InputItem] = [],
467
+ tools: list[ToolDefinition] = [],
468
+ **kwargs: object,
469
+ ) -> str:
470
+ input = [*history, *input]
471
+
472
+ system_prompt = kwargs.pop(
473
+ "system_prompt", ""
474
+ ) # TODO: refactor along with system prompt arg change
475
+
476
+ # special case if using a delegate
477
+ # don't inherit method override by default
478
+ if self.delegate:
479
+ parsed_input = await self.delegate.parse_input(input, **kwargs)
480
+ parsed_tools = await self.delegate.parse_tools(tools)
481
+ else:
482
+ parsed_input = await self.parse_input(input, **kwargs)
483
+ parsed_tools = await self.parse_tools(tools)
484
+
485
+ serialized_input = serialize_for_tokenizing(parsed_input)
486
+ serialized_tools = serialize_for_tokenizing(parsed_tools)
487
+
488
+ combined = f"{system_prompt}\n{serialized_input}\n{serialized_tools}"
489
+
490
+ return combined
491
+
492
+ async def count_tokens(
493
+ self,
494
+ input: Sequence[InputItem],
495
+ *,
496
+ history: Sequence[InputItem] = [],
497
+ tools: list[ToolDefinition] = [],
498
+ **kwargs: object,
499
+ ) -> int:
500
+ """
501
+ Count the number of tokens for a query.
502
+ Combines parsed input and tools, then tokenizes the result.
503
+ """
504
+
505
+ if not input and not history:
506
+ return 0
507
+
508
+ if self.delegate:
509
+ encoding = await self.delegate.get_encoding()
510
+ else:
511
+ encoding = await self.get_encoding()
512
+ self.logger.debug(f"Token Count Encoding: {encoding}")
513
+
514
+ string_input = await self.stringify_input(
515
+ input, history=history, tools=tools, **kwargs
516
+ )
517
+
518
+ count = len(encoding.encode(string_input, disallowed_special=()))
519
+ self.logger.debug(f"Combined Token Count Input: {count}")
520
+ return count
521
+
424
522
  async def query_json(
425
523
  self,
426
524
  input: Sequence[InputItem],
@@ -58,6 +58,16 @@ class DelegateOnly(LLM):
58
58
  input, tools=tools, query_logger=query_logger, **kwargs
59
59
  )
60
60
 
61
+ @override
62
+ async def build_body(
63
+ self,
64
+ input: Sequence[InputItem],
65
+ *,
66
+ tools: list[ToolDefinition],
67
+ **kwargs: object,
68
+ ) -> dict[str, Any]:
69
+ raise DelegateOnlyException()
70
+
61
71
  @override
62
72
  async def parse_input(
63
73
  self,
@@ -74,8 +74,6 @@ class ToolCall(BaseModel):
74
74
  --- INPUT ---
75
75
  """
76
76
 
77
- RawResponse = Any
78
-
79
77
 
80
78
  class ToolInput(BaseModel):
81
79
  tools: list[ToolDefinition] = []
@@ -90,11 +88,16 @@ class TextInput(BaseModel):
90
88
  text: str
91
89
 
92
90
 
93
- RawInputItem = dict[
94
- str, Any
95
- ] # to pass in, for example, a mock convertsation with {"role": "user", "content": "Hello"}
91
+ class RawResponse(BaseModel):
92
+ # used to store a received response
93
+ response: Any
94
+
95
+
96
+ class RawInput(BaseModel):
97
+ # used to pass in anything provider specific (e.g. a mock conversation)
98
+ input: Any
96
99
 
97
100
 
98
101
  InputItem = (
99
- TextInput | FileInput | ToolResult | RawInputItem | RawResponse
100
- ) # input item can either be a prompt, a file (image or file), a tool call result, raw input, or a previous response
102
+ TextInput | FileInput | ToolResult | RawInput | RawResponse
103
+ ) # input item can either be a prompt, a file (image or file), a tool call result, a previous response, or raw input
@@ -24,6 +24,11 @@ class Citation(BaseModel):
24
24
  index: int | None = None
25
25
  container_id: str | None = None
26
26
 
27
+ @override
28
+ def __repr__(self):
29
+ attrs = vars(self).copy()
30
+ return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2)}\n)"
31
+
27
32
 
28
33
  class QueryResultExtras(BaseModel):
29
34
  citations: list[Citation] = Field(default_factory=list)
@@ -1,18 +1,34 @@
1
- from typing import Sequence, TypeVar, cast
1
+ import json
2
+ from typing import Any, Sequence, TypeVar
3
+
4
+ from pydantic import BaseModel
2
5
 
3
6
  from model_library.base.input import (
4
7
  FileBase,
5
8
  InputItem,
6
- RawInputItem,
9
+ RawInput,
10
+ RawResponse,
7
11
  TextInput,
8
12
  ToolResult,
9
13
  )
10
14
  from model_library.utils import truncate_str
11
- from pydantic import BaseModel
12
15
 
13
16
  T = TypeVar("T", bound=BaseModel)
14
17
 
15
18
 
19
+ def serialize_for_tokenizing(content: Any) -> str:
20
+ """
21
+ Serialize parsed content into a string for tokenization
22
+ """
23
+ parts: list[str] = []
24
+ if content:
25
+ if isinstance(content, str):
26
+ parts.append(content)
27
+ else:
28
+ parts.append(json.dumps(content, default=str))
29
+ return "\n".join(parts)
30
+
31
+
16
32
  def add_optional(
17
33
  a: int | float | T | None, b: int | float | T | None
18
34
  ) -> int | float | T | None:
@@ -54,11 +70,9 @@ def get_pretty_input_types(input: Sequence["InputItem"], verbose: bool = False)
54
70
  return repr(item)
55
71
  case ToolResult():
56
72
  return repr(item)
57
- case dict():
58
- item = cast(RawInputItem, item)
73
+ case RawInput():
59
74
  return repr(item)
60
- case _:
61
- # RawResponse
75
+ case RawResponse():
62
76
  return repr(item)
63
77
 
64
78
  processed_items = [f" {process_item(item)}" for item in input]
@@ -146,6 +146,17 @@ class BadInputError(Exception):
146
146
  super().__init__(message or BadInputError.DEFAULT_MESSAGE)
147
147
 
148
148
 
149
+ class NoMatchingToolCallError(Exception):
150
+ """
151
+ Raised when a tool call result is provided with no matching tool call
152
+ """
153
+
154
+ DEFAULT_MESSAGE: str = "Tool call result provided with no matching tool call"
155
+
156
+ def __init__(self, message: str | None = None):
157
+ super().__init__(message or NoMatchingToolCallError.DEFAULT_MESSAGE)
158
+
159
+
149
160
  # Add more retriable exceptions as needed
150
161
  # Providers that don't have an explicit rate limit error are handled manually
151
162
  # by wrapping errored Http/gRPC requests with a BackoffRetryException
@@ -6,7 +6,11 @@ from rich.logging import RichHandler
6
6
  _llm_logger = logging.getLogger("llm")
7
7
 
8
8
 
9
- def set_logging(enable: bool = True, handler: logging.Handler | None = None):
9
+ def set_logging(
10
+ enable: bool = True,
11
+ level: int = logging.INFO,
12
+ handler: logging.Handler | None = None,
13
+ ):
10
14
  """
11
15
  Sets up logging for the model library
12
16
 
@@ -15,7 +19,7 @@ def set_logging(enable: bool = True, handler: logging.Handler | None = None):
15
19
  handler (logging.Handler, optional): A custom logging handler. Defaults to RichHandler.
16
20
  """
17
21
  if enable:
18
- _llm_logger.setLevel(logging.INFO)
22
+ _llm_logger.setLevel(level)
19
23
  else:
20
24
  _llm_logger.setLevel(logging.CRITICAL)
21
25
 
@@ -22,6 +22,7 @@ from model_library.base import (
22
22
  ToolDefinition,
23
23
  ToolResult,
24
24
  )
25
+ from model_library.base.input import RawResponse
25
26
  from model_library.exceptions import (
26
27
  BadInputError,
27
28
  MaxOutputTokensExceededError,
@@ -65,8 +66,6 @@ class AI21LabsModel(LLM):
65
66
  match item:
66
67
  case TextInput():
67
68
  new_input.append(ChatMessage(role="user", content=item.text))
68
- case AssistantMessage():
69
- new_input.append(item)
70
69
  case ToolResult():
71
70
  new_input.append(
72
71
  ToolMessage(
@@ -74,7 +73,9 @@ class AI21LabsModel(LLM):
74
73
  content=item.result,
75
74
  tool_call_id=item.tool_call.id,
76
75
  )
77
- )
76
+ ) # TODO: tool calling metadata and test
77
+ case RawResponse():
78
+ new_input.append(item.response)
78
79
  case _:
79
80
  raise BadInputError("Unsupported input type")
80
81
  return new_input
@@ -133,14 +134,13 @@ class AI21LabsModel(LLM):
133
134
  raise NotImplementedError()
134
135
 
135
136
  @override
136
- async def _query_impl(
137
+ async def build_body(
137
138
  self,
138
139
  input: Sequence[InputItem],
139
140
  *,
140
141
  tools: list[ToolDefinition],
141
- query_logger: logging.Logger,
142
142
  **kwargs: object,
143
- ) -> QueryResult:
143
+ ) -> dict[str, Any]:
144
144
  messages: list[ChatMessage] = []
145
145
  if "system_prompt" in kwargs:
146
146
  messages.append(
@@ -162,6 +162,18 @@ class AI21LabsModel(LLM):
162
162
  body["top_p"] = self.top_p
163
163
 
164
164
  body.update(kwargs)
165
+ return body
166
+
167
+ @override
168
+ async def _query_impl(
169
+ self,
170
+ input: Sequence[InputItem],
171
+ *,
172
+ tools: list[ToolDefinition],
173
+ query_logger: logging.Logger,
174
+ **kwargs: object,
175
+ ) -> QueryResult:
176
+ body = await self.build_body(input, tools=tools, **kwargs)
165
177
 
166
178
  response: ChatCompletionResponse = (
167
179
  await self.get_client().chat.completions.create(**body, stream=False) # pyright: ignore[reportAny, reportUnknownMemberType]
@@ -186,7 +198,7 @@ class AI21LabsModel(LLM):
186
198
 
187
199
  output = QueryResult(
188
200
  output_text=choice.message.content,
189
- history=[*input, choice.message],
201
+ history=[*input, RawResponse(response=choice.message)],
190
202
  metadata=QueryResultMetadata(
191
203
  in_tokens=response.usage.prompt_tokens,
192
204
  out_tokens=response.usage.completion_tokens,