model-library 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,7 +6,6 @@ from abc import ABC, abstractmethod
6
6
  from collections.abc import Awaitable
7
7
  from pprint import pformat
8
8
  from typing import (
9
- TYPE_CHECKING,
10
9
  Any,
11
10
  Callable,
12
11
  Literal,
@@ -14,8 +13,10 @@ from typing import (
14
13
  TypeVar,
15
14
  )
16
15
 
16
+ import tiktoken
17
17
  from pydantic import model_serializer
18
18
  from pydantic.main import BaseModel
19
+ from tiktoken.core import Encoding
19
20
  from typing_extensions import override
20
21
 
21
22
  from model_library.base.batch import (
@@ -36,6 +37,7 @@ from model_library.base.output import (
36
37
  )
37
38
  from model_library.base.utils import (
38
39
  get_pretty_input_types,
40
+ serialize_for_tokenizing,
39
41
  )
40
42
  from model_library.exceptions import (
41
43
  ImmediateRetryException,
@@ -43,9 +45,6 @@ from model_library.exceptions import (
43
45
  )
44
46
  from model_library.utils import truncate_str
45
47
 
46
- if TYPE_CHECKING:
47
- from model_library.providers.openai import OpenAIModel
48
-
49
48
  PydanticT = TypeVar("PydanticT", bound=BaseModel)
50
49
 
51
50
 
@@ -66,7 +65,7 @@ class LLMConfig(BaseModel):
66
65
  top_p: float | None = None
67
66
  top_k: int | None = None
68
67
  reasoning: bool = False
69
- reasoning_effort: str | None = None
68
+ reasoning_effort: str | bool | None = None
70
69
  supports_images: bool = False
71
70
  supports_files: bool = False
72
71
  supports_videos: bool = False
@@ -110,7 +109,7 @@ class LLM(ABC):
110
109
  self.top_k: int | None = config.top_k
111
110
 
112
111
  self.reasoning: bool = config.reasoning
113
- self.reasoning_effort: str | None = config.reasoning_effort
112
+ self.reasoning_effort: str | bool | None = config.reasoning_effort
114
113
 
115
114
  self.supports_files: bool = config.supports_files
116
115
  self.supports_videos: bool = config.supports_videos
@@ -120,7 +119,7 @@ class LLM(ABC):
120
119
  self.supports_tools: bool = config.supports_tools
121
120
 
122
121
  self.native: bool = config.native
123
- self.delegate: "OpenAIModel | None" = None
122
+ self.delegate: "LLM | None" = None
124
123
  self.batch: LLMBatchMixin | None = None
125
124
 
126
125
  if config.provider_config:
@@ -198,11 +197,14 @@ class LLM(ABC):
198
197
  input: Sequence[InputItem],
199
198
  *,
200
199
  tools: list[ToolDefinition] = [],
200
+ query_logger: logging.Logger,
201
201
  **kwargs: object,
202
202
  ) -> QueryResult:
203
203
  if not self.delegate:
204
204
  raise Exception("Delegate not set")
205
- return await self.delegate._query_impl(input, tools=tools, **kwargs) # pyright: ignore[reportPrivateUsage]
205
+ return await self.delegate._query_impl( # pyright: ignore[reportPrivateUsage]
206
+ input, tools=tools, query_logger=query_logger, **kwargs
207
+ )
206
208
 
207
209
  async def query(
208
210
  self,
@@ -213,6 +215,7 @@ class LLM(ABC):
213
215
  # for backwards compatibility
214
216
  files: list[FileInput] = [],
215
217
  images: list[FileInput] = [],
218
+ query_logger: logging.Logger | None = None,
216
219
  **kwargs: object,
217
220
  ) -> QueryResult:
218
221
  """
@@ -256,15 +259,18 @@ class LLM(ABC):
256
259
  input = [*history, *input]
257
260
 
258
261
  # unique logger for the query
259
- query_id = uuid.uuid4().hex[:14]
260
- query_logger = self.logger.getChild(f"query={query_id}")
262
+ if not query_logger:
263
+ query_id = uuid.uuid4().hex[:14]
264
+ query_logger = self.logger.getChild(f"query={query_id}")
261
265
 
262
266
  query_logger.info(
263
267
  "Query started:\n" + item_info + tool_info + f"--- kwargs: {short_kwargs}\n"
264
268
  )
265
269
 
266
270
  async def query_func() -> QueryResult:
267
- return await self._query_impl(input, tools=tools, **kwargs)
271
+ return await self._query_impl(
272
+ input, tools=tools, query_logger=query_logger, **kwargs
273
+ )
268
274
 
269
275
  async def timed_query() -> tuple[QueryResult, float]:
270
276
  return await LLM.timer_wrapper(query_func)
@@ -361,7 +367,8 @@ class LLM(ABC):
361
367
  input: Sequence[InputItem],
362
368
  *,
363
369
  tools: list[ToolDefinition],
364
- **kwargs: object, # TODO: pass in query logger
370
+ query_logger: logging.Logger,
371
+ **kwargs: object,
365
372
  ) -> QueryResult:
366
373
  """
367
374
  Query the model with input
@@ -375,6 +382,20 @@ class LLM(ABC):
375
382
  """
376
383
  ...
377
384
 
385
+ @abstractmethod
386
+ async def build_body(
387
+ self,
388
+ input: Sequence[InputItem],
389
+ *,
390
+ tools: list[ToolDefinition],
391
+ **kwargs: Any,
392
+ ) -> dict[str, Any]:
393
+ """
394
+ Builds the body of the request to the model provider
395
+ Calls parse_input
396
+ """
397
+ ...
398
+
378
399
  @abstractmethod
379
400
  async def parse_input(
380
401
  self,
@@ -417,6 +438,87 @@ class LLM(ABC):
417
438
  """Upload a file to the model provider"""
418
439
  ...
419
440
 
441
+ async def get_encoding(self) -> Encoding:
442
+ """Get the appropriate tokenizer"""
443
+
444
+ model = self.model_name.lower()
445
+
446
+ if any(x in model for x in ["gpt-4o", "o1", "o3", "gpt-4.1", "gpt-5"]):
447
+ return tiktoken.get_encoding("o200k_base")
448
+ elif "gpt-4" in model or "gpt-3.5" in model:
449
+ try:
450
+ return tiktoken.encoding_for_model(self.model_name)
451
+ except KeyError:
452
+ return tiktoken.get_encoding("cl100k_base")
453
+ elif "claude" in model:
454
+ return tiktoken.get_encoding("cl100k_base")
455
+ elif "gemini" in model:
456
+ return tiktoken.get_encoding("o200k_base")
457
+ elif "llama" in model or "mistral" in model:
458
+ return tiktoken.get_encoding("cl100k_base")
459
+ else:
460
+ return tiktoken.get_encoding("cl100k_base")
461
+
462
+ async def stringify_input(
463
+ self,
464
+ input: Sequence[InputItem],
465
+ *,
466
+ history: Sequence[InputItem] = [],
467
+ tools: list[ToolDefinition] = [],
468
+ **kwargs: object,
469
+ ) -> str:
470
+ input = [*history, *input]
471
+
472
+ system_prompt = kwargs.pop(
473
+ "system_prompt", ""
474
+ ) # TODO: refactor along with system prompt arg change
475
+
476
+ # special case if using a delegate
477
+ # don't inherit method override by default
478
+ if self.delegate:
479
+ parsed_input = await self.delegate.parse_input(input, **kwargs)
480
+ parsed_tools = await self.delegate.parse_tools(tools)
481
+ else:
482
+ parsed_input = await self.parse_input(input, **kwargs)
483
+ parsed_tools = await self.parse_tools(tools)
484
+
485
+ serialized_input = serialize_for_tokenizing(parsed_input)
486
+ serialized_tools = serialize_for_tokenizing(parsed_tools)
487
+
488
+ combined = f"{system_prompt}\n{serialized_input}\n{serialized_tools}"
489
+
490
+ return combined
491
+
492
+ async def count_tokens(
493
+ self,
494
+ input: Sequence[InputItem],
495
+ *,
496
+ history: Sequence[InputItem] = [],
497
+ tools: list[ToolDefinition] = [],
498
+ **kwargs: object,
499
+ ) -> int:
500
+ """
501
+ Count the number of tokens for a query.
502
+ Combines parsed input and tools, then tokenizes the result.
503
+ """
504
+
505
+ if not input and not history:
506
+ return 0
507
+
508
+ if self.delegate:
509
+ encoding = await self.delegate.get_encoding()
510
+ else:
511
+ encoding = await self.get_encoding()
512
+ self.logger.debug(f"Token Count Encoding: {encoding}")
513
+
514
+ string_input = await self.stringify_input(
515
+ input, history=history, tools=tools, **kwargs
516
+ )
517
+
518
+ count = len(encoding.encode(string_input, disallowed_special=()))
519
+ self.logger.debug(f"Combined Token Count Input: {count}")
520
+ return count
521
+
420
522
  async def query_json(
421
523
  self,
422
524
  input: Sequence[InputItem],
@@ -1,4 +1,5 @@
1
1
  import io
2
+ import logging
2
3
  from typing import Any, Literal, Sequence
3
4
 
4
5
  from typing_extensions import override
@@ -48,11 +49,24 @@ class DelegateOnly(LLM):
48
49
  input: Sequence[InputItem],
49
50
  *,
50
51
  tools: list[ToolDefinition],
52
+ query_logger: logging.Logger,
51
53
  **kwargs: object,
52
54
  ) -> QueryResult:
53
55
  assert self.delegate
54
56
 
55
- return await self.delegate_query(input, tools=tools, **kwargs)
57
+ return await self.delegate_query(
58
+ input, tools=tools, query_logger=query_logger, **kwargs
59
+ )
60
+
61
+ @override
62
+ async def build_body(
63
+ self,
64
+ input: Sequence[InputItem],
65
+ *,
66
+ tools: list[ToolDefinition],
67
+ **kwargs: object,
68
+ ) -> dict[str, Any]:
69
+ raise DelegateOnlyException()
56
70
 
57
71
  @override
58
72
  async def parse_input(
@@ -74,8 +74,6 @@ class ToolCall(BaseModel):
74
74
  --- INPUT ---
75
75
  """
76
76
 
77
- RawResponse = Any
78
-
79
77
 
80
78
  class ToolInput(BaseModel):
81
79
  tools: list[ToolDefinition] = []
@@ -90,11 +88,16 @@ class TextInput(BaseModel):
90
88
  text: str
91
89
 
92
90
 
93
- RawInputItem = dict[
94
- str, Any
95
- ] # to pass in, for example, a mock convertsation with {"role": "user", "content": "Hello"}
91
+ class RawResponse(BaseModel):
92
+ # used to store a received response
93
+ response: Any
94
+
95
+
96
+ class RawInput(BaseModel):
97
+ # used to pass in anything provider specific (e.g. a mock conversation)
98
+ input: Any
96
99
 
97
100
 
98
101
  InputItem = (
99
- TextInput | FileInput | ToolResult | RawInputItem | RawResponse
100
- ) # input item can either be a prompt, a file (image or file), a tool call result, raw input, or a previous response
102
+ TextInput | FileInput | ToolResult | RawInput | RawResponse
103
+ ) # input item can either be a prompt, a file (image or file), a tool call result, a previous response, or raw input
@@ -24,6 +24,11 @@ class Citation(BaseModel):
24
24
  index: int | None = None
25
25
  container_id: str | None = None
26
26
 
27
+ @override
28
+ def __repr__(self):
29
+ attrs = vars(self).copy()
30
+ return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2)}\n)"
31
+
27
32
 
