model-library 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,8 +13,10 @@ from typing import (
13
13
  TypeVar,
14
14
  )
15
15
 
16
+ import tiktoken
16
17
  from pydantic import model_serializer
17
18
  from pydantic.main import BaseModel
19
+ from tiktoken.core import Encoding
18
20
  from typing_extensions import override
19
21
 
20
22
  from model_library.base.batch import (
@@ -35,6 +37,7 @@ from model_library.base.output import (
35
37
  )
36
38
  from model_library.base.utils import (
37
39
  get_pretty_input_types,
40
+ serialize_for_tokenizing,
38
41
  )
39
42
  from model_library.exceptions import (
40
43
  ImmediateRetryException,
@@ -379,6 +382,20 @@ class LLM(ABC):
379
382
  """
380
383
  ...
381
384
 
385
+ @abstractmethod
386
+ async def build_body(
387
+ self,
388
+ input: Sequence[InputItem],
389
+ *,
390
+ tools: list[ToolDefinition],
391
+ **kwargs: Any,
392
+ ) -> dict[str, Any]:
393
+ """
394
+ Builds the body of the request to the model provider
395
+ Calls parse_input
396
+ """
397
+ ...
398
+
382
399
  @abstractmethod
383
400
  async def parse_input(
384
401
  self,
@@ -421,6 +438,87 @@ class LLM(ABC):
421
438
  """Upload a file to the model provider"""
422
439
  ...
423
440
 
441
+ async def get_encoding(self) -> Encoding:
442
+ """Get the appropriate tokenizer"""
443
+
444
+ model = self.model_name.lower()
445
+
446
+ if any(x in model for x in ["gpt-4o", "o1", "o3", "gpt-4.1", "gpt-5"]):
447
+ return tiktoken.get_encoding("o200k_base")
448
+ elif "gpt-4" in model or "gpt-3.5" in model:
449
+ try:
450
+ return tiktoken.encoding_for_model(self.model_name)
451
+ except KeyError:
452
+ return tiktoken.get_encoding("cl100k_base")
453
+ elif "claude" in model:
454
+ return tiktoken.get_encoding("cl100k_base")
455
+ elif "gemini" in model:
456
+ return tiktoken.get_encoding("o200k_base")
457
+ elif "llama" in model or "mistral" in model:
458
+ return tiktoken.get_encoding("cl100k_base")
459
+ else:
460
+ return tiktoken.get_encoding("cl100k_base")
461
+
462
+ async def stringify_input(
463
+ self,
464
+ input: Sequence[InputItem],
465
+ *,
466
+ history: Sequence[InputItem] = [],
467
+ tools: list[ToolDefinition] = [],
468
+ **kwargs: object,
469
+ ) -> str:
470
+ input = [*history, *input]
471
+
472
+ system_prompt = kwargs.pop(
473
+ "system_prompt", ""
474
+ ) # TODO: refactor along with system prompt arg change
475
+
476
+ # special case if using a delegate
477
+ # don't inherit method override by default
478
+ if self.delegate:
479
+ parsed_input = await self.delegate.parse_input(input, **kwargs)
480
+ parsed_tools = await self.delegate.parse_tools(tools)
481
+ else:
482
+ parsed_input = await self.parse_input(input, **kwargs)
483
+ parsed_tools = await self.parse_tools(tools)
484
+
485
+ serialized_input = serialize_for_tokenizing(parsed_input)
486
+ serialized_tools = serialize_for_tokenizing(parsed_tools)
487
+
488
+ combined = f"{system_prompt}\n{serialized_input}\n{serialized_tools}"
489
+
490
+ return combined
491
+
492
+ async def count_tokens(
493
+ self,
494
+ input: Sequence[InputItem],
495
+ *,
496
+ history: Sequence[InputItem] = [],
497
+ tools: list[ToolDefinition] = [],
498
+ **kwargs: object,
499
+ ) -> int:
500
+ """
501
+ Count the number of tokens for a query.
502
+ Combines parsed input and tools, then tokenizes the result.
503
+ """
504
+
505
+ if not input and not history:
506
+ return 0
507
+
508
+ if self.delegate:
509
+ encoding = await self.delegate.get_encoding()
510
+ else:
511
+ encoding = await self.get_encoding()
512
+ self.logger.debug(f"Token Count Encoding: {encoding}")
513
+
514
+ string_input = await self.stringify_input(
515
+ input, history=history, tools=tools, **kwargs
516
+ )
517
+
518
+ count = len(encoding.encode(string_input, disallowed_special=()))
519
+ self.logger.debug(f"Combined Token Count Input: {count}")
520
+ return count
521
+
424
522
  async def query_json(
425
523
  self,
426
524
  input: Sequence[InputItem],
@@ -58,6 +58,16 @@ class DelegateOnly(LLM):
58
58
  input, tools=tools, query_logger=query_logger, **kwargs
59
59
  )
60
60
 
61
+ @override
62
+ async def build_body(
63
+ self,
64
+ input: Sequence[InputItem],
65
+ *,
66
+ tools: list[ToolDefinition],
67
+ **kwargs: object,
68
+ ) -> dict[str, Any]:
69
+ raise DelegateOnlyException()
70
+
61
71
  @override
62
72
  async def parse_input(
63
73
  self,
@@ -74,8 +74,6 @@ class ToolCall(BaseModel):
74
74
  --- INPUT ---
75
75
  """
76
76
 
77
- RawResponse = Any
78
-
79
77
 
80
78
  class ToolInput(BaseModel):
81
79
  tools: list[ToolDefinition] = []
@@ -90,11 +88,16 @@ class TextInput(BaseModel):
90
88
  text: str
91
89
 
92
90
 
93
- RawInputItem = dict[
94
- str, Any
95
- ] # to pass in, for example, a mock convertsation with {"role": "user", "content": "Hello"}
91
+ class RawResponse(BaseModel):
92
+ # used to store a received response
93
+ response: Any
94
+
95
+
96
+ class RawInput(BaseModel):
97
+ # used to pass in anything provider specific (e.g. a mock conversation)
98
+ input: Any
96
99
 
97
100
 
98
101
  InputItem = (
99
- TextInput | FileInput | ToolResult | RawInputItem | RawResponse
100
- ) # input item can either be a prompt, a file (image or file), a tool call result, raw input, or a previous response
102
+ TextInput | FileInput | ToolResult | RawInput | RawResponse
103
+ ) # input item can either be a prompt, a file (image or file), a tool call result, a previous response, or raw input
@@ -24,6 +24,11 @@ class Citation(BaseModel):
24
24
  index: int | None = None
25
25
  container_id: str | None = None
26
26
 
27
+ @override
28
+ def __repr__(self):
29
+ attrs = vars(self).copy()
30
+ return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2)}\n)"
31
+
27
32
 
