lm-deluge 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +9 -1
- lm_deluge/agent.py +0 -0
- lm_deluge/api_requests/anthropic.py +107 -60
- lm_deluge/api_requests/base.py +107 -54
- lm_deluge/api_requests/bedrock.py +59 -22
- lm_deluge/api_requests/common.py +2 -1
- lm_deluge/api_requests/mistral.py +20 -22
- lm_deluge/api_requests/openai.py +283 -51
- lm_deluge/batches.py +498 -0
- lm_deluge/client.py +373 -634
- lm_deluge/computer_use/anthropic_tools.py +75 -0
- lm_deluge/{sampling_params.py → config.py} +10 -3
- lm_deluge/embed.py +17 -11
- lm_deluge/models.py +33 -0
- lm_deluge/prompt.py +173 -7
- lm_deluge/rerank.py +18 -12
- lm_deluge/tool.py +11 -1
- lm_deluge/tracker.py +212 -2
- lm_deluge/usage.py +114 -0
- lm_deluge/util/json.py +18 -1
- {lm_deluge-0.0.11.dist-info → lm_deluge-0.0.13.dist-info}/METADATA +78 -20
- lm_deluge-0.0.13.dist-info/RECORD +42 -0
- {lm_deluge-0.0.11.dist-info → lm_deluge-0.0.13.dist-info}/WHEEL +1 -1
- lm_deluge-0.0.11.dist-info/RECORD +0 -38
- {lm_deluge-0.0.11.dist-info → lm_deluge-0.0.13.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.11.dist-info → lm_deluge-0.0.13.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,6 @@ import asyncio
|
|
|
2
2
|
import json
|
|
3
3
|
import os
|
|
4
4
|
from aiohttp import ClientResponse
|
|
5
|
-
from tqdm import tqdm
|
|
6
5
|
from typing import Callable
|
|
7
6
|
|
|
8
7
|
try:
|
|
@@ -12,11 +11,19 @@ except ImportError:
|
|
|
12
11
|
"aws4auth is required for bedrock support. Install with: pip install requests-aws4auth"
|
|
13
12
|
)
|
|
14
13
|
|
|
15
|
-
from lm_deluge.prompt import
|
|
14
|
+
from lm_deluge.prompt import (
|
|
15
|
+
Conversation,
|
|
16
|
+
Message,
|
|
17
|
+
Text,
|
|
18
|
+
ToolCall,
|
|
19
|
+
Thinking,
|
|
20
|
+
CachePattern,
|
|
21
|
+
)
|
|
22
|
+
from lm_deluge.usage import Usage
|
|
16
23
|
from .base import APIRequestBase, APIResponse
|
|
17
24
|
|
|
18
25
|
from ..tracker import StatusTracker
|
|
19
|
-
from ..
|
|
26
|
+
from ..config import SamplingParams
|
|
20
27
|
from ..models import APIModel
|
|
21
28
|
|
|
22
29
|
|
|
@@ -28,16 +35,18 @@ class BedrockRequest(APIRequestBase):
|
|
|
28
35
|
prompt: Conversation,
|
|
29
36
|
attempts_left: int,
|
|
30
37
|
status_tracker: StatusTracker,
|
|
31
|
-
retry_queue: asyncio.Queue,
|
|
32
38
|
results_arr: list,
|
|
33
39
|
request_timeout: int = 30,
|
|
34
40
|
sampling_params: SamplingParams = SamplingParams(),
|
|
35
|
-
pbar: tqdm | None = None,
|
|
36
41
|
callback: Callable | None = None,
|
|
37
|
-
debug: bool = False,
|
|
38
42
|
all_model_names: list[str] | None = None,
|
|
39
43
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
40
44
|
tools: list | None = None,
|
|
45
|
+
cache: CachePattern | None = None,
|
|
46
|
+
# Computer Use support
|
|
47
|
+
computer_use: bool = False,
|
|
48
|
+
display_width: int = 1024,
|
|
49
|
+
display_height: int = 768,
|
|
41
50
|
):
|
|
42
51
|
super().__init__(
|
|
43
52
|
task_id=task_id,
|
|
@@ -45,18 +54,24 @@ class BedrockRequest(APIRequestBase):
|
|
|
45
54
|
prompt=prompt,
|
|
46
55
|
attempts_left=attempts_left,
|
|
47
56
|
status_tracker=status_tracker,
|
|
48
|
-
retry_queue=retry_queue,
|
|
49
57
|
results_arr=results_arr,
|
|
50
58
|
request_timeout=request_timeout,
|
|
51
59
|
sampling_params=sampling_params,
|
|
52
|
-
pbar=pbar,
|
|
53
60
|
callback=callback,
|
|
54
|
-
debug=debug,
|
|
55
61
|
all_model_names=all_model_names,
|
|
56
62
|
all_sampling_params=all_sampling_params,
|
|
57
63
|
tools=tools,
|
|
64
|
+
cache=cache,
|
|
58
65
|
)
|
|
59
66
|
|
|
67
|
+
self.computer_use = computer_use
|
|
68
|
+
self.display_width = display_width
|
|
69
|
+
self.display_height = display_height
|
|
70
|
+
|
|
71
|
+
# Lock images as bytes if caching is enabled
|
|
72
|
+
if cache is not None:
|
|
73
|
+
prompt.lock_images_as_bytes()
|
|
74
|
+
|
|
60
75
|
self.model = APIModel.from_registry(model_name)
|
|
61
76
|
|
|
62
77
|
# Get AWS credentials from environment
|
|
@@ -87,7 +102,7 @@ class BedrockRequest(APIRequestBase):
|
|
|
87
102
|
self.url = f"https://bedrock-runtime.{self.region}.amazonaws.com/model/{self.model.name}/invoke"
|
|
88
103
|
|
|
89
104
|
# Convert prompt to Anthropic format for bedrock
|
|
90
|
-
self.system_message, messages = prompt.to_anthropic()
|
|
105
|
+
self.system_message, messages = prompt.to_anthropic(cache_pattern=cache)
|
|
91
106
|
|
|
92
107
|
# Prepare request body in Anthropic's bedrock format
|
|
93
108
|
self.request_json = {
|
|
@@ -101,8 +116,35 @@ class BedrockRequest(APIRequestBase):
|
|
|
101
116
|
if self.system_message is not None:
|
|
102
117
|
self.request_json["system"] = self.system_message
|
|
103
118
|
|
|
104
|
-
if tools:
|
|
105
|
-
|
|
119
|
+
if tools or self.computer_use:
|
|
120
|
+
tool_definitions = []
|
|
121
|
+
|
|
122
|
+
# Add Computer Use tools at the beginning if enabled
|
|
123
|
+
if self.computer_use:
|
|
124
|
+
from ..computer_use.anthropic_tools import get_anthropic_cu_tools
|
|
125
|
+
|
|
126
|
+
cu_tools = get_anthropic_cu_tools(
|
|
127
|
+
model=self.model.id,
|
|
128
|
+
display_width=self.display_width,
|
|
129
|
+
display_height=self.display_height,
|
|
130
|
+
)
|
|
131
|
+
tool_definitions.extend(cu_tools)
|
|
132
|
+
|
|
133
|
+
# Add computer use display parameters to the request
|
|
134
|
+
self.request_json["computer_use_display_width_px"] = self.display_width
|
|
135
|
+
self.request_json["computer_use_display_height_px"] = (
|
|
136
|
+
self.display_height
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Add user-provided tools
|
|
140
|
+
if tools:
|
|
141
|
+
tool_definitions.extend([tool.dump_for("anthropic") for tool in tools])
|
|
142
|
+
|
|
143
|
+
# Add cache control to last tool if tools_only caching is specified
|
|
144
|
+
if cache == "tools_only" and tool_definitions:
|
|
145
|
+
tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
|
|
146
|
+
|
|
147
|
+
self.request_json["tools"] = tool_definitions
|
|
106
148
|
|
|
107
149
|
# Setup AWS4Auth for signing
|
|
108
150
|
self.auth = AWS4Auth(
|
|
@@ -179,8 +221,7 @@ class BedrockRequest(APIRequestBase):
|
|
|
179
221
|
is_error=True,
|
|
180
222
|
error_message="Request timed out (terminated by client).",
|
|
181
223
|
content=None,
|
|
182
|
-
|
|
183
|
-
output_tokens=None,
|
|
224
|
+
usage=None,
|
|
184
225
|
)
|
|
185
226
|
)
|
|
186
227
|
self.handle_error(create_new_request=False)
|
|
@@ -199,8 +240,7 @@ class BedrockRequest(APIRequestBase):
|
|
|
199
240
|
is_error=True,
|
|
200
241
|
error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
|
|
201
242
|
content=None,
|
|
202
|
-
|
|
203
|
-
output_tokens=None,
|
|
243
|
+
usage=None,
|
|
204
244
|
)
|
|
205
245
|
)
|
|
206
246
|
self.handle_error(create_new_request=False)
|
|
@@ -210,8 +250,7 @@ class BedrockRequest(APIRequestBase):
|
|
|
210
250
|
error_message = None
|
|
211
251
|
thinking = None
|
|
212
252
|
content = None
|
|
213
|
-
|
|
214
|
-
output_tokens = None
|
|
253
|
+
usage = None
|
|
215
254
|
status_code = http_response.status
|
|
216
255
|
mimetype = http_response.headers.get("Content-Type", None)
|
|
217
256
|
|
|
@@ -238,8 +277,7 @@ class BedrockRequest(APIRequestBase):
|
|
|
238
277
|
)
|
|
239
278
|
|
|
240
279
|
content = Message("assistant", parts)
|
|
241
|
-
|
|
242
|
-
output_tokens = data["usage"]["output_tokens"]
|
|
280
|
+
usage = Usage.from_anthropic_usage(data["usage"])
|
|
243
281
|
except Exception as e:
|
|
244
282
|
is_error = True
|
|
245
283
|
error_message = (
|
|
@@ -278,6 +316,5 @@ class BedrockRequest(APIRequestBase):
|
|
|
278
316
|
model_internal=self.model_name,
|
|
279
317
|
region=self.region,
|
|
280
318
|
sampling_params=self.sampling_params,
|
|
281
|
-
|
|
282
|
-
output_tokens=output_tokens,
|
|
319
|
+
usage=usage,
|
|
283
320
|
)
|
lm_deluge/api_requests/common.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
from .openai import OpenAIRequest
|
|
1
|
+
from .openai import OpenAIRequest, OpenAIResponsesRequest
|
|
2
2
|
from .anthropic import AnthropicRequest
|
|
3
3
|
from .mistral import MistralRequest
|
|
4
4
|
from .bedrock import BedrockRequest
|
|
5
5
|
|
|
6
6
|
CLASSES = {
|
|
7
7
|
"openai": OpenAIRequest,
|
|
8
|
+
"openai-responses": OpenAIResponsesRequest,
|
|
8
9
|
"anthropic": AnthropicRequest,
|
|
9
10
|
"mistral": MistralRequest,
|
|
10
11
|
"bedrock": BedrockRequest,
|
|
@@ -1,15 +1,14 @@
|
|
|
1
|
-
import asyncio
|
|
2
1
|
import warnings
|
|
3
2
|
from aiohttp import ClientResponse
|
|
4
3
|
import json
|
|
5
4
|
import os
|
|
6
|
-
from tqdm.auto import tqdm
|
|
7
5
|
from typing import Callable
|
|
8
6
|
|
|
9
7
|
from .base import APIRequestBase, APIResponse
|
|
10
|
-
from ..prompt import Conversation, Message
|
|
8
|
+
from ..prompt import Conversation, Message, CachePattern
|
|
9
|
+
from ..usage import Usage
|
|
11
10
|
from ..tracker import StatusTracker
|
|
12
|
-
from ..
|
|
11
|
+
from ..config import SamplingParams
|
|
13
12
|
from ..models import APIModel
|
|
14
13
|
|
|
15
14
|
|
|
@@ -23,18 +22,14 @@ class MistralRequest(APIRequestBase):
|
|
|
23
22
|
prompt: Conversation,
|
|
24
23
|
attempts_left: int,
|
|
25
24
|
status_tracker: StatusTracker,
|
|
26
|
-
retry_queue: asyncio.Queue,
|
|
27
25
|
results_arr: list,
|
|
28
26
|
request_timeout: int = 30,
|
|
29
27
|
sampling_params: SamplingParams = SamplingParams(),
|
|
30
|
-
logprobs: bool = False,
|
|
31
|
-
top_logprobs: int | None = None,
|
|
32
|
-
pbar: tqdm | None = None,
|
|
33
28
|
callback: Callable | None = None,
|
|
34
|
-
debug: bool = False,
|
|
35
29
|
all_model_names: list[str] | None = None,
|
|
36
30
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
37
31
|
tools: list | None = None,
|
|
32
|
+
cache: CachePattern | None = None,
|
|
38
33
|
):
|
|
39
34
|
super().__init__(
|
|
40
35
|
task_id=task_id,
|
|
@@ -42,18 +37,21 @@ class MistralRequest(APIRequestBase):
|
|
|
42
37
|
prompt=prompt,
|
|
43
38
|
attempts_left=attempts_left,
|
|
44
39
|
status_tracker=status_tracker,
|
|
45
|
-
retry_queue=retry_queue,
|
|
46
40
|
results_arr=results_arr,
|
|
47
41
|
request_timeout=request_timeout,
|
|
48
42
|
sampling_params=sampling_params,
|
|
49
|
-
logprobs=logprobs,
|
|
50
|
-
top_logprobs=top_logprobs,
|
|
51
|
-
pbar=pbar,
|
|
52
43
|
callback=callback,
|
|
53
|
-
debug=debug,
|
|
54
44
|
all_model_names=all_model_names,
|
|
55
45
|
all_sampling_params=all_sampling_params,
|
|
46
|
+
tools=tools,
|
|
47
|
+
cache=cache,
|
|
56
48
|
)
|
|
49
|
+
|
|
50
|
+
# Warn if cache is specified for non-Anthropic model
|
|
51
|
+
if cache is not None:
|
|
52
|
+
warnings.warn(
|
|
53
|
+
f"Cache parameter '{cache}' is only supported for Anthropic models, ignoring for {model_name}"
|
|
54
|
+
)
|
|
57
55
|
self.model = APIModel.from_registry(model_name)
|
|
58
56
|
self.url = f"{self.model.api_base}/chat/completions"
|
|
59
57
|
self.request_header = {
|
|
@@ -70,7 +68,7 @@ class MistralRequest(APIRequestBase):
|
|
|
70
68
|
warnings.warn(
|
|
71
69
|
f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
|
|
72
70
|
)
|
|
73
|
-
if logprobs:
|
|
71
|
+
if sampling_params.logprobs:
|
|
74
72
|
warnings.warn(
|
|
75
73
|
f"Ignoring logprobs param for non-logprobs model: {model_name}"
|
|
76
74
|
)
|
|
@@ -81,8 +79,7 @@ class MistralRequest(APIRequestBase):
|
|
|
81
79
|
is_error = False
|
|
82
80
|
error_message = None
|
|
83
81
|
completion = None
|
|
84
|
-
|
|
85
|
-
output_tokens = None
|
|
82
|
+
usage = None
|
|
86
83
|
logprobs = None
|
|
87
84
|
status_code = http_response.status
|
|
88
85
|
mimetype = http_response.headers.get("Content-Type", None)
|
|
@@ -99,9 +96,11 @@ class MistralRequest(APIRequestBase):
|
|
|
99
96
|
assert data is not None, "data is None"
|
|
100
97
|
try:
|
|
101
98
|
completion = data["choices"][0]["message"]["content"]
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
99
|
+
usage = Usage.from_mistral_usage(data["usage"])
|
|
100
|
+
if (
|
|
101
|
+
self.sampling_params.logprobs
|
|
102
|
+
and "logprobs" in data["choices"][0]
|
|
103
|
+
):
|
|
105
104
|
logprobs = data["choices"][0]["logprobs"]["content"]
|
|
106
105
|
except Exception:
|
|
107
106
|
is_error = True
|
|
@@ -134,6 +133,5 @@ class MistralRequest(APIRequestBase):
|
|
|
134
133
|
content=Message.ai(completion),
|
|
135
134
|
model_internal=self.model_name,
|
|
136
135
|
sampling_params=self.sampling_params,
|
|
137
|
-
|
|
138
|
-
output_tokens=output_tokens,
|
|
136
|
+
usage=usage,
|
|
139
137
|
)
|
lm_deluge/api_requests/openai.py
CHANGED
|
@@ -1,18 +1,56 @@
|
|
|
1
|
-
import asyncio
|
|
2
1
|
import warnings
|
|
3
2
|
from aiohttp import ClientResponse
|
|
4
3
|
import json
|
|
5
4
|
import os
|
|
6
|
-
from tqdm.auto import tqdm
|
|
7
5
|
from typing import Callable
|
|
8
6
|
|
|
7
|
+
from lm_deluge.tool import Tool
|
|
8
|
+
|
|
9
9
|
from .base import APIRequestBase, APIResponse
|
|
10
|
-
from ..prompt import Conversation, Message, Text, ToolCall, Thinking
|
|
10
|
+
from ..prompt import Conversation, Message, Text, ToolCall, Thinking, CachePattern
|
|
11
|
+
from ..usage import Usage
|
|
11
12
|
from ..tracker import StatusTracker
|
|
12
|
-
from ..
|
|
13
|
+
from ..config import SamplingParams
|
|
13
14
|
from ..models import APIModel
|
|
14
15
|
|
|
15
16
|
|
|
17
|
+
def _build_oa_chat_request(
|
|
18
|
+
model: APIModel,
|
|
19
|
+
prompt: Conversation,
|
|
20
|
+
tools: list[Tool] | None,
|
|
21
|
+
sampling_params: SamplingParams,
|
|
22
|
+
) -> dict:
|
|
23
|
+
request_json = {
|
|
24
|
+
"model": model.name,
|
|
25
|
+
"messages": prompt.to_openai(),
|
|
26
|
+
"temperature": sampling_params.temperature,
|
|
27
|
+
"top_p": sampling_params.top_p,
|
|
28
|
+
}
|
|
29
|
+
# set max_tokens or max_completion_tokens dep. on provider
|
|
30
|
+
if "cohere" in model.api_base:
|
|
31
|
+
request_json["max_tokens"] = sampling_params.max_new_tokens
|
|
32
|
+
else:
|
|
33
|
+
request_json["max_completion_tokens"] = sampling_params.max_new_tokens
|
|
34
|
+
if model.reasoning_model:
|
|
35
|
+
request_json["temperature"] = 1.0
|
|
36
|
+
request_json["top_p"] = 1.0
|
|
37
|
+
request_json["reasoning_effort"] = sampling_params.reasoning_effort
|
|
38
|
+
else:
|
|
39
|
+
if sampling_params.reasoning_effort:
|
|
40
|
+
warnings.warn(
|
|
41
|
+
f"Ignoring reasoning_effort param for non-reasoning model: {model.name}"
|
|
42
|
+
)
|
|
43
|
+
if sampling_params.logprobs:
|
|
44
|
+
request_json["logprobs"] = True
|
|
45
|
+
if sampling_params.top_logprobs is not None:
|
|
46
|
+
request_json["top_logprobs"] = sampling_params.top_logprobs
|
|
47
|
+
if sampling_params.json_mode and model.supports_json:
|
|
48
|
+
request_json["response_format"] = {"type": "json_object"}
|
|
49
|
+
if tools:
|
|
50
|
+
request_json["tools"] = [tool.dump_for("openai-completions") for tool in tools]
|
|
51
|
+
return request_json
|
|
52
|
+
|
|
53
|
+
|
|
16
54
|
class OpenAIRequest(APIRequestBase):
|
|
17
55
|
def __init__(
|
|
18
56
|
self,
|
|
@@ -23,18 +61,14 @@ class OpenAIRequest(APIRequestBase):
|
|
|
23
61
|
prompt: Conversation,
|
|
24
62
|
attempts_left: int,
|
|
25
63
|
status_tracker: StatusTracker,
|
|
26
|
-
retry_queue: asyncio.Queue,
|
|
27
64
|
results_arr: list,
|
|
28
65
|
request_timeout: int = 30,
|
|
29
66
|
sampling_params: SamplingParams = SamplingParams(),
|
|
30
|
-
logprobs: bool = False,
|
|
31
|
-
top_logprobs: int | None = None,
|
|
32
|
-
pbar: tqdm | None = None,
|
|
33
67
|
callback: Callable | None = None,
|
|
34
|
-
debug: bool = False,
|
|
35
68
|
all_model_names: list[str] | None = None,
|
|
36
69
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
37
70
|
tools: list | None = None,
|
|
71
|
+
cache: CachePattern | None = None,
|
|
38
72
|
):
|
|
39
73
|
super().__init__(
|
|
40
74
|
task_id=task_id,
|
|
@@ -42,63 +76,37 @@ class OpenAIRequest(APIRequestBase):
|
|
|
42
76
|
prompt=prompt,
|
|
43
77
|
attempts_left=attempts_left,
|
|
44
78
|
status_tracker=status_tracker,
|
|
45
|
-
retry_queue=retry_queue,
|
|
46
79
|
results_arr=results_arr,
|
|
47
80
|
request_timeout=request_timeout,
|
|
48
81
|
sampling_params=sampling_params,
|
|
49
|
-
logprobs=logprobs,
|
|
50
|
-
top_logprobs=top_logprobs,
|
|
51
|
-
pbar=pbar,
|
|
52
82
|
callback=callback,
|
|
53
|
-
debug=debug,
|
|
54
83
|
all_model_names=all_model_names,
|
|
55
84
|
all_sampling_params=all_sampling_params,
|
|
56
85
|
tools=tools,
|
|
86
|
+
cache=cache,
|
|
57
87
|
)
|
|
88
|
+
|
|
89
|
+
# Warn if cache is specified for non-Anthropic model
|
|
90
|
+
if cache is not None:
|
|
91
|
+
warnings.warn(
|
|
92
|
+
f"Cache parameter '{cache}' is only supported for Anthropic models, ignoring for {model_name}"
|
|
93
|
+
)
|
|
58
94
|
self.model = APIModel.from_registry(model_name)
|
|
59
95
|
self.url = f"{self.model.api_base}/chat/completions"
|
|
60
96
|
self.request_header = {
|
|
61
97
|
"Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
|
|
62
98
|
}
|
|
63
99
|
|
|
64
|
-
self.request_json =
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
"temperature": sampling_params.temperature,
|
|
68
|
-
"top_p": sampling_params.top_p,
|
|
69
|
-
}
|
|
70
|
-
# set max_tokens or max_completion_tokens dep. on provider
|
|
71
|
-
if "cohere" in self.model.api_base:
|
|
72
|
-
self.request_json["max_tokens"] = sampling_params.max_new_tokens
|
|
73
|
-
elif "openai" in self.model.api_base:
|
|
74
|
-
self.request_json["max_completion_tokens"] = sampling_params.max_new_tokens
|
|
75
|
-
if self.model.reasoning_model:
|
|
76
|
-
self.request_json["temperature"] = 1.0
|
|
77
|
-
self.request_json["top_p"] = 1.0
|
|
78
|
-
self.request_json["reasoning_effort"] = sampling_params.reasoning_effort
|
|
79
|
-
else:
|
|
80
|
-
if sampling_params.reasoning_effort:
|
|
81
|
-
warnings.warn(
|
|
82
|
-
f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
|
|
83
|
-
)
|
|
84
|
-
if logprobs:
|
|
85
|
-
self.request_json["logprobs"] = True
|
|
86
|
-
if top_logprobs is not None:
|
|
87
|
-
self.request_json["top_logprobs"] = top_logprobs
|
|
88
|
-
if sampling_params.json_mode and self.model.supports_json:
|
|
89
|
-
self.request_json["response_format"] = {"type": "json_object"}
|
|
90
|
-
if tools:
|
|
91
|
-
self.request_json["tools"] = [
|
|
92
|
-
tool.dump_for("openai-completions") for tool in tools
|
|
93
|
-
]
|
|
100
|
+
self.request_json = _build_oa_chat_request(
|
|
101
|
+
self.model, prompt, tools, sampling_params
|
|
102
|
+
)
|
|
94
103
|
|
|
95
104
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
96
105
|
is_error = False
|
|
97
106
|
error_message = None
|
|
98
107
|
thinking = None
|
|
99
108
|
content = None
|
|
100
|
-
|
|
101
|
-
output_tokens = None
|
|
109
|
+
usage = None
|
|
102
110
|
logprobs = None
|
|
103
111
|
status_code = http_response.status
|
|
104
112
|
mimetype = http_response.headers.get("Content-Type", None)
|
|
@@ -142,9 +150,11 @@ class OpenAIRequest(APIRequestBase):
|
|
|
142
150
|
|
|
143
151
|
content = Message("assistant", parts)
|
|
144
152
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
153
|
+
usage = Usage.from_openai_usage(data["usage"])
|
|
154
|
+
if (
|
|
155
|
+
self.sampling_params.logprobs
|
|
156
|
+
and "logprobs" in data["choices"][0]
|
|
157
|
+
):
|
|
148
158
|
logprobs = data["choices"][0]["logprobs"]["content"]
|
|
149
159
|
except Exception:
|
|
150
160
|
is_error = True
|
|
@@ -178,6 +188,228 @@ class OpenAIRequest(APIRequestBase):
|
|
|
178
188
|
content=content,
|
|
179
189
|
model_internal=self.model_name,
|
|
180
190
|
sampling_params=self.sampling_params,
|
|
181
|
-
|
|
182
|
-
|
|
191
|
+
usage=usage,
|
|
192
|
+
raw_response=data,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class OpenAIResponsesRequest(APIRequestBase):
|
|
197
|
+
def __init__(
|
|
198
|
+
self,
|
|
199
|
+
task_id: int,
|
|
200
|
+
model_name: str,
|
|
201
|
+
prompt: Conversation,
|
|
202
|
+
attempts_left: int,
|
|
203
|
+
status_tracker: StatusTracker,
|
|
204
|
+
results_arr: list,
|
|
205
|
+
request_timeout: int = 30,
|
|
206
|
+
sampling_params: SamplingParams = SamplingParams(),
|
|
207
|
+
callback: Callable | None = None,
|
|
208
|
+
all_model_names: list[str] | None = None,
|
|
209
|
+
all_sampling_params: list[SamplingParams] | None = None,
|
|
210
|
+
tools: list | None = None,
|
|
211
|
+
cache: CachePattern | None = None,
|
|
212
|
+
computer_use: bool = False,
|
|
213
|
+
display_width: int = 1024,
|
|
214
|
+
display_height: int = 768,
|
|
215
|
+
):
|
|
216
|
+
super().__init__(
|
|
217
|
+
task_id=task_id,
|
|
218
|
+
model_name=model_name,
|
|
219
|
+
prompt=prompt,
|
|
220
|
+
attempts_left=attempts_left,
|
|
221
|
+
status_tracker=status_tracker,
|
|
222
|
+
results_arr=results_arr,
|
|
223
|
+
request_timeout=request_timeout,
|
|
224
|
+
sampling_params=sampling_params,
|
|
225
|
+
callback=callback,
|
|
226
|
+
all_model_names=all_model_names,
|
|
227
|
+
all_sampling_params=all_sampling_params,
|
|
228
|
+
tools=tools,
|
|
229
|
+
cache=cache,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Store computer use parameters
|
|
233
|
+
self.computer_use = computer_use
|
|
234
|
+
self.display_width = display_width
|
|
235
|
+
self.display_height = display_height
|
|
236
|
+
|
|
237
|
+
# Validate computer use requirements
|
|
238
|
+
if computer_use and model_name != "openai-computer-use-preview":
|
|
239
|
+
raise ValueError(
|
|
240
|
+
f"Computer use is only supported with openai-computer-use-preview model, got {model_name}"
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# Warn if cache is specified for non-Anthropic model
|
|
244
|
+
if cache is not None:
|
|
245
|
+
warnings.warn(
|
|
246
|
+
f"Cache parameter '{cache}' is only supported for Anthropic models, ignoring for {model_name}"
|
|
247
|
+
)
|
|
248
|
+
self.model = APIModel.from_registry(model_name)
|
|
249
|
+
self.url = f"{self.model.api_base}/responses"
|
|
250
|
+
self.request_header = {
|
|
251
|
+
"Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
# Convert conversation to input format for Responses API
|
|
255
|
+
openai_responses_format = prompt.to_openai_responses()
|
|
256
|
+
|
|
257
|
+
self.request_json = {
|
|
258
|
+
"model": self.model.name,
|
|
259
|
+
"input": openai_responses_format["input"],
|
|
260
|
+
"temperature": sampling_params.temperature,
|
|
261
|
+
"top_p": sampling_params.top_p,
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
# Add max_output_tokens for responses API
|
|
265
|
+
if sampling_params.max_new_tokens:
|
|
266
|
+
self.request_json["max_output_tokens"] = sampling_params.max_new_tokens
|
|
267
|
+
|
|
268
|
+
if self.model.reasoning_model:
|
|
269
|
+
self.request_json["temperature"] = 1.0
|
|
270
|
+
self.request_json["top_p"] = 1.0
|
|
271
|
+
self.request_json["reasoning"] = {
|
|
272
|
+
"effort": sampling_params.reasoning_effort
|
|
273
|
+
}
|
|
274
|
+
else:
|
|
275
|
+
if sampling_params.reasoning_effort:
|
|
276
|
+
warnings.warn(
|
|
277
|
+
f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
if sampling_params.json_mode and self.model.supports_json:
|
|
281
|
+
self.request_json["text"] = {"format": {"type": "json_object"}}
|
|
282
|
+
|
|
283
|
+
# Handle tools
|
|
284
|
+
request_tools = []
|
|
285
|
+
if computer_use:
|
|
286
|
+
# Add computer use tool
|
|
287
|
+
request_tools.append(
|
|
288
|
+
{
|
|
289
|
+
"type": "computer_use_preview",
|
|
290
|
+
"display_width": display_width,
|
|
291
|
+
"display_height": display_height,
|
|
292
|
+
"environment": "browser", # Default to browser, could be configurable
|
|
293
|
+
}
|
|
294
|
+
)
|
|
295
|
+
# Set truncation to auto as required for computer use
|
|
296
|
+
self.request_json["truncation"] = "auto"
|
|
297
|
+
|
|
298
|
+
if tools:
|
|
299
|
+
# Add regular function tools
|
|
300
|
+
request_tools.extend([tool.dump_for("openai-responses") for tool in tools])
|
|
301
|
+
|
|
302
|
+
if request_tools:
|
|
303
|
+
self.request_json["tools"] = request_tools
|
|
304
|
+
|
|
305
|
+
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
306
|
+
is_error = False
|
|
307
|
+
error_message = None
|
|
308
|
+
thinking = None
|
|
309
|
+
content = None
|
|
310
|
+
usage = None
|
|
311
|
+
logprobs = None
|
|
312
|
+
status_code = http_response.status
|
|
313
|
+
mimetype = http_response.headers.get("Content-Type", None)
|
|
314
|
+
data = None
|
|
315
|
+
|
|
316
|
+
if status_code >= 200 and status_code < 300:
|
|
317
|
+
try:
|
|
318
|
+
data = await http_response.json()
|
|
319
|
+
except Exception:
|
|
320
|
+
is_error = True
|
|
321
|
+
error_message = (
|
|
322
|
+
f"Error calling .json() on response w/ status {status_code}"
|
|
323
|
+
)
|
|
324
|
+
if not is_error:
|
|
325
|
+
assert data is not None, "data is None"
|
|
326
|
+
try:
|
|
327
|
+
# Parse Responses API format
|
|
328
|
+
parts = []
|
|
329
|
+
|
|
330
|
+
# Get the output array from the response
|
|
331
|
+
output = data.get("output", [])
|
|
332
|
+
if not output:
|
|
333
|
+
is_error = True
|
|
334
|
+
error_message = "No output in response"
|
|
335
|
+
else:
|
|
336
|
+
# Process each output item
|
|
337
|
+
for item in output:
|
|
338
|
+
if item.get("type") == "message":
|
|
339
|
+
message_content = item.get("content", [])
|
|
340
|
+
for content_item in message_content:
|
|
341
|
+
if content_item.get("type") == "output_text":
|
|
342
|
+
parts.append(Text(content_item["text"]))
|
|
343
|
+
# Handle tool calls if present
|
|
344
|
+
elif content_item.get("type") == "tool_call":
|
|
345
|
+
tool_call = content_item["tool_call"]
|
|
346
|
+
parts.append(
|
|
347
|
+
ToolCall(
|
|
348
|
+
id=tool_call["id"],
|
|
349
|
+
name=tool_call["function"]["name"],
|
|
350
|
+
arguments=json.loads(
|
|
351
|
+
tool_call["function"]["arguments"]
|
|
352
|
+
),
|
|
353
|
+
)
|
|
354
|
+
)
|
|
355
|
+
elif item.get("type") == "computer_call":
|
|
356
|
+
# Handle computer use actions
|
|
357
|
+
action = item.get("action", {})
|
|
358
|
+
parts.append(
|
|
359
|
+
ToolCall(
|
|
360
|
+
id=item["call_id"],
|
|
361
|
+
name=f"_computer_{action.get('type', 'action')}",
|
|
362
|
+
arguments=action,
|
|
363
|
+
)
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# Handle reasoning if present
|
|
367
|
+
if "reasoning" in data and data["reasoning"].get("summary"):
|
|
368
|
+
thinking = data["reasoning"]["summary"]
|
|
369
|
+
parts.append(Thinking(thinking))
|
|
370
|
+
|
|
371
|
+
content = Message("assistant", parts)
|
|
372
|
+
|
|
373
|
+
# Extract usage information
|
|
374
|
+
if "usage" in data:
|
|
375
|
+
usage = Usage.from_openai_usage(data["usage"])
|
|
376
|
+
|
|
377
|
+
# Extract response_id for computer use continuation
|
|
378
|
+
# response_id = data.get("id")
|
|
379
|
+
|
|
380
|
+
except Exception as e:
|
|
381
|
+
is_error = True
|
|
382
|
+
error_message = f"Error parsing {self.model.name} responses API response: {str(e)}"
|
|
383
|
+
|
|
384
|
+
elif mimetype and "json" in mimetype.lower():
|
|
385
|
+
is_error = True
|
|
386
|
+
data = await http_response.json()
|
|
387
|
+
error_message = json.dumps(data)
|
|
388
|
+
else:
|
|
389
|
+
is_error = True
|
|
390
|
+
text = await http_response.text()
|
|
391
|
+
error_message = text
|
|
392
|
+
|
|
393
|
+
# Handle special kinds of errors
|
|
394
|
+
if is_error and error_message is not None:
|
|
395
|
+
if "rate limit" in error_message.lower() or status_code == 429:
|
|
396
|
+
error_message += " (Rate limit error, triggering cooldown.)"
|
|
397
|
+
self.status_tracker.rate_limit_exceeded()
|
|
398
|
+
if "context length" in error_message:
|
|
399
|
+
error_message += " (Context length exceeded, set retries to 0.)"
|
|
400
|
+
self.attempts_left = 0
|
|
401
|
+
|
|
402
|
+
return APIResponse(
|
|
403
|
+
id=self.task_id,
|
|
404
|
+
status_code=status_code,
|
|
405
|
+
is_error=is_error,
|
|
406
|
+
error_message=error_message,
|
|
407
|
+
prompt=self.prompt,
|
|
408
|
+
logprobs=logprobs,
|
|
409
|
+
thinking=thinking,
|
|
410
|
+
content=content,
|
|
411
|
+
model_internal=self.model_name,
|
|
412
|
+
sampling_params=self.sampling_params,
|
|
413
|
+
usage=usage,
|
|
414
|
+
raw_response=data,
|
|
183
415
|
)
|