model-library 0.1.7__tar.gz → 0.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. {model_library-0.1.7 → model_library-0.1.8}/Makefile +3 -5
  2. {model_library-0.1.7 → model_library-0.1.8}/PKG-INFO +2 -1
  3. {model_library-0.1.7 → model_library-0.1.8}/examples/advanced/custom_retrier.py +26 -33
  4. {model_library-0.1.7 → model_library-0.1.8}/examples/advanced/web_search.py +1 -2
  5. {model_library-0.1.7 → model_library-0.1.8}/examples/count_tokens.py +7 -4
  6. model_library-0.1.8/examples/token_retry.py +85 -0
  7. {model_library-0.1.7 → model_library-0.1.8}/model_library/base/base.py +139 -62
  8. model_library-0.1.8/model_library/base/delegate_only.py +175 -0
  9. {model_library-0.1.7 → model_library-0.1.8}/model_library/base/output.py +43 -0
  10. {model_library-0.1.7 → model_library-0.1.8}/model_library/base/utils.py +35 -0
  11. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/alibaba_models.yaml +44 -57
  12. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/all_models.json +253 -126
  13. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/kimi_models.yaml +30 -3
  14. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/openai_models.yaml +15 -23
  15. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/zai_models.yaml +24 -3
  16. {model_library-0.1.7 → model_library-0.1.8}/model_library/exceptions.py +3 -77
  17. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/ai21labs.py +12 -8
  18. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/alibaba.py +17 -8
  19. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/amazon.py +49 -16
  20. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/anthropic.py +93 -40
  21. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/azure.py +22 -10
  22. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/cohere.py +7 -7
  23. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/deepseek.py +8 -8
  24. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/fireworks.py +7 -8
  25. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/google/batch.py +14 -10
  26. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/google/google.py +48 -29
  27. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/inception.py +7 -7
  28. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/kimi.py +18 -8
  29. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/minimax.py +15 -17
  30. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/mistral.py +20 -8
  31. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/openai.py +99 -22
  32. model_library-0.1.8/model_library/providers/openrouter.py +34 -0
  33. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/perplexity.py +7 -7
  34. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/together.py +7 -8
  35. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/vals.py +12 -6
  36. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/xai.py +47 -42
  37. model_library-0.1.8/model_library/providers/zai.py +64 -0
  38. {model_library-0.1.7 → model_library-0.1.8}/model_library/registry_utils.py +39 -15
  39. model_library-0.1.8/model_library/retriers/backoff.py +73 -0
  40. model_library-0.1.8/model_library/retriers/base.py +225 -0
  41. model_library-0.1.8/model_library/retriers/token.py +427 -0
  42. model_library-0.1.8/model_library/retriers/utils.py +11 -0
  43. {model_library-0.1.7 → model_library-0.1.8}/model_library/settings.py +1 -1
  44. {model_library-0.1.7 → model_library-0.1.8}/model_library/utils.py +13 -0
  45. {model_library-0.1.7 → model_library-0.1.8}/model_library.egg-info/PKG-INFO +2 -1
  46. {model_library-0.1.7 → model_library-0.1.8}/model_library.egg-info/SOURCES.txt +13 -8
  47. {model_library-0.1.7 → model_library-0.1.8}/model_library.egg-info/requires.txt +1 -0
  48. {model_library-0.1.7 → model_library-0.1.8}/pyproject.toml +2 -0
  49. {model_library-0.1.7 → model_library-0.1.8}/scripts/browse_models.py +2 -2
  50. {model_library-0.1.7 → model_library-0.1.8}/scripts/run_models.py +13 -13
  51. model_library-0.1.8/tests/README.md +75 -0
  52. model_library-0.1.8/tests/conftest.py +85 -0
  53. model_library-0.1.8/tests/integration/conftest.py +8 -0
  54. model_library-0.1.8/tests/integration/test_basic.py +15 -0
  55. {model_library-0.1.7 → model_library-0.1.8}/tests/integration/test_batch.py +4 -12
  56. model_library-0.1.8/tests/integration/test_files.py +38 -0
  57. model_library-0.1.8/tests/integration/test_long_problem.py +24 -0
  58. {model_library-0.1.7 → model_library-0.1.8}/tests/integration/test_reasoning.py +3 -12
  59. model_library-0.1.8/tests/integration/test_retry.py +28 -0
  60. {model_library-0.1.7 → model_library-0.1.8}/tests/integration/test_streaming.py +13 -30
  61. {model_library-0.1.7 → model_library-0.1.8}/tests/integration/test_structured_output.py +0 -6
  62. {model_library-0.1.7 → model_library-0.1.8}/tests/integration/test_tools.py +1 -18
  63. model_library-0.1.8/tests/test_helpers.py +183 -0
  64. model_library-0.1.8/tests/unit/conftest.py +102 -0
  65. model_library-0.1.8/tests/unit/test_batch.py +353 -0
  66. model_library-0.1.8/tests/unit/test_count_tokens.py +42 -0
  67. model_library-0.1.8/tests/unit/test_get_client.py +188 -0
  68. model_library-0.1.8/tests/unit/test_openai_config.py +32 -0
  69. {model_library-0.1.7 → model_library-0.1.8}/tests/unit/test_prompt_caching.py +31 -81
  70. {model_library-0.1.7 → model_library-0.1.8}/tests/unit/test_query_logger.py +18 -0
  71. {model_library-0.1.7 → model_library-0.1.8}/tests/unit/test_registry.py +7 -7
  72. {model_library-0.1.7 → model_library-0.1.8}/tests/unit/test_result_metadata.py +28 -18
  73. {model_library-0.1.7 → model_library-0.1.8}/tests/unit/test_retry.py +44 -41
  74. model_library-0.1.8/tests/unit/test_token_retry.py +405 -0
  75. {model_library-0.1.7 → model_library-0.1.8}/tests/unit/test_tools.py +4 -6
  76. {model_library-0.1.7 → model_library-0.1.8}/uv.lock +25 -1
  77. model_library-0.1.7/model_library/base/delegate_only.py +0 -108
  78. model_library-0.1.7/model_library/providers/zai.py +0 -34
  79. model_library-0.1.7/tests/README.md +0 -87
  80. model_library-0.1.7/tests/conftest.py +0 -275
  81. model_library-0.1.7/tests/integration/conftest.py +0 -8
  82. model_library-0.1.7/tests/integration/test_completion.py +0 -41
  83. model_library-0.1.7/tests/integration/test_files.py +0 -279
  84. model_library-0.1.7/tests/integration/test_retry.py +0 -95
  85. model_library-0.1.7/tests/test_helpers.py +0 -89
  86. model_library-0.1.7/tests/unit/conftest.py +0 -53
  87. model_library-0.1.7/tests/unit/providers/test_fireworks_provider.py +0 -48
  88. model_library-0.1.7/tests/unit/providers/test_google_provider.py +0 -58
  89. model_library-0.1.7/tests/unit/test_batch.py +0 -236
  90. model_library-0.1.7/tests/unit/test_context_window.py +0 -45
  91. model_library-0.1.7/tests/unit/test_count_tokens.py +0 -67
  92. model_library-0.1.7/tests/unit/test_perplexity_provider.py +0 -73
  93. model_library-0.1.7/tests/unit/test_streaming.py +0 -83
  94. {model_library-0.1.7 → model_library-0.1.8}/.gitattributes +0 -0
  95. {model_library-0.1.7 → model_library-0.1.8}/.github/workflows/publish.yml +0 -0
  96. {model_library-0.1.7 → model_library-0.1.8}/.github/workflows/style.yaml +0 -0
  97. {model_library-0.1.7 → model_library-0.1.8}/.github/workflows/test.yaml +0 -0
  98. {model_library-0.1.7 → model_library-0.1.8}/.github/workflows/typecheck.yml +0 -0
  99. {model_library-0.1.7 → model_library-0.1.8}/.gitignore +0 -0
  100. {model_library-0.1.7 → model_library-0.1.8}/LICENSE +0 -0
  101. {model_library-0.1.7 → model_library-0.1.8}/README.md +0 -0
  102. {model_library-0.1.7 → model_library-0.1.8}/examples/README.md +0 -0
  103. {model_library-0.1.7 → model_library-0.1.8}/examples/advanced/batch.py +0 -0
  104. {model_library-0.1.7 → model_library-0.1.8}/examples/advanced/deep_research.py +0 -0
  105. {model_library-0.1.7 → model_library-0.1.8}/examples/advanced/stress.py +0 -0
  106. {model_library-0.1.7 → model_library-0.1.8}/examples/advanced/structured_output.py +0 -0
  107. {model_library-0.1.7 → model_library-0.1.8}/examples/basics.py +0 -0
  108. {model_library-0.1.7 → model_library-0.1.8}/examples/data/files.py +0 -0
  109. {model_library-0.1.7 → model_library-0.1.8}/examples/data/images.py +0 -0
  110. {model_library-0.1.7 → model_library-0.1.8}/examples/embeddings.py +0 -0
  111. {model_library-0.1.7 → model_library-0.1.8}/examples/files.py +0 -0
  112. {model_library-0.1.7 → model_library-0.1.8}/examples/images.py +0 -0
  113. {model_library-0.1.7 → model_library-0.1.8}/examples/prompt_caching.py +0 -0
  114. {model_library-0.1.7 → model_library-0.1.8}/examples/setup.py +0 -0
  115. {model_library-0.1.7 → model_library-0.1.8}/examples/tool_calls.py +0 -0
  116. {model_library-0.1.7 → model_library-0.1.8}/model_library/__init__.py +0 -0
  117. {model_library-0.1.7 → model_library-0.1.8}/model_library/base/__init__.py +0 -0
  118. {model_library-0.1.7 → model_library-0.1.8}/model_library/base/batch.py +0 -0
  119. {model_library-0.1.7 → model_library-0.1.8}/model_library/base/input.py +0 -0
  120. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/README.md +0 -0
  121. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/ai21labs_models.yaml +0 -0
  122. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/amazon_models.yaml +0 -0
  123. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/anthropic_models.yaml +0 -0
  124. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/cohere_models.yaml +0 -0
  125. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/deepseek_models.yaml +0 -0
  126. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/dummy_model.yaml +0 -0
  127. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/fireworks_models.yaml +0 -0
  128. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/google_models.yaml +0 -0
  129. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/inception_models.yaml +0 -0
  130. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/minimax_models.yaml +0 -0
  131. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/mistral_models.yaml +0 -0
  132. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/perplexity_models.yaml +0 -0
  133. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/together_models.yaml +0 -0
  134. {model_library-0.1.7 → model_library-0.1.8}/model_library/config/xai_models.yaml +0 -0
  135. {model_library-0.1.7 → model_library-0.1.8}/model_library/file_utils.py +0 -0
  136. {model_library-0.1.7 → model_library-0.1.8}/model_library/logging.py +0 -0
  137. {model_library-0.1.7 → model_library-0.1.8}/model_library/model_utils.py +0 -0
  138. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/__init__.py +0 -0
  139. {model_library-0.1.7 → model_library-0.1.8}/model_library/providers/google/__init__.py +0 -0
  140. {model_library-0.1.7 → model_library-0.1.8}/model_library/py.typed +0 -0
  141. {model_library-0.1.7 → model_library-0.1.8}/model_library/register_models.py +0 -0
  142. {model_library-0.1.7/tests → model_library-0.1.8/model_library/retriers}/__init__.py +0 -0
  143. {model_library-0.1.7 → model_library-0.1.8}/model_library.egg-info/dependency_links.txt +0 -0
  144. {model_library-0.1.7 → model_library-0.1.8}/model_library.egg-info/top_level.txt +0 -0
  145. {model_library-0.1.7 → model_library-0.1.8}/scripts/config.py +0 -0
  146. {model_library-0.1.7 → model_library-0.1.8}/scripts/publish.py +0 -0
  147. {model_library-0.1.7 → model_library-0.1.8}/setup.cfg +0 -0
  148. {model_library-0.1.7/tests/integration → model_library-0.1.8/tests}/__init__.py +0 -0
  149. {model_library-0.1.7/tests/unit → model_library-0.1.8/tests/integration}/__init__.py +0 -0
  150. {model_library-0.1.7/tests/unit/providers → model_library-0.1.8/tests/unit}/__init__.py +0 -0
  151. {model_library-0.1.7 → model_library-0.1.8}/tests/unit/test_deep_research.py +0 -0