28
33
  class QueryResultExtras(BaseModel):
29
34
  citations: list[Citation] = Field(default_factory=list)
@@ -1,18 +1,34 @@
1
- from typing import Sequence, TypeVar, cast
1
+ import json
2
+ from typing import Any, Sequence, TypeVar
3
+
4
+ from pydantic import BaseModel
2
5
 
3
6
  from model_library.base.input import (
4
7
  FileBase,
5
8
  InputItem,
6
- RawInputItem,
9
+ RawInput,
10
+ RawResponse,
7
11
  TextInput,
8
12
  ToolResult,
9
13
  )
10
14
  from model_library.utils import truncate_str
11
- from pydantic import BaseModel
12
15
 
13
16
  T = TypeVar("T", bound=BaseModel)
14
17
 
15
18
 
19
+ def serialize_for_tokenizing(content: Any) -> str:
20
+ """
21
+ Serialize parsed content into a string for tokenization
22
+ """
23
+ parts: list[str] = []
24
+ if content:
25
+ if isinstance(content, str):
26
+ parts.append(content)
27
+ else:
28
+ parts.append(json.dumps(content, default=str))
29
+ return "\n".join(parts)
30
+
31
+
16
32
  def add_optional(
17
33
  a: int | float | T | None, b: int | float | T | None
18
34
  ) -> int | float | T | None:
@@ -54,11 +70,9 @@ def get_pretty_input_types(input: Sequence["InputItem"], verbose: bool = False)
54
70
  return repr(item)
55
71
  case ToolResult():
56
72
  return repr(item)
57
- case dict():
58
- item = cast(RawInputItem, item)
73
+ case RawInput():
59
74
  return repr(item)
60
- case _:
61
- # RawResponse
75
+ case RawResponse():
62
76
  return repr(item)
63
77
 
64
78
  processed_items = [f" {process_item(item)}" for item in input]
@@ -146,6 +146,17 @@ class BadInputError(Exception):
146
146
  super().__init__(message or BadInputError.DEFAULT_MESSAGE)
147
147
 
148
148
 
149
+ class NoMatchingToolCallError(Exception):
150
+ """
151
+ Raised when a tool call result is provided with no matching tool call
152
+ """
153
+
154
+ DEFAULT_MESSAGE: str = "Tool call result provided with no matching tool call"
155
+
156
+ def __init__(self, message: str | None = None):
157
+ super().__init__(message or NoMatchingToolCallError.DEFAULT_MESSAGE)
158
+
159
+
149
160
  # Add more retriable exceptions as needed
150
161
  # Providers that don't have an explicit rate limit error are handled manually
151
162
  # by wrapping errored Http/gRPC requests with a BackoffRetryException
