lm-deluge 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lm_deluge/api_requests/__init__.py +0 -2
- lm_deluge/api_requests/anthropic.py +58 -84
- lm_deluge/api_requests/base.py +43 -229
- lm_deluge/api_requests/bedrock.py +173 -195
- lm_deluge/api_requests/common.py +2 -0
- lm_deluge/api_requests/gemini.py +196 -0
- lm_deluge/api_requests/mistral.py +30 -60
- lm_deluge/api_requests/openai.py +147 -148
- lm_deluge/api_requests/response.py +2 -1
- lm_deluge/batches.py +1 -1
- lm_deluge/{computer_use/anthropic_tools.py → built_in_tools/anthropic.py} +56 -5
- lm_deluge/built_in_tools/openai.py +28 -0
- lm_deluge/client.py +221 -150
- lm_deluge/file.py +7 -2
- lm_deluge/image.py +13 -8
- lm_deluge/llm_tools/extract.py +23 -4
- lm_deluge/llm_tools/ocr.py +1 -0
- lm_deluge/models.py +96 -2
- lm_deluge/prompt.py +43 -27
- lm_deluge/request_context.py +75 -0
- lm_deluge/tool.py +93 -15
- lm_deluge/tracker.py +1 -0
- lm_deluge/usage.py +10 -0
- {lm_deluge-0.0.14.dist-info → lm_deluge-0.0.16.dist-info}/METADATA +25 -1
- lm_deluge-0.0.16.dist-info/RECORD +48 -0
- lm_deluge-0.0.14.dist-info/RECORD +0 -44
- {lm_deluge-0.0.14.dist-info → lm_deluge-0.0.16.dist-info}/WHEEL +0 -0
- {lm_deluge-0.0.14.dist-info → lm_deluge-0.0.16.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.14.dist-info → lm_deluge-0.0.16.dist-info}/top_level.txt +0 -0
|
@@ -1,35 +1,39 @@
|
|
|
1
|
-
from aiohttp import ClientResponse
|
|
2
1
|
import json
|
|
3
2
|
import os
|
|
4
|
-
|
|
3
|
+
|
|
4
|
+
from aiohttp import ClientResponse
|
|
5
5
|
|
|
6
6
|
from lm_deluge.prompt import (
|
|
7
|
+
CachePattern,
|
|
7
8
|
Conversation,
|
|
8
9
|
Message,
|
|
9
10
|
Text,
|
|
10
|
-
ToolCall,
|
|
11
11
|
Thinking,
|
|
12
|
-
|
|
12
|
+
ToolCall,
|
|
13
13
|
)
|
|
14
|
-
from lm_deluge.
|
|
14
|
+
from lm_deluge.request_context import RequestContext
|
|
15
|
+
from lm_deluge.tool import MCPServer, Tool
|
|
15
16
|
from lm_deluge.usage import Usage
|
|
16
|
-
from .base import APIRequestBase, APIResponse
|
|
17
17
|
|
|
18
|
-
from ..tracker import StatusTracker
|
|
19
18
|
from ..config import SamplingParams
|
|
20
19
|
from ..models import APIModel
|
|
21
|
-
from
|
|
20
|
+
from .base import APIRequestBase, APIResponse
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _add_beta(headers: dict, beta: str):
|
|
24
|
+
if "anthropic-beta" in headers and headers["anthropic-beta"]:
|
|
25
|
+
if beta not in headers["anthropic-beta"]:
|
|
26
|
+
headers["anthropic-beta"] += f",{beta}"
|
|
27
|
+
else:
|
|
28
|
+
headers["anthropic-beta"] = beta
|
|
22
29
|
|
|
23
30
|
|
|
24
31
|
def _build_anthropic_request(
|
|
25
32
|
model: APIModel,
|
|
26
33
|
prompt: Conversation,
|
|
27
|
-
tools: list[Tool] | None,
|
|
34
|
+
tools: list[Tool | dict | MCPServer] | None,
|
|
28
35
|
sampling_params: SamplingParams,
|
|
29
36
|
cache_pattern: CachePattern | None = None,
|
|
30
|
-
computer_use: bool = False,
|
|
31
|
-
display_width: int = 1024,
|
|
32
|
-
display_height: int = 768,
|
|
33
37
|
):
|
|
34
38
|
system_message, messages = prompt.to_anthropic(cache_pattern=cache_pattern)
|
|
35
39
|
request_header = {
|
|
@@ -38,10 +42,6 @@ def _build_anthropic_request(
|
|
|
38
42
|
"content-type": "application/json",
|
|
39
43
|
}
|
|
40
44
|
|
|
41
|
-
# Add beta header for Computer Use
|
|
42
|
-
if computer_use:
|
|
43
|
-
request_header["anthropic-beta"] = "computer-use-2025-01-24"
|
|
44
|
-
|
|
45
45
|
request_json = {
|
|
46
46
|
"model": model.name,
|
|
47
47
|
"messages": messages,
|
|
@@ -69,89 +69,61 @@ def _build_anthropic_request(
|
|
|
69
69
|
print("ignoring reasoning_effort for non-reasoning model")
|
|
70
70
|
if system_message is not None:
|
|
71
71
|
request_json["system"] = system_message
|
|
72
|
-
if tools
|
|
72
|
+
if tools:
|
|
73
|
+
mcp_servers = []
|
|
73
74
|
tool_definitions = []
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
75
|
+
for tool in tools:
|
|
76
|
+
if isinstance(tool, Tool):
|
|
77
|
+
tool_definitions.append(tool.dump_for("anthropic"))
|
|
78
|
+
elif isinstance(tool, dict):
|
|
79
|
+
tool_definitions.append(tool)
|
|
80
|
+
# add betas if needed
|
|
81
|
+
if tool["type"] in [
|
|
82
|
+
"computer_20241022",
|
|
83
|
+
"text_editor_20241022",
|
|
84
|
+
"bash_20241022",
|
|
85
|
+
]:
|
|
86
|
+
_add_beta(request_header, "computer-use-2024-10-22")
|
|
87
|
+
elif tool["type"] == "computer_20250124":
|
|
88
|
+
_add_beta(request_header, "computer-use-2025-01-24")
|
|
89
|
+
elif tool["type"] == "code_execution_20250522":
|
|
90
|
+
_add_beta(request_header, "code-execution-2025-05-22")
|
|
91
|
+
elif isinstance(tool, MCPServer):
|
|
92
|
+
_add_beta(request_header, "mcp-client-2025-04-04")
|
|
93
|
+
mcp_servers.append(tool.for_anthropic())
|
|
84
94
|
|
|
85
95
|
# Add cache control to last tool if tools_only caching is specified
|
|
86
96
|
if cache_pattern == "tools_only" and tool_definitions:
|
|
87
97
|
tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
|
|
88
98
|
|
|
89
99
|
request_json["tools"] = tool_definitions
|
|
100
|
+
if len(mcp_servers) > 0:
|
|
101
|
+
request_json["mcp_servers"] = mcp_servers
|
|
90
102
|
|
|
91
103
|
return request_json, request_header
|
|
92
104
|
|
|
93
105
|
|
|
94
106
|
class AnthropicRequest(APIRequestBase):
|
|
95
|
-
def __init__(
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
# internal logic should handle translating to specific API format
|
|
100
|
-
model_name: str, # must correspond to registry
|
|
101
|
-
prompt: Conversation,
|
|
102
|
-
attempts_left: int,
|
|
103
|
-
status_tracker: StatusTracker,
|
|
104
|
-
results_arr: list,
|
|
105
|
-
request_timeout: int = 30,
|
|
106
|
-
sampling_params: SamplingParams = SamplingParams(),
|
|
107
|
-
callback: Callable | None = None,
|
|
108
|
-
# for retries
|
|
109
|
-
all_model_names: list[str] | None = None,
|
|
110
|
-
all_sampling_params: list[SamplingParams] | None = None,
|
|
111
|
-
tools: list | None = None,
|
|
112
|
-
cache: CachePattern | None = None,
|
|
113
|
-
# Computer Use support
|
|
114
|
-
computer_use: bool = False,
|
|
115
|
-
display_width: int = 1024,
|
|
116
|
-
display_height: int = 768,
|
|
117
|
-
):
|
|
118
|
-
super().__init__(
|
|
119
|
-
task_id=task_id,
|
|
120
|
-
model_name=model_name,
|
|
121
|
-
prompt=prompt,
|
|
122
|
-
attempts_left=attempts_left,
|
|
123
|
-
status_tracker=status_tracker,
|
|
124
|
-
results_arr=results_arr,
|
|
125
|
-
request_timeout=request_timeout,
|
|
126
|
-
sampling_params=sampling_params,
|
|
127
|
-
callback=callback,
|
|
128
|
-
all_model_names=all_model_names,
|
|
129
|
-
all_sampling_params=all_sampling_params,
|
|
130
|
-
tools=tools,
|
|
131
|
-
cache=cache,
|
|
132
|
-
)
|
|
133
|
-
self.computer_use = computer_use
|
|
134
|
-
self.display_width = display_width
|
|
135
|
-
self.display_height = display_height
|
|
136
|
-
self.model = APIModel.from_registry(model_name)
|
|
107
|
+
def __init__(self, context: RequestContext):
|
|
108
|
+
super().__init__(context=context)
|
|
109
|
+
|
|
110
|
+
self.model = APIModel.from_registry(self.context.model_name)
|
|
137
111
|
self.url = f"{self.model.api_base}/messages"
|
|
138
112
|
|
|
139
113
|
# Lock images as bytes if caching is enabled
|
|
140
|
-
if cache is not None:
|
|
141
|
-
prompt.lock_images_as_bytes()
|
|
114
|
+
if self.context.cache is not None:
|
|
115
|
+
self.context.prompt.lock_images_as_bytes()
|
|
142
116
|
|
|
143
117
|
self.request_json, self.request_header = _build_anthropic_request(
|
|
144
118
|
self.model,
|
|
145
|
-
prompt,
|
|
146
|
-
tools,
|
|
147
|
-
sampling_params,
|
|
148
|
-
cache,
|
|
149
|
-
computer_use,
|
|
150
|
-
display_width,
|
|
151
|
-
display_height,
|
|
119
|
+
self.context.prompt,
|
|
120
|
+
self.context.tools,
|
|
121
|
+
self.context.sampling_params,
|
|
122
|
+
self.context.cache,
|
|
152
123
|
)
|
|
153
124
|
|
|
154
125
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
126
|
+
data = None
|
|
155
127
|
is_error = False
|
|
156
128
|
error_message = None
|
|
157
129
|
thinking = None
|
|
@@ -160,6 +132,7 @@ class AnthropicRequest(APIRequestBase):
|
|
|
160
132
|
status_code = http_response.status
|
|
161
133
|
mimetype = http_response.headers.get("Content-Type", None)
|
|
162
134
|
rate_limits = {}
|
|
135
|
+
assert self.context.status_tracker
|
|
163
136
|
for header in [
|
|
164
137
|
"anthropic-ratelimit-requests-limit",
|
|
165
138
|
"anthropic-ratelimit-requests-remaining",
|
|
@@ -215,20 +188,21 @@ class AnthropicRequest(APIRequestBase):
|
|
|
215
188
|
or "overloaded" in error_message.lower()
|
|
216
189
|
):
|
|
217
190
|
error_message += " (Rate limit error, triggering cooldown.)"
|
|
218
|
-
self.status_tracker.rate_limit_exceeded()
|
|
191
|
+
self.context.status_tracker.rate_limit_exceeded()
|
|
219
192
|
if "context length" in error_message:
|
|
220
193
|
error_message += " (Context length exceeded, set retries to 0.)"
|
|
221
|
-
self.attempts_left = 0
|
|
194
|
+
self.context.attempts_left = 0
|
|
222
195
|
|
|
223
196
|
return APIResponse(
|
|
224
|
-
id=self.task_id,
|
|
197
|
+
id=self.context.task_id,
|
|
225
198
|
status_code=status_code,
|
|
226
199
|
is_error=is_error,
|
|
227
200
|
error_message=error_message,
|
|
228
|
-
prompt=self.prompt,
|
|
201
|
+
prompt=self.context.prompt,
|
|
229
202
|
content=content,
|
|
230
203
|
thinking=thinking,
|
|
231
|
-
model_internal=self.model_name,
|
|
232
|
-
sampling_params=self.sampling_params,
|
|
204
|
+
model_internal=self.context.model_name,
|
|
205
|
+
sampling_params=self.context.sampling_params,
|
|
233
206
|
usage=usage,
|
|
207
|
+
raw_response=data,
|
|
234
208
|
)
|
lm_deluge/api_requests/base.py
CHANGED
|
@@ -1,18 +1,12 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import random
|
|
3
2
|
import traceback
|
|
4
3
|
from abc import ABC, abstractmethod
|
|
5
|
-
from typing import Callable
|
|
6
4
|
|
|
7
5
|
import aiohttp
|
|
8
6
|
from aiohttp import ClientResponse
|
|
9
7
|
|
|
10
|
-
from lm_deluge.prompt import CachePattern, Conversation
|
|
11
|
-
|
|
12
|
-
from ..config import SamplingParams
|
|
13
8
|
from ..errors import raise_if_modal_exception
|
|
14
|
-
from ..
|
|
15
|
-
from ..tracker import StatusTracker
|
|
9
|
+
from ..request_context import RequestContext
|
|
16
10
|
from .response import APIResponse
|
|
17
11
|
|
|
18
12
|
|
|
@@ -28,40 +22,13 @@ class APIRequestBase(ABC):
|
|
|
28
22
|
|
|
29
23
|
def __init__(
|
|
30
24
|
self,
|
|
31
|
-
|
|
32
|
-
# should always be 'role', 'content' keys.
|
|
33
|
-
# internal logic should handle translating to specific API format
|
|
34
|
-
model_name: str, # must correspond to registry
|
|
35
|
-
prompt: Conversation,
|
|
36
|
-
attempts_left: int,
|
|
37
|
-
status_tracker: StatusTracker,
|
|
38
|
-
# needed in order to retry with a different model and not throw the output away
|
|
39
|
-
results_arr: list["APIRequestBase"],
|
|
40
|
-
request_timeout: int = 30,
|
|
41
|
-
sampling_params: SamplingParams = SamplingParams(),
|
|
42
|
-
callback: Callable | None = None,
|
|
43
|
-
all_model_names: list[str] | None = None,
|
|
44
|
-
all_sampling_params: list[SamplingParams] | None = None,
|
|
45
|
-
tools: list | None = None,
|
|
46
|
-
cache: CachePattern | None = None,
|
|
25
|
+
context: RequestContext,
|
|
47
26
|
):
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
self.
|
|
27
|
+
# If context is provided, use it; otherwise construct one from individual parameters
|
|
28
|
+
self.context = context
|
|
29
|
+
|
|
30
|
+
# Everything is now accessed through self.context - no copying!
|
|
52
31
|
self.system_prompt = None
|
|
53
|
-
self.prompt = prompt
|
|
54
|
-
self.attempts_left = attempts_left
|
|
55
|
-
self.status_tracker = status_tracker
|
|
56
|
-
self.request_timeout = request_timeout
|
|
57
|
-
self.sampling_params = sampling_params
|
|
58
|
-
self.callback = callback
|
|
59
|
-
self.num_tokens = prompt.count_tokens(sampling_params.max_new_tokens)
|
|
60
|
-
self.results_arr = results_arr
|
|
61
|
-
self.all_model_names = all_model_names
|
|
62
|
-
self.all_sampling_params = all_sampling_params
|
|
63
|
-
self.tools = tools
|
|
64
|
-
self.cache: CachePattern | None = cache
|
|
65
32
|
self.result = [] # list of APIResponse objects from each attempt
|
|
66
33
|
|
|
67
34
|
# these should be set in the __init__ of the subclass
|
|
@@ -71,101 +38,25 @@ class APIRequestBase(ABC):
|
|
|
71
38
|
self.region = None
|
|
72
39
|
|
|
73
40
|
def increment_pbar(self):
|
|
74
|
-
self.status_tracker
|
|
41
|
+
if self.context.status_tracker:
|
|
42
|
+
self.context.status_tracker.increment_pbar()
|
|
75
43
|
|
|
76
44
|
def call_callback(self):
|
|
77
|
-
if self.callback is not None:
|
|
45
|
+
if self.context.callback is not None:
|
|
78
46
|
# the APIResponse in self.result includes all the information
|
|
79
|
-
self.callback(self.result[-1], self.status_tracker)
|
|
47
|
+
self.context.callback(self.result[-1], self.context.status_tracker)
|
|
80
48
|
|
|
81
49
|
def handle_success(self, data):
|
|
82
50
|
self.call_callback()
|
|
83
|
-
self.status_tracker
|
|
84
|
-
|
|
85
|
-
def handle_error(self, create_new_request=False, give_up_if_no_other_models=False):
|
|
86
|
-
"""
|
|
87
|
-
If create_new_request is True, will create a new API request (so that it
|
|
88
|
-
has a chance of being sent to a different model). If false, will retry
|
|
89
|
-
the same request.
|
|
90
|
-
"""
|
|
91
|
-
last_result: APIResponse = self.result[-1]
|
|
92
|
-
error_to_print = f"Error task {self.task_id}. "
|
|
93
|
-
error_to_print += (
|
|
94
|
-
f"Model: {last_result.model_internal} Code: {last_result.status_code}, "
|
|
95
|
-
)
|
|
96
|
-
if self.region is not None:
|
|
97
|
-
error_to_print += f"Region: {self.region}, "
|
|
98
|
-
error_to_print += f"Message: {last_result.error_message}."
|
|
99
|
-
print(error_to_print)
|
|
100
|
-
if self.attempts_left > 0:
|
|
101
|
-
self.attempts_left -= 1
|
|
102
|
-
if not create_new_request:
|
|
103
|
-
assert self.status_tracker.retry_queue
|
|
104
|
-
self.status_tracker.retry_queue.put_nowait(self)
|
|
105
|
-
return
|
|
106
|
-
else:
|
|
107
|
-
# make sure we have another model to send it to besides the current one
|
|
108
|
-
if self.all_model_names is None or len(self.all_model_names) < 2:
|
|
109
|
-
if give_up_if_no_other_models:
|
|
110
|
-
print(
|
|
111
|
-
f"No other models to try for task {self.task_id}. Giving up."
|
|
112
|
-
)
|
|
113
|
-
self.status_tracker.task_failed(self.task_id)
|
|
114
|
-
else:
|
|
115
|
-
print(
|
|
116
|
-
f"No other models to try for task {self.task_id}. Retrying with same model."
|
|
117
|
-
)
|
|
118
|
-
assert self.status_tracker.retry_queue
|
|
119
|
-
self.status_tracker.retry_queue.put_nowait(self)
|
|
120
|
-
else:
|
|
121
|
-
# two things to change: model_name and sampling_params
|
|
122
|
-
new_model_name = self.model_name
|
|
123
|
-
new_model_idx = 0
|
|
124
|
-
while new_model_name == self.model_name:
|
|
125
|
-
new_model_idx = random.randint(0, len(self.all_model_names) - 1)
|
|
126
|
-
new_model_name = self.all_model_names[new_model_idx]
|
|
127
|
-
|
|
128
|
-
if isinstance(self.all_sampling_params, list):
|
|
129
|
-
new_sampling_params = self.all_sampling_params[new_model_idx]
|
|
130
|
-
elif isinstance(self.all_sampling_params, SamplingParams):
|
|
131
|
-
new_sampling_params = self.all_sampling_params
|
|
132
|
-
elif self.all_sampling_params is None:
|
|
133
|
-
new_sampling_params = self.sampling_params
|
|
134
|
-
else:
|
|
135
|
-
new_sampling_params = self.sampling_params
|
|
51
|
+
if self.context.status_tracker:
|
|
52
|
+
self.context.status_tracker.task_succeeded(self.context.task_id)
|
|
136
53
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
model_name=new_model_name,
|
|
141
|
-
prompt=self.prompt,
|
|
142
|
-
attempts_left=self.attempts_left,
|
|
143
|
-
status_tracker=self.status_tracker,
|
|
144
|
-
results_arr=self.results_arr,
|
|
145
|
-
request_timeout=self.request_timeout,
|
|
146
|
-
sampling_params=new_sampling_params,
|
|
147
|
-
callback=self.callback,
|
|
148
|
-
all_model_names=self.all_model_names,
|
|
149
|
-
all_sampling_params=self.all_sampling_params,
|
|
150
|
-
tools=self.tools,
|
|
151
|
-
cache=self.cache,
|
|
152
|
-
computer_use=getattr(self, "computer_use", False),
|
|
153
|
-
display_width=getattr(self, "display_width", 1024),
|
|
154
|
-
display_height=getattr(self, "display_height", 768),
|
|
155
|
-
)
|
|
156
|
-
# PROBLEM: new request is never put into results array, so we can't get the result.
|
|
157
|
-
assert self.status_tracker.retry_queue
|
|
158
|
-
self.status_tracker.retry_queue.put_nowait(self)
|
|
159
|
-
# SOLUTION: just need to make sure it's deduplicated by task_id later.
|
|
160
|
-
self.results_arr.append(new_request)
|
|
161
|
-
else:
|
|
162
|
-
print(f"Task {self.task_id} out of tries.")
|
|
163
|
-
self.status_tracker.task_failed(self.task_id)
|
|
164
|
-
|
|
165
|
-
async def call_api(self):
|
|
54
|
+
async def execute_once(self) -> APIResponse:
|
|
55
|
+
"""Send the HTTP request once and return the parsed APIResponse."""
|
|
56
|
+
assert self.context.status_tracker
|
|
166
57
|
try:
|
|
167
|
-
self.status_tracker.total_requests += 1
|
|
168
|
-
timeout = aiohttp.ClientTimeout(total=self.request_timeout)
|
|
58
|
+
self.context.status_tracker.total_requests += 1
|
|
59
|
+
timeout = aiohttp.ClientTimeout(total=self.context.request_timeout)
|
|
169
60
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
170
61
|
assert self.url is not None, "URL is not set"
|
|
171
62
|
async with session.post(
|
|
@@ -174,133 +65,56 @@ class APIRequestBase(ABC):
|
|
|
174
65
|
json=self.request_json,
|
|
175
66
|
) as http_response:
|
|
176
67
|
response: APIResponse = await self.handle_response(http_response)
|
|
177
|
-
|
|
178
|
-
self.result.append(response)
|
|
179
|
-
if response.is_error:
|
|
180
|
-
self.handle_error(
|
|
181
|
-
create_new_request=response.retry_with_different_model or False,
|
|
182
|
-
give_up_if_no_other_models=response.give_up_if_no_other_models
|
|
183
|
-
or False,
|
|
184
|
-
)
|
|
185
|
-
else:
|
|
186
|
-
self.handle_success(response)
|
|
68
|
+
return response
|
|
187
69
|
|
|
188
70
|
except asyncio.TimeoutError:
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
usage=None,
|
|
200
|
-
)
|
|
71
|
+
return APIResponse(
|
|
72
|
+
id=self.context.task_id,
|
|
73
|
+
model_internal=self.context.model_name,
|
|
74
|
+
prompt=self.context.prompt,
|
|
75
|
+
sampling_params=self.context.sampling_params,
|
|
76
|
+
status_code=None,
|
|
77
|
+
is_error=True,
|
|
78
|
+
error_message="Request timed out (terminated by client).",
|
|
79
|
+
content=None,
|
|
80
|
+
usage=None,
|
|
201
81
|
)
|
|
202
|
-
self.handle_error(create_new_request=False)
|
|
203
82
|
|
|
204
83
|
except Exception as e:
|
|
205
84
|
raise_if_modal_exception(e)
|
|
206
85
|
tb = traceback.format_exc()
|
|
207
86
|
print(tb)
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
usage=None,
|
|
219
|
-
)
|
|
87
|
+
return APIResponse(
|
|
88
|
+
id=self.context.task_id,
|
|
89
|
+
model_internal=self.context.model_name,
|
|
90
|
+
prompt=self.context.prompt,
|
|
91
|
+
sampling_params=self.context.sampling_params,
|
|
92
|
+
status_code=None,
|
|
93
|
+
is_error=True,
|
|
94
|
+
error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
|
|
95
|
+
content=None,
|
|
96
|
+
usage=None,
|
|
220
97
|
)
|
|
221
|
-
# maybe consider making True?
|
|
222
|
-
self.handle_error(create_new_request=False)
|
|
223
98
|
|
|
224
99
|
@abstractmethod
|
|
225
100
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
226
101
|
raise NotImplementedError
|
|
227
102
|
|
|
228
103
|
|
|
229
|
-
def create_api_request(
|
|
230
|
-
task_id: int,
|
|
231
|
-
model_name: str,
|
|
232
|
-
prompt: Conversation,
|
|
233
|
-
attempts_left: int,
|
|
234
|
-
status_tracker: StatusTracker,
|
|
235
|
-
results_arr: list["APIRequestBase"],
|
|
236
|
-
request_timeout: int = 30,
|
|
237
|
-
sampling_params: SamplingParams = SamplingParams(),
|
|
238
|
-
callback: Callable | None = None,
|
|
239
|
-
all_model_names: list[str] | None = None,
|
|
240
|
-
all_sampling_params: list[SamplingParams] | None = None,
|
|
241
|
-
tools: list | None = None,
|
|
242
|
-
cache: CachePattern | None = None,
|
|
243
|
-
computer_use: bool = False,
|
|
244
|
-
display_width: int = 1024,
|
|
245
|
-
display_height: int = 768,
|
|
246
|
-
use_responses_api: bool = False,
|
|
247
|
-
) -> APIRequestBase:
|
|
248
|
-
from .common import CLASSES # circular import so made it lazy, does this work?
|
|
249
|
-
|
|
250
|
-
model_obj = APIModel.from_registry(model_name)
|
|
251
|
-
|
|
252
|
-
# Choose API spec based on use_responses_api flag and model support
|
|
253
|
-
api_spec = model_obj.api_spec
|
|
254
|
-
if use_responses_api and model_obj.supports_responses and api_spec == "openai":
|
|
255
|
-
api_spec = "openai-responses"
|
|
256
|
-
|
|
257
|
-
request_class = CLASSES.get(api_spec, None)
|
|
258
|
-
if request_class is None:
|
|
259
|
-
raise ValueError(f"Unsupported API spec: {api_spec}")
|
|
260
|
-
kwargs = {}
|
|
261
|
-
# Add computer_use to kwargs if the request class supports it
|
|
262
|
-
model_obj = APIModel.from_registry(model_name)
|
|
263
|
-
if computer_use and api_spec in ["anthropic", "bedrock", "openai-responses"]:
|
|
264
|
-
kwargs.update(
|
|
265
|
-
{
|
|
266
|
-
"computer_use": computer_use,
|
|
267
|
-
"display_width": display_width,
|
|
268
|
-
"display_height": display_height,
|
|
269
|
-
}
|
|
270
|
-
)
|
|
271
|
-
|
|
272
|
-
return request_class(
|
|
273
|
-
task_id=task_id,
|
|
274
|
-
model_name=model_name,
|
|
275
|
-
prompt=prompt,
|
|
276
|
-
attempts_left=attempts_left,
|
|
277
|
-
status_tracker=status_tracker,
|
|
278
|
-
results_arr=results_arr,
|
|
279
|
-
request_timeout=request_timeout,
|
|
280
|
-
sampling_params=sampling_params,
|
|
281
|
-
callback=callback,
|
|
282
|
-
all_model_names=all_model_names,
|
|
283
|
-
all_sampling_params=all_sampling_params,
|
|
284
|
-
tools=tools,
|
|
285
|
-
cache=cache,
|
|
286
|
-
**kwargs,
|
|
287
|
-
)
|
|
288
|
-
|
|
289
|
-
|
|
290
104
|
def deduplicate_responses(results: list[APIRequestBase]) -> list[APIResponse]:
|
|
291
105
|
deduplicated = {}
|
|
292
106
|
for request in results:
|
|
293
|
-
if request.task_id not in deduplicated:
|
|
294
|
-
deduplicated[request.task_id] = request.result[-1]
|
|
107
|
+
if request.context.task_id not in deduplicated:
|
|
108
|
+
deduplicated[request.context.task_id] = request.result[-1]
|
|
295
109
|
else:
|
|
296
|
-
current_response: APIResponse = deduplicated[request.task_id]
|
|
110
|
+
current_response: APIResponse = deduplicated[request.context.task_id]
|
|
297
111
|
# only replace if the current request has no completion and the new one does
|
|
298
112
|
if (
|
|
299
113
|
request.result[-1].completion is not None
|
|
300
114
|
and current_response.completion is None
|
|
301
115
|
):
|
|
302
|
-
deduplicated[request.task_id] = request.result[-1]
|
|
116
|
+
deduplicated[request.context.task_id] = request.result[-1]
|
|
303
117
|
|
|
304
|
-
output = [deduplicated[request.task_id] for request in results]
|
|
118
|
+
output = [deduplicated[request.context.task_id] for request in results]
|
|
305
119
|
|
|
306
120
|
return output
|