model-library 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,12 +1,16 @@
1
1
  import base64
2
2
  import io
3
+ import logging
4
+ import uuid
3
5
  from typing import Any, Literal, Sequence, cast
4
6
 
5
7
  from google.genai import Client
6
8
  from google.genai import errors as genai_errors
7
9
  from google.genai.types import (
8
10
  Content,
11
+ CountTokensConfig,
9
12
  File,
13
+ FinishReason,
10
14
  FunctionDeclaration,
11
15
  GenerateContentConfig,
12
16
  GenerateContentResponse,
@@ -20,13 +24,13 @@ from google.genai.types import (
20
24
  Tool,
21
25
  ToolListUnion,
22
26
  UploadFileConfig,
23
- FinishReason,
24
27
  )
25
28
  from typing_extensions import override
26
29
 
27
30
  from model_library import model_library_settings
28
31
  from model_library.base import (
29
32
  LLM,
33
+ FileBase,
30
34
  FileInput,
31
35
  FileWithBase64,
32
36
  FileWithId,
@@ -39,6 +43,8 @@ from model_library.base import (
39
43
  QueryResult,
40
44
  QueryResultCost,
41
45
  QueryResultMetadata,
46
+ RawInput,
47
+ RawResponse,
42
48
  TextInput,
43
49
  ToolBody,
44
50
  ToolCall,
@@ -53,7 +59,10 @@ from model_library.exceptions import (
53
59
  )
54
60
  from model_library.providers.google.batch import GoogleBatchMixin
55
61
  from model_library.register_models import register_provider
56
- from model_library.utils import normalize_tool_result
62
+
63
+
64
+ def generate_tool_call_id(tool_name: str) -> str:
65
+ return str(tool_name + "_" + str(uuid.uuid4()))
57
66
 
58
67
 
59
68
  class GoogleConfig(ProviderConfig):
@@ -140,63 +149,52 @@ class GoogleModel(LLM):
140
149
  input: Sequence[InputItem],
141
150
  **kwargs: Any,
142
151
  ) -> list[Content]:
143
- parsed_input: list[Content] = []
144
- parts: list[Part] = []
152
+ new_input: list[Content] = []
145
153
 
146
- def flush_parts():
147
- nonlocal parts
154
+ content_user: list[Part] = []
148
155
 
149
- if parts:
150
- parsed_input.append(Content(parts=parts, role="user"))
151
- parts = []
156
+ def flush_content_user():
157
+ if content_user:
158
+ new_input.append(Content(parts=content_user, role="user"))
159
+ content_user.clear()
152
160
 
153
161
  for item in input:
154
- match item:
155
- case TextInput():
156
- if item.text.strip():
157
- parts.append(Part.from_text(text=item.text))
162
+ if isinstance(item, TextInput):
163
+ content_user.append(Part.from_text(text=item.text))
164
+ continue
165
+
166
+ if isinstance(item, FileBase):
167
+ parsed = await self.parse_file(item)
168
+ content_user.append(parsed)
169
+ continue
158
170
 
159
- case FileWithBase64() | FileWithUrl() | FileWithId():
160
- part = await self.parse_file(item)
161
- parts.append(part)
171
+ # non content user item
172
+ flush_content_user()
162
173
 
174
+ match item:
163
175
  case ToolResult():
164
- flush_parts()
165
- result_str = normalize_tool_result(item.result)
166
- parsed_input.append(
176
+ # id check
177
+ new_input.append(
167
178
  Content(
168
179
  role="function",
169
180
  parts=[
170
181
  Part.from_function_response(
171
182
  name=item.tool_call.name,
172
- response={"result": result_str},
183
+ response={"result": item.result},
173
184
  )
174
185
  ],
175
186
  )
176
187
  )
177
188
 
178
- case GenerateContentResponse():
179
- flush_parts()
180
- candidates = item.candidates
181
- if candidates and candidates[0]:
182
- content0 = candidates[0].content
183
- if content0 is not None:
184
- parsed_input.append(content0)
185
- else:
186
- self.logger.debug(
187
- "GenerateContentResponse missing candidates; skipping"
188
- )
189
-
190
- case Content():
191
- flush_parts()
192
- parsed_input.append(item)
189
+ case RawResponse():
190
+ new_input.extend(item.response)
191
+ case RawInput():
192
+ new_input.append(item.input)
193
193
 
194
- case _:
195
- raise BadInputError(f"Unsupported input type: {type(item)}")
194
+ # in case content user item is the last item
195
+ flush_content_user()
196
196
 
197
- flush_parts()
198
-
199
- return parsed_input
197
+ return new_input
200
198
 
201
199
  @override
202
200
  async def parse_file(self, file: FileInput) -> Part:
@@ -278,7 +276,8 @@ class GoogleModel(LLM):
278
276
  mime=mime,
279
277
  )
280
278
 
281
- async def create_body(
279
+ @override
280
+ async def build_body(
282
281
  self,
283
282
  input: Sequence[InputItem],
284
283
  *,
@@ -328,9 +327,10 @@ class GoogleModel(LLM):
328
327
  input: Sequence[InputItem],
329
328
  *,
330
329
  tools: list[ToolDefinition],
330
+ query_logger: logging.Logger,
331
331
  **kwargs: object,
332
332
  ) -> QueryResult:
333
- body: dict[str, Any] = await self.create_body(input, tools=tools, **kwargs)
333
+ body: dict[str, Any] = await self.build_body(input, tools=tools, **kwargs)
334
334
 
335
335
  text: str = ""
336
336
  reasoning: str = ""
@@ -357,9 +357,10 @@ class GoogleModel(LLM):
357
357
 
358
358
  call_args = part.function_call.args or {}
359
359
  tool_calls.append(
360
- # weirdly, id is not required
360
+ # Weirdly, id is not required. If not provided, we generate one.
361
361
  ToolCall(
362
- id=part.function_call.id or "",
362
+ id=part.function_call.id
363
+ or generate_tool_call_id(part.function_call.name),
363
364
  name=part.function_call.name,
364
365
  args=call_args,
365
366
  )
@@ -387,7 +388,7 @@ class GoogleModel(LLM):
387
388
  result = QueryResult(
388
389
  output_text=text,
389
390
  reasoning=reasoning,
390
- history=[*input, *contents],
391
+ history=[*input, RawResponse(response=contents)],
391
392
  tool_calls=tool_calls,
392
393
  )
393
394
 
@@ -402,6 +403,51 @@ class GoogleModel(LLM):
402
403
  )
403
404
  return result
404
405
 
406
+ @override
407
+ async def count_tokens(
408
+ self,
409
+ input: Sequence[InputItem],
410
+ *,
411
+ history: Sequence[InputItem] = [],
412
+ tools: list[ToolDefinition] = [],
413
+ **kwargs: object,
414
+ ) -> int:
415
+ """
416
+ Count the number of tokens using Google's native token counting API.
417
+ https://ai.google.dev/gemini-api/docs/tokens
418
+
419
+ Only Vertex AI supports system_instruction and tools in count_tokens.
420
+ For Gemini API, fall back to the base implementation.
421
+ TODO: implement token counting for non-Vertex models.
422
+ """
423
+ if not self.provider_config.use_vertex:
424
+ return await super().count_tokens(
425
+ input, history=history, tools=tools, **kwargs
426
+ )
427
+
428
+ input = [*history, *input]
429
+ if not input:
430
+ return 0
431
+
432
+ system_prompt = kwargs.pop("system_prompt", None)
433
+ contents = await self.parse_input(input, **kwargs)
434
+ parsed_tools = await self.parse_tools(tools) if tools else None
435
+ config = CountTokensConfig(
436
+ system_instruction=str(system_prompt) if system_prompt else None,
437
+ tools=parsed_tools,
438
+ )
439
+
440
+ response = await self.client.aio.models.count_tokens(
441
+ model=self.model_name,
442
+ contents=cast(Any, contents),
443
+ config=config,
444
+ )
445
+
446
+ if response.total_tokens is None:
447
+ raise ValueError("count_tokens returned None")
448
+
449
+ return response.total_tokens
450
+
405
451
  @override
406
452
  async def _calculate_cost(
407
453
  self,
@@ -438,7 +484,7 @@ class GoogleModel(LLM):
438
484
  **kwargs: object,
439
485
  ) -> PydanticT:
440
486
  # Create the request body with JSON schema
441
- body: dict[str, Any] = await self.create_body(input, tools=[], **kwargs)
487
+ body: dict[str, Any] = await self.build_body(input, tools=[], **kwargs)
442
488
 
443
489
  # Get the JSON schema from the Pydantic model
444
490
  json_schema = pydantic_model.model_json_schema()
@@ -1,13 +1,16 @@
1
1
  from typing import Literal
2
2
 
3
3
  from model_library import model_library_settings
4
- from model_library.base import (
5
- DelegateOnly,
6
- LLMConfig,
7
- )
8
- from model_library.providers.openai import OpenAIModel
4
+ from model_library.base import DelegateOnly, LLMConfig
5
+ from model_library.base.input import InputItem, ToolDefinition
6
+ from model_library.providers.anthropic import AnthropicModel
9
7
  from model_library.register_models import register_provider
10
- from model_library.utils import create_openai_client_with_defaults
8
+ from model_library.utils import default_httpx_client
9
+
10
+ from anthropic import AsyncAnthropic
11
+
12
+ from typing import Sequence
13
+ from typing_extensions import override
11
14
 
12
15
 
13
16
  @register_provider("minimax")
@@ -21,13 +24,29 @@ class MinimaxModel(DelegateOnly):
21
24
  ):
22
25
  super().__init__(model_name, provider, config=config)
23
26
 
24
- self.delegate = OpenAIModel(
27
+ self.delegate = AnthropicModel(
25
28
  model_name=self.model_name,
26
29
  provider=self.provider,
27
30
  config=config,
28
- custom_client=create_openai_client_with_defaults(
31
+ custom_client=AsyncAnthropic(
29
32
  api_key=model_library_settings.MINIMAX_API_KEY,
30
- base_url="https://api.minimax.io/v1",
33
+ base_url="https://api.minimax.io/anthropic",
34
+ http_client=default_httpx_client(),
35
+ max_retries=1,
31
36
  ),
32
- use_completions=True,
37
+ )
38
+
39
+ # minimax client shares anthropic's syntax
40
+ @override
41
+ async def count_tokens(
42
+ self,
43
+ input: Sequence[InputItem],
44
+ *,
45
+ history: Sequence[InputItem] = [],
46
+ tools: list[ToolDefinition] = [],
47
+ **kwargs: object,
48
+ ) -> int:
49
+ assert self.delegate
50
+ return await self.delegate.count_tokens(
51
+ input, history=history, tools=tools, **kwargs
33
52
  )
@@ -1,5 +1,5 @@
1
1
  import io
2
- import time
2
+ import logging
3
3
  from collections.abc import Sequence
4
4
  from typing import Any, Literal
5
5
 
@@ -12,14 +12,16 @@ from typing_extensions import override
12
12
  from model_library import model_library_settings
13
13
  from model_library.base import (
14
14
  LLM,
15
+ FileBase,
15
16
  FileInput,
16
17
  FileWithBase64,
17
18
  FileWithId,
18
- FileWithUrl,
19
19
  InputItem,
20
20
  LLMConfig,
21
21
  QueryResult,
22
22
  QueryResultMetadata,
23
+ RawInput,
24
+ RawResponse,
23
25
  TextInput,
24
26
  ToolBody,
25
27
  ToolCall,
@@ -68,27 +70,30 @@ class MistralModel(LLM):
68
70
  content_user: list[dict[str, Any]] = []
69
71
 
70
72
  def flush_content_user():
71
- nonlocal content_user
72
-
73
73
  if content_user:
74
- new_input.append({"role": "user", "content": content_user})
75
- content_user = []
74
+ # NOTE: must make new object as we clear()
75
+ new_input.append({"role": "user", "content": content_user.copy()})
76
+ content_user.clear()
76
77
 
77
78
  for item in input:
79
+ if isinstance(item, TextInput):
80
+ content_user.append({"type": "text", "text": item.text})
81
+ continue
82
+
83
+ if isinstance(item, FileBase):
84
+ match item.type:
85
+ case "image":
86
+ parsed = await self.parse_image(item)
87
+ case "file":
88
+ parsed = await self.parse_file(item)
89
+ content_user.append(parsed)
90
+ continue
91
+
92
+ # non content user item
93
+ flush_content_user()
94
+
78
95
  match item:
79
- case TextInput():
80
- content_user.append({"type": "text", "text": item.text})
81
- case FileWithBase64() | FileWithUrl() | FileWithId():
82
- match item.type:
83
- case "image":
84
- content_user.append(await self.parse_image(item))
85
- case "file":
86
- content_user.append(await self.parse_file(item))
87
- case AssistantMessage():
88
- flush_content_user()
89
- new_input.append(item)
90
96
  case ToolResult():
91
- flush_content_user()
92
97
  new_input.append(
93
98
  {
94
99
  "role": "tool",
@@ -97,9 +102,12 @@ class MistralModel(LLM):
97
102
  "tool_call_id": item.tool_call.id,
98
103
  }
99
104
  )
100
- case _:
101
- raise BadInputError("Unsupported input type")
105
+ case RawResponse():
106
+ new_input.append(item.response)
107
+ case RawInput():
108
+ new_input.append(item.input)
102
109
 
110
+ # in case content user item is the last item
103
111
  flush_content_user()
104
112
 
105
113
  return new_input
@@ -166,13 +174,13 @@ class MistralModel(LLM):
166
174
  raise NotImplementedError()
167
175
 
168
176
  @override
169
- async def _query_impl(
177
+ async def build_body(
170
178
  self,
171
179
  input: Sequence[InputItem],
172
180
  *,
173
181
  tools: list[ToolDefinition],
174
182
  **kwargs: object,
175
- ) -> QueryResult:
183
+ ) -> dict[str, Any]:
176
184
  # mistral supports max 8 images, merge extra images into the 8th image
177
185
  input = trim_images(input, max_images=8)
178
186
 
@@ -203,8 +211,18 @@ class MistralModel(LLM):
203
211
  body["top_p"] = self.top_p
204
212
 
205
213
  body.update(kwargs)
214
+ return body
206
215
 
207
- start = time.time()
216
+ @override
217
+ async def _query_impl(
218
+ self,
219
+ input: Sequence[InputItem],
220
+ *,
221
+ tools: list[ToolDefinition],
222
+ query_logger: logging.Logger,
223
+ **kwargs: object,
224
+ ) -> QueryResult:
225
+ body = await self.build_body(input, tools=tools, **kwargs)
208
226
 
209
227
  response: EventStreamAsync[
210
228
  CompletionEvent
@@ -245,8 +263,6 @@ class MistralModel(LLM):
245
263
  in_tokens += data.usage.prompt_tokens or 0
246
264
  out_tokens += data.usage.completion_tokens or 0
247
265
 
248
- self.logger.info(f"Finished in: {time.time() - start}")
249
-
250
266
  except Exception as e:
251
267
  self.logger.error(f"Error: {e}", exc_info=True)
252
268
  raise e
@@ -300,7 +316,7 @@ class MistralModel(LLM):
300
316
  return QueryResult(
301
317
  output_text=text,
302
318
  reasoning=reasoning or None,
303
- history=[*input, message],
319
+ history=[*input, RawResponse(response=message)],
304
320
  tool_calls=tool_calls,
305
321
  metadata=QueryResultMetadata(
306
322
  in_tokens=in_tokens,