model_library/logging.py CHANGED
@@ -6,7 +6,11 @@ from rich.logging import RichHandler
6
6
  _llm_logger = logging.getLogger("llm")
7
7
 
8
8
 
9
- def set_logging(enable: bool = True, handler: logging.Handler | None = None):
9
+ def set_logging(
10
+ enable: bool = True,
11
+ level: int = logging.INFO,
12
+ handler: logging.Handler | None = None,
13
+ ):
10
14
  """
11
15
  Sets up logging for the model library
12
16
 
@@ -15,7 +19,7 @@ def set_logging(enable: bool = True, handler: logging.Handler | None = None):
15
19
  handler (logging.Handler, optional): A custom logging handler. Defaults to RichHandler.
16
20
  """
17
21
  if enable:
18
- _llm_logger.setLevel(logging.INFO)
22
+ _llm_logger.setLevel(level)
19
23
  else:
20
24
  _llm_logger.setLevel(logging.CRITICAL)
21
25
 
@@ -22,6 +22,7 @@ from model_library.base import (
22
22
  ToolDefinition,
23
23
  ToolResult,
24
24
  )
25
+ from model_library.base.input import RawResponse
25
26
  from model_library.exceptions import (
26
27
  BadInputError,
27
28
  MaxOutputTokensExceededError,
@@ -65,8 +66,6 @@ class AI21LabsModel(LLM):
65
66
  match item:
66
67
  case TextInput():
67
68
  new_input.append(ChatMessage(role="user", content=item.text))
68
- case AssistantMessage():
69
- new_input.append(item)
70
69
  case ToolResult():
71
70
  new_input.append(
72
71
  ToolMessage(
@@ -74,7 +73,9 @@ class AI21LabsModel(LLM):
74
73
  content=item.result,
75
74
  tool_call_id=item.tool_call.id,
76
75
  )
77
- )
76
+ ) # TODO: tool calling metadata and test
77
+ case RawResponse():
78
+ new_input.append(item.response)
78
79
  case _:
79
80
  raise BadInputError("Unsupported input type")
80
81
  return new_input
@@ -133,14 +134,13 @@ class AI21LabsModel(LLM):
133
134
  raise NotImplementedError()
134
135
 
135
136
  @override
136
- async def _query_impl(
137
+ async def build_body(
137
138
  self,
138
139
  input: Sequence[InputItem],
139
140
  *,
140
141
  tools: list[ToolDefinition],
141
- query_logger: logging.Logger,
142
142
  **kwargs: object,
143
- ) -> QueryResult:
143
+ ) -> dict[str, Any]:
144
144
  messages: list[ChatMessage] = []
145
145
  if "system_prompt" in kwargs:
