lm-deluge 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,8 @@
1
1
  import asyncio
2
2
  import json
3
3
  import os
4
+
4
5
  from aiohttp import ClientResponse
5
- from typing import Callable
6
6
 
7
7
  try:
8
8
  from requests_aws4auth import AWS4Auth
@@ -12,186 +12,178 @@ except ImportError:
12
12
  )
13
13
 
14
14
  from lm_deluge.prompt import (
15
+ CachePattern,
15
16
  Conversation,
16
17
  Message,
17
18
  Text,
18
- ToolCall,
19
19
  Thinking,
20
- CachePattern,
20
+ ToolCall,
21
21
  )
22
+ from lm_deluge.request_context import RequestContext
23
+ from lm_deluge.tool import MCPServer, Tool
22
24
  from lm_deluge.usage import Usage
23
- from .base import APIRequestBase, APIResponse
24
25
 
25
- from ..tracker import StatusTracker
26
26
  from ..config import SamplingParams
27
27
  from ..models import APIModel
28
+ from .base import APIRequestBase, APIResponse
28
29
 
29
30
 
30
- class BedrockRequest(APIRequestBase):
31
- def __init__(
32
- self,
33
- task_id: int,
34
- model_name: str,
35
- prompt: Conversation,
36
- attempts_left: int,
37
- status_tracker: StatusTracker,
38
- results_arr: list,
39
- request_timeout: int = 30,
40
- sampling_params: SamplingParams = SamplingParams(),
41
- callback: Callable | None = None,
42
- all_model_names: list[str] | None = None,
43
- all_sampling_params: list[SamplingParams] | None = None,
44
- tools: list | None = None,
45
- cache: CachePattern | None = None,
46
- # Computer Use support
47
- computer_use: bool = False,
48
- display_width: int = 1024,
49
- display_height: int = 768,
50
- ):
51
- super().__init__(
52
- task_id=task_id,
53
- model_name=model_name,
54
- prompt=prompt,
55
- attempts_left=attempts_left,
56
- status_tracker=status_tracker,
57
- results_arr=results_arr,
58
- request_timeout=request_timeout,
59
- sampling_params=sampling_params,
60
- callback=callback,
61
- all_model_names=all_model_names,
62
- all_sampling_params=all_sampling_params,
63
- tools=tools,
64
- cache=cache,
31
+ # according to bedrock docs the header is "anthropic_beta" vs. "anthropic-beta"
32
+ # for anthropic. i don't know if this is a typo or the worst ever UX
33
+ def _add_beta(headers: dict, beta: str):
34
+ if "anthropic_beta" in headers and headers["anthropic_beta"]:
35
+ if beta not in headers["anthropic_beta"]:
36
+ headers["anthropic_beta"] += f",{beta}"
37
+ else:
38
+ headers["anthropic_beta"] = beta
39
+
40
+
41
+ def _build_anthropic_bedrock_request(
42
+ model: APIModel,
43
+ prompt: Conversation,
44
+ tools: list[Tool | dict | MCPServer] | None,
45
+ sampling_params: SamplingParams,
46
+ cache_pattern: CachePattern | None = None,
47
+ ):
48
+ system_message, messages = prompt.to_anthropic(cache_pattern=cache_pattern)
49
+
50
+ # handle AWS auth
51
+ access_key = os.getenv("AWS_ACCESS_KEY_ID")
52
+ secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
53
+ session_token = os.getenv("AWS_SESSION_TOKEN")
54
+
55
+ if not access_key or not secret_key:
56
+ raise ValueError(
57
+ "AWS credentials not found. Please set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables."
65
58
  )
66
59
 
67
- self.computer_use = computer_use
68
- self.display_width = display_width
69
- self.display_height = display_height
60
+ # Determine region - use us-west-2 for cross-region inference models
61
+ if model.name.startswith("us.anthropic."):
62
+ # Cross-region inference profiles should use us-west-2
63
+ region = "us-west-2"
64
+ else:
65
+ raise ValueError("only cross-region inference for bedrock")
66
+ # # Direct model IDs can use default region
67
+ # region = getattr(model, "region", "us-east-1")
68
+ # if hasattr(model, "regions") and model.regions:
69
+ # if isinstance(model.regions, list):
70
+ # region = model.regions[0]
71
+ # elif isinstance(model.regions, dict):
72
+ # region = list(model.regions.keys())[0]
73
+
74
+ # Construct the endpoint URL
75
+ service = "bedrock" # Service name for signing is 'bedrock' even though endpoint is bedrock-runtime
76
+ url = f"https://bedrock-runtime.{region}.amazonaws.com/model/{model.name}/invoke"
77
+
78
+ # Prepare headers
79
+ auth = AWS4Auth(
80
+ access_key,
81
+ secret_key,
82
+ region,
83
+ service,
84
+ session_token=session_token,
85
+ )
70
86
 
71
- # Lock images as bytes if caching is enabled
72
- if cache is not None:
73
- prompt.lock_images_as_bytes()
87
+ # Setup basic headers (AWS4Auth will add the Authorization header)
88
+ request_header = {
89
+ "Content-Type": "application/json",
90
+ }
91
+
92
+ # Prepare request body in Anthropic's bedrock format
93
+ request_json = {
94
+ "anthropic_version": "bedrock-2023-05-31",
95
+ "max_tokens": sampling_params.max_new_tokens,
96
+ "temperature": sampling_params.temperature,
97
+ "top_p": sampling_params.top_p,
98
+ "messages": messages,
99
+ }
100
+
101
+ if system_message is not None:
102
+ request_json["system"] = system_message
103
+
104
+ if tools:
105
+ mcp_servers = []
106
+ tool_definitions = []
107
+ for tool in tools:
108
+ if isinstance(tool, Tool):
109
+ tool_definitions.append(tool.dump_for("anthropic"))
110
+ elif isinstance(tool, dict):
111
+ tool_definitions.append(tool)
112
+ # add betas if needed
113
+ if tool["type"] in [
114
+ "computer_20241022",
115
+ "text_editor_20241022",
116
+ "bash_20241022",
117
+ ]:
118
+ _add_beta(request_header, "computer-use-2024-10-22")
119
+ elif tool["type"] == "computer_20250124":
120
+ _add_beta(request_header, "computer-use-2025-01-24")
121
+ elif tool["type"] == "code_execution_20250522":
122
+ _add_beta(request_header, "code-execution-2025-05-22")
123
+ elif isinstance(tool, MCPServer):
124
+ raise ValueError("bedrock doesn't support MCP connector right now")
125
+ # _add_beta(request_header, "mcp-client-2025-04-04")
126
+ # mcp_servers.append(tool.for_anthropic())
127
+
128
+ # Add cache control to last tool if tools_only caching is specified
129
+ if cache_pattern == "tools_only" and tool_definitions:
130
+ tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
131
+
132
+ request_json["tools"] = tool_definitions
133
+ if len(mcp_servers) > 0:
134
+ request_json["mcp_servers"] = mcp_servers
135
+
136
+ return request_json, request_header, auth, url
74
137
 
75
- self.model = APIModel.from_registry(model_name)
76
138
 
77
- # Get AWS credentials from environment
78
- self.access_key = os.getenv("AWS_ACCESS_KEY_ID")
79
- self.secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
80
- self.session_token = os.getenv("AWS_SESSION_TOKEN")
139
+ class BedrockRequest(APIRequestBase):
140
+ def __init__(self, context: RequestContext):
141
+ super().__init__(context=context)
142
+
143
+ self.model = APIModel.from_registry(self.context.model_name)
144
+ self.url = f"{self.model.api_base}/messages"
81
145
 
82
- if not self.access_key or not self.secret_key:
83
- raise ValueError(
84
- "AWS credentials not found. Please set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables."
146
+ # Lock images as bytes if caching is enabled
147
+ if self.context.cache is not None:
148
+ self.context.prompt.lock_images_as_bytes()
149
+
150
+ self.request_json, self.request_header, self.auth, self.url = (
151
+ _build_anthropic_bedrock_request(
152
+ self.model,
153
+ context.prompt,
154
+ context.tools,
155
+ context.sampling_params,
156
+ context.cache,
85
157
  )
158
+ )
86
159
 
87
- # Determine region - use us-west-2 for cross-region inference models
88
- if self.model.name.startswith("us.anthropic."):
89
- # Cross-region inference profiles should use us-west-2
90
- self.region = "us-west-2"
91
- else:
92
- # Direct model IDs can use default region
93
- self.region = getattr(self.model, "region", "us-east-1")
94
- if hasattr(self.model, "regions") and self.model.regions:
95
- if isinstance(self.model.regions, list):
96
- self.region = self.model.regions[0]
97
- elif isinstance(self.model.regions, dict):
98
- self.region = list(self.model.regions.keys())[0]
99
-
100
- # Construct the endpoint URL
101
- self.service = "bedrock" # Service name for signing is 'bedrock' even though endpoint is bedrock-runtime
102
- self.url = f"https://bedrock-runtime.{self.region}.amazonaws.com/model/{self.model.name}/invoke"
103
-
104
- # Convert prompt to Anthropic format for bedrock
105
- self.system_message, messages = prompt.to_anthropic(cache_pattern=cache)
106
-
107
- # Prepare request body in Anthropic's bedrock format
108
- self.request_json = {
109
- "anthropic_version": "bedrock-2023-05-31",
110
- "max_tokens": sampling_params.max_new_tokens,
111
- "temperature": sampling_params.temperature,
112
- "top_p": sampling_params.top_p,
113
- "messages": messages,
114
- }
115
-
116
- if self.system_message is not None:
117
- self.request_json["system"] = self.system_message
118
-
119
- if tools or self.computer_use:
120
- tool_definitions = []
121
-
122
- # Add Computer Use tools at the beginning if enabled
123
- if self.computer_use:
124
- from ..computer_use.anthropic_tools import get_anthropic_cu_tools
125
-
126
- cu_tools = get_anthropic_cu_tools(
127
- model=self.model.id,
128
- display_width=self.display_width,
129
- display_height=self.display_height,
130
- )
131
- tool_definitions.extend(cu_tools)
160
+ async def execute_once(self) -> APIResponse:
161
+ """Override execute_once to handle AWS4Auth signing."""
162
+ import aiohttp
132
163
 
133
- # Add computer use display parameters to the request
134
- self.request_json["computer_use_display_width_px"] = self.display_width
135
- self.request_json["computer_use_display_height_px"] = (
136
- self.display_height
137
- )
164
+ assert self.context.status_tracker
138
165
 
139
- # Add user-provided tools
140
- if tools:
141
- tool_definitions.extend([tool.dump_for("anthropic") for tool in tools])
166
+ self.context.status_tracker.total_requests += 1
167
+ timeout = aiohttp.ClientTimeout(total=self.context.request_timeout)
142
168
 
143
- # Add cache control to last tool if tools_only caching is specified
144
- if cache == "tools_only" and tool_definitions:
145
- tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
169
+ # Prepare the request data
170
+ payload = json.dumps(self.request_json, separators=(",", ":")).encode("utf-8")
146
171
 
147
- self.request_json["tools"] = tool_definitions
172
+ # Create a fake requests.PreparedRequest object for AWS4Auth to sign
173
+ import requests
148
174
 
149
- # Setup AWS4Auth for signing
150
- self.auth = AWS4Auth(
151
- self.access_key,
152
- self.secret_key,
153
- self.region,
154
- self.service,
155
- session_token=self.session_token,
175
+ fake_request = requests.Request(
176
+ method="POST",
177
+ url=self.url,
178
+ data=payload,
179
+ headers=self.request_header.copy(),
156
180
  )
157
181
 
158
- # Setup basic headers (AWS4Auth will add the Authorization header)
159
- self.request_header = {
160
- "Content-Type": "application/json",
161
- }
182
+ prepared_request = fake_request.prepare()
183
+ signed_request = self.auth(prepared_request)
184
+ signed_headers = dict(signed_request.headers)
162
185
 
163
- async def call_api(self):
164
- """Override call_api to handle AWS4Auth signing."""
165
186
  try:
166
- import aiohttp
167
-
168
- self.status_tracker.total_requests += 1
169
- timeout = aiohttp.ClientTimeout(total=self.request_timeout)
170
-
171
- # Prepare the request data
172
- payload = json.dumps(self.request_json, separators=(",", ":")).encode(
173
- "utf-8"
174
- )
175
-
176
- # Create a fake requests.PreparedRequest object for AWS4Auth to sign
177
- import requests
178
-
179
- fake_request = requests.Request(
180
- method="POST",
181
- url=self.url,
182
- data=payload,
183
- headers=self.request_header.copy(),
184
- )
185
-
186
- # Prepare the request so AWS4Auth can sign it properly
187
- prepared_request = fake_request.prepare()
188
-
189
- # Let AWS4Auth sign the prepared request
190
- signed_request = self.auth(prepared_request)
191
-
192
- # Extract the signed headers
193
- signed_headers = dict(signed_request.headers)
194
-
195
187
  async with aiohttp.ClientSession(timeout=timeout) as session:
196
188
  async with session.post(
197
189
  url=self.url,
@@ -199,51 +191,36 @@ class BedrockRequest(APIRequestBase):
199
191
  data=payload,
200
192
  ) as http_response:
201
193
  response: APIResponse = await self.handle_response(http_response)
202
-
203
- self.result.append(response)
204
- if response.is_error:
205
- self.handle_error(
206
- create_new_request=response.retry_with_different_model or False,
207
- give_up_if_no_other_models=response.give_up_if_no_other_models
208
- or False,
209
- )
210
- else:
211
- self.handle_success(response)
194
+ return response
212
195
 
213
196
  except asyncio.TimeoutError:
214
- self.result.append(
215
- APIResponse(
216
- id=self.task_id,
217
- model_internal=self.model_name,
218
- prompt=self.prompt,
219
- sampling_params=self.sampling_params,
220
- status_code=None,
221
- is_error=True,
222
- error_message="Request timed out (terminated by client).",
223
- content=None,
224
- usage=None,
225
- )
197
+ return APIResponse(
198
+ id=self.context.task_id,
199
+ model_internal=self.context.model_name,
200
+ prompt=self.context.prompt,
201
+ sampling_params=self.context.sampling_params,
202
+ status_code=None,
203
+ is_error=True,
204
+ error_message="Request timed out (terminated by client).",
205
+ content=None,
206
+ usage=None,
226
207
  )
227
- self.handle_error(create_new_request=False)
228
208
 
229
209
  except Exception as e:
230
210
  from ..errors import raise_if_modal_exception
231
211
 
232
212
  raise_if_modal_exception(e)
233
- self.result.append(
234
- APIResponse(
235
- id=self.task_id,
236
- model_internal=self.model_name,
237
- prompt=self.prompt,
238
- sampling_params=self.sampling_params,
239
- status_code=None,
240
- is_error=True,
241
- error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
242
- content=None,
243
- usage=None,
244
- )
213
+ return APIResponse(
214
+ id=self.context.task_id,
215
+ model_internal=self.context.model_name,
216
+ prompt=self.context.prompt,
217
+ sampling_params=self.context.sampling_params,
218
+ status_code=None,
219
+ is_error=True,
220
+ error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
221
+ content=None,
222
+ usage=None,
245
223
  )
246
- self.handle_error(create_new_request=False)
247
224
 
248
225
  async def handle_response(self, http_response: ClientResponse) -> APIResponse:
249
226
  is_error = False
@@ -253,6 +230,7 @@ class BedrockRequest(APIRequestBase):
253
230
  usage = None
254
231
  status_code = http_response.status
255
232
  mimetype = http_response.headers.get("Content-Type", None)
233
+ assert self.context.status_tracker
256
234
 
257
235
  if status_code >= 200 and status_code < 300:
258
236
  try:
@@ -300,21 +278,21 @@ class BedrockRequest(APIRequestBase):
300
278
  or status_code == 429
301
279
  ):
302
280
  error_message += " (Rate limit error, triggering cooldown.)"
303
- self.status_tracker.rate_limit_exceeded()
281
+ self.context.status_tracker.rate_limit_exceeded()
304
282
  if "context length" in error_message or "too long" in error_message:
305
283
  error_message += " (Context length exceeded, set retries to 0.)"
306
- self.attempts_left = 0
284
+ self.context.attempts_left = 0
307
285
 
308
286
  return APIResponse(
309
- id=self.task_id,
287
+ id=self.context.task_id,
310
288
  status_code=status_code,
311
289
  is_error=is_error,
312
290
  error_message=error_message,
313
- prompt=self.prompt,
291
+ prompt=self.context.prompt,
314
292
  content=content,
315
293
  thinking=thinking,
316
- model_internal=self.model_name,
294
+ model_internal=self.context.model_name,
317
295
  region=self.region,
318
- sampling_params=self.sampling_params,
296
+ sampling_params=self.context.sampling_params,
319
297
  usage=usage,
320
298
  )
@@ -1,16 +1,15 @@
1
1
  import json
2
2
  import os
3
3
  import warnings
4
- from typing import Callable
5
4
 
6
5
  from aiohttp import ClientResponse
7
6
 
7
+ from lm_deluge.request_context import RequestContext
8
8
  from lm_deluge.tool import Tool
9
9
 
10
10
  from ..config import SamplingParams
11
11
  from ..models import APIModel
12
- from ..prompt import CachePattern, Conversation, Message, Text, Thinking, ToolCall
13
- from ..tracker import StatusTracker
12
+ from ..prompt import Conversation, Message, Text, Thinking, ToolCall
14
13
  from ..usage import Usage
15
14
  from .base import APIRequestBase, APIResponse
16
15
 
@@ -66,45 +65,16 @@ def _build_gemini_request(
66
65
 
67
66
 
68
67
  class GeminiRequest(APIRequestBase):
69
- def __init__(
70
- self,
71
- task_id: int,
72
- model_name: str, # must correspond to registry
73
- prompt: Conversation,
74
- attempts_left: int,
75
- status_tracker: StatusTracker,
76
- results_arr: list,
77
- request_timeout: int = 30,
78
- sampling_params: SamplingParams = SamplingParams(),
79
- callback: Callable | None = None,
80
- all_model_names: list[str] | None = None,
81
- all_sampling_params: list[SamplingParams] | None = None,
82
- tools: list | None = None,
83
- cache: CachePattern | None = None,
84
- ):
85
- super().__init__(
86
- task_id=task_id,
87
- model_name=model_name,
88
- prompt=prompt,
89
- attempts_left=attempts_left,
90
- status_tracker=status_tracker,
91
- results_arr=results_arr,
92
- request_timeout=request_timeout,
93
- sampling_params=sampling_params,
94
- callback=callback,
95
- all_model_names=all_model_names,
96
- all_sampling_params=all_sampling_params,
97
- tools=tools,
98
- cache=cache,
99
- )
68
+ def __init__(self, context: RequestContext):
69
+ super().__init__(context=context)
100
70
 
101
71
  # Warn if cache is specified for Gemini model
102
- if cache is not None:
72
+ if self.context.cache is not None:
103
73
  warnings.warn(
104
- f"Cache parameter '{cache}' is not supported for Gemini models, ignoring for {model_name}"
74
+ f"Cache parameter '{self.context.cache}' is not supported for Gemini models, ignoring for {self.context.model_name}"
105
75
  )
106
76
 
107
- self.model = APIModel.from_registry(model_name)
77
+ self.model = APIModel.from_registry(self.context.model_name)
108
78
  # Gemini API endpoint format: https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent
109
79
  self.url = f"{self.model.api_base}/models/{self.model.name}:generateContent"
110
80
  self.request_header = {
@@ -120,7 +90,10 @@ class GeminiRequest(APIRequestBase):
120
90
  self.url += f"?key={api_key}"
121
91
 
122
92
  self.request_json = _build_gemini_request(
123
- self.model, prompt, tools, sampling_params
93
+ self.model,
94
+ self.context.prompt,
95
+ self.context.tools,
96
+ self.context.sampling_params,
124
97
  )
125
98
 
126
99
  async def handle_response(self, http_response: ClientResponse) -> APIResponse:
@@ -132,6 +105,7 @@ class GeminiRequest(APIRequestBase):
132
105
  status_code = http_response.status
133
106
  mimetype = http_response.headers.get("Content-Type", None)
134
107
  data = None
108
+ assert self.context.status_tracker
135
109
 
136
110
  if status_code >= 200 and status_code < 300:
137
111
  try:
@@ -199,24 +173,24 @@ class GeminiRequest(APIRequestBase):
199
173
  if is_error and error_message is not None:
200
174
  if "rate limit" in error_message.lower() or status_code == 429:
201
175
  error_message += " (Rate limit error, triggering cooldown.)"
202
- self.status_tracker.rate_limit_exceeded()
176
+ self.context.status_tracker.rate_limit_exceeded()
203
177
  if (
204
178
  "context length" in error_message.lower()
205
179
  or "token limit" in error_message.lower()
206
180
  ):
207
181
  error_message += " (Context length exceeded, set retries to 0.)"
208
- self.attempts_left = 0
182
+ self.context.attempts_left = 0
209
183
 
210
184
  return APIResponse(
211
- id=self.task_id,
185
+ id=self.context.task_id,
212
186
  status_code=status_code,
213
187
  is_error=is_error,
214
188
  error_message=error_message,
215
- prompt=self.prompt,
189
+ prompt=self.context.prompt,
216
190
  content=content,
217
191
  thinking=thinking,
218
- model_internal=self.model_name,
219
- sampling_params=self.sampling_params,
192
+ model_internal=self.context.model_name,
193
+ sampling_params=self.context.sampling_params,
220
194
  usage=usage,
221
195
  raw_response=data,
222
196
  )