lm-deluge 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1 @@
1
- from .base import create_api_request
2
1
 
3
- __all__ = ["create_api_request"]
@@ -1,35 +1,39 @@
1
- from aiohttp import ClientResponse
2
1
  import json
3
2
  import os
4
- from typing import Callable
3
+
4
+ from aiohttp import ClientResponse
5
5
 
6
6
  from lm_deluge.prompt import (
7
+ CachePattern,
7
8
  Conversation,
8
9
  Message,
9
10
  Text,
10
- ToolCall,
11
11
  Thinking,
12
- CachePattern,
12
+ ToolCall,
13
13
  )
14
- from lm_deluge.tool import Tool
14
+ from lm_deluge.request_context import RequestContext
15
+ from lm_deluge.tool import MCPServer, Tool
15
16
  from lm_deluge.usage import Usage
16
- from .base import APIRequestBase, APIResponse
17
17
 
18
- from ..tracker import StatusTracker
19
18
  from ..config import SamplingParams
20
19
  from ..models import APIModel
21
- from ..computer_use.anthropic_tools import get_anthropic_cu_tools
20
+ from .base import APIRequestBase, APIResponse
21
+
22
+
23
+ def _add_beta(headers: dict, beta: str):
24
+ if "anthropic-beta" in headers and headers["anthropic-beta"]:
25
+ if beta not in headers["anthropic-beta"]:
26
+ headers["anthropic-beta"] += f",{beta}"
27
+ else:
28
+ headers["anthropic-beta"] = beta
22
29
 
23
30
 
24
31
  def _build_anthropic_request(
25
32
  model: APIModel,
26
33
  prompt: Conversation,
27
- tools: list[Tool] | None,
34
+ tools: list[Tool | dict | MCPServer] | None,
28
35
  sampling_params: SamplingParams,
29
36
  cache_pattern: CachePattern | None = None,
30
- computer_use: bool = False,
31
- display_width: int = 1024,
32
- display_height: int = 768,
33
37
  ):
34
38
  system_message, messages = prompt.to_anthropic(cache_pattern=cache_pattern)
