lm-deluge 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lm_deluge/api_requests/__init__.py +0 -2
- lm_deluge/api_requests/anthropic.py +58 -84
- lm_deluge/api_requests/base.py +43 -229
- lm_deluge/api_requests/bedrock.py +173 -195
- lm_deluge/api_requests/common.py +2 -0
- lm_deluge/api_requests/gemini.py +196 -0
- lm_deluge/api_requests/mistral.py +30 -60
- lm_deluge/api_requests/openai.py +147 -148
- lm_deluge/api_requests/response.py +2 -1
- lm_deluge/batches.py +1 -1
- lm_deluge/{computer_use/anthropic_tools.py → built_in_tools/anthropic.py} +56 -5
- lm_deluge/built_in_tools/openai.py +28 -0
- lm_deluge/client.py +221 -150
- lm_deluge/file.py +7 -2
- lm_deluge/image.py +13 -8
- lm_deluge/llm_tools/extract.py +23 -4
- lm_deluge/llm_tools/ocr.py +1 -0
- lm_deluge/models.py +96 -2
- lm_deluge/prompt.py +43 -27
- lm_deluge/request_context.py +75 -0
- lm_deluge/tool.py +93 -15
- lm_deluge/tracker.py +1 -0
- lm_deluge/usage.py +10 -0
- {lm_deluge-0.0.14.dist-info → lm_deluge-0.0.16.dist-info}/METADATA +25 -1
- lm_deluge-0.0.16.dist-info/RECORD +48 -0
- lm_deluge-0.0.14.dist-info/RECORD +0 -44
- {lm_deluge-0.0.14.dist-info → lm_deluge-0.0.16.dist-info}/WHEEL +0 -0
- {lm_deluge-0.0.14.dist-info → lm_deluge-0.0.16.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.14.dist-info → lm_deluge-0.0.16.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import json
|
|
3
3
|
import os
|
|
4
|
+
|
|
4
5
|
from aiohttp import ClientResponse
|
|
5
|
-
from typing import Callable
|
|
6
6
|
|
|
7
7
|
try:
|
|
8
8
|
from requests_aws4auth import AWS4Auth
|
|
@@ -12,186 +12,178 @@ except ImportError:
|
|
|
12
12
|
)
|
|
13
13
|
|
|
14
14
|
from lm_deluge.prompt import (
|
|
15
|
+
CachePattern,
|
|
15
16
|
Conversation,
|
|
16
17
|
Message,
|
|
17
18
|
Text,
|
|
18
|
-
ToolCall,
|
|
19
19
|
Thinking,
|
|
20
|
-
|
|
20
|
+
ToolCall,
|
|
21
21
|
)
|
|
22
|
+
from lm_deluge.request_context import RequestContext
|
|
23
|
+
from lm_deluge.tool import MCPServer, Tool
|
|
22
24
|
from lm_deluge.usage import Usage
|
|
23
|
-
from .base import APIRequestBase, APIResponse
|
|
24
25
|
|
|
25
|
-
from ..tracker import StatusTracker
|
|
26
26
|
from ..config import SamplingParams
|
|
27
27
|
from ..models import APIModel
|
|
28
|
+
from .base import APIRequestBase, APIResponse
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
results_arr=results_arr,
|
|
58
|
-
request_timeout=request_timeout,
|
|
59
|
-
sampling_params=sampling_params,
|
|
60
|
-
callback=callback,
|
|
61
|
-
all_model_names=all_model_names,
|
|
62
|
-
all_sampling_params=all_sampling_params,
|
|
63
|
-
tools=tools,
|
|
64
|
-
cache=cache,
|
|
31
|
+
# according to bedrock docs the header is "anthropic_beta" vs. "anthropic-beta"
|
|
32
|
+
# for anthropic. i don't know if this is a typo or the worst ever UX
|
|
33
|
+
def _add_beta(headers: dict, beta: str):
|
|
34
|
+
if "anthropic_beta" in headers and headers["anthropic_beta"]:
|
|
35
|
+
if beta not in headers["anthropic_beta"]:
|
|
36
|
+
headers["anthropic_beta"] += f",{beta}"
|
|
37
|
+
else:
|
|
38
|
+
headers["anthropic_beta"] = beta
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _build_anthropic_bedrock_request(
|
|
42
|
+
model: APIModel,
|
|
43
|
+
prompt: Conversation,
|
|
44
|
+
tools: list[Tool | dict | MCPServer] | None,
|
|
45
|
+
sampling_params: SamplingParams,
|
|
46
|
+
cache_pattern: CachePattern | None = None,
|
|
47
|
+
):
|
|
48
|
+
system_message, messages = prompt.to_anthropic(cache_pattern=cache_pattern)
|
|
49
|
+
|
|
50
|
+
# handle AWS auth
|
|
51
|
+
access_key = os.getenv("AWS_ACCESS_KEY_ID")
|
|
52
|
+
secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
|
|
53
|
+
session_token = os.getenv("AWS_SESSION_TOKEN")
|
|
54
|
+
|
|
55
|
+
if not access_key or not secret_key:
|
|
56
|
+
raise ValueError(
|
|
57
|
+
"AWS credentials not found. Please set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables."
|
|
65
58
|
)
|
|
66
59
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
60
|
+
# Determine region - use us-west-2 for cross-region inference models
|
|
61
|
+
if model.name.startswith("us.anthropic."):
|
|
62
|
+
# Cross-region inference profiles should use us-west-2
|
|
63
|
+
region = "us-west-2"
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError("only cross-region inference for bedrock")
|
|
66
|
+
# # Direct model IDs can use default region
|
|
67
|
+
# region = getattr(model, "region", "us-east-1")
|
|
68
|
+
# if hasattr(model, "regions") and model.regions:
|
|
69
|
+
# if isinstance(model.regions, list):
|
|
70
|
+
# region = model.regions[0]
|
|
71
|
+
# elif isinstance(model.regions, dict):
|
|
72
|
+
# region = list(model.regions.keys())[0]
|
|
73
|
+
|
|
74
|
+
# Construct the endpoint URL
|
|
75
|
+
service = "bedrock" # Service name for signing is 'bedrock' even though endpoint is bedrock-runtime
|
|
76
|
+
url = f"https://bedrock-runtime.{region}.amazonaws.com/model/{model.name}/invoke"
|
|
77
|
+
|
|
78
|
+
# Prepare headers
|
|
79
|
+
auth = AWS4Auth(
|
|
80
|
+
access_key,
|
|
81
|
+
secret_key,
|
|
82
|
+
region,
|
|
83
|
+
service,
|
|
84
|
+
session_token=session_token,
|
|
85
|
+
)
|
|
70
86
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
87
|
+
# Setup basic headers (AWS4Auth will add the Authorization header)
|
|
88
|
+
request_header = {
|
|
89
|
+
"Content-Type": "application/json",
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
# Prepare request body in Anthropic's bedrock format
|
|
93
|
+
request_json = {
|
|
94
|
+
"anthropic_version": "bedrock-2023-05-31",
|
|
95
|
+
"max_tokens": sampling_params.max_new_tokens,
|
|
96
|
+
"temperature": sampling_params.temperature,
|
|
97
|
+
"top_p": sampling_params.top_p,
|
|
98
|
+
"messages": messages,
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
if system_message is not None:
|
|
102
|
+
request_json["system"] = system_message
|
|
103
|
+
|
|
104
|
+
if tools:
|
|
105
|
+
mcp_servers = []
|
|
106
|
+
tool_definitions = []
|
|
107
|
+
for tool in tools:
|
|
108
|
+
if isinstance(tool, Tool):
|
|
109
|
+
tool_definitions.append(tool.dump_for("anthropic"))
|
|
110
|
+
elif isinstance(tool, dict):
|
|
111
|
+
tool_definitions.append(tool)
|
|
112
|
+
# add betas if needed
|
|
113
|
+
if tool["type"] in [
|
|
114
|
+
"computer_20241022",
|
|
115
|
+
"text_editor_20241022",
|
|
116
|
+
"bash_20241022",
|
|
117
|
+
]:
|
|
118
|
+
_add_beta(request_header, "computer-use-2024-10-22")
|
|
119
|
+
elif tool["type"] == "computer_20250124":
|
|
120
|
+
_add_beta(request_header, "computer-use-2025-01-24")
|
|
121
|
+
elif tool["type"] == "code_execution_20250522":
|
|
122
|
+
_add_beta(request_header, "code-execution-2025-05-22")
|
|
123
|
+
elif isinstance(tool, MCPServer):
|
|
124
|
+
raise ValueError("bedrock doesn't support MCP connector right now")
|
|
125
|
+
# _add_beta(request_header, "mcp-client-2025-04-04")
|
|
126
|
+
# mcp_servers.append(tool.for_anthropic())
|
|
127
|
+
|
|
128
|
+
# Add cache control to last tool if tools_only caching is specified
|
|
129
|
+
if cache_pattern == "tools_only" and tool_definitions:
|
|
130
|
+
tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
|
|
131
|
+
|
|
132
|
+
request_json["tools"] = tool_definitions
|
|
133
|
+
if len(mcp_servers) > 0:
|
|
134
|
+
request_json["mcp_servers"] = mcp_servers
|
|
135
|
+
|
|
136
|
+
return request_json, request_header, auth, url
|
|
74
137
|
|
|
75
|
-
self.model = APIModel.from_registry(model_name)
|
|
76
138
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
139
|
+
class BedrockRequest(APIRequestBase):
|
|
140
|
+
def __init__(self, context: RequestContext):
|
|
141
|
+
super().__init__(context=context)
|
|
142
|
+
|
|
143
|
+
self.model = APIModel.from_registry(self.context.model_name)
|
|
144
|
+
self.url = f"{self.model.api_base}/messages"
|
|
81
145
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
146
|
+
# Lock images as bytes if caching is enabled
|
|
147
|
+
if self.context.cache is not None:
|
|
148
|
+
self.context.prompt.lock_images_as_bytes()
|
|
149
|
+
|
|
150
|
+
self.request_json, self.request_header, self.auth, self.url = (
|
|
151
|
+
_build_anthropic_bedrock_request(
|
|
152
|
+
self.model,
|
|
153
|
+
context.prompt,
|
|
154
|
+
context.tools,
|
|
155
|
+
context.sampling_params,
|
|
156
|
+
context.cache,
|
|
85
157
|
)
|
|
158
|
+
)
|
|
86
159
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
self.region = "us-west-2"
|
|
91
|
-
else:
|
|
92
|
-
# Direct model IDs can use default region
|
|
93
|
-
self.region = getattr(self.model, "region", "us-east-1")
|
|
94
|
-
if hasattr(self.model, "regions") and self.model.regions:
|
|
95
|
-
if isinstance(self.model.regions, list):
|
|
96
|
-
self.region = self.model.regions[0]
|
|
97
|
-
elif isinstance(self.model.regions, dict):
|
|
98
|
-
self.region = list(self.model.regions.keys())[0]
|
|
99
|
-
|
|
100
|
-
# Construct the endpoint URL
|
|
101
|
-
self.service = "bedrock" # Service name for signing is 'bedrock' even though endpoint is bedrock-runtime
|
|
102
|
-
self.url = f"https://bedrock-runtime.{self.region}.amazonaws.com/model/{self.model.name}/invoke"
|
|
103
|
-
|
|
104
|
-
# Convert prompt to Anthropic format for bedrock
|
|
105
|
-
self.system_message, messages = prompt.to_anthropic(cache_pattern=cache)
|
|
106
|
-
|
|
107
|
-
# Prepare request body in Anthropic's bedrock format
|
|
108
|
-
self.request_json = {
|
|
109
|
-
"anthropic_version": "bedrock-2023-05-31",
|
|
110
|
-
"max_tokens": sampling_params.max_new_tokens,
|
|
111
|
-
"temperature": sampling_params.temperature,
|
|
112
|
-
"top_p": sampling_params.top_p,
|
|
113
|
-
"messages": messages,
|
|
114
|
-
}
|
|
115
|
-
|
|
116
|
-
if self.system_message is not None:
|
|
117
|
-
self.request_json["system"] = self.system_message
|
|
118
|
-
|
|
119
|
-
if tools or self.computer_use:
|
|
120
|
-
tool_definitions = []
|
|
121
|
-
|
|
122
|
-
# Add Computer Use tools at the beginning if enabled
|
|
123
|
-
if self.computer_use:
|
|
124
|
-
from ..computer_use.anthropic_tools import get_anthropic_cu_tools
|
|
125
|
-
|
|
126
|
-
cu_tools = get_anthropic_cu_tools(
|
|
127
|
-
model=self.model.id,
|
|
128
|
-
display_width=self.display_width,
|
|
129
|
-
display_height=self.display_height,
|
|
130
|
-
)
|
|
131
|
-
tool_definitions.extend(cu_tools)
|
|
160
|
+
async def execute_once(self) -> APIResponse:
|
|
161
|
+
"""Override execute_once to handle AWS4Auth signing."""
|
|
162
|
+
import aiohttp
|
|
132
163
|
|
|
133
|
-
|
|
134
|
-
self.request_json["computer_use_display_width_px"] = self.display_width
|
|
135
|
-
self.request_json["computer_use_display_height_px"] = (
|
|
136
|
-
self.display_height
|
|
137
|
-
)
|
|
164
|
+
assert self.context.status_tracker
|
|
138
165
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
tool_definitions.extend([tool.dump_for("anthropic") for tool in tools])
|
|
166
|
+
self.context.status_tracker.total_requests += 1
|
|
167
|
+
timeout = aiohttp.ClientTimeout(total=self.context.request_timeout)
|
|
142
168
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
|
|
169
|
+
# Prepare the request data
|
|
170
|
+
payload = json.dumps(self.request_json, separators=(",", ":")).encode("utf-8")
|
|
146
171
|
|
|
147
|
-
|
|
172
|
+
# Create a fake requests.PreparedRequest object for AWS4Auth to sign
|
|
173
|
+
import requests
|
|
148
174
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
self.
|
|
152
|
-
|
|
153
|
-
self.
|
|
154
|
-
self.service,
|
|
155
|
-
session_token=self.session_token,
|
|
175
|
+
fake_request = requests.Request(
|
|
176
|
+
method="POST",
|
|
177
|
+
url=self.url,
|
|
178
|
+
data=payload,
|
|
179
|
+
headers=self.request_header.copy(),
|
|
156
180
|
)
|
|
157
181
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
}
|
|
182
|
+
prepared_request = fake_request.prepare()
|
|
183
|
+
signed_request = self.auth(prepared_request)
|
|
184
|
+
signed_headers = dict(signed_request.headers)
|
|
162
185
|
|
|
163
|
-
async def call_api(self):
|
|
164
|
-
"""Override call_api to handle AWS4Auth signing."""
|
|
165
186
|
try:
|
|
166
|
-
import aiohttp
|
|
167
|
-
|
|
168
|
-
self.status_tracker.total_requests += 1
|
|
169
|
-
timeout = aiohttp.ClientTimeout(total=self.request_timeout)
|
|
170
|
-
|
|
171
|
-
# Prepare the request data
|
|
172
|
-
payload = json.dumps(self.request_json, separators=(",", ":")).encode(
|
|
173
|
-
"utf-8"
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
# Create a fake requests.PreparedRequest object for AWS4Auth to sign
|
|
177
|
-
import requests
|
|
178
|
-
|
|
179
|
-
fake_request = requests.Request(
|
|
180
|
-
method="POST",
|
|
181
|
-
url=self.url,
|
|
182
|
-
data=payload,
|
|
183
|
-
headers=self.request_header.copy(),
|
|
184
|
-
)
|
|
185
|
-
|
|
186
|
-
# Prepare the request so AWS4Auth can sign it properly
|
|
187
|
-
prepared_request = fake_request.prepare()
|
|
188
|
-
|
|
189
|
-
# Let AWS4Auth sign the prepared request
|
|
190
|
-
signed_request = self.auth(prepared_request)
|
|
191
|
-
|
|
192
|
-
# Extract the signed headers
|
|
193
|
-
signed_headers = dict(signed_request.headers)
|
|
194
|
-
|
|
195
187
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
196
188
|
async with session.post(
|
|
197
189
|
url=self.url,
|
|
@@ -199,51 +191,36 @@ class BedrockRequest(APIRequestBase):
|
|
|
199
191
|
data=payload,
|
|
200
192
|
) as http_response:
|
|
201
193
|
response: APIResponse = await self.handle_response(http_response)
|
|
202
|
-
|
|
203
|
-
self.result.append(response)
|
|
204
|
-
if response.is_error:
|
|
205
|
-
self.handle_error(
|
|
206
|
-
create_new_request=response.retry_with_different_model or False,
|
|
207
|
-
give_up_if_no_other_models=response.give_up_if_no_other_models
|
|
208
|
-
or False,
|
|
209
|
-
)
|
|
210
|
-
else:
|
|
211
|
-
self.handle_success(response)
|
|
194
|
+
return response
|
|
212
195
|
|
|
213
196
|
except asyncio.TimeoutError:
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
usage=None,
|
|
225
|
-
)
|
|
197
|
+
return APIResponse(
|
|
198
|
+
id=self.context.task_id,
|
|
199
|
+
model_internal=self.context.model_name,
|
|
200
|
+
prompt=self.context.prompt,
|
|
201
|
+
sampling_params=self.context.sampling_params,
|
|
202
|
+
status_code=None,
|
|
203
|
+
is_error=True,
|
|
204
|
+
error_message="Request timed out (terminated by client).",
|
|
205
|
+
content=None,
|
|
206
|
+
usage=None,
|
|
226
207
|
)
|
|
227
|
-
self.handle_error(create_new_request=False)
|
|
228
208
|
|
|
229
209
|
except Exception as e:
|
|
230
210
|
from ..errors import raise_if_modal_exception
|
|
231
211
|
|
|
232
212
|
raise_if_modal_exception(e)
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
usage=None,
|
|
244
|
-
)
|
|
213
|
+
return APIResponse(
|
|
214
|
+
id=self.context.task_id,
|
|
215
|
+
model_internal=self.context.model_name,
|
|
216
|
+
prompt=self.context.prompt,
|
|
217
|
+
sampling_params=self.context.sampling_params,
|
|
218
|
+
status_code=None,
|
|
219
|
+
is_error=True,
|
|
220
|
+
error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
|
|
221
|
+
content=None,
|
|
222
|
+
usage=None,
|
|
245
223
|
)
|
|
246
|
-
self.handle_error(create_new_request=False)
|
|
247
224
|
|
|
248
225
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
249
226
|
is_error = False
|
|
@@ -253,6 +230,7 @@ class BedrockRequest(APIRequestBase):
|
|
|
253
230
|
usage = None
|
|
254
231
|
status_code = http_response.status
|
|
255
232
|
mimetype = http_response.headers.get("Content-Type", None)
|
|
233
|
+
assert self.context.status_tracker
|
|
256
234
|
|
|
257
235
|
if status_code >= 200 and status_code < 300:
|
|
258
236
|
try:
|
|
@@ -300,21 +278,21 @@ class BedrockRequest(APIRequestBase):
|
|
|
300
278
|
or status_code == 429
|
|
301
279
|
):
|
|
302
280
|
error_message += " (Rate limit error, triggering cooldown.)"
|
|
303
|
-
self.status_tracker.rate_limit_exceeded()
|
|
281
|
+
self.context.status_tracker.rate_limit_exceeded()
|
|
304
282
|
if "context length" in error_message or "too long" in error_message:
|
|
305
283
|
error_message += " (Context length exceeded, set retries to 0.)"
|
|
306
|
-
self.attempts_left = 0
|
|
284
|
+
self.context.attempts_left = 0
|
|
307
285
|
|
|
308
286
|
return APIResponse(
|
|
309
|
-
id=self.task_id,
|
|
287
|
+
id=self.context.task_id,
|
|
310
288
|
status_code=status_code,
|
|
311
289
|
is_error=is_error,
|
|
312
290
|
error_message=error_message,
|
|
313
|
-
prompt=self.prompt,
|
|
291
|
+
prompt=self.context.prompt,
|
|
314
292
|
content=content,
|
|
315
293
|
thinking=thinking,
|
|
316
|
-
model_internal=self.model_name,
|
|
294
|
+
model_internal=self.context.model_name,
|
|
317
295
|
region=self.region,
|
|
318
|
-
sampling_params=self.sampling_params,
|
|
296
|
+
sampling_params=self.context.sampling_params,
|
|
319
297
|
usage=usage,
|
|
320
298
|
)
|
lm_deluge/api_requests/common.py
CHANGED
|
@@ -2,6 +2,7 @@ from .openai import OpenAIRequest, OpenAIResponsesRequest
|
|
|
2
2
|
from .anthropic import AnthropicRequest
|
|
3
3
|
from .mistral import MistralRequest
|
|
4
4
|
from .bedrock import BedrockRequest
|
|
5
|
+
from .gemini import GeminiRequest
|
|
5
6
|
|
|
6
7
|
CLASSES = {
|
|
7
8
|
"openai": OpenAIRequest,
|
|
@@ -9,4 +10,5 @@ CLASSES = {
|
|
|
9
10
|
"anthropic": AnthropicRequest,
|
|
10
11
|
"mistral": MistralRequest,
|
|
11
12
|
"bedrock": BedrockRequest,
|
|
13
|
+
"gemini": GeminiRequest,
|
|
12
14
|
}
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
from aiohttp import ClientResponse
|
|
6
|
+
|
|
7
|
+
from lm_deluge.request_context import RequestContext
|
|
8
|
+
from lm_deluge.tool import Tool
|
|
9
|
+
|
|
10
|
+
from ..config import SamplingParams
|
|
11
|
+
from ..models import APIModel
|
|
12
|
+
from ..prompt import Conversation, Message, Text, Thinking, ToolCall
|
|
13
|
+
from ..usage import Usage
|
|
14
|
+
from .base import APIRequestBase, APIResponse
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _build_gemini_request(
|
|
18
|
+
model: APIModel,
|
|
19
|
+
prompt: Conversation,
|
|
20
|
+
tools: list[Tool] | None,
|
|
21
|
+
sampling_params: SamplingParams,
|
|
22
|
+
) -> dict:
|
|
23
|
+
system_message, messages = prompt.to_gemini()
|
|
24
|
+
|
|
25
|
+
request_json = {
|
|
26
|
+
"contents": messages,
|
|
27
|
+
"generationConfig": {
|
|
28
|
+
"temperature": sampling_params.temperature,
|
|
29
|
+
"topP": sampling_params.top_p,
|
|
30
|
+
"maxOutputTokens": sampling_params.max_new_tokens,
|
|
31
|
+
},
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
# Add system instruction if present
|
|
35
|
+
if system_message:
|
|
36
|
+
request_json["systemInstruction"] = {"parts": [{"text": system_message}]}
|
|
37
|
+
|
|
38
|
+
# Handle reasoning models (thinking)
|
|
39
|
+
if model.reasoning_model:
|
|
40
|
+
request_json["generationConfig"]["thinkingConfig"] = {"includeThoughts": True}
|
|
41
|
+
if sampling_params.reasoning_effort and "flash" in model.id:
|
|
42
|
+
budget = {"low": 1024, "medium": 4096, "high": 16384}.get(
|
|
43
|
+
sampling_params.reasoning_effort
|
|
44
|
+
)
|
|
45
|
+
request_json["generationConfig"]["thinkingConfig"]["thinkingBudget"] = (
|
|
46
|
+
budget
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
else:
|
|
50
|
+
if sampling_params.reasoning_effort:
|
|
51
|
+
warnings.warn(
|
|
52
|
+
f"Ignoring reasoning_effort param for non-reasoning model: {model.name}"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Add tools if provided
|
|
56
|
+
if tools:
|
|
57
|
+
tool_declarations = [tool.dump_for("google") for tool in tools]
|
|
58
|
+
request_json["tools"] = [{"functionDeclarations": tool_declarations}]
|
|
59
|
+
|
|
60
|
+
# Handle JSON mode
|
|
61
|
+
if sampling_params.json_mode and model.supports_json:
|
|
62
|
+
request_json["generationConfig"]["responseMimeType"] = "application/json"
|
|
63
|
+
|
|
64
|
+
return request_json
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class GeminiRequest(APIRequestBase):
|
|
68
|
+
def __init__(self, context: RequestContext):
|
|
69
|
+
super().__init__(context=context)
|
|
70
|
+
|
|
71
|
+
# Warn if cache is specified for Gemini model
|
|
72
|
+
if self.context.cache is not None:
|
|
73
|
+
warnings.warn(
|
|
74
|
+
f"Cache parameter '{self.context.cache}' is not supported for Gemini models, ignoring for {self.context.model_name}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
self.model = APIModel.from_registry(self.context.model_name)
|
|
78
|
+
# Gemini API endpoint format: https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent
|
|
79
|
+
self.url = f"{self.model.api_base}/models/{self.model.name}:generateContent"
|
|
80
|
+
self.request_header = {
|
|
81
|
+
"Content-Type": "application/json",
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
# Add API key as query parameter for Gemini
|
|
85
|
+
api_key = os.getenv(self.model.api_key_env_var)
|
|
86
|
+
if not api_key:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"API key environment variable {self.model.api_key_env_var} not set"
|
|
89
|
+
)
|
|
90
|
+
self.url += f"?key={api_key}"
|
|
91
|
+
|
|
92
|
+
self.request_json = _build_gemini_request(
|
|
93
|
+
self.model,
|
|
94
|
+
self.context.prompt,
|
|
95
|
+
self.context.tools,
|
|
96
|
+
self.context.sampling_params,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
100
|
+
is_error = False
|
|
101
|
+
error_message = None
|
|
102
|
+
thinking = None
|
|
103
|
+
content = None
|
|
104
|
+
usage = None
|
|
105
|
+
status_code = http_response.status
|
|
106
|
+
mimetype = http_response.headers.get("Content-Type", None)
|
|
107
|
+
data = None
|
|
108
|
+
assert self.context.status_tracker
|
|
109
|
+
|
|
110
|
+
if status_code >= 200 and status_code < 300:
|
|
111
|
+
try:
|
|
112
|
+
data = await http_response.json()
|
|
113
|
+
except Exception as e:
|
|
114
|
+
is_error = True
|
|
115
|
+
error_message = (
|
|
116
|
+
f"Error calling .json() on response w/ status {status_code}: {e}"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
if not is_error:
|
|
120
|
+
assert data
|
|
121
|
+
try:
|
|
122
|
+
# Parse Gemini response format
|
|
123
|
+
parts = []
|
|
124
|
+
|
|
125
|
+
if "candidates" in data and data["candidates"]:
|
|
126
|
+
candidate = data["candidates"][0]
|
|
127
|
+
if "content" in candidate and "parts" in candidate["content"]:
|
|
128
|
+
for part in candidate["content"]["parts"]:
|
|
129
|
+
if "text" in part:
|
|
130
|
+
parts.append(Text(part["text"]))
|
|
131
|
+
elif "thought" in part:
|
|
132
|
+
parts.append(Thinking(part["thought"]))
|
|
133
|
+
elif "functionCall" in part:
|
|
134
|
+
func_call = part["functionCall"]
|
|
135
|
+
# Generate a unique ID since Gemini doesn't provide one
|
|
136
|
+
import uuid
|
|
137
|
+
|
|
138
|
+
tool_id = f"call_{uuid.uuid4().hex[:8]}"
|
|
139
|
+
parts.append(
|
|
140
|
+
ToolCall(
|
|
141
|
+
id=tool_id,
|
|
142
|
+
name=func_call["name"],
|
|
143
|
+
arguments=func_call.get("args", {}),
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
content = Message("assistant", parts)
|
|
148
|
+
|
|
149
|
+
# Extract usage information if present
|
|
150
|
+
if "usageMetadata" in data:
|
|
151
|
+
usage_data = data["usageMetadata"]
|
|
152
|
+
usage = Usage.from_gemini_usage(usage_data)
|
|
153
|
+
|
|
154
|
+
except Exception as e:
|
|
155
|
+
is_error = True
|
|
156
|
+
error_message = f"Error parsing Gemini response: {str(e)}"
|
|
157
|
+
|
|
158
|
+
elif mimetype and "json" in mimetype.lower():
|
|
159
|
+
is_error = True
|
|
160
|
+
try:
|
|
161
|
+
data = await http_response.json()
|
|
162
|
+
error_message = json.dumps(data)
|
|
163
|
+
except Exception:
|
|
164
|
+
error_message = (
|
|
165
|
+
f"HTTP {status_code} with JSON content type but failed to parse"
|
|
166
|
+
)
|
|
167
|
+
else:
|
|
168
|
+
is_error = True
|
|
169
|
+
text = await http_response.text()
|
|
170
|
+
error_message = text
|
|
171
|
+
|
|
172
|
+
# Handle special kinds of errors
|
|
173
|
+
if is_error and error_message is not None:
|
|
174
|
+
if "rate limit" in error_message.lower() or status_code == 429:
|
|
175
|
+
error_message += " (Rate limit error, triggering cooldown.)"
|
|
176
|
+
self.context.status_tracker.rate_limit_exceeded()
|
|
177
|
+
if (
|
|
178
|
+
"context length" in error_message.lower()
|
|
179
|
+
or "token limit" in error_message.lower()
|
|
180
|
+
):
|
|
181
|
+
error_message += " (Context length exceeded, set retries to 0.)"
|
|
182
|
+
self.context.attempts_left = 0
|
|
183
|
+
|
|
184
|
+
return APIResponse(
|
|
185
|
+
id=self.context.task_id,
|
|
186
|
+
status_code=status_code,
|
|
187
|
+
is_error=is_error,
|
|
188
|
+
error_message=error_message,
|
|
189
|
+
prompt=self.context.prompt,
|
|
190
|
+
content=content,
|
|
191
|
+
thinking=thinking,
|
|
192
|
+
model_internal=self.context.model_name,
|
|
193
|
+
sampling_params=self.context.sampling_params,
|
|
194
|
+
usage=usage,
|
|
195
|
+
raw_response=data,
|
|
196
|
+
)
|