lm-deluge 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

lm_deluge/__init__.py CHANGED
@@ -1,7 +1,15 @@
1
1
  from .client import LLMClient, SamplingParams, APIResponse
2
2
  from .prompt import Conversation, Message
3
+ from .tool import Tool
3
4
  import dotenv
4
5
 
5
6
  dotenv.load_dotenv()
6
7
 
7
- __all__ = ["LLMClient", "SamplingParams", "APIResponse", "Conversation", "Message"]
8
+ __all__ = [
9
+ "LLMClient",
10
+ "SamplingParams",
11
+ "APIResponse",
12
+ "Conversation",
13
+ "Message",
14
+ "Tool",
15
+ ]
lm_deluge/agent.py ADDED
File without changes
@@ -1,17 +1,94 @@
1
- import asyncio
2
1
  from aiohttp import ClientResponse
3
2
  import json
4
3
  import os
5
- import warnings
6
- from tqdm import tqdm
7
4
  from typing import Callable
8
5
 
9
- from lm_deluge.prompt import Conversation, Message, Text, ToolCall, Thinking
6
+ from lm_deluge.prompt import (
7
+ Conversation,
8
+ Message,
9
+ Text,
10
+ ToolCall,
11
+ Thinking,
12
+ CachePattern,
13
+ )
14
+ from lm_deluge.tool import Tool
15
+ from lm_deluge.usage import Usage
10
16
  from .base import APIRequestBase, APIResponse
11
17
 
12
18
  from ..tracker import StatusTracker
13
- from ..sampling_params import SamplingParams
19
+ from ..config import SamplingParams
14
20
  from ..models import APIModel
21
+ from ..computer_use.anthropic_tools import get_anthropic_cu_tools
22
+
23
+
24
+ def _build_anthropic_request(
25
+ model: APIModel,
26
+ prompt: Conversation,
27
+ tools: list[Tool] | None,
28
+ sampling_params: SamplingParams,
29
+ cache_pattern: CachePattern | None = None,
30
+ computer_use: bool = False,
31
+ display_width: int = 1024,
32
+ display_height: int = 768,
33
+ ):
34
+ system_message, messages = prompt.to_anthropic(cache_pattern=cache_pattern)
35
+ request_header = {
36
+ "x-api-key": os.getenv(model.api_key_env_var),
37
+ "anthropic-version": "2023-06-01",
38
+ "content-type": "application/json",
39
+ }
40
+
41
+ # Add beta header for Computer Use
42
+ if computer_use:
43
+ request_header["anthropic-beta"] = "computer-use-2025-01-24"
44
+
45
+ request_json = {
46
+ "model": model.name,
47
+ "messages": messages,
48
+ "temperature": sampling_params.temperature,
49
+ "top_p": sampling_params.top_p,
50
+ "max_tokens": sampling_params.max_new_tokens,
51
+ }
52
+
53
+ # handle thinking
54
+ if model.reasoning_model and sampling_params.reasoning_effort:
55
+ # translate reasoning effort of low, medium, high to budget tokens
56
+ budget = {"low": 1024, "medium": 4096, "high": 16384}.get(
57
+ sampling_params.reasoning_effort
58
+ )
59
+ request_json["thinking"] = {
60
+ "type": "enabled",
61
+ "budget_tokens": budget,
62
+ }
63
+ request_json.pop("top_p")
64
+ request_json["temperature"] = 1.0
65
+ request_json["max_tokens"] += budget
66
+ else:
67
+ request_json["thinking"] = {"type": "disabled"}
68
+ if sampling_params.reasoning_effort:
69
+ print("ignoring reasoning_effort for non-reasoning model")
70
+ if system_message is not None:
71
+ request_json["system"] = system_message
72
+ if tools or computer_use:
73
+ tool_definitions = []
74
+ if tools:
75
+ tool_definitions.extend([tool.dump_for("anthropic") for tool in tools])
76
+ # Add Computer Use tools
77
+ if computer_use:
78
+ cu_tools = get_anthropic_cu_tools(
79
+ model=model.id,
80
+ display_width=display_width, # todo: set from ComputerUseParams
81
+ display_height=display_height,
82
+ )
83
+ tool_definitions.extend(cu_tools)
84
+
85
+ # Add cache control to last tool if tools_only caching is specified
86
+ if cache_pattern == "tools_only" and tool_definitions:
87
+ tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
88
+
89
+ request_json["tools"] = tool_definitions
90
+
91
+ return request_json, request_header
15
92
 