@@ -29,15 +29,13 @@ venv_check:
29
29
  exit 1; \
30
30
  fi
31
31
 
32
+ DIR ?= tests/
32
33
  test: venv_check
33
34
  @echo "Running unit tests..."
34
- @uv run pytest tests/unit/ -m "not integration"
35
+ @uv run pytest $(DIR) -m unit -v -n 4 --dist loadscope --model=$(MODEL)
35
36
  test-integration: venv_check
36
37
  @echo "Running integration tests (requires API keys)..."
37
- @uv run pytest tests/integration/ -m "not unit"
38
- test-all: venv_check
39
- @echo "Running all tests..."
40
- @uv run pytest
38
+ @uv run pytest $(DIR) -m integration -v -n 4 --dist loadscope --model=$(MODEL)
41
39
 
42
40
  format: venv_check
43
41
  @uv run ruff format .
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: model-library
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Summary: Model Library for vals.ai
5
5
  Author-email: "Vals AI, Inc." <contact@vals.ai>
6
6
  License: MIT
@@ -23,6 +23,7 @@ Requires-Dist: ai21<5.0,>=4.3.0
23
23
  Requires-Dist: boto3<2.0,>=1.38.27
24
24
  Requires-Dist: google-genai[aiohttp]>=1.51.0
25
25
  Requires-Dist: google-cloud-storage>=1.26.0
