lm-deluge 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

lm_deluge/__init__.py CHANGED
@@ -1,7 +1,17 @@
1
1
  from .client import LLMClient, SamplingParams, APIResponse
2
2
  from .prompt import Conversation, Message
3
+ from .tool import Tool
4
+ from .file import File
3
5
  import dotenv
4
6
 
5
7
  dotenv.load_dotenv()
6
8
 
7
- __all__ = ["LLMClient", "SamplingParams", "APIResponse", "Conversation", "Message"]
9
+ __all__ = [
10
+ "LLMClient",
11
+ "SamplingParams",
12
+ "APIResponse",
13
+ "Conversation",
14
+ "Message",
15
+ "Tool",
16
+ "File",
17
+ ]
lm_deluge/agent.py ADDED
File without changes
@@ -1,9 +1,6 @@
1
- import asyncio
2
1
  from aiohttp import ClientResponse
3
2
  import json
4
3
  import os
5
- import warnings
6
- from tqdm import tqdm
7
4
  from typing import Callable
8
5
 
9
6
  from lm_deluge.prompt import (
@@ -14,12 +11,84 @@ from lm_deluge.prompt import (
14
11
  Thinking,
15
12
  CachePattern,
16
13
  )
14
+ from lm_deluge.tool import Tool
17
15
  from lm_deluge.usage import Usage
18
16
  from .base import APIRequestBase, APIResponse
19
17
 
20
18
  from ..tracker import StatusTracker
21
- from ..sampling_params import SamplingParams
19
+ from ..config import SamplingParams
22
20
  from ..models import APIModel
21
+ from ..computer_use.anthropic_tools import get_anthropic_cu_tools
22
+
23
+
24
+ def _build_anthropic_request(
25
+ model: APIModel,
26
+ prompt: Conversation,
27
+ tools: list[Tool] | None,
28
+ sampling_params: SamplingParams,
29
+ cache_pattern: CachePattern | None = None,
30
+ computer_use: bool = False,
31
+ display_width: int = 1024,
32
+ display_height: int = 768,
33
+ ):
34
+ system_message, messages = prompt.to_anthropic(cache_pattern=cache_pattern)
35
+ request_header = {
36
+ "x-api-key": os.getenv(model.api_key_env_var),
37
+ "anthropic-version": "2023-06-01",
38
+ "content-type": "application/json",
39
+ }
40
+
41
+ # Add beta header for Computer Use
42
+ if computer_use:
43
+ request_header["anthropic-beta"] = "computer-use-2025-01-24"
44
+
45
+ request_json = {
46
+ "model": model.name,
47
+ "messages": messages,
48
+ "temperature": sampling_params.temperature,
49
+ "top_p": sampling_params.top_p,
50
+ "max_tokens": sampling_params.max_new_tokens,
51
+ }
52
+
53
+ # handle thinking
54
+ if model.reasoning_model and sampling_params.reasoning_effort:
55
+ # translate reasoning effort of low, medium, high to budget tokens
56
+ budget = {"low": 1024, "medium": 4096, "high": 16384}.get(
57
+ sampling_params.reasoning_effort
58
+ )
59
+ request_json["thinking"] = {
60
+ "type": "enabled",
61
+ "budget_tokens": budget,
62
+ }
63
+ request_json.pop("top_p")
64
+ request_json["temperature"] = 1.0
65
+ request_json["max_tokens"] += budget
66
+ else:
67
+ request_json["thinking"] = {"type": "disabled"}
68
+ if sampling_params.reasoning_effort:
69
+ print("ignoring reasoning_effort for non-reasoning model")
70
+ if system_message is not None:
71
+ request_json["system"] = system_message
72
+ if tools or computer_use:
73
+ tool_definitions = []
74
+ if tools:
75
+ tool_definitions.extend([tool.dump_for("anthropic") for tool in tools])
76
+ # Add Computer Use tools
77
+ if computer_use:
78
+ cu_tools = get_anthropic_cu_tools(
79
+ model=model.id,
80
+ display_width=display_width, # todo: set from ComputerUseParams
81
+ display_height=display_height,
82
+ )
83
+ tool_definitions.extend(cu_tools)
84
+
85
+ # Add cache control to last tool if tools_only caching is specified
86
+ if cache_pattern == "tools_only" and tool_definitions:
87
+ tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
88
+
89
+ request_json["tools"] = tool_definitions
90
+
91
+ return request_json, request_header
23
92
 
24
93
 
25
94
  class AnthropicRequest(APIRequestBase):
@@ -32,18 +101,19 @@ class AnthropicRequest(APIRequestBase):
32
101
  prompt: Conversation,
33
102
  attempts_left: int,
34
103
  status_tracker: StatusTracker,
35
- retry_queue: asyncio.Queue,
36
104
  results_arr: list,
37
105
  request_timeout: int = 30,
38
106
  sampling_params: SamplingParams = SamplingParams(),
39
- pbar: tqdm | None = None,
40
107
  callback: Callable | None = None,
41
- debug: bool = False,
42
108
  # for retries
43
109
  all_model_names: list[str] | None = None,
44
110
  all_sampling_params: list[SamplingParams] | None = None,
45
111
  tools: list | None = None,
46
112
  cache: CachePattern | None = None,
113
+ # Computer Use support
114
+ computer_use: bool = False,
115
+ display_width: int = 1024,
116
+ display_height: int = 768,
47
117
  ):