16
93
 
17
94
  class AnthropicRequest(APIRequestBase):
@@ -24,17 +101,19 @@ class AnthropicRequest(APIRequestBase):
24
101
  prompt: Conversation,
25
102
  attempts_left: int,
26
103
  status_tracker: StatusTracker,
27
- retry_queue: asyncio.Queue,
28
104
  results_arr: list,
29
105
  request_timeout: int = 30,
30
106
  sampling_params: SamplingParams = SamplingParams(),
31
- pbar: tqdm | None = None,
32
107
  callback: Callable | None = None,
33
- debug: bool = False,
34
108
  # for retries
35
109
  all_model_names: list[str] | None = None,
36
110
  all_sampling_params: list[SamplingParams] | None = None,
37
111
  tools: list | None = None,
112
+ cache: CachePattern | None = None,
113
+ # Computer Use support
114
+ computer_use: bool = False,
115
+ display_width: int = 1024,
116
+ display_height: int = 768,
38
117
  ):
39
118
  super().__init__(
40
119
  task_id=task_id,
@@ -42,70 +121,42 @@ class AnthropicRequest(APIRequestBase):
42
121
  prompt=prompt,
43
122
  attempts_left=attempts_left,
44
123
  status_tracker=status_tracker,
45
- retry_queue=retry_queue,
46
124
  results_arr=results_arr,
47
125
  request_timeout=request_timeout,
48
126
  sampling_params=sampling_params,
49
- pbar=pbar,
50
127
  callback=callback,
51
- debug=debug,
52
128
  all_model_names=all_model_names,
53
129
  all_sampling_params=all_sampling_params,
54
130
  tools=tools,
131
+ cache=cache,
55
132
  )
133
+ self.computer_use = computer_use
134
+ self.display_width = display_width
135
+ self.display_height = display_height
56
136
  self.model = APIModel.from_registry(model_name)
57
137
  self.url = f"{self.model.api_base}/messages"
58
138
 
59
- self.system_message, messages = prompt.to_anthropic()
60
- self.request_header = {
61
- "x-api-key": os.getenv(self.model.api_key_env_var),
62
- "anthropic-version": "2023-06-01",
63
- "content-type": "application/json",
64
- }
139
+ # Lock images as bytes if caching is enabled
140
+ if cache is not None:
141
+ prompt.lock_images_as_bytes()
65
142
 
66
- self.request_json = {
67
- "model": self.model.name,
68
- "messages": messages,
69
- "temperature": self.sampling_params.temperature,
70
- "top_p": self.sampling_params.top_p,
71
- "max_tokens": self.sampling_params.max_new_tokens,
72
- }
73
- # handle thinking
74
- if self.model.reasoning_model:
75
- if sampling_params.reasoning_effort:
76
- # translate reasoning effort of low, medium, high to budget tokens
77
- budget = {"low": 1024, "medium": 4096, "high": 16384}.get(
78
- sampling_params.reasoning_effort
79
- )
80
- self.request_json["thinking"] = {
81
- "type": "enabled",
82
- "budget_tokens": budget,
83
- }
84
- self.request_json.pop("top_p")
85
- self.request_json["temperature"] = 1.0
86
- self.request_json["max_tokens"] += (
87
- budget # assume max tokens is max completion tokens
88
- )
89
- else:
90
- # no thinking
91
- self.request_json["thinking"] = {"type": "disabled"}
92
- else:
93
- if sampling_params.reasoning_effort:
94
- warnings.warn(
95
- f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
96
- )
97
- if self.system_message is not None:
98
- self.request_json["system"] = self.system_message
99
- if tools:
100
- self.request_json["tools"] = [tool.dump_for("anthropic") for tool in tools]
143
+ self.request_json, self.request_header = _build_anthropic_request(
144
+ self.model,
145
+ prompt,
146
+ tools,
147
+ sampling_params,
148
+ cache,
149
+ computer_use,
150
+ display_width,
151
+ display_height,
152
+ )
101
153
 