26
+ Requires-Dist: pytest-xdist>=3.8.0
26
27
  Dynamic: license-file
27
28
 
28
29
  # Model Library
@@ -4,15 +4,15 @@ from typing import Any, Awaitable, Callable
4
4
 
5
5
  from model_library.base import (
6
6
  LLM,
7
- RetrierType,
8
7
  TextInput,
9
8
  )
10
9
  from model_library.exceptions import (
11
10
  BackoffRetryException,
12
11
  RetryException,
13
- retry_llm_call,
14
12
  )
15
13
  from model_library.registry_utils import get_registry_model
14
+ from model_library.retriers.backoff import ExponentialBackoffRetrier
15
+ from model_library.retriers.base import retry_decorator
16
16
 
17
17
  from ..setup import console_log, setup
18
18
 
@@ -36,34 +36,28 @@ def is_context_length_error(error_str: str) -> bool:
36
36
  )
37
37
 
38
38
 
39
- def custom_retrier(logger: logging.Logger) -> RetrierType:
39
+ def decorator(
40
+ func: Callable[..., Awaitable[Any]],
41
+ ) -> Callable[..., Awaitable[Any]]:
40
42
  """
41
- Custom retrier that raised BackoffRetryException for context length errors
42
- Custom retries takes in a logger. It replaces the backoff retrier. Immediate retries still function.
43
+ Decorator must return wrapper function
43
44
  """