48
118
  super().__init__(
49
119
  task_id=task_id,
@@ -51,18 +121,18 @@ class AnthropicRequest(APIRequestBase):
51
121
  prompt=prompt,
52
122
  attempts_left=attempts_left,
53
123
  status_tracker=status_tracker,
54
- retry_queue=retry_queue,
55
124
  results_arr=results_arr,
56
125
  request_timeout=request_timeout,
57
126
  sampling_params=sampling_params,
58
- pbar=pbar,
59
127
  callback=callback,
60
- debug=debug,
61
128
  all_model_names=all_model_names,
62
129
  all_sampling_params=all_sampling_params,
63
130
  tools=tools,
64
131
  cache=cache,
65
132
  )
133
+ self.computer_use = computer_use
134
+ self.display_width = display_width
135
+ self.display_height = display_height
66
136
  self.model = APIModel.from_registry(model_name)
67
137
  self.url = f"{self.model.api_base}/messages"
68
138
 
@@ -70,52 +140,16 @@ class AnthropicRequest(APIRequestBase):
70
140
  if cache is not None:
71
141
  prompt.lock_images_as_bytes()
72
142
 
73
- self.system_message, messages = prompt.to_anthropic(cache_pattern=cache)
74
- self.request_header = {
75
- "x-api-key": os.getenv(self.model.api_key_env_var),
76
- "anthropic-version": "2023-06-01",
77
- "content-type": "application/json",
78
- }
79
-
80
- self.request_json = {
81
- "model": self.model.name,
82
- "messages": messages,
83
- "temperature": self.sampling_params.temperature,
84
- "top_p": self.sampling_params.top_p,
85
- "max_tokens": self.sampling_params.max_new_tokens,
86
- }
87
- # handle thinking
88
- if self.model.reasoning_model:
89
- if sampling_params.reasoning_effort:
90
- # translate reasoning effort of low, medium, high to budget tokens
91
- budget = {"low": 1024, "medium": 4096, "high": 16384}.get(
92
- sampling_params.reasoning_effort
93
- )
94
- self.request_json["thinking"] = {
95
- "type": "enabled",
96
- "budget_tokens": budget,
97
- }
98
- self.request_json.pop("top_p")
99
- self.request_json["temperature"] = 1.0
100
- self.request_json["max_tokens"] += (
101
- budget # assume max tokens is max completion tokens
102
- )
103
- else:
104
- # no thinking
105
- self.request_json["thinking"] = {"type": "disabled"}
106
- else:
107
- if sampling_params.reasoning_effort:
108
- warnings.warn(
109
- f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
110
- )
111
- if self.system_message is not None:
112
- self.request_json["system"] = self.system_message
113
- if tools:
114
- tool_definitions = [tool.dump_for("anthropic") for tool in tools]
115
- # Add cache control to last tool if tools_only caching is specified
116
- if cache == "tools_only" and tool_definitions:
117
- tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
118
- self.request_json["tools"] = tool_definitions
143
+ self.request_json, self.request_header = _build_anthropic_request(
144
+ self.model,
145
+ prompt,
146
+ tools,
147
+ sampling_params,
148
+ cache,
149
+ computer_use,
150
+ display_width,
151
+ display_height,
152
+ )
119
153
 
