lm-deluge 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +11 -1
- lm_deluge/agent.py +0 -0
- lm_deluge/api_requests/anthropic.py +90 -58
- lm_deluge/api_requests/base.py +63 -180
- lm_deluge/api_requests/bedrock.py +34 -10
- lm_deluge/api_requests/common.py +2 -1
- lm_deluge/api_requests/mistral.py +6 -15
- lm_deluge/api_requests/openai.py +342 -50
- lm_deluge/api_requests/response.py +153 -0
- lm_deluge/batches.py +498 -0
- lm_deluge/client.py +354 -636
- lm_deluge/computer_use/anthropic_tools.py +75 -0
- lm_deluge/{sampling_params.py → config.py} +12 -4
- lm_deluge/embed.py +17 -11
- lm_deluge/file.py +149 -0
- lm_deluge/models.py +33 -0
- lm_deluge/prompt.py +156 -15
- lm_deluge/rerank.py +18 -12
- lm_deluge/tool.py +11 -1
- lm_deluge/tracker.py +214 -2
- lm_deluge/util/json.py +18 -1
- {lm_deluge-0.0.12.dist-info → lm_deluge-0.0.14.dist-info}/METADATA +8 -5
- lm_deluge-0.0.14.dist-info/RECORD +44 -0
- {lm_deluge-0.0.12.dist-info → lm_deluge-0.0.14.dist-info}/WHEEL +1 -1
- lm_deluge-0.0.12.dist-info/RECORD +0 -39
- {lm_deluge-0.0.12.dist-info → lm_deluge-0.0.14.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.12.dist-info → lm_deluge-0.0.14.dist-info}/top_level.txt +0 -0
lm_deluge/api_requests/common.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
from .openai import OpenAIRequest
|
|
1
|
+
from .openai import OpenAIRequest, OpenAIResponsesRequest
|
|
2
2
|
from .anthropic import AnthropicRequest
|
|
3
3
|
from .mistral import MistralRequest
|
|
4
4
|
from .bedrock import BedrockRequest
|
|
5
5
|
|
|
6
6
|
CLASSES = {
|
|
7
7
|
"openai": OpenAIRequest,
|
|
8
|
+
"openai-responses": OpenAIResponsesRequest,
|
|
8
9
|
"anthropic": AnthropicRequest,
|
|
9
10
|
"mistral": MistralRequest,
|
|
10
11
|
"bedrock": BedrockRequest,
|
|
@@ -1,16 +1,14 @@
|
|
|
1
|
-
import asyncio
|
|
2
1
|
import warnings
|
|
3
2
|
from aiohttp import ClientResponse
|
|
4
3
|
import json
|
|
5
4
|
import os
|
|
6
|
-
from tqdm.auto import tqdm
|
|
7
5
|
from typing import Callable
|
|
8
6
|
|
|
9
7
|
from .base import APIRequestBase, APIResponse
|
|
10
8
|
from ..prompt import Conversation, Message, CachePattern
|
|
11
9
|
from ..usage import Usage
|
|
12
10
|
from ..tracker import StatusTracker
|
|
13
|
-
from ..
|
|
11
|
+
from ..config import SamplingParams
|
|
14
12
|
from ..models import APIModel
|
|
15
13
|
|
|
16
14
|
|
|
@@ -24,15 +22,10 @@ class MistralRequest(APIRequestBase):
|
|
|
24
22
|
prompt: Conversation,
|
|
25
23
|
attempts_left: int,
|
|
26
24
|
status_tracker: StatusTracker,
|
|
27
|
-
retry_queue: asyncio.Queue,
|
|
28
25
|
results_arr: list,
|
|
29
26
|
request_timeout: int = 30,
|
|
30
27
|
sampling_params: SamplingParams = SamplingParams(),
|
|
31
|
-
logprobs: bool = False,
|
|
32
|
-
top_logprobs: int | None = None,
|
|
33
|
-
pbar: tqdm | None = None,
|
|
34
28
|
callback: Callable | None = None,
|
|
35
|
-
debug: bool = False,
|
|
36
29
|
all_model_names: list[str] | None = None,
|
|
37
30
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
38
31
|
tools: list | None = None,
|
|
@@ -44,15 +37,10 @@ class MistralRequest(APIRequestBase):
|
|
|
44
37
|
prompt=prompt,
|
|
45
38
|
attempts_left=attempts_left,
|
|
46
39
|
status_tracker=status_tracker,
|
|
47
|
-
retry_queue=retry_queue,
|
|
48
40
|
results_arr=results_arr,
|
|
49
41
|
request_timeout=request_timeout,
|
|
50
42
|
sampling_params=sampling_params,
|
|
51
|
-
logprobs=logprobs,
|
|
52
|
-
top_logprobs=top_logprobs,
|
|
53
|
-
pbar=pbar,
|
|
54
43
|
callback=callback,
|
|
55
|
-
debug=debug,
|
|
56
44
|
all_model_names=all_model_names,
|
|
57
45
|
all_sampling_params=all_sampling_params,
|
|
58
46
|
tools=tools,
|
|
@@ -80,7 +68,7 @@ class MistralRequest(APIRequestBase):
|
|
|
80
68
|
warnings.warn(
|
|
81
69
|
f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
|
|
82
70
|
)
|
|
83
|
-
if logprobs:
|
|
71
|
+
if sampling_params.logprobs:
|
|
84
72
|
warnings.warn(
|
|
85
73
|
f"Ignoring logprobs param for non-logprobs model: {model_name}"
|
|
86
74
|
)
|
|
@@ -109,7 +97,10 @@ class MistralRequest(APIRequestBase):
|
|
|
109
97
|
try:
|
|
110
98
|
completion = data["choices"][0]["message"]["content"]
|
|
111
99
|
usage = Usage.from_mistral_usage(data["usage"])
|
|
112
|
-
if
|
|
100
|
+
if (
|
|
101
|
+
self.sampling_params.logprobs
|
|
102
|
+
and "logprobs" in data["choices"][0]
|
|
103
|
+
):
|
|
113
104
|
logprobs = data["choices"][0]["logprobs"]["content"]
|
|
114
105
|
except Exception:
|
|
115
106
|
is_error = True
|
lm_deluge/api_requests/openai.py
CHANGED
|
@@ -1,17 +1,56 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import warnings
|
|
3
|
-
from aiohttp import ClientResponse
|
|
4
1
|
import json
|
|
5
2
|
import os
|
|
6
|
-
|
|
3
|
+
import warnings
|
|
7
4
|
from typing import Callable
|
|
8
5
|
|
|
9
|
-
|
|
10
|
-
from
|
|
11
|
-
|
|
12
|
-
from
|
|
13
|
-
|
|
6
|
+
import aiohttp
|
|
7
|
+
from aiohttp import ClientResponse
|
|
8
|
+
|
|
9
|
+
from lm_deluge.tool import Tool
|
|
10
|
+
|
|
11
|
+
from ..config import SamplingParams
|
|
14
12
|
from ..models import APIModel
|
|
13
|
+
from ..prompt import CachePattern, Conversation, Message, Text, Thinking, ToolCall
|
|
14
|
+
from ..tracker import StatusTracker
|
|
15
|
+
from ..usage import Usage
|
|
16
|
+
from .base import APIRequestBase, APIResponse
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _build_oa_chat_request(
|
|
20
|
+
model: APIModel,
|
|
21
|
+
prompt: Conversation,
|
|
22
|
+
tools: list[Tool] | None,
|
|
23
|
+
sampling_params: SamplingParams,
|
|
24
|
+
) -> dict:
|
|
25
|
+
request_json = {
|
|
26
|
+
"model": model.name,
|
|
27
|
+
"messages": prompt.to_openai(),
|
|
28
|
+
"temperature": sampling_params.temperature,
|
|
29
|
+
"top_p": sampling_params.top_p,
|
|
30
|
+
}
|
|
31
|
+
# set max_tokens or max_completion_tokens dep. on provider
|
|
32
|
+
if "cohere" in model.api_base:
|
|
33
|
+
request_json["max_tokens"] = sampling_params.max_new_tokens
|
|
34
|
+
else:
|
|
35
|
+
request_json["max_completion_tokens"] = sampling_params.max_new_tokens
|
|
36
|
+
if model.reasoning_model:
|
|
37
|
+
request_json["temperature"] = 1.0
|
|
38
|
+
request_json["top_p"] = 1.0
|
|
39
|
+
request_json["reasoning_effort"] = sampling_params.reasoning_effort
|
|
40
|
+
else:
|
|
41
|
+
if sampling_params.reasoning_effort:
|
|
42
|
+
warnings.warn(
|
|
43
|
+
f"Ignoring reasoning_effort param for non-reasoning model: {model.name}"
|
|
44
|
+
)
|
|
45
|
+
if sampling_params.logprobs:
|
|
46
|
+
request_json["logprobs"] = True
|
|
47
|
+
if sampling_params.top_logprobs is not None:
|
|
48
|
+
request_json["top_logprobs"] = sampling_params.top_logprobs
|
|
49
|
+
if sampling_params.json_mode and model.supports_json:
|
|
50
|
+
request_json["response_format"] = {"type": "json_object"}
|
|
51
|
+
if tools:
|
|
52
|
+
request_json["tools"] = [tool.dump_for("openai-completions") for tool in tools]
|
|
53
|
+
return request_json
|
|
15
54
|
|
|
16
55
|
|
|
17
56
|
class OpenAIRequest(APIRequestBase):
|
|
@@ -24,15 +63,10 @@ class OpenAIRequest(APIRequestBase):
|
|
|
24
63
|
prompt: Conversation,
|
|
25
64
|
attempts_left: int,
|
|
26
65
|
status_tracker: StatusTracker,
|
|
27
|
-
retry_queue: asyncio.Queue,
|
|
28
66
|
results_arr: list,
|
|
29
67
|
request_timeout: int = 30,
|
|
30
68
|
sampling_params: SamplingParams = SamplingParams(),
|
|
31
|
-
logprobs: bool = False,
|
|
32
|
-
top_logprobs: int | None = None,
|
|
33
|
-
pbar: tqdm | None = None,
|
|
34
69
|
callback: Callable | None = None,
|
|
35
|
-
debug: bool = False,
|
|
36
70
|
all_model_names: list[str] | None = None,
|
|
37
71
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
38
72
|
tools: list | None = None,
|
|
@@ -44,15 +78,10 @@ class OpenAIRequest(APIRequestBase):
|
|
|
44
78
|
prompt=prompt,
|
|
45
79
|
attempts_left=attempts_left,
|
|
46
80
|
status_tracker=status_tracker,
|
|
47
|
-
retry_queue=retry_queue,
|
|
48
81
|
results_arr=results_arr,
|
|
49
82
|
request_timeout=request_timeout,
|
|
50
83
|
sampling_params=sampling_params,
|
|
51
|
-
logprobs=logprobs,
|
|
52
|
-
top_logprobs=top_logprobs,
|
|
53
|
-
pbar=pbar,
|
|
54
84
|
callback=callback,
|
|
55
|
-
debug=debug,
|
|
56
85
|
all_model_names=all_model_names,
|
|
57
86
|
all_sampling_params=all_sampling_params,
|
|
58
87
|
tools=tools,
|
|
@@ -70,36 +99,9 @@ class OpenAIRequest(APIRequestBase):
|
|
|
70
99
|
"Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
|
|
71
100
|
}
|
|
72
101
|
|
|
73
|
-
self.request_json =
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
"temperature": sampling_params.temperature,
|
|
77
|
-
"top_p": sampling_params.top_p,
|
|
78
|
-
}
|
|
79
|
-
# set max_tokens or max_completion_tokens dep. on provider
|
|
80
|
-
if "cohere" in self.model.api_base:
|
|
81
|
-
self.request_json["max_tokens"] = sampling_params.max_new_tokens
|
|
82
|
-
elif "openai" in self.model.api_base:
|
|
83
|
-
self.request_json["max_completion_tokens"] = sampling_params.max_new_tokens
|
|
84
|
-
if self.model.reasoning_model:
|
|
85
|
-
self.request_json["temperature"] = 1.0
|
|
86
|
-
self.request_json["top_p"] = 1.0
|
|
87
|
-
self.request_json["reasoning_effort"] = sampling_params.reasoning_effort
|
|
88
|
-
else:
|
|
89
|
-
if sampling_params.reasoning_effort:
|
|
90
|
-
warnings.warn(
|
|
91
|
-
f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
|
|
92
|
-
)
|
|
93
|
-
if logprobs:
|
|
94
|
-
self.request_json["logprobs"] = True
|
|
95
|
-
if top_logprobs is not None:
|
|
96
|
-
self.request_json["top_logprobs"] = top_logprobs
|
|
97
|
-
if sampling_params.json_mode and self.model.supports_json:
|
|
98
|
-
self.request_json["response_format"] = {"type": "json_object"}
|
|
99
|
-
if tools:
|
|
100
|
-
self.request_json["tools"] = [
|
|
101
|
-
tool.dump_for("openai-completions") for tool in tools
|
|
102
|
-
]
|
|
102
|
+
self.request_json = _build_oa_chat_request(
|
|
103
|
+
self.model, prompt, tools, sampling_params
|
|
104
|
+
)
|
|
103
105
|
|
|
104
106
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
105
107
|
is_error = False
|
|
@@ -111,6 +113,7 @@ class OpenAIRequest(APIRequestBase):
|
|
|
111
113
|
status_code = http_response.status
|
|
112
114
|
mimetype = http_response.headers.get("Content-Type", None)
|
|
113
115
|
data = None
|
|
116
|
+
finish_reason = None
|
|
114
117
|
if status_code >= 200 and status_code < 300:
|
|
115
118
|
try:
|
|
116
119
|
data = await http_response.json()
|
|
@@ -125,6 +128,7 @@ class OpenAIRequest(APIRequestBase):
|
|
|
125
128
|
# Parse response into Message with parts
|
|
126
129
|
parts = []
|
|
127
130
|
message = data["choices"][0]["message"]
|
|
131
|
+
finish_reason = data["choices"][0]["finish_reason"]
|
|
128
132
|
|
|
129
133
|
# Add text content if present
|
|
130
134
|
if message.get("content"):
|
|
@@ -151,7 +155,10 @@ class OpenAIRequest(APIRequestBase):
|
|
|
151
155
|
content = Message("assistant", parts)
|
|
152
156
|
|
|
153
157
|
usage = Usage.from_openai_usage(data["usage"])
|
|
154
|
-
if
|
|
158
|
+
if (
|
|
159
|
+
self.sampling_params.logprobs
|
|
160
|
+
and "logprobs" in data["choices"][0]
|
|
161
|
+
):
|
|
155
162
|
logprobs = data["choices"][0]["logprobs"]["content"]
|
|
156
163
|
except Exception:
|
|
157
164
|
is_error = True
|
|
@@ -186,4 +193,289 @@ class OpenAIRequest(APIRequestBase):
|
|
|
186
193
|
model_internal=self.model_name,
|
|
187
194
|
sampling_params=self.sampling_params,
|
|
188
195
|
usage=usage,
|
|
196
|
+
raw_response=data,
|
|
197
|
+
finish_reason=finish_reason,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class OpenAIResponsesRequest(APIRequestBase):
|
|
202
|
+
def __init__(
|
|
203
|
+
self,
|
|
204
|
+
task_id: int,
|
|
205
|
+
model_name: str,
|
|
206
|
+
prompt: Conversation,
|
|
207
|
+
attempts_left: int,
|
|
208
|
+
status_tracker: StatusTracker,
|
|
209
|
+
results_arr: list,
|
|
210
|
+
request_timeout: int = 30,
|
|
211
|
+
sampling_params: SamplingParams = SamplingParams(),
|
|
212
|
+
callback: Callable | None = None,
|
|
213
|
+
all_model_names: list[str] | None = None,
|
|
214
|
+
all_sampling_params: list[SamplingParams] | None = None,
|
|
215
|
+
tools: list | None = None,
|
|
216
|
+
cache: CachePattern | None = None,
|
|
217
|
+
computer_use: bool = False,
|
|
218
|
+
display_width: int = 1024,
|
|
219
|
+
display_height: int = 768,
|
|
220
|
+
):
|
|
221
|
+
super().__init__(
|
|
222
|
+
task_id=task_id,
|
|
223
|
+
model_name=model_name,
|
|
224
|
+
prompt=prompt,
|
|
225
|
+
attempts_left=attempts_left,
|
|
226
|
+
status_tracker=status_tracker,
|
|
227
|
+
results_arr=results_arr,
|
|
228
|
+
request_timeout=request_timeout,
|
|
229
|
+
sampling_params=sampling_params,
|
|
230
|
+
callback=callback,
|
|
231
|
+
all_model_names=all_model_names,
|
|
232
|
+
all_sampling_params=all_sampling_params,
|
|
233
|
+
tools=tools,
|
|
234
|
+
cache=cache,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Store computer use parameters
|
|
238
|
+
self.computer_use = computer_use
|
|
239
|
+
self.display_width = display_width
|
|
240
|
+
self.display_height = display_height
|
|
241
|
+
|
|
242
|
+
# Validate computer use requirements
|
|
243
|
+
if computer_use and model_name != "openai-computer-use-preview":
|
|
244
|
+
raise ValueError(
|
|
245
|
+
f"Computer use is only supported with openai-computer-use-preview model, got {model_name}"
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Warn if cache is specified for non-Anthropic model
|
|
249
|
+
if cache is not None:
|
|
250
|
+
warnings.warn(
|
|
251
|
+
f"Cache parameter '{cache}' is only supported for Anthropic models, ignoring for {model_name}"
|
|
252
|
+
)
|
|
253
|
+
self.model = APIModel.from_registry(model_name)
|
|
254
|
+
self.url = f"{self.model.api_base}/responses"
|
|
255
|
+
self.request_header = {
|
|
256
|
+
"Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
# Convert conversation to input format for Responses API
|
|
260
|
+
openai_responses_format = prompt.to_openai_responses()
|
|
261
|
+
|
|
262
|
+
self.request_json = {
|
|
263
|
+
"model": self.model.name,
|
|
264
|
+
"input": openai_responses_format["input"],
|
|
265
|
+
"temperature": sampling_params.temperature,
|
|
266
|
+
"top_p": sampling_params.top_p,
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
# Add max_output_tokens for responses API
|
|
270
|
+
if sampling_params.max_new_tokens:
|
|
271
|
+
self.request_json["max_output_tokens"] = sampling_params.max_new_tokens
|
|
272
|
+
|
|
273
|
+
if self.model.reasoning_model:
|
|
274
|
+
if sampling_params.reasoning_effort in [None, "none"]:
|
|
275
|
+
# gemini models can switch reasoning off
|
|
276
|
+
if "gemini" in self.model.id:
|
|
277
|
+
self.sampling_params.reasoning_effort = "none" # expects string
|
|
278
|
+
# openai models can only go down to "low"
|
|
279
|
+
else:
|
|
280
|
+
self.sampling_params.reasoning_effort = "low"
|
|
281
|
+
self.request_json["temperature"] = 1.0
|
|
282
|
+
self.request_json["top_p"] = 1.0
|
|
283
|
+
self.request_json["reasoning"] = {
|
|
284
|
+
"effort": sampling_params.reasoning_effort
|
|
285
|
+
}
|
|
286
|
+
else:
|
|
287
|
+
if sampling_params.reasoning_effort:
|
|
288
|
+
warnings.warn(
|
|
289
|
+
f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
if sampling_params.json_mode and self.model.supports_json:
|
|
293
|
+
self.request_json["text"] = {"format": {"type": "json_object"}}
|
|
294
|
+
|
|
295
|
+
# Handle tools
|
|
296
|
+
request_tools = []
|
|
297
|
+
if computer_use:
|
|
298
|
+
# Add computer use tool
|
|
299
|
+
request_tools.append(
|
|
300
|
+
{
|
|
301
|
+
"type": "computer_use_preview",
|
|
302
|
+
"display_width": display_width,
|
|
303
|
+
"display_height": display_height,
|
|
304
|
+
"environment": "browser", # Default to browser, could be configurable
|
|
305
|
+
}
|
|
306
|
+
)
|
|
307
|
+
# Set truncation to auto as required for computer use
|
|
308
|
+
self.request_json["truncation"] = "auto"
|
|
309
|
+
|
|
310
|
+
if tools:
|
|
311
|
+
# Add regular function tools
|
|
312
|
+
request_tools.extend([tool.dump_for("openai-responses") for tool in tools])
|
|
313
|
+
|
|
314
|
+
if request_tools:
|
|
315
|
+
self.request_json["tools"] = request_tools
|
|
316
|
+
|
|
317
|
+
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
318
|
+
is_error = False
|
|
319
|
+
error_message = None
|
|
320
|
+
thinking = None
|
|
321
|
+
content = None
|
|
322
|
+
usage = None
|
|
323
|
+
logprobs = None
|
|
324
|
+
status_code = http_response.status
|
|
325
|
+
mimetype = http_response.headers.get("Content-Type", None)
|
|
326
|
+
data = None
|
|
327
|
+
|
|
328
|
+
if status_code >= 200 and status_code < 300:
|
|
329
|
+
try:
|
|
330
|
+
data = await http_response.json()
|
|
331
|
+
except Exception:
|
|
332
|
+
is_error = True
|
|
333
|
+
error_message = (
|
|
334
|
+
f"Error calling .json() on response w/ status {status_code}"
|
|
335
|
+
)
|
|
336
|
+
if not is_error:
|
|
337
|
+
assert data is not None, "data is None"
|
|
338
|
+
try:
|
|
339
|
+
# Parse Responses API format
|
|
340
|
+
parts = []
|
|
341
|
+
|
|
342
|
+
# Get the output array from the response
|
|
343
|
+
output = data.get("output", [])
|
|
344
|
+
if not output:
|
|
345
|
+
is_error = True
|
|
346
|
+
error_message = "No output in response"
|
|
347
|
+
else:
|
|
348
|
+
# Process each output item
|
|
349
|
+
for item in output:
|
|
350
|
+
if item.get("type") == "message":
|
|
351
|
+
message_content = item.get("content", [])
|
|
352
|
+
for content_item in message_content:
|
|
353
|
+
if content_item.get("type") == "output_text":
|
|
354
|
+
parts.append(Text(content_item["text"]))
|
|
355
|
+
# Handle tool calls if present
|
|
356
|
+
elif content_item.get("type") == "tool_call":
|
|
357
|
+
tool_call = content_item["tool_call"]
|
|
358
|
+
parts.append(
|
|
359
|
+
ToolCall(
|
|
360
|
+
id=tool_call["id"],
|
|
361
|
+
name=tool_call["function"]["name"],
|
|
362
|
+
arguments=json.loads(
|
|
363
|
+
tool_call["function"]["arguments"]
|
|
364
|
+
),
|
|
365
|
+
)
|
|
366
|
+
)
|
|
367
|
+
elif item.get("type") == "computer_call":
|
|
368
|
+
# Handle computer use actions
|
|
369
|
+
action = item.get("action", {})
|
|
370
|
+
parts.append(
|
|
371
|
+
ToolCall(
|
|
372
|
+
id=item["call_id"],
|
|
373
|
+
name=f"_computer_{action.get('type', 'action')}",
|
|
374
|
+
arguments=action,
|
|
375
|
+
)
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# Handle reasoning if present
|
|
379
|
+
if "reasoning" in data and data["reasoning"].get("summary"):
|
|
380
|
+
thinking = data["reasoning"]["summary"]
|
|
381
|
+
parts.append(Thinking(thinking))
|
|
382
|
+
|
|
383
|
+
content = Message("assistant", parts)
|
|
384
|
+
|
|
385
|
+
# Extract usage information
|
|
386
|
+
if "usage" in data:
|
|
387
|
+
usage = Usage.from_openai_usage(data["usage"])
|
|
388
|
+
|
|
389
|
+
# Extract response_id for computer use continuation
|
|
390
|
+
# response_id = data.get("id")
|
|
391
|
+
|
|
392
|
+
except Exception as e:
|
|
393
|
+
is_error = True
|
|
394
|
+
error_message = f"Error parsing {self.model.name} responses API response: {str(e)}"
|
|
395
|
+
|
|
396
|
+
elif mimetype and "json" in mimetype.lower():
|
|
397
|
+
is_error = True
|
|
398
|
+
data = await http_response.json()
|
|
399
|
+
error_message = json.dumps(data)
|
|
400
|
+
else:
|
|
401
|
+
is_error = True
|
|
402
|
+
text = await http_response.text()
|
|
403
|
+
error_message = text
|
|
404
|
+
|
|
405
|
+
# Handle special kinds of errors
|
|
406
|
+
if is_error and error_message is not None:
|
|
407
|
+
if "rate limit" in error_message.lower() or status_code == 429:
|
|
408
|
+
error_message += " (Rate limit error, triggering cooldown.)"
|
|
409
|
+
self.status_tracker.rate_limit_exceeded()
|
|
410
|
+
if "context length" in error_message:
|
|
411
|
+
error_message += " (Context length exceeded, set retries to 0.)"
|
|
412
|
+
self.attempts_left = 0
|
|
413
|
+
|
|
414
|
+
return APIResponse(
|
|
415
|
+
id=self.task_id,
|
|
416
|
+
status_code=status_code,
|
|
417
|
+
is_error=is_error,
|
|
418
|
+
error_message=error_message,
|
|
419
|
+
prompt=self.prompt,
|
|
420
|
+
logprobs=logprobs,
|
|
421
|
+
thinking=thinking,
|
|
422
|
+
content=content,
|
|
423
|
+
model_internal=self.model_name,
|
|
424
|
+
sampling_params=self.sampling_params,
|
|
425
|
+
usage=usage,
|
|
426
|
+
raw_response=data,
|
|
189
427
|
)
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
async def stream_chat(
|
|
431
|
+
model_name: str, # must correspond to registry
|
|
432
|
+
prompt: Conversation,
|
|
433
|
+
sampling_params: SamplingParams = SamplingParams(),
|
|
434
|
+
tools: list | None = None,
|
|
435
|
+
cache: CachePattern | None = None,
|
|
436
|
+
):
|
|
437
|
+
if cache is not None:
|
|
438
|
+
warnings.warn(
|
|
439
|
+
f"Cache parameter '{cache}' is only supported for Anthropic models, ignoring for {model_name}"
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
model = APIModel.from_registry(model_name)
|
|
443
|
+
if model.api_spec != "openai":
|
|
444
|
+
raise ValueError("streaming only supported on openai models for now")
|
|
445
|
+
url = f"{model.api_base}/chat/completions"
|
|
446
|
+
request_header = {"Authorization": f"Bearer {os.getenv(model.api_key_env_var)}"}
|
|
447
|
+
request_json = _build_oa_chat_request(model, prompt, tools, sampling_params)
|
|
448
|
+
request_json["stream"] = True
|
|
449
|
+
|
|
450
|
+
async with aiohttp.ClientSession() as s:
|
|
451
|
+
async with s.post(url, headers=request_header, json=request_json) as r:
|
|
452
|
+
r.raise_for_status() # bail on 4xx/5xx
|
|
453
|
+
content = ""
|
|
454
|
+
buf = ""
|
|
455
|
+
async for chunk in r.content.iter_any(): # raw bytes
|
|
456
|
+
buf += chunk.decode()
|
|
457
|
+
while "\n\n" in buf: # full SSE frame
|
|
458
|
+
event, buf = buf.split("\n\n", 1)
|
|
459
|
+
if not event.startswith("data:"):
|
|
460
|
+
continue # ignore comments
|
|
461
|
+
data = event[5:].strip() # after "data:"
|
|
462
|
+
if data == "[DONE]":
|
|
463
|
+
yield APIResponse(
|
|
464
|
+
id=0,
|
|
465
|
+
status_code=None,
|
|
466
|
+
is_error=False,
|
|
467
|
+
error_message=None,
|
|
468
|
+
prompt=prompt,
|
|
469
|
+
content=Message(
|
|
470
|
+
role="assistant", parts=[Text(text=content)]
|
|
471
|
+
),
|
|
472
|
+
model_internal=model.id,
|
|
473
|
+
sampling_params=sampling_params,
|
|
474
|
+
usage=None,
|
|
475
|
+
raw_response=None,
|
|
476
|
+
)
|
|
477
|
+
msg = json.loads(data) # SSE payload
|
|
478
|
+
delta = msg["choices"][0]["delta"].get("content")
|
|
479
|
+
if delta:
|
|
480
|
+
content += delta
|
|
481
|
+
yield delta
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import random
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from lm_deluge.prompt import Conversation, Message
|
|
6
|
+
from lm_deluge.usage import Usage
|
|
7
|
+
|
|
8
|
+
from ..config import SamplingParams
|
|
9
|
+
from ..models import APIModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class APIResponse:
|
|
14
|
+
# request information
|
|
15
|
+
id: int # should be unique to the request within a given prompt-processing call
|
|
16
|
+
model_internal: str # our internal model tag
|
|
17
|
+
prompt: Conversation
|
|
18
|
+
sampling_params: SamplingParams
|
|
19
|
+
|
|
20
|
+
# http response information
|
|
21
|
+
status_code: int | None
|
|
22
|
+
is_error: bool | None
|
|
23
|
+
error_message: str | None
|
|
24
|
+
|
|
25
|
+
# completion information - unified usage tracking
|
|
26
|
+
usage: Usage | None = None
|
|
27
|
+
|
|
28
|
+
# response content - structured format
|
|
29
|
+
content: Message | None = None
|
|
30
|
+
|
|
31
|
+
# optional or calculated automatically
|
|
32
|
+
thinking: str | None = None # if model shows thinking tokens
|
|
33
|
+
model_external: str | None = None # the model tag used by the API
|
|
34
|
+
region: str | None = None
|
|
35
|
+
logprobs: list | None = None
|
|
36
|
+
finish_reason: str | None = None # make required later
|
|
37
|
+
cost: float | None = None # calculated automatically
|
|
38
|
+
cache_hit: bool = False # manually set if true
|
|
39
|
+
# set to true if is_error and should be retried with a different model
|
|
40
|
+
retry_with_different_model: bool | None = False
|
|
41
|
+
# set to true if should NOT retry with the same model (unrecoverable error)
|
|
42
|
+
give_up_if_no_other_models: bool | None = False
|
|
43
|
+
# OpenAI Responses API specific - used for computer use continuation
|
|
44
|
+
response_id: str | None = None
|
|
45
|
+
# Raw API response for debugging
|
|
46
|
+
raw_response: dict | None = None
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def completion(self) -> str | None:
|
|
50
|
+
"""Backward compatibility: extract text from content Message."""
|
|
51
|
+
if self.content is not None:
|
|
52
|
+
return self.content.completion
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def input_tokens(self) -> int | None:
|
|
57
|
+
"""Get input tokens from usage object."""
|
|
58
|
+
return self.usage.input_tokens if self.usage else None
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def output_tokens(self) -> int | None:
|
|
62
|
+
"""Get output tokens from usage object."""
|
|
63
|
+
return self.usage.output_tokens if self.usage else None
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def cache_read_tokens(self) -> int | None:
|
|
67
|
+
"""Get cache read tokens from usage object."""
|
|
68
|
+
return self.usage.cache_read_tokens if self.usage else None
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def cache_write_tokens(self) -> int | None:
|
|
72
|
+
"""Get cache write tokens from usage object."""
|
|
73
|
+
return self.usage.cache_write_tokens if self.usage else None
|
|
74
|
+
|
|
75
|
+
def __post_init__(self):
|
|
76
|
+
# calculate cost & get external model name
|
|
77
|
+
self.id = int(self.id)
|
|
78
|
+
api_model = APIModel.from_registry(self.model_internal)
|
|
79
|
+
self.model_external = api_model.name
|
|
80
|
+
self.cost = None
|
|
81
|
+
if (
|
|
82
|
+
self.usage is not None
|
|
83
|
+
and api_model.input_cost is not None
|
|
84
|
+
and api_model.output_cost is not None
|
|
85
|
+
):
|
|
86
|
+
self.cost = (
|
|
87
|
+
self.usage.input_tokens * api_model.input_cost / 1e6
|
|
88
|
+
+ self.usage.output_tokens * api_model.output_cost / 1e6
|
|
89
|
+
)
|
|
90
|
+
elif self.content is not None and self.completion is not None:
|
|
91
|
+
print(
|
|
92
|
+
f"Warning: Completion provided without token counts for model {self.model_internal}."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def to_dict(self):
|
|
96
|
+
return {
|
|
97
|
+
"id": self.id,
|
|
98
|
+
"model_internal": self.model_internal,
|
|
99
|
+
"model_external": self.model_external,
|
|
100
|
+
"region": self.region,
|
|
101
|
+
"prompt": self.prompt.to_log(), # destroys image if present
|
|
102
|
+
"sampling_params": self.sampling_params.__dict__,
|
|
103
|
+
"status_code": self.status_code,
|
|
104
|
+
"is_error": self.is_error,
|
|
105
|
+
"error_message": self.error_message,
|
|
106
|
+
"completion": self.completion, # computed property
|
|
107
|
+
"content": self.content.to_log() if self.content else None,
|
|
108
|
+
"usage": self.usage.to_dict() if self.usage else None,
|
|
109
|
+
"finish_reason": self.finish_reason,
|
|
110
|
+
"cost": self.cost,
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def from_dict(cls, data: dict):
|
|
115
|
+
# Handle backward compatibility for content/completion
|
|
116
|
+
content = None
|
|
117
|
+
if "content" in data and data["content"] is not None:
|
|
118
|
+
# Reconstruct message from log format
|
|
119
|
+
content = Message.from_log(data["content"])
|
|
120
|
+
elif "completion" in data and data["completion"] is not None:
|
|
121
|
+
# Backward compatibility: create a Message with just text
|
|
122
|
+
content = Message.ai(data["completion"])
|
|
123
|
+
|
|
124
|
+
usage = None
|
|
125
|
+
if "usage" in data and data["usage"] is not None:
|
|
126
|
+
usage = Usage.from_dict(data["usage"])
|
|
127
|
+
|
|
128
|
+
return cls(
|
|
129
|
+
id=data.get("id", random.randint(0, 1_000_000_000)),
|
|
130
|
+
model_internal=data["model_internal"],
|
|
131
|
+
prompt=Conversation.from_log(data["prompt"]),
|
|
132
|
+
sampling_params=SamplingParams(**data["sampling_params"]),
|
|
133
|
+
status_code=data["status_code"],
|
|
134
|
+
is_error=data["is_error"],
|
|
135
|
+
error_message=data["error_message"],
|
|
136
|
+
usage=usage,
|
|
137
|
+
content=content,
|
|
138
|
+
thinking=data.get("thinking"),
|
|
139
|
+
model_external=data.get("model_external"),
|
|
140
|
+
region=data.get("region"),
|
|
141
|
+
logprobs=data.get("logprobs"),
|
|
142
|
+
finish_reason=data.get("finish_reason"),
|
|
143
|
+
cost=data.get("cost"),
|
|
144
|
+
cache_hit=data.get("cache_hit", False),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def write_to_file(self, filename):
|
|
148
|
+
"""
|
|
149
|
+
Writes the APIResponse as a line to a file.
|
|
150
|
+
If file exists, appends to it.
|
|
151
|
+
"""
|
|
152
|
+
with open(filename, "a") as f:
|
|
153
|
+
f.write(json.dumps(self.to_dict()) + "\n")
|