lm-deluge 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +9 -1
- lm_deluge/agent.py +0 -0
- lm_deluge/api_requests/anthropic.py +107 -60
- lm_deluge/api_requests/base.py +107 -54
- lm_deluge/api_requests/bedrock.py +59 -22
- lm_deluge/api_requests/common.py +2 -1
- lm_deluge/api_requests/mistral.py +20 -22
- lm_deluge/api_requests/openai.py +283 -51
- lm_deluge/batches.py +498 -0
- lm_deluge/client.py +373 -634
- lm_deluge/computer_use/anthropic_tools.py +75 -0
- lm_deluge/{sampling_params.py → config.py} +10 -3
- lm_deluge/embed.py +17 -11
- lm_deluge/models.py +33 -0
- lm_deluge/prompt.py +173 -7
- lm_deluge/rerank.py +18 -12
- lm_deluge/tool.py +11 -1
- lm_deluge/tracker.py +212 -2
- lm_deluge/usage.py +114 -0
- lm_deluge/util/json.py +18 -1
- {lm_deluge-0.0.11.dist-info → lm_deluge-0.0.13.dist-info}/METADATA +78 -20
- lm_deluge-0.0.13.dist-info/RECORD +42 -0
- {lm_deluge-0.0.11.dist-info → lm_deluge-0.0.13.dist-info}/WHEEL +1 -1
- lm_deluge-0.0.11.dist-info/RECORD +0 -38
- {lm_deluge-0.0.11.dist-info → lm_deluge-0.0.13.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.11.dist-info → lm_deluge-0.0.13.dist-info}/top_level.txt +0 -0
lm_deluge/__init__.py
CHANGED
|
@@ -1,7 +1,15 @@
|
|
|
1
1
|
from .client import LLMClient, SamplingParams, APIResponse
|
|
2
2
|
from .prompt import Conversation, Message
|
|
3
|
+
from .tool import Tool
|
|
3
4
|
import dotenv
|
|
4
5
|
|
|
5
6
|
dotenv.load_dotenv()
|
|
6
7
|
|
|
7
|
-
__all__ = [
|
|
8
|
+
__all__ = [
|
|
9
|
+
"LLMClient",
|
|
10
|
+
"SamplingParams",
|
|
11
|
+
"APIResponse",
|
|
12
|
+
"Conversation",
|
|
13
|
+
"Message",
|
|
14
|
+
"Tool",
|
|
15
|
+
]
|
lm_deluge/agent.py
ADDED
|
File without changes
|
|
@@ -1,17 +1,94 @@
|
|
|
1
|
-
import asyncio
|
|
2
1
|
from aiohttp import ClientResponse
|
|
3
2
|
import json
|
|
4
3
|
import os
|
|
5
|
-
import warnings
|
|
6
|
-
from tqdm import tqdm
|
|
7
4
|
from typing import Callable
|
|
8
5
|
|
|
9
|
-
from lm_deluge.prompt import
|
|
6
|
+
from lm_deluge.prompt import (
|
|
7
|
+
Conversation,
|
|
8
|
+
Message,
|
|
9
|
+
Text,
|
|
10
|
+
ToolCall,
|
|
11
|
+
Thinking,
|
|
12
|
+
CachePattern,
|
|
13
|
+
)
|
|
14
|
+
from lm_deluge.tool import Tool
|
|
15
|
+
from lm_deluge.usage import Usage
|
|
10
16
|
from .base import APIRequestBase, APIResponse
|
|
11
17
|
|
|
12
18
|
from ..tracker import StatusTracker
|
|
13
|
-
from ..
|
|
19
|
+
from ..config import SamplingParams
|
|
14
20
|
from ..models import APIModel
|
|
21
|
+
from ..computer_use.anthropic_tools import get_anthropic_cu_tools
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _build_anthropic_request(
|
|
25
|
+
model: APIModel,
|
|
26
|
+
prompt: Conversation,
|
|
27
|
+
tools: list[Tool] | None,
|
|
28
|
+
sampling_params: SamplingParams,
|
|
29
|
+
cache_pattern: CachePattern | None = None,
|
|
30
|
+
computer_use: bool = False,
|
|
31
|
+
display_width: int = 1024,
|
|
32
|
+
display_height: int = 768,
|
|
33
|
+
):
|
|
34
|
+
system_message, messages = prompt.to_anthropic(cache_pattern=cache_pattern)
|
|
35
|
+
request_header = {
|
|
36
|
+
"x-api-key": os.getenv(model.api_key_env_var),
|
|
37
|
+
"anthropic-version": "2023-06-01",
|
|
38
|
+
"content-type": "application/json",
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
# Add beta header for Computer Use
|
|
42
|
+
if computer_use:
|
|
43
|
+
request_header["anthropic-beta"] = "computer-use-2025-01-24"
|
|
44
|
+
|
|
45
|
+
request_json = {
|
|
46
|
+
"model": model.name,
|
|
47
|
+
"messages": messages,
|
|
48
|
+
"temperature": sampling_params.temperature,
|
|
49
|
+
"top_p": sampling_params.top_p,
|
|
50
|
+
"max_tokens": sampling_params.max_new_tokens,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
# handle thinking
|
|
54
|
+
if model.reasoning_model and sampling_params.reasoning_effort:
|
|
55
|
+
# translate reasoning effort of low, medium, high to budget tokens
|
|
56
|
+
budget = {"low": 1024, "medium": 4096, "high": 16384}.get(
|
|
57
|
+
sampling_params.reasoning_effort
|
|
58
|
+
)
|
|
59
|
+
request_json["thinking"] = {
|
|
60
|
+
"type": "enabled",
|
|
61
|
+
"budget_tokens": budget,
|
|
62
|
+
}
|
|
63
|
+
request_json.pop("top_p")
|
|
64
|
+
request_json["temperature"] = 1.0
|
|
65
|
+
request_json["max_tokens"] += budget
|
|
66
|
+
else:
|
|
67
|
+
request_json["thinking"] = {"type": "disabled"}
|
|
68
|
+
if sampling_params.reasoning_effort:
|
|
69
|
+
print("ignoring reasoning_effort for non-reasoning model")
|
|
70
|
+
if system_message is not None:
|
|
71
|
+
request_json["system"] = system_message
|
|
72
|
+
if tools or computer_use:
|
|
73
|
+
tool_definitions = []
|
|
74
|
+
if tools:
|
|
75
|
+
tool_definitions.extend([tool.dump_for("anthropic") for tool in tools])
|
|
76
|
+
# Add Computer Use tools
|
|
77
|
+
if computer_use:
|
|
78
|
+
cu_tools = get_anthropic_cu_tools(
|
|
79
|
+
model=model.id,
|
|
80
|
+
display_width=display_width, # todo: set from ComputerUseParams
|
|
81
|
+
display_height=display_height,
|
|
82
|
+
)
|
|
83
|
+
tool_definitions.extend(cu_tools)
|
|
84
|
+
|
|
85
|
+
# Add cache control to last tool if tools_only caching is specified
|
|
86
|
+
if cache_pattern == "tools_only" and tool_definitions:
|
|
87
|
+
tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
|
|
88
|
+
|
|
89
|
+
request_json["tools"] = tool_definitions
|
|
90
|
+
|
|
91
|
+
return request_json, request_header
|
|
15
92
|
|
|
16
93
|
|
|
17
94
|
class AnthropicRequest(APIRequestBase):
|
|
@@ -24,17 +101,19 @@ class AnthropicRequest(APIRequestBase):
|
|
|
24
101
|
prompt: Conversation,
|
|
25
102
|
attempts_left: int,
|
|
26
103
|
status_tracker: StatusTracker,
|
|
27
|
-
retry_queue: asyncio.Queue,
|
|
28
104
|
results_arr: list,
|
|
29
105
|
request_timeout: int = 30,
|
|
30
106
|
sampling_params: SamplingParams = SamplingParams(),
|
|
31
|
-
pbar: tqdm | None = None,
|
|
32
107
|
callback: Callable | None = None,
|
|
33
|
-
debug: bool = False,
|
|
34
108
|
# for retries
|
|
35
109
|
all_model_names: list[str] | None = None,
|
|
36
110
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
37
111
|
tools: list | None = None,
|
|
112
|
+
cache: CachePattern | None = None,
|
|
113
|
+
# Computer Use support
|
|
114
|
+
computer_use: bool = False,
|
|
115
|
+
display_width: int = 1024,
|
|
116
|
+
display_height: int = 768,
|
|
38
117
|
):
|
|
39
118
|
super().__init__(
|
|
40
119
|
task_id=task_id,
|
|
@@ -42,70 +121,42 @@ class AnthropicRequest(APIRequestBase):
|
|
|
42
121
|
prompt=prompt,
|
|
43
122
|
attempts_left=attempts_left,
|
|
44
123
|
status_tracker=status_tracker,
|
|
45
|
-
retry_queue=retry_queue,
|
|
46
124
|
results_arr=results_arr,
|
|
47
125
|
request_timeout=request_timeout,
|
|
48
126
|
sampling_params=sampling_params,
|
|
49
|
-
pbar=pbar,
|
|
50
127
|
callback=callback,
|
|
51
|
-
debug=debug,
|
|
52
128
|
all_model_names=all_model_names,
|
|
53
129
|
all_sampling_params=all_sampling_params,
|
|
54
130
|
tools=tools,
|
|
131
|
+
cache=cache,
|
|
55
132
|
)
|
|
133
|
+
self.computer_use = computer_use
|
|
134
|
+
self.display_width = display_width
|
|
135
|
+
self.display_height = display_height
|
|
56
136
|
self.model = APIModel.from_registry(model_name)
|
|
57
137
|
self.url = f"{self.model.api_base}/messages"
|
|
58
138
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
"anthropic-version": "2023-06-01",
|
|
63
|
-
"content-type": "application/json",
|
|
64
|
-
}
|
|
139
|
+
# Lock images as bytes if caching is enabled
|
|
140
|
+
if cache is not None:
|
|
141
|
+
prompt.lock_images_as_bytes()
|
|
65
142
|
|
|
66
|
-
self.request_json =
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
# translate reasoning effort of low, medium, high to budget tokens
|
|
77
|
-
budget = {"low": 1024, "medium": 4096, "high": 16384}.get(
|
|
78
|
-
sampling_params.reasoning_effort
|
|
79
|
-
)
|
|
80
|
-
self.request_json["thinking"] = {
|
|
81
|
-
"type": "enabled",
|
|
82
|
-
"budget_tokens": budget,
|
|
83
|
-
}
|
|
84
|
-
self.request_json.pop("top_p")
|
|
85
|
-
self.request_json["temperature"] = 1.0
|
|
86
|
-
self.request_json["max_tokens"] += (
|
|
87
|
-
budget # assume max tokens is max completion tokens
|
|
88
|
-
)
|
|
89
|
-
else:
|
|
90
|
-
# no thinking
|
|
91
|
-
self.request_json["thinking"] = {"type": "disabled"}
|
|
92
|
-
else:
|
|
93
|
-
if sampling_params.reasoning_effort:
|
|
94
|
-
warnings.warn(
|
|
95
|
-
f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
|
|
96
|
-
)
|
|
97
|
-
if self.system_message is not None:
|
|
98
|
-
self.request_json["system"] = self.system_message
|
|
99
|
-
if tools:
|
|
100
|
-
self.request_json["tools"] = [tool.dump_for("anthropic") for tool in tools]
|
|
143
|
+
self.request_json, self.request_header = _build_anthropic_request(
|
|
144
|
+
self.model,
|
|
145
|
+
prompt,
|
|
146
|
+
tools,
|
|
147
|
+
sampling_params,
|
|
148
|
+
cache,
|
|
149
|
+
computer_use,
|
|
150
|
+
display_width,
|
|
151
|
+
display_height,
|
|
152
|
+
)
|
|
101
153
|
|
|
102
154
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
103
155
|
is_error = False
|
|
104
156
|
error_message = None
|
|
105
157
|
thinking = None
|
|
106
158
|
content = None
|
|
107
|
-
|
|
108
|
-
output_tokens = None
|
|
159
|
+
usage = None
|
|
109
160
|
status_code = http_response.status
|
|
110
161
|
mimetype = http_response.headers.get("Content-Type", None)
|
|
111
162
|
rate_limits = {}
|
|
@@ -118,8 +169,6 @@ class AnthropicRequest(APIRequestBase):
|
|
|
118
169
|
"anthropic-ratelimit-tokens-reset",
|
|
119
170
|
]:
|
|
120
171
|
rate_limits[header] = http_response.headers.get(header, None)
|
|
121
|
-
if self.debug:
|
|
122
|
-
print(f"Rate limits: {rate_limits}")
|
|
123
172
|
if status_code >= 200 and status_code < 300:
|
|
124
173
|
try:
|
|
125
174
|
data = await http_response.json()
|
|
@@ -143,8 +192,7 @@ class AnthropicRequest(APIRequestBase):
|
|
|
143
192
|
)
|
|
144
193
|
|
|
145
194
|
content = Message("assistant", parts)
|
|
146
|
-
|
|
147
|
-
output_tokens = data["usage"]["output_tokens"]
|
|
195
|
+
usage = Usage.from_anthropic_usage(data["usage"])
|
|
148
196
|
except Exception as e:
|
|
149
197
|
is_error = True
|
|
150
198
|
error_message = (
|
|
@@ -182,6 +230,5 @@ class AnthropicRequest(APIRequestBase):
|
|
|
182
230
|
thinking=thinking,
|
|
183
231
|
model_internal=self.model_name,
|
|
184
232
|
sampling_params=self.sampling_params,
|
|
185
|
-
|
|
186
|
-
output_tokens=output_tokens,
|
|
233
|
+
usage=usage,
|
|
187
234
|
)
|
lm_deluge/api_requests/base.py
CHANGED
|
@@ -1,19 +1,21 @@
|
|
|
1
|
-
import aiohttp
|
|
2
1
|
import asyncio
|
|
3
2
|
import json
|
|
4
3
|
import random
|
|
5
|
-
|
|
6
|
-
from dataclasses import dataclass
|
|
4
|
+
import traceback
|
|
7
5
|
from abc import ABC, abstractmethod
|
|
6
|
+
from dataclasses import dataclass
|
|
8
7
|
from typing import Callable
|
|
9
8
|
|
|
10
|
-
|
|
9
|
+
import aiohttp
|
|
10
|
+
from aiohttp import ClientResponse
|
|
11
|
+
|
|
12
|
+
from lm_deluge.prompt import CachePattern, Conversation, Message
|
|
13
|
+
from lm_deluge.usage import Usage
|
|
11
14
|
|
|
12
|
-
from ..
|
|
13
|
-
from ..sampling_params import SamplingParams
|
|
14
|
-
from ..models import APIModel
|
|
15
|
+
from ..config import SamplingParams
|
|
15
16
|
from ..errors import raise_if_modal_exception
|
|
16
|
-
from
|
|
17
|
+
from ..models import APIModel
|
|
18
|
+
from ..tracker import StatusTracker
|
|
17
19
|
|
|
18
20
|
|
|
19
21
|
@dataclass
|
|
@@ -29,9 +31,8 @@ class APIResponse:
|
|
|
29
31
|
is_error: bool | None
|
|
30
32
|
error_message: str | None
|
|
31
33
|
|
|
32
|
-
# completion information
|
|
33
|
-
|
|
34
|
-
output_tokens: int | None
|
|
34
|
+
# completion information - unified usage tracking
|
|
35
|
+
usage: Usage | None = None
|
|
35
36
|
|
|
36
37
|
# response content - structured format
|
|
37
38
|
content: Message | None = None
|
|
@@ -48,6 +49,10 @@ class APIResponse:
|
|
|
48
49
|
retry_with_different_model: bool | None = False
|
|
49
50
|
# set to true if should NOT retry with the same model (unrecoverable error)
|
|
50
51
|
give_up_if_no_other_models: bool | None = False
|
|
52
|
+
# OpenAI Responses API specific - used for computer use continuation
|
|
53
|
+
response_id: str | None = None
|
|
54
|
+
# Raw API response for debugging
|
|
55
|
+
raw_response: dict | None = None
|
|
51
56
|
|
|
52
57
|
@property
|
|
53
58
|
def completion(self) -> str | None:
|
|
@@ -56,6 +61,26 @@ class APIResponse:
|
|
|
56
61
|
return self.content.completion
|
|
57
62
|
return None
|
|
58
63
|
|
|
64
|
+
@property
|
|
65
|
+
def input_tokens(self) -> int | None:
|
|
66
|
+
"""Get input tokens from usage object."""
|
|
67
|
+
return self.usage.input_tokens if self.usage else None
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def output_tokens(self) -> int | None:
|
|
71
|
+
"""Get output tokens from usage object."""
|
|
72
|
+
return self.usage.output_tokens if self.usage else None
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def cache_read_tokens(self) -> int | None:
|
|
76
|
+
"""Get cache read tokens from usage object."""
|
|
77
|
+
return self.usage.cache_read_tokens if self.usage else None
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def cache_write_tokens(self) -> int | None:
|
|
81
|
+
"""Get cache write tokens from usage object."""
|
|
82
|
+
return self.usage.cache_write_tokens if self.usage else None
|
|
83
|
+
|
|
59
84
|
def __post_init__(self):
|
|
60
85
|
# calculate cost & get external model name
|
|
61
86
|
self.id = int(self.id)
|
|
@@ -63,14 +88,13 @@ class APIResponse:
|
|
|
63
88
|
self.model_external = api_model.name
|
|
64
89
|
self.cost = None
|
|
65
90
|
if (
|
|
66
|
-
self.
|
|
67
|
-
and self.output_tokens is not None
|
|
91
|
+
self.usage is not None
|
|
68
92
|
and api_model.input_cost is not None
|
|
69
93
|
and api_model.output_cost is not None
|
|
70
94
|
):
|
|
71
95
|
self.cost = (
|
|
72
|
-
self.input_tokens * api_model.input_cost / 1e6
|
|
73
|
-
+ self.output_tokens * api_model.output_cost / 1e6
|
|
96
|
+
self.usage.input_tokens * api_model.input_cost / 1e6
|
|
97
|
+
+ self.usage.output_tokens * api_model.output_cost / 1e6
|
|
74
98
|
)
|
|
75
99
|
elif self.content is not None and self.completion is not None:
|
|
76
100
|
print(
|
|
@@ -90,8 +114,7 @@ class APIResponse:
|
|
|
90
114
|
"error_message": self.error_message,
|
|
91
115
|
"completion": self.completion, # computed property
|
|
92
116
|
"content": self.content.to_log() if self.content else None,
|
|
93
|
-
"
|
|
94
|
-
"output_tokens": self.output_tokens,
|
|
117
|
+
"usage": self.usage.to_dict() if self.usage else None,
|
|
95
118
|
"finish_reason": self.finish_reason,
|
|
96
119
|
"cost": self.cost,
|
|
97
120
|
}
|
|
@@ -107,6 +130,10 @@ class APIResponse:
|
|
|
107
130
|
# Backward compatibility: create a Message with just text
|
|
108
131
|
content = Message.ai(data["completion"])
|
|
109
132
|
|
|
133
|
+
usage = None
|
|
134
|
+
if "usage" in data and data["usage"] is not None:
|
|
135
|
+
usage = Usage.from_dict(data["usage"])
|
|
136
|
+
|
|
110
137
|
return cls(
|
|
111
138
|
id=data.get("id", random.randint(0, 1_000_000_000)),
|
|
112
139
|
model_internal=data["model_internal"],
|
|
@@ -115,8 +142,7 @@ class APIResponse:
|
|
|
115
142
|
status_code=data["status_code"],
|
|
116
143
|
is_error=data["is_error"],
|
|
117
144
|
error_message=data["error_message"],
|
|
118
|
-
|
|
119
|
-
output_tokens=data["output_tokens"],
|
|
145
|
+
usage=usage,
|
|
120
146
|
content=content,
|
|
121
147
|
thinking=data.get("thinking"),
|
|
122
148
|
model_external=data.get("model_external"),
|
|
@@ -155,19 +181,15 @@ class APIRequestBase(ABC):
|
|
|
155
181
|
prompt: Conversation,
|
|
156
182
|
attempts_left: int,
|
|
157
183
|
status_tracker: StatusTracker,
|
|
158
|
-
retry_queue: asyncio.Queue,
|
|
159
184
|
# needed in order to retry with a different model and not throw the output away
|
|
160
185
|
results_arr: list["APIRequestBase"],
|
|
161
186
|
request_timeout: int = 30,
|
|
162
187
|
sampling_params: SamplingParams = SamplingParams(),
|
|
163
|
-
logprobs: bool = False,
|
|
164
|
-
top_logprobs: int | None = None,
|
|
165
|
-
pbar: tqdm | None = None,
|
|
166
188
|
callback: Callable | None = None,
|
|
167
|
-
debug: bool = False,
|
|
168
189
|
all_model_names: list[str] | None = None,
|
|
169
190
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
170
191
|
tools: list | None = None,
|
|
192
|
+
cache: CachePattern | None = None,
|
|
171
193
|
):
|
|
172
194
|
if all_model_names is None:
|
|
173
195
|
raise ValueError("all_model_names must be provided.")
|
|
@@ -177,19 +199,15 @@ class APIRequestBase(ABC):
|
|
|
177
199
|
self.prompt = prompt
|
|
178
200
|
self.attempts_left = attempts_left
|
|
179
201
|
self.status_tracker = status_tracker
|
|
180
|
-
self.retry_queue = retry_queue
|
|
181
202
|
self.request_timeout = request_timeout
|
|
182
203
|
self.sampling_params = sampling_params
|
|
183
|
-
self.logprobs = logprobs # len(completion) logprobs
|
|
184
|
-
self.top_logprobs = top_logprobs
|
|
185
|
-
self.pbar = pbar
|
|
186
204
|
self.callback = callback
|
|
187
205
|
self.num_tokens = prompt.count_tokens(sampling_params.max_new_tokens)
|
|
188
206
|
self.results_arr = results_arr
|
|
189
|
-
self.debug = debug
|
|
190
207
|
self.all_model_names = all_model_names
|
|
191
208
|
self.all_sampling_params = all_sampling_params
|
|
192
209
|
self.tools = tools
|
|
210
|
+
self.cache: CachePattern | None = cache
|
|
193
211
|
self.result = [] # list of APIResponse objects from each attempt
|
|
194
212
|
|
|
195
213
|
# these should be set in the __init__ of the subclass
|
|
@@ -199,8 +217,7 @@ class APIRequestBase(ABC):
|
|
|
199
217
|
self.region = None
|
|
200
218
|
|
|
201
219
|
def increment_pbar(self):
|
|
202
|
-
|
|
203
|
-
self.pbar.update(1)
|
|
220
|
+
self.status_tracker.increment_pbar()
|
|
204
221
|
|
|
205
222
|
def call_callback(self):
|
|
206
223
|
if self.callback is not None:
|
|
@@ -209,7 +226,6 @@ class APIRequestBase(ABC):
|
|
|
209
226
|
|
|
210
227
|
def handle_success(self, data):
|
|
211
228
|
self.call_callback()
|
|
212
|
-
self.increment_pbar()
|
|
213
229
|
self.status_tracker.task_succeeded(self.task_id)
|
|
214
230
|
|
|
215
231
|
def handle_error(self, create_new_request=False, give_up_if_no_other_models=False):
|
|
@@ -230,7 +246,8 @@ class APIRequestBase(ABC):
|
|
|
230
246
|
if self.attempts_left > 0:
|
|
231
247
|
self.attempts_left -= 1
|
|
232
248
|
if not create_new_request:
|
|
233
|
-
self.retry_queue
|
|
249
|
+
assert self.status_tracker.retry_queue
|
|
250
|
+
self.status_tracker.retry_queue.put_nowait(self)
|
|
234
251
|
return
|
|
235
252
|
else:
|
|
236
253
|
# make sure we have another model to send it to besides the current one
|
|
@@ -244,7 +261,8 @@ class APIRequestBase(ABC):
|
|
|
244
261
|
print(
|
|
245
262
|
f"No other models to try for task {self.task_id}. Retrying with same model."
|
|
246
263
|
)
|
|
247
|
-
self.retry_queue
|
|
264
|
+
assert self.status_tracker.retry_queue
|
|
265
|
+
self.status_tracker.retry_queue.put_nowait(self)
|
|
248
266
|
else:
|
|
249
267
|
# two things to change: model_name and sampling_params
|
|
250
268
|
new_model_name = self.model_name
|
|
@@ -269,20 +287,21 @@ class APIRequestBase(ABC):
|
|
|
269
287
|
prompt=self.prompt,
|
|
270
288
|
attempts_left=self.attempts_left,
|
|
271
289
|
status_tracker=self.status_tracker,
|
|
272
|
-
retry_queue=self.retry_queue,
|
|
273
290
|
results_arr=self.results_arr,
|
|
274
291
|
request_timeout=self.request_timeout,
|
|
275
292
|
sampling_params=new_sampling_params,
|
|
276
|
-
logprobs=self.logprobs,
|
|
277
|
-
top_logprobs=self.top_logprobs,
|
|
278
|
-
pbar=self.pbar,
|
|
279
293
|
callback=self.callback,
|
|
280
294
|
all_model_names=self.all_model_names,
|
|
281
295
|
all_sampling_params=self.all_sampling_params,
|
|
282
296
|
tools=self.tools,
|
|
297
|
+
cache=self.cache,
|
|
298
|
+
computer_use=getattr(self, "computer_use", False),
|
|
299
|
+
display_width=getattr(self, "display_width", 1024),
|
|
300
|
+
display_height=getattr(self, "display_height", 768),
|
|
283
301
|
)
|
|
284
302
|
# PROBLEM: new request is never put into results array, so we can't get the result.
|
|
285
|
-
self.retry_queue
|
|
303
|
+
assert self.status_tracker.retry_queue
|
|
304
|
+
self.status_tracker.retry_queue.put_nowait(self)
|
|
286
305
|
# SOLUTION: just need to make sure it's deduplicated by task_id later.
|
|
287
306
|
self.results_arr.append(new_request)
|
|
288
307
|
else:
|
|
@@ -323,14 +342,15 @@ class APIRequestBase(ABC):
|
|
|
323
342
|
is_error=True,
|
|
324
343
|
error_message="Request timed out (terminated by client).",
|
|
325
344
|
content=None,
|
|
326
|
-
|
|
327
|
-
output_tokens=None,
|
|
345
|
+
usage=None,
|
|
328
346
|
)
|
|
329
347
|
)
|
|
330
348
|
self.handle_error(create_new_request=False)
|
|
331
349
|
|
|
332
350
|
except Exception as e:
|
|
333
351
|
raise_if_modal_exception(e)
|
|
352
|
+
tb = traceback.format_exc()
|
|
353
|
+
print(tb)
|
|
334
354
|
self.result.append(
|
|
335
355
|
APIResponse(
|
|
336
356
|
id=self.task_id,
|
|
@@ -341,8 +361,7 @@ class APIRequestBase(ABC):
|
|
|
341
361
|
is_error=True,
|
|
342
362
|
error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
|
|
343
363
|
content=None,
|
|
344
|
-
|
|
345
|
-
output_tokens=None,
|
|
364
|
+
usage=None,
|
|
346
365
|
)
|
|
347
366
|
)
|
|
348
367
|
# maybe consider making True?
|
|
@@ -359,41 +378,75 @@ def create_api_request(
|
|
|
359
378
|
prompt: Conversation,
|
|
360
379
|
attempts_left: int,
|
|
361
380
|
status_tracker: StatusTracker,
|
|
362
|
-
retry_queue: asyncio.Queue,
|
|
363
381
|
results_arr: list["APIRequestBase"],
|
|
364
382
|
request_timeout: int = 30,
|
|
365
383
|
sampling_params: SamplingParams = SamplingParams(),
|
|
366
|
-
logprobs: bool = False,
|
|
367
|
-
top_logprobs: int | None = None,
|
|
368
|
-
pbar: tqdm | None = None,
|
|
369
384
|
callback: Callable | None = None,
|
|
370
385
|
all_model_names: list[str] | None = None,
|
|
371
386
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
372
387
|
tools: list | None = None,
|
|
388
|
+
cache: CachePattern | None = None,
|
|
389
|
+
computer_use: bool = False,
|
|
390
|
+
display_width: int = 1024,
|
|
391
|
+
display_height: int = 768,
|
|
392
|
+
use_responses_api: bool = False,
|
|
373
393
|
) -> APIRequestBase:
|
|
374
394
|
from .common import CLASSES # circular import so made it lazy, does this work?
|
|
375
395
|
|
|
376
396
|
model_obj = APIModel.from_registry(model_name)
|
|
377
|
-
|
|
397
|
+
|
|
398
|
+
# Choose API spec based on use_responses_api flag and model support
|
|
399
|
+
api_spec = model_obj.api_spec
|
|
400
|
+
if use_responses_api and model_obj.supports_responses and api_spec == "openai":
|
|
401
|
+
api_spec = "openai-responses"
|
|
402
|
+
|
|
403
|
+
request_class = CLASSES.get(api_spec, None)
|
|
378
404
|
if request_class is None:
|
|
379
|
-
raise ValueError(f"Unsupported API spec: {
|
|
380
|
-
kwargs =
|
|
381
|
-
|
|
382
|
-
)
|
|
405
|
+
raise ValueError(f"Unsupported API spec: {api_spec}")
|
|
406
|
+
kwargs = {}
|
|
407
|
+
# Add computer_use to kwargs if the request class supports it
|
|
408
|
+
model_obj = APIModel.from_registry(model_name)
|
|
409
|
+
if computer_use and api_spec in ["anthropic", "bedrock", "openai-responses"]:
|
|
410
|
+
kwargs.update(
|
|
411
|
+
{
|
|
412
|
+
"computer_use": computer_use,
|
|
413
|
+
"display_width": display_width,
|
|
414
|
+
"display_height": display_height,
|
|
415
|
+
}
|
|
416
|
+
)
|
|
417
|
+
|
|
383
418
|
return request_class(
|
|
384
419
|
task_id=task_id,
|
|
385
420
|
model_name=model_name,
|
|
386
421
|
prompt=prompt,
|
|
387
422
|
attempts_left=attempts_left,
|
|
388
423
|
status_tracker=status_tracker,
|
|
389
|
-
retry_queue=retry_queue,
|
|
390
424
|
results_arr=results_arr,
|
|
391
425
|
request_timeout=request_timeout,
|
|
392
426
|
sampling_params=sampling_params,
|
|
393
|
-
pbar=pbar,
|
|
394
427
|
callback=callback,
|
|
395
428
|
all_model_names=all_model_names,
|
|
396
429
|
all_sampling_params=all_sampling_params,
|
|
397
430
|
tools=tools,
|
|
431
|
+
cache=cache,
|
|
398
432
|
**kwargs,
|
|
399
433
|
)
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def deduplicate_responses(results: list[APIRequestBase]) -> list[APIResponse]:
|
|
437
|
+
deduplicated = {}
|
|
438
|
+
for request in results:
|
|
439
|
+
if request.task_id not in deduplicated:
|
|
440
|
+
deduplicated[request.task_id] = request.result[-1]
|
|
441
|
+
else:
|
|
442
|
+
current_response: APIResponse = deduplicated[request.task_id]
|
|
443
|
+
# only replace if the current request has no completion and the new one does
|
|
444
|
+
if (
|
|
445
|
+
request.result[-1].completion is not None
|
|
446
|
+
and current_response.completion is None
|
|
447
|
+
):
|
|
448
|
+
deduplicated[request.task_id] = request.result[-1]
|
|
449
|
+
|
|
450
|
+
output = [deduplicated[request.task_id] for request in results]
|
|
451
|
+
|
|
452
|
+
return output
|