120
154
  async def handle_response(self, http_response: ClientResponse) -> APIResponse:
121
155
  is_error = False
@@ -135,8 +169,6 @@ class AnthropicRequest(APIRequestBase):
135
169
  "anthropic-ratelimit-tokens-reset",
136
170
  ]:
137
171
  rate_limits[header] = http_response.headers.get(header, None)
138
- if self.debug:
139
- print(f"Rate limits: {rate_limits}")
140
172
  if status_code >= 200 and status_code < 300:
141
173
  try:
142
174
  data = await http_response.json()
@@ -1,160 +1,19 @@
1
- import aiohttp
2
1
  import asyncio
3
- import json
4
2
  import random
5
- from tqdm import tqdm
6
- from dataclasses import dataclass
3
+ import traceback
7
4
  from abc import ABC, abstractmethod
8
5
  from typing import Callable
9
6
 
10
- from lm_deluge.prompt import Conversation, Message, CachePattern
11
- from lm_deluge.usage import Usage
12
-
13
- from ..tracker import StatusTracker
14
- from ..sampling_params import SamplingParams
15
- from ..models import APIModel
16
- from ..errors import raise_if_modal_exception
7
+ import aiohttp
17
8
  from aiohttp import ClientResponse
18
9
 
10
+ from lm_deluge.prompt import CachePattern, Conversation
19
11
 
