model-library 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +98 -0
- model_library/base/delegate_only.py +10 -0
- model_library/base/input.py +10 -7
- model_library/base/output.py +5 -0
- model_library/base/utils.py +21 -7
- model_library/exceptions.py +11 -0
- model_library/logging.py +6 -2
- model_library/providers/ai21labs.py +19 -7
- model_library/providers/amazon.py +70 -48
- model_library/providers/anthropic.py +101 -74
- model_library/providers/google/batch.py +3 -3
- model_library/providers/google/google.py +83 -45
- model_library/providers/minimax.py +19 -0
- model_library/providers/mistral.py +41 -27
- model_library/providers/openai.py +122 -73
- model_library/providers/vals.py +4 -3
- model_library/providers/xai.py +123 -115
- model_library/register_models.py +4 -2
- model_library/utils.py +0 -35
- {model_library-0.1.6.dist-info → model_library-0.1.7.dist-info}/METADATA +3 -3
- {model_library-0.1.6.dist-info → model_library-0.1.7.dist-info}/RECORD +24 -24
- {model_library-0.1.6.dist-info → model_library-0.1.7.dist-info}/WHEEL +0 -0
- {model_library-0.1.6.dist-info → model_library-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.6.dist-info → model_library-0.1.7.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import io
|
|
2
2
|
import logging
|
|
3
|
-
import time
|
|
4
3
|
from collections.abc import Sequence
|
|
5
4
|
from typing import Any, Literal
|
|
6
5
|
|
|
@@ -13,14 +12,16 @@ from typing_extensions import override
|
|
|
13
12
|
from model_library import model_library_settings
|
|
14
13
|
from model_library.base import (
|
|
15
14
|
LLM,
|
|
15
|
+
FileBase,
|
|
16
16
|
FileInput,
|
|
17
17
|
FileWithBase64,
|
|
18
18
|
FileWithId,
|
|
19
|
-
FileWithUrl,
|
|
20
19
|
InputItem,
|
|
21
20
|
LLMConfig,
|
|
22
21
|
QueryResult,
|
|
23
22
|
QueryResultMetadata,
|
|
23
|
+
RawInput,
|
|
24
|
+
RawResponse,
|
|
24
25
|
TextInput,
|
|
25
26
|
ToolBody,
|
|
26
27
|
ToolCall,
|
|
@@ -69,27 +70,30 @@ class MistralModel(LLM):
|
|
|
69
70
|
content_user: list[dict[str, Any]] = []
|
|
70
71
|
|
|
71
72
|
def flush_content_user():
|
|
72
|
-
nonlocal content_user
|
|
73
|
-
|
|
74
73
|
if content_user:
|
|
75
|
-
|
|
76
|
-
|
|
74
|
+
# NOTE: must make new object as we clear()
|
|
75
|
+
new_input.append({"role": "user", "content": content_user.copy()})
|
|
76
|
+
content_user.clear()
|
|
77
77
|
|
|
78
78
|
for item in input:
|
|
79
|
+
if isinstance(item, TextInput):
|
|
80
|
+
content_user.append({"type": "text", "text": item.text})
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
if isinstance(item, FileBase):
|
|
84
|
+
match item.type:
|
|
85
|
+
case "image":
|
|
86
|
+
parsed = await self.parse_image(item)
|
|
87
|
+
case "file":
|
|
88
|
+
parsed = await self.parse_file(item)
|
|
89
|
+
content_user.append(parsed)
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
# non content user item
|
|
93
|
+
flush_content_user()
|
|
94
|
+
|
|
79
95
|
match item:
|
|
80
|
-
case TextInput():
|
|
81
|
-
content_user.append({"type": "text", "text": item.text})
|
|
82
|
-
case FileWithBase64() | FileWithUrl() | FileWithId():
|
|
83
|
-
match item.type:
|
|
84
|
-
case "image":
|
|
85
|
-
content_user.append(await self.parse_image(item))
|
|
86
|
-
case "file":
|
|
87
|
-
content_user.append(await self.parse_file(item))
|
|
88
|
-
case AssistantMessage():
|
|
89
|
-
flush_content_user()
|
|
90
|
-
new_input.append(item)
|
|
91
96
|
case ToolResult():
|
|
92
|
-
flush_content_user()
|
|
93
97
|
new_input.append(
|
|
94
98
|
{
|
|
95
99
|
"role": "tool",
|
|
@@ -98,9 +102,12 @@ class MistralModel(LLM):
|
|
|
98
102
|
"tool_call_id": item.tool_call.id,
|
|
99
103
|
}
|
|
100
104
|
)
|
|
101
|
-
case
|
|
102
|
-
|
|
105
|
+
case RawResponse():
|
|
106
|
+
new_input.append(item.response)
|
|
107
|
+
case RawInput():
|
|
108
|
+
new_input.append(item.input)
|
|
103
109
|
|
|
110
|
+
# in case content user item is the last item
|
|
104
111
|
flush_content_user()
|
|
105
112
|
|
|
106
113
|
return new_input
|
|
@@ -167,14 +174,13 @@ class MistralModel(LLM):
|
|
|
167
174
|
raise NotImplementedError()
|
|
168
175
|
|
|
169
176
|
@override
|
|
170
|
-
async def
|
|
177
|
+
async def build_body(
|
|
171
178
|
self,
|
|
172
179
|
input: Sequence[InputItem],
|
|
173
180
|
*,
|
|
174
181
|
tools: list[ToolDefinition],
|
|
175
|
-
query_logger: logging.Logger,
|
|
176
182
|
**kwargs: object,
|
|
177
|
-
) ->
|
|
183
|
+
) -> dict[str, Any]:
|
|
178
184
|
# mistral supports max 8 images, merge extra images into the 8th image
|
|
179
185
|
input = trim_images(input, max_images=8)
|
|
180
186
|
|
|
@@ -205,8 +211,18 @@ class MistralModel(LLM):
|
|
|
205
211
|
body["top_p"] = self.top_p
|
|
206
212
|
|
|
207
213
|
body.update(kwargs)
|
|
214
|
+
return body
|
|
208
215
|
|
|
209
|
-
|
|
216
|
+
@override
|
|
217
|
+
async def _query_impl(
|
|
218
|
+
self,
|
|
219
|
+
input: Sequence[InputItem],
|
|
220
|
+
*,
|
|
221
|
+
tools: list[ToolDefinition],
|
|
222
|
+
query_logger: logging.Logger,
|
|
223
|
+
**kwargs: object,
|
|
224
|
+
) -> QueryResult:
|
|
225
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
210
226
|
|
|
211
227
|
response: EventStreamAsync[
|
|
212
228
|
CompletionEvent
|
|
@@ -247,8 +263,6 @@ class MistralModel(LLM):
|
|
|
247
263
|
in_tokens += data.usage.prompt_tokens or 0
|
|
248
264
|
out_tokens += data.usage.completion_tokens or 0
|
|
249
265
|
|
|
250
|
-
self.logger.info(f"Finished in: {time.time() - start}")
|
|
251
|
-
|
|
252
266
|
except Exception as e:
|
|
253
267
|
self.logger.error(f"Error: {e}", exc_info=True)
|
|
254
268
|
raise e
|
|
@@ -302,7 +316,7 @@ class MistralModel(LLM):
|
|
|
302
316
|
return QueryResult(
|
|
303
317
|
output_text=text,
|
|
304
318
|
reasoning=reasoning or None,
|
|
305
|
-
history=[*input, message],
|
|
319
|
+
history=[*input, RawResponse(response=message)],
|
|
306
320
|
tool_calls=tool_calls,
|
|
307
321
|
metadata=QueryResultMetadata(
|
|
308
322
|
in_tokens=in_tokens,
|
|
@@ -16,6 +16,7 @@ from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
|
|
|
16
16
|
from openai.types.create_embedding_response import CreateEmbeddingResponse
|
|
17
17
|
from openai.types.moderation_create_response import ModerationCreateResponse
|
|
18
18
|
from openai.types.responses import (
|
|
19
|
+
ResponseFunctionToolCall,
|
|
19
20
|
ResponseOutputItem,
|
|
20
21
|
ResponseOutputText,
|
|
21
22
|
ResponseStreamEvent,
|
|
@@ -29,6 +30,7 @@ from model_library.base import (
|
|
|
29
30
|
LLM,
|
|
30
31
|
BatchResult,
|
|
31
32
|
Citation,
|
|
33
|
+
FileBase,
|
|
32
34
|
FileInput,
|
|
33
35
|
FileWithBase64,
|
|
34
36
|
FileWithId,
|
|
@@ -42,7 +44,8 @@ from model_library.base import (
|
|
|
42
44
|
QueryResultCost,
|
|
43
45
|
QueryResultExtras,
|
|
44
46
|
QueryResultMetadata,
|
|
45
|
-
|
|
47
|
+
RawInput,
|
|
48
|
+
RawResponse,
|
|
46
49
|
TextInput,
|
|
47
50
|
ToolBody,
|
|
48
51
|
ToolCall,
|
|
@@ -53,6 +56,7 @@ from model_library.exceptions import (
|
|
|
53
56
|
ImmediateRetryException,
|
|
54
57
|
MaxOutputTokensExceededError,
|
|
55
58
|
ModelNoOutputError,
|
|
59
|
+
NoMatchingToolCallError,
|
|
56
60
|
)
|
|
57
61
|
from model_library.model_utils import get_reasoning_in_tag
|
|
58
62
|
from model_library.register_models import register_provider
|
|
@@ -258,7 +262,9 @@ class OpenAIModel(LLM):
|
|
|
258
262
|
use_completions: bool = False,
|
|
259
263
|
):
|
|
260
264
|
super().__init__(model_name, provider, config=config)
|
|
261
|
-
self.use_completions: bool =
|
|
265
|
+
self.use_completions: bool = (
|
|
266
|
+
use_completions # TODO: do completions in a separate file
|
|
267
|
+
)
|
|
262
268
|
self.deep_research = self.provider_config.deep_research
|
|
263
269
|
|
|
264
270
|
# allow custom client to act as delegate (native)
|
|
@@ -270,6 +276,29 @@ class OpenAIModel(LLM):
|
|
|
270
276
|
OpenAIBatchMixin(self) if self.supports_batch else None
|
|
271
277
|
)
|
|
272
278
|
|
|
279
|
+
async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
|
|
280
|
+
raw_responses = [x for x in input if isinstance(x, RawResponse)]
|
|
281
|
+
tool_call_ids: list[str] = []
|
|
282
|
+
|
|
283
|
+
if self.use_completions:
|
|
284
|
+
calls = [
|
|
285
|
+
y
|
|
286
|
+
for x in raw_responses
|
|
287
|
+
if isinstance(x.response, ChatCompletionMessage)
|
|
288
|
+
and x.response.tool_calls
|
|
289
|
+
for y in x.response.tool_calls
|
|
290
|
+
]
|
|
291
|
+
tool_call_ids.extend([x.id for x in calls if x.id])
|
|
292
|
+
else:
|
|
293
|
+
calls = [
|
|
294
|
+
y
|
|
295
|
+
for x in raw_responses
|
|
296
|
+
for y in x.response
|
|
297
|
+
if isinstance(y, ResponseFunctionToolCall)
|
|
298
|
+
]
|
|
299
|
+
tool_call_ids.extend([x.id for x in calls if x.id])
|
|
300
|
+
return tool_call_ids
|
|
301
|
+
|
|
273
302
|
@override
|
|
274
303
|
async def parse_input(
|
|
275
304
|
self,
|
|
@@ -277,63 +306,70 @@ class OpenAIModel(LLM):
|
|
|
277
306
|
**kwargs: Any,
|
|
278
307
|
) -> list[dict[str, Any] | Any]:
|
|
279
308
|
new_input: list[dict[str, Any] | Any] = []
|
|
309
|
+
|
|
280
310
|
content_user: list[dict[str, Any]] = []
|
|
311
|
+
|
|
312
|
+
def flush_content_user():
|
|
313
|
+
if content_user:
|
|
314
|
+
# NOTE: must make new object as we clear()
|
|
315
|
+
new_input.append({"role": "user", "content": content_user.copy()})
|
|
316
|
+
content_user.clear()
|
|
317
|
+
|
|
318
|
+
tool_call_ids = await self.get_tool_call_ids(input)
|
|
319
|
+
|
|
281
320
|
for item in input:
|
|
321
|
+
if isinstance(item, TextInput):
|
|
322
|
+
if self.use_completions:
|
|
323
|
+
text_key = "text"
|
|
324
|
+
else:
|
|
325
|
+
text_key = "input_text"
|
|
326
|
+
content_user.append({"type": text_key, "text": item.text})
|
|
327
|
+
continue
|
|
328
|
+
|
|
329
|
+
if isinstance(item, FileBase):
|
|
330
|
+
match item.type:
|
|
331
|
+
case "image":
|
|
332
|
+
parsed = await self.parse_image(item)
|
|
333
|
+
case "file":
|
|
334
|
+
parsed = await self.parse_file(item)
|
|
335
|
+
content_user.append(parsed)
|
|
336
|
+
continue
|
|
337
|
+
|
|
338
|
+
# non content user item
|
|
339
|
+
flush_content_user()
|
|
340
|
+
|
|
282
341
|
match item:
|
|
283
|
-
case
|
|
342
|
+
case ToolResult():
|
|
343
|
+
if item.tool_call.id not in tool_call_ids:
|
|
344
|
+
raise NoMatchingToolCallError()
|
|
345
|
+
|
|
284
346
|
if self.use_completions:
|
|
285
|
-
|
|
347
|
+
new_input.append(
|
|
348
|
+
{
|
|
349
|
+
"role": "tool",
|
|
350
|
+
"tool_call_id": item.tool_call.id,
|
|
351
|
+
"content": item.result,
|
|
352
|
+
}
|
|
353
|
+
)
|
|
286
354
|
else:
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
case
|
|
295
|
-
if
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
):
|
|
306
|
-
raise Exception(
|
|
307
|
-
"Tool call result provided with no matching tool call"
|
|
308
|
-
)
|
|
309
|
-
if self.use_completions:
|
|
310
|
-
new_input.append(
|
|
311
|
-
{
|
|
312
|
-
"role": "tool",
|
|
313
|
-
"tool_call_id": item.tool_call.id,
|
|
314
|
-
"content": item.result,
|
|
315
|
-
}
|
|
316
|
-
)
|
|
317
|
-
else:
|
|
318
|
-
new_input.append(
|
|
319
|
-
{
|
|
320
|
-
"type": "function_call_output",
|
|
321
|
-
"call_id": item.tool_call.call_id,
|
|
322
|
-
"output": item.result,
|
|
323
|
-
}
|
|
324
|
-
)
|
|
325
|
-
case dict(): # RawInputItem
|
|
326
|
-
item = cast(RawInputItem, item)
|
|
327
|
-
new_input.append(item)
|
|
328
|
-
case _: # RawResponse
|
|
329
|
-
if self.use_completions:
|
|
330
|
-
item = cast(ChatCompletionMessageToolCall, item)
|
|
331
|
-
else:
|
|
332
|
-
item = cast(ResponseOutputItem, item)
|
|
333
|
-
new_input.append(item)
|
|
334
|
-
|
|
335
|
-
if content_user:
|
|
336
|
-
new_input.append({"role": "user", "content": content_user})
|
|
355
|
+
new_input.append(
|
|
356
|
+
{
|
|
357
|
+
"type": "function_call_output",
|
|
358
|
+
"call_id": item.tool_call.call_id,
|
|
359
|
+
"output": item.result,
|
|
360
|
+
}
|
|
361
|
+
)
|
|
362
|
+
case RawResponse():
|
|
363
|
+
if self.use_completions:
|
|
364
|
+
pass
|
|
365
|
+
new_input.append(item.response)
|
|
366
|
+
else:
|
|
367
|
+
new_input.extend(item.response)
|
|
368
|
+
case RawInput():
|
|
369
|
+
new_input.append(item.input)
|
|
370
|
+
|
|
371
|
+
# in case content user item is the last item
|
|
372
|
+
flush_content_user()
|
|
337
373
|
|
|
338
374
|
return new_input
|
|
339
375
|
|
|
@@ -469,19 +505,13 @@ class OpenAIModel(LLM):
|
|
|
469
505
|
file_id=response.id,
|
|
470
506
|
)
|
|
471
507
|
|
|
472
|
-
async def
|
|
508
|
+
async def _build_body_completions(
|
|
473
509
|
self,
|
|
474
510
|
input: Sequence[InputItem],
|
|
475
511
|
*,
|
|
476
512
|
tools: list[ToolDefinition],
|
|
477
513
|
**kwargs: object,
|
|
478
|
-
) ->
|
|
479
|
-
"""
|
|
480
|
-
Completions endpoint
|
|
481
|
-
Generally not used for openai models
|
|
482
|
-
Used by some providers using openai as a delegate
|
|
483
|
-
"""
|
|
484
|
-
|
|
514
|
+
) -> dict[str, Any]:
|
|
485
515
|
parsed_input: list[dict[str, Any] | ChatCompletionMessage] = []
|
|
486
516
|
if "system_prompt" in kwargs:
|
|
487
517
|
parsed_input.append(
|
|
@@ -520,6 +550,23 @@ class OpenAIModel(LLM):
|
|
|
520
550
|
|
|
521
551
|
body.update(kwargs)
|
|
522
552
|
|
|
553
|
+
return body
|
|
554
|
+
|
|
555
|
+
async def _query_completions(
|
|
556
|
+
self,
|
|
557
|
+
input: Sequence[InputItem],
|
|
558
|
+
*,
|
|
559
|
+
tools: list[ToolDefinition],
|
|
560
|
+
**kwargs: object,
|
|
561
|
+
) -> QueryResult:
|
|
562
|
+
"""
|
|
563
|
+
Completions endpoint
|
|
564
|
+
Generally not used for openai models
|
|
565
|
+
Used by providers using openai as a delegate
|
|
566
|
+
"""
|
|
567
|
+
|
|
568
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
569
|
+
|
|
523
570
|
output_text: str = ""
|
|
524
571
|
reasoning_text: str = ""
|
|
525
572
|
metadata: QueryResultMetadata = QueryResultMetadata()
|
|
@@ -632,7 +679,7 @@ class OpenAIModel(LLM):
|
|
|
632
679
|
output_text=output_text,
|
|
633
680
|
reasoning=reasoning_text,
|
|
634
681
|
tool_calls=tool_calls,
|
|
635
|
-
history=[*input, final_message],
|
|
682
|
+
history=[*input, RawResponse(response=final_message)],
|
|
636
683
|
metadata=metadata,
|
|
637
684
|
)
|
|
638
685
|
|
|
@@ -667,13 +714,17 @@ class OpenAIModel(LLM):
|
|
|
667
714
|
if not valid:
|
|
668
715
|
raise Exception("Deep research models require web search tools")
|
|
669
716
|
|
|
717
|
+
@override
|
|
670
718
|
async def build_body(
|
|
671
719
|
self,
|
|
672
720
|
input: Sequence[InputItem],
|
|
673
721
|
*,
|
|
674
|
-
tools:
|
|
722
|
+
tools: list[ToolDefinition],
|
|
675
723
|
**kwargs: object,
|
|
676
724
|
) -> dict[str, Any]:
|
|
725
|
+
if self.use_completions:
|
|
726
|
+
return await self._build_body_completions(input, tools=tools, **kwargs)
|
|
727
|
+
|
|
677
728
|
if self.deep_research:
|
|
678
729
|
await self._check_deep_research_args(tools, **kwargs)
|
|
679
730
|
|
|
@@ -717,7 +768,6 @@ class OpenAIModel(LLM):
|
|
|
717
768
|
_ = kwargs.pop("stream", None)
|
|
718
769
|
|
|
719
770
|
body.update(kwargs)
|
|
720
|
-
|
|
721
771
|
return body
|
|
722
772
|
|
|
723
773
|
@override
|
|
@@ -785,13 +835,12 @@ class OpenAIModel(LLM):
|
|
|
785
835
|
citations: list[Citation] = []
|
|
786
836
|
reasoning = None
|
|
787
837
|
for output in response.output:
|
|
788
|
-
if
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
citations.append(Citation(**citation.model_dump()))
|
|
838
|
+
if output.type == "message":
|
|
839
|
+
for content in output.content:
|
|
840
|
+
if not isinstance(content, ResponseOutputText):
|
|
841
|
+
continue
|
|
842
|
+
for citation in content.annotations:
|
|
843
|
+
citations.append(Citation(**citation.model_dump()))
|
|
795
844
|
|
|
796
845
|
if output.type == "reasoning":
|
|
797
846
|
reasoning = " ".join([i.text for i in output.summary])
|
|
@@ -814,7 +863,7 @@ class OpenAIModel(LLM):
|
|
|
814
863
|
output_text=response.output_text,
|
|
815
864
|
reasoning=reasoning,
|
|
816
865
|
tool_calls=tool_calls,
|
|
817
|
-
history=[*input,
|
|
866
|
+
history=[*input, RawResponse(response=response.output)],
|
|
818
867
|
extras=QueryResultExtras(citations=citations),
|
|
819
868
|
)
|
|
820
869
|
if response.usage:
|
model_library/providers/vals.py
CHANGED
|
@@ -51,7 +51,7 @@ class DummyAIBatchMixin(LLMBatchMixin):
|
|
|
51
51
|
"custom_id": custom_id,
|
|
52
52
|
"method": "",
|
|
53
53
|
"url": "",
|
|
54
|
-
"body": await self._root.
|
|
54
|
+
"body": await self._root.build_body(input, tools=[], **kwargs),
|
|
55
55
|
}
|
|
56
56
|
|
|
57
57
|
@override
|
|
@@ -227,7 +227,8 @@ class DummyAIModel(LLM):
|
|
|
227
227
|
) -> FileWithId:
|
|
228
228
|
raise NotImplementedError()
|
|
229
229
|
|
|
230
|
-
|
|
230
|
+
@override
|
|
231
|
+
async def build_body(
|
|
231
232
|
self,
|
|
232
233
|
input: Sequence[InputItem],
|
|
233
234
|
*,
|
|
@@ -275,7 +276,7 @@ class DummyAIModel(LLM):
|
|
|
275
276
|
query_logger: logging.Logger,
|
|
276
277
|
**kwargs: object,
|
|
277
278
|
) -> QueryResult:
|
|
278
|
-
body = await self.
|
|
279
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
279
280
|
|
|
280
281
|
fail_rate = FAIL_RATE
|
|
281
282
|
|