model-library 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,7 @@ import asyncio
3
3
  import base64
4
4
  import io
5
5
  import json
6
+ import logging
6
7
  from typing import Any, Literal, Sequence, cast
7
8
 
8
9
  import boto3
@@ -12,24 +13,26 @@ from typing_extensions import override
12
13
 
13
14
  from model_library.base import (
14
15
  LLM,
16
+ FileBase,
15
17
  FileInput,
16
18
  FileWithBase64,
17
19
  FileWithId,
18
- FileWithUrl,
19
20
  InputItem,
20
21
  LLMConfig,
21
22
  QueryResult,
22
23
  QueryResultMetadata,
24
+ RawInput,
25
+ RawResponse,
23
26
  TextInput,
24
27
  ToolBody,
25
28
  ToolCall,
26
29
  ToolDefinition,
27
30
  ToolResult,
28
31
  )
29
- from model_library.base.input import FileBase
30
32
  from model_library.exceptions import (
31
33
  BadInputError,
32
34
  MaxOutputTokensExceededError,
35
+ NoMatchingToolCallError,
33
36
  )
34
37
  from model_library.model_utils import get_default_budget_tokens
35
38
  from model_library.register_models import register_provider
@@ -69,6 +72,20 @@ class AmazonModel(LLM):
69
72
 
70
73
  cache_control = {"type": "default"}
71
74
 
75
+ async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
76
+ raw_responses = [x for x in input if isinstance(x, RawResponse)]
77
+ tool_call_ids: list[str] = []
78
+
79
+ calls = [
80
+ y["toolUse"]
81
+ for x in raw_responses
82
+ if "content" in x.response
83
+ for y in x.response["content"]
84
+ if "toolUse" in y
85
+ ]
86
+ tool_call_ids.extend([x["toolUseId"] for x in calls])
87
+ return tool_call_ids
88
+
72
89
  @override
73
90
  async def parse_input(
74
91
  self,
@@ -76,58 +93,63 @@ class AmazonModel(LLM):
76
93
  **kwargs: Any,
77
94
  ) -> list[dict[str, Any]]:
78
95
  new_input: list[dict[str, Any] | Any] = []
96
+
79
97
  content_user: list[dict[str, Any]] = []
80
98
 
99
+ def flush_content_user():
100
+ if content_user:
101
+ # NOTE: must make new object as we clear()
102
+ new_input.append({"role": "user", "content": content_user.copy()})
103
+ content_user.clear()
104
+
105
+ tool_call_ids = await self.get_tool_call_ids(input)
106
+
81
107
  for item in input:
108
+ if isinstance(item, TextInput):
109
+ content_user.append({"text": item.text})
110
+ continue
111
+
112
+ if isinstance(item, FileBase):
113
+ match item.type:
114
+ case "image":
115
+ parsed = await self.parse_image(item)
116
+ case "file":
117
+ parsed = await self.parse_file(item)
118
+ content_user.append(parsed)
119
+ continue
120
+
121
+ # non content user item
122
+ flush_content_user()
123
+
82
124
  match item:
83
- case TextInput():
84
- content_user.append({"text": item.text})
85
- case FileWithBase64() | FileWithUrl() | FileWithId():
86
- match item.type:
87
- case "image":
88
- content_user.append(await self.parse_image(item))
89
- case "file":
90
- content_user.append(await self.parse_file(item))
91
- case _:
92
- if content_user:
93
- new_input.append({"role": "user", "content": content_user})
94
- content_user = []
95
- match item:
96
- case ToolResult():
97
- if not (
98
- isinstance(x, dict)
99
- and "toolUse" in x
100
- and x["toolUse"].get("toolUseId")
101
- == item.tool_call.call_id
102
- for x in new_input
103
- ):
104
- raise Exception(
105
- "Tool call result provided with no matching tool call"
106
- )
107
- new_input.append(
125
+ case ToolResult():
126
+ if item.tool_call.id not in tool_call_ids:
127
+ raise NoMatchingToolCallError()
128
+
129
+ new_input.append(
130
+ {
131
+ "role": "user",
132
+ "content": [
108
133
  {
109
- "role": "user",
110
- "content": [
111
- {
112
- "toolResult": {
113
- "toolUseId": item.tool_call.id,
114
- "content": [
115
- {"json": {"result": item.result}}
116
- ],
117
- }
118
- }
119
- ],
134
+ "toolResult": {
135
+ "toolUseId": item.tool_call.id,
136
+ "content": [{"json": {"result": item.result}}],
137
+ }
120
138
  }
121
- )
122
- case dict(): # RawInputItem and RawResponse
123
- new_input.append(item)
139
+ ],
140
+ }
141
+ )
142
+ case RawResponse():
143
+ new_input.append(item.response)
144
+ case RawInput():
145
+ new_input.append(item.input)
124
146
 