28
33
  class QueryResultExtras(BaseModel):
29
34
  citations: list[Citation] = Field(default_factory=list)
@@ -1,18 +1,34 @@
1
- from typing import Sequence, TypeVar, cast
1
+ import json
2
+ from typing import Any, Sequence, TypeVar
3
+
4
+ from pydantic import BaseModel
2
5
 
3
6
  from model_library.base.input import (
4
7
  FileBase,
5
8
  InputItem,
6
- RawInputItem,
9
+ RawInput,
10
+ RawResponse,
7
11
  TextInput,
8
12
  ToolResult,
9
13
  )
10
14
  from model_library.utils import truncate_str
11
- from pydantic import BaseModel
12
15
 
13
16
  T = TypeVar("T", bound=BaseModel)
14
17
 
15
18
 
19
+ def serialize_for_tokenizing(content: Any) -> str:
20
+ """
21
+ Serialize parsed content into a string for tokenization
22
+ """
23
+ parts: list[str] = []
24
+ if content:
25
+ if isinstance(content, str):
26
+ parts.append(content)
27
+ else:
28
+ parts.append(json.dumps(content, default=str))
29
+ return "\n".join(parts)
30
+
31
+
16
32
  def add_optional(
17
33
  a: int | float | T | None, b: int | float | T | None
18
34
  ) -> int | float | T | None:
@@ -54,11 +70,9 @@ def get_pretty_input_types(input: Sequence["InputItem"], verbose: bool = False)
54
70
  return repr(item)
55
71
  case ToolResult():
56
72
  return repr(item)
57
- case dict():
58
- item = cast(RawInputItem, item)
73
+ case RawInput():
59
74
  return repr(item)
60
- case _:
61
- # RawResponse
75
+ case RawResponse():
62
76
  return repr(item)
63
77
 
64
78
  processed_items = [f" {process_item(item)}" for item in input]
@@ -1,4 +1,94 @@
1
1
  {
2
+ "minimax/MiniMax-M2.1": {
3
+ "company": "MiniMax",
4
+ "label": "MiniMax-M2.1",
5
+ "description": null,
6
+ "release_date": "2025-12-23",
7
+ "open_source": true,
8
+ "documentation_url": "https://platform.minimax.io/docs",
9
+ "properties": {
10
+ "context_window": 204800,
11
+ "max_tokens": 131000,
12
+ "training_cutoff": null,
13
+ "reasoning_model": true
14
+ },
15
+ "supports": {
16
+ "images": false,
17
+ "files": false,
18
+ "temperature": true,
19
+ "tools": true
20
+ },
21
+ "metadata": {
22
+ "deprecated": false,
23
+ "available_for_everyone": true,
24
+ "available_as_evaluator": false,
25
+ "ignored_for_cost": false
26
+ },
27
+ "provider_properties": {},
28
+ "costs_per_million_token": {
29
+ "input": 0.3,
30
+ "output": 1.2,
31
+ "cache": {
32
+ "read": 0.03,
33
+ "write": 0.375,
34
+ "write_markup": 1.0
35
+ }
36
+ },
37
+ "alternative_keys": [],
38
+ "default_parameters": {
39
+ "temperature": 1.0,
40
+ "top_p": 0.95
41
+ },
42
+ "provider_endpoint": "MiniMax-M2.1",
43
+ "provider_name": "minimax",
44
+ "full_key": "minimax/MiniMax-M2.1",
45
+ "slug": "minimax_MiniMax-M2.1"
46
+ },
47
+ "zai/glm-4.7": {
48
+ "company": "zAI",
49
+ "label": "GLM 4.7",
50
+ "description": "Latest model from ZAI",
51
+ "release_date": "2025-12-22",
52
+ "open_source": true,
53
+ "documentation_url": "https://docs.z.ai/",
54
+ "properties": {
55
+ "context_window": 200000,
56
+ "max_tokens": 128000,
57
+ "training_cutoff": null,
58
+ "reasoning_model": true
59
+ },
60
+ "supports": {
61
+ "images": false,
62
+ "files": false,
63
+ "temperature": true,
64
+ "tools": true
65
+ },
66
+ "metadata": {
67
+ "deprecated": false,
68
+ "available_for_everyone": true,
69
+ "available_as_evaluator": false,
70
+ "ignored_for_cost": false
71
+ },
72
+ "provider_properties": {},
73
+ "costs_per_million_token": {
74
+ "input": 0.6,
75
+ "output": 2.2,
76
+ "cache": {
77
+ "read": 0.11,
78
+ "read_discount": 1.0,
79
+ "write_markup": 1.0
80
+ }
81
+ },
82
+ "alternative_keys": [],
83
+ "default_parameters": {
84
+ "temperature": 1.0,
85
+ "top_p": 1.0
86
+ },
87
+ "provider_endpoint": "glm-4.7",
88
+ "provider_name": "zai",
89
+ "full_key": "zai/glm-4.7",
90
+ "slug": "zai_glm-4.7"
91
+ },
2
92
  "google/gemini-3-flash-preview": {
3
93
  "company": "Google",
4
94
  "label": "Gemini 3 Flash (12/25)",
@@ -504,7 +594,8 @@
504
594
  }
505
595
  ],