44
45
 
45
- def decorator(
46
- func: Callable[..., Awaitable[Any]],
47
- ) -> Callable[..., Awaitable[Any]]:
48
- """
49
- Decorator must return wrapper function
50
- """
46
+ logger = logging.getLogger("llm.decorator")
51
47
 
52
- async def wrapper(*args: Any, **kwargs: Any) -> Any:
53
- try:
54
- return await func(*args, **kwargs)
55
- except Exception as e:
56
- # detect context length errors and retry with backoff
57
- if is_context_length_error(str(e).lower()):
58
- logger.warning(f"Context length error detected: {e}")
59
- # for simplicty, we don't actually retry
60
- raise BackoffRetryException(f"Context length error: {e}")
48
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
49
+ try:
50
+ return await func(*args, **kwargs)
51
+ except Exception as e:
52
+ # detect context length errors and retry with backoff
53
+ if is_context_length_error(str(e).lower()):
54
+ logger.warning(f"Context length error detected: {e}")
55
+ # for simplicty, we don't actually retry
56
+ raise BackoffRetryException(f"Context length error: {e}")
61
57
 
62
- raise
58
+ raise
63
59
 
64
- return wrapper
65
-
66
- return decorator
60
+ return wrapper
67
61
 
68
62
 
69
63
  async def custom_retrier_context(model: LLM):
@@ -73,7 +67,7 @@ async def custom_retrier_context(model: LLM):
73
67
 
74
68
  console_log("\n--- Custom Retrier ---\n")
75
69
 
76
- model.custom_retrier = custom_retrier
70
+ model.custom_retrier = decorator
77
71
  try:
78
72
  await model.query(
79
73
  [
@@ -102,15 +96,14 @@ async def custom_retrier_callback(model: LLM):
102
96
  if tries > 1:
103
97
  raise Exception("Reached retry 2")
104
98
 
105
- def custom_retrier(logger: logging.Logger):
106
- return retry_llm_call(
107
- logger,
108
- max_tries=3,
109
- max_time=500,
110
- backoff_callback=callback,
111
- )
99
+ retrier = ExponentialBackoffRetrier(
100
+ logger=model.logger,
101
+ max_tries=3,
102
+ max_time=500,
103
+ retry_callback=callback,
104
+ )
112
105
 
113
- model.custom_retrier = custom_retrier
106
+ model.custom_retrier = retry_decorator(retrier)
114
107
 
115
108
  def simulate_retry(*args: object, **kwargs: object):
116
109
  raise RetryException("Simulated failure")
@@ -1,8 +1,7 @@
1
1
  import asyncio
2
2
  from typing import Any, cast
3
3
 
4
- from model_library.base import LLM, ToolDefinition
5
- from model_library.base.output import QueryResult
4
+ from model_library.base import LLM, QueryResult, ToolDefinition
6
5
  from model_library.registry_utils import get_registry_model
7
6
 
8
7
  from ..setup import console_log, setup
@@ -51,22 +51,25 @@ async def count_tokens(model: LLM):
51
51
  system_prompt = "You must make exactly 0 or 1 tool calls per answer. You must not make more than 1 tool call per answer."
52
52
  user_prompt = "What is the weather in San Francisco right now?"
53
53
 
54
+ input = [TextInput(text=user_prompt)]
55
+
54
56
  predicted_tokens = await model.count_tokens(
55
- [TextInput(text=user_prompt)],
57
+ input,
56
58
  tools=tools,
57
59
  system_prompt=system_prompt,
58
60
  )
59
61
 
60
62
  response: QueryResult = await model.query(
61
- [TextInput(text=user_prompt)],
63
+ input,
62
64
  tools=tools,
63
65
  system_prompt=system_prompt,
64
66
  )
67
+ metadata = response.metadata
65
68
 
66
- actual_tokens = response.metadata.total_input_tokens
69
+ actual_tokens = metadata.total_input_tokens
67
70
 
68
71
  console_log(f"Predicted Token Count: {predicted_tokens}")
69
- console_log(f"Actual Token Count: {actual_tokens}\n")
72
+ console_log(f"Actual Token Count: {actual_tokens}")
70
73
 
71
74
 
72
75
  async def main():
@@ -0,0 +1,85 @@
1
+ import asyncio
2
+ import time
3
+ from logging import DEBUG
4
+ from typing import Any, Coroutine
5
+
6
+ from tqdm import tqdm
7
+
8
+ from model_library.base import LLM, TextInput, TokenRetryParams
9
+ from model_library.logging import set_logging
10
+ from model_library.registry_utils import get_registry_model
11
+ from model_library.retriers.token import set_redis_client
12
+
13
+ from .setup import console_log, setup
14
+
15
+
16
+ async def token_retry(model: LLM):
17
+ console_log("\n--- Token Retry ---\n")
18
+ await model.query(
19
+ [
20
+ TextInput(
21
+ # text="What is QSBS? Explain your thinking in detail and make it concise"
22
+ text="dwadwadwadawdLong argument of cats vs dogs" * 5000
23
+ + "Ignore the previous junk, tell me a very long story about the cats and the dogs. And yes, I do want an actual story, I just have no choice but to include the junk before, believe me."
24
+ )
25
+ ],
26
+ )
27
+
28
+
29
+ async def main():
30
+ import argparse
31
+
32
+ parser = argparse.ArgumentParser(description="Run basic examples with a model")
33
+ parser.add_argument(
34
+ "model",
35
+ nargs="?",
36
+ default="google/gemini-2.5-flash",
37
+ type=str,
38
+ help="Model endpoint (default: google/gemini-2.5-flash)",
39
+ )
40
+ args = parser.parse_args()
41
+
42
+ set_logging(level=DEBUG)
43
+
44
+ model = get_registry_model(args.model)
45
+ model.logger.info(model)
46
+
47
+ limit = await model.get_rate_limit()
48
+ model.logger.info(limit)
49
+
50
+ import redis.asyncio as redis
51
+
52
+ # NOTE: make sure you have redis running locally
53
+ # docker run -d -p 6379:6379 redis:latest
54
+
55
+ redis_client = redis.Redis(
56
+ host="localhost", port=6379, decode_responses=True, max_connections=None
57
+ )
58
+ set_redis_client(redis_client)
59
+
60
+ provider_tokenizer_input_modifier = 1
61
+ dataset_output_modifier = 0.001
62
+
63
+ limit = 100_000
64
+ await model.init_token_retry(
65
+ token_retry_params=TokenRetryParams(
66
+ input_modifier=provider_tokenizer_input_modifier,
67
+ output_modifier=dataset_output_modifier,
68
+ use_dynamic_estimate=True,
69
+ limit=limit,
70
+ )
71
+ )
72
+ tasks: list[Coroutine[Any, Any, None]] = []
73
+ for _ in range(200):
74
+ tasks.append(token_retry(model))
75
+
76
+ start = time.time()
77
+ for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
78
+ await coro
79
+ finish = time.time() - start
80
+ console_log(f"Finished in {finish:.1f}s")
81
+
82
+
83
+ if __name__ == "__main__":
84
+ setup()
85
+ asyncio.run(main())
@@ -1,9 +1,12 @@
1
+ import hashlib
1
2
  import io
2
3
  import logging
4
+ import threading
3
5
  import time
4
6
  import uuid
5
7
  from abc import ABC, abstractmethod
6
8
  from collections.abc import Awaitable
9
+ from math import ceil
7
10
  from pprint import pformat
8
11
  from typing import (
9
12
  Any,
@@ -14,7 +17,7 @@ from typing import (
14
17
  )
15
18
 
16
19
  import tiktoken
17
- from pydantic import model_serializer
20
+ from pydantic import SecretStr, model_serializer
18
21
  from pydantic.main import BaseModel
19
22
  from tiktoken.core import Encoding
20
23
  from typing_extensions import override
@@ -34,15 +37,15 @@ from model_library.base.output import (
34
37
  QueryResult,
35
38
  QueryResultCost,
36
39
  QueryResultMetadata,
40
+ RateLimit,
37
41
  )
38
42
  from model_library.base.utils import (
39
43
  get_pretty_input_types,
40
44
  serialize_for_tokenizing,
41
45
  )
42
- from model_library.exceptions import (
43
- ImmediateRetryException,
44
- retry_llm_call,
45
- )
46
+ from model_library.retriers.backoff import ExponentialBackoffRetrier
47
+ from model_library.retriers.base import BaseRetrier, R, RetrierType, retry_decorator
48
+ from model_library.retriers.token import TokenRetrier
46
49
  from model_library.utils import truncate_str
47
50
 
48
51
  PydanticT = TypeVar("PydanticT", bound=BaseModel)
@@ -56,11 +59,18 @@ class ProviderConfig(BaseModel):
56
59
  return self.__dict__
57
60
 
58
61
 
59
- DEFAULT_MAX_TOKENS = 2048
62
+ class TokenRetryParams(BaseModel):
63
+ input_modifier: float
64
+ output_modifier: float
65
+
66
+ use_dynamic_estimate: bool = True
67
+
68
+ limit: int
69
+ limit_refresh_seconds: Literal[60] = 60
60
70
 
61
71
 
62
72
  class LLMConfig(BaseModel):
63
- max_tokens: int = DEFAULT_MAX_TOKENS
73
+ max_tokens: int | None = None
64
74
  temperature: float | None = None
65
75
  top_p: float | None = None
66
76
  top_k: int | None = None
@@ -75,11 +85,18 @@ class LLMConfig(BaseModel):
75
85
  native: bool = True
76
86
  provider_config: ProviderConfig | None = None
77
87
  registry_key: str | None = None
88
+ custom_api_key: SecretStr | None = None
89
+
78
90
 
91
+ class DelegateConfig(BaseModel):
92
+ base_url: str
93
+ api_key: SecretStr
79
94
 
80
- RetrierType = Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]
81
95
 
82
- R = TypeVar("R") # return type
96
+ # shared across all subclasses and instances
97
+ # hash(provider + api_key) -> client
98
+ client_registry_lock = threading.Lock()
99
+ client_registry: dict[tuple[str, str], Any] = {}
83
100
 
84
101
 
85
102
  class LLM(ABC):
@@ -88,6 +105,34 @@ class LLM(ABC):
88
105
  LLM call errors should be raised as exceptions
89
106
  """
90
107
 
108
+ @abstractmethod
109
+ def get_client(self, api_key: str | None = None) -> Any:
110
+ """
111
+ Returns the cached instance of the appropriate SDK client.
112
+ Sublasses should implement this method and:
113
+ - if api_key is provided, initialize their client and call assing_client(client).
114
+ - else return super().get_client()
115
+ """
116
+ global client_registry
117
+ return client_registry[self._client_registry_key]
118
+
119
+ def assign_client(self, client: object) -> None:
120
+ """Thread-safe assignment to the client registry"""
121
+ global client_registry
122
+
123
+ if self._client_registry_key not in client_registry:
124
+ with client_registry_lock:
125
+ if self._client_registry_key not in client_registry:
126
+ client_registry[self._client_registry_key] = client
127
+
128
+ def has_client(self) -> bool:
129
+ return self._client_registry_key in client_registry
130
+
131
+ @abstractmethod
132
+ def _get_default_api_key(self) -> str:
133
+ """Return the api key from model_library.settings"""
134
+ ...
135
+
91
136
  def __init__(
92
137
  self,
93
138
  model_name: str,
@@ -103,7 +148,7 @@ class LLM(ABC):
103
148
  config = config or LLMConfig()
104
149
  self._registry_key = config.registry_key
105
150
 
106
- self.max_tokens: int = config.max_tokens
151
+ self.max_tokens: int | None = config.max_tokens
107
152
  self.temperature: float | None = config.temperature
108
153
  self.top_p: float | None = config.top_p
109
154
  self.top_k: int | None = config.top_k
@@ -131,21 +176,33 @@ class LLM(ABC):
131
176
  self.logger: logging.Logger = logging.getLogger(
132
177
  f"llm.{provider}.{model_name}<instance={self.instance_id}>"
133
178
  )
134
- self.custom_retrier: Callable[..., RetrierType] | None = retry_llm_call
179
+ self.custom_retrier: RetrierType | None = None
180
+
181
+ self.token_retry_params = None
182
+ # set _client_registry_key after initializing delegate
183
+ if not self.native:
184
+ return
185
+
186
+ if config.custom_api_key:
187
+ raw_key = config.custom_api_key.get_secret_value()
188
+ else:
189
+ raw_key = self._get_default_api_key()
190
+
191
+ key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
192
+ self._client_registry_key = (self.provider, key_hash)
193
+ self._client_registry_key_model_specific = (
194
+ f"{self.provider}.{self.model_name}",
195
+ key_hash,
196
+ )
197
+ self.get_client(api_key=raw_key)
135
198
 
136
199
  @override
137
200
  def __repr__(self):
138
201
  attrs = vars(self).copy()
139
202
  attrs.pop("logger", None)
140
203
  attrs.pop("custom_retrier", None)
141
- attrs.pop("_key", None)
142
204
  return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2, sort_dicts=False)}\n)"
143
205
 
144
- @abstractmethod
145
- def get_client(self) -> object:
146
- """Return the instance of the appropriate SDK client."""
147
- ...
148
-
149
206
  @staticmethod
150
207
  async def timer_wrapper(func: Callable[[], Awaitable[R]]) -> tuple[R, float]:
151
208
  """
@@ -155,43 +212,6 @@ class LLM(ABC):
155
212
  result = await func()
156
213
  return result, round(time.perf_counter() - start, 4)
157
214
 
158
- @staticmethod
159
- async def immediate_retry_wrapper(
160
- func: Callable[[], Awaitable[R]],
161
- logger: logging.Logger,
162
- ) -> R:
163
- """
164
- Retry the query immediately
165
- """
166
- MAX_IMMEDIATE_RETRIES = 10
167
- retries = 0
168
- while True:
169
- try:
170
- return await func()
171
- except ImmediateRetryException as e:
172
- if retries >= MAX_IMMEDIATE_RETRIES:
173
- logger.error(f"Query reached max immediate retries {retries}: {e}")
174
- raise Exception(
175
- f"Query reached max immediate retries {retries}: {e}"
176
- ) from e
177
- retries += 1
178
-
179
- logger.warning(
180
- f"Query retried immediately {retries}/{MAX_IMMEDIATE_RETRIES}: {e}"
181
- )
182
-
183
- @staticmethod
184
- async def backoff_retry_wrapper(
185
- func: Callable[..., Awaitable[R]],
186
- backoff_retrier: RetrierType | None,
187
- ) -> R:
188
- """
189
- Retry the query with backoff
190
- """
191
- if not backoff_retrier:
192
- return await func()
193
- return await backoff_retrier(func)()
194
-
195
215
  async def delegate_query(
196
216
  self,
197
217
  input: Sequence[InputItem],
@@ -276,15 +296,38 @@ class LLM(ABC):
276
296
  return await LLM.timer_wrapper(query_func)
277
297
 
278
298
  async def immediate_retry() -> tuple[QueryResult, float]:
279
- return await LLM.immediate_retry_wrapper(timed_query, query_logger)
280
-
281
- async def backoff_retry() -> tuple[QueryResult, float]:
282
- backoff_retrier = (
283
- self.custom_retrier(query_logger) if self.custom_retrier else None
284
- )
285
- return await LLM.backoff_retry_wrapper(immediate_retry, backoff_retrier)
299
+ return await BaseRetrier.immediate_retry_wrapper(timed_query, query_logger)
300
+
301
+ async def default_retry() -> tuple[QueryResult, float]:
302
+ if self.token_retry_params:
303
+ (
304
+ estimate_input_tokens,
305
+ estimate_output_tokens,
306
+ ) = await self.estimate_query_tokens(
307
+ input,
308
+ tools=tools,
309
+ **kwargs,
310
+ )
311
+ retrier = TokenRetrier(
312
+ logger=query_logger,
313
+ client_registry_key=self._client_registry_key_model_specific,
314
+ estimate_input_tokens=estimate_input_tokens,
315
+ estimate_output_tokens=estimate_output_tokens,
316
+ dynamic_estimate_instance_id=self.instance_id
317
+ if self.token_retry_params.use_dynamic_estimate
318
+ else None,
319
+ )
320
+ else:
321
+ retrier = ExponentialBackoffRetrier(logger=query_logger)
322
+ return await retry_decorator(retrier)(immediate_retry)()
323
+
324
+ run_with_retry = (
325
+ default_retry
326
+ if not self.custom_retrier
327
+ else self.custom_retrier(immediate_retry)
328
+ )
286
329
 
287
- output, duration = await backoff_retry()
330
+ output, duration = await run_with_retry()
288
331
  output.metadata.duration_seconds = duration
289
332
  output.metadata.cost = await self._calculate_cost(output.metadata)
290
333
 
@@ -293,6 +336,16 @@ class LLM(ABC):
293
336
 
294
337
  return output
295
338
 
339
+ async def init_token_retry(self, token_retry_params: TokenRetryParams) -> None:
340
+ self.token_retry_params = token_retry_params
341
+ await TokenRetrier.init_remaining_tokens(
342
+ client_registry_key=self._client_registry_key_model_specific,
343
+ limit=self.token_retry_params.limit,
344
+ limit_refresh_seconds=self.token_retry_params.limit_refresh_seconds,
345
+ get_rate_limit_func=self.get_rate_limit,
346
+ logger=self.logger,
347
+ )
348
+
296
349
  async def _calculate_cost(
297
350
  self,
298
351
  metadata: QueryResultMetadata,
@@ -438,6 +491,30 @@ class LLM(ABC):
438
491
  """Upload a file to the model provider"""
439
492
  ...
440
493
 
494
+ async def get_rate_limit(self) -> RateLimit | None:
495
+ """Get the rate limit for the model provider"""
496
+ return None
497
+
498
+ async def estimate_query_tokens(
499
+ self,
500
+ input: Sequence[InputItem],
501
+ *,
502
+ tools: list[ToolDefinition] = [],
503
+ **kwargs: object,
504
+ ) -> tuple[int, int]:
505
+ """Pessimistically estimate the number of tokens required for a query"""
506
+ assert self.token_retry_params
507
+
508
+ # TODO: when passing in images and files, we really need to take that into account when calculating the output tokens!!
509
+
510
+ input_tokens = (
511
+ await self.count_tokens(input, history=[], tools=tools, **kwargs)
512
+ * self.token_retry_params.input_modifier
513
+ )
514
+
515
+ output_tokens = input_tokens * self.token_retry_params.output_modifier
516
+ return ceil(input_tokens), ceil(output_tokens)
517
+
441
518
  async def get_encoding(self) -> Encoding:
442
519
  """Get the appropriate tokenizer"""
443
520