20
- @dataclass
21
- class APIResponse:
22
- # request information
23
- id: int # should be unique to the request within a given prompt-processing call
24
- model_internal: str # our internal model tag
25
- prompt: Conversation
26
- sampling_params: SamplingParams
27
-
28
- # http response information
29
- status_code: int | None
30
- is_error: bool | None
31
- error_message: str | None
32
-
33
- # completion information - unified usage tracking
34
- usage: Usage | None = None
35
-
36
- # response content - structured format
37
- content: Message | None = None
38
-
39
- # optional or calculated automatically
40
- thinking: str | None = None # if model shows thinking tokens
41
- model_external: str | None = None # the model tag used by the API
42
- region: str | None = None
43
- logprobs: list | None = None
44
- finish_reason: str | None = None # make required later
45
- cost: float | None = None # calculated automatically
46
- cache_hit: bool = False # manually set if true
47
- # set to true if is_error and should be retried with a different model
48
- retry_with_different_model: bool | None = False
49
- # set to true if should NOT retry with the same model (unrecoverable error)
50
- give_up_if_no_other_models: bool | None = False
51
-
52
- @property
53
- def completion(self) -> str | None:
54
- """Backward compatibility: extract text from content Message."""
55
- if self.content is not None:
56
- return self.content.completion
57
- return None
58
-
59
- @property
60
- def input_tokens(self) -> int | None:
61
- """Get input tokens from usage object."""
62
- return self.usage.input_tokens if self.usage else None
63
-
64
- @property
65
- def output_tokens(self) -> int | None:
66
- """Get output tokens from usage object."""
67
- return self.usage.output_tokens if self.usage else None
68
-
69
- @property
70
- def cache_read_tokens(self) -> int | None:
71
- """Get cache read tokens from usage object."""
72
- return self.usage.cache_read_tokens if self.usage else None
73
-
74
- @property
75
- def cache_write_tokens(self) -> int | None:
76
- """Get cache write tokens from usage object."""
77
- return self.usage.cache_write_tokens if self.usage else None
78
-
79
- def __post_init__(self):
80
- # calculate cost & get external model name
81
- self.id = int(self.id)
82
- api_model = APIModel.from_registry(self.model_internal)
83
- self.model_external = api_model.name
84
- self.cost = None
85
- if (
86
- self.usage is not None
87
- and api_model.input_cost is not None
88
- and api_model.output_cost is not None
89
- ):
90
- self.cost = (
91
- self.usage.input_tokens * api_model.input_cost / 1e6
92
- + self.usage.output_tokens * api_model.output_cost / 1e6
93
- )
94
- elif self.content is not None and self.completion is not None:
95
- print(
96
- f"Warning: Completion provided without token counts for model {self.model_internal}."
97
- )
98
-
99
- def to_dict(self):
100
- return {
101
- "id": self.id,
102
- "model_internal": self.model_internal,
103
- "model_external": self.model_external,
104
- "region": self.region,
105
- "prompt": self.prompt.to_log(), # destroys image if present
106
- "sampling_params": self.sampling_params.__dict__,
107
- "status_code": self.status_code,
108
- "is_error": self.is_error,
109
- "error_message": self.error_message,
110
- "completion": self.completion, # computed property
111
- "content": self.content.to_log() if self.content else None,
112
- "usage": self.usage.to_dict() if self.usage else None,
113
- "finish_reason": self.finish_reason,
114
- "cost": self.cost,
115
- }
116
-
117
- @classmethod
118
- def from_dict(cls, data: dict):
119
- # Handle backward compatibility for content/completion
120
- content = None
121
- if "content" in data and data["content"] is not None:
122
- # Reconstruct message from log format
123
- content = Message.from_log(data["content"])
124
- elif "completion" in data and data["completion"] is not None:
125
- # Backward compatibility: create a Message with just text
126
- content = Message.ai(data["completion"])
127
-
128
- usage = None
129
- if "usage" in data and data["usage"] is not None:
130
- usage = Usage.from_dict(data["usage"])
131
-
132
- return cls(
133
- id=data.get("id", random.randint(0, 1_000_000_000)),
134
- model_internal=data["model_internal"],
135
- prompt=Conversation.from_log(data["prompt"]),
136
- sampling_params=SamplingParams(**data["sampling_params"]),
137
- status_code=data["status_code"],
138
- is_error=data["is_error"],
139
- error_message=data["error_message"],
140
- usage=usage,
141
- content=content,
142
- thinking=data.get("thinking"),
143
- model_external=data.get("model_external"),
144
- region=data.get("region"),
145
- logprobs=data.get("logprobs"),
146
- finish_reason=data.get("finish_reason"),
147
- cost=data.get("cost"),
148
- cache_hit=data.get("cache_hit", False),
149
- )
150
-
151
- def write_to_file(self, filename):
152
- """
153
- Writes the APIResponse as a line to a file.
154
- If file exists, appends to it.
155
- """
156
- with open(filename, "a") as f:
157
- f.write(json.dumps(self.to_dict()) + "\n")
12
+ from ..config import SamplingParams
13
+ from ..errors import raise_if_modal_exception
14
+ from ..models import APIModel
15
+ from ..tracker import StatusTracker
16
+ from .response import APIResponse
158
17
 
159
18
 
160
19
  class APIRequestBase(ABC):
@@ -176,16 +35,11 @@ class APIRequestBase(ABC):
176
35
  prompt: Conversation,
177
36
  attempts_left: int,
178
37
  status_tracker: StatusTracker,
179
- retry_queue: asyncio.Queue,
180
38
  # needed in order to retry with a different model and not throw the output away
181
39
  results_arr: list["APIRequestBase"],
182
40
  request_timeout: int = 30,
183
41
  sampling_params: SamplingParams = SamplingParams(),
184
- logprobs: bool = False,
185
- top_logprobs: int | None = None,
186
- pbar: tqdm | None = None,
187
42
  callback: Callable | None = None,
