lm-deluge 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lm_deluge/api_requests/__init__.py +0 -2
- lm_deluge/api_requests/anthropic.py +58 -84
- lm_deluge/api_requests/base.py +43 -229
- lm_deluge/api_requests/bedrock.py +173 -195
- lm_deluge/api_requests/gemini.py +18 -44
- lm_deluge/api_requests/mistral.py +30 -60
- lm_deluge/api_requests/openai.py +147 -148
- lm_deluge/api_requests/response.py +2 -1
- lm_deluge/batches.py +1 -1
- lm_deluge/{computer_use/anthropic_tools.py → built_in_tools/anthropic.py} +56 -5
- lm_deluge/built_in_tools/openai.py +28 -0
- lm_deluge/client.py +221 -150
- lm_deluge/image.py +13 -8
- lm_deluge/llm_tools/extract.py +23 -4
- lm_deluge/llm_tools/ocr.py +1 -0
- lm_deluge/models.py +39 -2
- lm_deluge/prompt.py +43 -27
- lm_deluge/request_context.py +75 -0
- lm_deluge/tool.py +93 -15
- lm_deluge/tracker.py +1 -0
- {lm_deluge-0.0.15.dist-info → lm_deluge-0.0.16.dist-info}/METADATA +25 -1
- {lm_deluge-0.0.15.dist-info → lm_deluge-0.0.16.dist-info}/RECORD +25 -22
- {lm_deluge-0.0.15.dist-info → lm_deluge-0.0.16.dist-info}/WHEEL +0 -0
- {lm_deluge-0.0.15.dist-info → lm_deluge-0.0.16.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.15.dist-info → lm_deluge-0.0.16.dist-info}/top_level.txt +0 -0
|
@@ -1,78 +1,46 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
from aiohttp import ClientResponse
|
|
3
1
|
import json
|
|
4
2
|
import os
|
|
5
|
-
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
from aiohttp import ClientResponse
|
|
6
6
|
|
|
7
|
-
from .base import APIRequestBase, APIResponse
|
|
8
|
-
from ..prompt import Conversation, Message, CachePattern
|
|
9
|
-
from ..usage import Usage
|
|
10
|
-
from ..tracker import StatusTracker
|
|
11
|
-
from ..config import SamplingParams
|
|
12
7
|
from ..models import APIModel
|
|
8
|
+
from ..prompt import Message
|
|
9
|
+
from ..request_context import RequestContext
|
|
10
|
+
from ..usage import Usage
|
|
11
|
+
from .base import APIRequestBase, APIResponse
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
class MistralRequest(APIRequestBase):
|
|
16
|
-
def __init__(
|
|
17
|
-
|
|
18
|
-
task_id: int,
|
|
19
|
-
# should always be 'role', 'content' keys.
|
|
20
|
-
# internal logic should handle translating to specific API format
|
|
21
|
-
model_name: str, # must correspond to registry
|
|
22
|
-
prompt: Conversation,
|
|
23
|
-
attempts_left: int,
|
|
24
|
-
status_tracker: StatusTracker,
|
|
25
|
-
results_arr: list,
|
|
26
|
-
request_timeout: int = 30,
|
|
27
|
-
sampling_params: SamplingParams = SamplingParams(),
|
|
28
|
-
callback: Callable | None = None,
|
|
29
|
-
all_model_names: list[str] | None = None,
|
|
30
|
-
all_sampling_params: list[SamplingParams] | None = None,
|
|
31
|
-
tools: list | None = None,
|
|
32
|
-
cache: CachePattern | None = None,
|
|
33
|
-
):
|
|
34
|
-
super().__init__(
|
|
35
|
-
task_id=task_id,
|
|
36
|
-
model_name=model_name,
|
|
37
|
-
prompt=prompt,
|
|
38
|
-
attempts_left=attempts_left,
|
|
39
|
-
status_tracker=status_tracker,
|
|
40
|
-
results_arr=results_arr,
|
|
41
|
-
request_timeout=request_timeout,
|
|
42
|
-
sampling_params=sampling_params,
|
|
43
|
-
callback=callback,
|
|
44
|
-
all_model_names=all_model_names,
|
|
45
|
-
all_sampling_params=all_sampling_params,
|
|
46
|
-
tools=tools,
|
|
47
|
-
cache=cache,
|
|
48
|
-
)
|
|
15
|
+
def __init__(self, context: RequestContext):
|
|
16
|
+
super().__init__(context=context)
|
|
49
17
|
|
|
50
18
|
# Warn if cache is specified for non-Anthropic model
|
|
51
|
-
if cache is not None:
|
|
19
|
+
if self.context.cache is not None:
|
|
52
20
|
warnings.warn(
|
|
53
|
-
f"Cache parameter '{cache}' is only supported for Anthropic models, ignoring for {model_name}"
|
|
21
|
+
f"Cache parameter '{self.context.cache}' is only supported for Anthropic models, ignoring for {self.context.model_name}"
|
|
54
22
|
)
|
|
55
|
-
self.model = APIModel.from_registry(model_name)
|
|
23
|
+
self.model = APIModel.from_registry(self.context.model_name)
|
|
56
24
|
self.url = f"{self.model.api_base}/chat/completions"
|
|
57
25
|
self.request_header = {
|
|
58
26
|
"Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
|
|
59
27
|
}
|
|
60
28
|
self.request_json = {
|
|
61
29
|
"model": self.model.name,
|
|
62
|
-
"messages": prompt.to_mistral(),
|
|
63
|
-
"temperature": sampling_params.temperature,
|
|
64
|
-
"top_p": sampling_params.top_p,
|
|
65
|
-
"max_tokens": sampling_params.max_new_tokens,
|
|
30
|
+
"messages": self.context.prompt.to_mistral(),
|
|
31
|
+
"temperature": self.context.sampling_params.temperature,
|
|
32
|
+
"top_p": self.context.sampling_params.top_p,
|
|
33
|
+
"max_tokens": self.context.sampling_params.max_new_tokens,
|
|
66
34
|
}
|
|
67
|
-
if sampling_params.reasoning_effort:
|
|
35
|
+
if self.context.sampling_params.reasoning_effort:
|
|
68
36
|
warnings.warn(
|
|
69
|
-
f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
|
|
37
|
+
f"Ignoring reasoning_effort param for non-reasoning model: {self.context.model_name}"
|
|
70
38
|
)
|
|
71
|
-
if sampling_params.logprobs:
|
|
39
|
+
if self.context.sampling_params.logprobs:
|
|
72
40
|
warnings.warn(
|
|
73
|
-
f"Ignoring logprobs param for non-logprobs model: {model_name}"
|
|
41
|
+
f"Ignoring logprobs param for non-logprobs model: {self.context.model_name}"
|
|
74
42
|
)
|
|
75
|
-
if sampling_params.json_mode and self.model.supports_json:
|
|
43
|
+
if self.context.sampling_params.json_mode and self.model.supports_json:
|
|
76
44
|
self.request_json["response_format"] = {"type": "json_object"}
|
|
77
45
|
|
|
78
46
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
@@ -84,6 +52,8 @@ class MistralRequest(APIRequestBase):
|
|
|
84
52
|
status_code = http_response.status
|
|
85
53
|
mimetype = http_response.headers.get("Content-Type", None)
|
|
86
54
|
data = None
|
|
55
|
+
assert self.context.status_tracker
|
|
56
|
+
|
|
87
57
|
if status_code >= 200 and status_code < 300:
|
|
88
58
|
try:
|
|
89
59
|
data = await http_response.json()
|
|
@@ -98,7 +68,7 @@ class MistralRequest(APIRequestBase):
|
|
|
98
68
|
completion = data["choices"][0]["message"]["content"]
|
|
99
69
|
usage = Usage.from_mistral_usage(data["usage"])
|
|
100
70
|
if (
|
|
101
|
-
self.sampling_params.logprobs
|
|
71
|
+
self.context.sampling_params.logprobs
|
|
102
72
|
and "logprobs" in data["choices"][0]
|
|
103
73
|
):
|
|
104
74
|
logprobs = data["choices"][0]["logprobs"]["content"]
|
|
@@ -118,20 +88,20 @@ class MistralRequest(APIRequestBase):
|
|
|
118
88
|
if is_error and error_message is not None:
|
|
119
89
|
if "rate limit" in error_message.lower() or status_code == 429:
|
|
120
90
|
error_message += " (Rate limit error, triggering cooldown.)"
|
|
121
|
-
self.status_tracker.rate_limit_exceeded()
|
|
91
|
+
self.context.status_tracker.rate_limit_exceeded()
|
|
122
92
|
if "context length" in error_message:
|
|
123
93
|
error_message += " (Context length exceeded, set retries to 0.)"
|
|
124
|
-
self.attempts_left = 0
|
|
94
|
+
self.context.attempts_left = 0
|
|
125
95
|
|
|
126
96
|
return APIResponse(
|
|
127
|
-
id=self.task_id,
|
|
97
|
+
id=self.context.task_id,
|
|
128
98
|
status_code=status_code,
|
|
129
99
|
is_error=is_error,
|
|
130
100
|
error_message=error_message,
|
|
131
|
-
prompt=self.prompt,
|
|
101
|
+
prompt=self.context.prompt,
|
|
132
102
|
logprobs=logprobs,
|
|
133
103
|
content=Message.ai(completion),
|
|
134
|
-
model_internal=self.model_name,
|
|
135
|
-
sampling_params=self.sampling_params,
|
|
104
|
+
model_internal=self.context.model_name,
|
|
105
|
+
sampling_params=self.context.sampling_params,
|
|
136
106
|
usage=usage,
|
|
137
107
|
)
|
lm_deluge/api_requests/openai.py
CHANGED
|
@@ -1,17 +1,16 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
3
|
import warnings
|
|
4
|
-
from typing import Callable
|
|
5
4
|
|
|
6
5
|
import aiohttp
|
|
7
6
|
from aiohttp import ClientResponse
|
|
8
7
|
|
|
9
|
-
from lm_deluge.
|
|
8
|
+
from lm_deluge.request_context import RequestContext
|
|
9
|
+
from lm_deluge.tool import MCPServer, Tool
|
|
10
10
|
|
|
11
11
|
from ..config import SamplingParams
|
|
12
12
|
from ..models import APIModel
|
|
13
13
|
from ..prompt import CachePattern, Conversation, Message, Text, Thinking, ToolCall
|
|
14
|
-
from ..tracker import StatusTracker
|
|
15
14
|
from ..usage import Usage
|
|
16
15
|
from .base import APIRequestBase, APIResponse
|
|
17
16
|
|
|
@@ -53,54 +52,36 @@ def _build_oa_chat_request(
|
|
|
53
52
|
return request_json
|
|
54
53
|
|
|
55
54
|
|
|
55
|
+
def _build_oa_responses_request(
|
|
56
|
+
model: APIModel,
|
|
57
|
+
prompt: Conversation,
|
|
58
|
+
tools: list[Tool] | None,
|
|
59
|
+
sampling_params: SamplingParams,
|
|
60
|
+
):
|
|
61
|
+
pass # TODO: implement
|
|
62
|
+
|
|
63
|
+
|
|
56
64
|
class OpenAIRequest(APIRequestBase):
|
|
57
|
-
def __init__(
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
# should always be 'role', 'content' keys.
|
|
61
|
-
# internal logic should handle translating to specific API format
|
|
62
|
-
model_name: str, # must correspond to registry
|
|
63
|
-
prompt: Conversation,
|
|
64
|
-
attempts_left: int,
|
|
65
|
-
status_tracker: StatusTracker,
|
|
66
|
-
results_arr: list,
|
|
67
|
-
request_timeout: int = 30,
|
|
68
|
-
sampling_params: SamplingParams = SamplingParams(),
|
|
69
|
-
callback: Callable | None = None,
|
|
70
|
-
all_model_names: list[str] | None = None,
|
|
71
|
-
all_sampling_params: list[SamplingParams] | None = None,
|
|
72
|
-
tools: list | None = None,
|
|
73
|
-
cache: CachePattern | None = None,
|
|
74
|
-
):
|
|
75
|
-
super().__init__(
|
|
76
|
-
task_id=task_id,
|
|
77
|
-
model_name=model_name,
|
|
78
|
-
prompt=prompt,
|
|
79
|
-
attempts_left=attempts_left,
|
|
80
|
-
status_tracker=status_tracker,
|
|
81
|
-
results_arr=results_arr,
|
|
82
|
-
request_timeout=request_timeout,
|
|
83
|
-
sampling_params=sampling_params,
|
|
84
|
-
callback=callback,
|
|
85
|
-
all_model_names=all_model_names,
|
|
86
|
-
all_sampling_params=all_sampling_params,
|
|
87
|
-
tools=tools,
|
|
88
|
-
cache=cache,
|
|
89
|
-
)
|
|
65
|
+
def __init__(self, context: RequestContext):
|
|
66
|
+
# Pass context to parent, which will handle backwards compatibility
|
|
67
|
+
super().__init__(context=context)
|
|
90
68
|
|
|
91
69
|
# Warn if cache is specified for non-Anthropic model
|
|
92
|
-
if cache is not None:
|
|
70
|
+
if self.context.cache is not None:
|
|
93
71
|
warnings.warn(
|
|
94
|
-
f"Cache parameter '{cache}' is only supported for Anthropic models, ignoring for {model_name}"
|
|
72
|
+
f"Cache parameter '{self.context.cache}' is only supported for Anthropic models, ignoring for {self.context.model_name}"
|
|
95
73
|
)
|
|
96
|
-
self.model = APIModel.from_registry(model_name)
|
|
74
|
+
self.model = APIModel.from_registry(self.context.model_name)
|
|
97
75
|
self.url = f"{self.model.api_base}/chat/completions"
|
|
98
76
|
self.request_header = {
|
|
99
77
|
"Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
|
|
100
78
|
}
|
|
101
79
|
|
|
102
80
|
self.request_json = _build_oa_chat_request(
|
|
103
|
-
self.model,
|
|
81
|
+
self.model,
|
|
82
|
+
self.context.prompt,
|
|
83
|
+
self.context.tools,
|
|
84
|
+
self.context.sampling_params,
|
|
104
85
|
)
|
|
105
86
|
|
|
106
87
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
@@ -114,6 +95,8 @@ class OpenAIRequest(APIRequestBase):
|
|
|
114
95
|
mimetype = http_response.headers.get("Content-Type", None)
|
|
115
96
|
data = None
|
|
116
97
|
finish_reason = None
|
|
98
|
+
assert self.context.status_tracker
|
|
99
|
+
|
|
117
100
|
if status_code >= 200 and status_code < 300:
|
|
118
101
|
try:
|
|
119
102
|
data = await http_response.json()
|
|
@@ -156,7 +139,7 @@ class OpenAIRequest(APIRequestBase):
|
|
|
156
139
|
|
|
157
140
|
usage = Usage.from_openai_usage(data["usage"])
|
|
158
141
|
if (
|
|
159
|
-
self.sampling_params.logprobs
|
|
142
|
+
self.context.sampling_params.logprobs
|
|
160
143
|
and "logprobs" in data["choices"][0]
|
|
161
144
|
):
|
|
162
145
|
logprobs = data["choices"][0]["logprobs"]["content"]
|
|
@@ -176,22 +159,22 @@ class OpenAIRequest(APIRequestBase):
|
|
|
176
159
|
if is_error and error_message is not None:
|
|
177
160
|
if "rate limit" in error_message.lower() or status_code == 429:
|
|
178
161
|
error_message += " (Rate limit error, triggering cooldown.)"
|
|
179
|
-
self.status_tracker.rate_limit_exceeded()
|
|
162
|
+
self.context.status_tracker.rate_limit_exceeded()
|
|
180
163
|
if "context length" in error_message:
|
|
181
164
|
error_message += " (Context length exceeded, set retries to 0.)"
|
|
182
|
-
self.attempts_left = 0
|
|
165
|
+
self.context.attempts_left = 0
|
|
183
166
|
|
|
184
167
|
return APIResponse(
|
|
185
|
-
id=self.task_id,
|
|
168
|
+
id=self.context.task_id,
|
|
186
169
|
status_code=status_code,
|
|
187
170
|
is_error=is_error,
|
|
188
171
|
error_message=error_message,
|
|
189
|
-
prompt=self.prompt,
|
|
172
|
+
prompt=self.context.prompt,
|
|
190
173
|
logprobs=logprobs,
|
|
191
174
|
thinking=thinking,
|
|
192
175
|
content=content,
|
|
193
|
-
model_internal=self.model_name,
|
|
194
|
-
sampling_params=self.sampling_params,
|
|
176
|
+
model_internal=self.context.model_name,
|
|
177
|
+
sampling_params=self.context.sampling_params,
|
|
195
178
|
usage=usage,
|
|
196
179
|
raw_response=data,
|
|
197
180
|
finish_reason=finish_reason,
|
|
@@ -199,117 +182,78 @@ class OpenAIRequest(APIRequestBase):
|
|
|
199
182
|
|
|
200
183
|
|
|
201
184
|
class OpenAIResponsesRequest(APIRequestBase):
|
|
202
|
-
def __init__(
|
|
203
|
-
|
|
204
|
-
task_id: int,
|
|
205
|
-
model_name: str,
|
|
206
|
-
prompt: Conversation,
|
|
207
|
-
attempts_left: int,
|
|
208
|
-
status_tracker: StatusTracker,
|
|
209
|
-
results_arr: list,
|
|
210
|
-
request_timeout: int = 30,
|
|
211
|
-
sampling_params: SamplingParams = SamplingParams(),
|
|
212
|
-
callback: Callable | None = None,
|
|
213
|
-
all_model_names: list[str] | None = None,
|
|
214
|
-
all_sampling_params: list[SamplingParams] | None = None,
|
|
215
|
-
tools: list | None = None,
|
|
216
|
-
cache: CachePattern | None = None,
|
|
217
|
-
computer_use: bool = False,
|
|
218
|
-
display_width: int = 1024,
|
|
219
|
-
display_height: int = 768,
|
|
220
|
-
):
|
|
221
|
-
super().__init__(
|
|
222
|
-
task_id=task_id,
|
|
223
|
-
model_name=model_name,
|
|
224
|
-
prompt=prompt,
|
|
225
|
-
attempts_left=attempts_left,
|
|
226
|
-
status_tracker=status_tracker,
|
|
227
|
-
results_arr=results_arr,
|
|
228
|
-
request_timeout=request_timeout,
|
|
229
|
-
sampling_params=sampling_params,
|
|
230
|
-
callback=callback,
|
|
231
|
-
all_model_names=all_model_names,
|
|
232
|
-
all_sampling_params=all_sampling_params,
|
|
233
|
-
tools=tools,
|
|
234
|
-
cache=cache,
|
|
235
|
-
)
|
|
236
|
-
|
|
237
|
-
# Store computer use parameters
|
|
238
|
-
self.computer_use = computer_use
|
|
239
|
-
self.display_width = display_width
|
|
240
|
-
self.display_height = display_height
|
|
241
|
-
|
|
242
|
-
# Validate computer use requirements
|
|
243
|
-
if computer_use and model_name != "openai-computer-use-preview":
|
|
244
|
-
raise ValueError(
|
|
245
|
-
f"Computer use is only supported with openai-computer-use-preview model, got {model_name}"
|
|
246
|
-
)
|
|
247
|
-
|
|
185
|
+
def __init__(self, context: RequestContext):
|
|
186
|
+
super().__init__(context)
|
|
248
187
|
# Warn if cache is specified for non-Anthropic model
|
|
249
|
-
if cache is not None:
|
|
188
|
+
if self.context.cache is not None:
|
|
250
189
|
warnings.warn(
|
|
251
|
-
f"Cache parameter '{cache}' is only supported for Anthropic models, ignoring for {model_name}"
|
|
190
|
+
f"Cache parameter '{self.context.cache}' is only supported for Anthropic models, ignoring for {self.context.model_name}"
|
|
252
191
|
)
|
|
253
|
-
self.model = APIModel.from_registry(model_name)
|
|
192
|
+
self.model = APIModel.from_registry(self.context.model_name)
|
|
254
193
|
self.url = f"{self.model.api_base}/responses"
|
|
255
194
|
self.request_header = {
|
|
256
195
|
"Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
|
|
257
196
|
}
|
|
258
197
|
|
|
259
198
|
# Convert conversation to input format for Responses API
|
|
260
|
-
openai_responses_format = prompt.to_openai_responses()
|
|
199
|
+
openai_responses_format = self.context.prompt.to_openai_responses()
|
|
261
200
|
|
|
262
201
|
self.request_json = {
|
|
263
202
|
"model": self.model.name,
|
|
264
203
|
"input": openai_responses_format["input"],
|
|
265
|
-
"temperature": sampling_params.temperature,
|
|
266
|
-
"top_p": sampling_params.top_p,
|
|
204
|
+
"temperature": self.context.sampling_params.temperature,
|
|
205
|
+
"top_p": self.context.sampling_params.top_p,
|
|
267
206
|
}
|
|
268
207
|
|
|
269
208
|
# Add max_output_tokens for responses API
|
|
270
|
-
if sampling_params.max_new_tokens:
|
|
271
|
-
self.request_json["max_output_tokens"] =
|
|
209
|
+
if self.context.sampling_params.max_new_tokens:
|
|
210
|
+
self.request_json["max_output_tokens"] = (
|
|
211
|
+
self.context.sampling_params.max_new_tokens
|
|
212
|
+
)
|
|
272
213
|
|
|
273
214
|
if self.model.reasoning_model:
|
|
274
|
-
if sampling_params.reasoning_effort in [None, "none"]:
|
|
215
|
+
if self.context.sampling_params.reasoning_effort in [None, "none"]:
|
|
275
216
|
# gemini models can switch reasoning off
|
|
276
217
|
if "gemini" in self.model.id:
|
|
277
|
-
self.sampling_params.reasoning_effort =
|
|
218
|
+
self.context.sampling_params.reasoning_effort = (
|
|
219
|
+
"none" # expects string
|
|
220
|
+
)
|
|
278
221
|
# openai models can only go down to "low"
|
|
279
222
|
else:
|
|
280
|
-
self.sampling_params.reasoning_effort = "low"
|
|
223
|
+
self.context.sampling_params.reasoning_effort = "low"
|
|
281
224
|
self.request_json["temperature"] = 1.0
|
|
282
225
|
self.request_json["top_p"] = 1.0
|
|
283
226
|
self.request_json["reasoning"] = {
|
|
284
|
-
"effort": sampling_params.reasoning_effort
|
|
227
|
+
"effort": self.context.sampling_params.reasoning_effort
|
|
285
228
|
}
|
|
286
229
|
else:
|
|
287
|
-
if sampling_params.reasoning_effort:
|
|
230
|
+
if self.context.sampling_params.reasoning_effort:
|
|
288
231
|
warnings.warn(
|
|
289
|
-
f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
|
|
232
|
+
f"Ignoring reasoning_effort param for non-reasoning model: {self.context.model_name}"
|
|
290
233
|
)
|
|
291
234
|
|
|
292
|
-
if sampling_params.json_mode and self.model.supports_json:
|
|
235
|
+
if self.context.sampling_params.json_mode and self.model.supports_json:
|
|
293
236
|
self.request_json["text"] = {"format": {"type": "json_object"}}
|
|
294
237
|
|
|
295
238
|
# Handle tools
|
|
296
239
|
request_tools = []
|
|
297
|
-
if
|
|
298
|
-
# Add computer use tool
|
|
299
|
-
request_tools.append(
|
|
300
|
-
{
|
|
301
|
-
"type": "computer_use_preview",
|
|
302
|
-
"display_width": display_width,
|
|
303
|
-
"display_height": display_height,
|
|
304
|
-
"environment": "browser", # Default to browser, could be configurable
|
|
305
|
-
}
|
|
306
|
-
)
|
|
307
|
-
# Set truncation to auto as required for computer use
|
|
308
|
-
self.request_json["truncation"] = "auto"
|
|
309
|
-
|
|
310
|
-
if tools:
|
|
240
|
+
if self.context.tools:
|
|
311
241
|
# Add regular function tools
|
|
312
|
-
|
|
242
|
+
for tool in self.context.tools:
|
|
243
|
+
if isinstance(tool, Tool):
|
|
244
|
+
request_tools.append(tool.dump_for("openai-responses"))
|
|
245
|
+
elif isinstance(tool, dict):
|
|
246
|
+
# if computer use, make sure model supports it
|
|
247
|
+
if tool["type"] == "computer_use_preview":
|
|
248
|
+
if self.context.model_name != "openai-computer-use-preview":
|
|
249
|
+
raise ValueError(
|
|
250
|
+
f"model {self.context.model_name} does not support computer use"
|
|
251
|
+
)
|
|
252
|
+
# have to use truncation
|
|
253
|
+
self.request_json["truncation"] = "auto"
|
|
254
|
+
request_tools.append(tool) # allow passing dict
|
|
255
|
+
elif isinstance(tool, MCPServer):
|
|
256
|
+
request_tools.append(tool.for_openai_responses())
|
|
313
257
|
|
|
314
258
|
if request_tools:
|
|
315
259
|
self.request_json["tools"] = request_tools
|
|
@@ -324,6 +268,7 @@ class OpenAIResponsesRequest(APIRequestBase):
|
|
|
324
268
|
status_code = http_response.status
|
|
325
269
|
mimetype = http_response.headers.get("Content-Type", None)
|
|
326
270
|
data = None
|
|
271
|
+
assert self.context.status_tracker
|
|
327
272
|
|
|
328
273
|
if status_code >= 200 and status_code < 300:
|
|
329
274
|
try:
|
|
@@ -352,26 +297,83 @@ class OpenAIResponsesRequest(APIRequestBase):
|
|
|
352
297
|
for content_item in message_content:
|
|
353
298
|
if content_item.get("type") == "output_text":
|
|
354
299
|
parts.append(Text(content_item["text"]))
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
300
|
+
elif content_item.get("type") == "refusal":
|
|
301
|
+
parts.append(Text(content_item["refusal"]))
|
|
302
|
+
elif item.get("type") == "reasoning":
|
|
303
|
+
parts.append(Thinking(item["summary"]["text"]))
|
|
304
|
+
elif item.get("type") == "function_call":
|
|
305
|
+
parts.append(
|
|
306
|
+
ToolCall(
|
|
307
|
+
id=item["call_id"],
|
|
308
|
+
name=item["name"],
|
|
309
|
+
arguments=json.loads(item["arguments"]),
|
|
310
|
+
)
|
|
311
|
+
)
|
|
312
|
+
elif item.get("type") == "mcp_call":
|
|
313
|
+
parts.append(
|
|
314
|
+
ToolCall(
|
|
315
|
+
id=item["id"],
|
|
316
|
+
name=item["name"],
|
|
317
|
+
arguments=json.loads(item["arguments"]),
|
|
318
|
+
built_in=True,
|
|
319
|
+
built_in_type="mcp_call",
|
|
320
|
+
extra_body={
|
|
321
|
+
"server_label": item["server_label"],
|
|
322
|
+
"error": item.get("error"),
|
|
323
|
+
"output": item.get("output"),
|
|
324
|
+
},
|
|
325
|
+
)
|
|
326
|
+
)
|
|
327
|
+
|
|
367
328
|
elif item.get("type") == "computer_call":
|
|
368
|
-
# Handle computer use actions
|
|
369
|
-
action = item.get("action", {})
|
|
370
329
|
parts.append(
|
|
371
330
|
ToolCall(
|
|
372
331
|
id=item["call_id"],
|
|
373
|
-
name=
|
|
374
|
-
arguments=action,
|
|
332
|
+
name="computer_call",
|
|
333
|
+
arguments=item.get("action"),
|
|
334
|
+
built_in=True,
|
|
335
|
+
built_in_type="computer_call",
|
|
336
|
+
)
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
elif item.get("type") == "web_search_call":
|
|
340
|
+
parts.append(
|
|
341
|
+
ToolCall(
|
|
342
|
+
id=item["id"],
|
|
343
|
+
name="web_search_call",
|
|
344
|
+
arguments={},
|
|
345
|
+
built_in=True,
|
|
346
|
+
built_in_type="web_search_call",
|
|
347
|
+
extra_body={"status": item["status"]},
|
|
348
|
+
)
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
elif item.get("type") == "file_search_call":
|
|
352
|
+
parts.append(
|
|
353
|
+
ToolCall(
|
|
354
|
+
id=item["id"],
|
|
355
|
+
name="file_search_call",
|
|
356
|
+
arguments={"queries": item["queries"]},
|
|
357
|
+
built_in=True,
|
|
358
|
+
built_in_type="file_search_call",
|
|
359
|
+
extra_body={
|
|
360
|
+
"status": item["status"],
|
|
361
|
+
"results": item["results"],
|
|
362
|
+
},
|
|
363
|
+
)
|
|
364
|
+
)
|
|
365
|
+
elif item.get("type") == "image_generation_call":
|
|
366
|
+
parts.append(
|
|
367
|
+
ToolCall(
|
|
368
|
+
id=item["id"],
|
|
369
|
+
name="image_generation_call",
|
|
370
|
+
arguments={},
|
|
371
|
+
built_in=True,
|
|
372
|
+
built_in_type="image_generation_call",
|
|
373
|
+
extra_body={
|
|
374
|
+
"status": item["status"],
|
|
375
|
+
"result": item["result"],
|
|
376
|
+
},
|
|
375
377
|
)
|
|
376
378
|
)
|
|
377
379
|
|
|
@@ -386,9 +388,6 @@ class OpenAIResponsesRequest(APIRequestBase):
|
|
|
386
388
|
if "usage" in data:
|
|
387
389
|
usage = Usage.from_openai_usage(data["usage"])
|
|
388
390
|
|
|
389
|
-
# Extract response_id for computer use continuation
|
|
390
|
-
# response_id = data.get("id")
|
|
391
|
-
|
|
392
391
|
except Exception as e:
|
|
393
392
|
is_error = True
|
|
394
393
|
error_message = f"Error parsing {self.model.name} responses API response: {str(e)}"
|
|
@@ -406,22 +405,22 @@ class OpenAIResponsesRequest(APIRequestBase):
|
|
|
406
405
|
if is_error and error_message is not None:
|
|
407
406
|
if "rate limit" in error_message.lower() or status_code == 429:
|
|
408
407
|
error_message += " (Rate limit error, triggering cooldown.)"
|
|
409
|
-
self.status_tracker.rate_limit_exceeded()
|
|
408
|
+
self.context.status_tracker.rate_limit_exceeded()
|
|
410
409
|
if "context length" in error_message:
|
|
411
410
|
error_message += " (Context length exceeded, set retries to 0.)"
|
|
412
|
-
self.attempts_left = 0
|
|
411
|
+
self.context.attempts_left = 0
|
|
413
412
|
|
|
414
413
|
return APIResponse(
|
|
415
|
-
id=self.task_id,
|
|
414
|
+
id=self.context.task_id,
|
|
416
415
|
status_code=status_code,
|
|
417
416
|
is_error=is_error,
|
|
418
417
|
error_message=error_message,
|
|
419
|
-
prompt=self.prompt,
|
|
418
|
+
prompt=self.context.prompt,
|
|
420
419
|
logprobs=logprobs,
|
|
421
420
|
thinking=thinking,
|
|
422
421
|
content=content,
|
|
423
|
-
model_internal=self.model_name,
|
|
424
|
-
sampling_params=self.sampling_params,
|
|
422
|
+
model_internal=self.context.model_name,
|
|
423
|
+
sampling_params=self.context.sampling_params,
|
|
425
424
|
usage=usage,
|
|
426
425
|
raw_response=data,
|
|
427
426
|
)
|
|
@@ -35,7 +35,8 @@ class APIResponse:
|
|
|
35
35
|
logprobs: list | None = None
|
|
36
36
|
finish_reason: str | None = None # make required later
|
|
37
37
|
cost: float | None = None # calculated automatically
|
|
38
|
-
cache_hit: bool = False # manually set if true
|
|
38
|
+
cache_hit: bool = False # manually set if true (provider-side caching)
|
|
39
|
+
local_cache_hit: bool = False # set if hit our local dynamic cache
|
|
39
40
|
# set to true if is_error and should be retried with a different model
|
|
40
41
|
retry_with_different_model: bool | None = False
|
|
41
42
|
# set to true if should NOT retry with the same model (unrecoverable error)
|
lm_deluge/batches.py
CHANGED
|
@@ -218,7 +218,7 @@ async def submit_batches_anthropic(
|
|
|
218
218
|
batch_tasks = []
|
|
219
219
|
async with aiohttp.ClientSession() as session:
|
|
220
220
|
for batch in batches:
|
|
221
|
-
url = f"{registry[model]
|
|
221
|
+
url = f"{registry[model].api_base}/messages/batches"
|
|
222
222
|
data = {"requests": batch}
|
|
223
223
|
|
|
224
224
|
async def submit_batch(data, url, headers):
|