35
39
  request_header = {
@@ -38,10 +42,6 @@ def _build_anthropic_request(
38
42
  "content-type": "application/json",
39
43
  }
40
44
 
41
- # Add beta header for Computer Use
42
- if computer_use:
43
- request_header["anthropic-beta"] = "computer-use-2025-01-24"
44
-
45
45
  request_json = {
46
46
  "model": model.name,
47
47
  "messages": messages,
@@ -69,89 +69,61 @@ def _build_anthropic_request(
69
69
  print("ignoring reasoning_effort for non-reasoning model")
70
70
  if system_message is not None:
71
71
  request_json["system"] = system_message
72
- if tools or computer_use:
72
+ if tools:
73
+ mcp_servers = []
73
74
  tool_definitions = []
74
- if tools:
75
- tool_definitions.extend([tool.dump_for("anthropic") for tool in tools])
76
- # Add Computer Use tools
77
- if computer_use:
78
- cu_tools = get_anthropic_cu_tools(
79
- model=model.id,
80
- display_width=display_width, # todo: set from ComputerUseParams
81
- display_height=display_height,
82
- )
83
- tool_definitions.extend(cu_tools)
75
+ for tool in tools:
76
+ if isinstance(tool, Tool):
77
+ tool_definitions.append(tool.dump_for("anthropic"))
78
+ elif isinstance(tool, dict):
79
+ tool_definitions.append(tool)
80
+ # add betas if needed
81
+ if tool["type"] in [
82
+ "computer_20241022",
83
+ "text_editor_20241022",
84
+ "bash_20241022",
85
+ ]:
86
+ _add_beta(request_header, "computer-use-2024-10-22")
87
+ elif tool["type"] == "computer_20250124":
88
+ _add_beta(request_header, "computer-use-2025-01-24")
89
+ elif tool["type"] == "code_execution_20250522":
90
+ _add_beta(request_header, "code-execution-2025-05-22")
91
+ elif isinstance(tool, MCPServer):
92
+ _add_beta(request_header, "mcp-client-2025-04-04")
93
+ mcp_servers.append(tool.for_anthropic())
84
94
 
85
95
  # Add cache control to last tool if tools_only caching is specified
86
96
  if cache_pattern == "tools_only" and tool_definitions:
87
97
  tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
88
98
 
89
99
  request_json["tools"] = tool_definitions
100
+ if len(mcp_servers) > 0:
101
+ request_json["mcp_servers"] = mcp_servers
90
102
 
91
103
  return request_json, request_header
92
104
 
93
105
 
94
106
  class AnthropicRequest(APIRequestBase):
95
- def __init__(
96
- self,
97
- task_id: int,
98
- # should always be 'role', 'content' keys.
99
- # internal logic should handle translating to specific API format
100
- model_name: str, # must correspond to registry
101
- prompt: Conversation,
102
- attempts_left: int,
103
- status_tracker: StatusTracker,
104
- results_arr: list,
105
- request_timeout: int = 30,
106
- sampling_params: SamplingParams = SamplingParams(),
107
- callback: Callable | None = None,
108
- # for retries
109
- all_model_names: list[str] | None = None,
110
- all_sampling_params: list[SamplingParams] | None = None,
111
- tools: list | None = None,
112
- cache: CachePattern | None = None,
113
- # Computer Use support
114
- computer_use: bool = False,
115
- display_width: int = 1024,
116
- display_height: int = 768,
117
- ):
118
- super().__init__(
119
- task_id=task_id,
120
- model_name=model_name,
121
- prompt=prompt,
122
- attempts_left=attempts_left,
123
- status_tracker=status_tracker,
124
- results_arr=results_arr,
125
- request_timeout=request_timeout,
126
- sampling_params=sampling_params,
127
- callback=callback,
128
- all_model_names=all_model_names,
129
- all_sampling_params=all_sampling_params,
130
- tools=tools,
131
- cache=cache,
132
- )
133
- self.computer_use = computer_use
134
- self.display_width = display_width
135
- self.display_height = display_height
136
- self.model = APIModel.from_registry(model_name)
107
+ def __init__(self, context: RequestContext):
108
+ super().__init__(context=context)
109
+
110
+ self.model = APIModel.from_registry(self.context.model_name)
137
111
  self.url = f"{self.model.api_base}/messages"
138
112
 
139
113
  # Lock images as bytes if caching is enabled
140
- if cache is not None:
141
- prompt.lock_images_as_bytes()
114
+ if self.context.cache is not None:
115
+ self.context.prompt.lock_images_as_bytes()
142
116
 
143
117
  self.request_json, self.request_header = _build_anthropic_request(
144
118
  self.model,
145
- prompt,
146
- tools,
147
- sampling_params,
148
- cache,
149
- computer_use,
150
- display_width,
151
- display_height,
119
+ self.context.prompt,
120
+ self.context.tools,
121
+ self.context.sampling_params,
122
+ self.context.cache,
152
123
  )
153
124
 
154
125
  async def handle_response(self, http_response: ClientResponse) -> APIResponse:
126
+ data = None
155
127
  is_error = False
156
128
  error_message = None
157
129
  thinking = None
@@ -160,6 +132,7 @@ class AnthropicRequest(APIRequestBase):
160
132
  status_code = http_response.status
161
133
  mimetype = http_response.headers.get("Content-Type", None)
162
134
  rate_limits = {}
135
+ assert self.context.status_tracker
163
136
  for header in [
164
137
  "anthropic-ratelimit-requests-limit",
165
138
  "anthropic-ratelimit-requests-remaining",
@@ -215,20 +188,21 @@ class AnthropicRequest(APIRequestBase):
215
188
  or "overloaded" in error_message.lower()
216
189
  ):
217
190
  error_message += " (Rate limit error, triggering cooldown.)"
218
- self.status_tracker.rate_limit_exceeded()
191
+ self.context.status_tracker.rate_limit_exceeded()
219
192
  if "context length" in error_message:
220
193
  error_message += " (Context length exceeded, set retries to 0.)"
221
- self.attempts_left = 0
194
+ self.context.attempts_left = 0
222
195
 
223
196
  return APIResponse(
224
- id=self.task_id,
197
+ id=self.context.task_id,
225
198
  status_code=status_code,
226
199
  is_error=is_error,
227
200
  error_message=error_message,
228
- prompt=self.prompt,
201
+ prompt=self.context.prompt,
229
202
  content=content,
230
203
  thinking=thinking,
231
- model_internal=self.model_name,
232
- sampling_params=self.sampling_params,
204
+ model_internal=self.context.model_name,
205
+ sampling_params=self.context.sampling_params,
233
206
  usage=usage,
207
+ raw_response=data,
234
208
  )
@@ -1,18 +1,12 @@
1
1
  import asyncio
2
- import random
3
2
  import traceback
4
3
  from abc import ABC, abstractmethod
5
- from typing import Callable
6
4
 
7
5
  import aiohttp
8
6
  from aiohttp import ClientResponse
9
7
 
10
- from lm_deluge.prompt import CachePattern, Conversation
11
-
12
- from ..config import SamplingParams
13
8
  from ..errors import raise_if_modal_exception
14
- from ..models import APIModel
15
- from ..tracker import StatusTracker
9
+ from ..request_context import RequestContext
16
10
  from .response import APIResponse
17
11
 
18
12
 
@@ -28,40 +22,13 @@ class APIRequestBase(ABC):
28
22
 
29
23
  def __init__(
30
24
  self,
31
- task_id: int,
32
- # should always be 'role', 'content' keys.
33
- # internal logic should handle translating to specific API format
34
- model_name: str, # must correspond to registry
35
- prompt: Conversation,
36
- attempts_left: int,
37
- status_tracker: StatusTracker,
38
- # needed in order to retry with a different model and not throw the output away
39
- results_arr: list["APIRequestBase"],
40
- request_timeout: int = 30,
41
- sampling_params: SamplingParams = SamplingParams(),
42
- callback: Callable | None = None,
43
- all_model_names: list[str] | None = None,
44
- all_sampling_params: list[SamplingParams] | None = None,
45
- tools: list | None = None,
46
- cache: CachePattern | None = None,
25
+ context: RequestContext,
47
26
  ):
48
- if all_model_names is None:
49
- raise ValueError("all_model_names must be provided.")
50
- self.task_id = task_id
51
- self.model_name = model_name
27
+ # If context is provided, use it; otherwise construct one from individual parameters
28
+ self.context = context
29
+
30
+ # Everything is now accessed through self.context - no copying!
52
31
  self.system_prompt = None
53
- self.prompt = prompt
54
- self.attempts_left = attempts_left
55
- self.status_tracker = status_tracker
56
- self.request_timeout = request_timeout
57
- self.sampling_params = sampling_params
58
- self.callback = callback
59
- self.num_tokens = prompt.count_tokens(sampling_params.max_new_tokens)
60
- self.results_arr = results_arr
61
- self.all_model_names = all_model_names
62
- self.all_sampling_params = all_sampling_params
63
- self.tools = tools
64
- self.cache: CachePattern | None = cache
65
32
  self.result = [] # list of APIResponse objects from each attempt
66
33
 
67
34
  # these should be set in the __init__ of the subclass
@@ -71,101 +38,25 @@ class APIRequestBase(ABC):
71
38
  self.region = None
72
39
 
73
40
  def increment_pbar(self):
74
- self.status_tracker.increment_pbar()
41
+ if self.context.status_tracker:
42
+ self.context.status_tracker.increment_pbar()
75
43
 
76
44
  def call_callback(self):
77
- if self.callback is not None:
45
+ if self.context.callback is not None:
78
46
  # the APIResponse in self.result includes all the information
79
- self.callback(self.result[-1], self.status_tracker)
47
+ self.context.callback(self.result[-1], self.context.status_tracker)
80
48
 
81
49
  def handle_success(self, data):
82
50
  self.call_callback()
83
- self.status_tracker.task_succeeded(self.task_id)
84
-
85
- def handle_error(self, create_new_request=False, give_up_if_no_other_models=False):
86
- """
87
- If create_new_request is True, will create a new API request (so that it
88
- has a chance of being sent to a different model). If false, will retry
89
- the same request.
90
- """
91
- last_result: APIResponse = self.result[-1]
92
- error_to_print = f"Error task {self.task_id}. "
93
- error_to_print += (
94
- f"Model: {last_result.model_internal} Code: {last_result.status_code}, "
95
- )
96
- if self.region is not None:
97
- error_to_print += f"Region: {self.region}, "
98
- error_to_print += f"Message: {last_result.error_message}."
99
- print(error_to_print)
100
- if self.attempts_left > 0:
101
- self.attempts_left -= 1
102
- if not create_new_request:
103
- assert self.status_tracker.retry_queue
104
- self.status_tracker.retry_queue.put_nowait(self)
105
- return
106
- else:
107
- # make sure we have another model to send it to besides the current one
108
- if self.all_model_names is None or len(self.all_model_names) < 2:
109
- if give_up_if_no_other_models:
110
- print(
111
- f"No other models to try for task {self.task_id}. Giving up."
112
- )
113
- self.status_tracker.task_failed(self.task_id)
114
- else:
115
- print(
116
- f"No other models to try for task {self.task_id}. Retrying with same model."
117
- )
118
- assert self.status_tracker.retry_queue
119
- self.status_tracker.retry_queue.put_nowait(self)
120
- else:
121
- # two things to change: model_name and sampling_params
122
- new_model_name = self.model_name
123
- new_model_idx = 0
124
- while new_model_name == self.model_name:
125
- new_model_idx = random.randint(0, len(self.all_model_names) - 1)
126
- new_model_name = self.all_model_names[new_model_idx]
127
-
128
- if isinstance(self.all_sampling_params, list):
129
- new_sampling_params = self.all_sampling_params[new_model_idx]
130
- elif isinstance(self.all_sampling_params, SamplingParams):
131
- new_sampling_params = self.all_sampling_params
132
- elif self.all_sampling_params is None:
133
- new_sampling_params = self.sampling_params
134
- else:
135
- new_sampling_params = self.sampling_params
51
+ if self.context.status_tracker:
52
+ self.context.status_tracker.task_succeeded(self.context.task_id)
136
53
 
137
- print("Creating new request with model", new_model_name)
138
- new_request = create_api_request(
139
- task_id=self.task_id,
140
- model_name=new_model_name,
141
- prompt=self.prompt,
142
- attempts_left=self.attempts_left,
143
- status_tracker=self.status_tracker,
144
- results_arr=self.results_arr,
145
- request_timeout=self.request_timeout,
146
- sampling_params=new_sampling_params,
147
- callback=self.callback,
148
- all_model_names=self.all_model_names,
149
- all_sampling_params=self.all_sampling_params,
150
- tools=self.tools,
151
- cache=self.cache,
152
- computer_use=getattr(self, "computer_use", False),
153
- display_width=getattr(self, "display_width", 1024),
154
- display_height=getattr(self, "display_height", 768),
155
- )
156
- # PROBLEM: new request is never put into results array, so we can't get the result.
157
- assert self.status_tracker.retry_queue
158
- self.status_tracker.retry_queue.put_nowait(self)
159
- # SOLUTION: just need to make sure it's deduplicated by task_id later.
160
- self.results_arr.append(new_request)
161
- else:
162
- print(f"Task {self.task_id} out of tries.")
163
- self.status_tracker.task_failed(self.task_id)
164
-
165
- async def call_api(self):
54
+ async def execute_once(self) -> APIResponse:
55
+ """Send the HTTP request once and return the parsed APIResponse."""
56
+ assert self.context.status_tracker
166
57
  try:
167
- self.status_tracker.total_requests += 1
168
- timeout = aiohttp.ClientTimeout(total=self.request_timeout)
58
+ self.context.status_tracker.total_requests += 1
59
+ timeout = aiohttp.ClientTimeout(total=self.context.request_timeout)
169
60
  async with aiohttp.ClientSession(timeout=timeout) as session:
170
61
  assert self.url is not None, "URL is not set"
171
62
  async with session.post(
@@ -174,133 +65,56 @@ class APIRequestBase(ABC):
174
65
  json=self.request_json,
175
66
  ) as http_response:
176
67
  response: APIResponse = await self.handle_response(http_response)
177
-
178
- self.result.append(response)
179
- if response.is_error:
180
- self.handle_error(
181
- create_new_request=response.retry_with_different_model or False,
182
- give_up_if_no_other_models=response.give_up_if_no_other_models
183
- or False,
184
- )
185
- else:
186
- self.handle_success(response)
68
+ return response
187
69
 
188
70
  except asyncio.TimeoutError:
189
- self.result.append(
190
- APIResponse(
191
- id=self.task_id,
192
- model_internal=self.model_name,
193
- prompt=self.prompt,
194
- sampling_params=self.sampling_params,
195
- status_code=None,
196
- is_error=True,
197
- error_message="Request timed out (terminated by client).",
198
- content=None,
199
- usage=None,
200
- )
71
+ return APIResponse(
72
+ id=self.context.task_id,
73
+ model_internal=self.context.model_name,
74
+ prompt=self.context.prompt,
75
+ sampling_params=self.context.sampling_params,
76
+ status_code=None,
77
+ is_error=True,
78
+ error_message="Request timed out (terminated by client).",
79
+ content=None,
80
+ usage=None,
201
81
  )
202
- self.handle_error(create_new_request=False)
203
82
 
204
83
  except Exception as e:
205
84
  raise_if_modal_exception(e)
206
85
  tb = traceback.format_exc()
207
86
  print(tb)
208
- self.result.append(
209
- APIResponse(
210
- id=self.task_id,
211
- model_internal=self.model_name,
212
- prompt=self.prompt,
213
- sampling_params=self.sampling_params,
214
- status_code=None,
215
- is_error=True,
216
- error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
217
- content=None,
218
- usage=None,
219
- )
87
+ return APIResponse(
88
+ id=self.context.task_id,
89
+ model_internal=self.context.model_name,
90
+ prompt=self.context.prompt,
91
+ sampling_params=self.context.sampling_params,
92
+ status_code=None,
93
+ is_error=True,
94
+ error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
95
+ content=None,
96
+ usage=None,
220
97
  )
221
- # maybe consider making True?
222
- self.handle_error(create_new_request=False)
223
98
 
224
99
  @abstractmethod
225
100
  async def handle_response(self, http_response: ClientResponse) -> APIResponse:
226
101
  raise NotImplementedError
227
102
 
228
103
 
229
- def create_api_request(
230
- task_id: int,
231
- model_name: str,
232
- prompt: Conversation,
233
- attempts_left: int,
234
- status_tracker: StatusTracker,
235
- results_arr: list["APIRequestBase"],
236
- request_timeout: int = 30,
237
- sampling_params: SamplingParams = SamplingParams(),
238
- callback: Callable | None = None,
239
- all_model_names: list[str] | None = None,
240
- all_sampling_params: list[SamplingParams] | None = None,
241
- tools: list | None = None,
242
- cache: CachePattern | None = None,
243
- computer_use: bool = False,
244
- display_width: int = 1024,
245
- display_height: int = 768,
246
- use_responses_api: bool = False,
247
- ) -> APIRequestBase:
248
- from .common import CLASSES # circular import so made it lazy, does this work?
249
-
250
- model_obj = APIModel.from_registry(model_name)
251
-
252
- # Choose API spec based on use_responses_api flag and model support
253
- api_spec = model_obj.api_spec
254
- if use_responses_api and model_obj.supports_responses and api_spec == "openai":
255
- api_spec = "openai-responses"
256
-
257
- request_class = CLASSES.get(api_spec, None)
258
- if request_class is None:
259
- raise ValueError(f"Unsupported API spec: {api_spec}")
260
- kwargs = {}
261
- # Add computer_use to kwargs if the request class supports it
262
- model_obj = APIModel.from_registry(model_name)
263
- if computer_use and api_spec in ["anthropic", "bedrock", "openai-responses"]:
264
- kwargs.update(
265
- {
266
- "computer_use": computer_use,
267
- "display_width": display_width,
268
- "display_height": display_height,
269
- }
270
- )
271
-
272
- return request_class(
273
- task_id=task_id,
274
- model_name=model_name,
275
- prompt=prompt,
276
- attempts_left=attempts_left,
277
- status_tracker=status_tracker,
278
- results_arr=results_arr,
279
- request_timeout=request_timeout,
280
- sampling_params=sampling_params,
281
- callback=callback,
282
- all_model_names=all_model_names,
283
- all_sampling_params=all_sampling_params,
284
- tools=tools,
285
- cache=cache,
286
- **kwargs,
287
- )
288
-
289
-
290
104
  def deduplicate_responses(results: list[APIRequestBase]) -> list[APIResponse]:
291
105
  deduplicated = {}
292
106
  for request in results:
293
- if request.task_id not in deduplicated:
294
- deduplicated[request.task_id] = request.result[-1]
107
+ if request.context.task_id not in deduplicated:
108
+ deduplicated[request.context.task_id] = request.result[-1]
295
109
  else:
296
- current_response: APIResponse = deduplicated[request.task_id]
110
+ current_response: APIResponse = deduplicated[request.context.task_id]
297
111
  # only replace if the current request has no completion and the new one does
298
112
  if (
299
113
  request.result[-1].completion is not None
300
114
  and current_response.completion is None
301
115
  ):
302
- deduplicated[request.task_id] = request.result[-1]
116
+ deduplicated[request.context.task_id] = request.result[-1]
303
117
 
304
- output = [deduplicated[request.task_id] for request in results]
118
+ output = [deduplicated[request.context.task_id] for request in results]
305
119
 
306
120
  return output