188
- debug: bool = False,
189
43
  all_model_names: list[str] | None = None,
190
44
  all_sampling_params: list[SamplingParams] | None = None,
191
45
  tools: list | None = None,
@@ -199,16 +53,11 @@ class APIRequestBase(ABC):
199
53
  self.prompt = prompt
200
54
  self.attempts_left = attempts_left
201
55
  self.status_tracker = status_tracker
202
- self.retry_queue = retry_queue
203
56
  self.request_timeout = request_timeout
204
57
  self.sampling_params = sampling_params
205
- self.logprobs = logprobs # len(completion) logprobs
206
- self.top_logprobs = top_logprobs
207
- self.pbar = pbar
208
58
  self.callback = callback
209
59
  self.num_tokens = prompt.count_tokens(sampling_params.max_new_tokens)
210
60
  self.results_arr = results_arr
211
- self.debug = debug
212
61
  self.all_model_names = all_model_names
213
62
  self.all_sampling_params = all_sampling_params
214
63
  self.tools = tools
@@ -222,8 +71,7 @@ class APIRequestBase(ABC):
222
71
  self.region = None
223
72
 
224
73
  def increment_pbar(self):
225
- if self.pbar is not None:
226
- self.pbar.update(1)
74
+ self.status_tracker.increment_pbar()
227
75
 
228
76
  def call_callback(self):
229
77
  if self.callback is not None:
@@ -232,7 +80,6 @@ class APIRequestBase(ABC):
232
80
 
233
81
  def handle_success(self, data):
234
82
  self.call_callback()
235
- self.increment_pbar()
236
83
  self.status_tracker.task_succeeded(self.task_id)
237
84
 
238
85
  def handle_error(self, create_new_request=False, give_up_if_no_other_models=False):
@@ -253,7 +100,8 @@ class APIRequestBase(ABC):
253
100
  if self.attempts_left > 0:
254
101
  self.attempts_left -= 1
255
102
  if not create_new_request:
256
- self.retry_queue.put_nowait(self)
103
+ assert self.status_tracker.retry_queue
104
+ self.status_tracker.retry_queue.put_nowait(self)
257
105
  return
258
106
  else:
259
107
  # make sure we have another model to send it to besides the current one
@@ -267,7 +115,8 @@ class APIRequestBase(ABC):
267
115
  print(
268
116
  f"No other models to try for task {self.task_id}. Retrying with same model."
269
117
  )
270
- self.retry_queue.put_nowait(self)
118
+ assert self.status_tracker.retry_queue
119
+ self.status_tracker.retry_queue.put_nowait(self)
271
120
  else:
272
121
  # two things to change: model_name and sampling_params
273
122
  new_model_name = self.model_name
@@ -292,21 +141,21 @@ class APIRequestBase(ABC):
292
141
  prompt=self.prompt,
293
142
  attempts_left=self.attempts_left,
294
143
  status_tracker=self.status_tracker,
295
- retry_queue=self.retry_queue,
296
144
  results_arr=self.results_arr,
297
145
  request_timeout=self.request_timeout,
298
146
  sampling_params=new_sampling_params,
299
- logprobs=self.logprobs,
300
- top_logprobs=self.top_logprobs,
301
- pbar=self.pbar,
302
147
  callback=self.callback,
303
148
  all_model_names=self.all_model_names,
304
149
  all_sampling_params=self.all_sampling_params,
305
150
  tools=self.tools,
306
151
  cache=self.cache,
152
+ computer_use=getattr(self, "computer_use", False),
153
+ display_width=getattr(self, "display_width", 1024),
154
+ display_height=getattr(self, "display_height", 768),
307
155
  )
308
156
  # PROBLEM: new request is never put into results array, so we can't get the result.
309
- self.retry_queue.put_nowait(new_request)
157
+ assert self.status_tracker.retry_queue
158
+ self.status_tracker.retry_queue.put_nowait(self)
310
159
  # SOLUTION: just need to make sure it's deduplicated by task_id later.
311
160
  self.results_arr.append(new_request)