506
596
  "default_parameters": {
507
- "temperature": 1.0
597
+ "temperature": 1.0,
598
+ "reasoning_effort": "none"
508
599
  },
509
600
  "provider_endpoint": "deepseek-v3p2",
510
601
  "provider_name": "fireworks",
@@ -150,6 +150,8 @@ deepseek-models:
150
150
  context_window: 160_000
151
151
  max_tokens: 20_480
152
152
  reasoning_model: false
153
+ default_parameters:
154
+ reasoning_effort: "none"
153
155
  costs_per_million_token:
154
156
  input: 0.56
155
157
  output: 1.68
@@ -16,6 +16,24 @@ base-config:
16
16
 
17
17
  minimax-m2-models:
18
18
 
19
+ minimax/MiniMax-M2.1:
20
+ label: MiniMax-M2.1
21
+ release_date: 2025-12-23
22
+ properties:
23
+ context_window: 204_800
24
+ max_tokens: 131_000
25
+ reasoning_model: true
26
+ training_cutoff: null
27
+ default_parameters:
28
+ temperature: 1.0
29
+ top_p: 0.95
30
+ costs_per_million_token:
31
+ input: 0.30
32
+ output: 1.20
33
+ cache:
34
+ read: 0.03
35
+ write: 0.375
36
+
19
37
  minimax/MiniMax-M2:
20
38
  label: MiniMax-M2
21
39
  description: MiniMax-M2 is a cost-efficient open-source model optimized for agentic applications and coding in particular.
@@ -18,6 +18,20 @@ base-config:
18
18
  write_markup: 1
19
19
 
20
20
  zai-models:
21
+ zai/glm-4.7:
22
+ label: GLM 4.7
23
+ description: "Latest model from ZAI"
24
+ release_date: 2025-12-22
25
+ properties:
26
+ context_window: 200_000
27
+ max_tokens: 128_000
28
+ costs_per_million_token:
29
+ input: 0.6
30
+ output: 2.2
31
+ cache:
32
+ read: 0.11
33
+ default_parameters:
34
+ temperature: 1
21
35
  zai/glm-4.5:
22
36
  label: GLM 4.5
23
37
  description: "z.AI old model"
@@ -146,6 +146,17 @@ class BadInputError(Exception):
146
146
  super().__init__(message or BadInputError.DEFAULT_MESSAGE)
147
147
 
148
148
 
149
+ class NoMatchingToolCallError(Exception):
150
+ """
151
+ Raised when a tool call result is provided with no matching tool call
152
+ """
153
+
154
+ DEFAULT_MESSAGE: str = "Tool call result provided with no matching tool call"
155
+
156
+ def __init__(self, message: str | None = None):
157
+ super().__init__(message or NoMatchingToolCallError.DEFAULT_MESSAGE)
158
+
159
+
149
160
  # Add more retriable exceptions as needed
150
161
  # Providers that don't have an explicit rate limit error are handled manually
151
162
  # by wrapping errored Http/gRPC requests with a BackoffRetryException