146
146
  messages.append(
@@ -162,6 +162,18 @@ class AI21LabsModel(LLM):
162
162
  body["top_p"] = self.top_p
163
163
 
164
164
  body.update(kwargs)
165
+ return body
166
+
167
+ @override
168
+ async def _query_impl(
169
+ self,
170
+ input: Sequence[InputItem],
171
+ *,
172
+ tools: list[ToolDefinition],
173
+ query_logger: logging.Logger,
174
+ **kwargs: object,
175
+ ) -> QueryResult:
176
+ body = await self.build_body(input, tools=tools, **kwargs)
165
177
 
166
178
  response: ChatCompletionResponse = (
167
179
  await self.get_client().chat.completions.create(**body, stream=False) # pyright: ignore[reportAny, reportUnknownMemberType]
@@ -186,7 +198,7 @@ class AI21LabsModel(LLM):
186
198
 
187
199
  output = QueryResult(
188
200
  output_text=choice.message.content,
189
- history=[*input, choice.message],
201
+ history=[*input, RawResponse(response=choice.message)],
190
202
  metadata=QueryResultMetadata(
191
203
  in_tokens=response.usage.prompt_tokens,
192
204
  out_tokens=response.usage.completion_tokens,
@@ -13,24 +13,26 @@ from typing_extensions import override
13
13
 
14
14
  from model_library.base import (
15
15
  LLM,
16
+ FileBase,
16
17
  FileInput,
17
18
  FileWithBase64,
18
19
  FileWithId,
19
- FileWithUrl,
20
20
  InputItem,
21
21
  LLMConfig,
22
22
  QueryResult,
23
23
  QueryResultMetadata,
24
+ RawInput,
25
+ RawResponse,
24
26
  TextInput,
25
27
  ToolBody,
26
28
  ToolCall,
27
29
  ToolDefinition,
28
30
  ToolResult,
29
31
  )
30
- from model_library.base.input import FileBase
31
32
  from model_library.exceptions import (
32
33
  BadInputError,
33
34
  MaxOutputTokensExceededError,
35
+ NoMatchingToolCallError,
34
36
  )
35
37
  from model_library.model_utils import get_default_budget_tokens
36
38
  from model_library.register_models import register_provider
@@ -70,6 +72,20 @@ class AmazonModel(LLM):
70
72
 
71
73
  cache_control = {"type": "default"}
72
74
 
75
+ async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
76
+ raw_responses = [x for x in input if isinstance(x, RawResponse)]
77
+ tool_call_ids: list[str] = []
78
+
79
+ calls = [
80
+ y["toolUse"]
81
+ for x in raw_responses
82
+ if "content" in x.response
83
+ for y in x.response["content"]
84
+ if "toolUse" in y
85
+ ]
86
+ tool_call_ids.extend([x["toolUseId"] for x in calls])
87
+ return tool_call_ids
88
+
73
89
  @override
74
90
  async def parse_input(
75
91
  self,
@@ -77,58 +93,63 @@ class AmazonModel(LLM):
77
93
  **kwargs: Any,
78
94
  ) -> list[dict[str, Any]]:
79
95
  new_input: list[dict[str, Any] | Any] = []
96
+
80
97
  content_user: list[dict[str, Any]] = []
81
98
 
99
+ def flush_content_user():
100
+ if content_user:
101
+ # NOTE: must make new object as we clear()
102
+ new_input.append({"role": "user", "content": content_user.copy()})
103
+ content_user.clear()
104
+
105
+ tool_call_ids = await self.get_tool_call_ids(input)
106
+
82
107
  for item in input:
108
+ if isinstance(item, TextInput):
109
+ content_user.append({"text": item.text})
110
+ continue
111
+
112
+ if isinstance(item, FileBase):
113
+ match item.type:
114
+ case "image":
115
+ parsed = await self.parse_image(item)
116
+ case "file":
117
+ parsed = await self.parse_file(item)
118
+ content_user.append(parsed)
119
+ continue
120
+
121
+ # non content user item
122
+ flush_content_user()
123
+
83
124
  match item:
84
- case TextInput():
85
- content_user.append({"text": item.text})
86
- case FileWithBase64() | FileWithUrl() | FileWithId():
87
- match item.type:
88
- case "image":
89
- content_user.append(await self.parse_image(item))
90
- case "file":
91
- content_user.append(await self.parse_file(item))
92
- case _:
93
- if content_user:
94
- new_input.append({"role": "user", "content": content_user})
95
- content_user = []
96
- match item:
97
- case ToolResult():
98
- if not (
99
- isinstance(x, dict)
100
- and "toolUse" in x
101
- and x["toolUse"].get("toolUseId")
102
- == item.tool_call.call_id
103
- for x in new_input
104
- ):
105
- raise Exception(
106
- "Tool call result provided with no matching tool call"
107
- )
108
- new_input.append(
125
+ case ToolResult():
126
+ if item.tool_call.id not in tool_call_ids:
127
+ raise NoMatchingToolCallError()
128
+
129
+ new_input.append(
130
+ {
131
+ "role": "user",
132
+ "content": [
109
133
  {
110
- "role": "user",
111
- "content": [
112
- {
113
- "toolResult": {
114
- "toolUseId": item.tool_call.id,
115
- "content": [
116
- {"json": {"result": item.result}}
117
- ],
118
- }
119
- }
120
- ],
134
+ "toolResult": {
135
+ "toolUseId": item.tool_call.id,
136
+ "content": [{"json": {"result": item.result}}],
137
+ }
121
138
  }
122
- )
123
- case dict(): # RawInputItem and RawResponse
124
- new_input.append(item)
139
+ ],
140
+ }
141
+ )
142
+ case RawResponse():
143
+ new_input.append(item.response)
144
+ case RawInput():
145
+ new_input.append(item.input)
125
146
 
126
- if content_user:
127
- if self.supports_cache:
128
- if not isinstance(input[-1], FileBase):
129
- # last item cannot be file
130
- content_user.append({"cachePoint": self.cache_control})
131
- new_input.append({"role": "user", "content": content_user})
147
+ if content_user and self.supports_cache:
148
+ if not isinstance(input[-1], FileBase):
149
+ # last item cannot be file
150
+ content_user.append({"cachePoint": self.cache_control})
151
+
152
+ flush_content_user()
132
153
 
133
154
  return new_input
134
155
 
@@ -196,6 +217,7 @@ class AmazonModel(LLM):
196
217
  ) -> FileWithId:
197
218
  raise NotImplementedError()
198
219
 
220
+ @override
199
221
  async def build_body(
200
222
  self,
201
223
  input: Sequence[InputItem],
@@ -383,5 +405,5 @@ class AmazonModel(LLM):
383
405
  reasoning=reasoning,
384
406
  metadata=metadata,
385
407
  tool_calls=tool_calls,
386
- history=[*input, messages],
408
+ history=[*input, RawResponse(response=messages)],
387
409
  )