lm-deluge 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +11 -1
- lm_deluge/agent.py +0 -0
- lm_deluge/api_requests/anthropic.py +90 -58
- lm_deluge/api_requests/base.py +63 -180
- lm_deluge/api_requests/bedrock.py +34 -10
- lm_deluge/api_requests/common.py +2 -1
- lm_deluge/api_requests/mistral.py +6 -15
- lm_deluge/api_requests/openai.py +342 -50
- lm_deluge/api_requests/response.py +153 -0
- lm_deluge/batches.py +498 -0
- lm_deluge/client.py +354 -636
- lm_deluge/computer_use/anthropic_tools.py +75 -0
- lm_deluge/{sampling_params.py → config.py} +12 -4
- lm_deluge/embed.py +17 -11
- lm_deluge/file.py +149 -0
- lm_deluge/models.py +33 -0
- lm_deluge/prompt.py +156 -15
- lm_deluge/rerank.py +18 -12
- lm_deluge/tool.py +11 -1
- lm_deluge/tracker.py +214 -2
- lm_deluge/util/json.py +18 -1
- {lm_deluge-0.0.12.dist-info → lm_deluge-0.0.14.dist-info}/METADATA +8 -5
- lm_deluge-0.0.14.dist-info/RECORD +44 -0
- {lm_deluge-0.0.12.dist-info → lm_deluge-0.0.14.dist-info}/WHEEL +1 -1
- lm_deluge-0.0.12.dist-info/RECORD +0 -39
- {lm_deluge-0.0.12.dist-info → lm_deluge-0.0.14.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.12.dist-info → lm_deluge-0.0.14.dist-info}/top_level.txt +0 -0
lm_deluge/__init__.py
CHANGED
|
@@ -1,7 +1,17 @@
|
|
|
1
1
|
from .client import LLMClient, SamplingParams, APIResponse
|
|
2
2
|
from .prompt import Conversation, Message
|
|
3
|
+
from .tool import Tool
|
|
4
|
+
from .file import File
|
|
3
5
|
import dotenv
|
|
4
6
|
|
|
5
7
|
dotenv.load_dotenv()
|
|
6
8
|
|
|
7
|
-
__all__ = [
|
|
9
|
+
__all__ = [
|
|
10
|
+
"LLMClient",
|
|
11
|
+
"SamplingParams",
|
|
12
|
+
"APIResponse",
|
|
13
|
+
"Conversation",
|
|
14
|
+
"Message",
|
|
15
|
+
"Tool",
|
|
16
|
+
"File",
|
|
17
|
+
]
|
lm_deluge/agent.py
ADDED
|
File without changes
|
|
@@ -1,9 +1,6 @@
|
|
|
1
|
-
import asyncio
|
|
2
1
|
from aiohttp import ClientResponse
|
|
3
2
|
import json
|
|
4
3
|
import os
|
|
5
|
-
import warnings
|
|
6
|
-
from tqdm import tqdm
|
|
7
4
|
from typing import Callable
|
|
8
5
|
|
|
9
6
|
from lm_deluge.prompt import (
|
|
@@ -14,12 +11,84 @@ from lm_deluge.prompt import (
|
|
|
14
11
|
Thinking,
|
|
15
12
|
CachePattern,
|
|
16
13
|
)
|
|
14
|
+
from lm_deluge.tool import Tool
|
|
17
15
|
from lm_deluge.usage import Usage
|
|
18
16
|
from .base import APIRequestBase, APIResponse
|
|
19
17
|
|
|
20
18
|
from ..tracker import StatusTracker
|
|
21
|
-
from ..
|
|
19
|
+
from ..config import SamplingParams
|
|
22
20
|
from ..models import APIModel
|
|
21
|
+
from ..computer_use.anthropic_tools import get_anthropic_cu_tools
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _build_anthropic_request(
|
|
25
|
+
model: APIModel,
|
|
26
|
+
prompt: Conversation,
|
|
27
|
+
tools: list[Tool] | None,
|
|
28
|
+
sampling_params: SamplingParams,
|
|
29
|
+
cache_pattern: CachePattern | None = None,
|
|
30
|
+
computer_use: bool = False,
|
|
31
|
+
display_width: int = 1024,
|
|
32
|
+
display_height: int = 768,
|
|
33
|
+
):
|
|
34
|
+
system_message, messages = prompt.to_anthropic(cache_pattern=cache_pattern)
|
|
35
|
+
request_header = {
|
|
36
|
+
"x-api-key": os.getenv(model.api_key_env_var),
|
|
37
|
+
"anthropic-version": "2023-06-01",
|
|
38
|
+
"content-type": "application/json",
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
# Add beta header for Computer Use
|
|
42
|
+
if computer_use:
|
|
43
|
+
request_header["anthropic-beta"] = "computer-use-2025-01-24"
|
|
44
|
+
|
|
45
|
+
request_json = {
|
|
46
|
+
"model": model.name,
|
|
47
|
+
"messages": messages,
|
|
48
|
+
"temperature": sampling_params.temperature,
|
|
49
|
+
"top_p": sampling_params.top_p,
|
|
50
|
+
"max_tokens": sampling_params.max_new_tokens,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
# handle thinking
|
|
54
|
+
if model.reasoning_model and sampling_params.reasoning_effort:
|
|
55
|
+
# translate reasoning effort of low, medium, high to budget tokens
|
|
56
|
+
budget = {"low": 1024, "medium": 4096, "high": 16384}.get(
|
|
57
|
+
sampling_params.reasoning_effort
|
|
58
|
+
)
|
|
59
|
+
request_json["thinking"] = {
|
|
60
|
+
"type": "enabled",
|
|
61
|
+
"budget_tokens": budget,
|
|
62
|
+
}
|
|
63
|
+
request_json.pop("top_p")
|
|
64
|
+
request_json["temperature"] = 1.0
|
|
65
|
+
request_json["max_tokens"] += budget
|
|
66
|
+
else:
|
|
67
|
+
request_json["thinking"] = {"type": "disabled"}
|
|
68
|
+
if sampling_params.reasoning_effort:
|
|
69
|
+
print("ignoring reasoning_effort for non-reasoning model")
|
|
70
|
+
if system_message is not None:
|
|
71
|
+
request_json["system"] = system_message
|
|
72
|
+
if tools or computer_use:
|
|
73
|
+
tool_definitions = []
|
|
74
|
+
if tools:
|
|
75
|
+
tool_definitions.extend([tool.dump_for("anthropic") for tool in tools])
|
|
76
|
+
# Add Computer Use tools
|
|
77
|
+
if computer_use:
|
|
78
|
+
cu_tools = get_anthropic_cu_tools(
|
|
79
|
+
model=model.id,
|
|
80
|
+
display_width=display_width, # todo: set from ComputerUseParams
|
|
81
|
+
display_height=display_height,
|
|
82
|
+
)
|
|
83
|
+
tool_definitions.extend(cu_tools)
|
|
84
|
+
|
|
85
|
+
# Add cache control to last tool if tools_only caching is specified
|
|
86
|
+
if cache_pattern == "tools_only" and tool_definitions:
|
|
87
|
+
tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
|
|
88
|
+
|
|
89
|
+
request_json["tools"] = tool_definitions
|
|
90
|
+
|
|
91
|
+
return request_json, request_header
|
|
23
92
|
|
|
24
93
|
|
|
25
94
|
class AnthropicRequest(APIRequestBase):
|
|
@@ -32,18 +101,19 @@ class AnthropicRequest(APIRequestBase):
|
|
|
32
101
|
prompt: Conversation,
|
|
33
102
|
attempts_left: int,
|
|
34
103
|
status_tracker: StatusTracker,
|
|
35
|
-
retry_queue: asyncio.Queue,
|
|
36
104
|
results_arr: list,
|
|
37
105
|
request_timeout: int = 30,
|
|
38
106
|
sampling_params: SamplingParams = SamplingParams(),
|
|
39
|
-
pbar: tqdm | None = None,
|
|
40
107
|
callback: Callable | None = None,
|
|
41
|
-
debug: bool = False,
|
|
42
108
|
# for retries
|
|
43
109
|
all_model_names: list[str] | None = None,
|
|
44
110
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
45
111
|
tools: list | None = None,
|
|
46
112
|
cache: CachePattern | None = None,
|
|
113
|
+
# Computer Use support
|
|
114
|
+
computer_use: bool = False,
|
|
115
|
+
display_width: int = 1024,
|
|
116
|
+
display_height: int = 768,
|
|
47
117
|
):
|
|
48
118
|
super().__init__(
|
|
49
119
|
task_id=task_id,
|
|
@@ -51,18 +121,18 @@ class AnthropicRequest(APIRequestBase):
|
|
|
51
121
|
prompt=prompt,
|
|
52
122
|
attempts_left=attempts_left,
|
|
53
123
|
status_tracker=status_tracker,
|
|
54
|
-
retry_queue=retry_queue,
|
|
55
124
|
results_arr=results_arr,
|
|
56
125
|
request_timeout=request_timeout,
|
|
57
126
|
sampling_params=sampling_params,
|
|
58
|
-
pbar=pbar,
|
|
59
127
|
callback=callback,
|
|
60
|
-
debug=debug,
|
|
61
128
|
all_model_names=all_model_names,
|
|
62
129
|
all_sampling_params=all_sampling_params,
|
|
63
130
|
tools=tools,
|
|
64
131
|
cache=cache,
|
|
65
132
|
)
|
|
133
|
+
self.computer_use = computer_use
|
|
134
|
+
self.display_width = display_width
|
|
135
|
+
self.display_height = display_height
|
|
66
136
|
self.model = APIModel.from_registry(model_name)
|
|
67
137
|
self.url = f"{self.model.api_base}/messages"
|
|
68
138
|
|
|
@@ -70,52 +140,16 @@ class AnthropicRequest(APIRequestBase):
|
|
|
70
140
|
if cache is not None:
|
|
71
141
|
prompt.lock_images_as_bytes()
|
|
72
142
|
|
|
73
|
-
self.
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
"temperature": self.sampling_params.temperature,
|
|
84
|
-
"top_p": self.sampling_params.top_p,
|
|
85
|
-
"max_tokens": self.sampling_params.max_new_tokens,
|
|
86
|
-
}
|
|
87
|
-
# handle thinking
|
|
88
|
-
if self.model.reasoning_model:
|
|
89
|
-
if sampling_params.reasoning_effort:
|
|
90
|
-
# translate reasoning effort of low, medium, high to budget tokens
|
|
91
|
-
budget = {"low": 1024, "medium": 4096, "high": 16384}.get(
|
|
92
|
-
sampling_params.reasoning_effort
|
|
93
|
-
)
|
|
94
|
-
self.request_json["thinking"] = {
|
|
95
|
-
"type": "enabled",
|
|
96
|
-
"budget_tokens": budget,
|
|
97
|
-
}
|
|
98
|
-
self.request_json.pop("top_p")
|
|
99
|
-
self.request_json["temperature"] = 1.0
|
|
100
|
-
self.request_json["max_tokens"] += (
|
|
101
|
-
budget # assume max tokens is max completion tokens
|
|
102
|
-
)
|
|
103
|
-
else:
|
|
104
|
-
# no thinking
|
|
105
|
-
self.request_json["thinking"] = {"type": "disabled"}
|
|
106
|
-
else:
|
|
107
|
-
if sampling_params.reasoning_effort:
|
|
108
|
-
warnings.warn(
|
|
109
|
-
f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
|
|
110
|
-
)
|
|
111
|
-
if self.system_message is not None:
|
|
112
|
-
self.request_json["system"] = self.system_message
|
|
113
|
-
if tools:
|
|
114
|
-
tool_definitions = [tool.dump_for("anthropic") for tool in tools]
|
|
115
|
-
# Add cache control to last tool if tools_only caching is specified
|
|
116
|
-
if cache == "tools_only" and tool_definitions:
|
|
117
|
-
tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
|
|
118
|
-
self.request_json["tools"] = tool_definitions
|
|
143
|
+
self.request_json, self.request_header = _build_anthropic_request(
|
|
144
|
+
self.model,
|
|
145
|
+
prompt,
|
|
146
|
+
tools,
|
|
147
|
+
sampling_params,
|
|
148
|
+
cache,
|
|
149
|
+
computer_use,
|
|
150
|
+
display_width,
|
|
151
|
+
display_height,
|
|
152
|
+
)
|
|
119
153
|
|
|
120
154
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
121
155
|
is_error = False
|
|
@@ -135,8 +169,6 @@ class AnthropicRequest(APIRequestBase):
|
|
|
135
169
|
"anthropic-ratelimit-tokens-reset",
|
|
136
170
|
]:
|
|
137
171
|
rate_limits[header] = http_response.headers.get(header, None)
|
|
138
|
-
if self.debug:
|
|
139
|
-
print(f"Rate limits: {rate_limits}")
|
|
140
172
|
if status_code >= 200 and status_code < 300:
|
|
141
173
|
try:
|
|
142
174
|
data = await http_response.json()
|
lm_deluge/api_requests/base.py
CHANGED
|
@@ -1,160 +1,19 @@
|
|
|
1
|
-
import aiohttp
|
|
2
1
|
import asyncio
|
|
3
|
-
import json
|
|
4
2
|
import random
|
|
5
|
-
|
|
6
|
-
from dataclasses import dataclass
|
|
3
|
+
import traceback
|
|
7
4
|
from abc import ABC, abstractmethod
|
|
8
5
|
from typing import Callable
|
|
9
6
|
|
|
10
|
-
|
|
11
|
-
from lm_deluge.usage import Usage
|
|
12
|
-
|
|
13
|
-
from ..tracker import StatusTracker
|
|
14
|
-
from ..sampling_params import SamplingParams
|
|
15
|
-
from ..models import APIModel
|
|
16
|
-
from ..errors import raise_if_modal_exception
|
|
7
|
+
import aiohttp
|
|
17
8
|
from aiohttp import ClientResponse
|
|
18
9
|
|
|
10
|
+
from lm_deluge.prompt import CachePattern, Conversation
|
|
19
11
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
prompt: Conversation
|
|
26
|
-
sampling_params: SamplingParams
|
|
27
|
-
|
|
28
|
-
# http response information
|
|
29
|
-
status_code: int | None
|
|
30
|
-
is_error: bool | None
|
|
31
|
-
error_message: str | None
|
|
32
|
-
|
|
33
|
-
# completion information - unified usage tracking
|
|
34
|
-
usage: Usage | None = None
|
|
35
|
-
|
|
36
|
-
# response content - structured format
|
|
37
|
-
content: Message | None = None
|
|
38
|
-
|
|
39
|
-
# optional or calculated automatically
|
|
40
|
-
thinking: str | None = None # if model shows thinking tokens
|
|
41
|
-
model_external: str | None = None # the model tag used by the API
|
|
42
|
-
region: str | None = None
|
|
43
|
-
logprobs: list | None = None
|
|
44
|
-
finish_reason: str | None = None # make required later
|
|
45
|
-
cost: float | None = None # calculated automatically
|
|
46
|
-
cache_hit: bool = False # manually set if true
|
|
47
|
-
# set to true if is_error and should be retried with a different model
|
|
48
|
-
retry_with_different_model: bool | None = False
|
|
49
|
-
# set to true if should NOT retry with the same model (unrecoverable error)
|
|
50
|
-
give_up_if_no_other_models: bool | None = False
|
|
51
|
-
|
|
52
|
-
@property
|
|
53
|
-
def completion(self) -> str | None:
|
|
54
|
-
"""Backward compatibility: extract text from content Message."""
|
|
55
|
-
if self.content is not None:
|
|
56
|
-
return self.content.completion
|
|
57
|
-
return None
|
|
58
|
-
|
|
59
|
-
@property
|
|
60
|
-
def input_tokens(self) -> int | None:
|
|
61
|
-
"""Get input tokens from usage object."""
|
|
62
|
-
return self.usage.input_tokens if self.usage else None
|
|
63
|
-
|
|
64
|
-
@property
|
|
65
|
-
def output_tokens(self) -> int | None:
|
|
66
|
-
"""Get output tokens from usage object."""
|
|
67
|
-
return self.usage.output_tokens if self.usage else None
|
|
68
|
-
|
|
69
|
-
@property
|
|
70
|
-
def cache_read_tokens(self) -> int | None:
|
|
71
|
-
"""Get cache read tokens from usage object."""
|
|
72
|
-
return self.usage.cache_read_tokens if self.usage else None
|
|
73
|
-
|
|
74
|
-
@property
|
|
75
|
-
def cache_write_tokens(self) -> int | None:
|
|
76
|
-
"""Get cache write tokens from usage object."""
|
|
77
|
-
return self.usage.cache_write_tokens if self.usage else None
|
|
78
|
-
|
|
79
|
-
def __post_init__(self):
|
|
80
|
-
# calculate cost & get external model name
|
|
81
|
-
self.id = int(self.id)
|
|
82
|
-
api_model = APIModel.from_registry(self.model_internal)
|
|
83
|
-
self.model_external = api_model.name
|
|
84
|
-
self.cost = None
|
|
85
|
-
if (
|
|
86
|
-
self.usage is not None
|
|
87
|
-
and api_model.input_cost is not None
|
|
88
|
-
and api_model.output_cost is not None
|
|
89
|
-
):
|
|
90
|
-
self.cost = (
|
|
91
|
-
self.usage.input_tokens * api_model.input_cost / 1e6
|
|
92
|
-
+ self.usage.output_tokens * api_model.output_cost / 1e6
|
|
93
|
-
)
|
|
94
|
-
elif self.content is not None and self.completion is not None:
|
|
95
|
-
print(
|
|
96
|
-
f"Warning: Completion provided without token counts for model {self.model_internal}."
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
def to_dict(self):
|
|
100
|
-
return {
|
|
101
|
-
"id": self.id,
|
|
102
|
-
"model_internal": self.model_internal,
|
|
103
|
-
"model_external": self.model_external,
|
|
104
|
-
"region": self.region,
|
|
105
|
-
"prompt": self.prompt.to_log(), # destroys image if present
|
|
106
|
-
"sampling_params": self.sampling_params.__dict__,
|
|
107
|
-
"status_code": self.status_code,
|
|
108
|
-
"is_error": self.is_error,
|
|
109
|
-
"error_message": self.error_message,
|
|
110
|
-
"completion": self.completion, # computed property
|
|
111
|
-
"content": self.content.to_log() if self.content else None,
|
|
112
|
-
"usage": self.usage.to_dict() if self.usage else None,
|
|
113
|
-
"finish_reason": self.finish_reason,
|
|
114
|
-
"cost": self.cost,
|
|
115
|
-
}
|
|
116
|
-
|
|
117
|
-
@classmethod
|
|
118
|
-
def from_dict(cls, data: dict):
|
|
119
|
-
# Handle backward compatibility for content/completion
|
|
120
|
-
content = None
|
|
121
|
-
if "content" in data and data["content"] is not None:
|
|
122
|
-
# Reconstruct message from log format
|
|
123
|
-
content = Message.from_log(data["content"])
|
|
124
|
-
elif "completion" in data and data["completion"] is not None:
|
|
125
|
-
# Backward compatibility: create a Message with just text
|
|
126
|
-
content = Message.ai(data["completion"])
|
|
127
|
-
|
|
128
|
-
usage = None
|
|
129
|
-
if "usage" in data and data["usage"] is not None:
|
|
130
|
-
usage = Usage.from_dict(data["usage"])
|
|
131
|
-
|
|
132
|
-
return cls(
|
|
133
|
-
id=data.get("id", random.randint(0, 1_000_000_000)),
|
|
134
|
-
model_internal=data["model_internal"],
|
|
135
|
-
prompt=Conversation.from_log(data["prompt"]),
|
|
136
|
-
sampling_params=SamplingParams(**data["sampling_params"]),
|
|
137
|
-
status_code=data["status_code"],
|
|
138
|
-
is_error=data["is_error"],
|
|
139
|
-
error_message=data["error_message"],
|
|
140
|
-
usage=usage,
|
|
141
|
-
content=content,
|
|
142
|
-
thinking=data.get("thinking"),
|
|
143
|
-
model_external=data.get("model_external"),
|
|
144
|
-
region=data.get("region"),
|
|
145
|
-
logprobs=data.get("logprobs"),
|
|
146
|
-
finish_reason=data.get("finish_reason"),
|
|
147
|
-
cost=data.get("cost"),
|
|
148
|
-
cache_hit=data.get("cache_hit", False),
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
def write_to_file(self, filename):
|
|
152
|
-
"""
|
|
153
|
-
Writes the APIResponse as a line to a file.
|
|
154
|
-
If file exists, appends to it.
|
|
155
|
-
"""
|
|
156
|
-
with open(filename, "a") as f:
|
|
157
|
-
f.write(json.dumps(self.to_dict()) + "\n")
|
|
12
|
+
from ..config import SamplingParams
|
|
13
|
+
from ..errors import raise_if_modal_exception
|
|
14
|
+
from ..models import APIModel
|
|
15
|
+
from ..tracker import StatusTracker
|
|
16
|
+
from .response import APIResponse
|
|
158
17
|
|
|
159
18
|
|
|
160
19
|
class APIRequestBase(ABC):
|
|
@@ -176,16 +35,11 @@ class APIRequestBase(ABC):
|
|
|
176
35
|
prompt: Conversation,
|
|
177
36
|
attempts_left: int,
|
|
178
37
|
status_tracker: StatusTracker,
|
|
179
|
-
retry_queue: asyncio.Queue,
|
|
180
38
|
# needed in order to retry with a different model and not throw the output away
|
|
181
39
|
results_arr: list["APIRequestBase"],
|
|
182
40
|
request_timeout: int = 30,
|
|
183
41
|
sampling_params: SamplingParams = SamplingParams(),
|
|
184
|
-
logprobs: bool = False,
|
|
185
|
-
top_logprobs: int | None = None,
|
|
186
|
-
pbar: tqdm | None = None,
|
|
187
42
|
callback: Callable | None = None,
|
|
188
|
-
debug: bool = False,
|
|
189
43
|
all_model_names: list[str] | None = None,
|
|
190
44
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
191
45
|
tools: list | None = None,
|
|
@@ -199,16 +53,11 @@ class APIRequestBase(ABC):
|
|
|
199
53
|
self.prompt = prompt
|
|
200
54
|
self.attempts_left = attempts_left
|
|
201
55
|
self.status_tracker = status_tracker
|
|
202
|
-
self.retry_queue = retry_queue
|
|
203
56
|
self.request_timeout = request_timeout
|
|
204
57
|
self.sampling_params = sampling_params
|
|
205
|
-
self.logprobs = logprobs # len(completion) logprobs
|
|
206
|
-
self.top_logprobs = top_logprobs
|
|
207
|
-
self.pbar = pbar
|
|
208
58
|
self.callback = callback
|
|
209
59
|
self.num_tokens = prompt.count_tokens(sampling_params.max_new_tokens)
|
|
210
60
|
self.results_arr = results_arr
|
|
211
|
-
self.debug = debug
|
|
212
61
|
self.all_model_names = all_model_names
|
|
213
62
|
self.all_sampling_params = all_sampling_params
|
|
214
63
|
self.tools = tools
|
|
@@ -222,8 +71,7 @@ class APIRequestBase(ABC):
|
|
|
222
71
|
self.region = None
|
|
223
72
|
|
|
224
73
|
def increment_pbar(self):
|
|
225
|
-
|
|
226
|
-
self.pbar.update(1)
|
|
74
|
+
self.status_tracker.increment_pbar()
|
|
227
75
|
|
|
228
76
|
def call_callback(self):
|
|
229
77
|
if self.callback is not None:
|
|
@@ -232,7 +80,6 @@ class APIRequestBase(ABC):
|
|
|
232
80
|
|
|
233
81
|
def handle_success(self, data):
|
|
234
82
|
self.call_callback()
|
|
235
|
-
self.increment_pbar()
|
|
236
83
|
self.status_tracker.task_succeeded(self.task_id)
|
|
237
84
|
|
|
238
85
|
def handle_error(self, create_new_request=False, give_up_if_no_other_models=False):
|
|
@@ -253,7 +100,8 @@ class APIRequestBase(ABC):
|
|
|
253
100
|
if self.attempts_left > 0:
|
|
254
101
|
self.attempts_left -= 1
|
|
255
102
|
if not create_new_request:
|
|
256
|
-
self.retry_queue
|
|
103
|
+
assert self.status_tracker.retry_queue
|
|
104
|
+
self.status_tracker.retry_queue.put_nowait(self)
|
|
257
105
|
return
|
|
258
106
|
else:
|
|
259
107
|
# make sure we have another model to send it to besides the current one
|
|
@@ -267,7 +115,8 @@ class APIRequestBase(ABC):
|
|
|
267
115
|
print(
|
|
268
116
|
f"No other models to try for task {self.task_id}. Retrying with same model."
|
|
269
117
|
)
|
|
270
|
-
self.retry_queue
|
|
118
|
+
assert self.status_tracker.retry_queue
|
|
119
|
+
self.status_tracker.retry_queue.put_nowait(self)
|
|
271
120
|
else:
|
|
272
121
|
# two things to change: model_name and sampling_params
|
|
273
122
|
new_model_name = self.model_name
|
|
@@ -292,21 +141,21 @@ class APIRequestBase(ABC):
|
|
|
292
141
|
prompt=self.prompt,
|
|
293
142
|
attempts_left=self.attempts_left,
|
|
294
143
|
status_tracker=self.status_tracker,
|
|
295
|
-
retry_queue=self.retry_queue,
|
|
296
144
|
results_arr=self.results_arr,
|
|
297
145
|
request_timeout=self.request_timeout,
|
|
298
146
|
sampling_params=new_sampling_params,
|
|
299
|
-
logprobs=self.logprobs,
|
|
300
|
-
top_logprobs=self.top_logprobs,
|
|
301
|
-
pbar=self.pbar,
|
|
302
147
|
callback=self.callback,
|
|
303
148
|
all_model_names=self.all_model_names,
|
|
304
149
|
all_sampling_params=self.all_sampling_params,
|
|
305
150
|
tools=self.tools,
|
|
306
151
|
cache=self.cache,
|
|
152
|
+
computer_use=getattr(self, "computer_use", False),
|
|
153
|
+
display_width=getattr(self, "display_width", 1024),
|
|
154
|
+
display_height=getattr(self, "display_height", 768),
|
|
307
155
|
)
|
|
308
156
|
# PROBLEM: new request is never put into results array, so we can't get the result.
|
|
309
|
-
self.retry_queue
|
|
157
|
+
assert self.status_tracker.retry_queue
|
|
158
|
+
self.status_tracker.retry_queue.put_nowait(self)
|
|
310
159
|
# SOLUTION: just need to make sure it's deduplicated by task_id later.
|
|
311
160
|
self.results_arr.append(new_request)
|
|
312
161
|
else:
|
|
@@ -354,6 +203,8 @@ class APIRequestBase(ABC):
|
|
|
354
203
|
|
|
355
204
|
except Exception as e:
|
|
356
205
|
raise_if_modal_exception(e)
|
|
206
|
+
tb = traceback.format_exc()
|
|
207
|
+
print(tb)
|
|
357
208
|
self.result.append(
|
|
358
209
|
APIResponse(
|
|
359
210
|
id=self.task_id,
|
|
@@ -381,39 +232,52 @@ def create_api_request(
|
|
|
381
232
|
prompt: Conversation,
|
|
382
233
|
attempts_left: int,
|
|
383
234
|
status_tracker: StatusTracker,
|
|
384
|
-
retry_queue: asyncio.Queue,
|
|
385
235
|
results_arr: list["APIRequestBase"],
|
|
386
236
|
request_timeout: int = 30,
|
|
387
237
|
sampling_params: SamplingParams = SamplingParams(),
|
|
388
|
-
logprobs: bool = False,
|
|
389
|
-
top_logprobs: int | None = None,
|
|
390
|
-
pbar: tqdm | None = None,
|
|
391
238
|
callback: Callable | None = None,
|
|
392
239
|
all_model_names: list[str] | None = None,
|
|
393
240
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
394
241
|
tools: list | None = None,
|
|
395
242
|
cache: CachePattern | None = None,
|
|
243
|
+
computer_use: bool = False,
|
|
244
|
+
display_width: int = 1024,
|
|
245
|
+
display_height: int = 768,
|
|
246
|
+
use_responses_api: bool = False,
|
|
396
247
|
) -> APIRequestBase:
|
|
397
248
|
from .common import CLASSES # circular import so made it lazy, does this work?
|
|
398
249
|
|
|
399
250
|
model_obj = APIModel.from_registry(model_name)
|
|
400
|
-
|
|
251
|
+
|
|
252
|
+
# Choose API spec based on use_responses_api flag and model support
|
|
253
|
+
api_spec = model_obj.api_spec
|
|
254
|
+
if use_responses_api and model_obj.supports_responses and api_spec == "openai":
|
|
255
|
+
api_spec = "openai-responses"
|
|
256
|
+
|
|
257
|
+
request_class = CLASSES.get(api_spec, None)
|
|
401
258
|
if request_class is None:
|
|
402
|
-
raise ValueError(f"Unsupported API spec: {
|
|
403
|
-
kwargs =
|
|
404
|
-
|
|
405
|
-
)
|
|
259
|
+
raise ValueError(f"Unsupported API spec: {api_spec}")
|
|
260
|
+
kwargs = {}
|
|
261
|
+
# Add computer_use to kwargs if the request class supports it
|
|
262
|
+
model_obj = APIModel.from_registry(model_name)
|
|
263
|
+
if computer_use and api_spec in ["anthropic", "bedrock", "openai-responses"]:
|
|
264
|
+
kwargs.update(
|
|
265
|
+
{
|
|
266
|
+
"computer_use": computer_use,
|
|
267
|
+
"display_width": display_width,
|
|
268
|
+
"display_height": display_height,
|
|
269
|
+
}
|
|
270
|
+
)
|
|
271
|
+
|
|
406
272
|
return request_class(
|
|
407
273
|
task_id=task_id,
|
|
408
274
|
model_name=model_name,
|
|
409
275
|
prompt=prompt,
|
|
410
276
|
attempts_left=attempts_left,
|
|
411
277
|
status_tracker=status_tracker,
|
|
412
|
-
retry_queue=retry_queue,
|
|
413
278
|
results_arr=results_arr,
|
|
414
279
|
request_timeout=request_timeout,
|
|
415
280
|
sampling_params=sampling_params,
|
|
416
|
-
pbar=pbar,
|
|
417
281
|
callback=callback,
|
|
418
282
|
all_model_names=all_model_names,
|
|
419
283
|
all_sampling_params=all_sampling_params,
|
|
@@ -421,3 +285,22 @@ def create_api_request(
|
|
|
421
285
|
cache=cache,
|
|
422
286
|
**kwargs,
|
|
423
287
|
)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def deduplicate_responses(results: list[APIRequestBase]) -> list[APIResponse]:
|
|
291
|
+
deduplicated = {}
|
|
292
|
+
for request in results:
|
|
293
|
+
if request.task_id not in deduplicated:
|
|
294
|
+
deduplicated[request.task_id] = request.result[-1]
|
|
295
|
+
else:
|
|
296
|
+
current_response: APIResponse = deduplicated[request.task_id]
|
|
297
|
+
# only replace if the current request has no completion and the new one does
|
|
298
|
+
if (
|
|
299
|
+
request.result[-1].completion is not None
|
|
300
|
+
and current_response.completion is None
|
|
301
|
+
):
|
|
302
|
+
deduplicated[request.task_id] = request.result[-1]
|
|
303
|
+
|
|
304
|
+
output = [deduplicated[request.task_id] for request in results]
|
|
305
|
+
|
|
306
|
+
return output
|
|
@@ -2,7 +2,6 @@ import asyncio
|
|
|
2
2
|
import json
|
|
3
3
|
import os
|
|
4
4
|
from aiohttp import ClientResponse
|
|
5
|
-
from tqdm import tqdm
|
|
6
5
|
from typing import Callable
|
|
7
6
|
|
|
8
7
|
try:
|
|
@@ -24,7 +23,7 @@ from lm_deluge.usage import Usage
|
|
|
24
23
|
from .base import APIRequestBase, APIResponse
|
|
25
24
|
|
|
26
25
|
from ..tracker import StatusTracker
|
|
27
|
-
from ..
|
|
26
|
+
from ..config import SamplingParams
|
|
28
27
|
from ..models import APIModel
|
|
29
28
|
|
|
30
29
|
|
|
@@ -36,17 +35,18 @@ class BedrockRequest(APIRequestBase):
|
|
|
36
35
|
prompt: Conversation,
|
|
37
36
|
attempts_left: int,
|
|
38
37
|
status_tracker: StatusTracker,
|
|
39
|
-
retry_queue: asyncio.Queue,
|
|
40
38
|
results_arr: list,
|
|
41
39
|
request_timeout: int = 30,
|
|
42
40
|
sampling_params: SamplingParams = SamplingParams(),
|
|
43
|
-
pbar: tqdm | None = None,
|
|
44
41
|
callback: Callable | None = None,
|
|
45
|
-
debug: bool = False,
|
|
46
42
|
all_model_names: list[str] | None = None,
|
|
47
43
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
48
44
|
tools: list | None = None,
|
|
49
45
|
cache: CachePattern | None = None,
|
|
46
|
+
# Computer Use support
|
|
47
|
+
computer_use: bool = False,
|
|
48
|
+
display_width: int = 1024,
|
|
49
|
+
display_height: int = 768,
|
|
50
50
|
):
|
|
51
51
|
super().__init__(
|
|
52
52
|
task_id=task_id,
|
|
@@ -54,19 +54,20 @@ class BedrockRequest(APIRequestBase):
|
|
|
54
54
|
prompt=prompt,
|
|
55
55
|
attempts_left=attempts_left,
|
|
56
56
|
status_tracker=status_tracker,
|
|
57
|
-
retry_queue=retry_queue,
|
|
58
57
|
results_arr=results_arr,
|
|
59
58
|
request_timeout=request_timeout,
|
|
60
59
|
sampling_params=sampling_params,
|
|
61
|
-
pbar=pbar,
|
|
62
60
|
callback=callback,
|
|
63
|
-
debug=debug,
|
|
64
61
|
all_model_names=all_model_names,
|
|
65
62
|
all_sampling_params=all_sampling_params,
|
|
66
63
|
tools=tools,
|
|
67
64
|
cache=cache,
|
|
68
65
|
)
|
|
69
66
|
|
|
67
|
+
self.computer_use = computer_use
|
|
68
|
+
self.display_width = display_width
|
|
69
|
+
self.display_height = display_height
|
|
70
|
+
|
|
70
71
|
# Lock images as bytes if caching is enabled
|
|
71
72
|
if cache is not None:
|
|
72
73
|
prompt.lock_images_as_bytes()
|
|
@@ -115,11 +116,34 @@ class BedrockRequest(APIRequestBase):
|
|
|
115
116
|
if self.system_message is not None:
|
|
116
117
|
self.request_json["system"] = self.system_message
|
|
117
118
|
|
|
118
|
-
if tools:
|
|
119
|
-
tool_definitions = [
|
|
119
|
+
if tools or self.computer_use:
|
|
120
|
+
tool_definitions = []
|
|
121
|
+
|
|
122
|
+
# Add Computer Use tools at the beginning if enabled
|
|
123
|
+
if self.computer_use:
|
|
124
|
+
from ..computer_use.anthropic_tools import get_anthropic_cu_tools
|
|
125
|
+
|
|
126
|
+
cu_tools = get_anthropic_cu_tools(
|
|
127
|
+
model=self.model.id,
|
|
128
|
+
display_width=self.display_width,
|
|
129
|
+
display_height=self.display_height,
|
|
130
|
+
)
|
|
131
|
+
tool_definitions.extend(cu_tools)
|
|
132
|
+
|
|
133
|
+
# Add computer use display parameters to the request
|
|
134
|
+
self.request_json["computer_use_display_width_px"] = self.display_width
|
|
135
|
+
self.request_json["computer_use_display_height_px"] = (
|
|
136
|
+
self.display_height
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Add user-provided tools
|
|
140
|
+
if tools:
|
|
141
|
+
tool_definitions.extend([tool.dump_for("anthropic") for tool in tools])
|
|
142
|
+
|
|
120
143
|
# Add cache control to last tool if tools_only caching is specified
|
|
121
144
|
if cache == "tools_only" and tool_definitions:
|
|
122
145
|
tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
|
|
146
|
+
|
|
123
147
|
self.request_json["tools"] = tool_definitions
|
|
124
148
|
|
|
125
149
|
# Setup AWS4Auth for signing
|