model_library/logging.py CHANGED
@@ -6,7 +6,11 @@ from rich.logging import RichHandler
6
6
  _llm_logger = logging.getLogger("llm")
7
7
 
8
8
 
9
- def set_logging(enable: bool = True, handler: logging.Handler | None = None):
9
+ def set_logging(
10
+ enable: bool = True,
11
+ level: int = logging.INFO,
12
+ handler: logging.Handler | None = None,
13
+ ):
10
14
  """
11
15
  Sets up logging for the model library
12
16
 
@@ -15,7 +19,7 @@ def set_logging(enable: bool = True, handler: logging.Handler | None = None):
15
19
  handler (logging.Handler, optional): A custom logging handler. Defaults to RichHandler.
16
20
  """
17
21
  if enable:
18
- _llm_logger.setLevel(logging.INFO)
22
+ _llm_logger.setLevel(level)
19
23
  else:
20
24
  _llm_logger.setLevel(logging.CRITICAL)
21
25
 
@@ -1,4 +1,5 @@
1
1
  import io
2
+ import logging
2
3
  from typing import Any, Literal, Sequence
3
4
 
4
5
  from ai21 import AsyncAI21Client
@@ -21,6 +22,7 @@ from model_library.base import (
21
22
  ToolDefinition,
22
23
  ToolResult,
23
24
  )
25
+ from model_library.base.input import RawResponse
24
26
  from model_library.exceptions import (
25
27
  BadInputError,
26
28
  MaxOutputTokensExceededError,
@@ -64,8 +66,6 @@ class AI21LabsModel(LLM):
64
66
  match item:
65
67
  case TextInput():
66
68
  new_input.append(ChatMessage(role="user", content=item.text))
67
- case AssistantMessage():
68
- new_input.append(item)
69
69
  case ToolResult():
70
70
  new_input.append(
71
71
  ToolMessage(
@@ -73,7 +73,9 @@ class AI21LabsModel(LLM):
73
73
  content=item.result,
74
74
  tool_call_id=item.tool_call.id,
75
75
  )
76
- )
76
+ ) # TODO: tool calling metadata and test
77
+ case RawResponse():
78
+ new_input.append(item.response)
77
79
  case _:
78
80
  raise BadInputError("Unsupported input type")
79
81
  return new_input
@@ -132,13 +134,13 @@ class AI21LabsModel(LLM):
132
134
  raise NotImplementedError()
133
135
 
134
136
  @override
135
- async def _query_impl(
137
+ async def build_body(
136
138
  self,
137
139
  input: Sequence[InputItem],
138
140
  *,
139
141
  tools: list[ToolDefinition],
140
142
  **kwargs: object,
141
- ) -> QueryResult:
143
+ ) -> dict[str, Any]:
142
144
  messages: list[ChatMessage] = []
143
145
  if "system_prompt" in kwargs:
144
146
  messages.append(
@@ -160,6 +162,18 @@ class AI21LabsModel(LLM):
160
162
  body["top_p"] = self.top_p
161
163
 
162
164
  body.update(kwargs)
165
+ return body
166
+
167
+ @override
168
+ async def _query_impl(
169
+ self,
170
+ input: Sequence[InputItem],
171
+ *,
172
+ tools: list[ToolDefinition],
173
+ query_logger: logging.Logger,
174
+ **kwargs: object,
175
+ ) -> QueryResult:
176
+ body = await self.build_body(input, tools=tools, **kwargs)
163
177
 
164
178
  response: ChatCompletionResponse = (
165
179
  await self.get_client().chat.completions.create(**body, stream=False) # pyright: ignore[reportAny, reportUnknownMemberType]
@@ -184,7 +198,7 @@ class AI21LabsModel(LLM):
184
198
 
185
199
  output = QueryResult(
186
200
  output_text=choice.message.content,
187
- history=[*input, choice.message],
201
+ history=[*input, RawResponse(response=choice.message)],
188
202
  metadata=QueryResultMetadata(
189
203
  in_tokens=response.usage.prompt_tokens,
190
204
  out_tokens=response.usage.completion_tokens,