125
- if content_user:
126
- if self.supports_cache:
127
- if not isinstance(input[-1], FileBase):
128
- # last item cannot be file
129
- content_user.append({"cachePoint": self.cache_control})
130
- new_input.append({"role": "user", "content": content_user})
147
+ if content_user and self.supports_cache:
148
+ if not isinstance(input[-1], FileBase):
149
+ # last item cannot be file
150
+ content_user.append({"cachePoint": self.cache_control})
151
+
152
+ flush_content_user()
131
153
 
132
154
  return new_input
133
155
 
@@ -195,6 +217,7 @@ class AmazonModel(LLM):
195
217
  ) -> FileWithId:
196
218
  raise NotImplementedError()
197
219
 
220
+ @override
198
221
  async def build_body(
199
222
  self,
200
223
  input: Sequence[InputItem],
@@ -337,6 +360,7 @@ class AmazonModel(LLM):
337
360
  input: Sequence[InputItem],
338
361
  *,
339
362
  tools: list[ToolDefinition],
363
+ query_logger: logging.Logger,
340
364
  **kwargs: object,
341
365
  ) -> QueryResult:
342
366
  body = await self.build_body(input, tools=tools, **kwargs)
@@ -381,5 +405,5 @@ class AmazonModel(LLM):
381
405
  reasoning=reasoning,
382
406
  metadata=metadata,
383
407
  tool_calls=tool_calls,
384
- history=[*input, messages],
408
+ history=[*input, RawResponse(response=messages)],
385
409
  )
@@ -1,16 +1,17 @@
1
1
  import io
2
+ import logging
2
3
  from typing import Any, Literal, Sequence, cast
3
4
 
4
5
  from anthropic import AsyncAnthropic
5
- from anthropic.types import TextBlock, ToolUseBlock
6
6
  from anthropic.types.beta.beta_tool_use_block import BetaToolUseBlock
7
- from anthropic.types.message import Message
7
+ from anthropic.types.beta.parsed_beta_message import ParsedBetaMessage
8
8
  from typing_extensions import override
9
9
 
10
10
  from model_library import model_library_settings