102
154
  async def handle_response(self, http_response: ClientResponse) -> APIResponse:
103
155
  is_error = False
104
156
  error_message = None
105
157
  thinking = None
106
158
  content = None
107
- input_tokens = None
108
- output_tokens = None
159
+ usage = None
109
160
  status_code = http_response.status
110
161
  mimetype = http_response.headers.get("Content-Type", None)
111
162
  rate_limits = {}
@@ -118,8 +169,6 @@ class AnthropicRequest(APIRequestBase):
118
169
  "anthropic-ratelimit-tokens-reset",
119
170
  ]:
120
171
  rate_limits[header] = http_response.headers.get(header, None)
121
- if self.debug:
122
- print(f"Rate limits: {rate_limits}")
123
172
  if status_code >= 200 and status_code < 300:
124
173
  try:
125
174
  data = await http_response.json()
@@ -143,8 +192,7 @@ class AnthropicRequest(APIRequestBase):
143
192
  )
144
193
 
145
194
  content = Message("assistant", parts)
146
- input_tokens = data["usage"]["input_tokens"]
147
- output_tokens = data["usage"]["output_tokens"]
195
+ usage = Usage.from_anthropic_usage(data["usage"])
148
196
  except Exception as e:
149
197
  is_error = True
150
198
  error_message = (
@@ -182,6 +230,5 @@ class AnthropicRequest(APIRequestBase):
182
230
  thinking=thinking,
183
231
  model_internal=self.model_name,
184
232
  sampling_params=self.sampling_params,
185
- input_tokens=input_tokens,
186
- output_tokens=output_tokens,
233
+ usage=usage,
187
234
  )
@@ -1,19 +1,21 @@
1
- import aiohttp
2
1
  import asyncio
3
2
  import json
4
3
  import random
5
- from tqdm import tqdm
6
- from dataclasses import dataclass
4
+ import traceback
7
5
  from abc import ABC, abstractmethod
6
+ from dataclasses import dataclass
8
7
  from typing import Callable
9
8
 
10
- from lm_deluge.prompt import Conversation, Message
9
+ import aiohttp
10
+ from aiohttp import ClientResponse
11
+
12
+ from lm_deluge.prompt import CachePattern, Conversation, Message
13
+ from lm_deluge.usage import Usage
11
14
 
12
- from ..tracker import StatusTracker
13
- from ..sampling_params import SamplingParams
14
- from ..models import APIModel
15
+ from ..config import SamplingParams
15
16
  from ..errors import raise_if_modal_exception
16
- from aiohttp import ClientResponse
17
+ from ..models import APIModel
18
+ from ..tracker import StatusTracker
17
19
 
18
20
 
19
21
  @dataclass
@@ -29,9 +31,8 @@ class APIResponse:
29
31
  is_error: bool | None
30
32
  error_message: str | None
31
33
 
32
- # completion information
33
- input_tokens: int | None
34
- output_tokens: int | None
34
+ # completion information - unified usage tracking
35
+ usage: Usage | None = None
35
36
 
36
37
  # response content - structured format
37
38
  content: Message | None = None
@@ -48,6 +49,10 @@ class APIResponse:
48
49
  retry_with_different_model: bool | None = False
49
50
  # set to true if should NOT retry with the same model (unrecoverable error)
50
51
  give_up_if_no_other_models: bool | None = False
52
+ # OpenAI Responses API specific - used for computer use continuation
53
+ response_id: str | None = None
54
+ # Raw API response for debugging
55
+ raw_response: dict | None = None
51
56
 
52
57
  @property
53
58
  def completion(self) -> str | None:
@@ -56,6 +61,26 @@ class APIResponse:
56
61
  return self.content.completion
57
62
  return None
58
63
 
