lm-deluge 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lm_deluge/api_requests/__init__.py +0 -2
- lm_deluge/api_requests/anthropic.py +58 -84
- lm_deluge/api_requests/base.py +43 -229
- lm_deluge/api_requests/bedrock.py +173 -195
- lm_deluge/api_requests/gemini.py +18 -44
- lm_deluge/api_requests/mistral.py +30 -60
- lm_deluge/api_requests/openai.py +147 -148
- lm_deluge/api_requests/response.py +2 -1
- lm_deluge/batches.py +1 -1
- lm_deluge/{computer_use/anthropic_tools.py → built_in_tools/anthropic.py} +56 -5
- lm_deluge/built_in_tools/openai.py +28 -0
- lm_deluge/client.py +221 -150
- lm_deluge/image.py +13 -8
- lm_deluge/llm_tools/extract.py +23 -4
- lm_deluge/llm_tools/ocr.py +1 -0
- lm_deluge/models.py +39 -2
- lm_deluge/prompt.py +43 -27
- lm_deluge/request_context.py +75 -0
- lm_deluge/tool.py +93 -15
- lm_deluge/tracker.py +1 -0
- {lm_deluge-0.0.15.dist-info → lm_deluge-0.0.16.dist-info}/METADATA +25 -1
- {lm_deluge-0.0.15.dist-info → lm_deluge-0.0.16.dist-info}/RECORD +25 -22
- {lm_deluge-0.0.15.dist-info → lm_deluge-0.0.16.dist-info}/WHEEL +0 -0
- {lm_deluge-0.0.15.dist-info → lm_deluge-0.0.16.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.15.dist-info → lm_deluge-0.0.16.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import json
|
|
3
3
|
import os
|
|
4
|
+
|
|
4
5
|
from aiohttp import ClientResponse
|
|
5
|
-
from typing import Callable
|
|
6
6
|
|
|
7
7
|
try:
|
|
8
8
|
from requests_aws4auth import AWS4Auth
|
|
@@ -12,186 +12,178 @@ except ImportError:
|
|
|
12
12
|
)
|
|
13
13
|
|
|
14
14
|
from lm_deluge.prompt import (
|
|
15
|
+
CachePattern,
|
|
15
16
|
Conversation,
|
|
16
17
|
Message,
|
|
17
18
|
Text,
|
|
18
|
-
ToolCall,
|
|
19
19
|
Thinking,
|
|
20
|
-
|
|
20
|
+
ToolCall,
|
|
21
21
|
)
|
|
22
|
+
from lm_deluge.request_context import RequestContext
|
|
23
|
+
from lm_deluge.tool import MCPServer, Tool
|
|
22
24
|
from lm_deluge.usage import Usage
|
|
23
|
-
from .base import APIRequestBase, APIResponse
|
|
24
25
|
|
|
25
|
-
from ..tracker import StatusTracker
|
|
26
26
|
from ..config import SamplingParams
|
|
27
27
|
from ..models import APIModel
|
|
28
|
+
from .base import APIRequestBase, APIResponse
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
results_arr=results_arr,
|
|
58
|
-
request_timeout=request_timeout,
|
|
59
|
-
sampling_params=sampling_params,
|
|
60
|
-
callback=callback,
|
|
61
|
-
all_model_names=all_model_names,
|
|
62
|
-
all_sampling_params=all_sampling_params,
|
|
63
|
-
tools=tools,
|
|
64
|
-
cache=cache,
|
|
31
|
+
# according to bedrock docs the header is "anthropic_beta" vs. "anthropic-beta"
|
|
32
|
+
# for anthropic. i don't know if this is a typo or the worst ever UX
|
|
33
|
+
def _add_beta(headers: dict, beta: str):
|
|
34
|
+
if "anthropic_beta" in headers and headers["anthropic_beta"]:
|
|
35
|
+
if beta not in headers["anthropic_beta"]:
|
|
36
|
+
headers["anthropic_beta"] += f",{beta}"
|
|
37
|
+
else:
|
|
38
|
+
headers["anthropic_beta"] = beta
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _build_anthropic_bedrock_request(
|
|
42
|
+
model: APIModel,
|
|
43
|
+
prompt: Conversation,
|
|
44
|
+
tools: list[Tool | dict | MCPServer] | None,
|
|
45
|
+
sampling_params: SamplingParams,
|
|
46
|
+
cache_pattern: CachePattern | None = None,
|
|
47
|
+
):
|
|
48
|
+
system_message, messages = prompt.to_anthropic(cache_pattern=cache_pattern)
|
|
49
|
+
|
|
50
|
+
# handle AWS auth
|
|
51
|
+
access_key = os.getenv("AWS_ACCESS_KEY_ID")
|
|
52
|
+
secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
|
|
53
|
+
session_token = os.getenv("AWS_SESSION_TOKEN")
|
|
54
|
+
|
|
55
|
+
if not access_key or not secret_key:
|
|
56
|
+
raise ValueError(
|
|
57
|
+
"AWS credentials not found. Please set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables."
|
|
65
58
|
)
|
|
66
59
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
60
|
+
# Determine region - use us-west-2 for cross-region inference models
|
|
61
|
+
if model.name.startswith("us.anthropic."):
|
|
62
|
+
# Cross-region inference profiles should use us-west-2
|
|
63
|
+
region = "us-west-2"
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError("only cross-region inference for bedrock")
|
|
66
|
+
# # Direct model IDs can use default region
|
|
67
|
+
# region = getattr(model, "region", "us-east-1")
|
|
68
|
+
# if hasattr(model, "regions") and model.regions:
|
|
69
|
+
# if isinstance(model.regions, list):
|
|
70
|
+
# region = model.regions[0]
|
|
71
|
+
# elif isinstance(model.regions, dict):
|
|
72
|
+
# region = list(model.regions.keys())[0]
|
|
73
|
+
|
|
74
|
+
# Construct the endpoint URL
|
|
75
|
+
service = "bedrock" # Service name for signing is 'bedrock' even though endpoint is bedrock-runtime
|
|
76
|
+
url = f"https://bedrock-runtime.{region}.amazonaws.com/model/{model.name}/invoke"
|
|
77
|
+
|
|
78
|
+
# Prepare headers
|
|
79
|
+
auth = AWS4Auth(
|
|
80
|
+
access_key,
|
|
81
|
+
secret_key,
|
|
82
|
+
region,
|
|
83
|
+
service,
|
|
84
|
+
session_token=session_token,
|
|
85
|
+
)
|
|
70
86
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
87
|
+
# Setup basic headers (AWS4Auth will add the Authorization header)
|
|
88
|
+
request_header = {
|
|
89
|
+
"Content-Type": "application/json",
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
# Prepare request body in Anthropic's bedrock format
|
|
93
|
+
request_json = {
|
|
94
|
+
"anthropic_version": "bedrock-2023-05-31",
|
|
95
|
+
"max_tokens": sampling_params.max_new_tokens,
|
|
96
|
+
"temperature": sampling_params.temperature,
|
|
97
|
+
"top_p": sampling_params.top_p,
|
|
98
|
+
"messages": messages,
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
if system_message is not None:
|
|
102
|
+
request_json["system"] = system_message
|
|
103
|
+
|
|
104
|
+
if tools:
|
|
105
|
+
mcp_servers = []
|
|
106
|
+
tool_definitions = []
|
|
107
|
+
for tool in tools:
|
|
108
|
+
if isinstance(tool, Tool):
|
|
109
|
+
tool_definitions.append(tool.dump_for("anthropic"))
|
|
110
|
+
elif isinstance(tool, dict):
|
|
111
|
+
tool_definitions.append(tool)
|
|
112
|
+
# add betas if needed
|
|
113
|
+
if tool["type"] in [
|
|
114
|
+
"computer_20241022",
|
|
115
|
+
"text_editor_20241022",
|
|
116
|
+
"bash_20241022",
|
|
117
|
+
]:
|
|
118
|
+
_add_beta(request_header, "computer-use-2024-10-22")
|
|
119
|
+
elif tool["type"] == "computer_20250124":
|
|
120
|
+
_add_beta(request_header, "computer-use-2025-01-24")
|
|
121
|
+
elif tool["type"] == "code_execution_20250522":
|
|
122
|
+
_add_beta(request_header, "code-execution-2025-05-22")
|
|
123
|
+
elif isinstance(tool, MCPServer):
|
|
124
|
+
raise ValueError("bedrock doesn't support MCP connector right now")
|
|
125
|
+
# _add_beta(request_header, "mcp-client-2025-04-04")
|
|
126
|
+
# mcp_servers.append(tool.for_anthropic())
|
|
127
|
+
|
|
128
|
+
# Add cache control to last tool if tools_only caching is specified
|
|
129
|
+
if cache_pattern == "tools_only" and tool_definitions:
|
|
130
|
+
tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
|
|
131
|
+
|
|
132
|
+
request_json["tools"] = tool_definitions
|
|
133
|
+
if len(mcp_servers) > 0:
|
|
134
|
+
request_json["mcp_servers"] = mcp_servers
|
|
135
|
+
|
|
136
|
+
return request_json, request_header, auth, url
|
|
74
137
|
|
|
75
|
-
self.model = APIModel.from_registry(model_name)
|
|
76
138
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
139
|
+
class BedrockRequest(APIRequestBase):
|
|
140
|
+
def __init__(self, context: RequestContext):
|
|
141
|
+
super().__init__(context=context)
|
|
142
|
+
|
|
143
|
+
self.model = APIModel.from_registry(self.context.model_name)
|
|
144
|
+
self.url = f"{self.model.api_base}/messages"
|
|
81
145
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
146
|
+
# Lock images as bytes if caching is enabled
|
|
147
|
+
if self.context.cache is not None:
|
|
148
|
+
self.context.prompt.lock_images_as_bytes()
|
|
149
|
+
|
|
150
|
+
self.request_json, self.request_header, self.auth, self.url = (
|
|
151
|
+
_build_anthropic_bedrock_request(
|
|
152
|
+
self.model,
|
|
153
|
+
context.prompt,
|
|
154
|
+
context.tools,
|
|
155
|
+
context.sampling_params,
|
|
156
|
+
context.cache,
|
|
85
157
|
)
|
|
158
|
+
)
|
|
86
159
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
self.region = "us-west-2"
|
|
91
|
-
else:
|
|
92
|
-
# Direct model IDs can use default region
|
|
93
|
-
self.region = getattr(self.model, "region", "us-east-1")
|
|
94
|
-
if hasattr(self.model, "regions") and self.model.regions:
|
|
95
|
-
if isinstance(self.model.regions, list):
|
|
96
|
-
self.region = self.model.regions[0]
|
|
97
|
-
elif isinstance(self.model.regions, dict):
|
|
98
|
-
self.region = list(self.model.regions.keys())[0]
|
|
99
|
-
|
|
100
|
-
# Construct the endpoint URL
|
|
101
|
-
self.service = "bedrock" # Service name for signing is 'bedrock' even though endpoint is bedrock-runtime
|
|
102
|
-
self.url = f"https://bedrock-runtime.{self.region}.amazonaws.com/model/{self.model.name}/invoke"
|
|
103
|
-
|
|
104
|
-
# Convert prompt to Anthropic format for bedrock
|
|
105
|
-
self.system_message, messages = prompt.to_anthropic(cache_pattern=cache)
|
|
106
|
-
|
|
107
|
-
# Prepare request body in Anthropic's bedrock format
|
|
108
|
-
self.request_json = {
|
|
109
|
-
"anthropic_version": "bedrock-2023-05-31",
|
|
110
|
-
"max_tokens": sampling_params.max_new_tokens,
|
|
111
|
-
"temperature": sampling_params.temperature,
|
|
112
|
-
"top_p": sampling_params.top_p,
|
|
113
|
-
"messages": messages,
|
|
114
|
-
}
|
|
115
|
-
|
|
116
|
-
if self.system_message is not None:
|
|
117
|
-
self.request_json["system"] = self.system_message
|
|
118
|
-
|
|
119
|
-
if tools or self.computer_use:
|
|
120
|
-
tool_definitions = []
|
|
121
|
-
|
|
122
|
-
# Add Computer Use tools at the beginning if enabled
|
|
123
|
-
if self.computer_use:
|
|
124
|
-
from ..computer_use.anthropic_tools import get_anthropic_cu_tools
|
|
125
|
-
|
|
126
|
-
cu_tools = get_anthropic_cu_tools(
|
|
127
|
-
model=self.model.id,
|
|
128
|
-
display_width=self.display_width,
|
|
129
|
-
display_height=self.display_height,
|
|
130
|
-
)
|
|
131
|
-
tool_definitions.extend(cu_tools)
|
|
160
|
+
async def execute_once(self) -> APIResponse:
|
|
161
|
+
"""Override execute_once to handle AWS4Auth signing."""
|
|
162
|
+
import aiohttp
|
|
132
163
|
|
|
133
|
-
|
|
134
|
-
self.request_json["computer_use_display_width_px"] = self.display_width
|
|
135
|
-
self.request_json["computer_use_display_height_px"] = (
|
|
136
|
-
self.display_height
|
|
137
|
-
)
|
|
164
|
+
assert self.context.status_tracker
|
|
138
165
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
tool_definitions.extend([tool.dump_for("anthropic") for tool in tools])
|
|
166
|
+
self.context.status_tracker.total_requests += 1
|
|
167
|
+
timeout = aiohttp.ClientTimeout(total=self.context.request_timeout)
|
|
142
168
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
|
|
169
|
+
# Prepare the request data
|
|
170
|
+
payload = json.dumps(self.request_json, separators=(",", ":")).encode("utf-8")
|
|
146
171
|
|
|
147
|
-
|
|
172
|
+
# Create a fake requests.PreparedRequest object for AWS4Auth to sign
|
|
173
|
+
import requests
|
|
148
174
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
self.
|
|
152
|
-
|
|
153
|
-
self.
|
|
154
|
-
self.service,
|
|
155
|
-
session_token=self.session_token,
|
|
175
|
+
fake_request = requests.Request(
|
|
176
|
+
method="POST",
|
|
177
|
+
url=self.url,
|
|
178
|
+
data=payload,
|
|
179
|
+
headers=self.request_header.copy(),
|
|
156
180
|
)
|
|
157
181
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
}
|
|
182
|
+
prepared_request = fake_request.prepare()
|
|
183
|
+
signed_request = self.auth(prepared_request)
|
|
184
|
+
signed_headers = dict(signed_request.headers)
|
|
162
185
|
|
|
163
|
-
async def call_api(self):
|
|
164
|
-
"""Override call_api to handle AWS4Auth signing."""
|
|
165
186
|
try:
|
|
166
|
-
import aiohttp
|
|
167
|
-
|
|
168
|
-
self.status_tracker.total_requests += 1
|
|
169
|
-
timeout = aiohttp.ClientTimeout(total=self.request_timeout)
|
|
170
|
-
|
|
171
|
-
# Prepare the request data
|
|
172
|
-
payload = json.dumps(self.request_json, separators=(",", ":")).encode(
|
|
173
|
-
"utf-8"
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
# Create a fake requests.PreparedRequest object for AWS4Auth to sign
|
|
177
|
-
import requests
|
|
178
|
-
|
|
179
|
-
fake_request = requests.Request(
|
|
180
|
-
method="POST",
|
|
181
|
-
url=self.url,
|
|
182
|
-
data=payload,
|
|
183
|
-
headers=self.request_header.copy(),
|
|
184
|
-
)
|
|
185
|
-
|
|
186
|
-
# Prepare the request so AWS4Auth can sign it properly
|
|
187
|
-
prepared_request = fake_request.prepare()
|
|
188
|
-
|
|
189
|
-
# Let AWS4Auth sign the prepared request
|
|
190
|
-
signed_request = self.auth(prepared_request)
|
|
191
|
-
|
|
192
|
-
# Extract the signed headers
|
|
193
|
-
signed_headers = dict(signed_request.headers)
|
|
194
|
-
|
|
195
187
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
196
188
|
async with session.post(
|
|
197
189
|
url=self.url,
|
|
@@ -199,51 +191,36 @@ class BedrockRequest(APIRequestBase):
|
|
|
199
191
|
data=payload,
|
|
200
192
|
) as http_response:
|
|
201
193
|
response: APIResponse = await self.handle_response(http_response)
|
|
202
|
-
|
|
203
|
-
self.result.append(response)
|
|
204
|
-
if response.is_error:
|
|
205
|
-
self.handle_error(
|
|
206
|
-
create_new_request=response.retry_with_different_model or False,
|
|
207
|
-
give_up_if_no_other_models=response.give_up_if_no_other_models
|
|
208
|
-
or False,
|
|
209
|
-
)
|
|
210
|
-
else:
|
|
211
|
-
self.handle_success(response)
|
|
194
|
+
return response
|
|
212
195
|
|
|
213
196
|
except asyncio.TimeoutError:
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
usage=None,
|
|
225
|
-
)
|
|
197
|
+
return APIResponse(
|
|
198
|
+
id=self.context.task_id,
|
|
199
|
+
model_internal=self.context.model_name,
|
|
200
|
+
prompt=self.context.prompt,
|
|
201
|
+
sampling_params=self.context.sampling_params,
|
|
202
|
+
status_code=None,
|
|
203
|
+
is_error=True,
|
|
204
|
+
error_message="Request timed out (terminated by client).",
|
|
205
|
+
content=None,
|
|
206
|
+
usage=None,
|
|
226
207
|
)
|
|
227
|
-
self.handle_error(create_new_request=False)
|
|
228
208
|
|
|
229
209
|
except Exception as e:
|
|
230
210
|
from ..errors import raise_if_modal_exception
|
|
231
211
|
|
|
232
212
|
raise_if_modal_exception(e)
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
usage=None,
|
|
244
|
-
)
|
|
213
|
+
return APIResponse(
|
|
214
|
+
id=self.context.task_id,
|
|
215
|
+
model_internal=self.context.model_name,
|
|
216
|
+
prompt=self.context.prompt,
|
|
217
|
+
sampling_params=self.context.sampling_params,
|
|
218
|
+
status_code=None,
|
|
219
|
+
is_error=True,
|
|
220
|
+
error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
|
|
221
|
+
content=None,
|
|
222
|
+
usage=None,
|
|
245
223
|
)
|
|
246
|
-
self.handle_error(create_new_request=False)
|
|
247
224
|
|
|
248
225
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
249
226
|
is_error = False
|
|
@@ -253,6 +230,7 @@ class BedrockRequest(APIRequestBase):
|
|
|
253
230
|
usage = None
|
|
254
231
|
status_code = http_response.status
|
|
255
232
|
mimetype = http_response.headers.get("Content-Type", None)
|
|
233
|
+
assert self.context.status_tracker
|
|
256
234
|
|
|
257
235
|
if status_code >= 200 and status_code < 300:
|
|
258
236
|
try:
|
|
@@ -300,21 +278,21 @@ class BedrockRequest(APIRequestBase):
|
|
|
300
278
|
or status_code == 429
|
|
301
279
|
):
|
|
302
280
|
error_message += " (Rate limit error, triggering cooldown.)"
|
|
303
|
-
self.status_tracker.rate_limit_exceeded()
|
|
281
|
+
self.context.status_tracker.rate_limit_exceeded()
|
|
304
282
|
if "context length" in error_message or "too long" in error_message:
|
|
305
283
|
error_message += " (Context length exceeded, set retries to 0.)"
|
|
306
|
-
self.attempts_left = 0
|
|
284
|
+
self.context.attempts_left = 0
|
|
307
285
|
|
|
308
286
|
return APIResponse(
|
|
309
|
-
id=self.task_id,
|
|
287
|
+
id=self.context.task_id,
|
|
310
288
|
status_code=status_code,
|
|
311
289
|
is_error=is_error,
|
|
312
290
|
error_message=error_message,
|
|
313
|
-
prompt=self.prompt,
|
|
291
|
+
prompt=self.context.prompt,
|
|
314
292
|
content=content,
|
|
315
293
|
thinking=thinking,
|
|
316
|
-
model_internal=self.model_name,
|
|
294
|
+
model_internal=self.context.model_name,
|
|
317
295
|
region=self.region,
|
|
318
|
-
sampling_params=self.sampling_params,
|
|
296
|
+
sampling_params=self.context.sampling_params,
|
|
319
297
|
usage=usage,
|
|
320
298
|
)
|
lm_deluge/api_requests/gemini.py
CHANGED
|
@@ -1,16 +1,15 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
3
|
import warnings
|
|
4
|
-
from typing import Callable
|
|
5
4
|
|
|
6
5
|
from aiohttp import ClientResponse
|
|
7
6
|
|
|
7
|
+
from lm_deluge.request_context import RequestContext
|
|
8
8
|
from lm_deluge.tool import Tool
|
|
9
9
|
|
|
10
10
|
from ..config import SamplingParams
|
|
11
11
|
from ..models import APIModel
|
|
12
|
-
from ..prompt import
|
|
13
|
-
from ..tracker import StatusTracker
|
|
12
|
+
from ..prompt import Conversation, Message, Text, Thinking, ToolCall
|
|
14
13
|
from ..usage import Usage
|
|
15
14
|
from .base import APIRequestBase, APIResponse
|
|
16
15
|
|
|
@@ -66,45 +65,16 @@ def _build_gemini_request(
|
|
|
66
65
|
|
|
67
66
|
|
|
68
67
|
class GeminiRequest(APIRequestBase):
|
|
69
|
-
def __init__(
|
|
70
|
-
|
|
71
|
-
task_id: int,
|
|
72
|
-
model_name: str, # must correspond to registry
|
|
73
|
-
prompt: Conversation,
|
|
74
|
-
attempts_left: int,
|
|
75
|
-
status_tracker: StatusTracker,
|
|
76
|
-
results_arr: list,
|
|
77
|
-
request_timeout: int = 30,
|
|
78
|
-
sampling_params: SamplingParams = SamplingParams(),
|
|
79
|
-
callback: Callable | None = None,
|
|
80
|
-
all_model_names: list[str] | None = None,
|
|
81
|
-
all_sampling_params: list[SamplingParams] | None = None,
|
|
82
|
-
tools: list | None = None,
|
|
83
|
-
cache: CachePattern | None = None,
|
|
84
|
-
):
|
|
85
|
-
super().__init__(
|
|
86
|
-
task_id=task_id,
|
|
87
|
-
model_name=model_name,
|
|
88
|
-
prompt=prompt,
|
|
89
|
-
attempts_left=attempts_left,
|
|
90
|
-
status_tracker=status_tracker,
|
|
91
|
-
results_arr=results_arr,
|
|
92
|
-
request_timeout=request_timeout,
|
|
93
|
-
sampling_params=sampling_params,
|
|
94
|
-
callback=callback,
|
|
95
|
-
all_model_names=all_model_names,
|
|
96
|
-
all_sampling_params=all_sampling_params,
|
|
97
|
-
tools=tools,
|
|
98
|
-
cache=cache,
|
|
99
|
-
)
|
|
68
|
+
def __init__(self, context: RequestContext):
|
|
69
|
+
super().__init__(context=context)
|
|
100
70
|
|
|
101
71
|
# Warn if cache is specified for Gemini model
|
|
102
|
-
if cache is not None:
|
|
72
|
+
if self.context.cache is not None:
|
|
103
73
|
warnings.warn(
|
|
104
|
-
f"Cache parameter '{cache}' is not supported for Gemini models, ignoring for {model_name}"
|
|
74
|
+
f"Cache parameter '{self.context.cache}' is not supported for Gemini models, ignoring for {self.context.model_name}"
|
|
105
75
|
)
|
|
106
76
|
|
|
107
|
-
self.model = APIModel.from_registry(model_name)
|
|
77
|
+
self.model = APIModel.from_registry(self.context.model_name)
|
|
108
78
|
# Gemini API endpoint format: https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent
|
|
109
79
|
self.url = f"{self.model.api_base}/models/{self.model.name}:generateContent"
|
|
110
80
|
self.request_header = {
|
|
@@ -120,7 +90,10 @@ class GeminiRequest(APIRequestBase):
|
|
|
120
90
|
self.url += f"?key={api_key}"
|
|
121
91
|
|
|
122
92
|
self.request_json = _build_gemini_request(
|
|
123
|
-
self.model,
|
|
93
|
+
self.model,
|
|
94
|
+
self.context.prompt,
|
|
95
|
+
self.context.tools,
|
|
96
|
+
self.context.sampling_params,
|
|
124
97
|
)
|
|
125
98
|
|
|
126
99
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
@@ -132,6 +105,7 @@ class GeminiRequest(APIRequestBase):
|
|
|
132
105
|
status_code = http_response.status
|
|
133
106
|
mimetype = http_response.headers.get("Content-Type", None)
|
|
134
107
|
data = None
|
|
108
|
+
assert self.context.status_tracker
|
|
135
109
|
|
|
136
110
|
if status_code >= 200 and status_code < 300:
|
|
137
111
|
try:
|
|
@@ -199,24 +173,24 @@ class GeminiRequest(APIRequestBase):
|
|
|
199
173
|
if is_error and error_message is not None:
|
|
200
174
|
if "rate limit" in error_message.lower() or status_code == 429:
|
|
201
175
|
error_message += " (Rate limit error, triggering cooldown.)"
|
|
202
|
-
self.status_tracker.rate_limit_exceeded()
|
|
176
|
+
self.context.status_tracker.rate_limit_exceeded()
|
|
203
177
|
if (
|
|
204
178
|
"context length" in error_message.lower()
|
|
205
179
|
or "token limit" in error_message.lower()
|
|
206
180
|
):
|
|
207
181
|
error_message += " (Context length exceeded, set retries to 0.)"
|
|
208
|
-
self.attempts_left = 0
|
|
182
|
+
self.context.attempts_left = 0
|
|
209
183
|
|
|
210
184
|
return APIResponse(
|
|
211
|
-
id=self.task_id,
|
|
185
|
+
id=self.context.task_id,
|
|
212
186
|
status_code=status_code,
|
|
213
187
|
is_error=is_error,
|
|
214
188
|
error_message=error_message,
|
|
215
|
-
prompt=self.prompt,
|
|
189
|
+
prompt=self.context.prompt,
|
|
216
190
|
content=content,
|
|
217
191
|
thinking=thinking,
|
|
218
|
-
model_internal=self.model_name,
|
|
219
|
-
sampling_params=self.sampling_params,
|
|
192
|
+
model_internal=self.context.model_name,
|
|
193
|
+
sampling_params=self.context.sampling_params,
|
|
220
194
|
usage=usage,
|
|
221
195
|
raw_response=data,
|
|
222
196
|
)
|