model-library 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,15 +3,15 @@ import logging
3
3
  from typing import Any, Literal, Sequence, cast
4
4
 
5
5
  from anthropic import AsyncAnthropic
6
- from anthropic.types import TextBlock, ToolUseBlock
7
6
  from anthropic.types.beta.beta_tool_use_block import BetaToolUseBlock
8
- from anthropic.types.message import Message
7
+ from anthropic.types.beta.parsed_beta_message import ParsedBetaMessage
9
8
  from typing_extensions import override
10
9
 
11
10
  from model_library import model_library_settings
12
11
  from model_library.base import (
13
12
  LLM,
14
13
  BatchResult,
14
+ FileBase,
15
15
  FileInput,
16
16
  FileWithBase64,
17
17
  FileWithId,
@@ -22,7 +22,8 @@ from model_library.base import (
22
22
  QueryResult,
23
23
  QueryResultCost,
24
24
  QueryResultMetadata,
25
- RawInputItem,
25
+ RawInput,
26
+ RawResponse,
26
27
  TextInput,
27
28
  ToolBody,
28
29
  ToolCall,
@@ -31,6 +32,7 @@ from model_library.base import (
31
32
  )
32
33
  from model_library.exceptions import (
33
34
  MaxOutputTokensExceededError,
35
+ NoMatchingToolCallError,
34
36
  )
35
37
  from model_library.model_utils import get_default_budget_tokens
36
38
  from model_library.providers.openai import OpenAIModel
@@ -38,8 +40,6 @@ from model_library.register_models import register_provider
38
40
  from model_library.utils import (
39
41
  create_openai_client_with_defaults,
40
42
  default_httpx_client,
41
- filter_empty_text_blocks,
42
- normalize_tool_result,
43
43
  )
44
44
 
45
45
 
@@ -62,9 +62,9 @@ class AnthropicBatchMixin(LLMBatchMixin):
62
62
 
63
63
  Format: {"custom_id": str, "params": {...message params...}}
64
64
  """
65
- # Build the message body using the parent model's create_body method
65
+ # Build the message body using the parent model's build_body method
66
66
  tools = cast(list[ToolDefinition], kwargs.pop("tools", []))
67
- body = await self._root.create_body(input, tools=tools, **kwargs)
67
+ body = await self._root.build_body(input, tools=tools, **kwargs)
68
68
 
69
69
  return {
70
70
  "custom_id": custom_id,
@@ -300,6 +300,20 @@ class AnthropicModel(LLM):
300
300
  AnthropicBatchMixin(self) if self.supports_batch else None
301
301
  )
302
302
 
303
+ async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
304
+ raw_responses = [x for x in input if isinstance(x, RawResponse)]
305
+ tool_call_ids: list[str] = []
306
+
307
+ calls = [
308
+ y
309
+ for x in raw_responses
310
+ if isinstance(x.response, ParsedBetaMessage)
311
+ for y in x.response.content # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
312
+ if isinstance(y, BetaToolUseBlock)
313
+ ]
314
+ tool_call_ids.extend([x.id for x in calls])
315
+ return tool_call_ids
316
+
303
317
  @override
304
318
  async def parse_input(
305
319
  self,
@@ -307,77 +321,61 @@ class AnthropicModel(LLM):
307
321
  **kwargs: Any,
308
322
  ) -> list[dict[str, Any] | Any]:
309
323
  new_input: list[dict[str, Any] | Any] = []
324
+
310
325
  content_user: list[dict[str, Any]] = []
311
326
 
312
- # First pass: collect all tool calls from Message objects for validation
313
- tool_calls_in_input: set[str] = set()
314
- for item in input:
315
- if hasattr(item, "content") and hasattr(item, "role"):
316
- content_list = getattr(item, "content", [])
317
- for content in content_list:
318
- # Check for both ToolUseBlock and BetaToolUseBlock
319
- if isinstance(content, (ToolUseBlock, BetaToolUseBlock)):
320
- tool_calls_in_input.add(content.id)
327
+ def flush_content_user():
328
+ if content_user:
329
+ # NOTE: must make new object as we clear()
330
+ new_input.append({"role": "user", "content": content_user.copy()})
331
+ content_user.clear()
332
+
333
+ tool_call_ids = await self.get_tool_call_ids(input)
321
334
 
322
335
  for item in input:
336
+ if isinstance(item, TextInput):
337
+ content_user.append({"type": "text", "text": item.text})
338
+ continue
339
+
340
+ if isinstance(item, FileBase):
341
+ match item.type:
342
+ case "image":
343
+ parsed = await self.parse_image(item)
344
+ case "file":
345
+ parsed = await self.parse_file(item)
346
+ content_user.append(parsed)
347
+ continue
348
+
349
+ # non content user item
350
+ flush_content_user()
351
+
323
352
  match item:
324
- case TextInput():
325
- if item.text.strip():
326
- content_user.append({"type": "text", "text": item.text})
327
- case FileWithBase64() | FileWithUrl() | FileWithId():
328
- match item.type:
329
- case "image":
330
- content_user.append(await self.parse_image(item))
331
- case "file":
332
- content_user.append(await self.parse_file(item))
333
- case _:
334
- if content_user:
335
- filtered = filter_empty_text_blocks(content_user)
336
- if filtered:
337
- new_input.append({"role": "user", "content": filtered})
338
- content_user = []
339
- match item:
340
- case ToolResult():
341
- if item.tool_call.id not in tool_calls_in_input:
342
- raise Exception(
343
- "Tool call result provided with no matching tool call"
344
- )
345
- result_str = normalize_tool_result(item.result)
346
- new_input.append(
353
+ case ToolResult():
354
+ if item.tool_call.id not in tool_call_ids:
355
+ raise NoMatchingToolCallError()
356
+
357
+ new_input.append(
358
+ {
359
+ "role": "user",
360
+ "content": [
347
361
  {
348
- "role": "user",
349
- "content": [
350
- {
351
- "type": "tool_result",
352
- "tool_use_id": item.tool_call.id,
353
- "content": [
354
- {"type": "text", "text": result_str}
355
- ],
356
- }
357
- ],
362
+ "type": "tool_result",
363
+ "tool_use_id": item.tool_call.id,
364
+ "content": [{"type": "text", "text": item.result}],
358
365
  }
359
- )
360
- case dict(): # RawInputItem
361
- item = cast(RawInputItem, item)
362
- new_input.append(item)
363
- case _: # RawResponse
364
- item = cast(Message, item)
365
- filtered_content = [
366
- block
367
- for block in item.content
368
- if not isinstance(block, TextBlock)
369
- or block.text.strip()
370
- ]
371
- if filtered_content:
372
- new_input.append(
373
- {"role": "assistant", "content": filtered_content}
374
- )
375
-
376
- if content_user:
377
- filtered = filter_empty_text_blocks(content_user)
378
- if filtered:
379
- new_input.append({"role": "user", "content": filtered})
366
+ ],
367
+ }
368
+ )
369
+ case RawResponse():
370
+ content = cast(ParsedBetaMessage, item.response).content
371
+ new_input.append({"role": "assistant", "content": content})
372
+ case RawInput():
373
+ new_input.append(item.input)
374
+
375
+ # in case content user item is the last item
376
+ flush_content_user()
380
377
 
378
+ # cache control
381
379
  if new_input:
382
380
  last_msg = new_input[-1]
383
381
  if not isinstance(last_msg, dict):
@@ -495,7 +493,7 @@ class AnthropicModel(LLM):
495
493
  bytes: io.BytesIO,
496
494
  type: Literal["image", "file"] = "file",
497
495
  ) -> FileWithId:
498
- file_mime = f"image/{mime}" if type == "image" else mime # TODO:
496
+ file_mime = f"image/{mime}" if type == "image" else mime
499
497
  response = await self.get_client().beta.files.upload(
500
498
  file=(
501
499
  name,
@@ -513,7 +511,8 @@ class AnthropicModel(LLM):
513
511
 
514
512
  cache_control = {"type": "ephemeral"} # 5 min cache
515
513
 
516
- async def create_body(
514
+ @override
515
+ async def build_body(
517
516
  self,
518
517
  input: Sequence[InputItem],
519
518
  *,
@@ -573,7 +572,7 @@ class AnthropicModel(LLM):
573
572
  input, tools=tools, query_logger=query_logger, **kwargs
574
573
  )
575
574
 
576
- body = await self.create_body(input, tools=tools, **kwargs)
575
+ body = await self.build_body(input, tools=tools, **kwargs)
577
576
 
578
577
  client = self.get_client()
579
578
 
@@ -630,9 +629,37 @@ class AnthropicModel(LLM):
630
629
  cache_write_tokens=message.usage.cache_creation_input_tokens,
631
630
  ),
632
631
  tool_calls=tool_calls,
633
- history=[*input, message],
632
+ history=[*input, RawResponse(response=message)],
634
633
  )
635
634
 
635
+ @override
636
+ async def count_tokens(
637
+ self,
638
+ input: Sequence[InputItem],
639
+ *,
640
+ history: Sequence[InputItem] = [],
641
+ tools: list[ToolDefinition] = [],
642
+ **kwargs: object,
643
+ ) -> int:
644
+ """
645
+ Count the number of tokens using Anthropic's native token counting API.
646
+ https://docs.anthropic.com/en/docs/build-with-claude/token-counting
647
+ """
648
+ input = [*history, *input]
649
+ if not input:
650
+ return 0
651
+
652
+ body = await self.build_body(input, tools=tools, **kwargs)
653
+
654
+ # Remove fields not supported by count_tokens endpoint
655
+ body.pop("max_tokens", None)
656
+ body.pop("temperature", None)
657
+
658
+ client = self.get_client()
659
+ response = await client.messages.count_tokens(**body)
660
+
661
+ return response.input_tokens
662
+
636
663
  @override
637
664
  async def _calculate_cost(
638
665
  self,
@@ -2,8 +2,6 @@ import io
2
2
  import json
3
3
  from typing import TYPE_CHECKING, Any, Final, Sequence, cast
4
4
 
5
- from typing_extensions import override
6
-
7
5
  from google.genai.types import (
8
6
  BatchJob,
9
7
  Content,
@@ -11,6 +9,8 @@ from google.genai.types import (
11
9
  JobState,
12
10
  UploadFileConfig,
13
11
  )
12
+ from typing_extensions import override
13
+
14
14
  from model_library.base import BatchResult, InputItem, LLMBatchMixin
15
15
 
16
16
  if TYPE_CHECKING:
@@ -144,7 +144,7 @@ class GoogleBatchMixin(LLMBatchMixin):
144
144
  **kwargs: object,
145
145
  ) -> dict[str, Any]:
146
146
  self._root.logger.debug(f"Creating batch request for custom_id: {custom_id}")
147
- body = await self._root.create_body(input, tools=[], **kwargs)
147
+ body = await self._root.build_body(input, tools=[], **kwargs)
148
148
 
149
149
  contents_any = body["contents"]
150
150
  serialized_contents: list[dict[str, Any]] = [
@@ -1,13 +1,16 @@
1
1
  import base64
2
2
  import io
3
3
  import logging
4
+ import uuid
4
5
  from typing import Any, Literal, Sequence, cast
5
6
 
6
7
  from google.genai import Client
7
8
  from google.genai import errors as genai_errors
8
9
  from google.genai.types import (
9
10
  Content,
11
+ CountTokensConfig,
10
12
  File,
13
+ FinishReason,
11
14
  FunctionDeclaration,
12
15
  GenerateContentConfig,
13
16
  GenerateContentResponse,
@@ -21,13 +24,13 @@ from google.genai.types import (
21
24
  Tool,
22
25
  ToolListUnion,
23
26
  UploadFileConfig,
24
- FinishReason,
25
27
  )
26
28
  from typing_extensions import override
27
29
 
28
30
  from model_library import model_library_settings
29
31
  from model_library.base import (
30
32
  LLM,
33
+ FileBase,
31
34
  FileInput,
32
35
  FileWithBase64,
33
36
  FileWithId,
@@ -40,6 +43,8 @@ from model_library.base import (
40
43
  QueryResult,
41
44
  QueryResultCost,
42
45
  QueryResultMetadata,
46
+ RawInput,
47
+ RawResponse,
43
48
  TextInput,
44
49
  ToolBody,
45
50
  ToolCall,
@@ -54,8 +59,6 @@ from model_library.exceptions import (
54
59
  )
55
60
  from model_library.providers.google.batch import GoogleBatchMixin
56
61
  from model_library.register_models import register_provider
57
- from model_library.utils import normalize_tool_result
58
- import uuid
59
62
 
60
63
 
61
64
  def generate_tool_call_id(tool_name: str) -> str:
@@ -146,63 +149,52 @@ class GoogleModel(LLM):
146
149
  input: Sequence[InputItem],
147
150
  **kwargs: Any,
148
151
  ) -> list[Content]:
149
- parsed_input: list[Content] = []
150
- parts: list[Part] = []
152
+ new_input: list[Content] = []
151
153
 
152
- def flush_parts():
153
- nonlocal parts
154
+ content_user: list[Part] = []
154
155
 
155
- if parts:
156
- parsed_input.append(Content(parts=parts, role="user"))
157
- parts = []
156
+ def flush_content_user():
157
+ if content_user:
158
+ new_input.append(Content(parts=content_user, role="user"))
159
+ content_user.clear()
158
160
 
159
161
  for item in input:
160
- match item:
161
- case TextInput():
162
- if item.text.strip():
163
- parts.append(Part.from_text(text=item.text))
162
+ if isinstance(item, TextInput):
163
+ content_user.append(Part.from_text(text=item.text))
164
+ continue
165
+
166
+ if isinstance(item, FileBase):
167
+ parsed = await self.parse_file(item)
168
+ content_user.append(parsed)
169
+ continue
164
170
 
165
- case FileWithBase64() | FileWithUrl() | FileWithId():
166
- part = await self.parse_file(item)
167
- parts.append(part)
171
+ # non content user item
172
+ flush_content_user()
168
173
 
174
+ match item:
169
175
  case ToolResult():
170
- flush_parts()
171
- result_str = normalize_tool_result(item.result)
172
- parsed_input.append(
176
+ # id check
177
+ new_input.append(
173
178
  Content(
174
179
  role="function",
175
180
  parts=[
176
181
  Part.from_function_response(
177
182
  name=item.tool_call.name,
178
- response={"result": result_str},
183
+ response={"result": item.result},
179
184
  )
180
185
  ],
181
186
  )
182
187
  )
183
188
 
184
- case GenerateContentResponse():
185
- flush_parts()
186
- candidates = item.candidates
187
- if candidates and candidates[0]:
188
- content0 = candidates[0].content
189
- if content0 is not None:
190
- parsed_input.append(content0)
191
- else:
192
- self.logger.debug(
193
- "GenerateContentResponse missing candidates; skipping"
194
- )
195
-
196
- case Content():
197
- flush_parts()
198
- parsed_input.append(item)
189
+ case RawResponse():
190
+ new_input.extend(item.response)
191
+ case RawInput():
192
+ new_input.append(item.input)
199
193
 
200
- case _:
201
- raise BadInputError(f"Unsupported input type: {type(item)}")
194
+ # in case content user item is the last item
195
+ flush_content_user()
202
196
 
203
- flush_parts()
204
-
205
- return parsed_input
197
+ return new_input
206
198
 
207
199
  @override
208
200
  async def parse_file(self, file: FileInput) -> Part:
@@ -284,7 +276,8 @@ class GoogleModel(LLM):
284
276
  mime=mime,
285
277
  )
286
278
 
287
- async def create_body(
279
+ @override
280
+ async def build_body(
288
281
  self,
289
282
  input: Sequence[InputItem],
290
283
  *,
@@ -337,7 +330,7 @@ class GoogleModel(LLM):
337
330
  query_logger: logging.Logger,
338
331
  **kwargs: object,
339
332
  ) -> QueryResult:
340
- body: dict[str, Any] = await self.create_body(input, tools=tools, **kwargs)
333
+ body: dict[str, Any] = await self.build_body(input, tools=tools, **kwargs)
341
334
 
342
335
  text: str = ""
343
336
  reasoning: str = ""
@@ -395,7 +388,7 @@ class GoogleModel(LLM):
395
388
  result = QueryResult(
396
389
  output_text=text,
397
390
  reasoning=reasoning,
398
- history=[*input, *contents],
391
+ history=[*input, RawResponse(response=contents)],
399
392
  tool_calls=tool_calls,
400
393
  )
401
394
 
@@ -410,6 +403,51 @@ class GoogleModel(LLM):
410
403
  )
411
404
  return result
412
405
 
406
+ @override
407
+ async def count_tokens(
408
+ self,
409
+ input: Sequence[InputItem],
410
+ *,
411
+ history: Sequence[InputItem] = [],
412
+ tools: list[ToolDefinition] = [],
413
+ **kwargs: object,
414
+ ) -> int:
415
+ """
416
+ Count the number of tokens using Google's native token counting API.
417
+ https://ai.google.dev/gemini-api/docs/tokens
418
+
419
+ Only Vertex AI supports system_instruction and tools in count_tokens.
420
+ For Gemini API, fall back to the base implementation.
421
+ TODO: implement token counting for non-Vertex models.
422
+ """
423
+ if not self.provider_config.use_vertex:
424
+ return await super().count_tokens(
425
+ input, history=history, tools=tools, **kwargs
426
+ )
427
+
428
+ input = [*history, *input]
429
+ if not input:
430
+ return 0
431
+
432
+ system_prompt = kwargs.pop("system_prompt", None)
433
+ contents = await self.parse_input(input, **kwargs)
434
+ parsed_tools = await self.parse_tools(tools) if tools else None
435
+ config = CountTokensConfig(
436
+ system_instruction=str(system_prompt) if system_prompt else None,
437
+ tools=parsed_tools,
438
+ )
439
+
440
+ response = await self.client.aio.models.count_tokens(
441
+ model=self.model_name,
442
+ contents=cast(Any, contents),
443
+ config=config,
444
+ )
445
+
446
+ if response.total_tokens is None:
447
+ raise ValueError("count_tokens returned None")
448
+
449
+ return response.total_tokens
450
+
413
451
  @override
414
452
  async def _calculate_cost(
415
453
  self,
@@ -446,7 +484,7 @@ class GoogleModel(LLM):
446
484
  **kwargs: object,
447
485
  ) -> PydanticT:
448
486
  # Create the request body with JSON schema
449
- body: dict[str, Any] = await self.create_body(input, tools=[], **kwargs)
487
+ body: dict[str, Any] = await self.build_body(input, tools=[], **kwargs)
450
488
 
451
489
  # Get the JSON schema from the Pydantic model
452
490
  json_schema = pydantic_model.model_json_schema()
@@ -2,12 +2,16 @@ from typing import Literal
2
2
 
3
3
  from model_library import model_library_settings
4
4
  from model_library.base import DelegateOnly, LLMConfig
5
+ from model_library.base.input import InputItem, ToolDefinition
5
6
  from model_library.providers.anthropic import AnthropicModel
6
7
  from model_library.register_models import register_provider
7
8
  from model_library.utils import default_httpx_client
8
9
 
9
10
  from anthropic import AsyncAnthropic
10
11
 
12
+ from typing import Sequence
13
+ from typing_extensions import override
14
+
11
15
 
12
16
  @register_provider("minimax")
13
17
  class MinimaxModel(DelegateOnly):
@@ -31,3 +35,18 @@ class MinimaxModel(DelegateOnly):
31
35
  max_retries=1,
32
36
  ),
33
37
  )
38
+
39
+ # minimax client shares anthropic's syntax
40
+ @override
41
+ async def count_tokens(
42
+ self,
43
+ input: Sequence[InputItem],
44
+ *,
45
+ history: Sequence[InputItem] = [],
46
+ tools: list[ToolDefinition] = [],
47
+ **kwargs: object,
48
+ ) -> int:
49
+ assert self.delegate
50
+ return await self.delegate.count_tokens(
51
+ input, history=history, tools=tools, **kwargs
52
+ )