model-library 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/__init__.py +7 -0
- model_library/{base.py → base/base.py} +58 -429
- model_library/base/batch.py +121 -0
- model_library/base/delegate_only.py +94 -0
- model_library/base/input.py +100 -0
- model_library/base/output.py +229 -0
- model_library/base/utils.py +43 -0
- model_library/config/ai21labs_models.yaml +1 -0
- model_library/config/all_models.json +461 -36
- model_library/config/anthropic_models.yaml +30 -3
- model_library/config/deepseek_models.yaml +3 -1
- model_library/config/google_models.yaml +49 -0
- model_library/config/openai_models.yaml +43 -4
- model_library/config/together_models.yaml +1 -0
- model_library/config/xai_models.yaml +63 -3
- model_library/exceptions.py +8 -2
- model_library/file_utils.py +1 -1
- model_library/providers/__init__.py +0 -0
- model_library/providers/ai21labs.py +2 -0
- model_library/providers/alibaba.py +16 -78
- model_library/providers/amazon.py +3 -0
- model_library/providers/anthropic.py +215 -8
- model_library/providers/azure.py +2 -0
- model_library/providers/cohere.py +14 -80
- model_library/providers/deepseek.py +14 -90
- model_library/providers/fireworks.py +17 -81
- model_library/providers/google/google.py +55 -47
- model_library/providers/inception.py +15 -83
- model_library/providers/kimi.py +15 -83
- model_library/providers/mistral.py +2 -0
- model_library/providers/openai.py +10 -2
- model_library/providers/perplexity.py +12 -79
- model_library/providers/together.py +19 -210
- model_library/providers/vals.py +2 -0
- model_library/providers/xai.py +2 -0
- model_library/providers/zai.py +15 -83
- model_library/register_models.py +75 -57
- model_library/registry_utils.py +5 -5
- model_library/utils.py +3 -28
- {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/METADATA +2 -3
- model_library-0.1.3.dist-info/RECORD +61 -0
- model_library-0.1.1.dist-info/RECORD +0 -54
- {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/WHEEL +0 -0
- {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/top_level.txt +0 -0
|
@@ -5,14 +5,18 @@ from anthropic import AsyncAnthropic
|
|
|
5
5
|
from anthropic.types import TextBlock, ToolUseBlock
|
|
6
6
|
from anthropic.types.beta.beta_tool_use_block import BetaToolUseBlock
|
|
7
7
|
from anthropic.types.message import Message
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
8
10
|
from model_library import model_library_settings
|
|
9
11
|
from model_library.base import (
|
|
10
12
|
LLM,
|
|
13
|
+
BatchResult,
|
|
11
14
|
FileInput,
|
|
12
15
|
FileWithBase64,
|
|
13
16
|
FileWithId,
|
|
14
17
|
FileWithUrl,
|
|
15
18
|
InputItem,
|
|
19
|
+
LLMBatchMixin,
|
|
16
20
|
LLMConfig,
|
|
17
21
|
QueryResult,
|
|
18
22
|
QueryResultCost,
|
|
@@ -29,15 +33,217 @@ from model_library.exceptions import (
|
|
|
29
33
|
)
|
|
30
34
|
from model_library.model_utils import get_default_budget_tokens
|
|
31
35
|
from model_library.providers.openai import OpenAIModel
|
|
36
|
+
from model_library.register_models import register_provider
|
|
32
37
|
from model_library.utils import (
|
|
33
38
|
create_openai_client_with_defaults,
|
|
34
39
|
default_httpx_client,
|
|
35
40
|
filter_empty_text_blocks,
|
|
36
41
|
normalize_tool_result,
|
|
37
42
|
)
|
|
38
|
-
from typing_extensions import override
|
|
39
43
|
|
|
40
44
|
|
|
45
|
+
class AnthropicBatchMixin(LLMBatchMixin):
|
|
46
|
+
"""Batch processing support for Anthropic's Message Batches API."""
|
|
47
|
+
|
|
48
|
+
COMPLETED_RESULT_TYPES = ["succeeded", "errored", "canceled", "expired"]
|
|
49
|
+
|
|
50
|
+
def __init__(self, model: "AnthropicModel"):
|
|
51
|
+
self._root = model
|
|
52
|
+
|
|
53
|
+
@override
|
|
54
|
+
async def create_batch_query_request(
|
|
55
|
+
self,
|
|
56
|
+
custom_id: str,
|
|
57
|
+
input: Sequence[InputItem],
|
|
58
|
+
**kwargs: object,
|
|
59
|
+
) -> dict[str, Any]:
|
|
60
|
+
"""Create a single batch request in Anthropic's format.
|
|
61
|
+
|
|
62
|
+
Format: {"custom_id": str, "params": {...message params...}}
|
|
63
|
+
"""
|
|
64
|
+
# Build the message body using the parent model's create_body method
|
|
65
|
+
tools = cast(list[ToolDefinition], kwargs.pop("tools", []))
|
|
66
|
+
body = await self._root.create_body(input, tools=tools, **kwargs)
|
|
67
|
+
|
|
68
|
+
return {
|
|
69
|
+
"custom_id": custom_id,
|
|
70
|
+
"params": body,
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
@override
|
|
74
|
+
async def batch_query(
|
|
75
|
+
self,
|
|
76
|
+
batch_name: str,
|
|
77
|
+
requests: list[dict[str, Any]],
|
|
78
|
+
) -> str:
|
|
79
|
+
"""Submit a batch of requests to Anthropic's Message Batches API.
|
|
80
|
+
|
|
81
|
+
Returns the batch ID for status tracking.
|
|
82
|
+
"""
|
|
83
|
+
client = self._root.get_client()
|
|
84
|
+
|
|
85
|
+
# Create the batch using Anthropic's batches API
|
|
86
|
+
batch = await client.messages.batches.create(
|
|
87
|
+
requests=cast(Any, requests), # Type mismatch in SDK, cast to Any
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
self._root.logger.info(
|
|
91
|
+
f"Created Anthropic batch {batch.id} with {len(requests)} requests"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
return batch.id
|
|
95
|
+
|
|
96
|
+
@override
|
|
97
|
+
async def get_batch_results(self, batch_id: str) -> list[BatchResult]:
|
|
98
|
+
"""Retrieve results from a completed batch.
|
|
99
|
+
|
|
100
|
+
Streams results using the SDK's batches.results() method.
|
|
101
|
+
"""
|
|
102
|
+
client = self._root.get_client()
|
|
103
|
+
|
|
104
|
+
# Get batch status to verify it's completed
|
|
105
|
+
batch = await client.messages.batches.retrieve(batch_id)
|
|
106
|
+
|
|
107
|
+
if batch.processing_status != "ended":
|
|
108
|
+
raise ValueError(
|
|
109
|
+
f"Batch {batch_id} is not completed yet. Status: {batch.processing_status}"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Stream results using the SDK's results method
|
|
113
|
+
batch_results: list[BatchResult] = []
|
|
114
|
+
async for result_item in await client.messages.batches.results(batch_id):
|
|
115
|
+
# result_item is a MessageBatchIndividualResponse - convert to dict
|
|
116
|
+
result_dict = result_item.model_dump()
|
|
117
|
+
custom_id = cast(str, result_dict["custom_id"])
|
|
118
|
+
result_type = cast(str, result_dict["result"]["type"])
|
|
119
|
+
|
|
120
|
+
if result_type not in self.COMPLETED_RESULT_TYPES:
|
|
121
|
+
self._root.logger.warning(
|
|
122
|
+
f"Unknown result type '{result_type}' for request {custom_id}"
|
|
123
|
+
)
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
if result_type == "succeeded":
|
|
127
|
+
# Extract the message from the successful result
|
|
128
|
+
message_data = cast(dict[str, Any], result_dict["result"]["message"])
|
|
129
|
+
|
|
130
|
+
# Parse the message content to extract text, reasoning, and tool calls
|
|
131
|
+
text = ""
|
|
132
|
+
reasoning = ""
|
|
133
|
+
tool_calls: list[ToolCall] = []
|
|
134
|
+
|
|
135
|
+
for content in message_data.get("content", []):
|
|
136
|
+
if content.get("type") == "text":
|
|
137
|
+
text += content.get("text", "")
|
|
138
|
+
elif content.get("type") == "thinking":
|
|
139
|
+
reasoning += content.get("thinking", "")
|
|
140
|
+
elif content.get("type") == "tool_use":
|
|
141
|
+
tool_calls.append(
|
|
142
|
+
ToolCall(
|
|
143
|
+
id=content["id"],
|
|
144
|
+
name=content["name"],
|
|
145
|
+
args=content.get("input", {}),
|
|
146
|
+
)
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Extract usage information
|
|
150
|
+
usage = message_data.get("usage", {})
|
|
151
|
+
metadata = QueryResultMetadata(
|
|
152
|
+
in_tokens=usage.get("input_tokens", 0),
|
|
153
|
+
out_tokens=usage.get("output_tokens", 0),
|
|
154
|
+
cache_read_tokens=usage.get("cache_read_input_tokens", 0),
|
|
155
|
+
cache_write_tokens=usage.get("cache_creation_input_tokens", 0),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
query_result = QueryResult(
|
|
159
|
+
output_text=text,
|
|
160
|
+
reasoning=reasoning,
|
|
161
|
+
metadata=metadata,
|
|
162
|
+
tool_calls=tool_calls,
|
|
163
|
+
history=[], # History not available in batch results
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
batch_results.append(
|
|
167
|
+
BatchResult(
|
|
168
|
+
custom_id=custom_id,
|
|
169
|
+
output=query_result,
|
|
170
|
+
)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
elif result_type == "errored":
|
|
174
|
+
# Handle errored results
|
|
175
|
+
error = cast(dict[str, Any], result_dict["result"]["error"])
|
|
176
|
+
error_message = f"{error.get('type', 'unknown_error')}: {error.get('message', 'Unknown error')}"
|
|
177
|
+
output = QueryResult(output_text=error_message)
|
|
178
|
+
batch_results.append(
|
|
179
|
+
BatchResult(
|
|
180
|
+
custom_id=custom_id,
|
|
181
|
+
output=output,
|
|
182
|
+
error_message=error_message,
|
|
183
|
+
)
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
elif result_type in ["canceled", "expired"]:
|
|
187
|
+
# Handle canceled/expired results
|
|
188
|
+
error_message = f"Request {result_type}"
|
|
189
|
+
batch_results.append(
|
|
190
|
+
BatchResult(
|
|
191
|
+
custom_id=custom_id,
|
|
192
|
+
output=QueryResult(output_text=""),
|
|
193
|
+
error_message=error_message,
|
|
194
|
+
)
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return batch_results
|
|
198
|
+
|
|
199
|
+
@override
|
|
200
|
+
async def get_batch_progress(self, batch_id: str) -> int:
|
|
201
|
+
"""Get the number of completed requests in a batch."""
|
|
202
|
+
client = self._root.get_client()
|
|
203
|
+
batch = await client.messages.batches.retrieve(batch_id)
|
|
204
|
+
|
|
205
|
+
# Return the number of processed requests
|
|
206
|
+
request_counts = batch.request_counts
|
|
207
|
+
return (
|
|
208
|
+
request_counts.succeeded
|
|
209
|
+
+ request_counts.errored
|
|
210
|
+
+ request_counts.canceled
|
|
211
|
+
+ request_counts.expired
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
@override
|
|
215
|
+
async def cancel_batch_request(self, batch_id: str) -> None:
|
|
216
|
+
"""Cancel a running batch request."""
|
|
217
|
+
client = self._root.get_client()
|
|
218
|
+
await client.messages.batches.cancel(batch_id)
|
|
219
|
+
self._root.logger.info(f"Canceled Anthropic batch {batch_id}")
|
|
220
|
+
|
|
221
|
+
@override
|
|
222
|
+
async def get_batch_status(self, batch_id: str) -> str:
|
|
223
|
+
"""Get the current status of a batch."""
|
|
224
|
+
client = self._root.get_client()
|
|
225
|
+
batch = await client.messages.batches.retrieve(batch_id)
|
|
226
|
+
return batch.processing_status
|
|
227
|
+
|
|
228
|
+
@classmethod
|
|
229
|
+
def is_batch_status_completed(cls, batch_status: str) -> bool:
|
|
230
|
+
"""Check if a batch status indicates completion."""
|
|
231
|
+
return batch_status == "ended"
|
|
232
|
+
|
|
233
|
+
@classmethod
|
|
234
|
+
def is_batch_status_failed(cls, batch_status: str) -> bool:
|
|
235
|
+
"""Check if a batch status indicates failure."""
|
|
236
|
+
# Anthropic batches can have individual request failures but the batch
|
|
237
|
+
# itself doesn't have a "failed" status - it just ends
|
|
238
|
+
return False
|
|
239
|
+
|
|
240
|
+
@classmethod
|
|
241
|
+
def is_batch_status_cancelled(cls, batch_status: str) -> bool:
|
|
242
|
+
"""Check if a batch status indicates cancellation."""
|
|
243
|
+
return batch_status == "canceling" or batch_status == "canceled"
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
@register_provider("anthropic")
|
|
41
247
|
class AnthropicModel(LLM):
|
|
42
248
|
_client: AsyncAnthropic | None = None
|
|
43
249
|
|
|
@@ -78,6 +284,12 @@ class AnthropicModel(LLM):
|
|
|
78
284
|
)
|
|
79
285
|
)
|
|
80
286
|
|
|
287
|
+
# Initialize batch support if enabled
|
|
288
|
+
self.supports_batch: bool = self.supports_batch and self.native
|
|
289
|
+
self.batch: LLMBatchMixin | None = (
|
|
290
|
+
AnthropicBatchMixin(self) if self.supports_batch else None
|
|
291
|
+
)
|
|
292
|
+
|
|
81
293
|
@override
|
|
82
294
|
async def parse_input(
|
|
83
295
|
self,
|
|
@@ -88,7 +300,6 @@ class AnthropicModel(LLM):
|
|
|
88
300
|
content_user: list[dict[str, Any]] = []
|
|
89
301
|
|
|
90
302
|
# First pass: collect all tool calls from Message objects for validation
|
|
91
|
-
# This handles both Message and BetaMessage types
|
|
92
303
|
tool_calls_in_input: set[str] = set()
|
|
93
304
|
for item in input:
|
|
94
305
|
if hasattr(item, "content") and hasattr(item, "role"):
|
|
@@ -351,12 +562,8 @@ class AnthropicModel(LLM):
|
|
|
351
562
|
|
|
352
563
|
body = await self.create_body(input, tools=tools, **kwargs)
|
|
353
564
|
|
|
354
|
-
betas = [
|
|
355
|
-
|
|
356
|
-
"interleaved-thinking-2025-05-14",
|
|
357
|
-
]
|
|
358
|
-
|
|
359
|
-
if "claude-sonnet-4-5" in self.model_name:
|
|
565
|
+
betas = ["files-api-2025-04-14", "interleaved-thinking-2025-05-14"]
|
|
566
|
+
if "sonnet-4-5" in self.model_name:
|
|
360
567
|
betas.append("context-1m-2025-08-07")
|
|
361
568
|
|
|
362
569
|
async with self.get_client().beta.messages.stream(
|
model_library/providers/azure.py
CHANGED
|
@@ -8,9 +8,11 @@ from model_library.base import (
|
|
|
8
8
|
LLMConfig,
|
|
9
9
|
)
|
|
10
10
|
from model_library.providers.openai import OpenAIModel
|
|
11
|
+
from model_library.register_models import register_provider
|
|
11
12
|
from model_library.utils import default_httpx_client
|
|
12
13
|
|
|
13
14
|
|
|
15
|
+
@register_provider("azure")
|
|
14
16
|
class AzureOpenAIModel(OpenAIModel):
|
|
15
17
|
_azure_client: AsyncAzureOpenAI | None = None
|
|
16
18
|
|
|
@@ -1,27 +1,17 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import Any, Literal, Sequence
|
|
3
|
-
|
|
4
|
-
from typing_extensions import override
|
|
1
|
+
from typing import Literal
|
|
5
2
|
|
|
6
3
|
from model_library import model_library_settings
|
|
7
4
|
from model_library.base import (
|
|
8
|
-
|
|
9
|
-
FileInput,
|
|
10
|
-
FileWithId,
|
|
11
|
-
InputItem,
|
|
5
|
+
DelegateOnly,
|
|
12
6
|
LLMConfig,
|
|
13
|
-
QueryResult,
|
|
14
|
-
ToolDefinition,
|
|
15
7
|
)
|
|
16
8
|
from model_library.providers.openai import OpenAIModel
|
|
9
|
+
from model_library.register_models import register_provider
|
|
17
10
|
from model_library.utils import create_openai_client_with_defaults
|
|
18
11
|
|
|
19
12
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def get_client(self) -> None:
|
|
23
|
-
raise NotImplementedError("Not implemented")
|
|
24
|
-
|
|
13
|
+
@register_provider("cohere")
|
|
14
|
+
class CohereModel(DelegateOnly):
|
|
25
15
|
def __init__(
|
|
26
16
|
self,
|
|
27
17
|
model_name: str,
|
|
@@ -30,71 +20,15 @@ class CohereModel(LLM):
|
|
|
30
20
|
config: LLMConfig | None = None,
|
|
31
21
|
):
|
|
32
22
|
super().__init__(model_name, provider, config=config)
|
|
33
|
-
self.native: bool = False
|
|
34
23
|
|
|
35
24
|
# https://docs.cohere.com/docs/compatibility-api
|
|
36
|
-
self.delegate
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
base_url="https://api.cohere.ai/compatibility/v1",
|
|
46
|
-
),
|
|
47
|
-
use_completions=True,
|
|
48
|
-
)
|
|
25
|
+
self.delegate = OpenAIModel(
|
|
26
|
+
model_name=self.model_name,
|
|
27
|
+
provider=self.provider,
|
|
28
|
+
config=config,
|
|
29
|
+
custom_client=create_openai_client_with_defaults(
|
|
30
|
+
api_key=model_library_settings.COHERE_API_KEY,
|
|
31
|
+
base_url="https://api.cohere.ai/compatibility/v1",
|
|
32
|
+
),
|
|
33
|
+
use_completions=True,
|
|
49
34
|
)
|
|
50
|
-
|
|
51
|
-
@override
|
|
52
|
-
async def parse_input(
|
|
53
|
-
self,
|
|
54
|
-
input: Sequence[InputItem],
|
|
55
|
-
**kwargs: Any,
|
|
56
|
-
) -> Any:
|
|
57
|
-
raise NotImplementedError()
|
|
58
|
-
|
|
59
|
-
@override
|
|
60
|
-
async def parse_image(
|
|
61
|
-
self,
|
|
62
|
-
image: FileInput,
|
|
63
|
-
) -> Any:
|
|
64
|
-
raise NotImplementedError()
|
|
65
|
-
|
|
66
|
-
@override
|
|
67
|
-
async def parse_file(
|
|
68
|
-
self,
|
|
69
|
-
file: FileInput,
|
|
70
|
-
) -> Any:
|
|
71
|
-
raise NotImplementedError()
|
|
72
|
-
|
|
73
|
-
@override
|
|
74
|
-
async def parse_tools(
|
|
75
|
-
self,
|
|
76
|
-
tools: list[ToolDefinition],
|
|
77
|
-
) -> Any:
|
|
78
|
-
raise NotImplementedError()
|
|
79
|
-
|
|
80
|
-
@override
|
|
81
|
-
async def upload_file(
|
|
82
|
-
self,
|
|
83
|
-
name: str,
|
|
84
|
-
mime: str,
|
|
85
|
-
bytes: io.BytesIO,
|
|
86
|
-
type: Literal["image", "file"] = "file",
|
|
87
|
-
) -> FileWithId:
|
|
88
|
-
raise NotImplementedError()
|
|
89
|
-
|
|
90
|
-
@override
|
|
91
|
-
async def _query_impl(
|
|
92
|
-
self,
|
|
93
|
-
input: Sequence[InputItem],
|
|
94
|
-
*,
|
|
95
|
-
tools: list[ToolDefinition],
|
|
96
|
-
**kwargs: object,
|
|
97
|
-
) -> QueryResult:
|
|
98
|
-
if self.delegate:
|
|
99
|
-
return await self.delegate_query(input, tools=tools, **kwargs)
|
|
100
|
-
raise NotImplementedError()
|
|
@@ -3,31 +3,20 @@ See deepseek data retention policy
|
|
|
3
3
|
https://cdn.deepseek.com/policies/en-US/deepseek-privacy-policy.html
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
import
|
|
7
|
-
from typing import Any, Literal, Sequence
|
|
8
|
-
|
|
9
|
-
from typing_extensions import override
|
|
6
|
+
from typing import Literal
|
|
10
7
|
|
|
11
8
|
from model_library import model_library_settings
|
|
12
9
|
from model_library.base import (
|
|
13
|
-
|
|
14
|
-
FileInput,
|
|
15
|
-
FileWithId,
|
|
16
|
-
InputItem,
|
|
10
|
+
DelegateOnly,
|
|
17
11
|
LLMConfig,
|
|
18
|
-
QueryResult,
|
|
19
|
-
ToolDefinition,
|
|
20
12
|
)
|
|
21
|
-
from model_library.exceptions import ToolCallingNotSupportedError
|
|
22
13
|
from model_library.providers.openai import OpenAIModel
|
|
14
|
+
from model_library.register_models import register_provider
|
|
23
15
|
from model_library.utils import create_openai_client_with_defaults
|
|
24
16
|
|
|
25
17
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def get_client(self) -> None:
|
|
29
|
-
raise NotImplementedError("Not implemented")
|
|
30
|
-
|
|
18
|
+
@register_provider("deepseek")
|
|
19
|
+
class DeepSeekModel(DelegateOnly):
|
|
31
20
|
def __init__(
|
|
32
21
|
self,
|
|
33
22
|
model_name: str,
|
|
@@ -36,80 +25,15 @@ class DeepSeekModel(LLM):
|
|
|
36
25
|
config: LLMConfig | None = None,
|
|
37
26
|
):
|
|
38
27
|
super().__init__(model_name, provider, config=config)
|
|
39
|
-
self.model_name: str = model_name
|
|
40
|
-
self.native: bool = False
|
|
41
28
|
|
|
42
29
|
# https://api-docs.deepseek.com/
|
|
43
|
-
self.delegate
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
base_url="https://api.deepseek.com",
|
|
53
|
-
),
|
|
54
|
-
use_completions=True,
|
|
55
|
-
)
|
|
30
|
+
self.delegate = OpenAIModel(
|
|
31
|
+
model_name=self.model_name,
|
|
32
|
+
provider=self.provider,
|
|
33
|
+
config=config,
|
|
34
|
+
custom_client=create_openai_client_with_defaults(
|
|
35
|
+
api_key=model_library_settings.DEEPSEEK_API_KEY,
|
|
36
|
+
base_url="https://api.deepseek.com",
|
|
37
|
+
),
|
|
38
|
+
use_completions=True,
|
|
56
39
|
)
|
|
57
|
-
|
|
58
|
-
@override
|
|
59
|
-
async def parse_input(
|
|
60
|
-
self,
|
|
61
|
-
input: Sequence[InputItem],
|
|
62
|
-
**kwargs: Any,
|
|
63
|
-
) -> Any:
|
|
64
|
-
raise NotImplementedError()
|
|
65
|
-
|
|
66
|
-
@override
|
|
67
|
-
async def parse_image(
|
|
68
|
-
self,
|
|
69
|
-
image: FileInput,
|
|
70
|
-
) -> Any:
|
|
71
|
-
raise NotImplementedError()
|
|
72
|
-
|
|
73
|
-
@override
|
|
74
|
-
async def parse_file(
|
|
75
|
-
self,
|
|
76
|
-
file: FileInput,
|
|
77
|
-
) -> Any:
|
|
78
|
-
raise NotImplementedError()
|
|
79
|
-
|
|
80
|
-
@override
|
|
81
|
-
async def parse_tools(
|
|
82
|
-
self,
|
|
83
|
-
tools: list[ToolDefinition],
|
|
84
|
-
) -> Any:
|
|
85
|
-
raise NotImplementedError()
|
|
86
|
-
|
|
87
|
-
@override
|
|
88
|
-
async def upload_file(
|
|
89
|
-
self,
|
|
90
|
-
name: str,
|
|
91
|
-
mime: str,
|
|
92
|
-
bytes: io.BytesIO,
|
|
93
|
-
type: Literal["image", "file"] = "file",
|
|
94
|
-
) -> FileWithId:
|
|
95
|
-
raise NotImplementedError()
|
|
96
|
-
|
|
97
|
-
@override
|
|
98
|
-
async def _query_impl(
|
|
99
|
-
self,
|
|
100
|
-
input: Sequence[InputItem],
|
|
101
|
-
*,
|
|
102
|
-
tools: list[ToolDefinition],
|
|
103
|
-
**kwargs: object,
|
|
104
|
-
) -> QueryResult:
|
|
105
|
-
# DeepSeek reasoning models don't support tools - they auto-route to chat model
|
|
106
|
-
# which loses reasoning capability
|
|
107
|
-
if tools and not self.supports_tools:
|
|
108
|
-
raise ToolCallingNotSupportedError(
|
|
109
|
-
f"DeepSeek model ({self.model_name}) does not support tools. "
|
|
110
|
-
f"Use deepseek/deepseek-chat for tool calls."
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
if self.delegate:
|
|
114
|
-
return await self.delegate_query(input, tools=tools, **kwargs)
|
|
115
|
-
raise NotImplementedError()
|
|
@@ -1,22 +1,17 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import Any, Literal, Sequence
|
|
1
|
+
from typing import Literal
|
|
3
2
|
|
|
4
3
|
from typing_extensions import override
|
|
5
4
|
|
|
6
5
|
from model_library import model_library_settings
|
|
7
6
|
from model_library.base import (
|
|
8
|
-
LLM,
|
|
9
|
-
FileInput,
|
|
10
|
-
FileWithId,
|
|
11
|
-
InputItem,
|
|
12
7
|
LLMConfig,
|
|
13
8
|
ProviderConfig,
|
|
14
|
-
QueryResult,
|
|
15
9
|
QueryResultCost,
|
|
16
10
|
QueryResultMetadata,
|
|
17
|
-
ToolDefinition,
|
|
18
11
|
)
|
|
12
|
+
from model_library.base.delegate_only import DelegateOnly
|
|
19
13
|
from model_library.providers.openai import OpenAIModel
|
|
14
|
+
from model_library.register_models import register_provider
|
|
20
15
|
from model_library.utils import create_openai_client_with_defaults
|
|
21
16
|
|
|
22
17
|
|
|
@@ -24,13 +19,10 @@ class FireworksConfig(ProviderConfig):
|
|
|
24
19
|
serverless: bool = True
|
|
25
20
|
|
|
26
21
|
|
|
27
|
-
|
|
22
|
+
@register_provider("fireworks")
|
|
23
|
+
class FireworksModel(DelegateOnly):
|
|
28
24
|
provider_config = FireworksConfig()
|
|
29
25
|
|
|
30
|
-
@override
|
|
31
|
-
def get_client(self) -> None:
|
|
32
|
-
raise NotImplementedError("Not implemented")
|
|
33
|
-
|
|
34
26
|
def __init__(
|
|
35
27
|
self,
|
|
36
28
|
model_name: str,
|
|
@@ -45,76 +37,18 @@ class FireworksModel(LLM):
|
|
|
45
37
|
else:
|
|
46
38
|
self.model_name = "accounts/rayan-936e28/deployedModels/" + self.model_name
|
|
47
39
|
|
|
48
|
-
# not using Fireworks SDK
|
|
49
|
-
self.native: bool = False
|
|
50
|
-
|
|
51
40
|
# https://docs.fireworks.ai/tools-sdks/openai-compatibility
|
|
52
|
-
self.delegate
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
base_url="https://api.fireworks.ai/inference/v1",
|
|
62
|
-
),
|
|
63
|
-
use_completions=True,
|
|
64
|
-
)
|
|
41
|
+
self.delegate = OpenAIModel(
|
|
42
|
+
model_name=self.model_name,
|
|
43
|
+
provider=self.provider,
|
|
44
|
+
config=config,
|
|
45
|
+
custom_client=create_openai_client_with_defaults(
|
|
46
|
+
api_key=model_library_settings.FIREWORKS_API_KEY,
|
|
47
|
+
base_url="https://api.fireworks.ai/inference/v1",
|
|
48
|
+
),
|
|
49
|
+
use_completions=True,
|
|
65
50
|
)
|
|
66
51
|
|
|
67
|
-
@override
|
|
68
|
-
async def parse_input(
|
|
69
|
-
self,
|
|
70
|
-
input: Sequence[InputItem],
|
|
71
|
-
**kwargs: Any,
|
|
72
|
-
) -> Any:
|
|
73
|
-
raise NotImplementedError()
|
|
74
|
-
|
|
75
|
-
@override
|
|
76
|
-
async def parse_image(
|
|
77
|
-
self,
|
|
78
|
-
image: FileInput,
|
|
79
|
-
) -> Any:
|
|
80
|
-
raise NotImplementedError()
|
|
81
|
-
|
|
82
|
-
@override
|
|
83
|
-
async def parse_file(
|
|
84
|
-
self,
|
|
85
|
-
file: FileInput,
|
|
86
|
-
) -> Any:
|
|
87
|
-
raise NotImplementedError()
|
|
88
|
-
|
|
89
|
-
@override
|
|
90
|
-
async def parse_tools(
|
|
91
|
-
self,
|
|
92
|
-
tools: list[ToolDefinition],
|
|
93
|
-
) -> Any:
|
|
94
|
-
raise NotImplementedError()
|
|
95
|
-
|
|
96
|
-
@override
|
|
97
|
-
async def upload_file(
|
|
98
|
-
self,
|
|
99
|
-
name: str,
|
|
100
|
-
mime: str,
|
|
101
|
-
bytes: io.BytesIO,
|
|
102
|
-
type: Literal["image", "file"] = "file",
|
|
103
|
-
) -> FileWithId:
|
|
104
|
-
raise NotImplementedError()
|
|
105
|
-
|
|
106
|
-
@override
|
|
107
|
-
async def _query_impl(
|
|
108
|
-
self,
|
|
109
|
-
input: Sequence[InputItem],
|
|
110
|
-
*,
|
|
111
|
-
tools: list[ToolDefinition],
|
|
112
|
-
**kwargs: object,
|
|
113
|
-
) -> QueryResult:
|
|
114
|
-
if self.delegate:
|
|
115
|
-
return await self.delegate_query(input, tools=tools, **kwargs)
|
|
116
|
-
raise NotImplementedError()
|
|
117
|
-
|
|
118
52
|
@override
|
|
119
53
|
async def _calculate_cost(
|
|
120
54
|
self,
|
|
@@ -130,4 +64,6 @@ class FireworksModel(LLM):
|
|
|
130
64
|
# https://docs.fireworks.ai/faq-new/billing-pricing/is-prompt-caching-billed-differently
|
|
131
65
|
# prompt caching does not affect billing for serverless models
|
|
132
66
|
|
|
133
|
-
return await super()._calculate_cost(
|
|
67
|
+
return await super()._calculate_cost(
|
|
68
|
+
metadata, batch, bill_reasoning=bill_reasoning
|
|
69
|
+
)
|