11
11
  from model_library.base import (
12
12
  LLM,
13
13
  BatchResult,
14
+ FileBase,
14
15
  FileInput,
15
16
  FileWithBase64,
16
17
  FileWithId,
@@ -21,7 +22,8 @@ from model_library.base import (
21
22
  QueryResult,
22
23
  QueryResultCost,
23
24
  QueryResultMetadata,
24
- RawInputItem,
25
+ RawInput,
26
+ RawResponse,
25
27
  TextInput,
26
28
  ToolBody,
27
29
  ToolCall,
@@ -30,6 +32,7 @@ from model_library.base import (
30
32
  )
31
33
  from model_library.exceptions import (
32
34
  MaxOutputTokensExceededError,
35
+ NoMatchingToolCallError,
33
36
  )
34
37
  from model_library.model_utils import get_default_budget_tokens
35
38
  from model_library.providers.openai import OpenAIModel
@@ -37,8 +40,6 @@ from model_library.register_models import register_provider
37
40
  from model_library.utils import (
38
41
  create_openai_client_with_defaults,
39
42
  default_httpx_client,
40
- filter_empty_text_blocks,
41
- normalize_tool_result,
42
43
  )
43
44
 
44
45
 
@@ -61,9 +62,9 @@ class AnthropicBatchMixin(LLMBatchMixin):
61
62
 
62
63
  Format: {"custom_id": str, "params": {...message params...}}
63
64
  """
64
- # Build the message body using the parent model's create_body method
65
+ # Build the message body using the parent model's build_body method
65
66
  tools = cast(list[ToolDefinition], kwargs.pop("tools", []))
66
- body = await self._root.create_body(input, tools=tools, **kwargs)
67
+ body = await self._root.build_body(input, tools=tools, **kwargs)
67
68
 
68
69
  return {
69
70
  "custom_id": custom_id,
@@ -249,6 +250,8 @@ class AnthropicModel(LLM):
249
250
 
250
251
  @override
251
252
  def get_client(self) -> AsyncAnthropic:
253
+ if self._delegate_client:
254
+ return self._delegate_client
252
255
  if not AnthropicModel._client:
253
256
  headers: dict[str, str] = {}
254
257
  AnthropicModel._client = AsyncAnthropic(
@@ -262,16 +265,20 @@ class AnthropicModel(LLM):
262
265
  def __init__(
263
266
  self,
264
267
  model_name: str,
265
- provider: Literal["anthropic"] = "anthropic",
268
+ provider: str = "anthropic",
266
269
  *,
267
270
  config: LLMConfig | None = None,
271
+ custom_client: AsyncAnthropic | None = None,
268
272
  ):
269
273
  super().__init__(model_name, provider, config=config)
270
274
 
275
+ # allow custom client to act as delegate (native)
276
+ self._delegate_client: AsyncAnthropic | None = custom_client
277
+
271
278
  # https://docs.anthropic.com/en/api/openai-sdk
272
- self.delegate: OpenAIModel | None = (
279
+ self.delegate = (
273
280
  None
274
- if self.native
281
+ if self.native or custom_client
275
282
  else OpenAIModel(
276
283
  model_name=self.model_name,
277
284
  provider=provider,
@@ -285,11 +292,28 @@ class AnthropicModel(LLM):
285
292
  )
286
293
 
287
294
  # Initialize batch support if enabled
288
- self.supports_batch: bool = self.supports_batch and self.native
295
+ # Disable batch when using custom_client (similar to OpenAI)
296
+ self.supports_batch: bool = (
297
+ self.supports_batch and self.native and not custom_client
298
+ )
289
299
  self.batch: LLMBatchMixin | None = (
290
300
  AnthropicBatchMixin(self) if self.supports_batch else None
291
301
  )
292
302
 
303
+ async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
304
+ raw_responses = [x for x in input if isinstance(x, RawResponse)]
305
+ tool_call_ids: list[str] = []
306
+
307
+ calls = [
308
+ y
309
+ for x in raw_responses
310
+ if isinstance(x.response, ParsedBetaMessage)
311
+ for y in x.response.content # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
312
+ if isinstance(y, BetaToolUseBlock)
313
+ ]
314
+ tool_call_ids.extend([x.id for x in calls])
315
+ return tool_call_ids
316
+
293
317
  @override
294
318
  async def parse_input(
295
319
  self,
@@ -297,77 +321,61 @@ class AnthropicModel(LLM):
297
321
  **kwargs: Any,
298
322
  ) -> list[dict[str, Any] | Any]:
299
323
  new_input: list[dict[str, Any] | Any] = []
324
+
300
325
  content_user: list[dict[str, Any]] = []
301
326
 
302
- # First pass: collect all tool calls from Message objects for validation
303
- tool_calls_in_input: set[str] = set()
304
- for item in input:
305
- if hasattr(item, "content") and hasattr(item, "role"):
306
- content_list = getattr(item, "content", [])
307
- for content in content_list:
308
- # Check for both ToolUseBlock and BetaToolUseBlock
309
- if isinstance(content, (ToolUseBlock, BetaToolUseBlock)):
310
- tool_calls_in_input.add(content.id)
327
+ def flush_content_user():
328
+ if content_user:
329
+ # NOTE: must make new object as we clear()
330
+ new_input.append({"role": "user", "content": content_user.copy()})
331
+ content_user.clear()
332
+
333
+ tool_call_ids = await self.get_tool_call_ids(input)
311
334
 
312
335
  for item in input:
336
+ if isinstance(item, TextInput):
337
+ content_user.append({"type": "text", "text": item.text})
338
+ continue
339
+
340
+ if isinstance(item, FileBase):
341
+ match item.type:
342
+ case "image":
343
+ parsed = await self.parse_image(item)
344
+ case "file":
345
+ parsed = await self.parse_file(item)
346
+ content_user.append(parsed)
347
+ continue
348
+
349
+ # non content user item
350
+ flush_content_user()
351
+
313
352
  match item:
314
- case TextInput():
315
- if item.text.strip():
316
- content_user.append({"type": "text", "text": item.text})
317
- case FileWithBase64() | FileWithUrl() | FileWithId():
318
- match item.type:
319
- case "image":
320
- content_user.append(await self.parse_image(item))
321
- case "file":
322
- content_user.append(await self.parse_file(item))
323
- case _:
324
- if content_user:
325
- filtered = filter_empty_text_blocks(content_user)
326
- if filtered:
327
- new_input.append({"role": "user", "content": filtered})
328
- content_user = []
329
- match item:
330
- case ToolResult():
331
- if item.tool_call.id not in tool_calls_in_input:
332
- raise Exception(
333
- "Tool call result provided with no matching tool call"
334
- )
335
- result_str = normalize_tool_result(item.result)
336
- new_input.append(
353
+ case ToolResult():
354
+ if item.tool_call.id not in tool_call_ids:
355
+ raise NoMatchingToolCallError()
356
+
357
+ new_input.append(
358
+ {
359
+ "role": "user",
360
+ "content": [
337
361
  {
338
- "role": "user",
339
- "content": [
340
- {
341
- "type": "tool_result",
342
- "tool_use_id": item.tool_call.id,
343
- "content": [
344
- {"type": "text", "text": result_str}
345
- ],
346
- }
347
- ],
362
+ "type": "tool_result",
363
+ "tool_use_id": item.tool_call.id,
364
+ "content": [{"type": "text", "text": item.result}],
348
365
  }
349
- )
350
- case dict(): # RawInputItem
351
- item = cast(RawInputItem, item)
352
- new_input.append(item)
353
- case _: # RawResponse
354
- item = cast(Message, item)
355
- filtered_content = [
356
- block
357
- for block in item.content
358
- if not isinstance(block, TextBlock)
359
- or block.text.strip()
360
- ]
361
- if filtered_content:
362
- new_input.append(
363
- {"role": "assistant", "content": filtered_content}
364
- )
365
-
366
- if content_user:
367
- filtered = filter_empty_text_blocks(content_user)
368
- if filtered:
369
- new_input.append({"role": "user", "content": filtered})
366
+ ],
367
+ }
368
+ )
369
+ case RawResponse():
370
+ content = cast(ParsedBetaMessage, item.response).content
371
+ new_input.append({"role": "assistant", "content": content})
372
+ case RawInput():
373
+ new_input.append(item.input)
374
+
375
+ # in case content user item is the last item
376
+ flush_content_user()
370
377
 
378
+ # cache control
371
379
  if new_input:
372
380
  last_msg = new_input[-1]
373
381
  if not isinstance(last_msg, dict):
@@ -485,7 +493,7 @@ class AnthropicModel(LLM):
485
493
  bytes: io.BytesIO,
486
494
  type: Literal["image", "file"] = "file",
487
495
  ) -> FileWithId:
488
- file_mime = f"image/{mime}" if type == "image" else mime # TODO:
496
+ file_mime = f"image/{mime}" if type == "image" else mime
489
497
  response = await self.get_client().beta.files.upload(
490
498
  file=(
491
499
  name,
@@ -503,7 +511,8 @@ class AnthropicModel(LLM):
503
511
 
504
512
  cache_control = {"type": "ephemeral"} # 5 min cache
505
513
 
506
- async def create_body(
514
+ @override
515
+ async def build_body(
507
516
  self,
508
517
  input: Sequence[InputItem],
509
518
  *,
@@ -555,20 +564,36 @@ class AnthropicModel(LLM):
555
564
  input: Sequence[InputItem],
556
565
  *,
557
566
  tools: list[ToolDefinition],
567
+ query_logger: logging.Logger,
558
568
  **kwargs: object,
559
569
  ) -> QueryResult:
560
570
  if self.delegate:
561
- return await self.delegate_query(input, tools=tools, **kwargs)
571
+ return await self.delegate_query(
572
+ input, tools=tools, query_logger=query_logger, **kwargs
573
+ )
562
574
 
563
- body = await self.create_body(input, tools=tools, **kwargs)
575
+ body = await self.build_body(input, tools=tools, **kwargs)
564
576
 
565
- betas = ["files-api-2025-04-14", "interleaved-thinking-2025-05-14"]
566
- if "sonnet-4-5" in self.model_name:
567
- betas.append("context-1m-2025-08-07")
577
+ client = self.get_client()
568
578
 
569
- async with self.get_client().beta.messages.stream(
570
- **body,
571
- betas=betas,
579
+ # only send betas for the official Anthropic endpoint
580
+ is_anthropic_endpoint = self._delegate_client is None
581
+ if not is_anthropic_endpoint:
582
+ client_base_url = getattr(client, "_base_url", None) or getattr(
583
+ client, "base_url", None
584
+ )
585
+ if client_base_url:
586
+ is_anthropic_endpoint = "api.anthropic.com" in str(client_base_url)
587
+
588
+ stream_kwargs = {**body}
589
+ if is_anthropic_endpoint:
590
+ betas = ["files-api-2025-04-14", "interleaved-thinking-2025-05-14"]
591
+ if "sonnet-4-5" in self.model_name:
592
+ betas.append("context-1m-2025-08-07")
593
+ stream_kwargs["betas"] = betas
594
+
595
+ async with client.beta.messages.stream(
596
+ **stream_kwargs,
572
597
  ) as stream: # pyright: ignore[reportAny]
573
598
  message = await stream.get_final_message()
574
599
  self.logger.info(f"Anthropic Response finished: {message.id}")
@@ -604,9 +629,37 @@ class AnthropicModel(LLM):
604
629
  cache_write_tokens=message.usage.cache_creation_input_tokens,
605
630
  ),
606
631
  tool_calls=tool_calls,
607
- history=[*input, message],
632
+ history=[*input, RawResponse(response=message)],
608
633
  )
609
634
 
635
+ @override
636
+ async def count_tokens(
637
+ self,
638
+ input: Sequence[InputItem],
639
+ *,
640
+ history: Sequence[InputItem] = [],
641
+ tools: list[ToolDefinition] = [],
642
+ **kwargs: object,
643
+ ) -> int:
644
+ """
645
+ Count the number of tokens using Anthropic's native token counting API.
646
+ https://docs.anthropic.com/en/docs/build-with-claude/token-counting
647
+ """
648
+ input = [*history, *input]
649
+ if not input:
650
+ return 0
651
+
652
+ body = await self.build_body(input, tools=tools, **kwargs)
653
+
654
+ # Remove fields not supported by count_tokens endpoint
655
+ body.pop("max_tokens", None)
656
+ body.pop("temperature", None)
657
+
658
+ client = self.get_client()
659
+ response = await client.messages.count_tokens(**body)
660
+
661
+ return response.input_tokens
662
+
610
663
  @override
611
664
  async def _calculate_cost(
612
665
  self,
@@ -2,8 +2,6 @@ import io
2
2
  import json
3
3
  from typing import TYPE_CHECKING, Any, Final, Sequence, cast
4
4
 
5
- from typing_extensions import override
6
-
7
5
  from google.genai.types import (
8
6
  BatchJob,
9
7
  Content,
@@ -11,6 +9,8 @@ from google.genai.types import (
11
9
  JobState,
12
10
  UploadFileConfig,
13
11
  )
12
+ from typing_extensions import override
13
+
14
14
  from model_library.base import BatchResult, InputItem, LLMBatchMixin
15
15
 
16
16
  if TYPE_CHECKING:
@@ -144,7 +144,7 @@ class GoogleBatchMixin(LLMBatchMixin):
144
144
  **kwargs: object,
145
145
  ) -> dict[str, Any]:
146
146
  self._root.logger.debug(f"Creating batch request for custom_id: {custom_id}")
147
- body = await self._root.create_body(input, tools=[], **kwargs)
147
+ body = await self._root.build_body(input, tools=[], **kwargs)
148
148
 
149
149
  contents_any = body["contents"]
150
150
  serialized_contents: list[dict[str, Any]] = [