model-library 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +237 -62
- model_library/base/delegate_only.py +86 -9
- model_library/base/input.py +10 -7
- model_library/base/output.py +48 -0
- model_library/base/utils.py +56 -7
- model_library/config/alibaba_models.yaml +44 -57
- model_library/config/all_models.json +253 -126
- model_library/config/kimi_models.yaml +30 -3
- model_library/config/openai_models.yaml +15 -23
- model_library/config/zai_models.yaml +24 -3
- model_library/exceptions.py +14 -77
- model_library/logging.py +6 -2
- model_library/providers/ai21labs.py +30 -14
- model_library/providers/alibaba.py +17 -8
- model_library/providers/amazon.py +119 -64
- model_library/providers/anthropic.py +184 -104
- model_library/providers/azure.py +22 -10
- model_library/providers/cohere.py +7 -7
- model_library/providers/deepseek.py +8 -8
- model_library/providers/fireworks.py +7 -8
- model_library/providers/google/batch.py +17 -13
- model_library/providers/google/google.py +130 -73
- model_library/providers/inception.py +7 -7
- model_library/providers/kimi.py +18 -8
- model_library/providers/minimax.py +30 -13
- model_library/providers/mistral.py +61 -35
- model_library/providers/openai.py +219 -93
- model_library/providers/openrouter.py +34 -0
- model_library/providers/perplexity.py +7 -7
- model_library/providers/together.py +7 -8
- model_library/providers/vals.py +16 -9
- model_library/providers/xai.py +157 -144
- model_library/providers/zai.py +38 -8
- model_library/register_models.py +4 -2
- model_library/registry_utils.py +39 -15
- model_library/retriers/__init__.py +0 -0
- model_library/retriers/backoff.py +73 -0
- model_library/retriers/base.py +225 -0
- model_library/retriers/token.py +427 -0
- model_library/retriers/utils.py +11 -0
- model_library/settings.py +1 -1
- model_library/utils.py +13 -35
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/METADATA +4 -3
- model_library-0.1.8.dist-info/RECORD +70 -0
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
- model_library-0.1.6.dist-info/RECORD +0 -64
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,21 @@
|
|
|
1
|
+
import datetime
|
|
1
2
|
import io
|
|
2
3
|
import logging
|
|
4
|
+
import time
|
|
3
5
|
from typing import Any, Literal, Sequence, cast
|
|
4
6
|
|
|
5
|
-
from anthropic import AsyncAnthropic
|
|
6
|
-
from anthropic.types import TextBlock, ToolUseBlock
|
|
7
|
+
from anthropic import APIConnectionError, AsyncAnthropic
|
|
7
8
|
from anthropic.types.beta.beta_tool_use_block import BetaToolUseBlock
|
|
8
|
-
from anthropic.types.
|
|
9
|
+
from anthropic.types.beta.parsed_beta_message import ParsedBetaMessage
|
|
10
|
+
from pydantic import SecretStr
|
|
9
11
|
from typing_extensions import override
|
|
10
12
|
|
|
11
13
|
from model_library import model_library_settings
|
|
12
14
|
from model_library.base import (
|
|
13
15
|
LLM,
|
|
14
16
|
BatchResult,
|
|
17
|
+
DelegateConfig,
|
|
18
|
+
FileBase,
|
|
15
19
|
FileInput,
|
|
16
20
|
FileWithBase64,
|
|
17
21
|
FileWithId,
|
|
@@ -22,7 +26,9 @@ from model_library.base import (
|
|
|
22
26
|
QueryResult,
|
|
23
27
|
QueryResultCost,
|
|
24
28
|
QueryResultMetadata,
|
|
25
|
-
|
|
29
|
+
RateLimit,
|
|
30
|
+
RawInput,
|
|
31
|
+
RawResponse,
|
|
26
32
|
TextInput,
|
|
27
33
|
ToolBody,
|
|
28
34
|
ToolCall,
|
|
@@ -30,16 +36,15 @@ from model_library.base import (
|
|
|
30
36
|
ToolResult,
|
|
31
37
|
)
|
|
32
38
|
from model_library.exceptions import (
|
|
39
|
+
ImmediateRetryException,
|
|
33
40
|
MaxOutputTokensExceededError,
|
|
41
|
+
NoMatchingToolCallError,
|
|
34
42
|
)
|
|
35
43
|
from model_library.model_utils import get_default_budget_tokens
|
|
36
44
|
from model_library.providers.openai import OpenAIModel
|
|
37
45
|
from model_library.register_models import register_provider
|
|
38
46
|
from model_library.utils import (
|
|
39
|
-
|
|
40
|
-
default_httpx_client,
|
|
41
|
-
filter_empty_text_blocks,
|
|
42
|
-
normalize_tool_result,
|
|
47
|
+
create_anthropic_client_with_defaults,
|
|
43
48
|
)
|
|
44
49
|
|
|
45
50
|
|
|
@@ -62,9 +67,9 @@ class AnthropicBatchMixin(LLMBatchMixin):
|
|
|
62
67
|
|
|
63
68
|
Format: {"custom_id": str, "params": {...message params...}}
|
|
64
69
|
"""
|
|
65
|
-
# Build the message body using the parent model's
|
|
70
|
+
# Build the message body using the parent model's build_body method
|
|
66
71
|
tools = cast(list[ToolDefinition], kwargs.pop("tools", []))
|
|
67
|
-
body = await self._root.
|
|
72
|
+
body = await self._root.build_body(input, tools=tools, **kwargs)
|
|
68
73
|
|
|
69
74
|
return {
|
|
70
75
|
"custom_id": custom_id,
|
|
@@ -246,21 +251,25 @@ class AnthropicBatchMixin(LLMBatchMixin):
|
|
|
246
251
|
|
|
247
252
|
@register_provider("anthropic")
|
|
248
253
|
class AnthropicModel(LLM):
|
|
249
|
-
|
|
254
|
+
def _get_default_api_key(self) -> str:
|
|
255
|
+
if self.delegate_config:
|
|
256
|
+
return self.delegate_config.api_key.get_secret_value()
|
|
257
|
+
return model_library_settings.ANTHROPIC_API_KEY
|
|
250
258
|
|
|
251
259
|
@override
|
|
252
|
-
def get_client(self) -> AsyncAnthropic:
|
|
253
|
-
if self.
|
|
254
|
-
|
|
255
|
-
if not AnthropicModel._client:
|
|
260
|
+
def get_client(self, api_key: str | None = None) -> AsyncAnthropic:
|
|
261
|
+
if not self.has_client():
|
|
262
|
+
assert api_key
|
|
256
263
|
headers: dict[str, str] = {}
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
264
|
+
client = create_anthropic_client_with_defaults(
|
|
265
|
+
base_url=self.delegate_config.base_url
|
|
266
|
+
if self.delegate_config
|
|
267
|
+
else None,
|
|
268
|
+
api_key=api_key,
|
|
261
269
|
default_headers=headers,
|
|
262
270
|
)
|
|
263
|
-
|
|
271
|
+
self.assign_client(client)
|
|
272
|
+
return super().get_client()
|
|
264
273
|
|
|
265
274
|
def __init__(
|
|
266
275
|
self,
|
|
@@ -268,38 +277,51 @@ class AnthropicModel(LLM):
|
|
|
268
277
|
provider: str = "anthropic",
|
|
269
278
|
*,
|
|
270
279
|
config: LLMConfig | None = None,
|
|
271
|
-
|
|
280
|
+
delegate_config: DelegateConfig | None = None,
|
|
272
281
|
):
|
|
273
|
-
|
|
282
|
+
self.delegate_config = delegate_config
|
|
274
283
|
|
|
275
|
-
|
|
276
|
-
self._delegate_client: AsyncAnthropic | None = custom_client
|
|
284
|
+
super().__init__(model_name, provider, config=config)
|
|
277
285
|
|
|
278
286
|
# https://docs.anthropic.com/en/api/openai-sdk
|
|
279
287
|
self.delegate = (
|
|
280
288
|
None
|
|
281
|
-
if self.native or
|
|
289
|
+
if self.native or self.delegate_config
|
|
282
290
|
else OpenAIModel(
|
|
283
291
|
model_name=self.model_name,
|
|
284
|
-
provider=provider,
|
|
292
|
+
provider=self.provider,
|
|
285
293
|
config=config,
|
|
286
|
-
|
|
287
|
-
|
|
294
|
+
use_completions=True,
|
|
295
|
+
delegate_config=DelegateConfig(
|
|
288
296
|
base_url="https://api.anthropic.com/v1/",
|
|
297
|
+
api_key=SecretStr(model_library_settings.ANTHROPIC_API_KEY),
|
|
289
298
|
),
|
|
290
|
-
use_completions=True,
|
|
291
299
|
)
|
|
292
300
|
)
|
|
293
301
|
|
|
294
302
|
# Initialize batch support if enabled
|
|
295
303
|
# Disable batch when using custom_client (similar to OpenAI)
|
|
296
304
|
self.supports_batch: bool = (
|
|
297
|
-
self.supports_batch and self.native and not
|
|
305
|
+
self.supports_batch and self.native and not self.delegate_config
|
|
298
306
|
)
|
|
299
307
|
self.batch: LLMBatchMixin | None = (
|
|
300
308
|
AnthropicBatchMixin(self) if self.supports_batch else None
|
|
301
309
|
)
|
|
302
310
|
|
|
311
|
+
async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
|
|
312
|
+
raw_responses = [x for x in input if isinstance(x, RawResponse)]
|
|
313
|
+
tool_call_ids: list[str] = []
|
|
314
|
+
|
|
315
|
+
calls = [
|
|
316
|
+
y
|
|
317
|
+
for x in raw_responses
|
|
318
|
+
if isinstance(x.response, ParsedBetaMessage)
|
|
319
|
+
for y in x.response.content # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
320
|
+
if isinstance(y, BetaToolUseBlock)
|
|
321
|
+
]
|
|
322
|
+
tool_call_ids.extend([x.id for x in calls])
|
|
323
|
+
return tool_call_ids
|
|
324
|
+
|
|
303
325
|
@override
|
|
304
326
|
async def parse_input(
|
|
305
327
|
self,
|
|
@@ -307,77 +329,61 @@ class AnthropicModel(LLM):
|
|
|
307
329
|
**kwargs: Any,
|
|
308
330
|
) -> list[dict[str, Any] | Any]:
|
|
309
331
|
new_input: list[dict[str, Any] | Any] = []
|
|
332
|
+
|
|
310
333
|
content_user: list[dict[str, Any]] = []
|
|
311
334
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
if isinstance(content, (ToolUseBlock, BetaToolUseBlock)):
|
|
320
|
-
tool_calls_in_input.add(content.id)
|
|
335
|
+
def flush_content_user():
|
|
336
|
+
if content_user:
|
|
337
|
+
# NOTE: must make new object as we clear()
|
|
338
|
+
new_input.append({"role": "user", "content": content_user.copy()})
|
|
339
|
+
content_user.clear()
|
|
340
|
+
|
|
341
|
+
tool_call_ids = await self.get_tool_call_ids(input)
|
|
321
342
|
|
|
322
343
|
for item in input:
|
|
344
|
+
if isinstance(item, TextInput):
|
|
345
|
+
content_user.append({"type": "text", "text": item.text})
|
|
346
|
+
continue
|
|
347
|
+
|
|
348
|
+
if isinstance(item, FileBase):
|
|
349
|
+
match item.type:
|
|
350
|
+
case "image":
|
|
351
|
+
parsed = await self.parse_image(item)
|
|
352
|
+
case "file":
|
|
353
|
+
parsed = await self.parse_file(item)
|
|
354
|
+
content_user.append(parsed)
|
|
355
|
+
continue
|
|
356
|
+
|
|
357
|
+
# non content user item
|
|
358
|
+
flush_content_user()
|
|
359
|
+
|
|
323
360
|
match item:
|
|
324
|
-
case
|
|
325
|
-
if item.
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
content_user.append(await self.parse_file(item))
|
|
333
|
-
case _:
|
|
334
|
-
if content_user:
|
|
335
|
-
filtered = filter_empty_text_blocks(content_user)
|
|
336
|
-
if filtered:
|
|
337
|
-
new_input.append({"role": "user", "content": filtered})
|
|
338
|
-
content_user = []
|
|
339
|
-
match item:
|
|
340
|
-
case ToolResult():
|
|
341
|
-
if item.tool_call.id not in tool_calls_in_input:
|
|
342
|
-
raise Exception(
|
|
343
|
-
"Tool call result provided with no matching tool call"
|
|
344
|
-
)
|
|
345
|
-
result_str = normalize_tool_result(item.result)
|
|
346
|
-
new_input.append(
|
|
361
|
+
case ToolResult():
|
|
362
|
+
if item.tool_call.id not in tool_call_ids:
|
|
363
|
+
raise NoMatchingToolCallError()
|
|
364
|
+
|
|
365
|
+
new_input.append(
|
|
366
|
+
{
|
|
367
|
+
"role": "user",
|
|
368
|
+
"content": [
|
|
347
369
|
{
|
|
348
|
-
"
|
|
349
|
-
"
|
|
350
|
-
|
|
351
|
-
"type": "tool_result",
|
|
352
|
-
"tool_use_id": item.tool_call.id,
|
|
353
|
-
"content": [
|
|
354
|
-
{"type": "text", "text": result_str}
|
|
355
|
-
],
|
|
356
|
-
}
|
|
357
|
-
],
|
|
370
|
+
"type": "tool_result",
|
|
371
|
+
"tool_use_id": item.tool_call.id,
|
|
372
|
+
"content": [{"type": "text", "text": item.result}],
|
|
358
373
|
}
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
]
|
|
371
|
-
if filtered_content:
|
|
372
|
-
new_input.append(
|
|
373
|
-
{"role": "assistant", "content": filtered_content}
|
|
374
|
-
)
|
|
375
|
-
|
|
376
|
-
if content_user:
|
|
377
|
-
filtered = filter_empty_text_blocks(content_user)
|
|
378
|
-
if filtered:
|
|
379
|
-
new_input.append({"role": "user", "content": filtered})
|
|
374
|
+
],
|
|
375
|
+
}
|
|
376
|
+
)
|
|
377
|
+
case RawResponse():
|
|
378
|
+
content = cast(ParsedBetaMessage, item.response).content
|
|
379
|
+
new_input.append({"role": "assistant", "content": content})
|
|
380
|
+
case RawInput():
|
|
381
|
+
new_input.append(item.input)
|
|
382
|
+
|
|
383
|
+
# in case content user item is the last item
|
|
384
|
+
flush_content_user()
|
|
380
385
|
|
|
386
|
+
# cache control
|
|
381
387
|
if new_input:
|
|
382
388
|
last_msg = new_input[-1]
|
|
383
389
|
if not isinstance(last_msg, dict):
|
|
@@ -495,7 +501,7 @@ class AnthropicModel(LLM):
|
|
|
495
501
|
bytes: io.BytesIO,
|
|
496
502
|
type: Literal["image", "file"] = "file",
|
|
497
503
|
) -> FileWithId:
|
|
498
|
-
file_mime = f"image/{mime}" if type == "image" else mime
|
|
504
|
+
file_mime = f"image/{mime}" if type == "image" else mime
|
|
499
505
|
response = await self.get_client().beta.files.upload(
|
|
500
506
|
file=(
|
|
501
507
|
name,
|
|
@@ -513,7 +519,8 @@ class AnthropicModel(LLM):
|
|
|
513
519
|
|
|
514
520
|
cache_control = {"type": "ephemeral"} # 5 min cache
|
|
515
521
|
|
|
516
|
-
|
|
522
|
+
@override
|
|
523
|
+
async def build_body(
|
|
517
524
|
self,
|
|
518
525
|
input: Sequence[InputItem],
|
|
519
526
|
*,
|
|
@@ -521,7 +528,6 @@ class AnthropicModel(LLM):
|
|
|
521
528
|
**kwargs: object,
|
|
522
529
|
) -> dict[str, Any]:
|
|
523
530
|
body: dict[str, Any] = {
|
|
524
|
-
"max_tokens": self.max_tokens,
|
|
525
531
|
"model": self.model_name,
|
|
526
532
|
"messages": await self.parse_input(input),
|
|
527
533
|
}
|
|
@@ -535,6 +541,11 @@ class AnthropicModel(LLM):
|
|
|
535
541
|
}
|
|
536
542
|
]
|
|
537
543
|
|
|
544
|
+
if not self.max_tokens:
|
|
545
|
+
raise Exception("Anthropic models require a max_tokens parameter")
|
|
546
|
+
|
|
547
|
+
body["max_tokens"] = self.max_tokens
|
|
548
|
+
|
|
538
549
|
if self.reasoning:
|
|
539
550
|
budget_tokens = kwargs.pop(
|
|
540
551
|
"budget_tokens", get_default_budget_tokens(self.max_tokens)
|
|
@@ -573,12 +584,12 @@ class AnthropicModel(LLM):
|
|
|
573
584
|
input, tools=tools, query_logger=query_logger, **kwargs
|
|
574
585
|
)
|
|
575
586
|
|
|
576
|
-
body = await self.
|
|
587
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
577
588
|
|
|
578
589
|
client = self.get_client()
|
|
579
590
|
|
|
580
591
|
# only send betas for the official Anthropic endpoint
|
|
581
|
-
is_anthropic_endpoint = self.
|
|
592
|
+
is_anthropic_endpoint = self.delegate_config is None
|
|
582
593
|
if not is_anthropic_endpoint:
|
|
583
594
|
client_base_url = getattr(client, "_base_url", None) or getattr(
|
|
584
595
|
client, "base_url", None
|
|
@@ -593,11 +604,14 @@ class AnthropicModel(LLM):
|
|
|
593
604
|
betas.append("context-1m-2025-08-07")
|
|
594
605
|
stream_kwargs["betas"] = betas
|
|
595
606
|
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
607
|
+
try:
|
|
608
|
+
async with client.beta.messages.stream(
|
|
609
|
+
**stream_kwargs,
|
|
610
|
+
) as stream: # pyright: ignore[reportAny]
|
|
611
|
+
message = await stream.get_final_message()
|
|
612
|
+
self.logger.info(f"Anthropic Response finished: {message.id}")
|
|
613
|
+
except APIConnectionError:
|
|
614
|
+
raise ImmediateRetryException("Failed to connect to Anthropic")
|
|
601
615
|
|
|
602
616
|
text = ""
|
|
603
617
|
reasoning = ""
|
|
@@ -630,9 +644,75 @@ class AnthropicModel(LLM):
|
|
|
630
644
|
cache_write_tokens=message.usage.cache_creation_input_tokens,
|
|
631
645
|
),
|
|
632
646
|
tool_calls=tool_calls,
|
|
633
|
-
history=[*input, message],
|
|
647
|
+
history=[*input, RawResponse(response=message)],
|
|
634
648
|
)
|
|
635
649
|
|
|
650
|
+
@override
|
|
651
|
+
async def get_rate_limit(self) -> RateLimit:
|
|
652
|
+
response = await self.get_client().messages.with_raw_response.create(
|
|
653
|
+
max_tokens=1,
|
|
654
|
+
messages=[
|
|
655
|
+
{
|
|
656
|
+
"role": "user",
|
|
657
|
+
"content": "Ping",
|
|
658
|
+
}
|
|
659
|
+
],
|
|
660
|
+
model=self.model_name,
|
|
661
|
+
)
|
|
662
|
+
headers = response.headers
|
|
663
|
+
|
|
664
|
+
server_time_str = headers.get("date")
|
|
665
|
+
if server_time_str:
|
|
666
|
+
server_time = datetime.datetime.strptime(
|
|
667
|
+
server_time_str, "%a, %d %b %Y %H:%M:%S GMT"
|
|
668
|
+
).replace(tzinfo=datetime.timezone.utc)
|
|
669
|
+
timestamp = server_time.timestamp()
|
|
670
|
+
else:
|
|
671
|
+
timestamp = time.time()
|
|
672
|
+
|
|
673
|
+
return RateLimit(
|
|
674
|
+
unix_timestamp=timestamp,
|
|
675
|
+
raw=headers,
|
|
676
|
+
request_limit=int(headers["anthropic-ratelimit-requests-limit"]),
|
|
677
|
+
request_remaining=int(headers["anthropic-ratelimit-requests-remaining"]),
|
|
678
|
+
token_limit=int(response.headers["anthropic-ratelimit-tokens-limit"]),
|
|
679
|
+
token_remaining=int(headers["anthropic-ratelimit-tokens-remaining"]),
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
@override
|
|
683
|
+
async def count_tokens(
|
|
684
|
+
self,
|
|
685
|
+
input: Sequence[InputItem],
|
|
686
|
+
*,
|
|
687
|
+
history: Sequence[InputItem] = [],
|
|
688
|
+
tools: list[ToolDefinition] = [],
|
|
689
|
+
**kwargs: object,
|
|
690
|
+
) -> int:
|
|
691
|
+
"""
|
|
692
|
+
Count the number of tokens using Anthropic's native token counting API.
|
|
693
|
+
https://docs.anthropic.com/en/docs/build-with-claude/token-counting
|
|
694
|
+
"""
|
|
695
|
+
try:
|
|
696
|
+
input = [*history, *input]
|
|
697
|
+
if not input:
|
|
698
|
+
return 0
|
|
699
|
+
|
|
700
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
701
|
+
|
|
702
|
+
# Remove fields not supported by count_tokens endpoint
|
|
703
|
+
body.pop("max_tokens", None)
|
|
704
|
+
body.pop("temperature", None)
|
|
705
|
+
|
|
706
|
+
client = self.get_client()
|
|
707
|
+
response = await client.messages.count_tokens(**body)
|
|
708
|
+
|
|
709
|
+
return response.input_tokens
|
|
710
|
+
except Exception as e:
|
|
711
|
+
self.logger.error(f"Error counting tokens: {e}")
|
|
712
|
+
return await super().count_tokens(
|
|
713
|
+
input, history=history, tools=tools, **kwargs
|
|
714
|
+
)
|
|
715
|
+
|
|
636
716
|
@override
|
|
637
717
|
async def _calculate_cost(
|
|
638
718
|
self,
|
model_library/providers/azure.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import json
|
|
1
2
|
from typing import Literal
|
|
2
3
|
|
|
3
4
|
from openai.lib.azure import AsyncAzureOpenAI
|
|
@@ -14,21 +15,32 @@ from model_library.utils import default_httpx_client
|
|
|
14
15
|
|
|
15
16
|
@register_provider("azure")
|
|
16
17
|
class AzureOpenAIModel(OpenAIModel):
|
|
17
|
-
_azure_client: AsyncAzureOpenAI | None = None
|
|
18
|
-
|
|
19
18
|
@override
|
|
20
|
-
def
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
19
|
+
def _get_default_api_key(self) -> str:
|
|
20
|
+
return json.dumps(
|
|
21
|
+
{
|
|
22
|
+
"AZURE_API_KEY": model_library_settings.AZURE_API_KEY,
|
|
23
|
+
"AZURE_ENDPOINT": model_library_settings.AZURE_ENDPOINT,
|
|
24
|
+
"AZURE_API_VERSION": model_library_settings.get(
|
|
26
25
|
"AZURE_API_VERSION", "2025-04-01-preview"
|
|
27
26
|
),
|
|
27
|
+
}
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
@override
|
|
31
|
+
def get_client(self, api_key: str | None = None) -> AsyncAzureOpenAI:
|
|
32
|
+
if not self.has_client():
|
|
33
|
+
assert api_key
|
|
34
|
+
creds = json.loads(api_key)
|
|
35
|
+
client = AsyncAzureOpenAI(
|
|
36
|
+
api_key=creds["AZURE_API_KEY"],
|
|
37
|
+
azure_endpoint=creds["AZURE_ENDPOINT"],
|
|
38
|
+
api_version=creds["AZURE_API_VERSION"],
|
|
28
39
|
http_client=default_httpx_client(),
|
|
29
|
-
max_retries=
|
|
40
|
+
max_retries=3,
|
|
30
41
|
)
|
|
31
|
-
|
|
42
|
+
self.assign_client(client)
|
|
43
|
+
return super(OpenAIModel, self).get_client(api_key)
|
|
32
44
|
|
|
33
45
|
def __init__(
|
|
34
46
|
self,
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
|
|
3
5
|
from model_library import model_library_settings
|
|
4
6
|
from model_library.base import (
|
|
7
|
+
DelegateConfig,
|
|
5
8
|
DelegateOnly,
|
|
6
9
|
LLMConfig,
|
|
7
10
|
)
|
|
8
|
-
from model_library.providers.openai import OpenAIModel
|
|
9
11
|
from model_library.register_models import register_provider
|
|
10
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
@register_provider("cohere")
|
|
@@ -22,13 +23,12 @@ class CohereModel(DelegateOnly):
|
|
|
22
23
|
super().__init__(model_name, provider, config=config)
|
|
23
24
|
|
|
24
25
|
# https://docs.cohere.com/docs/compatibility-api
|
|
25
|
-
self.
|
|
26
|
-
model_name=self.model_name,
|
|
27
|
-
provider=self.provider,
|
|
26
|
+
self.init_delegate(
|
|
28
27
|
config=config,
|
|
29
|
-
|
|
30
|
-
api_key=model_library_settings.COHERE_API_KEY,
|
|
28
|
+
delegate_config=DelegateConfig(
|
|
31
29
|
base_url="https://api.cohere.ai/compatibility/v1",
|
|
30
|
+
api_key=SecretStr(model_library_settings.COHERE_API_KEY),
|
|
32
31
|
),
|
|
33
32
|
use_completions=True,
|
|
33
|
+
delegate_provider="openai",
|
|
34
34
|
)
|
|
@@ -5,14 +5,15 @@ https://cdn.deepseek.com/policies/en-US/deepseek-privacy-policy.html
|
|
|
5
5
|
|
|
6
6
|
from typing import Literal
|
|
7
7
|
|
|
8
|
+
from pydantic import SecretStr
|
|
9
|
+
|
|
8
10
|
from model_library import model_library_settings
|
|
9
11
|
from model_library.base import (
|
|
12
|
+
DelegateConfig,
|
|
10
13
|
DelegateOnly,
|
|
11
14
|
LLMConfig,
|
|
12
15
|
)
|
|
13
|
-
from model_library.providers.openai import OpenAIModel
|
|
14
16
|
from model_library.register_models import register_provider
|
|
15
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
@register_provider("deepseek")
|
|
@@ -27,13 +28,12 @@ class DeepSeekModel(DelegateOnly):
|
|
|
27
28
|
super().__init__(model_name, provider, config=config)
|
|
28
29
|
|
|
29
30
|
# https://api-docs.deepseek.com/
|
|
30
|
-
self.
|
|
31
|
-
model_name=self.model_name,
|
|
32
|
-
provider=self.provider,
|
|
31
|
+
self.init_delegate(
|
|
33
32
|
config=config,
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
33
|
+
delegate_config=DelegateConfig(
|
|
34
|
+
base_url="https://api.deepseek.com/v1",
|
|
35
|
+
api_key=SecretStr(model_library_settings.DEEPSEEK_API_KEY),
|
|
37
36
|
),
|
|
38
37
|
use_completions=True,
|
|
38
|
+
delegate_provider="openai",
|
|
39
39
|
)
|
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
|
+
from pydantic import SecretStr
|
|
3
4
|
from typing_extensions import override
|
|
4
5
|
|
|
5
6
|
from model_library import model_library_settings
|
|
6
7
|
from model_library.base import (
|
|
8
|
+
DelegateConfig,
|
|
9
|
+
DelegateOnly,
|
|
7
10
|
LLMConfig,
|
|
8
11
|
ProviderConfig,
|
|
9
12
|
QueryResultCost,
|
|
10
13
|
QueryResultMetadata,
|
|
11
14
|
)
|
|
12
|
-
from model_library.base.delegate_only import DelegateOnly
|
|
13
|
-
from model_library.providers.openai import OpenAIModel
|
|
14
15
|
from model_library.register_models import register_provider
|
|
15
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class FireworksConfig(ProviderConfig):
|
|
@@ -38,15 +38,14 @@ class FireworksModel(DelegateOnly):
|
|
|
38
38
|
self.model_name = "accounts/rayan-936e28/deployedModels/" + self.model_name
|
|
39
39
|
|
|
40
40
|
# https://docs.fireworks.ai/tools-sdks/openai-compatibility
|
|
41
|
-
self.
|
|
42
|
-
model_name=self.model_name,
|
|
43
|
-
provider=self.provider,
|
|
41
|
+
self.init_delegate(
|
|
44
42
|
config=config,
|
|
45
|
-
|
|
46
|
-
api_key=model_library_settings.FIREWORKS_API_KEY,
|
|
43
|
+
delegate_config=DelegateConfig(
|
|
47
44
|
base_url="https://api.fireworks.ai/inference/v1",
|
|
45
|
+
api_key=SecretStr(model_library_settings.FIREWORKS_API_KEY),
|
|
48
46
|
),
|
|
49
47
|
use_completions=True,
|
|
48
|
+
delegate_provider="openai",
|
|
50
49
|
)
|
|
51
50
|
|
|
52
51
|
@override
|