64
+ @property
65
+ def input_tokens(self) -> int | None:
66
+ """Get input tokens from usage object."""
67
+ return self.usage.input_tokens if self.usage else None
68
+
69
+ @property
70
+ def output_tokens(self) -> int | None:
71
+ """Get output tokens from usage object."""
72
+ return self.usage.output_tokens if self.usage else None
73
+
74
+ @property
75
+ def cache_read_tokens(self) -> int | None:
76
+ """Get cache read tokens from usage object."""
77
+ return self.usage.cache_read_tokens if self.usage else None
78
+
79
+ @property
80
+ def cache_write_tokens(self) -> int | None:
81
+ """Get cache write tokens from usage object."""
82
+ return self.usage.cache_write_tokens if self.usage else None
83
+
59
84
  def __post_init__(self):
60
85
  # calculate cost & get external model name
61
86
  self.id = int(self.id)
@@ -63,14 +88,13 @@ class APIResponse:
63
88
  self.model_external = api_model.name
64
89
  self.cost = None
65
90
  if (
66
- self.input_tokens is not None
67
- and self.output_tokens is not None
91
+ self.usage is not None
68
92
  and api_model.input_cost is not None
69
93
  and api_model.output_cost is not None
70
94
  ):
71
95
  self.cost = (
72
- self.input_tokens * api_model.input_cost / 1e6
73
- + self.output_tokens * api_model.output_cost / 1e6
96
+ self.usage.input_tokens * api_model.input_cost / 1e6
97
+ + self.usage.output_tokens * api_model.output_cost / 1e6
74
98
  )
75
99
  elif self.content is not None and self.completion is not None:
76
100
  print(
@@ -90,8 +114,7 @@ class APIResponse:
90
114
  "error_message": self.error_message,
91
115
  "completion": self.completion, # computed property
92
116
  "content": self.content.to_log() if self.content else None,
93
- "input_tokens": self.input_tokens,
94
- "output_tokens": self.output_tokens,
117
+ "usage": self.usage.to_dict() if self.usage else None,
95
118
  "finish_reason": self.finish_reason,
96
119
  "cost": self.cost,
97
120
  }
@@ -107,6 +130,10 @@ class APIResponse:
107
130
  # Backward compatibility: create a Message with just text
108
131
  content = Message.ai(data["completion"])
109
132
 
133
+ usage = None
134
+ if "usage" in data and data["usage"] is not None:
135
+ usage = Usage.from_dict(data["usage"])
136
+
110
137
  return cls(
111
138
  id=data.get("id", random.randint(0, 1_000_000_000)),
112
139
  model_internal=data["model_internal"],
@@ -115,8 +142,7 @@ class APIResponse:
115
142
  status_code=data["status_code"],
116
143
  is_error=data["is_error"],
117
144
  error_message=data["error_message"],
118
- input_tokens=data["input_tokens"],
119
- output_tokens=data["output_tokens"],
145
+ usage=usage,
120
146
  content=content,
121
147
  thinking=data.get("thinking"),
122
148
  model_external=data.get("model_external"),
@@ -155,19 +181,15 @@ class APIRequestBase(ABC):
155
181
  prompt: Conversation,
156
182
  attempts_left: int,
157
183
  status_tracker: StatusTracker,
158
- retry_queue: asyncio.Queue,
159
184
  # needed in order to retry with a different model and not throw the output away
160
185
  results_arr: list["APIRequestBase"],
161
186
  request_timeout: int = 30,
162
187
  sampling_params: SamplingParams = SamplingParams(),
163
- logprobs: bool = False,
164
- top_logprobs: int | None = None,
165
- pbar: tqdm | None = None,
166
188
  callback: Callable | None = None,
167
- debug: bool = False,
168
189
  all_model_names: list[str] | None = None,
169
190
  all_sampling_params: list[SamplingParams] | None = None,
170
191
  tools: list | None = None,
192
+ cache: CachePattern | None = None,
171
193
  ):
172
194
  if all_model_names is None:
173
195
  raise ValueError("all_model_names must be provided.")
@@ -177,19 +199,15 @@ class APIRequestBase(ABC):
177
199
  self.prompt = prompt
178
200
  self.attempts_left = attempts_left
179
201
  self.status_tracker = status_tracker
180
- self.retry_queue = retry_queue
181
202
  self.request_timeout = request_timeout
182
203
  self.sampling_params = sampling_params
