model-library 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +141 -62
- model_library/base/delegate_only.py +77 -10
- model_library/base/output.py +43 -0
- model_library/base/utils.py +35 -0
- model_library/config/alibaba_models.yaml +49 -57
- model_library/config/all_models.json +353 -120
- model_library/config/anthropic_models.yaml +2 -1
- model_library/config/kimi_models.yaml +30 -3
- model_library/config/mistral_models.yaml +2 -0
- model_library/config/openai_models.yaml +15 -23
- model_library/config/together_models.yaml +2 -0
- model_library/config/xiaomi_models.yaml +43 -0
- model_library/config/zai_models.yaml +27 -3
- model_library/exceptions.py +3 -77
- model_library/providers/ai21labs.py +12 -8
- model_library/providers/alibaba.py +17 -8
- model_library/providers/amazon.py +49 -16
- model_library/providers/anthropic.py +128 -48
- model_library/providers/azure.py +22 -10
- model_library/providers/cohere.py +7 -7
- model_library/providers/deepseek.py +8 -8
- model_library/providers/fireworks.py +7 -8
- model_library/providers/google/batch.py +14 -10
- model_library/providers/google/google.py +57 -30
- model_library/providers/inception.py +7 -7
- model_library/providers/kimi.py +18 -8
- model_library/providers/minimax.py +15 -17
- model_library/providers/mistral.py +20 -8
- model_library/providers/openai.py +99 -22
- model_library/providers/openrouter.py +34 -0
- model_library/providers/perplexity.py +7 -7
- model_library/providers/together.py +7 -8
- model_library/providers/vals.py +12 -6
- model_library/providers/vercel.py +34 -0
- model_library/providers/xai.py +47 -42
- model_library/providers/xiaomi.py +34 -0
- model_library/providers/zai.py +38 -8
- model_library/register_models.py +5 -0
- model_library/registry_utils.py +48 -17
- model_library/retriers/__init__.py +0 -0
- model_library/retriers/backoff.py +73 -0
- model_library/retriers/base.py +225 -0
- model_library/retriers/token.py +427 -0
- model_library/retriers/utils.py +11 -0
- model_library/settings.py +1 -1
- model_library/utils.py +17 -7
- {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/METADATA +2 -1
- model_library-0.1.9.dist-info/RECORD +73 -0
- {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/WHEEL +1 -1
- model_library-0.1.7.dist-info/RECORD +0 -64
- {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/top_level.txt +0 -0
|
@@ -1,16 +1,20 @@
|
|
|
1
|
+
import datetime
|
|
1
2
|
import io
|
|
2
3
|
import logging
|
|
4
|
+
import time
|
|
3
5
|
from typing import Any, Literal, Sequence, cast
|
|
4
6
|
|
|
5
|
-
from anthropic import AsyncAnthropic
|
|
7
|
+
from anthropic import APIConnectionError, AsyncAnthropic
|
|
6
8
|
from anthropic.types.beta.beta_tool_use_block import BetaToolUseBlock
|
|
7
9
|
from anthropic.types.beta.parsed_beta_message import ParsedBetaMessage
|
|
10
|
+
from pydantic import SecretStr
|
|
8
11
|
from typing_extensions import override
|
|
9
12
|
|
|
10
13
|
from model_library import model_library_settings
|
|
11
14
|
from model_library.base import (
|
|
12
15
|
LLM,
|
|
13
16
|
BatchResult,
|
|
17
|
+
DelegateConfig,
|
|
14
18
|
FileBase,
|
|
15
19
|
FileInput,
|
|
16
20
|
FileWithBase64,
|
|
@@ -19,9 +23,11 @@ from model_library.base import (
|
|
|
19
23
|
InputItem,
|
|
20
24
|
LLMBatchMixin,
|
|
21
25
|
LLMConfig,
|
|
26
|
+
ProviderConfig,
|
|
22
27
|
QueryResult,
|
|
23
28
|
QueryResultCost,
|
|
24
29
|
QueryResultMetadata,
|
|
30
|
+
RateLimit,
|
|
25
31
|
RawInput,
|
|
26
32
|
RawResponse,
|
|
27
33
|
TextInput,
|
|
@@ -31,6 +37,7 @@ from model_library.base import (
|
|
|
31
37
|
ToolResult,
|
|
32
38
|
)
|
|
33
39
|
from model_library.exceptions import (
|
|
40
|
+
ImmediateRetryException,
|
|
34
41
|
MaxOutputTokensExceededError,
|
|
35
42
|
NoMatchingToolCallError,
|
|
36
43
|
)
|
|
@@ -38,11 +45,15 @@ from model_library.model_utils import get_default_budget_tokens
|
|
|
38
45
|
from model_library.providers.openai import OpenAIModel
|
|
39
46
|
from model_library.register_models import register_provider
|
|
40
47
|
from model_library.utils import (
|
|
41
|
-
|
|
42
|
-
default_httpx_client,
|
|
48
|
+
create_anthropic_client_with_defaults,
|
|
43
49
|
)
|
|
44
50
|
|
|
45
51
|
|
|
52
|
+
class AnthropicConfig(ProviderConfig):
|
|
53
|
+
supports_compute_effort: bool = False
|
|
54
|
+
supports_auto_thinking: bool = False
|
|
55
|
+
|
|
56
|
+
|
|
46
57
|
class AnthropicBatchMixin(LLMBatchMixin):
|
|
47
58
|
"""Batch processing support for Anthropic's Message Batches API."""
|
|
48
59
|
|
|
@@ -246,21 +257,27 @@ class AnthropicBatchMixin(LLMBatchMixin):
|
|
|
246
257
|
|
|
247
258
|
@register_provider("anthropic")
|
|
248
259
|
class AnthropicModel(LLM):
|
|
249
|
-
|
|
260
|
+
provider_config = AnthropicConfig()
|
|
261
|
+
|
|
262
|
+
def _get_default_api_key(self) -> str:
|
|
263
|
+
if self.delegate_config:
|
|
264
|
+
return self.delegate_config.api_key.get_secret_value()
|
|
265
|
+
return model_library_settings.ANTHROPIC_API_KEY
|
|
250
266
|
|
|
251
267
|
@override
|
|
252
|
-
def get_client(self) -> AsyncAnthropic:
|
|
253
|
-
if self.
|
|
254
|
-
|
|
255
|
-
if not AnthropicModel._client:
|
|
268
|
+
def get_client(self, api_key: str | None = None) -> AsyncAnthropic:
|
|
269
|
+
if not self.has_client():
|
|
270
|
+
assert api_key
|
|
256
271
|
headers: dict[str, str] = {}
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
272
|
+
client = create_anthropic_client_with_defaults(
|
|
273
|
+
base_url=self.delegate_config.base_url
|
|
274
|
+
if self.delegate_config
|
|
275
|
+
else None,
|
|
276
|
+
api_key=api_key,
|
|
261
277
|
default_headers=headers,
|
|
262
278
|
)
|
|
263
|
-
|
|
279
|
+
self.assign_client(client)
|
|
280
|
+
return super().get_client()
|
|
264
281
|
|
|
265
282
|
def __init__(
|
|
266
283
|
self,
|
|
@@ -268,33 +285,32 @@ class AnthropicModel(LLM):
|
|
|
268
285
|
provider: str = "anthropic",
|
|
269
286
|
*,
|
|
270
287
|
config: LLMConfig | None = None,
|
|
271
|
-
|
|
288
|
+
delegate_config: DelegateConfig | None = None,
|
|
272
289
|
):
|
|
273
|
-
|
|
290
|
+
self.delegate_config = delegate_config
|
|
274
291
|
|
|
275
|
-
|
|
276
|
-
self._delegate_client: AsyncAnthropic | None = custom_client
|
|
292
|
+
super().__init__(model_name, provider, config=config)
|
|
277
293
|
|
|
278
294
|
# https://docs.anthropic.com/en/api/openai-sdk
|
|
279
295
|
self.delegate = (
|
|
280
296
|
None
|
|
281
|
-
if self.native or
|
|
297
|
+
if self.native or self.delegate_config
|
|
282
298
|
else OpenAIModel(
|
|
283
299
|
model_name=self.model_name,
|
|
284
|
-
provider=provider,
|
|
300
|
+
provider=self.provider,
|
|
285
301
|
config=config,
|
|
286
|
-
|
|
287
|
-
|
|
302
|
+
use_completions=True,
|
|
303
|
+
delegate_config=DelegateConfig(
|
|
288
304
|
base_url="https://api.anthropic.com/v1/",
|
|
305
|
+
api_key=SecretStr(model_library_settings.ANTHROPIC_API_KEY),
|
|
289
306
|
),
|
|
290
|
-
use_completions=True,
|
|
291
307
|
)
|
|
292
308
|
)
|
|
293
309
|
|
|
294
310
|
# Initialize batch support if enabled
|
|
295
311
|
# Disable batch when using custom_client (similar to OpenAI)
|
|
296
312
|
self.supports_batch: bool = (
|
|
297
|
-
self.supports_batch and self.native and not
|
|
313
|
+
self.supports_batch and self.native and not self.delegate_config
|
|
298
314
|
)
|
|
299
315
|
self.batch: LLMBatchMixin | None = (
|
|
300
316
|
AnthropicBatchMixin(self) if self.supports_batch else None
|
|
@@ -520,7 +536,6 @@ class AnthropicModel(LLM):
|
|
|
520
536
|
**kwargs: object,
|
|
521
537
|
) -> dict[str, Any]:
|
|
522
538
|
body: dict[str, Any] = {
|
|
523
|
-
"max_tokens": self.max_tokens,
|
|
524
539
|
"model": self.model_name,
|
|
525
540
|
"messages": await self.parse_input(input),
|
|
526
541
|
}
|
|
@@ -534,14 +549,28 @@ class AnthropicModel(LLM):
|
|
|
534
549
|
}
|
|
535
550
|
]
|
|
536
551
|
|
|
552
|
+
if not self.max_tokens:
|
|
553
|
+
raise Exception("Anthropic models require a max_tokens parameter")
|
|
554
|
+
|
|
555
|
+
body["max_tokens"] = self.max_tokens
|
|
556
|
+
|
|
537
557
|
if self.reasoning:
|
|
538
|
-
|
|
539
|
-
"
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
558
|
+
if self.provider_config.supports_auto_thinking:
|
|
559
|
+
body["thinking"] = {"type": "auto"}
|
|
560
|
+
else:
|
|
561
|
+
budget_tokens = kwargs.pop(
|
|
562
|
+
"budget_tokens", get_default_budget_tokens(self.max_tokens)
|
|
563
|
+
)
|
|
564
|
+
body["thinking"] = {
|
|
565
|
+
"type": "enabled",
|
|
566
|
+
"budget_tokens": budget_tokens,
|
|
567
|
+
}
|
|
568
|
+
|
|
569
|
+
# effort controls compute allocation for text, tool calls, and thinking. Opus-4.5+
|
|
570
|
+
# use instead of reasoning_effort with auto_thinking
|
|
571
|
+
if self.provider_config.supports_compute_effort and self.compute_effort:
|
|
572
|
+
# default is "high"
|
|
573
|
+
body["output_config"] = {"effort": self.compute_effort}
|
|
545
574
|
|
|
546
575
|
# Thinking models don't support temperature: https://docs.claude.com/en/docs/build-with-claude/extended-thinking#feature-compatibility
|
|
547
576
|
if self.supports_temperature and not self.reasoning:
|
|
@@ -577,7 +606,7 @@ class AnthropicModel(LLM):
|
|
|
577
606
|
client = self.get_client()
|
|
578
607
|
|
|
579
608
|
# only send betas for the official Anthropic endpoint
|
|
580
|
-
is_anthropic_endpoint = self.
|
|
609
|
+
is_anthropic_endpoint = self.delegate_config is None
|
|
581
610
|
if not is_anthropic_endpoint:
|
|
582
611
|
client_base_url = getattr(client, "_base_url", None) or getattr(
|
|
583
612
|
client, "base_url", None
|
|
@@ -587,16 +616,29 @@ class AnthropicModel(LLM):
|
|
|
587
616
|
|
|
588
617
|
stream_kwargs = {**body}
|
|
589
618
|
if is_anthropic_endpoint:
|
|
590
|
-
betas = ["files-api-2025-04-14"
|
|
619
|
+
betas = ["files-api-2025-04-14"]
|
|
620
|
+
if self.provider_config.supports_auto_thinking:
|
|
621
|
+
betas.extend(
|
|
622
|
+
[
|
|
623
|
+
"auto-thinking-2026-01-12",
|
|
624
|
+
"effort-2025-11-24",
|
|
625
|
+
"max-effort-2026-01-24",
|
|
626
|
+
]
|
|
627
|
+
)
|
|
628
|
+
else:
|
|
629
|
+
betas.extend(["interleaved-thinking-2025-05-14"])
|
|
591
630
|
if "sonnet-4-5" in self.model_name:
|
|
592
631
|
betas.append("context-1m-2025-08-07")
|
|
593
632
|
stream_kwargs["betas"] = betas
|
|
594
633
|
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
634
|
+
try:
|
|
635
|
+
async with client.beta.messages.stream(
|
|
636
|
+
**stream_kwargs,
|
|
637
|
+
) as stream: # pyright: ignore[reportAny]
|
|
638
|
+
message = await stream.get_final_message()
|
|
639
|
+
self.logger.info(f"Anthropic Response finished: {message.id}")
|
|
640
|
+
except APIConnectionError:
|
|
641
|
+
raise ImmediateRetryException("Failed to connect to Anthropic")
|
|
600
642
|
|
|
601
643
|
text = ""
|
|
602
644
|
reasoning = ""
|
|
@@ -632,6 +674,38 @@ class AnthropicModel(LLM):
|
|
|
632
674
|
history=[*input, RawResponse(response=message)],
|
|
633
675
|
)
|
|
634
676
|
|
|
677
|
+
@override
|
|
678
|
+
async def get_rate_limit(self) -> RateLimit:
|
|
679
|
+
response = await self.get_client().messages.with_raw_response.create(
|
|
680
|
+
max_tokens=1,
|
|
681
|
+
messages=[
|
|
682
|
+
{
|
|
683
|
+
"role": "user",
|
|
684
|
+
"content": "Ping",
|
|
685
|
+
}
|
|
686
|
+
],
|
|
687
|
+
model=self.model_name,
|
|
688
|
+
)
|
|
689
|
+
headers = response.headers
|
|
690
|
+
|
|
691
|
+
server_time_str = headers.get("date")
|
|
692
|
+
if server_time_str:
|
|
693
|
+
server_time = datetime.datetime.strptime(
|
|
694
|
+
server_time_str, "%a, %d %b %Y %H:%M:%S GMT"
|
|
695
|
+
).replace(tzinfo=datetime.timezone.utc)
|
|
696
|
+
timestamp = server_time.timestamp()
|
|
697
|
+
else:
|
|
698
|
+
timestamp = time.time()
|
|
699
|
+
|
|
700
|
+
return RateLimit(
|
|
701
|
+
unix_timestamp=timestamp,
|
|
702
|
+
raw=headers,
|
|
703
|
+
request_limit=int(headers["anthropic-ratelimit-requests-limit"]),
|
|
704
|
+
request_remaining=int(headers["anthropic-ratelimit-requests-remaining"]),
|
|
705
|
+
token_limit=int(response.headers["anthropic-ratelimit-tokens-limit"]),
|
|
706
|
+
token_remaining=int(headers["anthropic-ratelimit-tokens-remaining"]),
|
|
707
|
+
)
|
|
708
|
+
|
|
635
709
|
@override
|
|
636
710
|
async def count_tokens(
|
|
637
711
|
self,
|
|
@@ -645,20 +719,26 @@ class AnthropicModel(LLM):
|
|
|
645
719
|
Count the number of tokens using Anthropic's native token counting API.
|
|
646
720
|
https://docs.anthropic.com/en/docs/build-with-claude/token-counting
|
|
647
721
|
"""
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
722
|
+
try:
|
|
723
|
+
input = [*history, *input]
|
|
724
|
+
if not input:
|
|
725
|
+
return 0
|
|
651
726
|
|
|
652
|
-
|
|
727
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
653
728
|
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
729
|
+
# Remove fields not supported by count_tokens endpoint
|
|
730
|
+
body.pop("max_tokens", None)
|
|
731
|
+
body.pop("temperature", None)
|
|
657
732
|
|
|
658
|
-
|
|
659
|
-
|
|
733
|
+
client = self.get_client()
|
|
734
|
+
response = await client.messages.count_tokens(**body)
|
|
660
735
|
|
|
661
|
-
|
|
736
|
+
return response.input_tokens
|
|
737
|
+
except Exception as e:
|
|
738
|
+
self.logger.error(f"Error counting tokens: {e}")
|
|
739
|
+
return await super().count_tokens(
|
|
740
|
+
input, history=history, tools=tools, **kwargs
|
|
741
|
+
)
|
|
662
742
|
|
|
663
743
|
@override
|
|
664
744
|
async def _calculate_cost(
|
model_library/providers/azure.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import json
|
|
1
2
|
from typing import Literal
|
|
2
3
|
|
|
3
4
|
from openai.lib.azure import AsyncAzureOpenAI
|
|
@@ -14,21 +15,32 @@ from model_library.utils import default_httpx_client
|
|
|
14
15
|
|
|
15
16
|
@register_provider("azure")
|
|
16
17
|
class AzureOpenAIModel(OpenAIModel):
|
|
17
|
-
_azure_client: AsyncAzureOpenAI | None = None
|
|
18
|
-
|
|
19
18
|
@override
|
|
20
|
-
def
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
19
|
+
def _get_default_api_key(self) -> str:
|
|
20
|
+
return json.dumps(
|
|
21
|
+
{
|
|
22
|
+
"AZURE_API_KEY": model_library_settings.AZURE_API_KEY,
|
|
23
|
+
"AZURE_ENDPOINT": model_library_settings.AZURE_ENDPOINT,
|
|
24
|
+
"AZURE_API_VERSION": model_library_settings.get(
|
|
26
25
|
"AZURE_API_VERSION", "2025-04-01-preview"
|
|
27
26
|
),
|
|
27
|
+
}
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
@override
|
|
31
|
+
def get_client(self, api_key: str | None = None) -> AsyncAzureOpenAI:
|
|
32
|
+
if not self.has_client():
|
|
33
|
+
assert api_key
|
|
34
|
+
creds = json.loads(api_key)
|
|
35
|
+
client = AsyncAzureOpenAI(
|
|
36
|
+
api_key=creds["AZURE_API_KEY"],
|
|
37
|
+
azure_endpoint=creds["AZURE_ENDPOINT"],
|
|
38
|
+
api_version=creds["AZURE_API_VERSION"],
|
|
28
39
|
http_client=default_httpx_client(),
|
|
29
|
-
max_retries=
|
|
40
|
+
max_retries=3,
|
|
30
41
|
)
|
|
31
|
-
|
|
42
|
+
self.assign_client(client)
|
|
43
|
+
return super(OpenAIModel, self).get_client(api_key)
|
|
32
44
|
|
|
33
45
|
def __init__(
|
|
34
46
|
self,
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
|
|
3
5
|
from model_library import model_library_settings
|
|
4
6
|
from model_library.base import (
|
|
7
|
+
DelegateConfig,
|
|
5
8
|
DelegateOnly,
|
|
6
9
|
LLMConfig,
|
|
7
10
|
)
|
|
8
|
-
from model_library.providers.openai import OpenAIModel
|
|
9
11
|
from model_library.register_models import register_provider
|
|
10
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
@register_provider("cohere")
|
|
@@ -22,13 +23,12 @@ class CohereModel(DelegateOnly):
|
|
|
22
23
|
super().__init__(model_name, provider, config=config)
|
|
23
24
|
|
|
24
25
|
# https://docs.cohere.com/docs/compatibility-api
|
|
25
|
-
self.
|
|
26
|
-
model_name=self.model_name,
|
|
27
|
-
provider=self.provider,
|
|
26
|
+
self.init_delegate(
|
|
28
27
|
config=config,
|
|
29
|
-
|
|
30
|
-
api_key=model_library_settings.COHERE_API_KEY,
|
|
28
|
+
delegate_config=DelegateConfig(
|
|
31
29
|
base_url="https://api.cohere.ai/compatibility/v1",
|
|
30
|
+
api_key=SecretStr(model_library_settings.COHERE_API_KEY),
|
|
32
31
|
),
|
|
33
32
|
use_completions=True,
|
|
33
|
+
delegate_provider="openai",
|
|
34
34
|
)
|
|
@@ -5,14 +5,15 @@ https://cdn.deepseek.com/policies/en-US/deepseek-privacy-policy.html
|
|
|
5
5
|
|
|
6
6
|
from typing import Literal
|
|
7
7
|
|
|
8
|
+
from pydantic import SecretStr
|
|
9
|
+
|
|
8
10
|
from model_library import model_library_settings
|
|
9
11
|
from model_library.base import (
|
|
12
|
+
DelegateConfig,
|
|
10
13
|
DelegateOnly,
|
|
11
14
|
LLMConfig,
|
|
12
15
|
)
|
|
13
|
-
from model_library.providers.openai import OpenAIModel
|
|
14
16
|
from model_library.register_models import register_provider
|
|
15
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
@register_provider("deepseek")
|
|
@@ -27,13 +28,12 @@ class DeepSeekModel(DelegateOnly):
|
|
|
27
28
|
super().__init__(model_name, provider, config=config)
|
|
28
29
|
|
|
29
30
|
# https://api-docs.deepseek.com/
|
|
30
|
-
self.
|
|
31
|
-
model_name=self.model_name,
|
|
32
|
-
provider=self.provider,
|
|
31
|
+
self.init_delegate(
|
|
33
32
|
config=config,
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
33
|
+
delegate_config=DelegateConfig(
|
|
34
|
+
base_url="https://api.deepseek.com/v1",
|
|
35
|
+
api_key=SecretStr(model_library_settings.DEEPSEEK_API_KEY),
|
|
37
36
|
),
|
|
38
37
|
use_completions=True,
|
|
38
|
+
delegate_provider="openai",
|
|
39
39
|
)
|
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
|
+
from pydantic import SecretStr
|
|
3
4
|
from typing_extensions import override
|
|
4
5
|
|
|
5
6
|
from model_library import model_library_settings
|
|
6
7
|
from model_library.base import (
|
|
8
|
+
DelegateConfig,
|
|
9
|
+
DelegateOnly,
|
|
7
10
|
LLMConfig,
|
|
8
11
|
ProviderConfig,
|
|
9
12
|
QueryResultCost,
|
|
10
13
|
QueryResultMetadata,
|
|
11
14
|
)
|
|
12
|
-
from model_library.base.delegate_only import DelegateOnly
|
|
13
|
-
from model_library.providers.openai import OpenAIModel
|
|
14
15
|
from model_library.register_models import register_provider
|
|
15
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class FireworksConfig(ProviderConfig):
|
|
@@ -38,15 +38,14 @@ class FireworksModel(DelegateOnly):
|
|
|
38
38
|
self.model_name = "accounts/rayan-936e28/deployedModels/" + self.model_name
|
|
39
39
|
|
|
40
40
|
# https://docs.fireworks.ai/tools-sdks/openai-compatibility
|
|
41
|
-
self.
|
|
42
|
-
model_name=self.model_name,
|
|
43
|
-
provider=self.provider,
|
|
41
|
+
self.init_delegate(
|
|
44
42
|
config=config,
|
|
45
|
-
|
|
46
|
-
api_key=model_library_settings.FIREWORKS_API_KEY,
|
|
43
|
+
delegate_config=DelegateConfig(
|
|
47
44
|
base_url="https://api.fireworks.ai/inference/v1",
|
|
45
|
+
api_key=SecretStr(model_library_settings.FIREWORKS_API_KEY),
|
|
48
46
|
),
|
|
49
47
|
use_completions=True,
|
|
48
|
+
delegate_provider="openai",
|
|
50
49
|
)
|
|
51
50
|
|
|
52
51
|
@override
|
|
@@ -24,16 +24,19 @@ from google.genai.types import (
|
|
|
24
24
|
)
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
def extract_text_from_json_response(response: dict[str, Any]) -> str:
|
|
27
|
+
def extract_text_from_json_response(response: dict[str, Any]) -> tuple[str, str]:
|
|
28
28
|
"""Extract concatenated non-thought text from a JSON response structure."""
|
|
29
29
|
# TODO: fix the typing we always ignore
|
|
30
30
|
text = ""
|
|
31
|
+
reasoning = ""
|
|
31
32
|
for candidate in response.get("candidates", []) or []: # type: ignore
|
|
32
33
|
content = (candidate or {}).get("content") or {} # type: ignore
|
|
33
34
|
for part in content.get("parts", []) or []: # type: ignore
|
|
34
|
-
if
|
|
35
|
+
if part.get("thought", False): # type: ignore
|
|
36
|
+
reasoning += part.get("text", "") # type: ignore
|
|
37
|
+
else:
|
|
35
38
|
text += part.get("text", "") # type: ignore
|
|
36
|
-
return text # type: ignore
|
|
39
|
+
return text, reasoning # type: ignore
|
|
37
40
|
|
|
38
41
|
|
|
39
42
|
def parse_predictions_jsonl(jsonl: str) -> list[BatchResult]:
|
|
@@ -48,9 +51,10 @@ def parse_predictions_jsonl(jsonl: str) -> list[BatchResult]:
|
|
|
48
51
|
custom_id = data.get("key", "unknown")
|
|
49
52
|
if "response" in data:
|
|
50
53
|
response = data["response"]
|
|
51
|
-
text = extract_text_from_json_response(response)
|
|
54
|
+
text, reasoning = extract_text_from_json_response(response)
|
|
52
55
|
output = QueryResult()
|
|
53
56
|
output.output_text = text
|
|
57
|
+
output.reasoning = reasoning
|
|
54
58
|
if "usageMetadata" in response:
|
|
55
59
|
output.metadata.in_tokens = response["usageMetadata"].get(
|
|
56
60
|
"promptTokenCount", 0
|
|
@@ -196,7 +200,7 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
196
200
|
custom_id = labels.get("qa_pair_id", f"request-{i}")
|
|
197
201
|
jsonl_lines.append(json.dumps({"key": custom_id, "request": request_data}))
|
|
198
202
|
|
|
199
|
-
batch_request_file = self._root.
|
|
203
|
+
batch_request_file = self._root.get_client().files.upload(
|
|
200
204
|
file=io.StringIO("\n".join(jsonl_lines)),
|
|
201
205
|
config=UploadFileConfig(mime_type="application/jsonl"),
|
|
202
206
|
)
|
|
@@ -205,7 +209,7 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
205
209
|
raise Exception("Failed to upload batch jsonl")
|
|
206
210
|
|
|
207
211
|
try:
|
|
208
|
-
job: BatchJob = await self._root.
|
|
212
|
+
job: BatchJob = await self._root.get_client().aio.batches.create(
|
|
209
213
|
model=self._root.model_name,
|
|
210
214
|
src=batch_request_file.name,
|
|
211
215
|
config={"display_name": batch_name},
|
|
@@ -224,14 +228,14 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
224
228
|
async def get_batch_results(self, batch_id: str) -> list[BatchResult]:
|
|
225
229
|
self._root.logger.info(f"Retrieving batch results for {batch_id}")
|
|
226
230
|
|
|
227
|
-
job = await self._root.
|
|
231
|
+
job = await self._root.get_client().aio.batches.get(name=batch_id)
|
|
228
232
|
|
|
229
233
|
results: list[BatchResult] = []
|
|
230
234
|
|
|
231
235
|
if job.state == JobState.JOB_STATE_SUCCEEDED:
|
|
232
236
|
if job.dest and job.dest.file_name:
|
|
233
237
|
results_file_name = job.dest.file_name
|
|
234
|
-
file_content = await self._root.
|
|
238
|
+
file_content = await self._root.get_client().aio.files.download(
|
|
235
239
|
file=results_file_name
|
|
236
240
|
)
|
|
237
241
|
decoded = file_content.decode("utf-8")
|
|
@@ -250,7 +254,7 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
250
254
|
@override
|
|
251
255
|
async def cancel_batch_request(self, batch_id: str):
|
|
252
256
|
self._root.logger.info(f"Cancelling batch {batch_id}")
|
|
253
|
-
await self._root.
|
|
257
|
+
await self._root.get_client().aio.batches.cancel(name=batch_id)
|
|
254
258
|
|
|
255
259
|
@override
|
|
256
260
|
async def get_batch_progress(self, batch_id: str) -> int:
|
|
@@ -262,7 +266,7 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
262
266
|
|
|
263
267
|
try:
|
|
264
268
|
self._root.logger.debug(f"Checking batch status for {batch_id}")
|
|
265
|
-
job: BatchJob = await self._root.
|
|
269
|
+
job: BatchJob = await self._root.get_client().aio.batches.get(name=batch_id)
|
|
266
270
|
state = job.state
|
|
267
271
|
|
|
268
272
|
if not state:
|