312
161
  else:
@@ -354,6 +203,8 @@ class APIRequestBase(ABC):
354
203
 
355
204
  except Exception as e:
356
205
  raise_if_modal_exception(e)
206
+ tb = traceback.format_exc()
207
+ print(tb)
357
208
  self.result.append(
358
209
  APIResponse(
359
210
  id=self.task_id,
@@ -381,39 +232,52 @@ def create_api_request(
381
232
  prompt: Conversation,
382
233
  attempts_left: int,
383
234
  status_tracker: StatusTracker,
384
- retry_queue: asyncio.Queue,
385
235
  results_arr: list["APIRequestBase"],
386
236
  request_timeout: int = 30,
387
237
  sampling_params: SamplingParams = SamplingParams(),
388
- logprobs: bool = False,
389
- top_logprobs: int | None = None,
390
- pbar: tqdm | None = None,
391
238
  callback: Callable | None = None,
392
239
  all_model_names: list[str] | None = None,
393
240
  all_sampling_params: list[SamplingParams] | None = None,
394
241
  tools: list | None = None,
395
242
  cache: CachePattern | None = None,
243
+ computer_use: bool = False,
244
+ display_width: int = 1024,
245
+ display_height: int = 768,
246
+ use_responses_api: bool = False,
396
247
  ) -> APIRequestBase:
397
248
  from .common import CLASSES # circular import so made it lazy, does this work?
398
249
 
399
250
  model_obj = APIModel.from_registry(model_name)
400
- request_class = CLASSES.get(model_obj.api_spec, None)
251
+
252
+ # Choose API spec based on use_responses_api flag and model support
253
+ api_spec = model_obj.api_spec
254
+ if use_responses_api and model_obj.supports_responses and api_spec == "openai":
255
+ api_spec = "openai-responses"
256
+
257
+ request_class = CLASSES.get(api_spec, None)
401
258
  if request_class is None:
402
- raise ValueError(f"Unsupported API spec: {model_obj.api_spec}")
403
- kwargs = (
404
- {} if not logprobs else {"logprobs": logprobs, "top_logprobs": top_logprobs}
405
- )
259
+ raise ValueError(f"Unsupported API spec: {api_spec}")
260
+ kwargs = {}
261
+ # Add computer_use to kwargs if the request class supports it
262
+ model_obj = APIModel.from_registry(model_name)
263
+ if computer_use and api_spec in ["anthropic", "bedrock", "openai-responses"]:
264
+ kwargs.update(
265
+ {
266
+ "computer_use": computer_use,
267
+ "display_width": display_width,
268
+ "display_height": display_height,
269
+ }
270
+ )
271
+
406
272
  return request_class(
407
273
  task_id=task_id,
408
274
  model_name=model_name,
409
275
  prompt=prompt,
410
276
  attempts_left=attempts_left,
411
277
  status_tracker=status_tracker,
412
- retry_queue=retry_queue,
413
278
  results_arr=results_arr,
414
279
  request_timeout=request_timeout,
415
280
  sampling_params=sampling_params,
416
- pbar=pbar,
417
281
  callback=callback,
418
282
  all_model_names=all_model_names,
419
283
  all_sampling_params=all_sampling_params,
@@ -421,3 +285,22 @@ def create_api_request(
421
285
  cache=cache,
422
286
  **kwargs,
423
287
  )
288
+
289
+
290
+ def deduplicate_responses(results: list[APIRequestBase]) -> list[APIResponse]:
291
+ deduplicated = {}
292
+ for request in results:
293
+ if request.task_id not in deduplicated:
294
+ deduplicated[request.task_id] = request.result[-1]
295
+ else:
296
+ current_response: APIResponse = deduplicated[request.task_id]
297
+ # only replace if the current request has no completion and the new one does
298
+ if (
299
+ request.result[-1].completion is not None
300
+ and current_response.completion is None
301
+ ):
302
+ deduplicated[request.task_id] = request.result[-1]
303
+
304
+ output = [deduplicated[request.task_id] for request in results]
305
+
306
+ return output
@@ -2,7 +2,6 @@ import asyncio
2
2
  import json
3
3
  import os
4
4
  from aiohttp import ClientResponse
5
- from tqdm import tqdm
6
5
  from typing import Callable
7
6
 
8
7
  try:
@@ -24,7 +23,7 @@ from lm_deluge.usage import Usage
24
23
  from .base import APIRequestBase, APIResponse
25
24
 
26
25
  from ..tracker import StatusTracker
27
- from ..sampling_params import SamplingParams
26
+ from ..config import SamplingParams
28
27
  from ..models import APIModel
29
28
 
30
29
 
@@ -36,17 +35,18 @@ class BedrockRequest(APIRequestBase):
36
35
  prompt: Conversation,
37
36
  attempts_left: int,
38
37
  status_tracker: StatusTracker,
39
- retry_queue: asyncio.Queue,
40
38
  results_arr: list,
41
39
  request_timeout: int = 30,
42
40
  sampling_params: SamplingParams = SamplingParams(),
43
- pbar: tqdm | None = None,
44
41
  callback: Callable | None = None,
45
- debug: bool = False,
46
42
  all_model_names: list[str] | None = None,
47
43
  all_sampling_params: list[SamplingParams] | None = None,
48
44
  tools: list | None = None,
49
45
  cache: CachePattern | None = None,
46
+ # Computer Use support
47
+ computer_use: bool = False,
48
+ display_width: int = 1024,
49
+ display_height: int = 768,
50
50
  ):
51
51
  super().__init__(
52
52
  task_id=task_id,
@@ -54,19 +54,20 @@ class BedrockRequest(APIRequestBase):
54
54
  prompt=prompt,
55
55
  attempts_left=attempts_left,
56
56
  status_tracker=status_tracker,
57
- retry_queue=retry_queue,
58
57
  results_arr=results_arr,
59
58
  request_timeout=request_timeout,
60
59
  sampling_params=sampling_params,
61
- pbar=pbar,
62
60
  callback=callback,
63
- debug=debug,
64
61
  all_model_names=all_model_names,
65
62
  all_sampling_params=all_sampling_params,
66
63
  tools=tools,
67
64
  cache=cache,
68
65
  )
69
66
 
67
+ self.computer_use = computer_use
68
+ self.display_width = display_width
69
+ self.display_height = display_height
70
+
70
71
  # Lock images as bytes if caching is enabled
71
72
  if cache is not None:
72
73
  prompt.lock_images_as_bytes()
@@ -115,11 +116,34 @@ class BedrockRequest(APIRequestBase):
115
116
  if self.system_message is not None:
116
117
  self.request_json["system"] = self.system_message
117
118
 
118
- if tools:
119
- tool_definitions = [tool.dump_for("anthropic") for tool in tools]
119
+ if tools or self.computer_use:
120
+ tool_definitions = []
121
+
122
+ # Add Computer Use tools at the beginning if enabled
123
+ if self.computer_use:
124
+ from ..computer_use.anthropic_tools import get_anthropic_cu_tools
125
+
126
+ cu_tools = get_anthropic_cu_tools(
127
+ model=self.model.id,
128
+ display_width=self.display_width,
129
+ display_height=self.display_height,
130
+ )
131
+ tool_definitions.extend(cu_tools)
132
+
133
+ # Add computer use display parameters to the request
134
+ self.request_json["computer_use_display_width_px"] = self.display_width
135
+ self.request_json["computer_use_display_height_px"] = (
136
+ self.display_height
137
+ )
138
+
139
+ # Add user-provided tools
140
+ if tools:
141
+ tool_definitions.extend([tool.dump_for("anthropic") for tool in tools])
142
+
120
143
  # Add cache control to last tool if tools_only caching is specified
121
144
  if cache == "tools_only" and tool_definitions:
122
145
  tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
146
+
123
147
  self.request_json["tools"] = tool_definitions
124
148
 
125
149
  # Setup AWS4Auth for signing