model-library 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +114 -12
- model_library/base/delegate_only.py +15 -1
- model_library/base/input.py +10 -7
- model_library/base/output.py +5 -0
- model_library/base/utils.py +21 -7
- model_library/config/all_models.json +92 -1
- model_library/config/fireworks_models.yaml +2 -0
- model_library/config/minimax_models.yaml +18 -0
- model_library/config/zai_models.yaml +14 -0
- model_library/exceptions.py +11 -0
- model_library/logging.py +6 -2
- model_library/providers/ai21labs.py +20 -6
- model_library/providers/amazon.py +72 -48
- model_library/providers/anthropic.py +138 -85
- model_library/providers/google/batch.py +3 -3
- model_library/providers/google/google.py +92 -46
- model_library/providers/minimax.py +29 -10
- model_library/providers/mistral.py +42 -26
- model_library/providers/openai.py +131 -77
- model_library/providers/vals.py +6 -3
- model_library/providers/xai.py +125 -113
- model_library/register_models.py +5 -3
- model_library/utils.py +0 -35
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/METADATA +3 -3
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/RECORD +28 -28
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/WHEEL +0 -0
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import io
|
|
4
4
|
import json
|
|
5
|
+
import logging
|
|
5
6
|
from typing import Any, Literal, Sequence, cast
|
|
6
7
|
|
|
7
8
|
from openai import APIConnectionError, AsyncOpenAI
|
|
@@ -15,6 +16,7 @@ from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
|
|
|
15
16
|
from openai.types.create_embedding_response import CreateEmbeddingResponse
|
|
16
17
|
from openai.types.moderation_create_response import ModerationCreateResponse
|
|
17
18
|
from openai.types.responses import (
|
|
19
|
+
ResponseFunctionToolCall,
|
|
18
20
|
ResponseOutputItem,
|
|
19
21
|
ResponseOutputText,
|
|
20
22
|
ResponseStreamEvent,
|
|
@@ -28,6 +30,7 @@ from model_library.base import (
|
|
|
28
30
|
LLM,
|
|
29
31
|
BatchResult,
|
|
30
32
|
Citation,
|
|
33
|
+
FileBase,
|
|
31
34
|
FileInput,
|
|
32
35
|
FileWithBase64,
|
|
33
36
|
FileWithId,
|
|
@@ -41,7 +44,8 @@ from model_library.base import (
|
|
|
41
44
|
QueryResultCost,
|
|
42
45
|
QueryResultExtras,
|
|
43
46
|
QueryResultMetadata,
|
|
44
|
-
|
|
47
|
+
RawInput,
|
|
48
|
+
RawResponse,
|
|
45
49
|
TextInput,
|
|
46
50
|
ToolBody,
|
|
47
51
|
ToolCall,
|
|
@@ -52,6 +56,7 @@ from model_library.exceptions import (
|
|
|
52
56
|
ImmediateRetryException,
|
|
53
57
|
MaxOutputTokensExceededError,
|
|
54
58
|
ModelNoOutputError,
|
|
59
|
+
NoMatchingToolCallError,
|
|
55
60
|
)
|
|
56
61
|
from model_library.model_utils import get_reasoning_in_tag
|
|
57
62
|
from model_library.register_models import register_provider
|
|
@@ -257,7 +262,9 @@ class OpenAIModel(LLM):
|
|
|
257
262
|
use_completions: bool = False,
|
|
258
263
|
):
|
|
259
264
|
super().__init__(model_name, provider, config=config)
|
|
260
|
-
self.use_completions: bool =
|
|
265
|
+
self.use_completions: bool = (
|
|
266
|
+
use_completions # TODO: do completions in a separate file
|
|
267
|
+
)
|
|
261
268
|
self.deep_research = self.provider_config.deep_research
|
|
262
269
|
|
|
263
270
|
# allow custom client to act as delegate (native)
|
|
@@ -269,6 +276,29 @@ class OpenAIModel(LLM):
|
|
|
269
276
|
OpenAIBatchMixin(self) if self.supports_batch else None
|
|
270
277
|
)
|
|
271
278
|
|
|
279
|
+
async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
|
|
280
|
+
raw_responses = [x for x in input if isinstance(x, RawResponse)]
|
|
281
|
+
tool_call_ids: list[str] = []
|
|
282
|
+
|
|
283
|
+
if self.use_completions:
|
|
284
|
+
calls = [
|
|
285
|
+
y
|
|
286
|
+
for x in raw_responses
|
|
287
|
+
if isinstance(x.response, ChatCompletionMessage)
|
|
288
|
+
and x.response.tool_calls
|
|
289
|
+
for y in x.response.tool_calls
|
|
290
|
+
]
|
|
291
|
+
tool_call_ids.extend([x.id for x in calls if x.id])
|
|
292
|
+
else:
|
|
293
|
+
calls = [
|
|
294
|
+
y
|
|
295
|
+
for x in raw_responses
|
|
296
|
+
for y in x.response
|
|
297
|
+
if isinstance(y, ResponseFunctionToolCall)
|
|
298
|
+
]
|
|
299
|
+
tool_call_ids.extend([x.id for x in calls if x.id])
|
|
300
|
+
return tool_call_ids
|
|
301
|
+
|
|
272
302
|
@override
|
|
273
303
|
async def parse_input(
|
|
274
304
|
self,
|
|
@@ -276,63 +306,70 @@ class OpenAIModel(LLM):
|
|
|
276
306
|
**kwargs: Any,
|
|
277
307
|
) -> list[dict[str, Any] | Any]:
|
|
278
308
|
new_input: list[dict[str, Any] | Any] = []
|
|
309
|
+
|
|
279
310
|
content_user: list[dict[str, Any]] = []
|
|
311
|
+
|
|
312
|
+
def flush_content_user():
|
|
313
|
+
if content_user:
|
|
314
|
+
# NOTE: must make new object as we clear()
|
|
315
|
+
new_input.append({"role": "user", "content": content_user.copy()})
|
|
316
|
+
content_user.clear()
|
|
317
|
+
|
|
318
|
+
tool_call_ids = await self.get_tool_call_ids(input)
|
|
319
|
+
|
|
280
320
|
for item in input:
|
|
321
|
+
if isinstance(item, TextInput):
|
|
322
|
+
if self.use_completions:
|
|
323
|
+
text_key = "text"
|
|
324
|
+
else:
|
|
325
|
+
text_key = "input_text"
|
|
326
|
+
content_user.append({"type": text_key, "text": item.text})
|
|
327
|
+
continue
|
|
328
|
+
|
|
329
|
+
if isinstance(item, FileBase):
|
|
330
|
+
match item.type:
|
|
331
|
+
case "image":
|
|
332
|
+
parsed = await self.parse_image(item)
|
|
333
|
+
case "file":
|
|
334
|
+
parsed = await self.parse_file(item)
|
|
335
|
+
content_user.append(parsed)
|
|
336
|
+
continue
|
|
337
|
+
|
|
338
|
+
# non content user item
|
|
339
|
+
flush_content_user()
|
|
340
|
+
|
|
281
341
|
match item:
|
|
282
|
-
case
|
|
342
|
+
case ToolResult():
|
|
343
|
+
if item.tool_call.id not in tool_call_ids:
|
|
344
|
+
raise NoMatchingToolCallError()
|
|
345
|
+
|
|
283
346
|
if self.use_completions:
|
|
284
|
-
|
|
347
|
+
new_input.append(
|
|
348
|
+
{
|
|
349
|
+
"role": "tool",
|
|
350
|
+
"tool_call_id": item.tool_call.id,
|
|
351
|
+
"content": item.result,
|
|
352
|
+
}
|
|
353
|
+
)
|
|
285
354
|
else:
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
case
|
|
294
|
-
if
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
):
|
|
305
|
-
raise Exception(
|
|
306
|
-
"Tool call result provided with no matching tool call"
|
|
307
|
-
)
|
|
308
|
-
if self.use_completions:
|
|
309
|
-
new_input.append(
|
|
310
|
-
{
|
|
311
|
-
"role": "tool",
|
|
312
|
-
"tool_call_id": item.tool_call.id,
|
|
313
|
-
"content": item.result,
|
|
314
|
-
}
|
|
315
|
-
)
|
|
316
|
-
else:
|
|
317
|
-
new_input.append(
|
|
318
|
-
{
|
|
319
|
-
"type": "function_call_output",
|
|
320
|
-
"call_id": item.tool_call.call_id,
|
|
321
|
-
"output": item.result,
|
|
322
|
-
}
|
|
323
|
-
)
|
|
324
|
-
case dict(): # RawInputItem
|
|
325
|
-
item = cast(RawInputItem, item)
|
|
326
|
-
new_input.append(item)
|
|
327
|
-
case _: # RawResponse
|
|
328
|
-
if self.use_completions:
|
|
329
|
-
item = cast(ChatCompletionMessageToolCall, item)
|
|
330
|
-
else:
|
|
331
|
-
item = cast(ResponseOutputItem, item)
|
|
332
|
-
new_input.append(item)
|
|
333
|
-
|
|
334
|
-
if content_user:
|
|
335
|
-
new_input.append({"role": "user", "content": content_user})
|
|
355
|
+
new_input.append(
|
|
356
|
+
{
|
|
357
|
+
"type": "function_call_output",
|
|
358
|
+
"call_id": item.tool_call.call_id,
|
|
359
|
+
"output": item.result,
|
|
360
|
+
}
|
|
361
|
+
)
|
|
362
|
+
case RawResponse():
|
|
363
|
+
if self.use_completions:
|
|
364
|
+
pass
|
|
365
|
+
new_input.append(item.response)
|
|
366
|
+
else:
|
|
367
|
+
new_input.extend(item.response)
|
|
368
|
+
case RawInput():
|
|
369
|
+
new_input.append(item.input)
|
|
370
|
+
|
|
371
|
+
# in case content user item is the last item
|
|
372
|
+
flush_content_user()
|
|
336
373
|
|
|
337
374
|
return new_input
|
|
338
375
|
|
|
@@ -468,19 +505,13 @@ class OpenAIModel(LLM):
|
|
|
468
505
|
file_id=response.id,
|
|
469
506
|
)
|
|
470
507
|
|
|
471
|
-
async def
|
|
508
|
+
async def _build_body_completions(
|
|
472
509
|
self,
|
|
473
510
|
input: Sequence[InputItem],
|
|
474
511
|
*,
|
|
475
512
|
tools: list[ToolDefinition],
|
|
476
513
|
**kwargs: object,
|
|
477
|
-
) ->
|
|
478
|
-
"""
|
|
479
|
-
Completions endpoint
|
|
480
|
-
Generally not used for openai models
|
|
481
|
-
Used by some providers using openai as a delegate
|
|
482
|
-
"""
|
|
483
|
-
|
|
514
|
+
) -> dict[str, Any]:
|
|
484
515
|
parsed_input: list[dict[str, Any] | ChatCompletionMessage] = []
|
|
485
516
|
if "system_prompt" in kwargs:
|
|
486
517
|
parsed_input.append(
|
|
@@ -505,8 +536,11 @@ class OpenAIModel(LLM):
|
|
|
505
536
|
if self.reasoning:
|
|
506
537
|
del body["max_tokens"]
|
|
507
538
|
body["max_completion_tokens"] = self.max_tokens
|
|
508
|
-
|
|
509
|
-
|
|
539
|
+
|
|
540
|
+
# some model endpoints (like `fireworks/deepseek-v3p2`)
|
|
541
|
+
# require explicitly setting reasoning effort to disable thinking
|
|
542
|
+
if self.reasoning_effort is not None:
|
|
543
|
+
body["reasoning_effort"] = self.reasoning_effort
|
|
510
544
|
|
|
511
545
|
if self.supports_temperature:
|
|
512
546
|
if self.temperature is not None:
|
|
@@ -516,6 +550,23 @@ class OpenAIModel(LLM):
|
|
|
516
550
|
|
|
517
551
|
body.update(kwargs)
|
|
518
552
|
|
|
553
|
+
return body
|
|
554
|
+
|
|
555
|
+
async def _query_completions(
|
|
556
|
+
self,
|
|
557
|
+
input: Sequence[InputItem],
|
|
558
|
+
*,
|
|
559
|
+
tools: list[ToolDefinition],
|
|
560
|
+
**kwargs: object,
|
|
561
|
+
) -> QueryResult:
|
|
562
|
+
"""
|
|
563
|
+
Completions endpoint
|
|
564
|
+
Generally not used for openai models
|
|
565
|
+
Used by providers using openai as a delegate
|
|
566
|
+
"""
|
|
567
|
+
|
|
568
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
569
|
+
|
|
519
570
|
output_text: str = ""
|
|
520
571
|
reasoning_text: str = ""
|
|
521
572
|
metadata: QueryResultMetadata = QueryResultMetadata()
|
|
@@ -628,7 +679,7 @@ class OpenAIModel(LLM):
|
|
|
628
679
|
output_text=output_text,
|
|
629
680
|
reasoning=reasoning_text,
|
|
630
681
|
tool_calls=tool_calls,
|
|
631
|
-
history=[*input, final_message],
|
|
682
|
+
history=[*input, RawResponse(response=final_message)],
|
|
632
683
|
metadata=metadata,
|
|
633
684
|
)
|
|
634
685
|
|
|
@@ -663,13 +714,17 @@ class OpenAIModel(LLM):
|
|
|
663
714
|
if not valid:
|
|
664
715
|
raise Exception("Deep research models require web search tools")
|
|
665
716
|
|
|
717
|
+
@override
|
|
666
718
|
async def build_body(
|
|
667
719
|
self,
|
|
668
720
|
input: Sequence[InputItem],
|
|
669
721
|
*,
|
|
670
|
-
tools:
|
|
722
|
+
tools: list[ToolDefinition],
|
|
671
723
|
**kwargs: object,
|
|
672
724
|
) -> dict[str, Any]:
|
|
725
|
+
if self.use_completions:
|
|
726
|
+
return await self._build_body_completions(input, tools=tools, **kwargs)
|
|
727
|
+
|
|
673
728
|
if self.deep_research:
|
|
674
729
|
await self._check_deep_research_args(tools, **kwargs)
|
|
675
730
|
|
|
@@ -701,8 +756,8 @@ class OpenAIModel(LLM):
|
|
|
701
756
|
|
|
702
757
|
if self.reasoning:
|
|
703
758
|
body["reasoning"] = {"summary": "auto"}
|
|
704
|
-
if self.reasoning_effort:
|
|
705
|
-
body["reasoning"]["effort"] = self.reasoning_effort
|
|
759
|
+
if self.reasoning_effort is not None:
|
|
760
|
+
body["reasoning"]["effort"] = self.reasoning_effort # type: ignore[reportArgumentType]
|
|
706
761
|
|
|
707
762
|
if self.supports_temperature:
|
|
708
763
|
if self.temperature is not None:
|
|
@@ -713,7 +768,6 @@ class OpenAIModel(LLM):
|
|
|
713
768
|
_ = kwargs.pop("stream", None)
|
|
714
769
|
|
|
715
770
|
body.update(kwargs)
|
|
716
|
-
|
|
717
771
|
return body
|
|
718
772
|
|
|
719
773
|
@override
|
|
@@ -722,6 +776,7 @@ class OpenAIModel(LLM):
|
|
|
722
776
|
input: Sequence[InputItem],
|
|
723
777
|
*,
|
|
724
778
|
tools: list[ToolDefinition],
|
|
779
|
+
query_logger: logging.Logger,
|
|
725
780
|
**kwargs: object,
|
|
726
781
|
) -> QueryResult:
|
|
727
782
|
if self.use_completions:
|
|
@@ -780,13 +835,12 @@ class OpenAIModel(LLM):
|
|
|
780
835
|
citations: list[Citation] = []
|
|
781
836
|
reasoning = None
|
|
782
837
|
for output in response.output:
|
|
783
|
-
if
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
citations.append(Citation(**citation.model_dump()))
|
|
838
|
+
if output.type == "message":
|
|
839
|
+
for content in output.content:
|
|
840
|
+
if not isinstance(content, ResponseOutputText):
|
|
841
|
+
continue
|
|
842
|
+
for citation in content.annotations:
|
|
843
|
+
citations.append(Citation(**citation.model_dump()))
|
|
790
844
|
|
|
791
845
|
if output.type == "reasoning":
|
|
792
846
|
reasoning = " ".join([i.text for i in output.summary])
|
|
@@ -809,7 +863,7 @@ class OpenAIModel(LLM):
|
|
|
809
863
|
output_text=response.output_text,
|
|
810
864
|
reasoning=reasoning,
|
|
811
865
|
tool_calls=tool_calls,
|
|
812
|
-
history=[*input,
|
|
866
|
+
history=[*input, RawResponse(response=response.output)],
|
|
813
867
|
extras=QueryResultExtras(citations=citations),
|
|
814
868
|
)
|
|
815
869
|
if response.usage:
|
model_library/providers/vals.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import io
|
|
4
4
|
import json
|
|
5
|
+
import logging
|
|
5
6
|
import random
|
|
6
7
|
import re
|
|
7
8
|
import time
|
|
@@ -50,7 +51,7 @@ class DummyAIBatchMixin(LLMBatchMixin):
|
|
|
50
51
|
"custom_id": custom_id,
|
|
51
52
|
"method": "",
|
|
52
53
|
"url": "",
|
|
53
|
-
"body": await self._root.
|
|
54
|
+
"body": await self._root.build_body(input, tools=[], **kwargs),
|
|
54
55
|
}
|
|
55
56
|
|
|
56
57
|
@override
|
|
@@ -226,7 +227,8 @@ class DummyAIModel(LLM):
|
|
|
226
227
|
) -> FileWithId:
|
|
227
228
|
raise NotImplementedError()
|
|
228
229
|
|
|
229
|
-
|
|
230
|
+
@override
|
|
231
|
+
async def build_body(
|
|
230
232
|
self,
|
|
231
233
|
input: Sequence[InputItem],
|
|
232
234
|
*,
|
|
@@ -271,9 +273,10 @@ class DummyAIModel(LLM):
|
|
|
271
273
|
input: Sequence[InputItem],
|
|
272
274
|
*,
|
|
273
275
|
tools: list[ToolDefinition],
|
|
276
|
+
query_logger: logging.Logger,
|
|
274
277
|
**kwargs: object,
|
|
275
278
|
) -> QueryResult:
|
|
276
|
-
body = await self.
|
|
279
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
277
280
|
|
|
278
281
|
fail_rate = FAIL_RATE
|
|
279
282
|
|