model-library 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/__init__.py +7 -3
- model_library/base/__init__.py +7 -0
- model_library/{base.py → base/base.py} +47 -423
- model_library/base/batch.py +121 -0
- model_library/base/delegate_only.py +94 -0
- model_library/base/input.py +100 -0
- model_library/base/output.py +175 -0
- model_library/base/utils.py +42 -0
- model_library/config/all_models.json +164 -2
- model_library/config/anthropic_models.yaml +4 -0
- model_library/config/deepseek_models.yaml +3 -1
- model_library/config/openai_models.yaml +48 -0
- model_library/exceptions.py +2 -0
- model_library/logging.py +30 -0
- model_library/providers/__init__.py +0 -0
- model_library/providers/ai21labs.py +2 -0
- model_library/providers/alibaba.py +16 -78
- model_library/providers/amazon.py +3 -0
- model_library/providers/anthropic.py +213 -2
- model_library/providers/azure.py +2 -0
- model_library/providers/cohere.py +14 -80
- model_library/providers/deepseek.py +14 -90
- model_library/providers/fireworks.py +17 -81
- model_library/providers/google/google.py +22 -20
- model_library/providers/inception.py +15 -83
- model_library/providers/kimi.py +15 -83
- model_library/providers/mistral.py +2 -0
- model_library/providers/openai.py +2 -0
- model_library/providers/perplexity.py +12 -79
- model_library/providers/together.py +2 -0
- model_library/providers/vals.py +2 -0
- model_library/providers/xai.py +2 -0
- model_library/providers/zai.py +15 -83
- model_library/register_models.py +75 -55
- model_library/registry_utils.py +5 -5
- model_library/utils.py +3 -28
- {model_library-0.1.0.dist-info → model_library-0.1.2.dist-info}/METADATA +36 -7
- model_library-0.1.2.dist-info/RECORD +61 -0
- model_library-0.1.0.dist-info/RECORD +0 -53
- {model_library-0.1.0.dist-info → model_library-0.1.2.dist-info}/WHEEL +0 -0
- {model_library-0.1.0.dist-info → model_library-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.0.dist-info → model_library-0.1.2.dist-info}/top_level.txt +0 -0
|
@@ -31,6 +31,34 @@ gpt-5-models:
|
|
|
31
31
|
supports_temperature: false
|
|
32
32
|
supports_files: true
|
|
33
33
|
supports_tools: true
|
|
34
|
+
|
|
35
|
+
openai/gpt-5.1-codex:
|
|
36
|
+
label: GPT 5.1 Codex
|
|
37
|
+
documentation_url: https://platform.openai.com/docs/models/gpt-5.1-codex
|
|
38
|
+
description: OpenAI's latest coding model
|
|
39
|
+
release_date: 2025-11-13
|
|
40
|
+
costs_per_million_token:
|
|
41
|
+
input: 1.25
|
|
42
|
+
output: 10.0
|
|
43
|
+
cache:
|
|
44
|
+
read: 0.125
|
|
45
|
+
default_parameters:
|
|
46
|
+
temperature: 1
|
|
47
|
+
max_output_tokens: 128_000
|
|
48
|
+
|
|
49
|
+
openai/gpt-5.1-codex-mini:
|
|
50
|
+
label: GPT 5.1 Codex Mini
|
|
51
|
+
documentation_url: https://platform.openai.com/docs/models/gpt-5.1-codex-mini
|
|
52
|
+
description: OpenAI's miniature coding model
|
|
53
|
+
release_date: 2025-11-13
|
|
54
|
+
costs_per_million_token:
|
|
55
|
+
input: 0.25
|
|
56
|
+
output: 2.00
|
|
57
|
+
cache:
|
|
58
|
+
read: 0.025
|
|
59
|
+
default_parameters:
|
|
60
|
+
temperature: 1
|
|
61
|
+
max_output_tokens: 128_000
|
|
34
62
|
|
|
35
63
|
openai/gpt-5-codex:
|
|
36
64
|
label: GPT 5 Codex
|
|
@@ -51,6 +79,26 @@ gpt-5-models:
|
|
|
51
79
|
temperature: 1
|
|
52
80
|
max_output_tokens: 128_000
|
|
53
81
|
|
|
82
|
+
|
|
83
|
+
openai/gpt-5.1-2025-11-13:
|
|
84
|
+
label: GPT 5.1
|
|
85
|
+
documentation_url: https://platform.openai.com/docs/models/gpt-5.1
|
|
86
|
+
description: GPT-5.1 is OpenAI's flagship model for coding and agentic tasks with configurable reasoning and non-reasoning effort.
|
|
87
|
+
release_date: 2025-11-13
|
|
88
|
+
costs_per_million_token:
|
|
89
|
+
input: 1.25
|
|
90
|
+
output: 10
|
|
91
|
+
cache:
|
|
92
|
+
read: 0.125
|
|
93
|
+
properties:
|
|
94
|
+
training_cutoff: "2024-09"
|
|
95
|
+
class_properties:
|
|
96
|
+
available_as_evaluator: true
|
|
97
|
+
supports_images: true
|
|
98
|
+
default_parameters:
|
|
99
|
+
temperature: 1
|
|
100
|
+
max_output_tokens: 128_000
|
|
101
|
+
|
|
54
102
|
openai/gpt-5-2025-08-07:
|
|
55
103
|
label: GPT 5
|
|
56
104
|
documentation_url: https://platform.openai.com/docs/models/gpt-5
|
model_library/exceptions.py
CHANGED
model_library/logging.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from rich.console import Console
|
|
4
|
+
from rich.logging import RichHandler
|
|
5
|
+
|
|
6
|
+
_llm_logger = logging.getLogger("llm")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def set_logging(enable: bool = True, handler: logging.Handler | None = None):
|
|
10
|
+
"""
|
|
11
|
+
Sets up logging for the model library
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
enable (bool): Enable or disable logging.
|
|
15
|
+
handler (logging.Handler, optional): A custom logging handler. Defaults to RichHandler.
|
|
16
|
+
"""
|
|
17
|
+
if enable:
|
|
18
|
+
_llm_logger.setLevel(logging.INFO)
|
|
19
|
+
else:
|
|
20
|
+
_llm_logger.setLevel(logging.CRITICAL)
|
|
21
|
+
|
|
22
|
+
if not enable or _llm_logger.hasHandlers():
|
|
23
|
+
return
|
|
24
|
+
|
|
25
|
+
if handler is None:
|
|
26
|
+
console = Console()
|
|
27
|
+
handler = RichHandler(console=console, markup=True, show_time=False)
|
|
28
|
+
|
|
29
|
+
handler.setFormatter(logging.Formatter("%(name)s - %(levelname)s - %(message)s"))
|
|
30
|
+
_llm_logger.addHandler(handler)
|
|
File without changes
|
|
@@ -26,9 +26,11 @@ from model_library.exceptions import (
|
|
|
26
26
|
MaxOutputTokensExceededError,
|
|
27
27
|
ModelNoOutputError,
|
|
28
28
|
)
|
|
29
|
+
from model_library.register_models import register_provider
|
|
29
30
|
from model_library.utils import default_httpx_client
|
|
30
31
|
|
|
31
32
|
|
|
33
|
+
@register_provider("ai21labs")
|
|
32
34
|
class AI21LabsModel(LLM):
|
|
33
35
|
_client: AsyncAI21Client | None = None
|
|
34
36
|
|
|
@@ -1,29 +1,21 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import Any, Literal, Sequence
|
|
1
|
+
from typing import Literal
|
|
3
2
|
|
|
4
3
|
from typing_extensions import override
|
|
5
4
|
|
|
6
5
|
from model_library import model_library_settings
|
|
7
6
|
from model_library.base import (
|
|
8
|
-
|
|
9
|
-
FileInput,
|
|
10
|
-
FileWithId,
|
|
11
|
-
InputItem,
|
|
7
|
+
DelegateOnly,
|
|
12
8
|
LLMConfig,
|
|
13
|
-
QueryResult,
|
|
14
9
|
QueryResultCost,
|
|
15
10
|
QueryResultMetadata,
|
|
16
|
-
ToolDefinition,
|
|
17
11
|
)
|
|
18
12
|
from model_library.providers.openai import OpenAIModel
|
|
13
|
+
from model_library.register_models import register_provider
|
|
19
14
|
from model_library.utils import create_openai_client_with_defaults
|
|
20
15
|
|
|
21
16
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def get_client(self) -> None:
|
|
25
|
-
raise NotImplementedError("Not implemented")
|
|
26
|
-
|
|
17
|
+
@register_provider("alibaba")
|
|
18
|
+
class AlibabaModel(DelegateOnly):
|
|
27
19
|
def __init__(
|
|
28
20
|
self,
|
|
29
21
|
model_name: str,
|
|
@@ -32,23 +24,20 @@ class AlibabaModel(LLM):
|
|
|
32
24
|
config: LLMConfig | None = None,
|
|
33
25
|
):
|
|
34
26
|
super().__init__(model_name, provider, config=config)
|
|
35
|
-
self.native: bool = False
|
|
36
27
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
),
|
|
48
|
-
use_completions=True,
|
|
49
|
-
)
|
|
28
|
+
# https://www.alibabacloud.com/help/en/model-studio/first-api-call-to-qwen
|
|
29
|
+
self.delegate = OpenAIModel(
|
|
30
|
+
model_name=self.model_name,
|
|
31
|
+
provider=self.provider,
|
|
32
|
+
config=config,
|
|
33
|
+
custom_client=create_openai_client_with_defaults(
|
|
34
|
+
api_key=model_library_settings.DASHSCOPE_API_KEY,
|
|
35
|
+
base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
|
36
|
+
),
|
|
37
|
+
use_completions=True,
|
|
50
38
|
)
|
|
51
39
|
|
|
40
|
+
@override
|
|
52
41
|
async def _calculate_cost(
|
|
53
42
|
self,
|
|
54
43
|
metadata: QueryResultMetadata,
|
|
@@ -94,54 +83,3 @@ class AlibabaModel(LLM):
|
|
|
94
83
|
else None,
|
|
95
84
|
cache_write=None,
|
|
96
85
|
)
|
|
97
|
-
|
|
98
|
-
@override
|
|
99
|
-
async def parse_input(
|
|
100
|
-
self,
|
|
101
|
-
input: Sequence[InputItem],
|
|
102
|
-
**kwargs: Any,
|
|
103
|
-
) -> Any:
|
|
104
|
-
raise NotImplementedError()
|
|
105
|
-
|
|
106
|
-
@override
|
|
107
|
-
async def parse_image(
|
|
108
|
-
self,
|
|
109
|
-
image: FileInput,
|
|
110
|
-
) -> Any:
|
|
111
|
-
raise NotImplementedError()
|
|
112
|
-
|
|
113
|
-
@override
|
|
114
|
-
async def parse_file(
|
|
115
|
-
self,
|
|
116
|
-
file: FileInput,
|
|
117
|
-
) -> Any:
|
|
118
|
-
raise NotImplementedError()
|
|
119
|
-
|
|
120
|
-
@override
|
|
121
|
-
async def parse_tools(
|
|
122
|
-
self,
|
|
123
|
-
tools: list[ToolDefinition],
|
|
124
|
-
) -> Any:
|
|
125
|
-
raise NotImplementedError()
|
|
126
|
-
|
|
127
|
-
@override
|
|
128
|
-
async def upload_file(
|
|
129
|
-
self,
|
|
130
|
-
name: str,
|
|
131
|
-
mime: str,
|
|
132
|
-
bytes: io.BytesIO,
|
|
133
|
-
type: Literal["image", "file"] = "file",
|
|
134
|
-
) -> FileWithId:
|
|
135
|
-
raise NotImplementedError()
|
|
136
|
-
|
|
137
|
-
@override
|
|
138
|
-
async def _query_impl(
|
|
139
|
-
self,
|
|
140
|
-
input: Sequence[InputItem],
|
|
141
|
-
*,
|
|
142
|
-
tools: list[ToolDefinition],
|
|
143
|
-
**kwargs: object,
|
|
144
|
-
) -> QueryResult:
|
|
145
|
-
if self.delegate:
|
|
146
|
-
return await self.delegate_query(input, tools=tools, **kwargs)
|
|
147
|
-
raise NotImplementedError()
|
|
@@ -31,8 +31,11 @@ from model_library.exceptions import (
|
|
|
31
31
|
MaxOutputTokensExceededError,
|
|
32
32
|
)
|
|
33
33
|
from model_library.model_utils import get_default_budget_tokens
|
|
34
|
+
from model_library.register_models import register_provider
|
|
34
35
|
|
|
35
36
|
|
|
37
|
+
@register_provider("amazon")
|
|
38
|
+
@register_provider("bedrock")
|
|
36
39
|
class AmazonModel(LLM):
|
|
37
40
|
_client: BaseClient | None = None
|
|
38
41
|
|
|
@@ -5,14 +5,18 @@ from anthropic import AsyncAnthropic
|
|
|
5
5
|
from anthropic.types import TextBlock, ToolUseBlock
|
|
6
6
|
from anthropic.types.beta.beta_tool_use_block import BetaToolUseBlock
|
|
7
7
|
from anthropic.types.message import Message
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
8
10
|
from model_library import model_library_settings
|
|
9
11
|
from model_library.base import (
|
|
10
12
|
LLM,
|
|
13
|
+
BatchResult,
|
|
11
14
|
FileInput,
|
|
12
15
|
FileWithBase64,
|
|
13
16
|
FileWithId,
|
|
14
17
|
FileWithUrl,
|
|
15
18
|
InputItem,
|
|
19
|
+
LLMBatchMixin,
|
|
16
20
|
LLMConfig,
|
|
17
21
|
QueryResult,
|
|
18
22
|
QueryResultCost,
|
|
@@ -29,15 +33,217 @@ from model_library.exceptions import (
|
|
|
29
33
|
)
|
|
30
34
|
from model_library.model_utils import get_default_budget_tokens
|
|
31
35
|
from model_library.providers.openai import OpenAIModel
|
|
36
|
+
from model_library.register_models import register_provider
|
|
32
37
|
from model_library.utils import (
|
|
33
38
|
create_openai_client_with_defaults,
|
|
34
39
|
default_httpx_client,
|
|
35
40
|
filter_empty_text_blocks,
|
|
36
41
|
normalize_tool_result,
|
|
37
42
|
)
|
|
38
|
-
from typing_extensions import override
|
|
39
43
|
|
|
40
44
|
|
|
45
|
+
class AnthropicBatchMixin(LLMBatchMixin):
|
|
46
|
+
"""Batch processing support for Anthropic's Message Batches API."""
|
|
47
|
+
|
|
48
|
+
COMPLETED_RESULT_TYPES = ["succeeded", "errored", "canceled", "expired"]
|
|
49
|
+
|
|
50
|
+
def __init__(self, model: "AnthropicModel"):
|
|
51
|
+
self._root = model
|
|
52
|
+
|
|
53
|
+
@override
|
|
54
|
+
async def create_batch_query_request(
|
|
55
|
+
self,
|
|
56
|
+
custom_id: str,
|
|
57
|
+
input: Sequence[InputItem],
|
|
58
|
+
**kwargs: object,
|
|
59
|
+
) -> dict[str, Any]:
|
|
60
|
+
"""Create a single batch request in Anthropic's format.
|
|
61
|
+
|
|
62
|
+
Format: {"custom_id": str, "params": {...message params...}}
|
|
63
|
+
"""
|
|
64
|
+
# Build the message body using the parent model's create_body method
|
|
65
|
+
tools = cast(list[ToolDefinition], kwargs.pop("tools", []))
|
|
66
|
+
body = await self._root.create_body(input, tools=tools, **kwargs)
|
|
67
|
+
|
|
68
|
+
return {
|
|
69
|
+
"custom_id": custom_id,
|
|
70
|
+
"params": body,
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
@override
|
|
74
|
+
async def batch_query(
|
|
75
|
+
self,
|
|
76
|
+
batch_name: str,
|
|
77
|
+
requests: list[dict[str, Any]],
|
|
78
|
+
) -> str:
|
|
79
|
+
"""Submit a batch of requests to Anthropic's Message Batches API.
|
|
80
|
+
|
|
81
|
+
Returns the batch ID for status tracking.
|
|
82
|
+
"""
|
|
83
|
+
client = self._root.get_client()
|
|
84
|
+
|
|
85
|
+
# Create the batch using Anthropic's batches API
|
|
86
|
+
batch = await client.messages.batches.create(
|
|
87
|
+
requests=cast(Any, requests), # Type mismatch in SDK, cast to Any
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
self._root.logger.info(
|
|
91
|
+
f"Created Anthropic batch {batch.id} with {len(requests)} requests"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
return batch.id
|
|
95
|
+
|
|
96
|
+
@override
|
|
97
|
+
async def get_batch_results(self, batch_id: str) -> list[BatchResult]:
|
|
98
|
+
"""Retrieve results from a completed batch.
|
|
99
|
+
|
|
100
|
+
Streams results using the SDK's batches.results() method.
|
|
101
|
+
"""
|
|
102
|
+
client = self._root.get_client()
|
|
103
|
+
|
|
104
|
+
# Get batch status to verify it's completed
|
|
105
|
+
batch = await client.messages.batches.retrieve(batch_id)
|
|
106
|
+
|
|
107
|
+
if batch.processing_status != "ended":
|
|
108
|
+
raise ValueError(
|
|
109
|
+
f"Batch {batch_id} is not completed yet. Status: {batch.processing_status}"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Stream results using the SDK's results method
|
|
113
|
+
batch_results: list[BatchResult] = []
|
|
114
|
+
async for result_item in await client.messages.batches.results(batch_id):
|
|
115
|
+
# result_item is a MessageBatchIndividualResponse - convert to dict
|
|
116
|
+
result_dict = result_item.model_dump()
|
|
117
|
+
custom_id = cast(str, result_dict["custom_id"])
|
|
118
|
+
result_type = cast(str, result_dict["result"]["type"])
|
|
119
|
+
|
|
120
|
+
if result_type not in self.COMPLETED_RESULT_TYPES:
|
|
121
|
+
self._root.logger.warning(
|
|
122
|
+
f"Unknown result type '{result_type}' for request {custom_id}"
|
|
123
|
+
)
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
if result_type == "succeeded":
|
|
127
|
+
# Extract the message from the successful result
|
|
128
|
+
message_data = cast(dict[str, Any], result_dict["result"]["message"])
|
|
129
|
+
|
|
130
|
+
# Parse the message content to extract text, reasoning, and tool calls
|
|
131
|
+
text = ""
|
|
132
|
+
reasoning = ""
|
|
133
|
+
tool_calls: list[ToolCall] = []
|
|
134
|
+
|
|
135
|
+
for content in message_data.get("content", []):
|
|
136
|
+
if content.get("type") == "text":
|
|
137
|
+
text += content.get("text", "")
|
|
138
|
+
elif content.get("type") == "thinking":
|
|
139
|
+
reasoning += content.get("thinking", "")
|
|
140
|
+
elif content.get("type") == "tool_use":
|
|
141
|
+
tool_calls.append(
|
|
142
|
+
ToolCall(
|
|
143
|
+
id=content["id"],
|
|
144
|
+
name=content["name"],
|
|
145
|
+
args=content.get("input", {}),
|
|
146
|
+
)
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Extract usage information
|
|
150
|
+
usage = message_data.get("usage", {})
|
|
151
|
+
metadata = QueryResultMetadata(
|
|
152
|
+
in_tokens=usage.get("input_tokens", 0),
|
|
153
|
+
out_tokens=usage.get("output_tokens", 0),
|
|
154
|
+
cache_read_tokens=usage.get("cache_read_input_tokens", 0),
|
|
155
|
+
cache_write_tokens=usage.get("cache_creation_input_tokens", 0),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
query_result = QueryResult(
|
|
159
|
+
output_text=text,
|
|
160
|
+
reasoning=reasoning,
|
|
161
|
+
metadata=metadata,
|
|
162
|
+
tool_calls=tool_calls,
|
|
163
|
+
history=[], # History not available in batch results
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
batch_results.append(
|
|
167
|
+
BatchResult(
|
|
168
|
+
custom_id=custom_id,
|
|
169
|
+
output=query_result,
|
|
170
|
+
)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
elif result_type == "errored":
|
|
174
|
+
# Handle errored results
|
|
175
|
+
error = cast(dict[str, Any], result_dict["result"]["error"])
|
|
176
|
+
error_message = f"{error.get('type', 'unknown_error')}: {error.get('message', 'Unknown error')}"
|
|
177
|
+
output = QueryResult(output_text=error_message)
|
|
178
|
+
batch_results.append(
|
|
179
|
+
BatchResult(
|
|
180
|
+
custom_id=custom_id,
|
|
181
|
+
output=output,
|
|
182
|
+
error_message=error_message,
|
|
183
|
+
)
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
elif result_type in ["canceled", "expired"]:
|
|
187
|
+
# Handle canceled/expired results
|
|
188
|
+
error_message = f"Request {result_type}"
|
|
189
|
+
batch_results.append(
|
|
190
|
+
BatchResult(
|
|
191
|
+
custom_id=custom_id,
|
|
192
|
+
output=QueryResult(output_text=""),
|
|
193
|
+
error_message=error_message,
|
|
194
|
+
)
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return batch_results
|
|
198
|
+
|
|
199
|
+
@override
|
|
200
|
+
async def get_batch_progress(self, batch_id: str) -> int:
|
|
201
|
+
"""Get the number of completed requests in a batch."""
|
|
202
|
+
client = self._root.get_client()
|
|
203
|
+
batch = await client.messages.batches.retrieve(batch_id)
|
|
204
|
+
|
|
205
|
+
# Return the number of processed requests
|
|
206
|
+
request_counts = batch.request_counts
|
|
207
|
+
return (
|
|
208
|
+
request_counts.succeeded
|
|
209
|
+
+ request_counts.errored
|
|
210
|
+
+ request_counts.canceled
|
|
211
|
+
+ request_counts.expired
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
@override
|
|
215
|
+
async def cancel_batch_request(self, batch_id: str) -> None:
|
|
216
|
+
"""Cancel a running batch request."""
|
|
217
|
+
client = self._root.get_client()
|
|
218
|
+
await client.messages.batches.cancel(batch_id)
|
|
219
|
+
self._root.logger.info(f"Canceled Anthropic batch {batch_id}")
|
|
220
|
+
|
|
221
|
+
@override
|
|
222
|
+
async def get_batch_status(self, batch_id: str) -> str:
|
|
223
|
+
"""Get the current status of a batch."""
|
|
224
|
+
client = self._root.get_client()
|
|
225
|
+
batch = await client.messages.batches.retrieve(batch_id)
|
|
226
|
+
return batch.processing_status
|
|
227
|
+
|
|
228
|
+
@classmethod
|
|
229
|
+
def is_batch_status_completed(cls, batch_status: str) -> bool:
|
|
230
|
+
"""Check if a batch status indicates completion."""
|
|
231
|
+
return batch_status == "ended"
|
|
232
|
+
|
|
233
|
+
@classmethod
|
|
234
|
+
def is_batch_status_failed(cls, batch_status: str) -> bool:
|
|
235
|
+
"""Check if a batch status indicates failure."""
|
|
236
|
+
# Anthropic batches can have individual request failures but the batch
|
|
237
|
+
# itself doesn't have a "failed" status - it just ends
|
|
238
|
+
return False
|
|
239
|
+
|
|
240
|
+
@classmethod
|
|
241
|
+
def is_batch_status_cancelled(cls, batch_status: str) -> bool:
|
|
242
|
+
"""Check if a batch status indicates cancellation."""
|
|
243
|
+
return batch_status == "canceling" or batch_status == "canceled"
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
@register_provider("anthropic")
|
|
41
247
|
class AnthropicModel(LLM):
|
|
42
248
|
_client: AsyncAnthropic | None = None
|
|
43
249
|
|
|
@@ -78,6 +284,12 @@ class AnthropicModel(LLM):
|
|
|
78
284
|
)
|
|
79
285
|
)
|
|
80
286
|
|
|
287
|
+
# Initialize batch support if enabled
|
|
288
|
+
self.supports_batch: bool = self.supports_batch and self.native
|
|
289
|
+
self.batch: LLMBatchMixin | None = (
|
|
290
|
+
AnthropicBatchMixin(self) if self.supports_batch else None
|
|
291
|
+
)
|
|
292
|
+
|
|
81
293
|
@override
|
|
82
294
|
async def parse_input(
|
|
83
295
|
self,
|
|
@@ -88,7 +300,6 @@ class AnthropicModel(LLM):
|
|
|
88
300
|
content_user: list[dict[str, Any]] = []
|
|
89
301
|
|
|
90
302
|
# First pass: collect all tool calls from Message objects for validation
|
|
91
|
-
# This handles both Message and BetaMessage types
|
|
92
303
|
tool_calls_in_input: set[str] = set()
|
|
93
304
|
for item in input:
|
|
94
305
|
if hasattr(item, "content") and hasattr(item, "role"):
|
model_library/providers/azure.py
CHANGED
|
@@ -8,9 +8,11 @@ from model_library.base import (
|
|
|
8
8
|
LLMConfig,
|
|
9
9
|
)
|
|
10
10
|
from model_library.providers.openai import OpenAIModel
|
|
11
|
+
from model_library.register_models import register_provider
|
|
11
12
|
from model_library.utils import default_httpx_client
|
|
12
13
|
|
|
13
14
|
|
|
15
|
+
@register_provider("azure")
|
|
14
16
|
class AzureOpenAIModel(OpenAIModel):
|
|
15
17
|
_azure_client: AsyncAzureOpenAI | None = None
|
|
16
18
|
|
|
@@ -1,27 +1,17 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import Any, Literal, Sequence
|
|
3
|
-
|
|
4
|
-
from typing_extensions import override
|
|
1
|
+
from typing import Literal
|
|
5
2
|
|
|
6
3
|
from model_library import model_library_settings
|
|
7
4
|
from model_library.base import (
|
|
8
|
-
|
|
9
|
-
FileInput,
|
|
10
|
-
FileWithId,
|
|
11
|
-
InputItem,
|
|
5
|
+
DelegateOnly,
|
|
12
6
|
LLMConfig,
|
|
13
|
-
QueryResult,
|
|
14
|
-
ToolDefinition,
|
|
15
7
|
)
|
|
16
8
|
from model_library.providers.openai import OpenAIModel
|
|
9
|
+
from model_library.register_models import register_provider
|
|
17
10
|
from model_library.utils import create_openai_client_with_defaults
|
|
18
11
|
|
|
19
12
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def get_client(self) -> None:
|
|
23
|
-
raise NotImplementedError("Not implemented")
|
|
24
|
-
|
|
13
|
+
@register_provider("cohere")
|
|
14
|
+
class CohereModel(DelegateOnly):
|
|
25
15
|
def __init__(
|
|
26
16
|
self,
|
|
27
17
|
model_name: str,
|
|
@@ -30,71 +20,15 @@ class CohereModel(LLM):
|
|
|
30
20
|
config: LLMConfig | None = None,
|
|
31
21
|
):
|
|
32
22
|
super().__init__(model_name, provider, config=config)
|
|
33
|
-
self.native: bool = False
|
|
34
23
|
|
|
35
24
|
# https://docs.cohere.com/docs/compatibility-api
|
|
36
|
-
self.delegate
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
base_url="https://api.cohere.ai/compatibility/v1",
|
|
46
|
-
),
|
|
47
|
-
use_completions=True,
|
|
48
|
-
)
|
|
25
|
+
self.delegate = OpenAIModel(
|
|
26
|
+
model_name=self.model_name,
|
|
27
|
+
provider=self.provider,
|
|
28
|
+
config=config,
|
|
29
|
+
custom_client=create_openai_client_with_defaults(
|
|
30
|
+
api_key=model_library_settings.COHERE_API_KEY,
|
|
31
|
+
base_url="https://api.cohere.ai/compatibility/v1",
|
|
32
|
+
),
|
|
33
|
+
use_completions=True,
|
|
49
34
|
)
|
|
50
|
-
|
|
51
|
-
@override
|
|
52
|
-
async def parse_input(
|
|
53
|
-
self,
|
|
54
|
-
input: Sequence[InputItem],
|
|
55
|
-
**kwargs: Any,
|
|
56
|
-
) -> Any:
|
|
57
|
-
raise NotImplementedError()
|
|
58
|
-
|
|
59
|
-
@override
|
|
60
|
-
async def parse_image(
|
|
61
|
-
self,
|
|
62
|
-
image: FileInput,
|
|
63
|
-
) -> Any:
|
|
64
|
-
raise NotImplementedError()
|
|
65
|
-
|
|
66
|
-
@override
|
|
67
|
-
async def parse_file(
|
|
68
|
-
self,
|
|
69
|
-
file: FileInput,
|
|
70
|
-
) -> Any:
|
|
71
|
-
raise NotImplementedError()
|
|
72
|
-
|
|
73
|
-
@override
|
|
74
|
-
async def parse_tools(
|
|
75
|
-
self,
|
|
76
|
-
tools: list[ToolDefinition],
|
|
77
|
-
) -> Any:
|
|
78
|
-
raise NotImplementedError()
|
|
79
|
-
|
|
80
|
-
@override
|
|
81
|
-
async def upload_file(
|
|
82
|
-
self,
|
|
83
|
-
name: str,
|
|
84
|
-
mime: str,
|
|
85
|
-
bytes: io.BytesIO,
|
|
86
|
-
type: Literal["image", "file"] = "file",
|
|
87
|
-
) -> FileWithId:
|
|
88
|
-
raise NotImplementedError()
|
|
89
|
-
|
|
90
|
-
@override
|
|
91
|
-
async def _query_impl(
|
|
92
|
-
self,
|
|
93
|
-
input: Sequence[InputItem],
|
|
94
|
-
*,
|
|
95
|
-
tools: list[ToolDefinition],
|
|
96
|
-
**kwargs: object,
|
|
97
|
-
) -> QueryResult:
|
|
98
|
-
if self.delegate:
|
|
99
|
-
return await self.delegate_query(input, tools=tools, **kwargs)
|
|
100
|
-
raise NotImplementedError()
|