model-library 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +13 -6
- model_library/base/output.py +55 -0
- model_library/base/utils.py +3 -2
- model_library/config/README.md +169 -0
- model_library/config/ai21labs_models.yaml +11 -10
- model_library/config/alibaba_models.yaml +21 -22
- model_library/config/all_models.json +4708 -2471
- model_library/config/amazon_models.yaml +100 -102
- model_library/config/anthropic_models.yaml +59 -45
- model_library/config/cohere_models.yaml +25 -24
- model_library/config/deepseek_models.yaml +28 -25
- model_library/config/dummy_model.yaml +9 -7
- model_library/config/fireworks_models.yaml +86 -56
- model_library/config/google_models.yaml +156 -102
- model_library/config/inception_models.yaml +6 -6
- model_library/config/kimi_models.yaml +13 -14
- model_library/config/minimax_models.yaml +37 -0
- model_library/config/mistral_models.yaml +85 -29
- model_library/config/openai_models.yaml +192 -159
- model_library/config/perplexity_models.yaml +8 -23
- model_library/config/together_models.yaml +115 -103
- model_library/config/xai_models.yaml +85 -57
- model_library/config/zai_models.yaml +23 -15
- model_library/exceptions.py +12 -17
- model_library/file_utils.py +1 -1
- model_library/providers/amazon.py +32 -17
- model_library/providers/anthropic.py +2 -6
- model_library/providers/google/google.py +35 -29
- model_library/providers/minimax.py +33 -0
- model_library/providers/mistral.py +10 -1
- model_library/providers/openai.py +10 -8
- model_library/providers/together.py +18 -211
- model_library/register_models.py +36 -38
- model_library/registry_utils.py +18 -16
- model_library/utils.py +2 -2
- {model_library-0.1.2.dist-info → model_library-0.1.4.dist-info}/METADATA +3 -4
- model_library-0.1.4.dist-info/RECORD +64 -0
- model_library-0.1.2.dist-info/RECORD +0 -61
- {model_library-0.1.2.dist-info → model_library-0.1.4.dist-info}/WHEEL +0 -0
- {model_library-0.1.2.dist-info → model_library-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.2.dist-info → model_library-0.1.4.dist-info}/top_level.txt +0 -0
|
@@ -2,8 +2,6 @@ import base64
|
|
|
2
2
|
import io
|
|
3
3
|
from typing import Any, Literal, Sequence, cast
|
|
4
4
|
|
|
5
|
-
from typing_extensions import override
|
|
6
|
-
|
|
7
5
|
from google.genai import Client
|
|
8
6
|
from google.genai import errors as genai_errors
|
|
9
7
|
from google.genai.types import (
|
|
@@ -18,10 +16,14 @@ from google.genai.types import (
|
|
|
18
16
|
Part,
|
|
19
17
|
SafetySetting,
|
|
20
18
|
ThinkingConfig,
|
|
19
|
+
ThinkingLevel,
|
|
21
20
|
Tool,
|
|
22
21
|
ToolListUnion,
|
|
23
22
|
UploadFileConfig,
|
|
23
|
+
FinishReason,
|
|
24
24
|
)
|
|
25
|
+
from typing_extensions import override
|
|
26
|
+
|
|
25
27
|
from model_library import model_library_settings
|
|
26
28
|
from model_library.base import (
|
|
27
29
|
LLM,
|
|
@@ -119,15 +121,6 @@ class GoogleModel(LLM):
|
|
|
119
121
|
):
|
|
120
122
|
super().__init__(model_name, provider, config=config)
|
|
121
123
|
|
|
122
|
-
# thinking tag
|
|
123
|
-
if self.model_name.endswith("-thinking"):
|
|
124
|
-
original_name = self.model_name
|
|
125
|
-
self.model_name = self.model_name.replace("-thinking", "")
|
|
126
|
-
self.reasoning = True
|
|
127
|
-
self.logger.info(
|
|
128
|
-
f"Enabled thinking mode for {original_name} -> {self.model_name}"
|
|
129
|
-
)
|
|
130
|
-
|
|
131
124
|
if self.provider_config.use_vertex:
|
|
132
125
|
self.supports_batch = False
|
|
133
126
|
|
|
@@ -261,14 +254,12 @@ class GoogleModel(LLM):
|
|
|
261
254
|
bytes: io.BytesIO,
|
|
262
255
|
type: Literal["image", "file"] = "file",
|
|
263
256
|
) -> FileWithId:
|
|
264
|
-
if
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
)
|
|
271
|
-
raise Exception("Model does not support batching")
|
|
257
|
+
if self.provider_config.use_vertex:
|
|
258
|
+
raise Exception(
|
|
259
|
+
"Vertex AI does not support file uploads. "
|
|
260
|
+
"use FileWithBase64 to pass files as inline data"
|
|
261
|
+
"or use genai for file uploads"
|
|
262
|
+
)
|
|
272
263
|
|
|
273
264
|
mime = f"image/{mime}" if type == "image" else mime # TODO:
|
|
274
265
|
response: File = self.client.files.upload(
|
|
@@ -294,7 +285,6 @@ class GoogleModel(LLM):
|
|
|
294
285
|
tools: list[ToolDefinition],
|
|
295
286
|
**kwargs: object,
|
|
296
287
|
) -> dict[str, Any]:
|
|
297
|
-
self.logger.debug(f"Creating request body for {self.model_name}")
|
|
298
288
|
generation_config = GenerateContentConfig(
|
|
299
289
|
max_output_tokens=self.max_tokens,
|
|
300
290
|
)
|
|
@@ -310,13 +300,15 @@ class GoogleModel(LLM):
|
|
|
310
300
|
if system_prompt and isinstance(system_prompt, str) and system_prompt.strip():
|
|
311
301
|
generation_config.system_instruction = str(system_prompt)
|
|
312
302
|
|
|
313
|
-
if
|
|
314
|
-
|
|
315
|
-
|
|
303
|
+
if self.reasoning:
|
|
304
|
+
reasoning_config = ThinkingConfig(include_thoughts=True)
|
|
305
|
+
if self.reasoning_effort:
|
|
306
|
+
reasoning_config.thinking_level = ThinkingLevel(self.reasoning_effort)
|
|
307
|
+
else:
|
|
308
|
+
reasoning_config.thinking_budget = cast(
|
|
316
309
|
int, kwargs.pop("thinking_budget", self.DEFAULT_THINKING_BUDGET)
|
|
317
|
-
)
|
|
318
|
-
|
|
319
|
-
)
|
|
310
|
+
)
|
|
311
|
+
generation_config.thinking_config = reasoning_config
|
|
320
312
|
|
|
321
313
|
if tools:
|
|
322
314
|
generation_config.tools = cast(ToolListUnion, await self.parse_tools(tools))
|
|
@@ -343,17 +335,20 @@ class GoogleModel(LLM):
|
|
|
343
335
|
text: str = ""
|
|
344
336
|
reasoning: str = ""
|
|
345
337
|
tool_calls: list[ToolCall] = []
|
|
346
|
-
last_content: Content | None = None
|
|
347
338
|
|
|
348
339
|
metadata: GenerateContentResponseUsageMetadata | None = None
|
|
349
340
|
|
|
350
341
|
stream = await self.client.aio.models.generate_content_stream(**body)
|
|
342
|
+
contents: list[Content | None] = []
|
|
343
|
+
finish_reason: FinishReason | None = None
|
|
344
|
+
|
|
351
345
|
async for chunk in stream:
|
|
352
346
|
candidates = chunk.candidates
|
|
353
347
|
if not candidates:
|
|
354
348
|
continue
|
|
355
349
|
|
|
356
350
|
content = candidates[0].content
|
|
351
|
+
|
|
357
352
|
if content and content.parts:
|
|
358
353
|
for part in content.parts:
|
|
359
354
|
if part.function_call:
|
|
@@ -378,14 +373,24 @@ class GoogleModel(LLM):
|
|
|
378
373
|
|
|
379
374
|
if chunk.usage_metadata:
|
|
380
375
|
metadata = chunk.usage_metadata
|
|
381
|
-
|
|
376
|
+
if content:
|
|
377
|
+
contents.append(content)
|
|
378
|
+
if candidates[0].finish_reason:
|
|
379
|
+
finish_reason = candidates[0].finish_reason
|
|
380
|
+
|
|
381
|
+
if finish_reason != FinishReason.STOP:
|
|
382
|
+
self.logger.error(f"Unexpected finish reason: {finish_reason}")
|
|
383
|
+
|
|
384
|
+
if not text and not reasoning and not tool_calls:
|
|
385
|
+
raise ModelNoOutputError("Model returned empty response")
|
|
382
386
|
|
|
383
387
|
result = QueryResult(
|
|
384
388
|
output_text=text,
|
|
385
389
|
reasoning=reasoning,
|
|
386
|
-
history=[*input,
|
|
390
|
+
history=[*input, *contents],
|
|
387
391
|
tool_calls=tool_calls,
|
|
388
392
|
)
|
|
393
|
+
|
|
389
394
|
if metadata:
|
|
390
395
|
# see _calculate_cost
|
|
391
396
|
cache_read_tokens = metadata.cached_content_token_count or 0
|
|
@@ -446,6 +451,7 @@ class GoogleModel(LLM):
|
|
|
446
451
|
"response_mime_type": "application/json",
|
|
447
452
|
}
|
|
448
453
|
)
|
|
454
|
+
|
|
449
455
|
body["config"] = config
|
|
450
456
|
|
|
451
457
|
# Make the request with retry wrapper
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from model_library import model_library_settings
|
|
4
|
+
from model_library.base import (
|
|
5
|
+
DelegateOnly,
|
|
6
|
+
LLMConfig,
|
|
7
|
+
)
|
|
8
|
+
from model_library.providers.openai import OpenAIModel
|
|
9
|
+
from model_library.register_models import register_provider
|
|
10
|
+
from model_library.utils import create_openai_client_with_defaults
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@register_provider("minimax")
|
|
14
|
+
class MinimaxModel(DelegateOnly):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
model_name: str,
|
|
18
|
+
provider: Literal["minimax"] = "minimax",
|
|
19
|
+
*,
|
|
20
|
+
config: LLMConfig | None = None,
|
|
21
|
+
):
|
|
22
|
+
super().__init__(model_name, provider, config=config)
|
|
23
|
+
|
|
24
|
+
self.delegate = OpenAIModel(
|
|
25
|
+
model_name=self.model_name,
|
|
26
|
+
provider=self.provider,
|
|
27
|
+
config=config,
|
|
28
|
+
custom_client=create_openai_client_with_defaults(
|
|
29
|
+
api_key=model_library_settings.MINIMAX_API_KEY,
|
|
30
|
+
base_url="https://api.minimax.io/v1",
|
|
31
|
+
),
|
|
32
|
+
use_completions=True,
|
|
33
|
+
)
|
|
@@ -29,6 +29,7 @@ from model_library.base import (
|
|
|
29
29
|
from model_library.exceptions import (
|
|
30
30
|
BadInputError,
|
|
31
31
|
MaxOutputTokensExceededError,
|
|
32
|
+
ModelNoOutputError,
|
|
32
33
|
)
|
|
33
34
|
from model_library.file_utils import trim_images
|
|
34
35
|
from model_library.register_models import register_provider
|
|
@@ -250,9 +251,17 @@ class MistralModel(LLM):
|
|
|
250
251
|
self.logger.error(f"Error: {e}", exc_info=True)
|
|
251
252
|
raise e
|
|
252
253
|
|
|
253
|
-
if
|
|
254
|
+
if (
|
|
255
|
+
finish_reason == "length"
|
|
256
|
+
and not text
|
|
257
|
+
and not reasoning
|
|
258
|
+
and not raw_tool_calls
|
|
259
|
+
):
|
|
254
260
|
raise MaxOutputTokensExceededError()
|
|
255
261
|
|
|
262
|
+
if not text and not reasoning and not raw_tool_calls:
|
|
263
|
+
raise ModelNoOutputError()
|
|
264
|
+
|
|
256
265
|
tool_calls: list[ToolCall] = []
|
|
257
266
|
|
|
258
267
|
for tool_call in raw_tool_calls or []:
|
|
@@ -5,7 +5,11 @@ import json
|
|
|
5
5
|
from typing import Any, Literal, Sequence, cast
|
|
6
6
|
|
|
7
7
|
from openai import APIConnectionError, AsyncOpenAI
|
|
8
|
-
from openai.types.chat import
|
|
8
|
+
from openai.types.chat import (
|
|
9
|
+
ChatCompletionMessage,
|
|
10
|
+
ChatCompletionMessageToolCall,
|
|
11
|
+
ChatCompletionMessageToolCallUnion,
|
|
12
|
+
)
|
|
9
13
|
from openai.types.chat.chat_completion_message_tool_call import Function
|
|
10
14
|
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
|
|
11
15
|
from openai.types.create_embedding_response import CreateEmbeddingResponse
|
|
@@ -517,10 +521,6 @@ class OpenAIModel(LLM):
|
|
|
517
521
|
metadata: QueryResultMetadata = QueryResultMetadata()
|
|
518
522
|
raw_tool_calls: list[ChatCompletionMessageToolCall] = []
|
|
519
523
|
|
|
520
|
-
# enable usage data in streaming responses
|
|
521
|
-
if "stream_options" not in body:
|
|
522
|
-
body["stream_options"] = {"include_usage": True}
|
|
523
|
-
|
|
524
524
|
stream = await self.get_client().chat.completions.create(
|
|
525
525
|
**body, # pyright: ignore[reportAny]
|
|
526
526
|
stream=True,
|
|
@@ -583,7 +583,7 @@ class OpenAIModel(LLM):
|
|
|
583
583
|
cache_read_tokens = (
|
|
584
584
|
chunk.usage.prompt_tokens_details.cached_tokens or 0
|
|
585
585
|
if chunk.usage.prompt_tokens_details
|
|
586
|
-
else 0
|
|
586
|
+
else getattr(chunk.usage, "cached_tokens", 0) # for kimi
|
|
587
587
|
)
|
|
588
588
|
metadata = QueryResultMetadata(
|
|
589
589
|
in_tokens=chunk.usage.prompt_tokens - cache_read_tokens,
|
|
@@ -617,9 +617,11 @@ class OpenAIModel(LLM):
|
|
|
617
617
|
final_message = ChatCompletionMessage(
|
|
618
618
|
role="assistant",
|
|
619
619
|
content=output_text if output_text else None,
|
|
620
|
-
tool_calls=
|
|
620
|
+
tool_calls=cast(list[ChatCompletionMessageToolCallUnion], raw_tool_calls)
|
|
621
|
+
if raw_tool_calls
|
|
622
|
+
else None,
|
|
621
623
|
)
|
|
622
|
-
if
|
|
624
|
+
if reasoning_text:
|
|
623
625
|
setattr(final_message, "reasoning_content", reasoning_text)
|
|
624
626
|
|
|
625
627
|
return QueryResult(
|
|
@@ -1,51 +1,27 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import Any, Literal, Sequence, cast
|
|
1
|
+
from typing import Literal
|
|
3
2
|
|
|
4
|
-
from together import AsyncTogether
|
|
5
|
-
from together.types.chat_completions import (
|
|
6
|
-
ChatCompletionMessage,
|
|
7
|
-
ChatCompletionResponse,
|
|
8
|
-
)
|
|
9
3
|
from typing_extensions import override
|
|
10
4
|
|
|
11
5
|
from model_library import model_library_settings
|
|
12
6
|
from model_library.base import (
|
|
13
|
-
|
|
14
|
-
FileInput,
|
|
15
|
-
FileWithBase64,
|
|
16
|
-
FileWithId,
|
|
17
|
-
FileWithUrl,
|
|
18
|
-
InputItem,
|
|
7
|
+
DelegateOnly,
|
|
19
8
|
LLMConfig,
|
|
20
|
-
|
|
9
|
+
ProviderConfig,
|
|
21
10
|
QueryResultCost,
|
|
22
11
|
QueryResultMetadata,
|
|
23
|
-
TextInput,
|
|
24
|
-
ToolDefinition,
|
|
25
|
-
)
|
|
26
|
-
from model_library.exceptions import (
|
|
27
|
-
BadInputError,
|
|
28
|
-
MaxOutputTokensExceededError,
|
|
29
|
-
ModelNoOutputError,
|
|
30
12
|
)
|
|
31
|
-
from model_library.file_utils import trim_images
|
|
32
|
-
from model_library.model_utils import get_reasoning_in_tag
|
|
33
13
|
from model_library.providers.openai import OpenAIModel
|
|
34
14
|
from model_library.register_models import register_provider
|
|
35
15
|
from model_library.utils import create_openai_client_with_defaults
|
|
36
16
|
|
|
37
17
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
_client: AsyncTogether | None = None
|
|
18
|
+
class TogetherConfig(ProviderConfig):
|
|
19
|
+
serverless: bool = True
|
|
41
20
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
api_key=model_library_settings.TOGETHER_API_KEY,
|
|
47
|
-
)
|
|
48
|
-
return TogetherModel._client
|
|
21
|
+
|
|
22
|
+
@register_provider("together")
|
|
23
|
+
class TogetherModel(DelegateOnly):
|
|
24
|
+
provider_config = TogetherConfig()
|
|
49
25
|
|
|
50
26
|
def __init__(
|
|
51
27
|
self,
|
|
@@ -55,187 +31,18 @@ class TogetherModel(LLM):
|
|
|
55
31
|
config: LLMConfig | None = None,
|
|
56
32
|
):
|
|
57
33
|
super().__init__(model_name, provider, config=config)
|
|
58
|
-
|
|
59
34
|
# https://docs.together.ai/docs/openai-api-compatibility
|
|
60
|
-
self.delegate
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
base_url="https://api.together.xyz/v1",
|
|
70
|
-
),
|
|
71
|
-
use_completions=False,
|
|
72
|
-
)
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
@override
|
|
76
|
-
async def parse_input(
|
|
77
|
-
self,
|
|
78
|
-
input: Sequence[InputItem],
|
|
79
|
-
**kwargs: Any,
|
|
80
|
-
) -> list[dict[str, Any] | Any]:
|
|
81
|
-
new_input: list[dict[str, Any] | Any] = []
|
|
82
|
-
content_user: list[dict[str, Any]] = []
|
|
83
|
-
|
|
84
|
-
def flush_content_user():
|
|
85
|
-
nonlocal content_user
|
|
86
|
-
|
|
87
|
-
if content_user:
|
|
88
|
-
new_input.append({"role": "user", "content": content_user})
|
|
89
|
-
content_user = []
|
|
90
|
-
|
|
91
|
-
for item in input:
|
|
92
|
-
match item:
|
|
93
|
-
case TextInput():
|
|
94
|
-
content_user.append({"type": "text", "text": item.text})
|
|
95
|
-
case FileWithBase64() | FileWithUrl() | FileWithId():
|
|
96
|
-
match item.type:
|
|
97
|
-
case "image":
|
|
98
|
-
content_user.append(await self.parse_image(item))
|
|
99
|
-
case "file":
|
|
100
|
-
content_user.append(await self.parse_file(item))
|
|
101
|
-
case ChatCompletionMessage():
|
|
102
|
-
flush_content_user()
|
|
103
|
-
new_input.append(item)
|
|
104
|
-
case _:
|
|
105
|
-
raise BadInputError("Unsupported input type")
|
|
106
|
-
|
|
107
|
-
flush_content_user()
|
|
108
|
-
|
|
109
|
-
return new_input
|
|
110
|
-
|
|
111
|
-
@override
|
|
112
|
-
async def parse_image(
|
|
113
|
-
self,
|
|
114
|
-
image: FileInput,
|
|
115
|
-
) -> dict[str, Any]:
|
|
116
|
-
match image:
|
|
117
|
-
case FileWithBase64():
|
|
118
|
-
return {
|
|
119
|
-
"type": "image_url",
|
|
120
|
-
"image_url": {
|
|
121
|
-
"url": f"data:image/{image.mime};base64,{image.base64}"
|
|
122
|
-
},
|
|
123
|
-
}
|
|
124
|
-
case _:
|
|
125
|
-
# docs show that we can pass in s3 location somehow
|
|
126
|
-
raise BadInputError("Unsupported image type")
|
|
127
|
-
|
|
128
|
-
@override
|
|
129
|
-
async def parse_file(
|
|
130
|
-
self,
|
|
131
|
-
file: FileInput,
|
|
132
|
-
) -> Any:
|
|
133
|
-
raise NotImplementedError()
|
|
134
|
-
|
|
135
|
-
@override
|
|
136
|
-
async def parse_tools(
|
|
137
|
-
self,
|
|
138
|
-
tools: list[ToolDefinition],
|
|
139
|
-
) -> Any:
|
|
140
|
-
raise NotImplementedError()
|
|
141
|
-
|
|
142
|
-
@override
|
|
143
|
-
async def upload_file(
|
|
144
|
-
self,
|
|
145
|
-
name: str,
|
|
146
|
-
mime: str,
|
|
147
|
-
bytes: io.BytesIO,
|
|
148
|
-
type: Literal["image", "file"] = "file",
|
|
149
|
-
) -> FileWithId:
|
|
150
|
-
raise NotImplementedError()
|
|
151
|
-
|
|
152
|
-
@override
|
|
153
|
-
async def _query_impl(
|
|
154
|
-
self,
|
|
155
|
-
input: Sequence[InputItem],
|
|
156
|
-
*,
|
|
157
|
-
tools: list[ToolDefinition],
|
|
158
|
-
**kwargs: object,
|
|
159
|
-
) -> QueryResult:
|
|
160
|
-
if self.delegate:
|
|
161
|
-
return await self.delegate_query(input, tools=tools, **kwargs)
|
|
162
|
-
|
|
163
|
-
# llama supports max 5 images
|
|
164
|
-
if "lama-4" in self.model_name:
|
|
165
|
-
input = trim_images(input, max_images=5)
|
|
166
|
-
|
|
167
|
-
messages: list[dict[str, Any]] = []
|
|
168
|
-
|
|
169
|
-
if "nemotron-super" in self.model_name:
|
|
170
|
-
# move system prompt to prompt
|
|
171
|
-
if "system_prompt" in kwargs:
|
|
172
|
-
first_text_item = next(
|
|
173
|
-
(item for item in input if isinstance(item, TextInput)), None
|
|
174
|
-
)
|
|
175
|
-
if not first_text_item:
|
|
176
|
-
raise Exception(
|
|
177
|
-
"Given system prompt for nemotron-super model, but no text input found"
|
|
178
|
-
)
|
|
179
|
-
system_prompt = kwargs.pop("system_prompt")
|
|
180
|
-
first_text_item.text = f"SYSTEM PROMPT: {system_prompt}\nUSER PROMPT: {first_text_item.text}"
|
|
181
|
-
|
|
182
|
-
# set system prompt to detailed thinking
|
|
183
|
-
mode = "on" if self.reasoning else "off"
|
|
184
|
-
kwargs["system_prompt"] = f"detailed thinking {mode}"
|
|
185
|
-
messages.append(
|
|
186
|
-
{
|
|
187
|
-
"role": "system",
|
|
188
|
-
"content": f"detailed thinking {mode}",
|
|
189
|
-
}
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
if "system_prompt" in kwargs:
|
|
193
|
-
messages.append({"role": "system", "content": kwargs.pop("system_prompt")})
|
|
194
|
-
|
|
195
|
-
messages.extend(await self.parse_input(input))
|
|
196
|
-
|
|
197
|
-
body: dict[str, Any] = {
|
|
198
|
-
"max_tokens": self.max_tokens,
|
|
199
|
-
"model": self.model_name,
|
|
200
|
-
"messages": messages,
|
|
201
|
-
}
|
|
202
|
-
|
|
203
|
-
if self.supports_temperature:
|
|
204
|
-
if self.temperature is not None:
|
|
205
|
-
body["temperature"] = self.temperature
|
|
206
|
-
if self.top_p is not None:
|
|
207
|
-
body["top_p"] = self.top_p
|
|
208
|
-
|
|
209
|
-
body.update(kwargs)
|
|
210
|
-
|
|
211
|
-
response = await self.get_client().chat.completions.create(**body, stream=False) # pyright: ignore[reportAny]
|
|
212
|
-
|
|
213
|
-
response = cast(ChatCompletionResponse, response)
|
|
214
|
-
|
|
215
|
-
if not response or not response.choices or not response.choices[0].message:
|
|
216
|
-
raise ModelNoOutputError("Model returned no completions")
|
|
217
|
-
|
|
218
|
-
text = str(response.choices[0].message.content)
|
|
219
|
-
reasoning = None
|
|
220
|
-
|
|
221
|
-
if response.choices[0].finish_reason == "length" and not text:
|
|
222
|
-
raise MaxOutputTokensExceededError()
|
|
223
|
-
|
|
224
|
-
if self.reasoning:
|
|
225
|
-
text, reasoning = get_reasoning_in_tag(text)
|
|
226
|
-
|
|
227
|
-
output = QueryResult(
|
|
228
|
-
output_text=text,
|
|
229
|
-
reasoning=reasoning,
|
|
230
|
-
history=[*input, response.choices[0].message],
|
|
35
|
+
self.delegate = OpenAIModel(
|
|
36
|
+
model_name=self.model_name,
|
|
37
|
+
provider=self.provider,
|
|
38
|
+
config=config,
|
|
39
|
+
custom_client=create_openai_client_with_defaults(
|
|
40
|
+
api_key=model_library_settings.TOGETHER_API_KEY,
|
|
41
|
+
base_url="https://api.together.xyz/v1",
|
|
42
|
+
),
|
|
43
|
+
use_completions=True,
|
|
231
44
|
)
|
|
232
45
|
|
|
233
|
-
if response.usage:
|
|
234
|
-
output.metadata.in_tokens = response.usage.prompt_tokens
|
|
235
|
-
output.metadata.out_tokens = response.usage.completion_tokens
|
|
236
|
-
# no cache tokens it seems
|
|
237
|
-
return output
|
|
238
|
-
|
|
239
46
|
@override
|
|
240
47
|
async def _calculate_cost(
|
|
241
48
|
self,
|
model_library/register_models.py
CHANGED
|
@@ -27,9 +27,25 @@ You can set metadata configs that are not passed into the LLMConfig class here,
|
|
|
27
27
|
"""
|
|
28
28
|
|
|
29
29
|
|
|
30
|
+
class Supports(BaseModel):
|
|
31
|
+
images: bool | None = None
|
|
32
|
+
videos: bool | None = None
|
|
33
|
+
files: bool | None = None
|
|
34
|
+
batch: bool | None = None
|
|
35
|
+
temperature: bool | None = None
|
|
36
|
+
tools: bool | None = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Metadata(BaseModel):
|
|
40
|
+
deprecated: bool = False
|
|
41
|
+
available_for_everyone: bool = True
|
|
42
|
+
available_as_evaluator: bool = False
|
|
43
|
+
ignored_for_cost: bool = False
|
|
44
|
+
|
|
45
|
+
|
|
30
46
|
class Properties(BaseModel):
|
|
31
47
|
context_window: int | None = None
|
|
32
|
-
|
|
48
|
+
max_tokens: int | None = None
|
|
33
49
|
training_cutoff: str | None = None
|
|
34
50
|
reasoning_model: bool | None = None
|
|
35
51
|
|
|
@@ -118,33 +134,9 @@ class CostProperties(BaseModel):
|
|
|
118
134
|
context: ContextCost | None = None
|
|
119
135
|
|
|
120
136
|
|
|
121
|
-
class ClassProperties(BaseModel):
|
|
122
|
-
supports_images: bool | None = None
|
|
123
|
-
supports_videos: bool | None = None
|
|
124
|
-
supports_files: bool | None = None
|
|
125
|
-
supports_batch_requests: bool | None = None
|
|
126
|
-
supports_temperature: bool | None = None
|
|
127
|
-
supports_tools: bool | None = None
|
|
128
|
-
# vals specific
|
|
129
|
-
deprecated: bool = False
|
|
130
|
-
available_for_everyone: bool = True
|
|
131
|
-
available_as_evaluator: bool = False
|
|
132
|
-
ignored_for_cost: bool = False
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
"""
|
|
136
|
-
Each provider can have a set of provider-specific properties, we however want to accept
|
|
137
|
-
any possible property from a provider in the yaml, and validate later. So we join all
|
|
138
|
-
provider-specific properties into a single class.
|
|
139
|
-
This has no effect on runtime use of ProviderConfig, only used to load the yaml
|
|
140
|
-
"""
|
|
141
|
-
|
|
142
|
-
|
|
143
137
|
class BaseProviderProperties(BaseModel):
|
|
144
138
|
"""Static base class for dynamic ProviderProperties."""
|
|
145
139
|
|
|
146
|
-
pass
|
|
147
|
-
|
|
148
140
|
|
|
149
141
|
def all_subclasses(cls: type) -> list[type]:
|
|
150
142
|
"""Recursively find all subclasses of a class."""
|
|
@@ -174,9 +166,9 @@ def get_dynamic_provider_properties_model() -> type[BaseProviderProperties]:
|
|
|
174
166
|
|
|
175
167
|
|
|
176
168
|
class DefaultParameters(BaseModel):
|
|
177
|
-
max_output_tokens: int | None = None
|
|
178
169
|
temperature: float | None = None
|
|
179
170
|
top_p: float | None = None
|
|
171
|
+
top_k: int | None = None
|
|
180
172
|
reasoning_effort: str | None = None
|
|
181
173
|
|
|
182
174
|
|
|
@@ -188,26 +180,29 @@ class RawModelConfig(BaseModel):
|
|
|
188
180
|
open_source: bool
|
|
189
181
|
documentation_url: str | None = None
|
|
190
182
|
properties: Properties = Field(default_factory=Properties)
|
|
191
|
-
|
|
192
|
-
|
|
183
|
+
supports: Supports
|
|
184
|
+
metadata: Metadata = Field(default_factory=Metadata)
|
|
185
|
+
provider_properties: BaseProviderProperties = Field(
|
|
186
|
+
default_factory=BaseProviderProperties
|
|
187
|
+
)
|
|
193
188
|
costs_per_million_token: CostProperties = Field(default_factory=CostProperties)
|
|
194
189
|
alternative_keys: list[str | dict[str, Any]] = Field(default_factory=list)
|
|
195
190
|
default_parameters: DefaultParameters = Field(default_factory=DefaultParameters)
|
|
191
|
+
provider_endpoint: str | None = None
|
|
196
192
|
|
|
197
193
|
def model_dump(self, *args: object, **kwargs: object):
|
|
198
194
|
data = super().model_dump(*args, **kwargs)
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
)
|
|
195
|
+
# explicitly dump dynamic ProviderProperties instance
|
|
196
|
+
data["provider_properties"] = self.provider_properties.model_dump(
|
|
197
|
+
*args, **kwargs
|
|
198
|
+
)
|
|
204
199
|
return data
|
|
205
200
|
|
|
206
201
|
|
|
207
202
|
class ModelConfig(RawModelConfig):
|
|
208
203
|
# post processing fields
|
|
204
|
+
provider_endpoint: str # pyright: ignore[reportIncompatibleVariableOverride, reportGeneralTypeIssues]
|
|
209
205
|
provider_name: str
|
|
210
|
-
provider_endpoint: str
|
|
211
206
|
full_key: str
|
|
212
207
|
slug: str
|
|
213
208
|
|
|
@@ -274,14 +269,17 @@ def _register_models() -> ModelRegistry:
|
|
|
274
269
|
current_model_config, model_config
|
|
275
270
|
)
|
|
276
271
|
|
|
272
|
+
provider_properties = current_model_config.pop(
|
|
273
|
+
"provider_properties", {}
|
|
274
|
+
)
|
|
275
|
+
|
|
277
276
|
# create model config object
|
|
278
277
|
raw_model_obj: RawModelConfig = RawModelConfig.model_validate(
|
|
279
|
-
current_model_config, strict=True
|
|
278
|
+
current_model_config, strict=True, extra="forbid"
|
|
280
279
|
)
|
|
281
280
|
|
|
282
281
|
provider_endpoint = (
|
|
283
|
-
|
|
284
|
-
or model_name.split("/", 1)[1]
|
|
282
|
+
raw_model_obj.provider_endpoint or model_name.split("/", 1)[1]
|
|
285
283
|
)
|
|
286
284
|
# add provider metadata
|
|
287
285
|
model_obj = ModelConfig.model_validate(
|
|
@@ -295,7 +293,7 @@ def _register_models() -> ModelRegistry:
|
|
|
295
293
|
)
|
|
296
294
|
# load provider properties separately since the model was generated at runtime
|
|
297
295
|
model_obj.provider_properties = ProviderProperties.model_validate(
|
|
298
|
-
|
|
296
|
+
provider_properties
|
|
299
297
|
)
|
|
300
298
|
|
|
301
299
|
registry[model_name] = model_obj
|