183
- self.logprobs = logprobs # len(completion) logprobs
184
- self.top_logprobs = top_logprobs
185
- self.pbar = pbar
186
204
  self.callback = callback
187
205
  self.num_tokens = prompt.count_tokens(sampling_params.max_new_tokens)
188
206
  self.results_arr = results_arr
189
- self.debug = debug
190
207
  self.all_model_names = all_model_names
191
208
  self.all_sampling_params = all_sampling_params
192
209
  self.tools = tools
210
+ self.cache: CachePattern | None = cache
193
211
  self.result = [] # list of APIResponse objects from each attempt
194
212
 
195
213
  # these should be set in the __init__ of the subclass
@@ -199,8 +217,7 @@ class APIRequestBase(ABC):
199
217
  self.region = None
200
218
 
201
219
  def increment_pbar(self):
202
- if self.pbar is not None:
203
- self.pbar.update(1)
220
+ self.status_tracker.increment_pbar()
204
221
 
205
222
  def call_callback(self):
206
223
  if self.callback is not None:
@@ -209,7 +226,6 @@ class APIRequestBase(ABC):
209
226
 
210
227
  def handle_success(self, data):
211
228
  self.call_callback()
212
- self.increment_pbar()
213
229
  self.status_tracker.task_succeeded(self.task_id)
214
230
 
215
231
  def handle_error(self, create_new_request=False, give_up_if_no_other_models=False):
@@ -230,7 +246,8 @@ class APIRequestBase(ABC):
230
246
  if self.attempts_left > 0:
231
247
  self.attempts_left -= 1
232
248
  if not create_new_request:
233
- self.retry_queue.put_nowait(self)
249
+ assert self.status_tracker.retry_queue
250
+ self.status_tracker.retry_queue.put_nowait(self)
234
251
  return
235
252
  else:
236
253
  # make sure we have another model to send it to besides the current one
@@ -244,7 +261,8 @@ class APIRequestBase(ABC):
244
261
  print(
245
262
  f"No other models to try for task {self.task_id}. Retrying with same model."
246
263
  )
247
- self.retry_queue.put_nowait(self)
264
+ assert self.status_tracker.retry_queue
265
+ self.status_tracker.retry_queue.put_nowait(self)
248
266
  else:
249
267
  # two things to change: model_name and sampling_params
250
268
  new_model_name = self.model_name
@@ -269,20 +287,21 @@ class APIRequestBase(ABC):
269
287
  prompt=self.prompt,
270
288
  attempts_left=self.attempts_left,
271
289
  status_tracker=self.status_tracker,
272
- retry_queue=self.retry_queue,
273
290
  results_arr=self.results_arr,
274
291
  request_timeout=self.request_timeout,
275
292
  sampling_params=new_sampling_params,
276
- logprobs=self.logprobs,
277
- top_logprobs=self.top_logprobs,
278
- pbar=self.pbar,
279
293
  callback=self.callback,
280
294
  all_model_names=self.all_model_names,
281
295
  all_sampling_params=self.all_sampling_params,
282
296
  tools=self.tools,
297
+ cache=self.cache,
298
+ computer_use=getattr(self, "computer_use", False),
299
+ display_width=getattr(self, "display_width", 1024),
300
+ display_height=getattr(self, "display_height", 768),
283
301
  )
284
302
  # PROBLEM: new request is never put into results array, so we can't get the result.
285
- self.retry_queue.put_nowait(new_request)
303
+ assert self.status_tracker.retry_queue
304
+ self.status_tracker.retry_queue.put_nowait(self)
286
305
  # SOLUTION: just need to make sure it's deduplicated by task_id later.
287
306
  self.results_arr.append(new_request)
288
307
  else:
@@ -323,14 +342,15 @@ class APIRequestBase(ABC):
323
342
  is_error=True,
324
343
  error_message="Request timed out (terminated by client).",
325
344
  content=None,
326
- input_tokens=None,
327
- output_tokens=None,
345
+ usage=None,
328
346
  )
329
347
  )
330
348
  self.handle_error(create_new_request=False)
331
349
 
332
350
  except Exception as e:
333
351
  raise_if_modal_exception(e)
352
+ tb = traceback.format_exc()
353
+ print(tb)
334
354
  self.result.append(
335
355
  APIResponse(
336
356
  id=self.task_id,
@@ -341,8 +361,7 @@ class APIRequestBase(ABC):
341
361
  is_error=True,
342
362
  error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
343
363
  content=None,
344
- input_tokens=None,
345
- output_tokens=None,
364
+ usage=None,
346
365
  )
347
366
  )
348
367
  # maybe consider making True?
@@ -359,41 +378,75 @@ def create_api_request(
359
378
  prompt: Conversation,
360
379
  attempts_left: int,
361
380
  status_tracker: StatusTracker,
362
- retry_queue: asyncio.Queue,
363
381
  results_arr: list["APIRequestBase"],
364
382
  request_timeout: int = 30,
365
383
  sampling_params: SamplingParams = SamplingParams(),
366
- logprobs: bool = False,
367
- top_logprobs: int | None = None,
368
- pbar: tqdm | None = None,
369
384
  callback: Callable | None = None,
370
385
  all_model_names: list[str] | None = None,
371
386
  all_sampling_params: list[SamplingParams] | None = None,
372
387
  tools: list | None = None,
388
+ cache: CachePattern | None = None,
389
+ computer_use: bool = False,
390
+ display_width: int = 1024,
391
+ display_height: int = 768,
392
+ use_responses_api: bool = False,
373
393
  ) -> APIRequestBase:
374
394
  from .common import CLASSES # circular import so made it lazy, does this work?
375
395
 
376
396
  model_obj = APIModel.from_registry(model_name)
377
- request_class = CLASSES.get(model_obj.api_spec, None)
397
+
398
+ # Choose API spec based on use_responses_api flag and model support
399
+ api_spec = model_obj.api_spec
400
+ if use_responses_api and model_obj.supports_responses and api_spec == "openai":
401
+ api_spec = "openai-responses"
402
+
403
+ request_class = CLASSES.get(api_spec, None)
378
404
  if request_class is None:
379
- raise ValueError(f"Unsupported API spec: {model_obj.api_spec}")
380
- kwargs = (
381
- {} if not logprobs else {"logprobs": logprobs, "top_logprobs": top_logprobs}
382
- )
405
+ raise ValueError(f"Unsupported API spec: {api_spec}")
406
+ kwargs = {}
407
+ # Add computer_use to kwargs if the request class supports it
408
+ model_obj = APIModel.from_registry(model_name)
409
+ if computer_use and api_spec in ["anthropic", "bedrock", "openai-responses"]:
410
+ kwargs.update(
411
+ {
412
+ "computer_use": computer_use,
413
+ "display_width": display_width,
414
+ "display_height": display_height,
415
+ }
416
+ )
417
+
383
418
  return request_class(
384
419
  task_id=task_id,
385
420
  model_name=model_name,
386
421
  prompt=prompt,
387
422
  attempts_left=attempts_left,
388
423
  status_tracker=status_tracker,
389
- retry_queue=retry_queue,
390
424
  results_arr=results_arr,
391
425
  request_timeout=request_timeout,
392
426
  sampling_params=sampling_params,
393
- pbar=pbar,
394
427
  callback=callback,
395
428
  all_model_names=all_model_names,
396
429
  all_sampling_params=all_sampling_params,
397
430
  tools=tools,
431
+ cache=cache,
398
432
  **kwargs,
399
433
  )
434
+
435
+
436
+ def deduplicate_responses(results: list[APIRequestBase]) -> list[APIResponse]:
437
+ deduplicated = {}
438
+ for request in results:
439
+ if request.task_id not in deduplicated:
440
+ deduplicated[request.task_id] = request.result[-1]
441
+ else:
442
+ current_response: APIResponse = deduplicated[request.task_id]
443
+ # only replace if the current request has no completion and the new one does
444
+ if (
445
+ request.result[-1].completion is not None
446
+ and current_response.completion is None
447
+ ):
448
+ deduplicated[request.task_id] = request.result[-1]
449
+
450
+ output = [deduplicated[request.task_id] for request in results]
451
+
452
+ return output